Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,23 @@
_COMPACTION_CUSTOM_METADATA_KEY = '_compaction'
_USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata'

_SESSION_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]+$')


def _validate_session_id(session_id: str) -> None:
"""Rejects session IDs that could escape the URL path segment.

Vertex AI session IDs are interpolated into the REST URL path. Without
validation, values like '..' or '../foo' resolve to sibling resources at
HTTP-client level (path traversal), allowing cross-resource reads/deletes.
"""
if not isinstance(session_id, str) or not _SESSION_ID_PATTERN.fullmatch(
session_id
):
raise ValueError(
f'Invalid session_id {session_id!r}: must match {_SESSION_ID_PATTERN.pattern}.'
)


def _quote_filter_literal(value: str) -> str:
"""Quotes filter values so embedded metacharacters stay inside the literal."""
Expand Down Expand Up @@ -127,6 +144,7 @@ async def create_session(

config = {'session_state': state} if state else {}
if session_id:
_validate_session_id(session_id)
config['session_id'] = session_id
config.update(kwargs)
async with self._get_api_client() as api_client:
Expand Down Expand Up @@ -157,6 +175,7 @@ async def get_session(
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
_validate_session_id(session_id)
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
session_resource_name = (
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
Expand Down Expand Up @@ -256,14 +275,32 @@ async def list_sessions(
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
_validate_session_id(session_id)
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
session_resource_name = (
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
)

async with self._get_api_client() as api_client:
# Ownership check: ensure the session belongs to the caller before
# deleting. Without this, delete_session ignores user_id entirely and
# any caller who knows a session_id can delete it.
try:
existing = await api_client.agent_engines.sessions.get(
name=session_resource_name
)
except ClientError as e:
if e.code == 404:
return
raise
if existing.user_id != user_id:
raise ValueError(
f'Session {session_id} does not belong to user {user_id}.'
)

try:
await api_client.agent_engines.sessions.delete(
name=(
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
),
name=session_resource_name,
)
except Exception as e:
logger.error('Error deleting session %s: %s', session_id, e)
Expand All @@ -274,6 +311,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
await super().append_event(session=session, event=event)

_validate_session_id(session.id)
reasoning_engine_id = self._get_reasoning_engine_id(session.app_name)

# Build config (Monolithic approach)
Expand Down
39 changes: 39 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,45 @@ async def test_get_and_delete_session():
assert str(excinfo.value) == '404 Session not found: 1'


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_delete_session_rejects_other_users_session():
"""delete_session must not delete a session owned by a different user."""
session_service = mock_vertex_ai_session_service()

# session '1' belongs to 'user'; 'user2' must not be allowed to delete it.
with pytest.raises(ValueError) as excinfo:
await session_service.delete_session(
app_name='123', user_id='user2', session_id='1'
)
assert 'does not belong to user user2' in str(excinfo.value)

# Session must still exist.
assert (
await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
== MOCK_SESSION
)


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_session_id_path_traversal_rejected():
"""Session IDs containing path-traversal characters must be rejected."""
session_service = mock_vertex_ai_session_service()

for bad_id in ['..', '../foo', '..?force=true', 'a/b', '']:
with pytest.raises(ValueError):
await session_service.delete_session(
app_name='123', user_id='user', session_id=bad_id
)
with pytest.raises(ValueError):
await session_service.get_session(
app_name='123', user_id='user', session_id=bad_id
)


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_session_with_page_token():
Expand Down