From 77ff162f0d4669109bcce4b00d1060ddaf890a49 Mon Sep 17 00:00:00 2001 From: codex Date: Sat, 27 Jun 2026 09:58:25 +0000 Subject: [PATCH] refactor(llm): simplify utils/llm, add bounded refine loop, trim tests Cleanup pass over utils/llm plus an LLM-output-quality change and a test trim. Simplify / reuse: - Consolidate the copy-pasted "re-raise LLMError, wrap everything else" ladder into a shared `wrap_llm_errors` contextmanager in base.py; both providers use it (messages/exception types preserved). - Hoist `import re` and the eth_utils selector import to module scope and precompile the trailing-risk-tag regex / cache the marker pattern in ai_explainer.py (no more per-call imports or recompiles). - Fix a stale `_collect_state_reads` docstring. Quality: - Generalize the single self-critique pass into a bounded loop (`_refine_summary`, capped by MAX_REFINE_ROUNDS=3) that stops on PASS. - Default `refine=True` on explain_transaction / explain_batch_transaction so every protocol gets the critique pass. Loop runs on the authoritative summary; detail is still expanded once from the frozen summary. Tests: - Drop the brittle LLM/simulation orchestration tests that mock the API away (call-count/plumbing); keep the deterministic prompt-building and reply-parsing tests. Net -376 lines. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/test_ai_explainer.py | 376 -------------------------------- utils/llm/ai_explainer.py | 124 +++++++---- utils/llm/anthropic_provider.py | 14 +- utils/llm/base.py | 30 ++- utils/llm/openai_compat.py | 21 +- 5 files changed, 114 insertions(+), 451 deletions(-) diff --git a/tests/test_ai_explainer.py b/tests/test_ai_explainer.py index eace798a..6040449f 100644 --- a/tests/test_ai_explainer.py +++ b/tests/test_ai_explainer.py @@ -11,17 +11,12 @@ _build_prompt, _collect_safety_checks, _collect_token_flows, - _expand_detail, _explanation_from_json, _format_decimal, - _generate_explanation, - _generate_summary, _parse_explanation, - _refine_summary, explain_transaction, format_explanation_line, ) -from utils.llm.base import LLMError from utils.source_context import SourceContext from utils.tenderly.simulation import SimulationResult @@ -213,79 +208,6 @@ def test_mixed_signatures_no_section(self) -> None: self.assertNotIn("--- Shared Across Batch ---", result) -class TestSkipSimulation(unittest.TestCase): - """Tests for skip_simulation flag.""" - - @patch("utils.llm.ai_explainer.get_source_context", return_value=None) - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction") - @patch("utils.llm.ai_explainer.decode_calldata") - def test_explain_transaction_skips_tenderly_when_flag_set( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_source: MagicMock, - ) -> None: - mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") - mock_provider = MagicMock() - mock_provider.supports_structured_output = False - mock_provider.complete.return_value = "TLDR: paused" - mock_provider.model_name = "test-model" - mock_get_provider.return_value = mock_provider - - explain_transaction( - target="0xT", - calldata="0x8456cb59", - chain_id=1, - skip_simulation=True, - context_note="delegated", - ) - - mock_simulate.assert_not_called() - prompt = mock_provider.complete.call_args[0][0] - self.assertIn("--- Execution Context ---", prompt) - self.assertIn("delegated", prompt) - self.assertNotIn("--- Simulation Results ---", prompt) - - @patch("utils.llm.ai_explainer.get_source_context", return_value=None) - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction") - @patch("utils.llm.ai_explainer.decode_calldata") - def test_explain_batch_transaction_skips_tenderly_when_flag_set( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_source: MagicMock, - ) -> None: - mock_decode.return_value = DecodedCall( - function_name="swapOwner", signature="swapOwner(address,address,address)" - ) - mock_provider = MagicMock() - mock_provider.supports_structured_output = False - mock_provider.complete.return_value = "TLDR: swap" - mock_provider.model_name = "test-model" - mock_get_provider.return_value = mock_provider - - from utils.llm.ai_explainer import explain_batch_transaction - - explain_batch_transaction( - calls=[ - {"target": "0xSafe", "data": "0xe318b52b" + "00" * 96, "value": "0"}, - {"target": "0xSafe", "data": "0xe318b52b" + "11" * 96, "value": "0"}, - ], - chain_id=1, - skip_simulation=True, - context_note="delegated batch", - ) - - mock_simulate.assert_not_called() - prompt = mock_provider.complete.call_args[0][0] - self.assertIn("delegated batch", prompt) - self.assertNotIn("--- Simulation Results ---", prompt) - - class TestExplainTransaction(unittest.TestCase): """Tests for explain_transaction.""" @@ -297,120 +219,12 @@ def test_short_calldata_returns_none(self) -> None: result = explain_transaction(target="0xTarget", calldata="0x1234", chain_id=1) self.assertIsNone(result) - @patch("utils.llm.ai_explainer.get_source_context", return_value=None) - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction") - @patch("utils.llm.ai_explainer.decode_calldata") - def test_successful_explanation( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_source: MagicMock, - ) -> None: - mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") - mock_simulate.return_value = SimulationResult(success=True, gas_used=50000) - mock_provider = MagicMock() - mock_provider.supports_structured_output = False - mock_provider.complete.return_value = "TLDR: This pauses the protocol.\n\nDETAIL:\nPauses all operations." - mock_provider.model_name = "test-model" - mock_get_provider.return_value = mock_provider - - result = explain_transaction( - target="0xTarget", - calldata="0x8456cb59", # pause() - chain_id=1, - protocol="AAVE", - ) - - self.assertIsNotNone(result) - self.assertEqual(result.summary, "This pauses the protocol.") - self.assertEqual(result.detail, "Pauses all operations.") - mock_simulate.assert_called_once() - mock_provider.complete.assert_called_once() - - @patch("utils.llm.ai_explainer.get_source_context", return_value=None) - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction") - @patch("utils.llm.ai_explainer.decode_calldata") - def test_llm_error_returns_none( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_source: MagicMock, - ) -> None: - mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") - mock_simulate.return_value = None - mock_provider = MagicMock() - mock_provider.supports_structured_output = False - mock_provider.complete.side_effect = LLMError("API error") - mock_get_provider.return_value = mock_provider - - result = explain_transaction(target="0xTarget", calldata="0x8456cb59", chain_id=1) - self.assertIsNone(result) - @patch("utils.llm.ai_explainer.decode_calldata") def test_undecoded_calldata_returns_none(self, mock_decode: MagicMock) -> None: mock_decode.return_value = None result = explain_transaction(target="0xTarget", calldata="0x11223344", chain_id=1) self.assertIsNone(result) - @patch("utils.llm.ai_explainer.get_source_context", return_value=None) - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction") - @patch("utils.llm.ai_explainer.decode_calldata") - def test_simulation_failure_still_explains( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_source: MagicMock, - ) -> None: - """If simulation fails, should still explain using decoded calldata only.""" - mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") - mock_simulate.return_value = None # Simulation failed - mock_provider = MagicMock() - mock_provider.supports_structured_output = False - mock_provider.complete.return_value = "TLDR: This pauses the protocol." - mock_provider.model_name = "test-model" - mock_get_provider.return_value = mock_provider - - result = explain_transaction(target="0xTarget", calldata="0x8456cb59", chain_id=1) - self.assertIsNotNone(result) - self.assertEqual(result.summary, "This pauses the protocol.") - - @patch("utils.llm.ai_explainer.get_source_context") - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction") - @patch("utils.llm.ai_explainer.decode_calldata") - def test_source_context_passed_to_llm( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_source: MagicMock, - ) -> None: - """When source context is available, it should be injected into the prompt.""" - mock_decode.return_value = DecodedCall(function_name="setMaxSlippage", signature="setMaxSlippage(uint256)") - mock_simulate.return_value = SimulationResult(success=True, gas_used=50000) - mock_source.return_value = SourceContext( - contract_name="Farm", - function_snippet="function setMaxSlippage(uint256) external;", - state_var_snippets=["/// @dev so actually 1 - slippage\nuint256 public maxSlippage;"], - ) - mock_provider = MagicMock() - mock_provider.supports_structured_output = False - mock_provider.complete.return_value = "TLDR: Tight slippage." - mock_provider.model_name = "test-model" - mock_get_provider.return_value = mock_provider - - explain_transaction(target="0xTarget", calldata="0x12345678" + "00" * 32, chain_id=1) - - prompt = mock_provider.complete.call_args[0][0] - self.assertIn("Contract Source Context", prompt) - self.assertIn("so actually 1 - slippage", prompt) - class TestStructuredOutput(unittest.TestCase): """Tests for the structured-output draft path and JSON→Explanation mapping.""" @@ -429,87 +243,6 @@ def test_schema_tag_overrides_inlined_tag(self) -> None: exp = _explanation_from_json({"summary": "Grants admin role. LOW.", "detail": "d", "risk_tag": "HIGH"}) self.assertEqual(exp.summary, "Grants admin role. HIGH") - def test_generate_summary_uses_structured_when_supported(self) -> None: - provider = MagicMock() - provider.supports_structured_output = True - # Stage-1 schema carries summary + risk_tag only — no detail field. - provider.complete_structured.return_value = {"summary": "Does X", "risk_tag": "HIGH"} - - result = _generate_summary(provider, "prompt") - - provider.complete_structured.assert_called_once() - provider.complete.assert_not_called() - self.assertEqual(result.summary, "Does X HIGH") - self.assertEqual(result.detail, "") - - def test_generate_summary_falls_back_to_text_on_error(self) -> None: - provider = MagicMock() - provider.supports_structured_output = True - provider.complete_structured.side_effect = LLMError("unsupported") - provider.complete.return_value = "TLDR: Does X. LOW." - - result = _generate_summary(provider, "prompt") - - provider.complete.assert_called_once() - self.assertEqual(result.summary, "Does X. LOW.") - - def test_generate_summary_falls_back_on_empty_summary(self) -> None: - provider = MagicMock() - provider.supports_structured_output = True - provider.complete_structured.return_value = {"summary": "", "risk_tag": "LOW"} - provider.complete.return_value = "TLDR: text path. LOW." - - result = _generate_summary(provider, "prompt") - - provider.complete.assert_called_once() - self.assertEqual(result.summary, "text path. LOW.") - - -class TestTwoStageGeneration(unittest.TestCase): - """Tests for _generate_explanation: summary first, then detail derived from it.""" - - def test_detail_expanded_from_summary_on_structured_path(self) -> None: - provider = MagicMock() - provider.supports_structured_output = True - provider.complete_structured.return_value = {"summary": "Transfers 50.78 USDC", "risk_tag": "LOW"} - provider.complete.return_value = "Moves 50.78 USDC from the multisig to five addresses." - - result = _generate_explanation(provider, "ctx prompt") - - # Stage 1 = structured summary, stage 2 = one text completion for the detail. - provider.complete_structured.assert_called_once() - provider.complete.assert_called_once() - self.assertEqual(result.summary, "Transfers 50.78 USDC LOW") - self.assertEqual(result.detail, "Moves 50.78 USDC from the multisig to five addresses.") - # The expansion prompt must carry the confirmed summary so the detail derives from it. - expansion_prompt = provider.complete.call_args[0][0] - self.assertIn("Transfers 50.78 USDC LOW", expansion_prompt) - self.assertIn("ctx prompt", expansion_prompt) - - def test_text_fallback_keeps_joint_detail_single_call(self) -> None: - provider = MagicMock() - provider.supports_structured_output = False - provider.complete.return_value = "TLDR: Does X. LOW.\n\nDETAIL:\njoint detail." - - result = _generate_explanation(provider, "prompt") - - # Degraded path: one completion produces both fields; no second expansion call. - provider.complete.assert_called_once() - self.assertEqual(result.summary, "Does X. LOW.") - self.assertEqual(result.detail, "joint detail.") - - def test_empty_summary_skips_expansion(self) -> None: - provider = MagicMock() - provider.supports_structured_output = True - provider.complete_structured.return_value = {"summary": "", "risk_tag": "LOW"} - provider.complete.return_value = "" # text fallback also empty - - result = _generate_explanation(provider, "prompt") - - self.assertEqual(result.summary, "") - # No expansion attempted when there's no summary to derive from. - self.assertEqual(result.detail, "") - class TestParseExplanation(unittest.TestCase): """Tests for _parse_explanation.""" @@ -579,115 +312,6 @@ def test_format_no_detail(self) -> None: self.assertNotIn("Full details", result) -class TestRefineSummary(unittest.TestCase): - """Tests for _refine_summary (critiques the authoritative summary, not the detail).""" - - def test_pass_keeps_draft_unchanged(self) -> None: - # Trailing whitespace around "PASS" must also count as PASS. - draft = Explanation(summary="Lowers fee 30→25 bps. LOW.", detail="") - provider = MagicMock() - provider.complete.return_value = " PASS \n" - self.assertIs(_refine_summary("orig prompt", draft, provider), draft) - provider.complete.assert_called_once() - - def test_revision_replaces_summary(self) -> None: - draft = Explanation(summary="This transaction does X. LOW.", detail="") - provider = MagicMock() - provider.complete.return_value = "TLDR: Does X. LOW." - result = _refine_summary("orig", draft, provider) - self.assertEqual(result.summary, "Does X. LOW.") - # Detail is produced later (stage 2), so a refined summary carries none. - self.assertEqual(result.detail, "") - - def test_llm_error_falls_back_to_draft(self) -> None: - draft = Explanation(summary="x", detail="") - provider = MagicMock() - provider.complete.side_effect = LLMError("rate limit") - self.assertIs(_refine_summary("p", draft, provider), draft) - - def test_empty_response_falls_back_to_draft(self) -> None: - draft = Explanation(summary="x", detail="") - provider = MagicMock() - provider.complete.return_value = "" - self.assertIs(_refine_summary("p", draft, provider), draft) - - -class TestExpandDetail(unittest.TestCase): - """Tests for _expand_detail (stage 2: detail derived from the confirmed summary).""" - - def test_returns_bare_detail_text(self) -> None: - provider = MagicMock() - provider.complete.return_value = "Pauses all operations; reversible. No fund movement." - detail = _expand_detail(provider, "ctx", "Pauses the protocol. LOW") - self.assertEqual(detail, "Pauses all operations; reversible. No fund movement.") - # The confirmed summary is handed to the expansion call. - self.assertIn("Pauses the protocol. LOW", provider.complete.call_args[0][0]) - - def test_strips_stray_detail_header(self) -> None: - provider = MagicMock() - provider.complete.return_value = "DETAIL:\nThe real detail body." - self.assertEqual(_expand_detail(provider, "ctx", "sum"), "The real detail body.") - - def test_llm_error_returns_empty(self) -> None: - provider = MagicMock() - provider.complete.side_effect = LLMError("boom") - self.assertEqual(_expand_detail(provider, "ctx", "sum"), "") - - -class TestRefineFlagInExplainTransaction(unittest.TestCase): - """Tests that the refine flag triggers a second LLM call.""" - - @patch("utils.llm.ai_explainer.get_source_context", return_value=None) - @patch("utils.llm.ai_explainer._collect_state_reads", return_value=[]) - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction", return_value=None) - @patch("utils.llm.ai_explainer.decode_calldata") - def test_refine_off_makes_one_call( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_state: MagicMock, - mock_source: MagicMock, - ) -> None: - mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") - provider = MagicMock() - provider.supports_structured_output = False - provider.complete.return_value = "TLDR: Pauses. LOW." - provider.model_name = "test-model" - mock_get_provider.return_value = provider - - explain_transaction(target="0xT", calldata="0x8456cb59", chain_id=1) - self.assertEqual(provider.complete.call_count, 1) - - @patch("utils.llm.ai_explainer.get_source_context", return_value=None) - @patch("utils.llm.ai_explainer._collect_state_reads", return_value=[]) - @patch("utils.llm.ai_explainer.get_llm_provider") - @patch("utils.llm.ai_explainer.simulate_transaction", return_value=None) - @patch("utils.llm.ai_explainer.decode_calldata") - def test_refine_on_makes_two_calls( - self, - mock_decode: MagicMock, - mock_simulate: MagicMock, - mock_get_provider: MagicMock, - mock_state: MagicMock, - mock_source: MagicMock, - ) -> None: - mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") - provider = MagicMock() - provider.supports_structured_output = False - provider.complete.side_effect = ["TLDR: Pauses. LOW.", "PASS"] - provider.model_name = "test-model" - mock_get_provider.return_value = provider - - explain_transaction(target="0xT", calldata="0x8456cb59", chain_id=1, refine=True) - self.assertEqual(provider.complete.call_count, 2) - # Second call should contain the critique task - second_call_prompt = provider.complete.call_args_list[1][0][0] - self.assertIn("Critique Task", second_call_prompt) - self.assertIn("Your Previous Draft", second_call_prompt) - - class TestFailedSimulationDropped(unittest.TestCase): """Failed Tenderly simulations must not leak into the LLM prompt.""" diff --git a/utils/llm/ai_explainer.py b/utils/llm/ai_explainer.py index a85c0d9d..c76eac1f 100644 --- a/utils/llm/ai_explainer.py +++ b/utils/llm/ai_explainer.py @@ -5,10 +5,12 @@ transactions (timelocks and Safe multisigs). """ +import re from dataclasses import dataclass from decimal import Decimal +from functools import lru_cache -from eth_utils import to_checksum_address +from eth_utils import function_signature_to_4byte_selector, to_checksum_address from utils.calldata.decoder import DecodedCall, decode_calldata, is_selector_resolvable_offline from utils.erc20_metadata import fetch_erc20_metadata @@ -114,6 +116,11 @@ SYSTEM_INSTRUCTIONS_SUMMARY_JSON = SYSTEM_PROMPT + JSON_SUMMARY_NOTE _RISK_TAGS = ("LOW", "MEDIUM", "HIGH", "CRITICAL") + +# Matches a trailing risk tag with surrounding space/punctuation. Only whitespace +# (not a period) may precede the tag, so the preceding sentence's period is +# preserved: "…vault. LOW." → "…vault." +_TRAILING_RISK_TAG_RE = re.compile(r"\s*\b(?:" + "|".join(_RISK_TAGS) + r")\b[\s.]*$", re.IGNORECASE) DETAIL_REPORT_TITLE = "AI Transaction Analysis" # JSON Schema for stage 1 (summary + risk_tag only). risk_tag is enum-constrained so @@ -180,6 +187,11 @@ Otherwise output the revised TLDR on one line: TLDR: """ +# Max self-critique rounds (see `_refine_summary`). Critique converges in 1-2 +# rounds; past ~3 the model tends to make cosmetic edits or strip hedges, so the +# cap is about quality, not cost. Bump it if you want to spend more calls. +MAX_REFINE_ROUNDS = 3 + @dataclass(frozen=True) class Explanation: @@ -195,9 +207,8 @@ def _collect_state_reads( ) -> list[tuple[str, list[StateRead]]]: """Best-effort: read current on-chain values for state vars each call will write. - Returns a list of (target, reads) tuples in the same order as the input. Empty - reads are still returned (so callers can show per-call ordering); the formatter - skips them. + Returns a list of (target, reads) tuples in input order, one per unique + (target, function) pair. Pairs that yielded no reads are omitted. """ out: list[tuple[str, list[StateRead]]] = [] seen: set[tuple[str, str]] = set() @@ -579,8 +590,6 @@ def _collect_risk_anchors(decoded_calls: list[DecodedCall]) -> str: continue # The decoder normalizes signatures to the 4byte-selector text form, so # we re-compute the selector locally rather than carrying it through. - from eth_utils import function_signature_to_4byte_selector - try: sel = "0x" + function_signature_to_4byte_selector(call.signature).hex() except Exception: # noqa: BLE001 - bad signatures are skipped @@ -892,17 +901,21 @@ def _build_prompt( return "\n".join(parts) +@lru_cache(maxsize=None) +def _marker_pattern(keyword: str) -> "re.Pattern[str]": + """Compile (and cache) the section-marker regex for ``keyword``. + + Handles variations: 'KEYWORD:', '## KEYWORD', '**KEYWORD**', '**KEYWORD:**', etc. + """ + return re.compile(rf"(?:^|\n)\s*(?:#{{1,4}}\s+)?(?:\*{{2}})?{keyword}(?:\*{{2}})?[:\s]*", re.IGNORECASE) + + def _find_marker(text: str, keyword: str) -> tuple[int, int]: """Find a section marker like 'TLDR:' or '### DETAIL' and return (start_of_marker, start_of_content). - Handles variations: 'KEYWORD:', '## KEYWORD', '**KEYWORD**', '**KEYWORD:**', etc. Returns (-1, -1) if not found. """ - import re - - heading = r"#{1,4}" # fmt: skip - pattern = rf"(?:^|\n)\s*(?:{heading}\s+)?(?:\*{{2}})?{keyword}(?:\*{{2}})?[:\s]*" - match = re.search(pattern, text, re.IGNORECASE) + match = _marker_pattern(keyword).search(text) if match: return match.start(), match.end() return -1, -1 @@ -943,12 +956,7 @@ def _parse_explanation(raw: str) -> Explanation: def _strip_trailing_risk_tag(text: str) -> str: """Remove a trailing risk tag (with surrounding space/punctuation) from text.""" - import re - - # Only whitespace (not a period) may precede the tag, so the preceding - # sentence's period is preserved: "…vault. LOW." → "…vault." - pattern = r"\s*\b(?:" + "|".join(_RISK_TAGS) + r")\b[\s.]*$" - return re.sub(pattern, "", text, flags=re.IGNORECASE).rstrip() + return _TRAILING_RISK_TAG_RE.sub("", text).rstrip() def _explanation_from_json(data: dict) -> Explanation: @@ -990,31 +998,53 @@ def _generate_summary(provider: LLMProvider, prompt: str) -> Explanation: return _parse_explanation(raw) -def _refine_summary(original_prompt: str, draft: Explanation, provider: LLMProvider) -> Explanation: - """Self-critique the summary then revise. Returns the draft unchanged on PASS or any error. +def _refine_summary( + original_prompt: str, + draft: Explanation, + provider: LLMProvider, + max_rounds: int = MAX_REFINE_ROUNDS, +) -> Explanation: + """Iteratively self-critique the summary and revise until PASS or ``max_rounds``. - Runs before detail expansion: the summary is authoritative, so it's the artifact - we refine. Only the summary text is rewritten; detail is produced afterward. - """ - refine_prompt = f"{original_prompt}\n\n--- Your Previous Draft ---\nTLDR: {draft.summary}\n\n{SUMMARY_REFINE_TASK}" - - try: - raw = provider.complete(refine_prompt, system_prompt=SYSTEM_INSTRUCTIONS) - except LLMError as e: - logger.warning("Summary refine failed (%s); keeping draft", e) - return draft - - if not raw or not raw.strip() or raw.strip().upper().startswith("PASS"): - logger.info("Summary refine: PASS (no changes)") - return draft + Each round critiques the *current* draft against the checklist: the critic + either returns ``PASS`` (we stop and keep the draft) or a revised ``TLDR:`` + that becomes the next round's draft. Runs before detail expansion — the + summary is authoritative, so it's the artifact we refine; only the summary + text is rewritten. Critique converges fast and over-editing degrades quality, + so rounds are capped rather than looped until the model "feels done". - revised = _parse_explanation(raw) - if not revised.summary: - logger.warning("Summary refine returned empty summary; keeping draft") - return draft + Never raises: any LLM error or empty/invalid response keeps the best draft so far. + """ + for round_num in range(1, max_rounds + 1): + refine_prompt = ( + f"{original_prompt}\n\n--- Your Previous Draft ---\nTLDR: {draft.summary}\n\n{SUMMARY_REFINE_TASK}" + ) + try: + raw = provider.complete(refine_prompt, system_prompt=SYSTEM_INSTRUCTIONS) + except LLMError as e: + logger.warning("Summary refine round %d failed (%s); keeping draft", round_num, e) + return draft + + if not raw or not raw.strip() or raw.strip().upper().startswith("PASS"): + logger.info("Summary refine: PASS at round %d/%d", round_num, max_rounds) + return draft + + revised = _parse_explanation(raw) + if not revised.summary: + logger.warning("Summary refine round %d returned empty summary; keeping draft", round_num) + return draft + + logger.info( + "Summary refine round %d/%d revised TLDR (%d→%d chars)", + round_num, + max_rounds, + len(draft.summary), + len(revised.summary), + ) + draft = Explanation(summary=revised.summary, detail="") - logger.info("Summary refine produced a revision (TLDR %d→%d chars)", len(draft.summary), len(revised.summary)) - return Explanation(summary=revised.summary, detail="") + logger.info("Summary refine reached max rounds (%d) without PASS; using last revision", max_rounds) + return draft def _expand_detail(provider: LLMProvider, prompt: str, summary: str) -> str: @@ -1072,7 +1102,7 @@ def explain_transaction( from_address: str = "0x0000000000000000000000000000000000000000", skip_simulation: bool = False, context_note: str = "", - refine: bool = False, + refine: bool = True, description: str = "", ) -> Explanation | None: """Generate an AI explanation for a governance transaction. @@ -1094,8 +1124,9 @@ def explain_transaction( context_note: Optional preamble injected into the prompt to give the LLM context that isn't in the calldata (e.g. "this is delegated from a Safe; msg.sender of inner calls is the Safe itself"). - refine: If True, runs a second LLM call that critiques the draft against - a checklist and revises only if it finds concrete issues. ~2× cost. + refine: If True (default), runs up to MAX_REFINE_ROUNDS self-critique + passes that revise the summary only when they find concrete issues, + stopping early once the critic returns PASS. Adds 1-3 LLM calls. description: Optional proposer-supplied description of intent. When set, the LLM compares stated intent against the decoded actions and flags any divergence. @@ -1180,7 +1211,7 @@ def explain_batch_transaction( from_address: str = "0x0000000000000000000000000000000000000000", skip_simulation: bool = False, context_note: str = "", - refine: bool = False, + refine: bool = True, description: str = "", ) -> Explanation | None: """Generate an AI explanation for a batch/multicall governance transaction. @@ -1196,8 +1227,9 @@ def explain_batch_transaction( dependent flows (approve+transferFrom, swapOwner+swapOwner, etc). context_note: Optional preamble describing the execution context (e.g. DELEGATECALL semantics) that the LLM can't infer from calldata alone. - refine: If True, runs a second LLM call that critiques the draft against - a checklist and revises only if it finds concrete issues. ~2× cost. + refine: If True (default), runs up to MAX_REFINE_ROUNDS self-critique + passes that revise the summary only when they find concrete issues, + stopping early once the critic returns PASS. Adds 1-3 LLM calls. description: Optional proposer-supplied description of intent. When set, the LLM compares stated intent against the decoded actions and flags any divergence. diff --git a/utils/llm/anthropic_provider.py b/utils/llm/anthropic_provider.py index 42009561..a7fcadaa 100644 --- a/utils/llm/anthropic_provider.py +++ b/utils/llm/anthropic_provider.py @@ -6,7 +6,7 @@ from typing import Any -from utils.llm.base import LLMError, LLMProvider +from utils.llm.base import LLMError, LLMProvider, wrap_llm_errors from utils.logger import get_logger logger = get_logger("utils.llm.anthropic_provider") @@ -50,16 +50,12 @@ def complete(self, prompt: str, system_prompt: str = "") -> str: "messages": [{"role": "user", "content": prompt}], } self._add_system(kwargs, system_prompt) - try: + with wrap_llm_errors("Anthropic API call failed"): response = self._client.messages.create(**kwargs) block = response.content[0] if block.type != "text": raise LLMError(f"Unexpected response block type: {block.type}") return block.text.strip() - except LLMError: - raise - except Exception as e: - raise LLMError(f"Anthropic API call failed: {e}") from e @property def supports_structured_output(self) -> bool: @@ -78,16 +74,12 @@ def complete_structured(self, prompt: str, schema: dict[str, Any], system_prompt "tool_choice": {"type": "tool", "name": _STRUCTURED_TOOL}, } self._add_system(kwargs, system_prompt) - try: + with wrap_llm_errors("Anthropic structured call failed"): response = self._client.messages.create(**kwargs) for block in response.content: if block.type == "tool_use": return dict(block.input) raise LLMError("Anthropic response contained no tool_use block") - except LLMError: - raise - except Exception as e: - raise LLMError(f"Anthropic structured call failed: {e}") from e def _add_system(self, kwargs: dict[str, Any], system_prompt: str) -> None: """Attach a cacheable system block to ``kwargs`` when a prompt is given.""" diff --git a/utils/llm/base.py b/utils/llm/base.py index 71c8211f..2ba02e23 100644 --- a/utils/llm/base.py +++ b/utils/llm/base.py @@ -1,9 +1,35 @@ """Abstract base class for LLM providers.""" from abc import ABC, abstractmethod +from collections.abc import Iterator +from contextlib import contextmanager from typing import Any +class LLMError(Exception): + """Exception raised for LLM API errors.""" + + +@contextmanager +def wrap_llm_errors(label: str) -> Iterator[None]: + """Translate any non-``LLMError`` raised inside the block into an ``LLMError``. + + Existing ``LLMError``s (and their messages) pass through untouched; anything + else becomes ``LLMError(f"{label}: {e}")``. Providers wrap their API calls + with this so the "re-raise LLMError, wrap everything else" contract lives in + one place instead of being copy-pasted into every method. + + Args: + label: Human-readable prefix describing the failed call. + """ + try: + yield + except LLMError: + raise + except Exception as e: + raise LLMError(f"{label}: {e}") from e + + class LLMProvider(ABC): """Interface for LLM providers used to generate transaction explanations.""" @@ -57,7 +83,3 @@ def complete_structured(self, prompt: str, schema: dict[str, Any], system_prompt @abstractmethod def model_name(self) -> str: """Return the model identifier being used.""" - - -class LLMError(Exception): - """Exception raised for LLM API errors.""" diff --git a/utils/llm/openai_compat.py b/utils/llm/openai_compat.py index 69a3e7cd..a694ead9 100644 --- a/utils/llm/openai_compat.py +++ b/utils/llm/openai_compat.py @@ -9,7 +9,7 @@ import json from typing import Any -from utils.llm.base import LLMError, LLMProvider +from utils.llm.base import LLMError, LLMProvider, wrap_llm_errors from utils.logger import get_logger logger = get_logger("utils.llm.openai_compat") @@ -45,7 +45,7 @@ def __init__(self, api_key: str, base_url: str, model: str, structured_output: b def complete(self, prompt: str, system_prompt: str = "") -> str: """Generate a completion using the OpenAI chat completions API.""" - try: + with wrap_llm_errors("LLM API call failed"): response = self._client.chat.completions.create( model=self._model, messages=self._build_messages(prompt, system_prompt), @@ -54,10 +54,6 @@ def complete(self, prompt: str, system_prompt: str = "") -> str: if not content: raise LLMError("Empty response from LLM") return content.strip() - except LLMError: - raise - except Exception as e: - raise LLMError(f"LLM API call failed: {e}") from e @property def supports_structured_output(self) -> bool: @@ -66,7 +62,7 @@ def supports_structured_output(self) -> bool: def complete_structured(self, prompt: str, schema: dict[str, Any], system_prompt: str = "") -> dict[str, Any]: """Request a JSON-schema-constrained response and return it parsed.""" - try: + with wrap_llm_errors("Structured LLM call failed"): response = self._client.chat.completions.create( # type: ignore[call-overload] model=self._model, messages=self._build_messages(prompt, system_prompt), @@ -78,14 +74,11 @@ def complete_structured(self, prompt: str, schema: dict[str, Any], system_prompt content = response.choices[0].message.content if not content: raise LLMError("Empty response from LLM") - parsed: dict[str, Any] = json.loads(content) + try: + parsed: dict[str, Any] = json.loads(content) + except json.JSONDecodeError as e: + raise LLMError(f"Structured response was not valid JSON: {e}") from e return parsed - except LLMError: - raise - except json.JSONDecodeError as e: - raise LLMError(f"Structured response was not valid JSON: {e}") from e - except Exception as e: - raise LLMError(f"Structured LLM call failed: {e}") from e def _build_messages(self, prompt: str, system_prompt: str) -> list[dict[str, str]]: """Assemble chat messages, prepending the system prompt when present."""