Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/tether/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import base64
import contextvars
import hashlib
import inspect
import io
Expand All @@ -43,6 +44,7 @@
compute_config_hash,
compute_model_hash,
)
from .auth import generate_request_id
from .tracing import get_tracer, setup_tracing, shutdown_tracing

# Optional Prometheus metrics — gated on the [serve] extra (prometheus-client).
Expand Down Expand Up @@ -78,6 +80,12 @@ def track_in_flight(*args, **kwargs):

logger = logging.getLogger(__name__)
_tracer = get_tracer(__name__)
REQUEST_ID_HEADER = "X-Tether-Request-ID"
REQUEST_ID_ALIASES = (REQUEST_ID_HEADER, "X-Request-ID")
_request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"tether_request_id",
default="",
)

try:
from tether import __version__ as _TETHER_VERSION
Expand All @@ -97,6 +105,16 @@ def _coerce_optional_float(value: Any) -> float | None:
return out


def _resolve_http_request_id(request: Any) -> str:
for header in REQUEST_ID_ALIASES:
value = request.headers.get(header)
if value:
value = value.strip()
if value:
return value[:128]
return generate_request_id()


def _call_accepts_keyword(fn: Any, keyword: str) -> bool:
try:
params = inspect.signature(fn).parameters
Expand Down Expand Up @@ -3014,6 +3032,18 @@ async def _heartbeat_loop():
lifespan=lifespan,
)

@app.middleware("http")
async def _request_id_middleware(request, call_next):
req_id = _resolve_http_request_id(request)
request.state.request_id = req_id
token = _request_id_var.set(req_id)
try:
response = await call_next(request)
finally:
_request_id_var.reset(token)
response.headers[REQUEST_ID_HEADER] = req_id
return response

# Bearer auth dependency (Phase 1 auth-bearer feature).
# If api_key is set at app-creation time, every protected route requires
# the caller to pass `Authorization: Bearer <token>` (preferred) OR the
Expand Down Expand Up @@ -3108,6 +3138,9 @@ async def act(request: PredictRequest, _auth: None = Depends(_require_api_key)):
# Non-standard attrs under gen_ai.action.* — proposed for upstream
# OTel GenAI working group contribution (Phase 2 per spec).
span.set_attribute("gen_ai.action.embodiment", _emb_label)
_req_id = _request_id_var.get()
if _req_id:
span.set_attribute("tether.request_id", _req_id)
# chunk_size + denoise_steps are set AFTER predict returns (we don't
# know them until the result is in hand). See ~line 1590 below.
span.set_attribute(
Expand Down
100 changes: 100 additions & 0 deletions tests/test_request_id_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Request-id middleware coverage for the real Tether FastAPI app."""
from __future__ import annotations

from pathlib import Path

import pytest
from fastapi.testclient import TestClient


class _StubServer:
def __init__(self, export_dir, *args, **kwargs):
self.export_dir = Path(export_dir)
self._ready = True
self.health_state = "ready"
self._inference_mode = "stub"
self._vlm_loaded = False
self.consecutive_crash_count = 0
self.max_consecutive_crashes = 5
self.robot_id = ""

@property
def ready(self):
return self._ready

async def load(self):
self._ready = True
self.health_state = "ready"


@pytest.fixture
def app(tmp_path, monkeypatch):
from tether.runtime import server as runtime_server

monkeypatch.setattr(runtime_server, "TetherServer", _StubServer)
export_dir = tmp_path / "export"
export_dir.mkdir()
app = runtime_server.create_app(str(export_dir), device="cpu")

@app.get("/_test/request-id")
async def request_id_probe():
return {"request_id": runtime_server._request_id_var.get()}

return app


@pytest.fixture
def client(app):
return TestClient(app)


def test_health_response_has_tether_request_id_header(client):
response = client.get("/health")

assert response.headers["X-Tether-Request-ID"].startswith("req-")
assert "X-Reflex-Request-ID" not in response.headers


def test_each_request_gets_a_unique_generated_id(client):
ids = {client.get("/health").headers["X-Tether-Request-ID"] for _ in range(5)}

assert len(ids) == 5


def test_tether_request_id_header_is_echoed(client):
response = client.get(
"/health",
headers={"X-Tether-Request-ID": " req-user-supplied "},
)

assert response.headers["X-Tether-Request-ID"] == "req-user-supplied"


def test_generic_request_id_header_is_accepted_when_tether_header_missing(client):
response = client.get("/health", headers={"X-Request-ID": "edge-proxy-123"})

assert response.headers["X-Tether-Request-ID"] == "edge-proxy-123"


def test_tether_header_wins_over_generic_request_id(client):
response = client.get(
"/health",
headers={
"X-Tether-Request-ID": "req-tether",
"X-Request-ID": "proxy-request",
},
)

assert response.headers["X-Tether-Request-ID"] == "req-tether"


def test_request_id_is_available_inside_route_context(app):
client = TestClient(app)

response = client.get(
"/_test/request-id",
headers={"X-Tether-Request-ID": "req-route-context"},
)

assert response.json()["request_id"] == "req-route-context"
assert response.headers["X-Tether-Request-ID"] == "req-route-context"