From 78f799dc61bb67cff9336a3539d83edb0b3ece97 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Mon, 15 Jun 2026 13:36:11 +0800 Subject: [PATCH 1/3] refactor: modularize agent-to-agent message routing and fix formatting --- backend/app/services/agent_tools.py | 645 +++++++++++----------------- backend/app/services/llm/caller.py | 6 + backend/app/services/llm/client.py | 2 +- backend/tests/test_a2a_msg_type.py | 514 ++++++++++++---------- 4 files changed, 564 insertions(+), 603 deletions(-) diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index e0058e8b7..e05b5cd9e 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -2001,6 +2001,7 @@ def _patch_computer_tool_descriptions(tools: list[dict], os_type: str) -> list[d # Build the OS-aware description for agentbay_file_transfer new_file_transfer_desc = ( + ( "Transfer a file between any two endpoints: the agent workspace, " "the AgentBay browser environment, the cloud desktop (computer), or the code sandbox.\n\n" f"COMPUTER ENVIRONMENT OS: {computer_os_label}\n" @@ -2015,7 +2016,9 @@ def _patch_computer_tool_descriptions(tools: list[dict], os_type: str) -> list[d "- workspace -> env: upload a workspace file into a cloud environment\n" "- env -> workspace: download a file from a cloud environment into the workspace\n" "- env A -> env B: transfer between environments (transparent backend temp)" - ) if os_type == "windows" else ( + ) + if os_type == "windows" + else ( "Transfer a file between any two endpoints: the agent workspace, " "the AgentBay browser environment, the cloud desktop (computer), or the code sandbox.\n\n" f"COMPUTER ENVIRONMENT OS: {computer_os_label}\n" @@ -2030,6 +2033,7 @@ def _patch_computer_tool_descriptions(tools: list[dict], os_type: str) -> list[d "- workspace -> env: upload a workspace file into a cloud environment\n" "- env -> workspace: download a file from a cloud environment into the workspace\n" "- env A -> env B: transfer between environments (transparent backend temp)" + ) ) patched = [] @@ -7032,22 +7036,33 @@ async def _wake_agent_async(agent_id: uuid.UUID, reason_context: str, *, from_ag await wake_agent_with_context(agent_id, reason_context, **kwargs) -async def _send_message_to_agent( +from dataclasses import dataclass, field + + +@dataclass +class A2AContext: + source_agent: AgentModel + target_agent: AgentModel + chat_session_id: str + session_agent_id: uuid.UUID + owner_id: uuid.UUID + src_participant_id: uuid.UUID | None + tgt_participant_id: uuid.UUID | None + msg_type: str + message_text: str + origin_source_channel: str + origin_session_id: str | None + primary_model: Optional["LLMModel"] = None + fallback_model: Optional["LLMModel"] = None + conversation_history: list[dict] = field(default_factory=list) + + +async def _build_a2a_context( from_agent_id: uuid.UUID, args: dict, user_id: uuid.UUID | None = None, origin_session_id: str | None = None, -) -> str: - """Send a message to another digital employee. - - Behaviour depends on ``msg_type``: - - notify: fire-and-forget — message is saved, target is woken asynchronously. - Returns immediately. - - task_delegate: async with callback — message is saved, source agent sets up - a focus item + on_message trigger so it is notified when the - target completes the task. Returns immediately. - - consult: synchronous request-response (original behaviour). - """ +) -> A2AContext | str: agent_name = args.get("agent_name", "").strip() message_text = args.get("message", "").strip() msg_type = args.get("msg_type", "notify").strip().lower() @@ -7061,7 +7076,6 @@ async def _send_message_to_agent( from app.models.llm import LLMModel from app.services.llm.utils import get_model_api_key - # Phase 1: Setup and database queries under a short-lived session origin_source_channel = "web" async with async_session() as db: @@ -7162,50 +7176,6 @@ async def _send_message_to_agent( await db.flush() session_id = str(chat_session.id) - target_id = target.id - target_name = target.name - target_agent_type = getattr(target, "agent_type", "native") - target_openclaw_last_seen = target.openclaw_last_seen - target_role_description = target.role_description - target_max_tool_rounds = target.max_tool_rounds or 50 - - # ── OpenClaw target: queue message for gateway poll ── - if target_agent_type == "openclaw": - # 1. Save the source message to the chat session - db.add(ChatMessage( - agent_id=session_agent_id, - user_id=owner_id, - role="user", - content=message_text, - conversation_id=session_id, - participant_id=src_participant_id, - )) - chat_session.last_message_at = datetime.now(timezone.utc) - - # 2. Queue for Gateway - from app.models.gateway_message import GatewayMessage as GMsg - gw_msg = GMsg( - agent_id=target_id, - sender_agent_id=from_agent_id, - sender_user_id=owner_id, - content=f"[From {source_name}] {message_text}", - status="pending", - conversation_id=session_id, - ) - db.add(gw_msg) - await db.commit() - - # 3. Log activity - from app.services.activity_logger import log_activity - await log_activity( - from_agent_id, "agent_msg_sent", - f"Sent message to {target_name} (queued)", - detail={"partner": target_name, "message": message_text[:200]}, - ) - - online = target_openclaw_last_seen and (datetime.now(timezone.utc) - target_openclaw_last_seen).total_seconds() < 300 - status_hint = "online" if online else "offline (message will be delivered on next heartbeat)" - return f"✅ Message sent to {target_name} (OpenClaw agent, currently {status_hint}). The message has been queued and will be delivered when the agent polls for updates." # Save source message (common to all paths) db.add(ChatMessage( @@ -7219,6 +7189,21 @@ async def _send_message_to_agent( chat_session.last_message_at = datetime.now(timezone.utc) await db.commit() + if getattr(target, "agent_type", "native") == "openclaw": + return A2AContext( + source_agent=source_agent, + target_agent=target, + chat_session_id=session_id, + session_agent_id=session_agent_id, + owner_id=owner_id, + src_participant_id=src_participant_id, + tgt_participant_id=tgt_participant_id, + msg_type=msg_type, + message_text=message_text, + origin_source_channel=origin_source_channel, + origin_session_id=origin_session_id, + ) + # ── Feature flag: async A2A (tenant-level) ── _a2a_async = False if source_tenant_id: @@ -7234,38 +7219,23 @@ async def _send_message_to_agent( if msg_type in ("notify", "task_delegate"): msg_type = "consult" - # If consult, we need target LLM model details inside the session - target_model_provider = None - target_model_base_url = None - target_model_name = None - target_model_temperature = None - target_model_request_timeout = 120.0 - target_api_key = "" - conversation_messages: list[dict] = [] + primary_model = None + fallback_model = None + conversation_history: list[dict] = [] if msg_type == "consult": # Load primary model - target_model = None if target.primary_model_id: model_r = await db.execute(select(LLMModel).where(LLMModel.id == target.primary_model_id)) - target_model = model_r.scalar_one_or_none() + primary_model = model_r.scalar_one_or_none() # Fallback model - if not target_model and target.fallback_model_id: + if target.fallback_model_id: fb_r = await db.execute(select(LLMModel).where(LLMModel.id == target.fallback_model_id)) - target_model = fb_r.scalar_one_or_none() - if target_model: - logger.warning(f"[A2A] Primary model unavailable for {target_name}, using fallback: {target_model.model}") + fallback_model = fb_r.scalar_one_or_none() - if not target_model: - return f"⚠️ {target_name} has no LLM model configured" - - target_model_provider = target_model.provider - target_model_base_url = target_model.base_url - target_model_name = target_model.model - target_model_temperature = target_model.temperature - target_model_request_timeout = float(getattr(target_model, 'request_timeout', None) or 120.0) - target_api_key = get_model_api_key(target_model) + if not primary_model and not fallback_model: + return f"⚠️ {target.name} has no LLM model configured" # Load recent history for context hist_result = await db.execute( @@ -7282,106 +7252,160 @@ async def _send_message_to_agent( role = "user" else: role = "assistant" - conversation_messages.append({"role": role, "content": m.content}) + conversation_history.append({"role": role, "content": m.content}) - # ── notify: fire-and-forget ── - if msg_type == "notify": - try: - from app.services.activity_logger import log_activity - await log_activity( - from_agent_id, "agent_msg_sent", - f"Sent notification to {target_name}", - detail={"partner": target_name, "message": message_text[:200], "msg_type": "notify"}, - ) - except Exception: - pass + return A2AContext( + source_agent=source_agent, + target_agent=target, + chat_session_id=session_id, + session_agent_id=session_agent_id, + owner_id=owner_id, + src_participant_id=src_participant_id, + tgt_participant_id=tgt_participant_id, + msg_type=msg_type, + message_text=message_text, + origin_source_channel=origin_source_channel, + origin_session_id=origin_session_id, + primary_model=primary_model, + fallback_model=fallback_model, + conversation_history=conversation_history, + ) + except Exception as e: + logger.exception(f"[A2A] _build_a2a_context failed: from={from_agent_id}") + return f"❌ A2A context error ({type(e).__name__}): {str(e)[:200]}" - try: - await _wake_agent_async( - target_id, - f"[From {source_name}] {message_text}", - from_agent_id=from_agent_id, - skip_dedup=True, - a2a_session_id=session_id, - ) - except Exception as e: - logger.warning(f"[A2A] Failed to wake {target_name} for notify: {e}") - return f"✅ Notification sent to {target_name}. They will process it asynchronously." +async def _a2a_handle_openclaw(ctx: A2AContext) -> str: + try: + async with async_session() as db: + # 2. Queue for Gateway + from app.models.gateway_message import GatewayMessage as GMsg + gw_msg = GMsg( + agent_id=ctx.target_agent.id, + sender_agent_id=ctx.source_agent.id, + sender_user_id=ctx.owner_id, + content=f"[From {ctx.source_agent.name}] {ctx.message_text}", + status="pending", + conversation_id=ctx.chat_session_id, + ) + db.add(gw_msg) + await db.commit() + + # 3. Log activity + from app.services.activity_logger import log_activity + await log_activity( + ctx.source_agent.id, "agent_msg_sent", + f"Sent message to {ctx.target_agent.name} (queued)", + detail={"partner": ctx.target_agent.name, "message": ctx.message_text[:200]}, + ) - # ── task_delegate: async with callback ── - if msg_type == "task_delegate": - focus_id = f"wait_{target_name.lower().replace(' ', '_')}_task" - focus_desc = f"Waiting for {target_name} to complete delegated task: {message_text[:100]}" + online = ctx.target_agent.openclaw_last_seen and (datetime.now(timezone.utc) - ctx.target_agent.openclaw_last_seen).total_seconds() < 300 + status_hint = "online" if online else "offline (message will be delivered on next heartbeat)" + return f"✅ Message sent to {ctx.target_agent.name} (OpenClaw agent, currently {status_hint}). The message has been queued and will be delivered when the agent polls for updates." + except Exception as e: + logger.exception(f"[A2A] _a2a_handle_openclaw failed: from={ctx.source_agent.id}, to={ctx.target_agent.id}") + return f"❌ OpenClaw send error ({type(e).__name__}): {str(e)[:200]}" - try: - await _append_focus_item(from_agent_id, focus_id, focus_desc) - except Exception as e: - logger.warning(f"[A2A] Failed to write focus for delegate: {e}") - - trigger_name = f"a2a_wait_{target_name.lower().replace(' ', '_')}" - trigger_reason = ( - f"{target_name} has replied with the result of a delegated task. " - f"Original task: {message_text[:200]}. " - f"Steps: 1) Process {target_name}'s reply. " - f"2) Mark focus item '{focus_id}' as completed. " - f"3) Cancel this trigger. " - f"USER-FACING OUTPUT RULES: Your reply goes directly to the user's chat. " - f"Write in natural, conversational language as if talking to a colleague. " - f"NEVER use technical terms like: trigger name, focus item, a2a_wait, " - f"task_delegate, focus_ref, or any internal identifier. " - f"NEVER mention your internal operations (canceling triggers, updating focus, " - f"marking items complete, trigger status, etc.). " - f"Just summarize the task result in plain language." + +async def _a2a_handle_notify(ctx: A2AContext) -> str: + try: + try: + from app.services.activity_logger import log_activity + await log_activity( + ctx.source_agent.id, "agent_msg_sent", + f"Sent notification to {ctx.target_agent.name}", + detail={"partner": ctx.target_agent.name, "message": ctx.message_text[:200], "msg_type": "notify"}, ) - try: - await _create_on_message_trigger( - agent_id=from_agent_id, - trigger_name=trigger_name, - from_agent_name=target_name, - reason=trigger_reason, - focus_ref=focus_id, - notification_summary=f"等待{target_name}完成任务并回复", - origin_session_id=origin_session_id, - origin_user_id=str(owner_id) if owner_id else None, - origin_source_channel=origin_source_channel, - ) - except Exception as e: - logger.warning(f"[A2A] Failed to create trigger for delegate: {e}") + except Exception: + pass - try: - from app.services.activity_logger import log_activity - await log_activity( - from_agent_id, "agent_msg_sent", - f"Delegated task to {target_name}", - detail={"partner": target_name, "message": message_text[:200], "msg_type": "task_delegate"}, - ) - except Exception: - pass + try: + await _wake_agent_async( + ctx.target_agent.id, + f"[From {ctx.source_agent.name}] {ctx.message_text}", + from_agent_id=ctx.source_agent.id, + skip_dedup=True, + a2a_session_id=ctx.chat_session_id, + ) + except Exception as e: + logger.warning(f"[A2A] Failed to wake {ctx.target_agent.name} for notify: {e}") - try: - await _wake_agent_async( - target_id, - f"[From {source_name}] {message_text}", - from_agent_id=from_agent_id, - skip_dedup=True, - a2a_session_id=session_id, - ) - except Exception as e: - logger.warning(f"[A2A] Failed to wake {target_name} for delegate: {e}") + return f"✅ Notification sent to {ctx.target_agent.name}. They will process it asynchronously." + except Exception as e: + logger.exception(f"[A2A] _a2a_handle_notify failed: from={ctx.source_agent.id}, to={ctx.target_agent.id}") + return f"❌ Notification error ({type(e).__name__}): {str(e)[:200]}" - return f"✅ Task delegated to {target_name}. You will be notified when they complete it." - # ── consult (default): synchronous request-response ── - # Build target system prompt - from app.services.agent_context import build_agent_context - target_static, target_dynamic = await build_agent_context( - target_id, - target_name, - target_role_description or "", - current_user_name=source_name +async def _a2a_handle_task_delegate(ctx: A2AContext) -> str: + try: + focus_id = f"wait_{ctx.target_agent.name.lower().replace(' ', '_')}_task" + focus_desc = f"Waiting for {ctx.target_agent.name} to complete delegated task: {ctx.message_text[:100]}" + + try: + await _append_focus_item(ctx.source_agent.id, focus_id, focus_desc) + except Exception as e: + logger.warning(f"[A2A] Failed to write focus for delegate: {e}") + + trigger_name = f"a2a_wait_{ctx.target_agent.name.lower().replace(' ', '_')}" + trigger_reason = ( + f"{ctx.target_agent.name} has replied with the result of a delegated task. " + f"Original task: {ctx.message_text[:200]}. " + f"Steps: 1) Process {ctx.target_agent.name}'s reply. " + f"2) Mark focus item '{focus_id}' as completed. " + f"3) Cancel this trigger. " + f"USER-FACING OUTPUT RULES: Your reply goes directly to the user's chat. " + f"Write in natural, conversational language as if talking to a colleague. " + f"NEVER use technical terms like: trigger name, focus item, a2a_wait, " + f"task_delegate, focus_ref, or any internal identifier. " + f"NEVER mention your internal operations (canceling triggers, updating focus, " + f"marking items complete, trigger status, etc.). " + f"Just summarize the task result in plain language." ) - target_dynamic += ( + try: + await _create_on_message_trigger( + agent_id=ctx.source_agent.id, + trigger_name=trigger_name, + from_agent_name=ctx.target_agent.name, + reason=trigger_reason, + focus_ref=focus_id, + notification_summary=f"等待{ctx.target_agent.name}完成任务并回复", + origin_session_id=ctx.origin_session_id, + origin_user_id=str(ctx.owner_id) if ctx.owner_id else None, + origin_source_channel=ctx.origin_source_channel, + ) + except Exception as e: + logger.warning(f"[A2A] Failed to create trigger for delegate: {e}") + + try: + from app.services.activity_logger import log_activity + await log_activity( + ctx.source_agent.id, "agent_msg_sent", + f"Delegated task to {ctx.target_agent.name}", + detail={"partner": ctx.target_agent.name, "message": ctx.message_text[:200], "msg_type": "task_delegate"}, + ) + except Exception: + pass + + try: + await _wake_agent_async( + ctx.target_agent.id, + f"[From {ctx.source_agent.name}] {ctx.message_text}", + from_agent_id=ctx.source_agent.id, + skip_dedup=True, + a2a_session_id=ctx.chat_session_id, + ) + except Exception as e: + logger.warning(f"[A2A] Failed to wake {ctx.target_agent.name} for delegate: {e}") + + return f"✅ Task delegated to {ctx.target_agent.name}. You will be notified when they complete it." + except Exception as e: + logger.exception(f"[A2A] _a2a_handle_task_delegate failed: from={ctx.source_agent.id}, to={ctx.target_agent.id}") + return f"❌ Task delegation error ({type(e).__name__}): {str(e)[:200]}" + + +async def _a2a_handle_consult(ctx: A2AContext) -> str: + try: + suffix = ( "\n\n--- Agent-to-Agent Message ---\n" "You are receiving a message from another digital employee. " "Reply concisely and helpfully. Focus on the request and provide a clear answer.\n" @@ -7390,212 +7414,44 @@ async def _send_message_to_agent( "Do NOT output plain text without calling `finish`. " "Plain text responses will be REJECTED and you will be asked to redo.\n" "\n** CRITICAL FILE DELIVERY RULE **\n" - "After you write any file (report, document, analysis, etc.) that the requesting agent needs, " - "you MUST call `send_file_to_agent(agent_name=\"\", file_path=\"\")` " - "to deliver it. The other agent CANNOT access your workspace. " - "Never just tell them the path — always deliver explicitly.\n" - ) - - conversation_messages.append({"role": "user", "content": f"[From {source_name}] {message_text}"}) - - import random - import httpx - from app.services.llm import ( - get_provider_base_url, - create_llm_client, - LLMMessage, - LLMError, + f"After you write any file (report, document, analysis, etc.) that the requesting agent needs, " + f"you MUST call `send_file_to_agent(agent_name=\"{ctx.source_agent.name}\", file_path=\"\")` " + f"to deliver it. The other agent CANNOT access your workspace. " + f"Never just tell them the path — always deliver explicitly.\n" ) - base_url = get_provider_base_url(target_model_provider, target_model_base_url) - if not base_url: - return f"⚠️ {target_name}'s model has no API base URL configured" - - full_msgs: list[LLMMessage] = [LLMMessage(role="system", content=target_static, dynamic_content=target_dynamic)] + [ - LLMMessage(role=m["role"], content=m["content"]) for m in conversation_messages - ] - - # Load tools for target agent - tools_for_llm = await get_agent_tools_for_llm(target_id) - - target_reply = "" - _a2a_accumulated_usage = None - from app.services.token_tracker import ( - TokenUsage, - record_token_usage, - extract_token_usage, - estimate_token_usage_from_chars, + conversation_messages = list(ctx.conversation_history) + conversation_messages.append({"role": "user", "content": f"[From {ctx.source_agent.name}] {ctx.message_text}"}) + + from app.services.llm.caller import call_llm_with_failover + + target_reply = await call_llm_with_failover( + primary_model=ctx.primary_model, + fallback_model=ctx.fallback_model, + messages=conversation_messages, + agent_name=ctx.target_agent.name, + role_description=ctx.target_agent.role_description or "", + agent_id=ctx.target_agent.id, + user_id=ctx.owner_id, + session_id=ctx.chat_session_id, + current_user_name_override=ctx.source_agent.name, + system_prompt_suffix=suffix, ) - _a2a_accumulated_usage = TokenUsage() - - llm_client = create_llm_client( - provider=target_model_provider, - api_key=target_api_key, - model=target_model_name, - base_url=base_url, - timeout=target_model_request_timeout, - ) - _A2A_RETRYABLE_MARKERS = ( - "http 408", "http 429", "http 500", "http 502", "http 503", "http 504", - "timeout", "timed out", "connection failed", "temporarily unavailable", "rate limit", - ) - _A2A_MAX_RETRIES = 3 - - def _is_retryable_llm_error(exc: Exception) -> bool: - """Determine whether an LLM exception is transient and worth retrying.""" - if isinstance(exc, (httpx.TimeoutException, httpx.TransportError)): - return True - if isinstance(exc, LLMError): - lowered = (str(exc) or "").lower() - return any(m in lowered for m in _A2A_RETRYABLE_MARKERS) - return False - - try: - for _round in range(target_max_tool_rounds): - response = None - for attempt in range(1, _A2A_MAX_RETRIES + 1): - try: - response = await llm_client.complete( - messages=full_msgs, - tools=tools_for_llm if tools_for_llm else None, - temperature=target_model_temperature, - max_tokens=4096, - ) - break - except Exception as llm_exc: - if not _is_retryable_llm_error(llm_exc) or attempt >= _A2A_MAX_RETRIES: - raise - - err_text = str(llm_exc) or type(llm_exc).__name__ - backoff = (2 ** (attempt - 1)) + random.uniform(0, 0.5) - logger.warning( - f"[A2A] LLM call failed for {target_name} (round={_round + 1}, " - f"attempt={attempt}/{_A2A_MAX_RETRIES}): {err_text[:200]}. " - f"Retrying in {backoff:.1f}s" - ) - await asyncio.sleep(backoff) - - if response is None: - raise RuntimeError("A2A LLM response is unexpectedly empty after retries") - # Track tokens from API response - usage = extract_token_usage(response.usage) - if usage: - _a2a_accumulated_usage.add(usage) - else: - round_chars = sum(len(m.content or '') for m in full_msgs if isinstance(m.content, str)) - _a2a_accumulated_usage.add(estimate_token_usage_from_chars(round_chars)) - - # Check for tool calls - if response.tool_calls: - # Add assistant message with tool calls to conversation - full_msgs.append(LLMMessage( - role="assistant", - content=response.content or None, - tool_calls=[{ - "id": tc.get("id", ""), - "type": "function", - "function": tc.get("function", {}), - } for tc in response.tool_calls], - reasoning_content=response.reasoning_content, - )) - - finish_call = find_finish_call(response.tool_calls) - if finish_call: - if finish_call.valid: - target_reply = finish_call.content - break - full_msgs.append(LLMMessage( - role="tool", - tool_call_id=finish_call.call_id, - content=finish_call.error or "`finish` was invalid.", - )) - continue - - # Execute each tool call - for tc in response.tool_calls: - fn = tc.get("function", {}) - tool_name = fn.get("name", "") - raw_args = fn.get("arguments", "{}") - try: - tool_args = parse_tool_arguments(raw_args) - except Exception as parse_exc: - logger.warning(f"[A2A] Invalid tool arguments for {tool_name}: {parse_exc}") - tool_result = ( - f"❌ Invalid JSON arguments for `{tool_name}`: {parse_exc}. " - "DO NOT retry with the same content. Please fix the JSON encoding: " - "escape all double quotes inside string values as \\\" and all newlines as \\n." - ) - # Add tool result to conversation - full_msgs.append(LLMMessage( - role="tool", - tool_call_id=tc.get("id", ""), - content=str(tool_result), - )) - continue - - tool_result = await execute_tool(tool_name, tool_args, target_id, owner_id) - - # Nudge: after write_file in A2A, remind to deliver via send_file_to_agent - if tool_name == "write_file" and isinstance(tool_result, str) and tool_result.startswith("\u2705"): - wrote_path = tool_args.get("path", "") - tool_result += ( - f"\n\n⚠️ REMINDER: The requesting agent ({source_name}) cannot access your workspace. " - f"You MUST now call `send_file_to_agent(agent_name=\"{source_name}\", file_path=\"{wrote_path}\")` " - f"to deliver this file to them." - ) - - # Save tool_call to DB so it appears in chat history - try: - async with async_session() as _tc_db: - _tc_db.add(ChatMessage( - agent_id=session_agent_id, - user_id=owner_id, - role="tool_call", - content=json.dumps({ - "name": tool_name, - "args": tool_args, - "status": "done", - "result": str(tool_result)[:500], - }, ensure_ascii=False), - conversation_id=session_id, - participant_id=tgt_participant_id, - )) - await _tc_db.commit() - except Exception as _tc_err: - logger.error(f"[A2A] Failed to save tool_call: {_tc_err}") - - # Add tool result to conversation - full_msgs.append(LLMMessage( - role="tool", - tool_call_id=tc.get("id", ""), - content=str(tool_result)[:4000], - )) - continue # Next LLM round - - if response.content: - full_msgs.append(LLMMessage(role="assistant", content=response.content)) - full_msgs.append(LLMMessage(role="user", content=FINISH_PROTOCOL_REMINDER)) - finally: - await llm_client.close() - - # Record accumulated A2A tokens for the target agent - if _a2a_accumulated_usage and _a2a_accumulated_usage.total_tokens > 0: - await record_token_usage(target_id, _a2a_accumulated_usage) - - if not target_reply: - return f"⚠️ {target_name} did not respond (LLM returned empty)" + if not target_reply or target_reply.startswith("⚠️") or target_reply.startswith("[Error]") or target_reply.startswith("[LLM Error]") or target_reply.startswith("[LLM call error]"): + return target_reply or f"⚠️ {ctx.target_agent.name} did not respond (LLM returned empty)" # Save target reply async with async_session() as db2: - part_r = await db2.execute(select(Participant).where(Participant.type == "agent", Participant.ref_id == target_id)) + from app.models.participant import Participant + part_r = await db2.execute(select(Participant).where(Participant.type == "agent", Participant.ref_id == ctx.target_agent.id)) tgt_part = part_r.scalar_one_or_none() db2.add(ChatMessage( - agent_id=session_agent_id, - user_id=owner_id, + agent_id=ctx.session_agent_id, + user_id=ctx.owner_id, role="assistant", content=target_reply, - conversation_id=session_id, + conversation_id=ctx.chat_session_id, participant_id=tgt_part.id if tgt_part else None, )) await db2.commit() @@ -7603,31 +7459,52 @@ def _is_retryable_llm_error(exc: Exception) -> bool: # Log activity from app.services.activity_logger import log_activity await log_activity( - target_id, "agent_msg_sent", - f"Replied to message from {source_name}", - detail={"partner": source_name, "message": message_text[:200], "reply": target_reply[:200]}, + ctx.target_agent.id, "agent_msg_sent", + f"Replied to message from {ctx.source_agent.name}", + detail={"partner": ctx.source_agent.name, "message": ctx.message_text[:200], "reply": target_reply[:200]}, ) await log_activity( - from_agent_id, "agent_msg_sent", - f"Sent message to {target_name} and received reply", - detail={"partner": target_name, "message": message_text[:200], "reply": target_reply[:200]}, + ctx.source_agent.id, "agent_msg_sent", + f"Sent message to {ctx.target_agent.name} and received reply", + detail={"partner": ctx.target_agent.name, "message": ctx.message_text[:200], "reply": target_reply[:200]}, ) - return f"💬 {target_name} replied:\n{target_reply}" + return f"💬 {ctx.target_agent.name} replied:\n{target_reply}" except Exception as e: - logger.exception( - f"[A2A] send_message_to_agent failed: from={from_agent_id}, to={args.get('agent_name', '')}" - ) - error_type = type(e).__name__ - error_detail = (str(e) or "").strip() - if not error_detail: - timeout_types = {"ReadTimeout", "ConnectTimeout", "TimeoutException"} - if error_type in timeout_types: - error_detail = "LLM request timed out while waiting for target agent response" - else: - error_detail = "No detailed error message returned from upstream" - return f"❌ Message send error ({error_type}): {error_detail[:200]}" + logger.exception(f"[A2A] _a2a_handle_consult failed: from={ctx.source_agent.id}, to={ctx.target_agent.id}") + return f"❌ Consult request error ({type(e).__name__}): {str(e)[:200]}" + + +async def _send_message_to_agent( + from_agent_id: uuid.UUID, + args: dict, + user_id: uuid.UUID | None = None, + origin_session_id: str | None = None, +) -> str: + """Send a message to another digital employee. + + Behaviour depends on ``msg_type``: + - notify: fire-and-forget — message is saved, target is woken asynchronously. + Returns immediately. + - task_delegate: async with callback — message is saved, source agent sets up + a focus item + on_message trigger so it is notified when the + target completes the task. Returns immediately. + - consult: synchronous request-response (original behaviour). + """ + ctx_or_err = await _build_a2a_context(from_agent_id, args, user_id, origin_session_id) + if isinstance(ctx_or_err, str): + return ctx_or_err + ctx = ctx_or_err + + if ctx.target_agent.agent_type == "openclaw": + return await _a2a_handle_openclaw(ctx) + if ctx.msg_type == "notify": + return await _a2a_handle_notify(ctx) + if ctx.msg_type == "task_delegate": + return await _a2a_handle_task_delegate(ctx) + return await _a2a_handle_consult(ctx) + diff --git a/backend/app/services/llm/caller.py b/backend/app/services/llm/caller.py index 2dc66faa8..ebf4774e0 100644 --- a/backend/app/services/llm/caller.py +++ b/backend/app/services/llm/caller.py @@ -425,6 +425,7 @@ async def call_llm( skip_tools: bool = False, on_code_output=None, current_user_name_override: str | None = None, + system_prompt_suffix: str | None = None, ) -> str: """Call LLM via unified client with function-calling tool loop.""" # Get agent config for tool rounds @@ -462,6 +463,8 @@ async def _default_on_tool_call(data: dict): from app.services.agent_context import build_agent_context # Look up current user's display name so the agent knows who it's talking to static_prompt, dynamic_prompt = await build_agent_context(agent_id, agent_name, role_description, current_user_name=_user_name) + if system_prompt_suffix: + dynamic_prompt += system_prompt_suffix # Load tools dynamically from DB. `skip_tools=True` is set by the WS # handler on the onboarding greeting turn; keep the runtime-level `finish` @@ -657,6 +660,7 @@ async def call_llm_with_failover( skip_tools: bool = False, on_code_output=None, current_user_name_override: str | None = None, + system_prompt_suffix: str | None = None, ) -> str: """Call LLM with automatic failover support.""" guard = FailoverGuard() @@ -699,6 +703,7 @@ async def _wrapped_on_tool_call(data: dict): skip_tools=skip_tools, on_code_output=on_code_output, current_user_name_override=current_user_name_override, + system_prompt_suffix=system_prompt_suffix, ) # Check if we need to failover @@ -763,6 +768,7 @@ async def _fallback_on_tool_call(data: dict): skip_tools=skip_tools, on_code_output=on_code_output, current_user_name_override=current_user_name_override, + system_prompt_suffix=system_prompt_suffix, ) # Combine error messages if fallback also failed diff --git a/backend/app/services/llm/client.py b/backend/app/services/llm/client.py index 56701d713..1cdcfeda4 100644 --- a/backend/app/services/llm/client.py +++ b/backend/app/services/llm/client.py @@ -2130,7 +2130,7 @@ def get_max_tokens(provider: str, model: str | None = None, max_output_tokens: i model_limits = spec.model_max_tokens if spec else MAX_TOKENS_BY_MODEL # Highest priority: per-model DB override - if max_output_tokens and max_output_tokens > 0: + if isinstance(max_output_tokens, int) and max_output_tokens > 0: return max_output_tokens # Check model-specific limits diff --git a/backend/tests/test_a2a_msg_type.py b/backend/tests/test_a2a_msg_type.py index 2df3b337b..6cefcb4d6 100644 --- a/backend/tests/test_a2a_msg_type.py +++ b/backend/tests/test_a2a_msg_type.py @@ -66,8 +66,9 @@ async def flush(self): self.flushed = True -def _make_agent(agent_id=None, name="TestAgent", tenant_id=None, agent_type="native", - expired=False, primary_model_id=None): +def _make_agent( + agent_id=None, name="TestAgent", tenant_id=None, agent_type="native", expired=False, primary_model_id=None +): agent = MagicMock() agent.id = agent_id or uuid.uuid4() agent.name = name @@ -99,6 +100,7 @@ def _make_tenant(a2a_async_enabled=True): # ── Tests ──────────────────────────────────────────────────────────── + @pytest.mark.asyncio async def test_notify_returns_immediately(): """notify msg_type should return immediately without calling LLM.""" @@ -117,27 +119,33 @@ async def test_notify_returns_immediately(): session.id = session_id session.last_message_at = None - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=rel_id), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - DummyResult(scalar_value=session), - DummyResult(scalar_value=_make_tenant()), - ]) - - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake: + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=rel_id), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + DummyResult(scalar_value=session), + DummyResult(scalar_value=_make_tenant()), + ] + ) + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake, + ): mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "Bob", - "message": "Please review the document", - "msg_type": "notify", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "Bob", + "message": "Please review the document", + "msg_type": "notify", + }, + ) assert "Notification sent to Bob" in result assert "asynchronously" in result @@ -162,29 +170,35 @@ async def test_task_delegate_creates_focus_and_trigger(): session.id = session_id session.last_message_at = None - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=rel_id), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - DummyResult(scalar_value=session), - DummyResult(scalar_value=_make_tenant()), - ]) - - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_tools._append_focus_item", new_callable=AsyncMock) as mock_focus, \ - patch("app.services.agent_tools._create_on_message_trigger", new_callable=AsyncMock) as mock_trigger, \ - patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake: + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=rel_id), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + DummyResult(scalar_value=session), + DummyResult(scalar_value=_make_tenant()), + ] + ) + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_tools._append_focus_item", new_callable=AsyncMock) as mock_focus, + patch("app.services.agent_tools._create_on_message_trigger", new_callable=AsyncMock) as mock_trigger, + patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake, + ): mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "Bob", - "message": "Please prepare the Q3 report", - "msg_type": "task_delegate", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "Bob", + "message": "Please prepare the Q3 report", + "msg_type": "task_delegate", + }, + ) assert "Task delegated to Bob" in result assert "notified when they complete" in result @@ -230,35 +244,42 @@ async def test_consult_calls_llm_synchronously(): response = MagicMock() response.content = "" - response.tool_calls = [{ - "id": "call_finish", - "type": "function", - "function": { - "name": "finish", - "arguments": json.dumps({"content": "Here is the answer"}), - }, - }] + response.tool_calls = [ + { + "id": "call_finish", + "type": "function", + "function": { + "name": "finish", + "arguments": json.dumps({"content": "Here is the answer"}), + }, + } + ] response.usage = None mock_llm_client = AsyncMock() mock_llm_client.complete = AsyncMock(return_value=response) + mock_llm_client.stream = AsyncMock(return_value=response) mock_llm_client.close = AsyncMock() - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=rel_id), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - DummyResult(scalar_value=session), - DummyResult(scalar_value=_make_tenant()), - DummyResult(scalar_value=model), - DummyResult(scalars_list=[]), - ]) - - db2 = RecordingDB(responses=[ - DummyResult(scalar_value=tgt_participant), - ]) + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=rel_id), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + DummyResult(scalar_value=session), + DummyResult(scalar_value=_make_tenant()), + DummyResult(scalar_value=model), + DummyResult(scalars_list=[]), + ] + ) + + db2 = RecordingDB( + responses=[ + DummyResult(scalar_value=tgt_participant), + ] + ) call_count = 0 session_dbs = [db, db2] @@ -269,29 +290,37 @@ async def mock_session_enter(self): call_count += 1 return result - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_context.build_agent_context", new_callable=AsyncMock, return_value=("static", "dynamic")), \ - patch("app.services.llm.create_llm_client", return_value=mock_llm_client), \ - patch("app.services.agent_tools.get_agent_tools_for_llm", new_callable=AsyncMock, return_value=[]), \ - patch("app.services.llm.get_provider_base_url", return_value="https://api.openai.com/v1"), \ - patch("app.services.token_tracker.record_token_usage", new_callable=AsyncMock), \ - patch("app.services.activity_logger.log_activity", new_callable=AsyncMock): - - mock_session_ctx.return_value.__aenter__ = AsyncMock(side_effect=[ - db, - db2, - ]) + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch( + "app.services.agent_context.build_agent_context", new_callable=AsyncMock, return_value=("static", "dynamic") + ), + patch("app.services.llm.caller.create_llm_client", return_value=mock_llm_client), + patch("app.services.agent_tools.get_agent_tools_for_llm", new_callable=AsyncMock, return_value=[]), + patch("app.services.llm.get_provider_base_url", return_value="https://api.openai.com/v1"), + patch("app.services.token_tracker.record_token_usage", new_callable=AsyncMock), + patch("app.services.activity_logger.log_activity", new_callable=AsyncMock), + ): + mock_session_ctx.return_value.__aenter__ = AsyncMock( + side_effect=[ + db, + db2, + ] + ) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "Bob", - "message": "What is 2+2?", - "msg_type": "consult", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "Bob", + "message": "What is 2+2?", + "msg_type": "consult", + }, + ) assert "Bob replied" in result assert "Here is the answer" in result - mock_llm_client.complete.assert_awaited() + mock_llm_client.stream.assert_awaited() @pytest.mark.asyncio @@ -312,26 +341,32 @@ async def test_default_msg_type_is_notify(): session.id = session_id session.last_message_at = None - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=rel_id), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - DummyResult(scalar_value=session), - DummyResult(scalar_value=_make_tenant()), - ]) - - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake: + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=rel_id), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + DummyResult(scalar_value=session), + DummyResult(scalar_value=_make_tenant()), + ] + ) + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake, + ): mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "Bob", - "message": "Heads up about the meeting", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "Bob", + "message": "Heads up about the meeting", + }, + ) assert "Notification sent" in result mock_wake.assert_awaited_once() @@ -342,10 +377,13 @@ async def test_missing_agent_name_returns_error(): """Missing agent_name should return an error.""" from app.services.agent_tools import _send_message_to_agent - result = await _send_message_to_agent(uuid.uuid4(), { - "agent_name": "", - "message": "Hello", - }) + result = await _send_message_to_agent( + uuid.uuid4(), + { + "agent_name": "", + "message": "Hello", + }, + ) assert "❌" in result @@ -362,23 +400,28 @@ async def test_no_relationship_returns_error(): src_participant = _make_participant(ref_id=from_agent_id) tgt_participant = _make_participant(ref_id=target_id) - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=None), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - ]) + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=None), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + ] + ) with patch("app.services.agent_tools.async_session") as mock_session_ctx: mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "Bob", - "message": "Hello", - "msg_type": "notify", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "Bob", + "message": "Hello", + "msg_type": "notify", + }, + ) assert "do not have a relationship" in result @@ -401,12 +444,16 @@ async def test_create_on_message_trigger(): agent_id = uuid.uuid4() - snap_db = RecordingDB(responses=[ - DummyResult(scalar_value=None), - ]) - trigger_db = RecordingDB(responses=[ - DummyResult(scalar_value=None), - ]) + snap_db = RecordingDB( + responses=[ + DummyResult(scalar_value=None), + ] + ) + trigger_db = RecordingDB( + responses=[ + DummyResult(scalar_value=None), + ] + ) enter_count = 0 dbs = [snap_db, trigger_db] @@ -417,8 +464,10 @@ async def _enter(): enter_count += 1 return db - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_tools.ensure_focus_item", new_callable=AsyncMock) as mock_ensure: + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_tools.ensure_focus_item", new_callable=AsyncMock) as mock_ensure, + ): mock_ensure.return_value = "test_focus" mock_session_ctx.return_value.__aenter__ = AsyncMock(side_effect=_enter) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) @@ -462,12 +511,16 @@ async def test_create_on_message_trigger_resets_fire_count(): max_fires=1, ) - snap_db = RecordingDB(responses=[ - DummyResult(scalar_value=None), - ]) - trigger_db = RecordingDB(responses=[ - DummyResult(scalar_value=existing_trigger), - ]) + snap_db = RecordingDB( + responses=[ + DummyResult(scalar_value=None), + ] + ) + trigger_db = RecordingDB( + responses=[ + DummyResult(scalar_value=existing_trigger), + ] + ) enter_count = 0 dbs = [snap_db, trigger_db] @@ -478,8 +531,10 @@ async def _enter(): enter_count += 1 return db - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_tools.ensure_focus_item", new_callable=AsyncMock) as mock_ensure: + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_tools.ensure_focus_item", new_callable=AsyncMock) as mock_ensure, + ): mock_ensure.return_value = "new_focus" mock_session_ctx.return_value.__aenter__ = AsyncMock(side_effect=_enter) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) @@ -531,26 +586,32 @@ async def test_openclaw_target_still_queues(): session.id = session_id session.last_message_at = None - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=rel_id), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - DummyResult(scalar_value=session), - ]) - - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.activity_logger.log_activity", new_callable=AsyncMock): + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=rel_id), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + DummyResult(scalar_value=session), + ] + ) + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.activity_logger.log_activity", new_callable=AsyncMock), + ): mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "OpenClawBot", - "message": "Hello", - "msg_type": "notify", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "OpenClawBot", + "message": "Hello", + "msg_type": "notify", + }, + ) assert "OpenClaw agent" in result assert "queued" in result @@ -589,52 +650,63 @@ async def test_feature_flag_off_falls_back_to_consult(): response = MagicMock() response.content = "" - response.tool_calls = [{ - "id": "call_finish", - "type": "function", - "function": { - "name": "finish", - "arguments": json.dumps({"content": "Got it"}), - }, - }] + response.tool_calls = [ + { + "id": "call_finish", + "type": "function", + "function": { + "name": "finish", + "arguments": json.dumps({"content": "Got it"}), + }, + } + ] response.usage = None mock_llm_client = AsyncMock() mock_llm_client.complete = AsyncMock(return_value=response) + mock_llm_client.stream = AsyncMock(return_value=response) mock_llm_client.close = AsyncMock() - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=rel_id), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - DummyResult(scalar_value=session), - DummyResult(scalar_value=tenant), - DummyResult(scalar_value=model), - DummyResult(scalars_list=[]), - ]) - - db2 = RecordingDB(responses=[ - DummyResult(scalar_value=tgt_participant), - ]) - - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_context.build_agent_context", new_callable=AsyncMock, return_value=("s", "d")), \ - patch("app.services.llm.create_llm_client", return_value=mock_llm_client), \ - patch("app.services.agent_tools.get_agent_tools_for_llm", new_callable=AsyncMock, return_value=[]), \ - patch("app.services.llm.get_provider_base_url", return_value="https://api.openai.com/v1"), \ - patch("app.services.token_tracker.record_token_usage", new_callable=AsyncMock), \ - patch("app.services.activity_logger.log_activity", new_callable=AsyncMock): + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=rel_id), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + DummyResult(scalar_value=session), + DummyResult(scalar_value=tenant), + DummyResult(scalar_value=model), + DummyResult(scalars_list=[]), + ] + ) + db2 = RecordingDB( + responses=[ + DummyResult(scalar_value=tgt_participant), + ] + ) + + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_context.build_agent_context", new_callable=AsyncMock, return_value=("s", "d")), + patch("app.services.llm.caller.create_llm_client", return_value=mock_llm_client), + patch("app.services.agent_tools.get_agent_tools_for_llm", new_callable=AsyncMock, return_value=[]), + patch("app.services.llm.get_provider_base_url", return_value="https://api.openai.com/v1"), + patch("app.services.token_tracker.record_token_usage", new_callable=AsyncMock), + patch("app.services.activity_logger.log_activity", new_callable=AsyncMock), + ): mock_session_ctx.return_value.__aenter__ = AsyncMock(side_effect=[db, db2]) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "Bob", - "message": "Hello", - "msg_type": "notify", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "Bob", + "message": "Hello", + "msg_type": "notify", + }, + ) assert "Bob replied" in result assert "Got it" in result @@ -662,27 +734,33 @@ async def test_feature_flag_on_uses_notify(): session.id = session_id session.last_message_at = None - db = RecordingDB(responses=[ - DummyResult(scalar_value=source_agent), - DummyResult(scalars_list=[target_agent]), - DummyResult(scalar_value=rel_id), - DummyResult(scalar_value=src_participant), - DummyResult(scalar_value=tgt_participant), - DummyResult(scalar_value=session), - DummyResult(scalar_value=tenant), - ]) - - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake: + db = RecordingDB( + responses=[ + DummyResult(scalar_value=source_agent), + DummyResult(scalars_list=[target_agent]), + DummyResult(scalar_value=rel_id), + DummyResult(scalar_value=src_participant), + DummyResult(scalar_value=tgt_participant), + DummyResult(scalar_value=session), + DummyResult(scalar_value=tenant), + ] + ) + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_tools._wake_agent_async", new_callable=AsyncMock) as mock_wake, + ): mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - result = await _send_message_to_agent(from_agent_id, { - "agent_name": "Bob", - "message": "Hello", - "msg_type": "notify", - }) + result = await _send_message_to_agent( + from_agent_id, + { + "agent_name": "Bob", + "message": "Hello", + "msg_type": "notify", + }, + ) assert "Notification sent" in result mock_wake.assert_awaited_once() @@ -710,14 +788,18 @@ async def test_handle_set_trigger_resets_fire_count(): max_fires=1, ) - db = RecordingDB(responses=[ - DummyResult(scalar_value=agent_mock), # Load agent to get per-agent trigger limit - DummyResult(scalar_value=0), # Check max triggers (count) - DummyResult(scalar_value=existing_trigger), # Check for duplicate name - ]) + db = RecordingDB( + responses=[ + DummyResult(scalar_value=agent_mock), # Load agent to get per-agent trigger limit + DummyResult(scalar_value=0), # Check max triggers (count) + DummyResult(scalar_value=existing_trigger), # Check for duplicate name + ] + ) - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.agent_tools.ensure_focus_item", new_callable=AsyncMock) as mock_ensure: + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.agent_tools.ensure_focus_item", new_callable=AsyncMock) as mock_ensure, + ): mock_ensure.return_value = "new_focus" mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) @@ -748,14 +830,17 @@ async def test_execute_tool_failure_writes_system_message(): session_id = str(uuid.uuid4()) tenant_id = uuid.uuid4() - db = RecordingDB(responses=[ - DummyResult(scalar_value=tenant_id), # tenant_id - DummyResult(scalar_value=None), # query in _send_channel_message (returns empty -> fails) - ]) - - with patch("app.services.agent_tools.async_session") as mock_session_ctx, \ - patch("app.services.activity_logger.log_activity", new_callable=AsyncMock): + db = RecordingDB( + responses=[ + DummyResult(scalar_value=tenant_id), # tenant_id + DummyResult(scalar_value=None), # query in _send_channel_message (returns empty -> fails) + ] + ) + with ( + patch("app.services.agent_tools.async_session") as mock_session_ctx, + patch("app.services.activity_logger.log_activity", new_callable=AsyncMock), + ): mock_session_ctx.return_value.__aenter__ = AsyncMock(return_value=db) mock_session_ctx.return_value.__aexit__ = AsyncMock(return_value=False) @@ -775,16 +860,9 @@ async def test_execute_tool_failure_writes_system_message(): assert result.startswith("❌") assert db.committed assert len(db.added) == 1 - + error_msg = db.added[0] assert error_msg.conversation_id == session_id assert error_msg.role == "assistant" assert "系统提示" in error_msg.content assert "send_channel_message" in error_msg.content - - - - - - - From 13e8d27d09b89c59291c16cfbe09bccaa646e39e Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Mon, 15 Jun 2026 15:49:27 +0800 Subject: [PATCH 2/3] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=89=80=E6=9C=89?= =?UTF-8?q?channel=E7=9A=84DB=E4=BC=9A=E8=AF=9D=E9=9A=94=E7=A6=BB=E4=B8=8E?= =?UTF-8?q?Identity=20FK=E8=BF=9D=E5=8F=8D=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 问题1:ForeignKeyViolationError(users.identity_id → identities.id) 根因:channel_user_service._create_channel_user() 调用 registration_service.find_or_create_identity(),后者通过 identity_dao.create_identity() 开启独立的 async with self.session()。 由于后台channel handler 直接使用 async_session()(未经 Depends(get_db)), _session_ctx 未设置,DAO 打开新 session 执行 flush 后退出时被 rollback, Identity 行消失,随后 User INSERT 的 FK 引用失效。 修复:_create_channel_user 和 get_platform_user_by_org_member 改为 直接在传入的 db session 上创建 Identity(db.add + flush), 确保 Identity 和 User 在同一事务内,彻底解决 FK 违反。 ## 问题2:LLM调用期间长期占用DB连接 所有channel在调用LLM时仍持有数据库连接(30~180s),耗尽连接池。 修复:所有channel统一采用三段式事务模式: - Phase 1:短事务加载数据、保存用户消息 → commit + close - Phase 2:调用LLM(_call_llm_with_config,无DB session) - Phase 3:新短事务(async_session)保存回复 ## 修复文件清单 - database.py:get_db 正确设置/重置 _session_ctx ContextVar - tenants.py:修复 bind_org_member 调用签名 - enterprise.py:修复 get_email_templates 调用签名 - feishu.py:修复图片处理路径的会话泄漏 - teams.py:修复webhook持有Depends(get_db)连接跨LLM调用 - whatsapp.py:修复LLM调用期间持有连接 - wechat_channel.py:修复LLM调用期间持有连接 - wecom_stream.py:修复LLM调用期间持有连接 - discord_gateway.py:修复LLM调用期间持有连接 - channel_user_service.py:根治Identity FK违反,同session创建Identity+User --- backend/app/api/enterprise.py | 2 +- backend/app/api/feishu.py | 33 +++---- backend/app/api/teams.py | 34 +++++--- backend/app/api/tenants.py | 8 +- backend/app/api/whatsapp.py | 30 +++++-- backend/app/database.py | 3 + backend/app/services/channel_user_service.py | 91 +++++++++++++++----- backend/app/services/discord_gateway.py | 46 ++++++---- backend/app/services/wechat_channel.py | 75 +++++++++------- backend/app/services/wecom_stream.py | 57 +++++++----- 10 files changed, 249 insertions(+), 130 deletions(-) diff --git a/backend/app/api/enterprise.py b/backend/app/api/enterprise.py index b886ba81c..788ee1261 100644 --- a/backend/app/api/enterprise.py +++ b/backend/app/api/enterprise.py @@ -637,7 +637,7 @@ async def get_email_templates_endpoint( DEFAULT_EMAIL_TEMPLATES, ) - templates = await get_email_templates(db=db) + templates = await get_email_templates() return { "templates": templates, "variables": EMAIL_TEMPLATE_VARIABLES, diff --git a/backend/app/api/feishu.py b/backend/app/api/feishu.py index f789a2b78..303b6cd76 100644 --- a/backend/app/api/feishu.py +++ b/backend/app/api/feishu.py @@ -1405,8 +1405,11 @@ async def _handle_feishu_file( ) _history = convert_chat_messages_to_llm_format(reversed(_hist_r.scalars().all())) - await db.commit() + # Pre-load agent/model for LLM call before releasing DB connection + _agent_model_img, _llm_model_img, _fallback_model_img = await _load_agent_and_model(db, agent_id) + await db.commit() + # ── Phase 1 complete: release connection before slow LLM/HTTP work ── # For images: call LLM so vision models can actually see the image if msg_type == "image": import json as _json_card_img @@ -1499,20 +1502,20 @@ async def _img_heartbeat(): _img_heartbeat_task = asyncio.create_task(_img_heartbeat()) # Call LLM with image marker — vision models will parse it - async with _async_session() as _db_img: - try: - reply_text = await _call_agent_llm( - _db_img, agent_id, user_msg_content, history=_history, - user_id=platform_user_id, session_id=session_conv_id, on_chunk=_img_on_chunk, - ) - finally: - _img_llm_done = True - if _img_heartbeat_task: - _img_heartbeat_task.cancel() - try: - await _img_heartbeat_task - except Exception: - pass + try: + reply_text = await _call_llm_with_config( + _agent_model_img, _llm_model_img, _fallback_model_img, + agent_id, user_msg_content, history=_history, + user_id=platform_user_id, session_id=session_conv_id, on_chunk=_img_on_chunk, + ) + finally: + _img_llm_done = True + if _img_heartbeat_task: + _img_heartbeat_task.cancel() + try: + await _img_heartbeat_task + except Exception: + pass logger.info(f"[Feishu] Image LLM reply: {reply_text[:100]}") diff --git a/backend/app/api/teams.py b/backend/app/api/teams.py index ce7bc1b7e..4853cde43 100644 --- a/backend/app/api/teams.py +++ b/backend/app/api/teams.py @@ -17,14 +17,14 @@ from app.config import get_settings from app.core.permissions import check_agent_access, is_agent_creator from app.core.security import get_current_user -from app.database import get_db +from app.database import async_session as _async_session, get_db from app.models.agent import Agent as AgentModel from app.models.audit import ChatMessage from app.models.channel_config import ChannelConfig from app.models.user import User from app.schemas.schemas import ChannelConfigOut from app.services.channel_session import find_or_create_channel_session -from app.api.feishu import _call_agent_llm +from app.api.feishu import _call_llm_with_config, _load_agent_and_model from app.services.agent_tools import channel_file_sender as _cfs_s from app.core.security import hash_password as _hp from pathlib import Path as _Path @@ -483,7 +483,13 @@ async def teams_event_webhook( # Save user message db.add(ChatMessage(agent_id=agent_id, user_id=platform_user_id, role="user", content=user_text, conversation_id=session_conv_id)) sess.last_message_at = datetime.now(timezone.utc) + + # Pre-load agent/model for LLM call before releasing DB connection + _agent_model, _llm_model, _fallback_model = await _load_agent_and_model(db, agent_id) + await db.commit() + # ── Phase 1 complete: release connection before slow LLM call ── + await db.close() # Set channel_file_sender contextvar for agent → user file delivery async def _teams_file_sender(file_path, msg: str = ""): @@ -503,32 +509,38 @@ async def _teams_file_sender(file_path, msg: str = ""): _cfs_s_token = _cfs_s.set(_teams_file_sender) - # Call LLM + # Call LLM (no DB session needed) try: - reply_text = await _call_agent_llm( - db, + reply_text = await _call_llm_with_config( + _agent_model, _llm_model, _fallback_model, agent_id, user_text, history=history, user_id=platform_user_id, session_id=session_conv_id, ) - _cfs_s.reset(_cfs_s_token) logger.info(f"Teams: LLM reply generated: {reply_text[:80]}") except Exception as e: logger.exception(f"Teams: Failed to call LLM for agent {agent_id}: {e}") reply_text = "Sorry, I encountered an error processing your message." + finally: _cfs_s.reset(_cfs_s_token) - # Save reply + # Save reply (new short transaction) try: - db.add(ChatMessage(agent_id=agent_id, user_id=platform_user_id, role="assistant", content=reply_text, conversation_id=session_conv_id)) - sess.last_message_at = datetime.now(timezone.utc) - await db.commit() + async with _async_session() as _save_db: + _save_db.add(ChatMessage(agent_id=agent_id, user_id=platform_user_id, role="assistant", content=reply_text, conversation_id=session_conv_id)) + from app.models.chat_session import ChatSession + _sess_r = await _save_db.execute( + select(ChatSession).where(ChatSession.id == uuid.UUID(session_conv_id)) + ) + _sess_fresh = _sess_r.scalar_one_or_none() + if _sess_fresh: + _sess_fresh.last_message_at = datetime.now(timezone.utc) + await _save_db.commit() logger.info(f"Teams: Saved reply to database for conversation {conversation_id}") except Exception as e: logger.exception(f"Teams: Failed to save reply to database: {e}") - await db.rollback() # Send to Teams use_managed_identity = config.extra_config.get("use_managed_identity", False) diff --git a/backend/app/api/tenants.py b/backend/app/api/tenants.py index 2bfbf9cfd..c96d60f27 100644 --- a/backend/app/api/tenants.py +++ b/backend/app/api/tenants.py @@ -213,7 +213,7 @@ async def self_create_company( avatar_url=new_user.avatar_url, )) await db.flush() - await registration_service.bind_org_member(db, new_user) + await registration_service.bind_org_member(new_user) # Generate token scoped to the new user so frontend can switch context access_token = create_access_token(str(new_user.id), new_user.role) @@ -227,7 +227,7 @@ async def self_create_company( current_user.quota_max_agents = tenant.default_max_agents current_user.quota_agent_ttl_hours = tenant.default_agent_ttl_hours await db.flush() - await registration_service.bind_org_member(db, current_user) + await registration_service.bind_org_member(current_user) await db.commit() @@ -341,7 +341,7 @@ async def join_company( avatar_url=new_user.avatar_url, )) await db.flush() - await registration_service.bind_org_member(db, new_user) + await registration_service.bind_org_member(new_user) # Generate token scoped to the new user so frontend can switch context access_token = create_access_token(str(new_user.id), new_user.role) @@ -358,7 +358,7 @@ async def join_company( current_user.quota_agent_ttl_hours = tenant.default_agent_ttl_hours final_role = current_user.role await db.flush() - await registration_service.bind_org_member(db, current_user) + await registration_service.bind_org_member(current_user) # Increment invitation code usage code_obj.used_count += 1 diff --git a/backend/app/api/whatsapp.py b/backend/app/api/whatsapp.py index e51bd9280..f1a7eb526 100644 --- a/backend/app/api/whatsapp.py +++ b/backend/app/api/whatsapp.py @@ -265,11 +265,12 @@ async def whatsapp_event_webhook( if not user_text or not sender_phone: continue - from app.api.feishu import _call_agent_llm + from app.api.feishu import _call_llm_with_config, _load_agent_and_model from app.models.agent import Agent as AgentModel, DEFAULT_CONTEXT_WINDOW_SIZE from app.models.audit import ChatMessage from app.services.channel_session import find_or_create_channel_session from app.services.channel_user_service import channel_user_service + from app.database import async_session as _async_session agent_r = await db.execute(select(AgentModel).where(AgentModel.id == agent_id)) agent_obj = agent_r.scalar_one_or_none() @@ -305,11 +306,17 @@ async def whatsapp_event_webhook( db.add(ChatMessage(agent_id=agent_id, user_id=platform_user_id, role="user", content=user_text, conversation_id=session_conv_id)) sess.last_message_at = datetime.now(timezone.utc) + + # Pre-load agent/model before releasing connection + _agent_model, _llm_model, _fallback_model = await _load_agent_and_model(db, agent_id) + await db.commit() + await db.close() + # ── Phase 1 complete: release connection before slow LLM call ── try: - reply_text = await _call_agent_llm( - db, + reply_text = await _call_llm_with_config( + _agent_model, _llm_model, _fallback_model, agent_id, user_text, history=history, @@ -322,13 +329,18 @@ async def whatsapp_event_webhook( try: await _send_whatsapp_messages(config, sender_phone, reply_text) - config.is_connected = True - db.add(ChatMessage(agent_id=agent_id, user_id=platform_user_id, role="assistant", content=reply_text, conversation_id=session_conv_id)) - sess.last_message_at = datetime.now(timezone.utc) - await db.commit() + async with _async_session() as _save_db: + _save_db.add(ChatMessage(agent_id=agent_id, user_id=platform_user_id, role="assistant", content=reply_text, conversation_id=session_conv_id)) + from app.models.chat_session import ChatSession + _sess_r = await _save_db.execute( + select(ChatSession).where(ChatSession.id == uuid.UUID(session_conv_id)) + ) + _sess_fresh = _sess_r.scalar_one_or_none() + if _sess_fresh: + _sess_fresh.last_message_at = datetime.now(timezone.utc) + await _save_db.commit() except Exception as exc: logger.exception(f"[WhatsApp] Send failed for agent {agent_id}: {exc}") - config.is_connected = False - await db.commit() + return {"ok": True} diff --git a/backend/app/database.py b/backend/app/database.py index 8c9ad630c..df0caba1c 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -30,12 +30,15 @@ class Base(DeclarativeBase): async def get_db() -> AsyncGenerator[AsyncSession, None]: """Dependency for getting async database sessions.""" async with async_session() as session: + token = _session_ctx.set(session) try: yield session await session.commit() except Exception: await session.rollback() raise + finally: + _session_ctx.reset(token) _session_ctx: ContextVar[AsyncSession | None] = ContextVar("db_session_ctx", default=None) diff --git a/backend/app/services/channel_user_service.py b/backend/app/services/channel_user_service.py index 002a89a45..2484a8369 100644 --- a/backend/app/services/channel_user_service.py +++ b/backend/app/services/channel_user_service.py @@ -406,11 +406,22 @@ async def _create_channel_user( """Create a new Identity + User for channel identity (lazy registration). Creates a global Identity first, then a tenant-scoped User linked to it. - This ensures compatibility with the Phase 2 user model where username, - email, and password_hash live on the Identity table. + Both objects are added to the SAME ``db`` session so the FK constraint + (users.identity_id → identities.id) is satisfied within one transaction. + + The previous implementation called ``registration_service.find_or_create_identity`` + which delegates to ``identity_dao.create_identity``. That DAO method uses its own + ``async with self.session()`` context. When ``_session_ctx`` is not set (all + background channel handlers use raw ``async_session()`` directly, not FastAPI + ``Depends(get_db)``), the DAO opens a **separate** session, flushes the Identity + there, and exits — but SQLAlchemy does NOT auto-commit on session close, so the + Identity is **rolled back**. The User INSERT that follows on the outer ``db`` + session then violates the FK constraint. """ - # Generate username and email + import re as _re + email = extra_info.get("email") + mobile = extra_info.get("mobile") identity_seed = ( external_user_id or (extra_info.get("open_id") or "").strip() @@ -438,17 +449,36 @@ async def _create_channel_user( email = email or f"{username}@{channel_type}.local" - # Step 1: Find or create global Identity using unified registration service - from app.services.registration_service import registration_service - identity = await registration_service.find_or_create_identity( - email=email, - phone=extra_info.get("mobile"), - username=username, - password=uuid.uuid4().hex, - ) + # ── Step 1: Find or create Identity on the SAME session ────────────── + # First try to find an existing Identity by email / phone so we don't + # create duplicate identities for the same person across channels. + from sqlalchemy import or_ + identity: Identity | None = None + + lookup_conditions = [Identity.email == email] + if mobile: + normalized_mobile = _re.sub(r"[\s\-\+]", "", mobile) + lookup_conditions.append(Identity.phone == normalized_mobile) + id_result = await db.execute( + select(Identity).where(or_(*lookup_conditions)).limit(1) + ) + identity = id_result.scalar_one_or_none() + + if not identity: + normalized_phone = _re.sub(r"[\s\-\+]", "", mobile) if mobile else None + identity = Identity( + email=email, + phone=normalized_phone, + username=username, + password_hash=None, + is_platform_admin=False, + email_verified=True, # auto-verify channel users + ) + db.add(identity) + await db.flush() # assigns identity.id within this transaction - # Step 2: Create tenant-scoped User linked to Identity + # ── Step 2: Create tenant-scoped User linked to Identity ───────────── user = User( identity_id=identity.id, display_name=name, @@ -463,6 +493,7 @@ async def _create_channel_user( return user + # Global service instance channel_user_service = ChannelUserService() @@ -552,16 +583,36 @@ async def get_platform_user_by_org_member( email = email or f"{username}@{channel_type}.local" - # Step 3: Create new User and link to OrgMember - from app.services.registration_service import registration_service - # Use unified find_or_create_identity with dual lookup (email/phone) - identity = await registration_service.find_or_create_identity( - email=email, - phone=org_member.phone, - username=username, - password=uuid.uuid4().hex, + # Step 3: Create new Identity on the SAME session, then User + link OrgMember. + # Using registration_service.find_or_create_identity would route through + # identity_dao which opens its own session (no _session_ctx here), causing + # the Identity to be rolled back before the User FK reference is resolved. + from sqlalchemy import or_ + import re as _re_pu + + identity: Identity | None = None + lookup_conditions = [Identity.email == email] + if org_member.phone: + normalized_ph = _re_pu.sub(r"[\s\-\+]", "", org_member.phone) + lookup_conditions.append(Identity.phone == normalized_ph) + + id_result = await db.execute( + select(Identity).where(or_(*lookup_conditions)).limit(1) ) + identity = id_result.scalar_one_or_none() + if not identity: + normalized_phone = _re_pu.sub(r"[\s\-\+]", "", org_member.phone) if org_member.phone else None + identity = Identity( + email=email, + phone=normalized_phone, + username=username, + password_hash=None, + is_platform_admin=False, + email_verified=True, + ) + db.add(identity) + await db.flush() user = User( identity=identity, diff --git a/backend/app/services/discord_gateway.py b/backend/app/services/discord_gateway.py index c8384c777..70d2d2053 100644 --- a/backend/app/services/discord_gateway.py +++ b/backend/app/services/discord_gateway.py @@ -146,7 +146,7 @@ async def _handle_message( try: from app.models.audit import ChatMessage from app.models.agent import Agent as AgentModel - from app.api.feishu import _call_agent_llm + from app.api.feishu import _call_llm_with_config, _load_agent_and_model from app.services.channel_session import find_or_create_channel_session from app.models.user import User as _User from app.core.security import hash_password as _hp @@ -227,31 +227,43 @@ async def _handle_message( conversation_id=session_conv_id, )) sess.last_message_at = datetime.now(timezone.utc) - await db.commit() - # Call LLM - reply_text = await _call_agent_llm( - db, - agent_id, - user_text, - history=history, - user_id=platform_user_id, - session_id=session_conv_id, - ) - logger.info(f"[Discord GW] LLM reply for {agent_id}: {reply_text[:80]}") + # Pre-load agent/model before releasing connection + _agent_model, _llm_model, _fallback_model = await _load_agent_and_model(db, agent_id) - # Save reply - db.add(ChatMessage( + await db.commit() + # ── Phase 1 complete: release connection before slow LLM call ── + + # ── Phase 2: LLM call (no DB session) ── + reply_text = await _call_llm_with_config( + _agent_model, _llm_model, _fallback_model, + agent_id, + user_text, + history=history, + user_id=platform_user_id, + session_id=session_conv_id, + ) + logger.info(f"[Discord GW] LLM reply for {agent_id}: {reply_text[:80]}") + + # ── Phase 3: Save reply (new short transaction) ── + async with async_session() as _save_db: + _save_db.add(ChatMessage( agent_id=agent_id, user_id=platform_user_id, role="assistant", content=reply_text, conversation_id=session_conv_id, )) - sess.last_message_at = datetime.now(timezone.utc) - await db.commit() + from app.models.chat_session import ChatSession + _sess_r = await _save_db.execute( + select(ChatSession).where(ChatSession.id == _uuid.UUID(session_conv_id)) + ) + _sess_fresh = _sess_r.scalar_one_or_none() + if _sess_fresh: + _sess_fresh.last_message_at = datetime.now(timezone.utc) + await _save_db.commit() - return reply_text + return reply_text except Exception as e: logger.exception( diff --git a/backend/app/services/wechat_channel.py b/backend/app/services/wechat_channel.py index 2fd114781..64f1449a2 100644 --- a/backend/app/services/wechat_channel.py +++ b/backend/app/services/wechat_channel.py @@ -186,7 +186,7 @@ def _extract_wechat_text(item_list: list[dict[str, Any]] | None) -> str: async def _process_wechat_message(agent_id: uuid.UUID, msg: dict[str, Any], config: ChannelConfig) -> None: - from app.api.feishu import _call_agent_llm + from app.api.feishu import _call_llm_with_config, _load_agent_and_model from app.services.activity_logger import log_activity from_user_id = str(msg.get("from_user_id") or "").strip() @@ -258,30 +258,39 @@ async def _process_wechat_message(agent_id: uuid.UUID, msg: dict[str, Any], conf ) ) sess.last_message_at = datetime.now(timezone.utc) - await db.commit() - reply_text = await _call_agent_llm( - db=db, - agent_id=agent_id, - user_text=user_text, - history=history, - user_id=platform_user_id, - session_id=session_conv_id, - ) + # Pre-load agent/model before releasing the connection + _agent_model, _llm_model, _fallback_model = await _load_agent_and_model(db, agent_id) - token = str((config.extra_config or {}).get("bot_token") or "").strip() - base_url = str((config.extra_config or {}).get("baseurl") or WECHAT_ILINK_BASE_URL).strip() - route_tag = str((config.extra_config or {}).get("route_tag") or "").strip() or None - await send_wechat_text_message( - token=token, - base_url=base_url, - to_user_id=from_user_id, - context_token=context_token, - text=reply_text, - route_tag=route_tag, - ) + await db.commit() + # ── Phase 1 complete: release connection before slow LLM call ── + + # ── Phase 2: LLM call (no DB session) ── + token = str((config.extra_config or {}).get("bot_token") or "").strip() + base_url = str((config.extra_config or {}).get("baseurl") or WECHAT_ILINK_BASE_URL).strip() + route_tag = str((config.extra_config or {}).get("route_tag") or "").strip() or None + + reply_text = await _call_llm_with_config( + _agent_model, _llm_model, _fallback_model, + agent_id, + user_text, + history=history, + user_id=platform_user_id, + session_id=session_conv_id, + ) - db.add( + await send_wechat_text_message( + token=token, + base_url=base_url, + to_user_id=from_user_id, + context_token=context_token, + text=reply_text, + route_tag=route_tag, + ) + + # ── Phase 3: Save reply (new short transaction) ── + async with async_session() as _save_db: + _save_db.add( ChatMessage( agent_id=agent_id, user_id=platform_user_id, @@ -290,15 +299,21 @@ async def _process_wechat_message(agent_id: uuid.UUID, msg: dict[str, Any], conf conversation_id=session_conv_id, ) ) - sess.last_message_at = datetime.now(timezone.utc) - await db.commit() - - await log_activity( - agent_id, - "chat_reply", - f"Replied to WeChat message: {reply_text[:80]}", - detail={"channel": "wechat", "user_text": user_text[:200], "reply": reply_text[:500]}, + from app.models.chat_session import ChatSession + _sess_r = await _save_db.execute( + select(ChatSession).where(ChatSession.id == uuid.UUID(session_conv_id)) ) + _sess_fresh = _sess_r.scalar_one_or_none() + if _sess_fresh: + _sess_fresh.last_message_at = datetime.now(timezone.utc) + await _save_db.commit() + + await log_activity( + agent_id, + "chat_reply", + f"Replied to WeChat message: {reply_text[:80]}", + detail={"channel": "wechat", "user_text": user_text[:200], "reply": reply_text[:500]}, + ) class WeChatPollManager: diff --git a/backend/app/services/wecom_stream.py b/backend/app/services/wecom_stream.py index 80230c638..5fb197648 100644 --- a/backend/app/services/wecom_stream.py +++ b/backend/app/services/wecom_stream.py @@ -339,7 +339,7 @@ async def _process_wecom_stream_message( from app.models.audit import ChatMessage from app.services.channel_session import find_or_create_channel_session from app.services.channel_user_service import channel_user_service - from app.api.feishu import _call_agent_llm + from app.api.feishu import _call_llm_with_config, _load_agent_and_model async with async_session() as db: # Load agent @@ -355,9 +355,6 @@ async def _process_wecom_stream_message( conv_id = _build_wecom_conv_id(sender_id, chat_id, normalized_chat_type) # Resolve or create platform user via unified channel user service. - # This correctly handles the User/Identity model relationship - # (email/username/password_hash are AssociationProxy fields — cannot be - # set directly in UserModel constructor). platform_user = await channel_user_service.resolve_channel_user( db=db, agent=agent_obj, @@ -398,32 +395,46 @@ async def _process_wecom_stream_message( conversation_id=session_conv_id, )) sess.last_message_at = datetime.now(timezone.utc) - await db.commit() - # Call LLM - reply_text = await _call_agent_llm( - db, agent_id, user_text, - history=history, user_id=platform_user_id, - session_id=session_conv_id, - ) - logger.info(f"[WeCom Stream] LLM reply: {reply_text[:100]}") + # Pre-load agent/model before releasing connection + _agent_model, _llm_model, _fallback_model = await _load_agent_and_model(db, agent_id) - # Save assistant reply - db.add(ChatMessage( + await db.commit() + # ── Phase 1 complete: release connection before slow LLM call ── + + # ── Phase 2: LLM call (no DB session) ── + reply_text = await _call_llm_with_config( + _agent_model, _llm_model, _fallback_model, + agent_id, user_text, + history=history, user_id=platform_user_id, + session_id=session_conv_id, + ) + logger.info(f"[WeCom Stream] LLM reply: {reply_text[:100]}") + + # ── Phase 3: Save assistant reply (new short transaction) ── + async with async_session() as _save_db: + _save_db.add(ChatMessage( agent_id=agent_id, user_id=platform_user_id, role="assistant", content=reply_text, conversation_id=session_conv_id, )) - sess.last_message_at = datetime.now(timezone.utc) - await db.commit() - - # Log activity - from app.services.activity_logger import log_activity - await log_activity( - agent_id, "chat_reply", - f"Replied to WeCom message: {reply_text[:80]}", - detail={"channel": "wecom", "user_text": user_text[:200], "reply": reply_text[:500]}, + from app.models.chat_session import ChatSession + import uuid as _uuid_ws + _sess_r = await _save_db.execute( + _select(ChatSession).where(ChatSession.id == _uuid_ws.UUID(session_conv_id)) ) + _sess_fresh = _sess_r.scalar_one_or_none() + if _sess_fresh: + _sess_fresh.last_message_at = datetime.now(timezone.utc) + await _save_db.commit() + + # Log activity + from app.services.activity_logger import log_activity + await log_activity( + agent_id, "chat_reply", + f"Replied to WeCom message: {reply_text[:80]}", + detail={"channel": "wecom", "user_text": user_text[:200], "reply": reply_text[:500]}, + ) return reply_text From 2c76cc8712cff3577387cfc296b81e7b04c49240 Mon Sep 17 00:00:00 2001 From: yaojin Date: Mon, 15 Jun 2026 16:14:48 +0800 Subject: [PATCH 3/3] update session --- backend/app/dao/base.py | 20 ++++++- backend/tests/test_base_dao.py | 104 +++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 backend/tests/test_base_dao.py diff --git a/backend/app/dao/base.py b/backend/app/dao/base.py index 5ca484d38..c79668207 100644 --- a/backend/app/dao/base.py +++ b/backend/app/dao/base.py @@ -24,7 +24,17 @@ async def session(self) -> AsyncGenerator[AsyncSession, None]: yield context_session else: async with async_session() as session: - yield session + token = _session_ctx.set(session) + try: + yield session + if hasattr(session, "commit"): + await session.commit() + except Exception: + if hasattr(session, "rollback"): + await session.rollback() + raise + finally: + _session_ctx.reset(token) async def get(self, id: Any) -> ModelType | None: """Fetch a single record by its primary key ID.""" @@ -71,10 +81,14 @@ async def update(self, *, db_obj: ModelType, obj_in: dict[str, Any]) -> ModelTyp async def delete(self, *, id: Any) -> ModelType | None: """Delete a record by ID.""" async with self.session() as db: - obj = await self.get(id) + if hasattr(db, "get"): + obj = await db.get(self.model, id) + else: + stmt = select(self.model).where(self.model.id == id) + result = await db.execute(stmt) + obj = result.scalar_one_or_none() if obj: if hasattr(db, "delete"): await db.delete(obj) await db.flush() return obj - diff --git a/backend/tests/test_base_dao.py b/backend/tests/test_base_dao.py new file mode 100644 index 000000000..c54d922c5 --- /dev/null +++ b/backend/tests/test_base_dao.py @@ -0,0 +1,104 @@ +from types import SimpleNamespace + +import pytest + +from app.dao.base import BaseDAO +from app.database import _session_ctx + + +class DummyModel: + id = "id" + + +class RecordingSession: + def __init__(self): + self.added = [] + self.deleted = [] + self.flushed = False + self.committed = False + self.rolled_back = False + self.get_calls = [] + self.execute_calls = 0 + self.object_to_get = SimpleNamespace(id="row-1") + + def add(self, obj): + self.added.append(obj) + + async def flush(self): + self.flushed = True + + async def commit(self): + self.committed = True + + async def rollback(self): + self.rolled_back = True + + async def get(self, model, id): + self.get_calls.append((model, id)) + return self.object_to_get + + async def delete(self, obj): + self.deleted.append(obj) + + +class SessionFactory: + def __init__(self, session): + self.session = session + + def __call__(self): + return self + + async def __aenter__(self): + return self.session + + async def __aexit__(self, exc_type, exc, tb): + return False + + +@pytest.mark.asyncio +async def test_standalone_dao_session_sets_context_and_commits(monkeypatch): + session = RecordingSession() + monkeypatch.setattr("app.dao.base.async_session", SessionFactory(session)) + + dao = BaseDAO(DummyModel) + + async with dao.session() as db: + assert db is session + assert _session_ctx.get() is session + + assert session.committed is True + assert session.rolled_back is False + assert _session_ctx.get() is None + + +@pytest.mark.asyncio +async def test_standalone_dao_session_rolls_back_on_error(monkeypatch): + session = RecordingSession() + monkeypatch.setattr("app.dao.base.async_session", SessionFactory(session)) + + dao = BaseDAO(DummyModel) + + with pytest.raises(RuntimeError): + async with dao.session(): + raise RuntimeError("boom") + + assert session.committed is False + assert session.rolled_back is True + assert _session_ctx.get() is None + + +@pytest.mark.asyncio +async def test_delete_uses_current_session_without_nested_lookup(monkeypatch): + session = RecordingSession() + monkeypatch.setattr("app.dao.base.async_session", SessionFactory(session)) + + dao = BaseDAO(DummyModel) + + deleted = await dao.delete(id="row-1") + + assert deleted is session.object_to_get + assert session.get_calls == [(DummyModel, "row-1")] + assert session.execute_calls == 0 + assert session.deleted == [session.object_to_get] + assert session.flushed is True + assert session.committed is True