Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/webwright/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ class BaseModelConfig(PydanticBaseModel):
format_error_template: OptStr = DEFAULT_FORMAT_ERROR_TEMPLATE
attach_observation_screenshot: bool = True
action_field: str = "bash_command"
throttle_rate: float = 0.0
throttle_capacity: int = 1

@field_validator("action_field")
@classmethod
Expand Down Expand Up @@ -429,6 +431,13 @@ def _format_repair_message(self, *, raw_text: str, error: str) -> dict[str, Any]
)

async def _post_with_retries(self, payload: dict[str, Any]) -> dict[str, Any]:
if self.config.throttle_rate > 0:
from webwright.utils.throttle import get_global_throttle

bucket = await get_global_throttle(
self.config.throttle_rate, self.config.throttle_capacity
)
await bucket.acquire()
headers = self._request_headers()
url = self._post_url()
for attempt in range(max(self._MAX_RATE_LIMIT_RETRIES, self._MAX_TRANSIENT_RETRIES) + 1):
Expand Down
76 changes: 76 additions & 0 deletions src/webwright/utils/throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Process-global async token bucket for throttling model API calls."""

from __future__ import annotations

import asyncio
import time


class AsyncTokenBucket:
"""Classic token bucket rate limiter for ``asyncio`` callers.

Parameters
----------
rate:
Tokens added per second (refill rate).
capacity:
Maximum burst size (bucket depth). Defaults to ``1``.
"""

def __init__(self, rate: float, capacity: int = 1) -> None:
if rate <= 0:
raise ValueError("rate must be positive")
if capacity < 1:
raise ValueError("capacity must be >= 1")
self.rate = rate
self.capacity = capacity
self._tokens: float = float(capacity)
self._last_refill: float = time.monotonic()
self._lock = asyncio.Lock()

def _refill(self) -> None:
now = time.monotonic()
elapsed = now - self._last_refill
self._tokens = min(self.capacity, self._tokens + elapsed * self.rate)
self._last_refill = now

async def acquire(self) -> None:
"""Wait until a token is available, then consume it."""
while True:
async with self._lock:
self._refill()
if self._tokens >= 1.0:
self._tokens -= 1.0
return
# How long until at least one token is available?
wait = (1.0 - self._tokens) / self.rate
await asyncio.sleep(wait)


# ---- process-global registry -------------------------------------------------

_throttle_registry: dict[tuple[float, int], AsyncTokenBucket] = {}
_registry_lock = asyncio.Lock()


async def get_global_throttle(rate: float, capacity: int = 1) -> AsyncTokenBucket:
"""Return (and lazily create) a throttle bucket for the given config.

Each unique ``(rate, capacity)`` pair receives its own bucket so that
different model configurations coexisting in the same process are
throttled independently.
"""
key = (rate, capacity)
bucket = _throttle_registry.get(key)
if bucket is not None:
return bucket
async with _registry_lock:
# Double-check after acquiring the lock.
if key not in _throttle_registry:
_throttle_registry[key] = AsyncTokenBucket(rate, capacity)
return _throttle_registry[key]


def reset_global_throttle() -> None:
"""Clear the registry — mainly for tests."""
_throttle_registry.clear()
88 changes: 88 additions & 0 deletions tests/unit/test_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Unit tests for webwright.utils.throttle."""

from __future__ import annotations

import asyncio
import time

import pytest

from webwright.utils.throttle import AsyncTokenBucket, get_global_throttle, reset_global_throttle


# ---- AsyncTokenBucket --------------------------------------------------------


def test_constructor_rejects_non_positive_rate() -> None:
with pytest.raises(ValueError, match="rate must be positive"):
AsyncTokenBucket(rate=0, capacity=1)
with pytest.raises(ValueError, match="rate must be positive"):
AsyncTokenBucket(rate=-1, capacity=1)


def test_constructor_rejects_zero_capacity() -> None:
with pytest.raises(ValueError, match="capacity must be >= 1"):
AsyncTokenBucket(rate=1.0, capacity=0)


@pytest.mark.asyncio
async def test_acquire_within_capacity() -> None:
bucket = AsyncTokenBucket(rate=100.0, capacity=3)
# Should be able to grab 3 tokens immediately.
for _ in range(3):
await bucket.acquire()


@pytest.mark.asyncio
async def test_acquire_blocks_when_empty() -> None:
bucket = AsyncTokenBucket(rate=20.0, capacity=1)
await bucket.acquire() # drain the single token

start = time.monotonic()
await bucket.acquire() # must wait ~0.05s for refill
elapsed = time.monotonic() - start

assert elapsed >= 0.03, f"Expected to block ~50ms, but took only {elapsed:.3f}s"


@pytest.mark.asyncio
async def test_burst_capacity_replenishes() -> None:
bucket = AsyncTokenBucket(rate=1000.0, capacity=5)
# Drain all 5 tokens.
for _ in range(5):
await bucket.acquire()
# After a short sleep tokens should have been added back.
await asyncio.sleep(0.01)
await bucket.acquire()


# ---- Registry ----------------------------------------------------------------


@pytest.fixture(autouse=True)
def _reset_registry() -> None:
reset_global_throttle()
yield # type: ignore[misc]
reset_global_throttle()


@pytest.mark.asyncio
async def test_same_config_returns_same_instance() -> None:
a = await get_global_throttle(10.0, 2)
b = await get_global_throttle(10.0, 2)
assert a is b


@pytest.mark.asyncio
async def test_different_config_returns_different_instance() -> None:
a = await get_global_throttle(10.0, 2)
b = await get_global_throttle(5.0, 1)
assert a is not b


@pytest.mark.asyncio
async def test_reset_clears_registry() -> None:
a = await get_global_throttle(10.0, 2)
reset_global_throttle()
b = await get_global_throttle(10.0, 2)
assert a is not b