diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 91fb568cd3..da831dce3c 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -439,7 +439,15 @@ async def _handle_before_agent_callback( Returns: Optional[Event]: an event if callback provides content or changed state. """ - callback_context = CallbackContext(ctx) + last_event = None + if ctx.session.events: + last_event = ctx.session.events[-1] + + event_actions = None + if last_event and last_event.actions: + event_actions = last_event.actions + + callback_context = CallbackContext(ctx, event_actions=event_actions) # Run callbacks from the plugins. before_agent_callback_content = ( @@ -499,7 +507,15 @@ async def _handle_after_agent_callback( Optional[Event]: an event if callback provides content or changed state. """ - callback_context = CallbackContext(invocation_context) + last_event = None + if invocation_context.session.events: + last_event = invocation_context.session.events[-1] + + event_actions = None + if last_event and last_event.actions: + event_actions = last_event.actions + + callback_context = CallbackContext(invocation_context, event_actions=event_actions) # Run callbacks from the plugins. after_agent_callback_content = ( diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index cd9e88f718..108f2581ce 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -684,6 +684,92 @@ async def test_run_async_with_async_after_agent_callback_append_reply( ) +@pytest.mark.asyncio +async def test_run_async_after_agent_callback_state_visibility( + request: pytest.FixtureRequest, +): + class _StateTestingAgent(_TestingAgent): + @override + async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text='State change event')]), + actions=EventActions(state_delta={"test_key": "test_val"}), + ) + + agent = _StateTestingAgent( + name=f'{request.function.__name__}_test_agent', + ) + + def verify_state_callback(ctx: CallbackContext) -> None: + assert ctx.state['test_key'] == 'test_val' + + agent.after_agent_callback = verify_state_callback + + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, agent + ) + + from google.adk.runners import InMemoryRunner + runner = InMemoryRunner(agent=agent, app_name='test') + + events = [e async for e in runner.run_async( + user_id='user1', + session_id=parent_ctx.session.id, + new_message=types.Content(parts=[types.Part(text='hello')]) + )] + + assert len(events) == 1 + assert events[0].author == agent.name + + +@pytest.mark.asyncio +async def test_run_async_before_agent_callback_state_visibility( + request: pytest.FixtureRequest, +): + from google.adk.agents.sequential_agent import SequentialAgent + + class _StateTestingAgent(_TestingAgent): + @override + async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text='Agent 1 event')]), + actions=EventActions(state_delta={"shared_key": "shared_val"}), + ) + + agent1 = _StateTestingAgent(name='agent1') + + def verify_state_callback(ctx: CallbackContext) -> None: + assert ctx.state['shared_key'] == 'shared_val' + + agent2 = _TestingAgent( + name='agent2', + before_agent_callback=verify_state_callback + ) + + seq_agent = SequentialAgent( + name='seq_agent', + sub_agents=[agent1, agent2] + ) + + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, seq_agent + ) + + from google.adk.runners import InMemoryRunner + runner = InMemoryRunner(agent=seq_agent, app_name='test') + + events = [e async for e in runner.run_async( + user_id='user1', + session_id=parent_ctx.session.id, + new_message=types.Content(parts=[types.Part(text='hello')]) + )] + + assert len(events) == 2 + @pytest.mark.asyncio async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')