Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,23 +546,23 @@ async def test_streaming_agent_run_with_events_force_flush_otel(
async def test_async_create_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = await app.async_create_session(user_id=_TEST_USER_ID)
assert session1.user_id == _TEST_USER_ID
assert session1["user_id"] == _TEST_USER_ID
session2 = await app.async_create_session(
user_id=_TEST_USER_ID, session_id="test_session_id"
)
assert session2.user_id == _TEST_USER_ID
assert session2.id == "test_session_id"
assert session2["user_id"] == _TEST_USER_ID
assert session2["id"] == "test_session_id"

@pytest.mark.asyncio
async def test_async_get_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = await app.async_create_session(user_id=_TEST_USER_ID)
session2 = await app.async_get_session(
user_id=_TEST_USER_ID,
session_id=session1.id,
session_id=session1["id"],
)
assert session2.user_id == _TEST_USER_ID
assert session1.id == session2.id
assert session1["id"] == session2.id

@pytest.mark.asyncio
async def test_async_list_sessions(self, get_project_id_mock: mock.Mock):
Expand All @@ -572,12 +572,12 @@ async def test_async_list_sessions(self, get_project_id_mock: mock.Mock):
session = await app.async_create_session(user_id=_TEST_USER_ID)
response1 = await app.async_list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
assert response1.sessions[0].id == session.id
assert response1.sessions[0].id == session["id"]
session2 = await app.async_create_session(user_id=_TEST_USER_ID)
response2 = await app.async_list_sessions(user_id=_TEST_USER_ID)
assert len(response2.sessions) == 2
assert response2.sessions[0].id == session.id
assert response2.sessions[1].id == session2.id
assert response2.sessions[0].id == session["id"]
assert response2.sessions[1].id == session2["id"]

@pytest.mark.asyncio
async def test_async_delete_session(self, get_project_id_mock: mock.Mock):
Expand All @@ -592,30 +592,30 @@ async def test_async_delete_session(self, get_project_id_mock: mock.Mock):
assert len(response1.sessions) == 1
await app.async_delete_session(
user_id=_TEST_USER_ID,
session_id=session.id,
session_id=session["id"],
)
response0 = await app.async_list_sessions(user_id=_TEST_USER_ID)
assert not response0.sessions

def test_create_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = app.create_session(user_id=_TEST_USER_ID)
assert session1.user_id == _TEST_USER_ID
assert session1["user_id"] == _TEST_USER_ID
session2 = app.create_session(
user_id=_TEST_USER_ID, session_id="test_session_id"
)
assert session2.user_id == _TEST_USER_ID
assert session2.id == "test_session_id"
assert session2["user_id"] == _TEST_USER_ID
assert session2["id"] == "test_session_id"

def test_get_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = app.create_session(user_id=_TEST_USER_ID)
session2 = app.get_session(
user_id=_TEST_USER_ID,
session_id=session1.id,
session_id=session1["id"],
)
assert session2.user_id == _TEST_USER_ID
assert session1.id == session2.id
assert session1["id"] == session2.id

def test_list_sessions(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
Expand All @@ -624,12 +624,12 @@ def test_list_sessions(self, get_project_id_mock: mock.Mock):
session = app.create_session(user_id=_TEST_USER_ID)
response1 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
assert response1.sessions[0].id == session.id
assert response1.sessions[0].id == session["id"]
session2 = app.create_session(user_id=_TEST_USER_ID)
response2 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response2.sessions) == 2
assert response2.sessions[0].id == session.id
assert response2.sessions[1].id == session2.id
assert response2.sessions[0].id == session["id"]
assert response2.sessions[1].id == session2["id"]

def test_delete_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
Expand All @@ -638,7 +638,7 @@ def test_delete_session(self, get_project_id_mock: mock.Mock):
session = app.create_session(user_id=_TEST_USER_ID)
response1 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
app.delete_session(user_id=_TEST_USER_ID, session_id=session.id)
app.delete_session(user_id=_TEST_USER_ID, session_id=session["id"])
response0 = app.list_sessions(user_id=_TEST_USER_ID)
assert not response0.sessions

Expand Down Expand Up @@ -817,7 +817,8 @@ def test_tracing_setup(
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
)
monkeypatch.setattr("os.getpid", lambda: 123123123)
app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
with mock.patch.object(initializer.global_config, "_project", _TEST_PROJECT):
app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
app.set_up()

otlp_span_exporter_mock.assert_called_once_with(
Expand All @@ -826,7 +827,7 @@ def test_tracing_setup(
headers=mock.ANY,
)

get_project_id_mock.assert_called_with(_TEST_PROJECT_ID)
get_project_id_mock.assert_called_with(_TEST_PROJECT)

user_agent = otlp_span_exporter_mock.call_args.kwargs["headers"]["User-Agent"]
assert (
Expand Down
74 changes: 54 additions & 20 deletions tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,18 @@ def adk_version_mock():
yield adk_version_mock


@pytest.fixture
@pytest.fixture(autouse=True)
def get_project_id_mock():
with mock.patch(
"google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id"
) as get_project_id_mock:
get_project_id_mock.return_value = _TEST_PROJECT_ID
yield get_project_id_mock
with mock.patch.object(initializer.global_config, "_project", _TEST_PROJECT):
with mock.patch(
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.AdkApp._warn_if_telemetry_api_disabled",
return_value=None,
):
yield get_project_id_mock


class _MockRunner:
Expand Down Expand Up @@ -376,7 +381,7 @@ def test_initialization(self):
app = reasoning_engines.AdkApp(
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL),
)
assert app._tmpl_attrs.get("project") == _TEST_PROJECT
assert app._tmpl_attrs.get("project") == _TEST_PROJECT_ID
assert app._tmpl_attrs.get("location") == _TEST_LOCATION
assert app._tmpl_attrs.get("runner") is None

Expand Down Expand Up @@ -568,7 +573,17 @@ def test_streaming_agent_run_with_events(self):
"artifacts": [
{
"file_name": "test_file_name",
"versions": [{"version": "v1", "data": "v1data"}],
"versions": [
{
"version": "v1",
"data": {
"inline_data": {
"data": "djFkYXRh",
"mime_type": "text/plain",
}
},
}
],
}
],
"authorizations": {
Expand Down Expand Up @@ -606,7 +621,17 @@ async def test_streaming_agent_run_with_events_force_flush_otel(
"artifacts": [
{
"file_name": "test_file_name",
"versions": [{"version": "v1", "data": "v1data"}],
"versions": [
{
"version": "v1",
"data": {
"inline_data": {
"data": "djFkYXRh",
"mime_type": "text/plain",
}
},
}
],
}
],
"authorizations": {
Expand Down Expand Up @@ -682,12 +707,12 @@ def test_create_session(self):
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
)
session1 = app.create_session(user_id=_TEST_USER_ID)
assert session1.user_id == _TEST_USER_ID
assert session1["user_id"] == _TEST_USER_ID
session2 = app.create_session(
user_id=_TEST_USER_ID, session_id="test_session_id"
)
assert session2.user_id == _TEST_USER_ID
assert session2.id == "test_session_id"
assert session2["user_id"] == _TEST_USER_ID
assert session2["id"] == "test_session_id"

def test_get_session(self):
app = reasoning_engines.AdkApp(
Expand All @@ -696,10 +721,10 @@ def test_get_session(self):
session1 = app.create_session(user_id=_TEST_USER_ID)
session2 = app.get_session(
user_id=_TEST_USER_ID,
session_id=session1.id,
session_id=session1["id"],
)
assert session2.user_id == _TEST_USER_ID
assert session1.id == session2.id
assert session1["id"] == session2.id

def test_list_sessions(self):
app = reasoning_engines.AdkApp(
Expand All @@ -710,12 +735,12 @@ def test_list_sessions(self):
session = app.create_session(user_id=_TEST_USER_ID)
response1 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
assert response1.sessions[0].id == session.id
assert response1.sessions[0].id == session["id"]
session2 = app.create_session(user_id=_TEST_USER_ID)
response2 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response2.sessions) == 2
assert response2.sessions[0].id == session.id
assert response2.sessions[1].id == session2.id
assert response2.sessions[0].id == session["id"]
assert response2.sessions[1].id == session2["id"]

def test_delete_session(self):
app = reasoning_engines.AdkApp(
Expand All @@ -726,7 +751,7 @@ def test_delete_session(self):
session = app.create_session(user_id=_TEST_USER_ID)
response1 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
app.delete_session(user_id=_TEST_USER_ID, session_id=session.id)
app.delete_session(user_id=_TEST_USER_ID, session_id=session["id"])
response0 = app.list_sessions(user_id=_TEST_USER_ID)
assert not response0.sessions

Expand All @@ -740,14 +765,14 @@ async def test_async_add_session_to_memory(self):
list(
app.stream_query(
user_id=_TEST_USER_ID,
session_id=session.id,
session_id=session["id"],
message="My cat's name is Garfield",
)
)
await app.async_add_session_to_memory(
session=app.get_session(
user_id=_TEST_USER_ID,
session_id=session.id,
session_id=session["id"],
)
)
response = await app.async_search_memory(
Expand Down Expand Up @@ -838,7 +863,7 @@ def test_default_instrumentor_enablement(

# Assert
default_instrumentor_builder_mock.assert_called_once_with(
_TEST_PROJECT,
_TEST_PROJECT_ID,
enable_tracing=want_tracing_setup,
enable_logging=want_logging_setup,
)
Expand All @@ -863,11 +888,16 @@ def test_tracing_setup(
monkeypatch.setattr("os.getpid", lambda: 123123123)
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
app._warn_if_telemetry_api_disabled = lambda: None
app.set_up()
with mock.patch(
"google.cloud.aiplatform.vertexai.agent_engines._utils.is_noop_or_proxy_tracer_provider",
return_value=True,
):
app.set_up()

expected_attributes = {
"cloud.account.id": _TEST_PROJECT_ID,
"cloud.platform": "gcp.agent_engine",
"cloud.provider": "gcp",
"cloud.region": "us-central1",
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project-id/locations/us-central1/reasoningEngines/test_agent_id",
"gcp.project_id": _TEST_PROJECT_ID,
Expand All @@ -876,7 +906,7 @@ def test_tracing_setup(
"some-attribute": "some-value",
"telemetry.sdk.language": "python",
"telemetry.sdk.name": "opentelemetry",
"telemetry.sdk.version": "1.36.0",
"telemetry.sdk.version": "1.39.0",
"some-attribute": "some-value",
}

Expand All @@ -886,7 +916,11 @@ def test_tracing_setup(
headers=mock.ANY,
)

get_project_id_mock.assert_called_once_with(_TEST_PROJECT)
calls = [
mock.call(project_number=_TEST_PROJECT_ID, credentials=mock.ANY),
mock.call(_TEST_PROJECT_ID),
]
get_project_id_mock.assert_has_calls(calls)

user_agent = otlp_span_exporter_mock.call_args.kwargs["headers"]["User-Agent"]
assert (
Expand Down
18 changes: 15 additions & 3 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,18 @@ def __init__(
),
}

def _serialize(self, obj: Any) -> Any:
"""Serializes an object to be JSON compatible."""
if hasattr(obj, "model_dump"):
return obj.model_dump(mode="json")
elif hasattr(obj, "dict"):
return self._serialize(obj.dict())
elif isinstance(obj, dict):
return {k: self._serialize(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._serialize(v) for v in obj]
return obj

def _app_name(self) -> str:
"""Returns the app name."""
app = self._tmpl_attrs.get("app")
Expand Down Expand Up @@ -1062,7 +1074,7 @@ async def async_stream_query(
)
if not session_id:
session = await self.async_create_session(user_id=user_id)
session_id = session.id
session_id = session["id"]
if session_events is not None:
# We allow for session_events to be an empty list.
from google.adk.events.event import Event
Expand Down Expand Up @@ -1163,7 +1175,7 @@ def stream_query(
self.set_up()
if not session_id:
session = self.create_session(user_id=user_id)
session_id = session.id
session_id = session["id"]
run_config = _validate_run_config(run_config)
if run_config:
for event in self._tmpl_attrs.get("runner").run(
Expand Down Expand Up @@ -1469,7 +1481,7 @@ async def async_create_session(
state=state,
**kwargs,
)
return session
return self._serialize(session)

def create_session(
self,
Expand Down
Loading
Loading