diff --git a/src/tether/runtime/server.py b/src/tether/runtime/server.py index fabe005..119b402 100644 --- a/src/tether/runtime/server.py +++ b/src/tether/runtime/server.py @@ -19,6 +19,7 @@ from __future__ import annotations import base64 +import contextvars import hashlib import inspect import io @@ -43,6 +44,7 @@ compute_config_hash, compute_model_hash, ) +from .auth import generate_request_id from .tracing import get_tracer, setup_tracing, shutdown_tracing # Optional Prometheus metrics — gated on the [serve] extra (prometheus-client). @@ -78,6 +80,12 @@ def track_in_flight(*args, **kwargs): logger = logging.getLogger(__name__) _tracer = get_tracer(__name__) +REQUEST_ID_HEADER = "X-Tether-Request-ID" +REQUEST_ID_ALIASES = (REQUEST_ID_HEADER, "X-Request-ID") +_request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar( + "tether_request_id", + default="", +) try: from tether import __version__ as _TETHER_VERSION @@ -97,6 +105,16 @@ def _coerce_optional_float(value: Any) -> float | None: return out +def _resolve_http_request_id(request: Any) -> str: + for header in REQUEST_ID_ALIASES: + value = request.headers.get(header) + if value: + value = value.strip() + if value: + return value[:128] + return generate_request_id() + + def _call_accepts_keyword(fn: Any, keyword: str) -> bool: try: params = inspect.signature(fn).parameters @@ -3014,6 +3032,18 @@ async def _heartbeat_loop(): lifespan=lifespan, ) + @app.middleware("http") + async def _request_id_middleware(request, call_next): + req_id = _resolve_http_request_id(request) + request.state.request_id = req_id + token = _request_id_var.set(req_id) + try: + response = await call_next(request) + finally: + _request_id_var.reset(token) + response.headers[REQUEST_ID_HEADER] = req_id + return response + # Bearer auth dependency (Phase 1 auth-bearer feature). # If api_key is set at app-creation time, every protected route requires # the caller to pass `Authorization: Bearer ` (preferred) OR the @@ -3108,6 +3138,9 @@ async def act(request: PredictRequest, _auth: None = Depends(_require_api_key)): # Non-standard attrs under gen_ai.action.* — proposed for upstream # OTel GenAI working group contribution (Phase 2 per spec). span.set_attribute("gen_ai.action.embodiment", _emb_label) + _req_id = _request_id_var.get() + if _req_id: + span.set_attribute("tether.request_id", _req_id) # chunk_size + denoise_steps are set AFTER predict returns (we don't # know them until the result is in hand). See ~line 1590 below. span.set_attribute( diff --git a/tests/test_request_id_header.py b/tests/test_request_id_header.py new file mode 100644 index 0000000..6d84b8f --- /dev/null +++ b/tests/test_request_id_header.py @@ -0,0 +1,100 @@ +"""Request-id middleware coverage for the real Tether FastAPI app.""" +from __future__ import annotations + +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + + +class _StubServer: + def __init__(self, export_dir, *args, **kwargs): + self.export_dir = Path(export_dir) + self._ready = True + self.health_state = "ready" + self._inference_mode = "stub" + self._vlm_loaded = False + self.consecutive_crash_count = 0 + self.max_consecutive_crashes = 5 + self.robot_id = "" + + @property + def ready(self): + return self._ready + + async def load(self): + self._ready = True + self.health_state = "ready" + + +@pytest.fixture +def app(tmp_path, monkeypatch): + from tether.runtime import server as runtime_server + + monkeypatch.setattr(runtime_server, "TetherServer", _StubServer) + export_dir = tmp_path / "export" + export_dir.mkdir() + app = runtime_server.create_app(str(export_dir), device="cpu") + + @app.get("/_test/request-id") + async def request_id_probe(): + return {"request_id": runtime_server._request_id_var.get()} + + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +def test_health_response_has_tether_request_id_header(client): + response = client.get("/health") + + assert response.headers["X-Tether-Request-ID"].startswith("req-") + assert "X-Reflex-Request-ID" not in response.headers + + +def test_each_request_gets_a_unique_generated_id(client): + ids = {client.get("/health").headers["X-Tether-Request-ID"] for _ in range(5)} + + assert len(ids) == 5 + + +def test_tether_request_id_header_is_echoed(client): + response = client.get( + "/health", + headers={"X-Tether-Request-ID": " req-user-supplied "}, + ) + + assert response.headers["X-Tether-Request-ID"] == "req-user-supplied" + + +def test_generic_request_id_header_is_accepted_when_tether_header_missing(client): + response = client.get("/health", headers={"X-Request-ID": "edge-proxy-123"}) + + assert response.headers["X-Tether-Request-ID"] == "edge-proxy-123" + + +def test_tether_header_wins_over_generic_request_id(client): + response = client.get( + "/health", + headers={ + "X-Tether-Request-ID": "req-tether", + "X-Request-ID": "proxy-request", + }, + ) + + assert response.headers["X-Tether-Request-ID"] == "req-tether" + + +def test_request_id_is_available_inside_route_context(app): + client = TestClient(app) + + response = client.get( + "/_test/request-id", + headers={"X-Tether-Request-ID": "req-route-context"}, + ) + + assert response.json()["request_id"] == "req-route-context" + assert response.headers["X-Tether-Request-ID"] == "req-route-context"