Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions api/src/scripts/populate_db_gbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from shared.common.license_utils import assign_license_by_url
from shared.database.database import generate_unique_id, configure_polymorphic_mappers
from shared.database_gen.sqlacodegen_models import Gbfsfeed, Location, Externalid
from shared.notifications.notification_event_service import emit_url_replaced, urls_differ

GBFS_PUBSUB_TOPIC_NAME = "validate-gbfs-feed"

Expand Down Expand Up @@ -108,9 +109,18 @@ def populate_db(self, session, fetch_url=True):
gbfs_feed.operator = row["Name"]
gbfs_feed.provider = row["Name"]
gbfs_feed.operator_url = row["URL"]
gbfs_feed.producer_url = row["Auto-Discovery URL"]
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
old_producer_url = gbfs_feed.producer_url
new_producer_url = row["Auto-Discovery URL"]
gbfs_feed.producer_url = new_producer_url
gbfs_feed.auto_discovery_url = new_producer_url
gbfs_feed.updated_at = datetime.now(pytz.utc)
if not is_new_feed and old_producer_url and urls_differ(old_producer_url, new_producer_url):
emit_url_replaced(
feed_stable_id=stable_id,
old_url=old_producer_url,
new_url=new_producer_url,
source="populate_db_gbfs",
)

if not gbfs_feed.locations: # If locations are empty, create a new location (no overwrite)
country_code = self.get_safe_value(row, "Country Code", "")
Expand Down
21 changes: 21 additions & 0 deletions api/src/scripts/populate_db_gtfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
Location,
Redirectingid,
)
from shared.notifications.notification_event_service import (
emit_feed_redirected,
emit_url_replaced,
urls_differ,
)
from utils.data_utils import set_up_defaults

if TYPE_CHECKING:
Expand Down Expand Up @@ -200,6 +205,14 @@ def process_redirects(self, session: "Session"):
)
# Flush to avoid FK violation
session.flush()
emit_feed_redirected(
source_stable_id=stable_id,
target_stable_id=target_stable_id,
old_url=getattr(feed, "producer_url", None),
new_url=getattr(target_feed, "producer_url", None),
source="populate_db_gtfs",
extra_data={"redirect_comment": comment} if comment else None,
)

def populate_db(self, session: "Session", fetch_url: bool = True):
"""
Expand Down Expand Up @@ -252,7 +265,15 @@ def populate_db(self, session: "Session", fetch_url: bool = True):
feed.note = self.get_safe_value(row, "note", "")
producer_url = self.get_safe_value(row, "urls.direct_download", "")
if "transitfeeds" not in producer_url: # Avoid setting transitfeeds as producer_url
old_producer_url = feed.producer_url
feed.producer_url = producer_url
if not is_new_feed and old_producer_url and urls_differ(old_producer_url, producer_url):
emit_url_replaced(
feed_stable_id=stable_id,
old_url=old_producer_url,
new_url=producer_url,
source="populate_db_gtfs",
)
feed.authentication_type = str(int(float(self.get_safe_value(row, "urls.authentication_type", "0"))))
feed.authentication_info_url = self.get_safe_value(row, "urls.authentication_info", "")
feed.api_key_parameter_name = self.get_safe_value(row, "urls.api_key_parameter_name", "")
Expand Down
141 changes: 141 additions & 0 deletions api/src/shared/common/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#
# MobilityData 2026
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Generic, reusable client-side rate limiting.

Provides a thread-safe token-bucket :class:`RateLimiter` and a small named
registry (:func:`get_rate_limiter`) so any outbound API caller can share a
single process-wide bucket keyed by a logical name (e.g. ``"brevo"``,
``"tdg"``). The algorithm is API-agnostic; callers only choose a name and rate.

Example::

limiter = get_rate_limiter("tdg", rate=10) # 10 requests/second
limiter.acquire() # blocks if necessary
response = requests.get(url)
"""

from __future__ import annotations

import threading
import time
from typing import Callable, Dict, Optional


class RateLimiter:
"""Thread-safe token-bucket rate limiter.

Tokens refill continuously at ``rate`` tokens per second up to ``capacity``
(the maximum burst). :meth:`acquire` blocks just long enough to keep the
effective call rate at or below ``rate``.

``clock`` and ``sleep`` are injectable so the limiter can be unit-tested
deterministically without real time passing.
"""

def __init__(
self,
rate: float,
capacity: Optional[float] = None,
clock: Callable[[], float] = time.monotonic,
sleep: Callable[[float], None] = time.sleep,
) -> None:
if rate <= 0:
raise ValueError("rate must be greater than 0")
if capacity is not None and capacity <= 0:
raise ValueError("capacity must be greater than 0")
self._rate = float(rate)
self._capacity = float(capacity if capacity is not None else rate)
self._clock = clock
self._sleep = sleep
self._tokens = self._capacity
self._timestamp = clock()
self._lock = threading.Lock()

@property
def rate(self) -> float:
return self._rate

@property
def capacity(self) -> float:
return self._capacity

def _refill(self) -> None:
now = self._clock()
elapsed = now - self._timestamp
if elapsed > 0:
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
self._timestamp = now

def acquire(self, n: float = 1) -> float:
"""Consume ``n`` tokens, blocking until they are available.

Returns the number of seconds spent waiting (``0`` when tokens were
immediately available). The lock is held for the call so concurrent
callers are serialized against the single shared bucket.
"""
if n <= 0:
return 0.0
with self._lock:
self._refill()
waited = 0.0
if self._tokens < n:
deficit = n - self._tokens
waited = deficit / self._rate
self._sleep(waited)
self._refill()
self._tokens -= n
return waited

def __enter__(self) -> "RateLimiter":
self.acquire()
return self

def __exit__(self, exc_type, exc, tb) -> None:
return None


_registry: Dict[str, RateLimiter] = {}
_registry_lock = threading.Lock()


def get_rate_limiter(
name: str,
rate: float,
capacity: Optional[float] = None,
) -> RateLimiter:
"""Return a process-wide :class:`RateLimiter` shared under ``name``.

The first caller for a given ``name`` configures the limiter; subsequent
calls return the same instance and ignore their ``rate``/``capacity``
arguments. Use :func:`reset_rate_limiter` in tests to reconfigure.
"""
limiter = _registry.get(name)
if limiter is None:
with _registry_lock:
limiter = _registry.get(name)
if limiter is None:
limiter = RateLimiter(rate, capacity=capacity)
_registry[name] = limiter
return limiter


def reset_rate_limiter(name: Optional[str] = None) -> None:
"""Drop the cached limiter for ``name`` (or all when ``name`` is None)."""
with _registry_lock:
if name is None:
_registry.clear()
else:
_registry.pop(name, None)
6 changes: 6 additions & 0 deletions api/src/shared/notifications/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Shared notification utilities.

Packages exported from here:
notification_event_service — emit_feed_redirected / emit_url_replaced
brevo_notification_sender — send_single / send_digest
"""
Loading
Loading