From 21f5804a3cd2932bfbedf3af96b1de69f8a07ca0 Mon Sep 17 00:00:00 2001 From: yuyili Date: Tue, 16 Jun 2026 15:25:00 +0800 Subject: [PATCH] =?UTF-8?q?FEAT:=20=E5=9C=A8=E6=A8=A1=E5=9E=8B=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E5=A4=B1=E8=B4=A5=E6=97=B6=E8=83=BD=E5=A4=9F=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E9=87=8D=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/llmagent_with_model_retry/.env | 15 + examples/llmagent_with_model_retry/README.md | 156 ++++++++ .../agent/__init__.py | 0 .../llmagent_with_model_retry/agent/agent.py | 43 +++ .../llmagent_with_model_retry/agent/config.py | 57 +++ .../agent/prompts.py | 13 + .../llmagent_with_model_retry/agent/tools.py | 40 ++ .../llmagent_with_model_retry/run_agent.py | 87 +++++ tests/configs/test_model_retry_config.py | 96 +++++ tests/models/test_anthropic_model.py | 28 ++ tests/models/test_anthropic_model_ext.py | 9 +- tests/models/test_litellm_model.py | 42 ++- tests/models/test_openai_model.py | 10 +- tests/models/test_retry.py | 346 ++++++++++++++++++ trpc_agent_sdk/configs/__init__.py | 12 + trpc_agent_sdk/configs/_model_retry_config.py | 124 +++++++ trpc_agent_sdk/models/_anthropic_model.py | 40 +- trpc_agent_sdk/models/_litellm_model.py | 50 +-- trpc_agent_sdk/models/_llm_model.py | 13 +- trpc_agent_sdk/models/_openai_model.py | 28 +- trpc_agent_sdk/models/_retry.py | 271 ++++++++++++++ 21 files changed, 1367 insertions(+), 113 deletions(-) create mode 100644 examples/llmagent_with_model_retry/.env create mode 100644 examples/llmagent_with_model_retry/README.md create mode 100644 examples/llmagent_with_model_retry/agent/__init__.py create mode 100644 examples/llmagent_with_model_retry/agent/agent.py create mode 100644 examples/llmagent_with_model_retry/agent/config.py create mode 100644 examples/llmagent_with_model_retry/agent/prompts.py create mode 100644 examples/llmagent_with_model_retry/agent/tools.py create mode 100644 examples/llmagent_with_model_retry/run_agent.py create mode 100644 tests/configs/test_model_retry_config.py create mode 100644 tests/models/test_retry.py create mode 100644 trpc_agent_sdk/configs/_model_retry_config.py create mode 100644 trpc_agent_sdk/models/_retry.py diff --git a/examples/llmagent_with_model_retry/.env b/examples/llmagent_with_model_retry/.env new file mode 100644 index 00000000..4f79bfc3 --- /dev/null +++ b/examples/llmagent_with_model_retry/.env @@ -0,0 +1,15 @@ +# Copy this file or edit it in place before running the example. +# The example uses an OpenAI-compatible endpoint. + +TRPC_AGENT_API_KEY= +TRPC_AGENT_BASE_URL= +TRPC_AGENT_MODEL_NAME= + +# Optional model retry tuning. +# num_retries is the number of retry attempts in addition to the initial call. +TRPC_AGENT_MODEL_RETRY_NUM_RETRIES=2 +TRPC_AGENT_MODEL_RETRY_INITIAL_BACKOFF=1.0 +TRPC_AGENT_MODEL_RETRY_MAX_BACKOFF=8.0 +TRPC_AGENT_MODEL_RETRY_BACKOFF_MULTIPLIER=2.0 +TRPC_AGENT_MODEL_RETRY_JITTER=true +TRPC_AGENT_MODEL_RETRY_RESPECT_RETRY_AFTER=true diff --git a/examples/llmagent_with_model_retry/README.md b/examples/llmagent_with_model_retry/README.md new file mode 100644 index 00000000..fb17f354 --- /dev/null +++ b/examples/llmagent_with_model_retry/README.md @@ -0,0 +1,156 @@ +# LLM Agent 模型重试示例 + +本示例演示如何在模型构造时传入 `ModelRetryConfig`,让业务代码不用自己实现重试逻辑;当 LLM 请求遇到限流、超时、网络波动等瞬时错误时,SDK 会自动重试。 + +## 关键特性 + +- **按需开启重试**:只有显式传入 `ModelRetryConfig` 的模型才会启用 SDK 托管重试。 +- **一套配置,多种模型可用**:同一个 `ModelRetryConfig` 可以传给 OpenAI、Anthropic 等模型实现,统一控制重试行为。 +- **只重试临时问题**:限流、超时、服务端内部错误、连接异常会在次数限制内重试。 +- **避免重复输出内容**:流式输出已经产生内容后如果再失败,错误会直接透出,不会重试并重复输出内容。 +- **优先使用服务端等待时间**:默认会读取 `Retry-After` / `retry-after-ms`,有服务端等待时间时优先使用。 + +## Agent 层级结构说明 + +本例是单 Agent 示例,重试配置绑定在模型上: + +```text +weather_retry_agent (LlmAgent) +├── model: OpenAIModel(..., model_retry_config=ModelRetryConfig(...)) +├── tool: get_weather_report(city) +└── runner: 无自定义重试逻辑 +``` + +关键文件: + +- [examples/llmagent_with_model_retry/agent/agent.py](./agent/agent.py):创建模型并注入 `ModelRetryConfig` +- [examples/llmagent_with_model_retry/agent/config.py](./agent/config.py):读取模型连接与重试环境变量 +- [examples/llmagent_with_model_retry/agent/tools.py](./agent/tools.py):天气工具实现 +- [examples/llmagent_with_model_retry/agent/prompts.py](./agent/prompts.py):提示词 +- [examples/llmagent_with_model_retry/run_agent.py](./run_agent.py):运行入口,展示业务层无需手写重试 + +## 关键代码解释 + +### 1) 创建重试配置 + +`agent/config.py` 从环境变量构造: + +```python +from trpc_agent_sdk.configs import ExponentialBackoffConfig +from trpc_agent_sdk.configs import ModelRetryConfig + +ModelRetryConfig( + num_retries=2, + backoff=ExponentialBackoffConfig( + initial_backoff=1.0, + max_backoff=8.0, + multiplier=2.0, + jitter=True, + respect_retry_after=True, + ), +) +``` + +### 2) 注入模型 + +`agent/agent.py` 将配置传给模型构造器: + +```python +OpenAIModel( + model_name=model_name, + api_key=api_key, + base_url=base_url, + model_retry_config=retry_config, +) +``` + +### 3) Runner 不需要重试逻辑 + +`run_agent.py` 仍然只调用: + +```python +async for event in runner.run_async(...): + ... +``` + +如果模型调用在产出内容前遇到可重试错误,SDK 会按 `ModelRetryConfig` 自动重试。 + +## 会重试和不会重试的场景 + +### 会重试 + +- `429` 限流 +- `408` / `409` 超时或模型服务锁等待超时 +- `5xx` 服务端内部错误 +- 超时异常 +- 连接或传输异常 + +### 不会重试 + +- `400` 错误请求 +- `401` / `403` 认证或权限错误 +- 未识别错误类别 +- 重试次数已耗尽 +- 流式输出已经产生内容后才发生的错误 + +## 环境与运行 + +### 环境要求 + +- Python 3.10+ + +### 安装步骤 + +```bash +git clone https://github.com/trpc-group/trpc-agent-python.git +cd trpc-agent-python +python3 -m venv .venv +source .venv/bin/activate +pip3 install -e . +``` + +### 环境变量要求 + +在 [examples/llmagent_with_model_retry/.env](./.env) 中配置(或通过 `export` 设置): + +```bash +# 必填:模型连接配置 +TRPC_AGENT_API_KEY= +TRPC_AGENT_BASE_URL= +TRPC_AGENT_MODEL_NAME= + +# 可选:模型重试配置;不设置时使用示例默认值 +TRPC_AGENT_MODEL_RETRY_NUM_RETRIES=2 +TRPC_AGENT_MODEL_RETRY_INITIAL_BACKOFF=1.0 +TRPC_AGENT_MODEL_RETRY_MAX_BACKOFF=8.0 +TRPC_AGENT_MODEL_RETRY_BACKOFF_MULTIPLIER=2.0 +TRPC_AGENT_MODEL_RETRY_JITTER=true +TRPC_AGENT_MODEL_RETRY_RESPECT_RETRY_AFTER=true +``` + +### 运行命令 + +```bash +cd examples/llmagent_with_model_retry +python3 run_agent.py +``` + +## 运行结果示例 + +```text +Model retry enabled: {'num_retries': 2, 'rules': {'retryable_error_codes': ['408', '409', '429', '500', '502', '503', '504'], 'non_retryable_error_codes': ['400', '401', '403'], 'retryable_exception_class_name_parts': ['Timeout', 'Connection', 'Transport']}, 'backoff': {'type': 'exponential', 'initial_backoff': 1.0, 'max_backoff': 8.0, 'multiplier': 2.0, 'jitter': True, 'respect_retry_after': True}} +Session ID: fdb9e370... +SDK-managed retry is configured on the model; no retry loop is needed in this runner. +User: What's the current weather in Beijing? +Assistant: +Invoke Tool: get_weather_report({'city': 'Beijing'}) +Tool Result: {'temperature': '25°C', 'condition': 'Sunny', 'humidity': '60%'} + +# 当模型服务返回可重试错误时,SDK 会在模型层自动重试: +[WARNING] Model call failed (category=rate_limit); retrying in 0.46s (attempt 1/2). +``` + +## 适用场景建议 + +- 在模型服务限流,或其他瞬时错误时自动重试 +- 需要结合服务端 `Retry-After` 控制退避等待时间。 diff --git a/examples/llmagent_with_model_retry/agent/__init__.py b/examples/llmagent_with_model_retry/agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/llmagent_with_model_retry/agent/agent.py b/examples/llmagent_with_model_retry/agent/agent.py new file mode 100644 index 00000000..25b3e7d4 --- /dev/null +++ b/examples/llmagent_with_model_retry/agent/agent.py @@ -0,0 +1,43 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent module for the model retry example.""" + +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import LLMModel +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import FunctionTool + +from .config import get_model_config +from .config import get_model_retry_config +from .prompts import INSTRUCTION +from .tools import get_weather_report + + +def _create_model() -> LLMModel: + """Create an OpenAI-compatible model with SDK-managed retry enabled.""" + api_key, base_url, model_name = get_model_config() + retry_config = get_model_retry_config() + print(f"Model retry enabled: {retry_config.model_dump()}") + return OpenAIModel( + model_name=model_name, + api_key=api_key, + base_url=base_url, + model_retry_config=retry_config, + ) + + +def create_agent() -> LlmAgent: + """Create a weather agent that uses model-level retry.""" + return LlmAgent( + name="weather_retry_agent", + description="A weather assistant with SDK-managed model retry enabled.", + model=_create_model(), + instruction=INSTRUCTION, + tools=[FunctionTool(get_weather_report)], + ) + + +root_agent = create_agent() diff --git a/examples/llmagent_with_model_retry/agent/config.py b/examples/llmagent_with_model_retry/agent/config.py new file mode 100644 index 00000000..105a1943 --- /dev/null +++ b/examples/llmagent_with_model_retry/agent/config.py @@ -0,0 +1,57 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Configuration helpers for the model retry example.""" + +import os + +from trpc_agent_sdk.configs import ExponentialBackoffConfig +from trpc_agent_sdk.configs import ModelRetryConfig + + +def _get_bool(name: str, default: bool) -> bool: + value = os.getenv(name) + if value is None or value == "": + return default + return value.strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _get_float(name: str, default: float) -> float: + value = os.getenv(name) + if value is None or value == "": + return default + return float(value) + + +def _get_int(name: str, default: int) -> int: + value = os.getenv(name) + if value is None or value == "": + return default + return int(value) + + +def get_model_config() -> tuple[str, str, str]: + """Get model connection settings from environment variables.""" + api_key = os.getenv("TRPC_AGENT_API_KEY", "") + base_url = os.getenv("TRPC_AGENT_BASE_URL", "") + model_name = os.getenv("TRPC_AGENT_MODEL_NAME", "") + if not api_key or not base_url or not model_name: + raise ValueError("TRPC_AGENT_API_KEY, TRPC_AGENT_BASE_URL, and " + "TRPC_AGENT_MODEL_NAME must be set in environment variables") + return api_key, base_url, model_name + + +def get_model_retry_config() -> ModelRetryConfig: + """Build the opt-in SDK-managed model retry configuration.""" + return ModelRetryConfig( + num_retries=_get_int("TRPC_AGENT_MODEL_RETRY_NUM_RETRIES", 2), + backoff=ExponentialBackoffConfig( + initial_backoff=_get_float("TRPC_AGENT_MODEL_RETRY_INITIAL_BACKOFF", 1.0), + max_backoff=_get_float("TRPC_AGENT_MODEL_RETRY_MAX_BACKOFF", 8.0), + multiplier=_get_float("TRPC_AGENT_MODEL_RETRY_BACKOFF_MULTIPLIER", 2.0), + jitter=_get_bool("TRPC_AGENT_MODEL_RETRY_JITTER", True), + respect_retry_after=_get_bool("TRPC_AGENT_MODEL_RETRY_RESPECT_RETRY_AFTER", True), + ), + ) diff --git a/examples/llmagent_with_model_retry/agent/prompts.py b/examples/llmagent_with_model_retry/agent/prompts.py new file mode 100644 index 00000000..f9f4b76d --- /dev/null +++ b/examples/llmagent_with_model_retry/agent/prompts.py @@ -0,0 +1,13 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompts for the model retry example agent.""" + +INSTRUCTION = """You are a practical weather assistant. + +When the user asks for weather, identify the city and call get_weather_report. +If the city is missing, ask one short clarification question. +After receiving tool results, summarize the weather clearly and mention the retry configuration only if the user asks about reliability. +""" diff --git a/examples/llmagent_with_model_retry/agent/tools.py b/examples/llmagent_with_model_retry/agent/tools.py new file mode 100644 index 00000000..a5715eb0 --- /dev/null +++ b/examples/llmagent_with_model_retry/agent/tools.py @@ -0,0 +1,40 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tools for the model retry example agent.""" + + +def get_weather_report(city: str) -> dict: + """Get weather information for the specified city.""" + weather_data = { + "Beijing": { + "temperature": "25°C", + "condition": "Sunny", + "humidity": "60%", + }, + "Shanghai": { + "temperature": "28°C", + "condition": "Cloudy", + "humidity": "70%", + }, + "Guangzhou": { + "temperature": "32°C", + "condition": "Thunderstorm", + "humidity": "85%", + }, + "Shenzhen": { + "temperature": "30°C", + "condition": "Light rain", + "humidity": "78%", + }, + } + return weather_data.get( + city, + { + "temperature": "Unknown", + "condition": "Data not available", + "humidity": "Unknown", + }, + ) diff --git a/examples/llmagent_with_model_retry/run_agent.py b/examples/llmagent_with_model_retry/run_agent.py new file mode 100644 index 00000000..a053ec81 --- /dev/null +++ b/examples/llmagent_with_model_retry/run_agent.py @@ -0,0 +1,87 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Run the model retry weather agent example.""" + +import asyncio +import uuid + +from dotenv import load_dotenv +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +load_dotenv() + + +async def run_weather_agent() -> None: + """Run the weather query agent with model-level retry enabled.""" + app_name = "model_retry_weather_demo" + + from agent.agent import root_agent + + session_service = InMemorySessionService() + runner = Runner(app_name=app_name, agent=root_agent, session_service=session_service) + + user_id = "demo_user" + session_id = str(uuid.uuid4()) + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={ + "user_name": user_id, + "user_city": "Beijing", + }, + ) + + query = "What's the current weather in Beijing?" + print(f"Session ID: {session_id[:8]}...") + print("SDK-managed retry is configured on the model; no retry loop is needed in this runner.") + print(f"User: {query}") + print("Assistant: ", end="", flush=True) + + user_content = Content(parts=[Part.from_text(text=query)]) + assistant_started = True + + async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=user_content): + if event.is_error(): + if assistant_started: + print() + assistant_started = False + print(f"Error: {event.error_code}: {event.error_message}") + continue + + if not event.content or not event.content.parts: + continue + + if event.partial: + for part in event.content.parts: + if part.text and not part.thought: + if not assistant_started: + print("Assistant: ", end="", flush=True) + assistant_started = True + print(part.text, end="", flush=True) + continue + + for part in event.content.parts: + if part.thought: + continue + if part.function_call: + print(f"\nInvoke Tool: {part.function_call.name}({part.function_call.args})") + assistant_started = False + elif part.function_response: + print(f"Tool Result: {part.function_response.response}") + elif part.text and not assistant_started: + print("Assistant: ", end="", flush=True) + print(part.text, end="", flush=True) + assistant_started = True + + print("\n") + + +if __name__ == "__main__": + asyncio.run(run_weather_agent()) diff --git a/tests/configs/test_model_retry_config.py b/tests/configs/test_model_retry_config.py new file mode 100644 index 00000000..bb69e344 --- /dev/null +++ b/tests/configs/test_model_retry_config.py @@ -0,0 +1,96 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Unit tests for ModelRetryConfig (_model_retry_config).""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from trpc_agent_sdk.configs import ExponentialBackoffConfig +from trpc_agent_sdk.configs import FixedBackoffConfig +from trpc_agent_sdk.configs import ModelRetryConfig +from trpc_agent_sdk.configs import RegisteredBackoffConfig +from trpc_agent_sdk.configs import RetryRuleConfig +from trpc_agent_sdk.configs._model_retry_config import ExponentialBackoffConfig as ExponentialBackoffConfigDirect +from trpc_agent_sdk.configs._model_retry_config import FixedBackoffConfig as FixedBackoffConfigDirect +from trpc_agent_sdk.configs._model_retry_config import ModelRetryConfig as ModelRetryConfigDirect +from trpc_agent_sdk.configs._model_retry_config import RegisteredBackoffConfig as RegisteredBackoffConfigDirect +from trpc_agent_sdk.configs._model_retry_config import RetryRuleConfig as RetryRuleConfigDirect + + +class TestDefaults: + def test_default_values(self): + cfg = ModelRetryConfig() + assert cfg.num_retries == 2 + assert cfg.rules.retryable_error_codes == ["408", "409", "429", "500", "502", "503", "504"] + assert cfg.rules.non_retryable_error_codes == ["400", "401", "403"] + assert cfg.rules.retryable_exception_class_name_parts == ["Timeout", "Connection", "Transport"] + assert isinstance(cfg.backoff, ExponentialBackoffConfig) + assert cfg.backoff.type == "exponential" + assert cfg.backoff.initial_backoff == 1.0 + assert cfg.backoff.max_backoff == 10.0 + assert cfg.backoff.multiplier == 2.0 + assert cfg.backoff.jitter is True + assert cfg.backoff.respect_retry_after is True + + def test_fixed_backoff_values(self): + cfg = ModelRetryConfig(backoff=FixedBackoffConfig(interval=3.0, max_backoff=5.0, jitter=False)) + assert isinstance(cfg.backoff, FixedBackoffConfig) + assert cfg.backoff.type == "fixed" + assert cfg.backoff.interval == 3.0 + assert cfg.backoff.max_backoff == 5.0 + assert cfg.backoff.jitter is False + + +class TestValidation: + def test_negative_num_retries_rejected(self): + with pytest.raises(ValidationError): + ModelRetryConfig(num_retries=-1) + + def test_zero_num_retries_allowed(self): + cfg = ModelRetryConfig(num_retries=0) + assert cfg.num_retries == 0 + + def test_multiplier_below_one_rejected(self): + with pytest.raises(ValidationError): + ExponentialBackoffConfig(multiplier=0.5) + + def test_negative_exponential_backoff_rejected(self): + with pytest.raises(ValidationError): + ExponentialBackoffConfig(initial_backoff=-1.0) + + def test_negative_fixed_interval_rejected(self): + with pytest.raises(ValidationError): + FixedBackoffConfig(interval=-1.0) + + def test_registered_backoff_config_allows_unknown_type(self): + cfg = ModelRetryConfig(backoff={"type": "linear", "step": 0.5}) + assert isinstance(cfg.backoff, RegisteredBackoffConfig) + assert cfg.backoff.type == "linear" + assert cfg.backoff.step == 0.5 + + def test_old_flat_fields_are_rejected(self): + with pytest.raises(ValidationError): + ModelRetryConfig(initial_backoff=1.0) + with pytest.raises(ValidationError): + ModelRetryConfig(backoff_strategy="fixed") + with pytest.raises(ValidationError): + ModelRetryConfig(retryable_error_codes=["429"]) + + def test_custom_error_codes_allowed_in_rules(self): + cfg = ModelRetryConfig(rules=RetryRuleConfig(retryable_error_codes=["429"], non_retryable_error_codes=["400"])) + assert cfg.rules.retryable_error_codes == ["429"] + assert cfg.rules.non_retryable_error_codes == ["400"] + + +class TestExport: + def test_reexports_are_same_classes(self): + assert ModelRetryConfig is ModelRetryConfigDirect + assert ExponentialBackoffConfig is ExponentialBackoffConfigDirect + assert FixedBackoffConfig is FixedBackoffConfigDirect + assert RegisteredBackoffConfig is RegisteredBackoffConfigDirect + assert RetryRuleConfig is RetryRuleConfigDirect diff --git a/tests/models/test_anthropic_model.py b/tests/models/test_anthropic_model.py index f1d6380e..f800b5ea 100644 --- a/tests/models/test_anthropic_model.py +++ b/tests/models/test_anthropic_model.py @@ -701,6 +701,34 @@ def test_all_breakpoints_inject_all_points(self): assert api_params["messages"][0]["content"][0]["cache_control"]["type"] == "ephemeral" +class TestAnthropicModelRetryErrors: + + @pytest.mark.asyncio + async def test_generate_single_error_raises_and_closes_client(self): + model = AnthropicModel(model_name="claude-3-5-sonnet-20241022", api_key="test-key") + client = MagicMock() + client.messages.create = AsyncMock(side_effect=TimeoutError("timeout")) + client.close = AsyncMock() + + with patch.object(model, "_create_async_client", return_value=client): + with pytest.raises(TimeoutError): + await model._generate_single({}, LlmRequest(contents=[])) + + client.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_generate_async_converts_provider_exception_to_retry_error_response(self): + model = AnthropicModel(model_name="claude-3-5-sonnet-20241022", api_key="test-key") + request = LlmRequest(contents=[Content(parts=[Part.from_text(text="hi")], role="user")]) + + with patch.object(model, "_generate_single", side_effect=ConnectionError("offline")): + responses = [response async for response in model.generate_async(request, stream=False)] + + assert len(responses) == 1 + assert responses[0].error_code == "API_ERROR" + assert responses[0].custom_metadata == {"error": "offline"} + + class TestAnthropicBuildUsageMetadata: """Tests for AnthropicModel._build_usage_metadata cache-inclusive normalization.""" diff --git a/tests/models/test_anthropic_model_ext.py b/tests/models/test_anthropic_model_ext.py index 60630c0e..8b7e609e 100644 --- a/tests/models/test_anthropic_model_ext.py +++ b/tests/models/test_anthropic_model_ext.py @@ -467,16 +467,15 @@ def test_no_warning_for_supported_options(self): class TestGenerateSingleError: @pytest.mark.asyncio - async def test_api_error_returns_error_response(self): - """API error during single generation returns error LlmResponse.""" + async def test_api_error_raises_and_closes_client(self): + """API error during single generation is raised to the retry layer.""" model = _model() mock_client = AsyncMock() mock_client.messages.create = AsyncMock(side_effect=RuntimeError("timeout")) mock_client.close = AsyncMock() with patch.object(model, "_create_async_client", return_value=mock_client): - resp = await model._generate_single({}, LlmRequest(contents=[])) - assert resp.error_code == "API_ERROR" - assert "timeout" in resp.error_message + with pytest.raises(RuntimeError, match="timeout"): + await model._generate_single({}, LlmRequest(contents=[])) mock_client.close.assert_awaited_once() diff --git a/tests/models/test_litellm_model.py b/tests/models/test_litellm_model.py index a7810015..9366bc42 100644 --- a/tests/models/test_litellm_model.py +++ b/tests/models/test_litellm_model.py @@ -361,6 +361,7 @@ def fake_import(name, *a, **k): assert len(responses) == 1 assert responses[0].error_code == "API_ERROR" assert "Connection refused" in (responses[0].error_message or "") + assert responses[0].custom_metadata == {"error": responses[0].error_message} @pytest.mark.asyncio async def test_generate_async_passes_response_format_for_openai_model(self): @@ -586,6 +587,7 @@ def fake_import(name, *a, **k): assert len(responses) == 1 assert responses[0].error_code == "STREAMING_ERROR" assert "Stream broken" in (responses[0].error_message or "") + assert responses[0].custom_metadata == {"error": responses[0].error_message} class TestLiteLLMModelValidateRequest: @@ -697,9 +699,18 @@ def test_injection_points_added_for_system_and_messages(self): api_params = { "tools": [], "messages": [ - {"role": "user", "content": "hi"}, - {"role": "assistant", "content": "hello"}, - {"role": "user", "content": "again"}, + { + "role": "user", + "content": "hi" + }, + { + "role": "assistant", + "content": "hello" + }, + { + "role": "user", + "content": "again" + }, ], } model._apply_prompt_cache(api_params, None) @@ -960,9 +971,18 @@ def test_system_breakpoint_adds_message_role_system(self): def test_messages_breakpoint_adds_latest_assistant_index(self): messages = [ - {"role": "user", "content": "hi"}, - {"role": "assistant", "content": "hello"}, - {"role": "user", "content": "again"}, + { + "role": "user", + "content": "hi" + }, + { + "role": "assistant", + "content": "hello" + }, + { + "role": "user", + "content": "again" + }, ] points = self._build("anthropic/claude-3", ["messages"], messages=messages) assert any(p.get("index") == 1 for p in points) @@ -995,8 +1015,14 @@ def test_no_ttl_produces_ephemeral_only_control(self): def test_all_non_bedrock_breakpoints_no_tool_config_point(self): """All three breakpoints for a non-Bedrock provider: no tool_config point.""" messages = [ - {"role": "user", "content": "hi"}, - {"role": "assistant", "content": "hello"}, + { + "role": "user", + "content": "hi" + }, + { + "role": "assistant", + "content": "hello" + }, ] points = self._build("anthropic/claude-3", ["tools", "system", "messages"], messages=messages) assert not any(p.get("location") == "tool_config" for p in points) diff --git a/tests/models/test_openai_model.py b/tests/models/test_openai_model.py index 7f15ea0d..ab5994b5 100644 --- a/tests/models/test_openai_model.py +++ b/tests/models/test_openai_model.py @@ -285,15 +285,15 @@ async def test_generate_async_simple_text_response(self): @pytest.mark.asyncio async def test_generate_async_validation_failure(self): - """Test generate_async raises ValueError on invalid request.""" + """Test generate_async converts validation failures to error responses.""" model = OpenAIModel(model_name="gpt-4", api_key="test_key") - # Empty contents request = LlmRequest(contents=[], config=None, tools_dict={}) - with pytest.raises(ValueError, match="At least one content is required"): - async for _ in model.generate_async(request, stream=False): - pass + responses = [response async for response in model.generate_async(request, stream=False)] + assert len(responses) == 1 + assert responses[0].error_code == "API_ERROR" + assert "At least one content is required" in (responses[0].error_message or "") @pytest.mark.asyncio async def test_generate_async_with_config_parameters(self): diff --git a/tests/models/test_retry.py b/tests/models/test_retry.py new file mode 100644 index 00000000..0e137944 --- /dev/null +++ b/tests/models/test_retry.py @@ -0,0 +1,346 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Unit tests for the model retry policy layer.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from typing import Optional +from unittest.mock import AsyncMock +from unittest.mock import patch + +from trpc_agent_sdk.configs import ExponentialBackoffConfig +from trpc_agent_sdk.configs import FixedBackoffConfig +from trpc_agent_sdk.configs import ModelRetryConfig +from trpc_agent_sdk.configs import RegisteredBackoffConfig +from trpc_agent_sdk.configs import RetryRuleConfig +from trpc_agent_sdk.models._llm_response import LlmResponse +from trpc_agent_sdk.models._retry import BackoffStrategy +from trpc_agent_sdk.models._retry import DefaultRetryPolicy +from trpc_agent_sdk.models._retry import ExponentialBackoffStrategy +from trpc_agent_sdk.models._retry import FixedBackoffStrategy +from trpc_agent_sdk.models._retry import RetryErrorInfo +from trpc_agent_sdk.models._retry import register_backoff_strategy +from trpc_agent_sdk.models._retry import _analyze_exception +from trpc_agent_sdk.models._retry import retry_model_call +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + + +class _StatusError(Exception): + + def __init__(self, status_code: int | str, headers: Optional[dict] = None): + super().__init__(f"status {status_code}") + self.status_code = status_code + if headers is not None: + self.response = type("Resp", (), {"headers": headers})() + + +class _HeadersError(Exception): + + def __init__(self, headers: dict): + super().__init__("headers") + self.headers = headers + + +class _LiteLlmHeadersError(Exception): + + def __init__(self, headers: dict): + super().__init__("litellm headers") + self.litellm_response_headers = headers + + +class _ResponseWithoutHeaderGet(Exception): + + def __init__(self): + super().__init__("bad headers") + self.response = type("Resp", (), {"headers": object()})() + + +class _NamedTimeout(Exception): + pass + + +class _NamedConnection(Exception): + pass + + +class _ZeroDelayBackoffStrategy(BackoffStrategy): + + def compute_delay(self, config, attempt: int, retry_after: Optional[float]) -> float: + return 0.0 + + +class _ConstantDelayBackoffStrategy(BackoffStrategy): + + def compute_delay(self, config, attempt: int, retry_after: Optional[float]) -> float: + return 7.0 + + +def _content_response(text: str = "hi", partial: bool = False) -> LlmResponse: + return LlmResponse(content=Content(parts=[Part.from_text(text=text)], role="model"), partial=partial) + + +async def _collect(call_model, config: Optional[ModelRetryConfig] = None) -> list[LlmResponse]: + return [response async for response in retry_model_call(call_model, config)] + + +class TestAnalyzeException: + + def test_status_code_extraction(self): + info = _analyze_exception(_StatusError("429"), RetryRuleConfig()) + assert info.status_code == 429 + assert info.exception_class_name_matched is False + + def test_retryable_exception_class_name_match(self): + rules = RetryRuleConfig() + assert _analyze_exception(_NamedTimeout("x"), rules).exception_class_name_matched is True + assert _analyze_exception(_NamedConnection("x"), rules).exception_class_name_matched is True + assert _analyze_exception(TimeoutError("x"), rules).exception_class_name_matched is True + assert _analyze_exception(ConnectionError("x"), rules).exception_class_name_matched is True + assert _analyze_exception(ValueError("x"), rules).exception_class_name_matched is False + + def test_custom_type_name_rules(self): + rules = RetryRuleConfig(retryable_exception_class_name_parts=["Slow", "Socket"]) + slow_error = type("ProviderSlowError", (Exception,), {}) + socket_error = type("ProviderSocketError", (Exception,), {}) + assert _analyze_exception(slow_error("x"), rules).exception_class_name_matched is True + assert _analyze_exception(socket_error("x"), rules).exception_class_name_matched is True + assert _analyze_exception(_NamedTimeout("x"), rules).exception_class_name_matched is False + + def test_retry_after_headers(self): + rules = RetryRuleConfig() + assert _analyze_exception(_StatusError(429, {"retry-after-ms": "2500"}), rules).retry_after == 2.5 + assert _analyze_exception(_StatusError(429, {"retry-after": "7"}), rules).retry_after == 7.0 + assert _analyze_exception(_StatusError(429, {"retry-after": "not-a-date"}), rules).retry_after is None + assert _analyze_exception(_HeadersError({"retry-after": "3"}), rules).retry_after == 3.0 + assert _analyze_exception(_LiteLlmHeadersError({"retry-after": "4"}), rules).retry_after == 4.0 + assert _analyze_exception(_ResponseWithoutHeaderGet(), rules).retry_after is None + + +class TestBackoffStrategies: + + def test_exponential_backoff(self): + cfg = ExponentialBackoffConfig(jitter=False, initial_backoff=1.0, max_backoff=10.0, multiplier=2.0) + strategy = ExponentialBackoffStrategy() + assert strategy.compute_delay(cfg, 0, None) == 1.0 + assert strategy.compute_delay(cfg, 1, None) == 2.0 + assert strategy.compute_delay(cfg, 10, None) == 10.0 + + def test_fixed_backoff(self): + cfg = FixedBackoffConfig(interval=3.0, max_backoff=10.0, jitter=False) + strategy = FixedBackoffStrategy() + assert strategy.compute_delay(cfg, 0, None) == 3.0 + assert strategy.compute_delay(cfg, 5, None) == 3.0 + + def test_retry_after_overrides_backoff(self): + assert ExponentialBackoffStrategy().compute_delay(ExponentialBackoffConfig(jitter=False), 0, 5.0) == 5.0 + assert FixedBackoffStrategy().compute_delay(FixedBackoffConfig(jitter=False), 0, 5.0) == 5.0 + + def test_jitter_bounds(self): + cfg = ExponentialBackoffConfig(jitter=True, initial_backoff=1.0, max_backoff=10.0, multiplier=2.0) + for _ in range(50): + delay = ExponentialBackoffStrategy().compute_delay(cfg, 2, None) + assert 0.0 <= delay <= 4.0 + + +class TestDefaultRetryPolicy: + + def test_retryable_decision(self): + cfg = ModelRetryConfig(num_retries=2, backoff=ExponentialBackoffConfig(jitter=False)) + error_info = RetryErrorInfo(status_code=429) + decision = DefaultRetryPolicy().decide(cfg, 0, error_info) + assert decision.should_retry is True + assert decision.delay == 1.0 + + def test_budget_exhausted(self): + cfg = ModelRetryConfig(num_retries=1, backoff=ExponentialBackoffConfig(jitter=False)) + error_info = RetryErrorInfo(status_code=429) + decision = DefaultRetryPolicy().decide(cfg, 1, error_info) + assert decision.should_retry is False + + def test_non_retryable_status(self): + cfg = ModelRetryConfig(num_retries=2, backoff=ExponentialBackoffConfig(jitter=False)) + error_info = RetryErrorInfo(status_code=400) + decision = DefaultRetryPolicy().decide(cfg, 0, error_info) + assert decision.should_retry is False + + def test_custom_status_codes_override_defaults(self): + cfg = ModelRetryConfig( + num_retries=2, + backoff=ExponentialBackoffConfig(jitter=False), + rules=RetryRuleConfig( + retryable_error_codes=["418"], + non_retryable_error_codes=["429"], + ), + ) + assert DefaultRetryPolicy().decide(cfg, 0, RetryErrorInfo(status_code=429)).should_retry is False + assert DefaultRetryPolicy().decide(cfg, 0, RetryErrorInfo(status_code=418)).should_retry is True + assert DefaultRetryPolicy().decide(cfg, 0, RetryErrorInfo(status_code=500)).should_retry is False + + def test_retryable_exception_class_name_match(self): + cfg = ModelRetryConfig(num_retries=2, backoff=ExponentialBackoffConfig(jitter=False)) + assert DefaultRetryPolicy().decide( + cfg, 0, RetryErrorInfo(exception_class_name_matched=True)).should_retry is True + assert DefaultRetryPolicy().decide( + cfg, 0, RetryErrorInfo(exception_class_name_matched=False)).should_retry is False + + def test_explicit_backoff_strategy_overrides_config_registry(self): + cfg = ModelRetryConfig(num_retries=2, backoff=ExponentialBackoffConfig(jitter=False)) + error_info = RetryErrorInfo(status_code=429) + decision = DefaultRetryPolicy(backoff_strategy=_ConstantDelayBackoffStrategy()).decide(cfg, 0, error_info) + assert decision.delay == 7.0 + + def test_backoff_strategy_registry_can_replace_builtin_strategy(self): + cfg = ModelRetryConfig(num_retries=2, backoff=FixedBackoffConfig(interval=3.0, jitter=False)) + error_info = RetryErrorInfo(status_code=429) + try: + register_backoff_strategy("fixed", _ConstantDelayBackoffStrategy()) + decision = DefaultRetryPolicy().decide(cfg, 0, error_info) + assert decision.delay == 7.0 + finally: + register_backoff_strategy("fixed", FixedBackoffStrategy()) + + def test_registered_backoff_strategy_supports_new_config_type(self): + cfg = ModelRetryConfig(num_retries=2, backoff=RegisteredBackoffConfig(type="linear", step=0.5)) + error_info = RetryErrorInfo(status_code=429) + register_backoff_strategy("linear", _ConstantDelayBackoffStrategy()) + decision = DefaultRetryPolicy().decide(cfg, 0, error_info) + assert decision.delay == 7.0 + + def test_unregistered_backoff_strategy_fails_fast(self): + cfg = ModelRetryConfig(num_retries=2, backoff=RegisteredBackoffConfig(type="unknown")) + error_info = RetryErrorInfo(status_code=429) + try: + DefaultRetryPolicy().decide(cfg, 0, error_info) + except ValueError as exc: + assert "No backoff strategy registered" in str(exc) + else: + raise AssertionError("Expected ValueError for unregistered backoff strategy") + + +class TestRetryModelCall: + + def _retry_cfg(self, **kw): + base = dict( + num_retries=2, + backoff=ExponentialBackoffConfig(jitter=False, initial_backoff=0.0, max_backoff=0.0), + ) + base.update(kw) + return ModelRetryConfig(**base) + + async def test_no_config_converts_exception_without_retry(self): + attempts = 0 + + async def call_model() -> AsyncGenerator[LlmResponse, None]: + nonlocal attempts + attempts += 1 + raise _StatusError(429) + yield + + responses = await _collect(call_model) + assert attempts == 1 + assert responses[0].error_code == "API_ERROR" + assert responses[0].custom_metadata == {"error": "status 429"} + + async def test_retry_exception_then_success(self): + attempts = 0 + + async def call_model() -> AsyncGenerator[LlmResponse, None]: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise _StatusError(429) + yield _content_response("ok") + + with patch("trpc_agent_sdk.models._retry.asyncio.sleep", new=AsyncMock()) as sleep: + responses = await _collect(call_model, self._retry_cfg()) + assert attempts == 2 + assert sleep.await_count == 1 + assert responses[-1].content.parts[0].text == "ok" + assert all(response.error_code is None for response in responses) + + async def test_exhausts_budget_then_yields_error(self): + attempts = 0 + + async def call_model() -> AsyncGenerator[LlmResponse, None]: + nonlocal attempts + attempts += 1 + raise _StatusError(500) + yield + + with patch("trpc_agent_sdk.models._retry.asyncio.sleep", new=AsyncMock()) as sleep: + responses = await _collect(call_model, self._retry_cfg(num_retries=2)) + assert attempts == 3 + assert sleep.await_count == 2 + assert responses[-1].error_code == "API_ERROR" + assert responses[-1].custom_metadata == {"error": "status 500"} + + async def test_non_retryable_not_retried(self): + attempts = 0 + + async def call_model() -> AsyncGenerator[LlmResponse, None]: + nonlocal attempts + attempts += 1 + raise _StatusError(400) + yield + + with patch("trpc_agent_sdk.models._retry.asyncio.sleep", new=AsyncMock()) as sleep: + responses = await _collect(call_model, self._retry_cfg()) + assert attempts == 1 + assert sleep.await_count == 0 + assert responses[-1].error_code == "API_ERROR" + + async def test_no_status_code_exception_class_name_retried(self): + attempts = 0 + + async def call_model() -> AsyncGenerator[LlmResponse, None]: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise _NamedConnection("offline") + yield _content_response("ok") + + with patch("trpc_agent_sdk.models._retry.asyncio.sleep", new=AsyncMock()) as sleep: + responses = await _collect(call_model, self._retry_cfg()) + assert attempts == 2 + assert sleep.await_count == 1 + assert responses[-1].content.parts[0].text == "ok" + + async def test_no_retry_after_content_emitted(self): + attempts = 0 + + async def call_model() -> AsyncGenerator[LlmResponse, None]: + nonlocal attempts + attempts += 1 + yield _content_response("partial", partial=True) + raise _StatusError(429) + + with patch("trpc_agent_sdk.models._retry.asyncio.sleep", new=AsyncMock()) as sleep: + responses = await _collect(call_model, self._retry_cfg()) + assert attempts == 1 + assert sleep.await_count == 0 + assert responses[0].content.parts[0].text == "partial" + assert responses[1].error_code == "API_ERROR" + + async def test_closes_interrupted_attempt_before_retry(self): + closed_attempts = [] + + async def first_attempt() -> AsyncGenerator[LlmResponse, None]: + try: + raise _StatusError(429) + yield + finally: + closed_attempts.append("first") + + async def second_attempt() -> AsyncGenerator[LlmResponse, None]: + yield _content_response("ok") + + attempts = iter([first_attempt, second_attempt]) + with patch("trpc_agent_sdk.models._retry.asyncio.sleep", new=AsyncMock()): + responses = await _collect(lambda: next(attempts)(), self._retry_cfg()) + assert closed_attempts == ["first"] + assert responses[-1].content.parts[0].text == "ok" diff --git a/trpc_agent_sdk/configs/__init__.py b/trpc_agent_sdk/configs/__init__.py index 3b1c56c5..2019e077 100644 --- a/trpc_agent_sdk/configs/__init__.py +++ b/trpc_agent_sdk/configs/__init__.py @@ -5,10 +5,22 @@ # tRPC-Agent-Python is licensed under Apache-2.0. """Configs for TRPC Agent framework.""" +from ._model_retry_config import BackoffConfig +from ._model_retry_config import ExponentialBackoffConfig +from ._model_retry_config import FixedBackoffConfig +from ._model_retry_config import ModelRetryConfig +from ._model_retry_config import RegisteredBackoffConfig +from ._model_retry_config import RetryRuleConfig from ._prompt_cache_config import PromptCacheConfig from ._run_config import RunConfig __all__ = [ + "BackoffConfig", + "ExponentialBackoffConfig", + "FixedBackoffConfig", + "ModelRetryConfig", "PromptCacheConfig", + "RegisteredBackoffConfig", + "RetryRuleConfig", "RunConfig", ] diff --git a/trpc_agent_sdk/configs/_model_retry_config.py b/trpc_agent_sdk/configs/_model_retry_config.py new file mode 100644 index 00000000..76cd383e --- /dev/null +++ b/trpc_agent_sdk/configs/_model_retry_config.py @@ -0,0 +1,124 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Model retry configuration for TRPC Agent framework.""" + +from __future__ import annotations + +from typing import Any +from typing import Literal + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import model_validator + + +class RetryRuleConfig(BaseModel): + """Provider-agnostic retryability rules for model errors.""" + + model_config = ConfigDict(extra="forbid") + + retryable_error_codes: list[str] = Field(default_factory=lambda: ["408", "409", "429", "500", "502", "503", "504"]) + """Exact HTTP status codes that should be retried.""" + + non_retryable_error_codes: list[str] = Field(default_factory=lambda: ["400", "401", "403"]) + """Exact HTTP status codes that should never be retried.""" + + retryable_exception_class_name_parts: list[str] = Field(default_factory=lambda: ["Timeout", "Connection", "Transport"]) + """Exception class name fragments that should be retried when no HTTP status code is available.""" + + +class ExponentialBackoffConfig(BaseModel): + """Configuration for exponential retry backoff.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["exponential"] = "exponential" + """Backoff strategy discriminator.""" + + initial_backoff: float = Field(default=1.0, ge=0.0) + """Base backoff in seconds for the first exponential retry.""" + + max_backoff: float = Field(default=10.0, ge=0.0) + """Upper bound in seconds for any single computed backoff.""" + + multiplier: float = Field(default=2.0, ge=1.0) + """Exponential growth factor per attempt.""" + + jitter: bool = True + """Whether to apply full jitter to computed backoff values.""" + + respect_retry_after: bool = True + """Honor provider ``Retry-After`` or ``retry-after-ms`` hints when present.""" + + +class FixedBackoffConfig(BaseModel): + """Configuration for fixed-interval retry backoff.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["fixed"] = "fixed" + """Backoff strategy discriminator.""" + + interval: float = Field(default=1.0, ge=0.0) + """Fixed interval in seconds between retry attempts.""" + + max_backoff: float = Field(default=10.0, ge=0.0) + """Upper bound in seconds for any single computed backoff.""" + + jitter: bool = True + """Whether to apply full jitter to computed backoff values.""" + + respect_retry_after: bool = True + """Honor provider ``Retry-After`` or ``retry-after-ms`` hints when present.""" + + +class RegisteredBackoffConfig(BaseModel): + """Configuration for externally registered retry backoff strategies.""" + + model_config = ConfigDict(extra="allow") + + type: str + """Backoff strategy discriminator registered with the retry strategy registry.""" + + +BackoffConfig = ExponentialBackoffConfig | FixedBackoffConfig | RegisteredBackoffConfig +"""Retry backoff configuration, including built-in and registered strategy configs.""" + + +class ModelRetryConfig(BaseModel): + """SDK-managed, provider-agnostic model retry configuration.""" + + model_config = ConfigDict(extra="forbid") + + num_retries: int = Field(default=2, ge=0) + """Retry attempts in addition to the initial call.""" + + rules: RetryRuleConfig = Field(default_factory=RetryRuleConfig) + """Rules used to decide whether a classified model error is retryable.""" + + backoff: BackoffConfig = Field(default_factory=ExponentialBackoffConfig) + """Backoff strategy configuration used between retries.""" + + @model_validator(mode="before") + @classmethod + def _normalize_backoff(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + backoff = data.get("backoff") + if not isinstance(backoff, dict): + return data + + backoff_type = backoff.get("type", "exponential") + normalized = dict(data) + if backoff_type == "exponential": + normalized["backoff"] = ExponentialBackoffConfig(**backoff) + elif backoff_type == "fixed": + normalized["backoff"] = FixedBackoffConfig(**backoff) + else: + normalized["backoff"] = RegisteredBackoffConfig(**backoff) + return normalized diff --git a/trpc_agent_sdk/models/_anthropic_model.py b/trpc_agent_sdk/models/_anthropic_model.py index 39912a9d..0eac8500 100644 --- a/trpc_agent_sdk/models/_anthropic_model.py +++ b/trpc_agent_sdk/models/_anthropic_model.py @@ -37,7 +37,6 @@ from ._llm_request import LlmRequest from ._llm_response import LlmResponse from ._registry import register_model - _EPHEMERAL = "ephemeral" @@ -533,14 +532,7 @@ async def _generate_single( client = self._create_async_client() try: response = await client.messages.create(**api_params) - return self._message_to_llm_response(response) - except Exception as ex: # pylint: disable=broad-except - logger.error("Anthropic API error: %s", ex) - return LlmResponse(content=None, - error_code="API_ERROR", - error_message=str(ex), - custom_metadata={"error": str(ex)}) finally: await client.close() @@ -565,7 +557,6 @@ async def _generate_stream( client = self._create_async_client() try: logger.debug("Anthropic invoke with params: %s", api_params) - logger.debug("Anthropic invoke with params: %s", api_params) async with client.messages.stream(**api_params) as stream: async for event in stream: @@ -667,16 +658,9 @@ async def _generate_stream( partial=False, custom_metadata={"stream_complete": True}) - except Exception as ex: # pylint: disable=broad-except - logger.error("Error in streaming response: %s", ex, exc_info=True) - logger.error("Error in streaming response: %s", ex, exc_info=True) - yield LlmResponse( - content=None, - error_code="STREAMING_ERROR", - error_message=f"Error in streaming: {str(ex)}", - partial=False, - custom_metadata={"error": str(ex)}, - ) + except Exception: + logger.error("Error in streaming response", exc_info=True) + raise finally: await client.close() @@ -790,17 +774,9 @@ async def _generate_async_impl(self, self._apply_prompt_cache(api_params, ctx) - try: - if stream: - async for response in self._generate_stream(api_params, request, ctx): - yield response - else: - response = await self._generate_single(api_params, request, ctx) + if stream: + async for response in self._generate_stream(api_params, request, ctx): yield response - except Exception as ex: # pylint: disable=broad-except - logger.error("Anthropic API error: %s", ex) - # Create error response - yield LlmResponse(content=None, - error_code="API_ERROR", - error_message=str(ex), - custom_metadata={"error": str(ex)}) + else: + response = await self._generate_single(api_params, request, ctx) + yield response diff --git a/trpc_agent_sdk/models/_litellm_model.py b/trpc_agent_sdk/models/_litellm_model.py index 5b5a2a0c..bd4d0b94 100644 --- a/trpc_agent_sdk/models/_litellm_model.py +++ b/trpc_agent_sdk/models/_litellm_model.py @@ -30,7 +30,6 @@ from ._openai_model import FinishReason from ._openai_model import OpenAIModel from ._registry import register_model - # Cache families for LiteLLM provider routing. _ANTHROPIC_FAMILY = "anthropic" # uses cache_control breakpoints _OPENAI_FAMILY = "openai_managed" # uses provider-managed prefix caching @@ -485,21 +484,12 @@ async def _generate_async_impl( self._apply_prompt_cache(api_params, ctx) - try: - if stream: - async for response in self._generate_stream(api_params, request, ctx): - yield response - else: - response = await self._generate_single(api_params, request, ctx) + if stream: + async for response in self._generate_stream(api_params, request, ctx): yield response - except Exception as ex: # pylint: disable=broad-except - logger.error("LiteLLM API error: %s", ex) - yield LlmResponse( - content=None, - error_code="API_ERROR", - error_message=str(ex), - custom_metadata={"error": str(ex)}, - ) + else: + response = await self._generate_single(api_params, request, ctx) + yield response async def _generate_single( self, @@ -508,20 +498,11 @@ async def _generate_single( ctx: InvocationContext | None = None, ) -> LlmResponse: """One-shot acompletion → LlmResponse.""" - try: - litellm = __import__("litellm") - acompletion = getattr(litellm, "acompletion") - response = await acompletion(**api_params) - response_dict: Dict[str, Any] = (response.model_dump() if hasattr(response, "model_dump") else response) - return self._create_response_with_content(response_dict, partial=False) - except Exception as ex: # pylint: disable=broad-except - logger.error("LiteLLM API error: %s", ex) - return LlmResponse( - content=None, - error_code="API_ERROR", - error_message=str(ex), - custom_metadata={"error": str(ex)}, - ) + litellm = __import__("litellm") + acompletion = getattr(litellm, "acompletion") + response = await acompletion(**api_params) + response_dict: Dict[str, Any] = (response.model_dump() if hasattr(response, "model_dump") else response) + return self._create_response_with_content(response_dict, partial=False) async def _generate_stream( self, @@ -589,11 +570,6 @@ async def _generate_stream( partial=False, custom_metadata={"stream_complete": True}, ) - except Exception as ex: # pylint: disable=broad-except - logger.error("Error in streaming response: %s", ex, exc_info=True) - yield LlmResponse( - content=None, - error_code="STREAMING_ERROR", - error_message=str(ex), - custom_metadata={"error": str(ex)}, - ) + except Exception: + logger.error("Error in streaming response", exc_info=True) + raise diff --git a/trpc_agent_sdk/models/_llm_model.py b/trpc_agent_sdk/models/_llm_model.py index e892b31a..a3aaf3a3 100644 --- a/trpc_agent_sdk/models/_llm_model.py +++ b/trpc_agent_sdk/models/_llm_model.py @@ -17,16 +17,17 @@ from typing import Optional from typing import final +from trpc_agent_sdk.configs import ModelRetryConfig from trpc_agent_sdk.configs import PromptCacheConfig from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.context import create_agent_context from trpc_agent_sdk.filter import BaseFilter from trpc_agent_sdk.filter import FilterRunner from trpc_agent_sdk.filter import FilterType - from . import _constants as const from ._llm_request import LlmRequest from ._llm_response import LlmResponse +from ._retry import retry_model_call _VALID_ROLES: set[str] = {const.USER, const.ASSISTANT, const.MODEL, const.SYSTEM} @@ -39,6 +40,7 @@ def __init__( model_name: str, filters_name: Optional[list[str]] = None, prompt_cache_config: Optional[PromptCacheConfig] = None, + model_retry_config: Optional[ModelRetryConfig] = None, **kwargs, ): filters: list = kwargs.get("filters", []) @@ -46,6 +48,7 @@ def __init__( self._model_name = model_name self.config = kwargs self.prompt_cache_config = prompt_cache_config + self.model_retry_config = model_retry_config self._type = FilterType.MODEL self._init_filters() self._api_key: str = kwargs.get(const.API_KEY, "") @@ -111,7 +114,9 @@ async def generate_async(self, For streaming, yields multiple partial responses. Error responses should have error_code and error_message set. """ - handle = partial(self._generate_async_impl, request, stream, ctx) # type: ignore + call_model = partial(self._generate_async_impl, request, stream, ctx) # type: ignore + error_code = "STREAMING_ERROR" if stream else "API_ERROR" + run_with_retry = partial(retry_model_call, call_model, self.model_retry_config, error_code=error_code) extra_filters: list[BaseFilter] = [] if ctx: agent_context = ctx.agent_context @@ -122,9 +127,11 @@ async def generate_async(self, else: agent_context = create_agent_context() - async for event in self._run_stream_filters(agent_context, request, handle, extra_filters): # type: ignore + async for event in self._run_stream_filters(agent_context, request, run_with_retry, + extra_filters): # type: ignore yield event # type: ignore + @abstractmethod async def _generate_async_impl(self, request: LlmRequest, diff --git a/trpc_agent_sdk/models/_openai_model.py b/trpc_agent_sdk/models/_openai_model.py index 1819da15..6056011c 100644 --- a/trpc_agent_sdk/models/_openai_model.py +++ b/trpc_agent_sdk/models/_openai_model.py @@ -1511,20 +1511,12 @@ async def _generate_async_impl(self, # set thinking params self._set_thinking(request, http_options) - try: - if stream: - async for response in self._generate_stream(api_params, request, http_options, ctx): - yield response - else: - response = await self._generate_single(api_params, request, http_options, ctx) + if stream: + async for response in self._generate_stream(api_params, request, http_options, ctx): yield response - except Exception as ex: # pylint: disable=broad-except - logger.error("OpenAI API error: %s", ex) - # Create error response using LlmResponse fields - yield LlmResponse(content=None, - error_code="API_ERROR", - error_message=str(ex), - custom_metadata={"error": str(ex)}) + else: + response = await self._generate_single(api_params, request, http_options, ctx) + yield response async def _generate_stream(self, api_params: Dict, @@ -1761,15 +1753,5 @@ async def _generate_stream(self, custom_metadata={"stream_complete": True}, ) - except Exception as ex: # pylint: disable=broad-except - logger.error("Error in streaming response: %s", ex, exc_info=True) - # Create error response using LlmResponse fields - yield LlmResponse( - content=None, - error_code="STREAMING_ERROR", - error_message=f"Error in streaming: {str(ex)}", - partial=False, - custom_metadata={"error": str(ex)}, - ) finally: await client.close() diff --git a/trpc_agent_sdk/models/_retry.py b/trpc_agent_sdk/models/_retry.py new file mode 100644 index 00000000..4d550a1e --- /dev/null +++ b/trpc_agent_sdk/models/_retry.py @@ -0,0 +1,271 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Provider-agnostic model retry policy and execution utilities.""" + +from __future__ import annotations + +import asyncio +import email.utils +import random +import time +from abc import ABC +from abc import abstractmethod +from collections.abc import AsyncGenerator +from collections.abc import Callable +from dataclasses import dataclass +from typing import Optional + +from trpc_agent_sdk.configs import BackoffConfig +from trpc_agent_sdk.configs import ExponentialBackoffConfig +from trpc_agent_sdk.configs import FixedBackoffConfig +from trpc_agent_sdk.configs import ModelRetryConfig +from trpc_agent_sdk.configs import RetryRuleConfig +from trpc_agent_sdk.log import logger + +from ._llm_response import LlmResponse + +@dataclass(frozen=True) +class RetryErrorInfo: + """Provider error information used by retry rules.""" + + status_code: Optional[int] = None + retry_after: Optional[float] = None + exception_class_name_matched: bool = False + + +@dataclass(frozen=True) +class RetryDecision: + """A policy decision for one failed attempt.""" + + should_retry: bool + delay: Optional[float] + error_info: RetryErrorInfo + + +class BackoffStrategy(ABC): + """Computes the delay before the next retry attempt.""" + + @abstractmethod + def compute_delay(self, config: BackoffConfig, attempt: int, retry_after: Optional[float]) -> float: + """Return seconds to wait before retrying.""" + + +class RetryPolicy(ABC): + """Decides whether a model error should be retried.""" + + @abstractmethod + def decide(self, config: ModelRetryConfig, attempt: int, error_info: RetryErrorInfo) -> RetryDecision: + """Return the retry decision for the failed attempt.""" + + +def _analyze_exception(ex: Exception, rules: RetryRuleConfig) -> RetryErrorInfo: + status_code = _extract_status_code(ex) + return RetryErrorInfo( + status_code=status_code, + retry_after=_extract_retry_after(ex), + exception_class_name_matched=( + status_code is None and _has_type_name(ex, *rules.retryable_exception_class_name_parts) + ), + ) + + +def _extract_status_code(ex: Exception) -> Optional[int]: + status_code = getattr(ex, "status_code", None) + if isinstance(status_code, int): + return status_code + if isinstance(status_code, str): + try: + return int(status_code) + except ValueError: + return None + return None + + +def _has_type_name(ex: Exception, *needles: str) -> bool: + for klass in type(ex).__mro__: + if any(needle in klass.__name__ for needle in needles): + return True + return False + + +def _extract_retry_after(ex: Exception) -> Optional[float]: + headers = _extract_headers(ex) + get_header = getattr(headers, "get", None) + if get_header is None: + return None + + retry_after_ms = get_header("retry-after-ms") + if retry_after_ms is not None: + try: + return float(retry_after_ms) / 1000.0 + except (TypeError, ValueError): + pass + + retry_after = get_header("retry-after") + if retry_after is None: + return None + try: + return float(retry_after) + except (TypeError, ValueError): + parsed = email.utils.parsedate_tz(retry_after) + if parsed is None: + return None + delta = email.utils.mktime_tz(parsed) - time.time() + return delta if delta > 0 else 0.0 + + +def _extract_headers(ex: Exception): + litellm_headers = getattr(ex, "litellm_response_headers", None) + if litellm_headers is not None: + return litellm_headers + + response = getattr(ex, "response", None) + response_headers = getattr(response, "headers", None) + if response_headers is not None: + return response_headers + + return getattr(ex, "headers", None) + + +class ExponentialBackoffStrategy(BackoffStrategy): + """Exponential backoff with optional full jitter.""" + + def compute_delay(self, config: BackoffConfig, attempt: int, retry_after: Optional[float]) -> float: + if not isinstance(config, ExponentialBackoffConfig): + raise TypeError("ExponentialBackoffStrategy requires ExponentialBackoffConfig") + + if config.respect_retry_after and retry_after is not None and retry_after >= 0: + return float(retry_after) + + delay = config.initial_backoff * (config.multiplier**attempt) + delay = min(delay, config.max_backoff) + if config.jitter: + return random.uniform(0.0, delay) + return delay + + +class FixedBackoffStrategy(BackoffStrategy): + """Fixed backoff with optional full jitter.""" + + def compute_delay(self, config: BackoffConfig, attempt: int, retry_after: Optional[float]) -> float: + if not isinstance(config, FixedBackoffConfig): + raise TypeError("FixedBackoffStrategy requires FixedBackoffConfig") + + if config.respect_retry_after and retry_after is not None and retry_after >= 0: + return float(retry_after) + + delay = min(config.interval, config.max_backoff) + if config.jitter: + return random.uniform(0.0, delay) + return delay + + +_BACKOFF_STRATEGIES: dict[str, BackoffStrategy] = {} + + +def register_backoff_strategy(type_name: str, strategy: BackoffStrategy) -> None: + """Register or replace a backoff strategy for a backoff config discriminator.""" + _BACKOFF_STRATEGIES[type_name] = strategy + + +class DefaultRetryPolicy(RetryPolicy): + """Default retry policy used by model calls.""" + + def __init__(self, backoff_strategy: Optional[BackoffStrategy] = None): + self.backoff_strategy = backoff_strategy + + def decide(self, config: ModelRetryConfig, attempt: int, error_info: RetryErrorInfo) -> RetryDecision: + if attempt >= config.num_retries: + return RetryDecision(False, None, error_info) + + if not self._is_retryable(config, error_info): + return RetryDecision(False, None, error_info) + + strategy = self.backoff_strategy or _make_backoff_strategy(config.backoff) + delay = strategy.compute_delay(config.backoff, attempt, error_info.retry_after) + return RetryDecision(True, delay, error_info) + + def _is_retryable(self, config: ModelRetryConfig, error_info: RetryErrorInfo) -> bool: + status_code = error_info.status_code + if status_code is not None: + status = str(status_code) + non_retryable_codes = set(config.rules.non_retryable_error_codes) + retryable_codes = set(config.rules.retryable_error_codes) + if status in non_retryable_codes: + return False + return status in retryable_codes + + return error_info.exception_class_name_matched + + +def _make_backoff_strategy(config: BackoffConfig) -> BackoffStrategy: + strategy = _BACKOFF_STRATEGIES.get(config.type) + if strategy is None: + raise ValueError(f"No backoff strategy registered for type: {config.type}") + return strategy + + +register_backoff_strategy("exponential", ExponentialBackoffStrategy()) +register_backoff_strategy("fixed", FixedBackoffStrategy()) + + +def _build_error_response(ex: Exception, error_code: str) -> LlmResponse: + return LlmResponse( + content=None, + error_code=error_code, + error_message=str(ex), + custom_metadata={"error": str(ex)}, + ) + + +async def retry_model_call( + call_model: Callable[[], AsyncGenerator[LlmResponse, None]], + retry_config: Optional[ModelRetryConfig], + *, + error_code: str = "API_ERROR", + policy: Optional[RetryPolicy] = None, +) -> AsyncGenerator[LlmResponse, None]: + """Execute a model call with SDK-managed retry. + + Retries only when an attempt raises before emitting user-visible content. Once + content has been yielded, subsequent failures are converted to a final error + response and surfaced without replaying the request. + """ + retry_rules = retry_config.rules if retry_config is not None else RetryRuleConfig() + active_policy = policy or DefaultRetryPolicy() + attempt = 0 + + while True: + produced_content = False + attempt_stream = call_model() + try: + async for response in attempt_stream: + if response.has_content(): + produced_content = True + yield response + return + except Exception as ex: # pylint: disable=broad-except + error_info = _analyze_exception(ex, retry_rules) + if retry_config is None or produced_content: + yield _build_error_response(ex, error_code) + return + + decision = active_policy.decide(retry_config, attempt, error_info) + if not decision.should_retry: + yield _build_error_response(ex, error_code) + return + + logger.warning( + "Model call failed (status_code=%s); retrying in %.2fs (attempt %d/%d).", + error_info.status_code, + decision.delay or 0.0, + attempt + 1, + retry_config.num_retries, + ) + await asyncio.sleep(decision.delay or 0.0) + attempt += 1 + finally: + await attempt_stream.aclose()