diff --git a/src/webwright/models/base.py b/src/webwright/models/base.py index 7728bb8..87357fc 100644 --- a/src/webwright/models/base.py +++ b/src/webwright/models/base.py @@ -212,6 +212,8 @@ class BaseModelConfig(PydanticBaseModel): format_error_template: OptStr = DEFAULT_FORMAT_ERROR_TEMPLATE attach_observation_screenshot: bool = True action_field: str = "bash_command" + throttle_rate: float = 0.0 + throttle_capacity: int = 1 @field_validator("action_field") @classmethod @@ -429,6 +431,13 @@ def _format_repair_message(self, *, raw_text: str, error: str) -> dict[str, Any] ) async def _post_with_retries(self, payload: dict[str, Any]) -> dict[str, Any]: + if self.config.throttle_rate > 0: + from webwright.utils.throttle import get_global_throttle + + bucket = await get_global_throttle( + self.config.throttle_rate, self.config.throttle_capacity + ) + await bucket.acquire() headers = self._request_headers() url = self._post_url() for attempt in range(max(self._MAX_RATE_LIMIT_RETRIES, self._MAX_TRANSIENT_RETRIES) + 1): diff --git a/src/webwright/utils/throttle.py b/src/webwright/utils/throttle.py new file mode 100644 index 0000000..39a07a5 --- /dev/null +++ b/src/webwright/utils/throttle.py @@ -0,0 +1,76 @@ +"""Process-global async token bucket for throttling model API calls.""" + +from __future__ import annotations + +import asyncio +import time + + +class AsyncTokenBucket: + """Classic token bucket rate limiter for ``asyncio`` callers. + + Parameters + ---------- + rate: + Tokens added per second (refill rate). + capacity: + Maximum burst size (bucket depth). Defaults to ``1``. + """ + + def __init__(self, rate: float, capacity: int = 1) -> None: + if rate <= 0: + raise ValueError("rate must be positive") + if capacity < 1: + raise ValueError("capacity must be >= 1") + self.rate = rate + self.capacity = capacity + self._tokens: float = float(capacity) + self._last_refill: float = time.monotonic() + self._lock = asyncio.Lock() + + def _refill(self) -> None: + now = time.monotonic() + elapsed = now - self._last_refill + self._tokens = min(self.capacity, self._tokens + elapsed * self.rate) + self._last_refill = now + + async def acquire(self) -> None: + """Wait until a token is available, then consume it.""" + while True: + async with self._lock: + self._refill() + if self._tokens >= 1.0: + self._tokens -= 1.0 + return + # How long until at least one token is available? + wait = (1.0 - self._tokens) / self.rate + await asyncio.sleep(wait) + + +# ---- process-global registry ------------------------------------------------- + +_throttle_registry: dict[tuple[float, int], AsyncTokenBucket] = {} +_registry_lock = asyncio.Lock() + + +async def get_global_throttle(rate: float, capacity: int = 1) -> AsyncTokenBucket: + """Return (and lazily create) a throttle bucket for the given config. + + Each unique ``(rate, capacity)`` pair receives its own bucket so that + different model configurations coexisting in the same process are + throttled independently. + """ + key = (rate, capacity) + bucket = _throttle_registry.get(key) + if bucket is not None: + return bucket + async with _registry_lock: + # Double-check after acquiring the lock. + if key not in _throttle_registry: + _throttle_registry[key] = AsyncTokenBucket(rate, capacity) + return _throttle_registry[key] + + +def reset_global_throttle() -> None: + """Clear the registry — mainly for tests.""" + _throttle_registry.clear() diff --git a/tests/unit/test_throttle.py b/tests/unit/test_throttle.py new file mode 100644 index 0000000..8b6b6c0 --- /dev/null +++ b/tests/unit/test_throttle.py @@ -0,0 +1,88 @@ +"""Unit tests for webwright.utils.throttle.""" + +from __future__ import annotations + +import asyncio +import time + +import pytest + +from webwright.utils.throttle import AsyncTokenBucket, get_global_throttle, reset_global_throttle + + +# ---- AsyncTokenBucket -------------------------------------------------------- + + +def test_constructor_rejects_non_positive_rate() -> None: + with pytest.raises(ValueError, match="rate must be positive"): + AsyncTokenBucket(rate=0, capacity=1) + with pytest.raises(ValueError, match="rate must be positive"): + AsyncTokenBucket(rate=-1, capacity=1) + + +def test_constructor_rejects_zero_capacity() -> None: + with pytest.raises(ValueError, match="capacity must be >= 1"): + AsyncTokenBucket(rate=1.0, capacity=0) + + +@pytest.mark.asyncio +async def test_acquire_within_capacity() -> None: + bucket = AsyncTokenBucket(rate=100.0, capacity=3) + # Should be able to grab 3 tokens immediately. + for _ in range(3): + await bucket.acquire() + + +@pytest.mark.asyncio +async def test_acquire_blocks_when_empty() -> None: + bucket = AsyncTokenBucket(rate=20.0, capacity=1) + await bucket.acquire() # drain the single token + + start = time.monotonic() + await bucket.acquire() # must wait ~0.05s for refill + elapsed = time.monotonic() - start + + assert elapsed >= 0.03, f"Expected to block ~50ms, but took only {elapsed:.3f}s" + + +@pytest.mark.asyncio +async def test_burst_capacity_replenishes() -> None: + bucket = AsyncTokenBucket(rate=1000.0, capacity=5) + # Drain all 5 tokens. + for _ in range(5): + await bucket.acquire() + # After a short sleep tokens should have been added back. + await asyncio.sleep(0.01) + await bucket.acquire() + + +# ---- Registry ---------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_registry() -> None: + reset_global_throttle() + yield # type: ignore[misc] + reset_global_throttle() + + +@pytest.mark.asyncio +async def test_same_config_returns_same_instance() -> None: + a = await get_global_throttle(10.0, 2) + b = await get_global_throttle(10.0, 2) + assert a is b + + +@pytest.mark.asyncio +async def test_different_config_returns_different_instance() -> None: + a = await get_global_throttle(10.0, 2) + b = await get_global_throttle(5.0, 1) + assert a is not b + + +@pytest.mark.asyncio +async def test_reset_clears_registry() -> None: + a = await get_global_throttle(10.0, 2) + reset_global_throttle() + b = await get_global_throttle(10.0, 2) + assert a is not b