Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""

import time
from typing import Annotated, Any, Dict, List, Set, Tuple
from typing import Annotated, Any, Dict, FrozenSet, List, Set, Tuple

from ldai import log
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
from ldai.providers import AgentGraphRunner, ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult, AIGraphMetrics
from ldai.providers.types import AgentGraphRunnerResult, AIGraphMetrics, EvalRequest

from ldai_langchain.langchain_helper import (
build_structured_tools,
Expand All @@ -17,6 +17,62 @@
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler


def _message_content_to_str(content: Any) -> str:
"""Normalize a LangChain message ``content`` (string or list of parts) to a string."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get('text')
if isinstance(text, str):
parts.append(text)
return '\r\n'.join(parts)
return str(content)


def _maybe_record_eval_request(
eval_requests: List[EvalRequest],
node_key: str,
msgs: List[Any],
response: Any,
handoff_tool_names: FrozenSet[str],
) -> None:
"""
Append an :class:`EvalRequest` to ``eval_requests`` when ``response``
represents the agent's final output for this activation.

Skips emission when the response only requests further tool calls (still
working in a tool loop) or when there is no content to evaluate. Tool
calls limited to handoff tools are treated as the agent terminating with
a transfer, so the response is still emitted.
"""
tool_calls = getattr(response, 'tool_calls', None) or []
if tool_calls:
# If every tool call is a handoff, the agent is terminating with a
# transfer; otherwise it is still working through a tool loop.
for tc in tool_calls:
name = tc.get('name') if isinstance(tc, dict) else getattr(tc, 'name', None)
if name not in handoff_tool_names:
return

response_content = getattr(response, 'content', response)
output_text = _message_content_to_str(response_content)
if not output_text or not output_text.strip():
return

input_text = '\r\n'.join(
_message_content_to_str(getattr(m, 'content', m)) for m in msgs
) if msgs else ''

eval_requests.append(
EvalRequest(node_key=node_key, input=input_text, output=output_text)
)


def _make_handoff_tool(child_key: str, description: str) -> Any:
"""
Create a tool that transfers control to ``child_key``.
Expand Down Expand Up @@ -81,19 +137,10 @@ def __init__(
"""
self._graph = graph
self._tools = tools
self._compiled: Any = None
self._fn_name_to_config_key: Dict[str, str] = {}
self._node_keys: Set[str] = set()

def _ensure_compiled(self) -> None:
"""Build and cache the compiled graph if not already done."""
if self._compiled is None:
compiled, fn_name_to_config_key, node_keys = self._build_graph()
self._compiled = compiled
self._fn_name_to_config_key = fn_name_to_config_key
self._node_keys = node_keys

def _build_graph(self) -> Tuple[Any, Dict[str, str], Set[str]]:

def _build_graph(
self, eval_requests: List[EvalRequest]
) -> Tuple[Any, Dict[str, str], Set[str]]:
"""
Build and compile the LangGraph StateGraph from the AgentGraphDefinition.

Expand Down Expand Up @@ -169,20 +216,46 @@ def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
else:
model = lc_model

def make_node_fn(bound_model: Any, node_instructions: Any, nk: str):
# Names of the handoff tools attached to this node. Tool calls
# against these are control-flow signals, not the agent doing work,
# so they must not block emission of an EvalRequest.
handoff_tool_names: FrozenSet[str] = frozenset(
getattr(t, 'name', '') for t in handoff_fns
)

# Whether this node has at least one judge configured. Nodes without
# judges contribute zero EvalRequest entries.
jc = getattr(node_config, 'judge_configuration', None)
node_has_judges = bool(jc is not None and getattr(jc, 'judges', None))

def make_node_fn(
bound_model: Any,
node_instructions: Any,
nk: str,
ht_names: FrozenSet[str],
emit_eval: bool,
):
async def invoke(state: WorkflowState) -> dict:
if not bound_model:
return {'messages': []}
msgs = list(state['messages'])
if node_instructions:
msgs = [SystemMessage(content=node_instructions)] + msgs
response = await bound_model.ainvoke(msgs)

if emit_eval:
_maybe_record_eval_request(
eval_requests, nk, msgs, response, ht_names
)

return {'messages': [response]}

invoke.__name__ = nk
return invoke

invoke_fn = make_node_fn(model, instructions, node_key)
invoke_fn = make_node_fn(
model, instructions, node_key, handoff_tool_names, node_has_judges
)
agent_builder.add_node(node_key, invoke_fn)

if node_key == root_key:
Expand Down Expand Up @@ -287,14 +360,16 @@ async def run(self, input: str) -> AgentGraphRunnerResult:
:return: AgentGraphRunnerResult with the final content and AIGraphMetrics
"""
start_ns = time.perf_counter_ns()
# Per-run state — kept local so concurrent run() calls do not share it.
eval_requests: List[EvalRequest] = []

try:
from langchain_core.messages import HumanMessage

self._ensure_compiled()
handler = LDMetricsCallbackHandler(self._node_keys, self._fn_name_to_config_key)
compiled, fn_name_to_config_key, node_keys = self._build_graph(eval_requests)
handler = LDMetricsCallbackHandler(node_keys, fn_name_to_config_key)

result = await self._compiled.ainvoke( # type: ignore[call-overload]
result = await compiled.ainvoke( # type: ignore[call-overload]
{'messages': [HumanMessage(content=input)]},
config={'callbacks': [handler], 'recursion_limit': 25},
)
Expand All @@ -316,6 +391,7 @@ async def run(self, input: str) -> AgentGraphRunnerResult:
tokens=total_usage if (total_usage is not None and total_usage.total > 0) else None,
node_metrics=node_metrics,
),
eval_requests=eval_requests if eval_requests else None,
)

except Exception as exc:
Expand All @@ -334,4 +410,5 @@ async def run(self, input: str) -> AgentGraphRunnerResult:
success=False,
duration_ms=duration_ms,
),
eval_requests=eval_requests if eval_requests else None,
)
Loading
Loading