diff --git a/src/webwright/models/anthropic_model.py b/src/webwright/models/anthropic_model.py index 53393b5..cf54077 100644 --- a/src/webwright/models/anthropic_model.py +++ b/src/webwright/models/anthropic_model.py @@ -101,21 +101,33 @@ def _usage_from_anthropic_payload(payload: dict[str, Any]) -> dict[str, int]: input_tokens = _safe_int(usage.get("input_tokens")) output_tokens = _safe_int(usage.get("output_tokens")) cached_input_tokens = _safe_int(usage.get("cache_read_input_tokens")) + cache_creation_tokens = _safe_int(usage.get("cache_creation_input_tokens")) return { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, "cached_input_tokens": cached_input_tokens, + "cache_creation_input_tokens": cache_creation_tokens, "reasoning_output_tokens": 0, } def _metrics_input_from_anthropic( - system_prompt: str | None, anthropic_messages: list[dict[str, Any]] + system_prompt: str | list[dict[str, Any]] | None, + anthropic_messages: list[dict[str, Any]], ) -> list[dict[str, Any]]: items: list[dict[str, Any]] = [] if system_prompt: - items.append({"content": [{"type": "input_text", "text": system_prompt}]}) + if isinstance(system_prompt, str): + items.append({"content": [{"type": "input_text", "text": system_prompt}]}) + else: + parts = [ + {"type": "input_text", "text": block.get("text", "")} + for block in system_prompt + if isinstance(block, dict) and block.get("type") == "text" + ] + if parts: + items.append({"content": parts}) for msg in anthropic_messages: content = msg.get("content", "") if isinstance(content, str): @@ -179,7 +191,13 @@ def _build_payload(self, messages: list[dict[str, Any]]) -> dict[str, Any]: "max_tokens": self.config.max_output_tokens, } if system_prompt: - payload["system"] = system_prompt + payload["system"] = [ + { + "type": "text", + "text": system_prompt, + "cache_control": {"type": "ephemeral"}, + } + ] return payload def _request_metrics_input(self, payload: dict[str, Any]) -> list[dict[str, Any]]: diff --git a/src/webwright/models/base.py b/src/webwright/models/base.py index 7728bb8..ba4115e 100644 --- a/src/webwright/models/base.py +++ b/src/webwright/models/base.py @@ -197,6 +197,7 @@ def _request_metrics_from_serialized_input(serialized_input: list[dict[str, Any] "output_tokens", "total_tokens", "cached_input_tokens", + "cache_creation_input_tokens", "reasoning_output_tokens", )