From bed9850c4281e359f6d3e68b92c181ab9cd7bd7c Mon Sep 17 00:00:00 2001 From: sciapanCA Date: Sat, 6 Jun 2026 18:35:34 +0200 Subject: [PATCH 1/2] Add datasource relevance filter --- src/tests/test_datasources.py | 170 +++++++++++++++++++++++++++++++++- src/tools/datasources.py | 102 +++++++++++++++++--- 2 files changed, 259 insertions(+), 13 deletions(-) diff --git a/src/tests/test_datasources.py b/src/tests/test_datasources.py index 21e490b..635f320 100644 --- a/src/tests/test_datasources.py +++ b/src/tests/test_datasources.py @@ -177,4 +177,172 @@ async def test_get_data_sources_handles_missing_repository_ids(mock_get_api_key) workspace = data_sources[0] assert workspace["id"] == "workspace-1" assert workspace["name"] == "Test Workspace" - assert "repositoryIds" not in workspace \ No newline at end of file + assert "repositoryIds" not in workspace + + +def _ctx_with_response(json_return, headers=None): + """Builds a mocked Context whose client.get returns a response with the given JSON body.""" + mock_ctx = MagicMock(spec=Context) + mock_ctx.info = AsyncMock() + mock_ctx.warning = AsyncMock() + mock_ctx.error = AsyncMock() + + mock_response = MagicMock() + mock_response.json.return_value = json_return + mock_response.headers = headers or {} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + + mock_lifespan_context = MagicMock() + mock_lifespan_context.base_url = "https://api.example.com" + mock_lifespan_context.client = mock_client + mock_ctx.request_context.lifespan_context = mock_lifespan_context + return mock_ctx, mock_client + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_with_query_passes_query_param(mock_get_api_key): + """When a query is supplied, it is forwarded to the listing endpoint as the `query` param.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, mock_client = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "handles OAuth"}, + ]) + + await get_data_sources(mock_ctx, alive_only=True, query="add OAuth to checkout") + + call_args = mock_client.get.call_args + assert call_args.args[0] == "/api/datasources/ready" + assert call_args.kwargs["params"] == {"query": "add OAuth to checkout"} + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_without_query_sends_no_query_param(mock_get_api_key): + """Without a query, no `query` param is sent (legacy behavior unchanged).""" + mock_get_api_key.return_value = "test-key" + mock_ctx, mock_client = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository"}, + ]) + + await get_data_sources(mock_ctx, alive_only=True) + + call_args = mock_client.get.call_args + assert call_args.kwargs.get("params") is None + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_surfaces_relevance_reason(mock_get_api_key): + """relevanceReason is preserved per item for the client (wrapped shape when query is set).""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "implements the checkout flow"}, + ]) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = json.loads(result) + assert payload["dataSources"][0]["relevanceReason"] == "implements the checkout flow" + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_filtered_hint_reports_total_and_omitted(mock_get_api_key): + """Filtered success surfaces how many sources exist beyond the shown subset and how to get them.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + headers={"X-CodeAlive-Total-Data-Sources": "25"}, + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = json.loads(result) + assert len(payload["dataSources"]) == 1 + assert "1 of 25" in payload["message"] + assert "omitted" in payload["message"].lower() + assert "without a query" in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_filtered_hint_without_total_header(mock_get_api_key): + """Filtered success without the total header still hints that sources were omitted.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = json.loads(result) + assert "omitted" in payload["message"].lower() + assert "without a query" in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_all_relevant_hint_reports_no_omission(mock_get_api_key): + """When every available source is relevant, the hint says so instead of claiming omissions.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + headers={"X-CodeAlive-Total-Data-Sources": "1"}, + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = json.loads(result) + assert "all 1" in payload["message"].lower() + assert "omitted" not in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_failopen_hint_when_no_reasons_present(mock_get_api_key): + """Query supplied but no item carries relevanceReason → the filter did not run (fail-open, + disabled, or an older backend); the hint must say the FULL list is returned.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository"}, + {"id": "repo-2", "name": "Other", "type": "Repository"}, + ]) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = json.loads(result) + assert len(payload["dataSources"]) == 2 + assert "unavailable" in payload["message"].lower() + assert "full" in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_empty_with_query_returns_no_relevant_message(mock_get_api_key): + """Empty result WITH a query returns a 'no relevant' message, not 'add a repository'.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([]) + + result = await get_data_sources(mock_ctx, alive_only=True, query="something unrelated") + + payload = json.loads(result) + assert payload["dataSources"] == [] + assert "relevant" in payload["message"].lower() + assert "add a repository" not in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_empty_without_query_keeps_add_repository_message(mock_get_api_key): + """Empty result WITHOUT a query keeps the existing 'add a repository' message.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([]) + + result = await get_data_sources(mock_ctx, alive_only=True) + + payload = json.loads(result) + assert payload["dataSources"] == [] + assert "add a repository" in payload["message"].lower() \ No newline at end of file diff --git a/src/tools/datasources.py b/src/tools/datasources.py index fd6875f..3a473e9 100644 --- a/src/tools/datasources.py +++ b/src/tools/datasources.py @@ -6,14 +6,56 @@ import httpx from fastmcp import Context -from core import CodeAliveContext, get_api_key_from_context, log_api_request, log_api_response +from core import ( + CodeAliveContext, + get_api_key_from_context, + log_api_request, + log_api_response, +) from utils import handle_api_error # MCP tool/method name surfaced in every error/log message from this module. _TOOL_NAME = "get_data_sources" +# Pre-filter scoped candidate count, emitted by the backend only on relevance-filtered requests. +_TOTAL_HEADER = "X-CodeAlive-Total-Data-Sources" + + +def _relevance_message(data_sources: list, response) -> str: + """Builds the hint accompanying a query'd (relevance-filtered) result. + + The backend guarantees every relevance-selected item carries a non-empty `relevanceReason`, + so a query'd response where NO item has one means the filter did not run (fail-open on error, + disabled by config, or an older backend ignoring `query`) and the FULL list was returned — + the model must be told, instead of mistaking the full dump for a relevant shortlist. + """ + filtered = any(ds.get("relevanceReason") for ds in data_sources) + if not filtered: + return ( + "Relevance filtering was unavailable for this request (it may have failed or be " + "disabled), so the FULL unfiltered list of data sources is returned." + ) + + shown = len(data_sources) + total_header = response.headers.get(_TOTAL_HEADER) + total = int(total_header) if total_header and total_header.isdigit() else None + if total is not None and total > shown: + return ( + f"{shown} of {total} available data sources are relevant to this query; the other " + f"{total - shown} were omitted. Call get_data_sources without a query to get the full list." + ) + if total is not None: + return f"All {total} available data sources are relevant to this query." + return ( + "Only the data sources relevant to this query are shown; non-relevant sources were " + "omitted. Call get_data_sources without a query to get the full list." + ) + + # alive_only refers to ready_only. leaved as is for backward compatibility. -async def get_data_sources(ctx: Context, alive_only: bool = True) -> str: +async def get_data_sources( + ctx: Context, alive_only: bool = True, query: str | None = None +) -> str: """ **CALL THIS FIRST**: Gets all available data sources (repositories and workspaces) for the user's account. @@ -27,9 +69,19 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> str: Args: alive_only: If True (default), returns only data sources that are fully processed and ready for use. If False, returns all data sources regardless of processing state. + query: Optional. The user's initial intent/task in natural language (e.g. "add OAuth to + checkout"). When provided, the backend runs an agentic relevance filter and returns + ONLY the data sources relevant to that intent, each with a `relevanceReason` + explaining why. This is the user's GOAL — distinct from `searchTerm` (a substring + name filter). Omit it to get the full list. Pass it whenever you + know what the user is trying to accomplish, to keep the returned list focused. Returns: - A compact JSON array of available data sources with the following fields for each: + Without `query`: a compact JSON array of available data sources. + With `query`: a JSON object {"dataSources": [...], "message": "..."} where `message` tells + you whether sources were omitted as non-relevant (and how many of the total), that every + available source was relevant, or that relevance filtering was unavailable and the FULL + list is returned. Each data source has the following fields: - id: Unique identifier for the data source - name: Human-readable name - CRITICAL for matching with current working directory name - description: Summary of codebase contents - CRITICAL for identifying if this matches your @@ -37,6 +89,7 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> str: - type: The type of data source ("Repository" or "Workspace") - url: Repository URL (for Repository type only) - useful for matching with git remote - state: The processing state of the data source (if alive_only=false) + - relevanceReason: Why this source is relevant to `query` (present ONLY when `query` was supplied) Use name + description + url together to determine if a repository is the CURRENT one you're working in versus an EXTERNAL repository. @@ -92,12 +145,17 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> str: "X-CodeAlive-Client": "fastmcp", } + # Thread the user's intent as the `query` param when present so the backend relevance + # filter runs. Omitted entirely otherwise, so the request is unchanged for legacy callers + # (and an older backend that ignores `query` simply returns the full list). + params = {"query": query} if query else None + # Log the request full_url = urljoin(context.base_url, endpoint) request_id = log_api_request("GET", full_url, headers) # Make API request - response = await context.client.get(endpoint, headers=headers) + response = await context.client.get(endpoint, headers=headers, params=params) # Log the response log_api_response(response, request_id) @@ -107,27 +165,47 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> str: # Parse and format the response data_sources = response.json() - # If no data sources found, return an empty JSON array with a hint + # If no data sources found, return an empty JSON array with a hint. With a `query`, an empty + # result means "nothing relevant to this intent" (sources DO exist) — a distinct message from + # the no-sources-at-all case, so the model doesn't tell the user to add a repository. if not data_sources or len(data_sources) == 0: + message = ( + "No data sources are relevant to this query. Try a broader query, or call " + "get_data_sources without a query to see the full list." + if query + else "No data sources found. Please add a repository or workspace to CodeAlive before using this API." + ) return json.dumps( - { - "dataSources": [], - "message": "No data sources found. Please add a repository or workspace to CodeAlive before using this API.", - }, + {"dataSources": [], "message": message}, separators=(",", ":"), ) # Remove repositoryIds from workspace data sources for data_source in data_sources: - if data_source.get("type") == "Workspace" and "repositoryIds" in data_source: + if ( + data_source.get("type") == "Workspace" + and "repositoryIds" in data_source + ): del data_source["repositoryIds"] - # Return compact JSON + if query: + return json.dumps( + { + "dataSources": data_sources, + "message": _relevance_message(data_sources, response), + }, + separators=(",", ":"), + ) + + # Return compact JSON (no query → legacy bare array, byte-for-byte unchanged) return json.dumps(data_sources, separators=(",", ":")) except (httpx.HTTPStatusError, Exception) as e: await handle_api_error( - ctx, e, "retrieving data sources", method=_TOOL_NAME, + ctx, + e, + "retrieving data sources", + method=_TOOL_NAME, recovery_hints={ # 422 means *some* sources are still indexing — surface alive_only=false as the next step 422: ( From eba206e3582f9284df9e7206a2cb37400d716f6b Mon Sep 17 00:00:00 2001 From: sciapanCA Date: Sat, 6 Jun 2026 21:37:55 +0200 Subject: [PATCH 2/2] Improve parsing of total header --- src/tests/test_datasources.py | 17 +++++++++++++++++ src/tools/datasources.py | 7 +++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/tests/test_datasources.py b/src/tests/test_datasources.py index d82c9c3..c79b640 100644 --- a/src/tests/test_datasources.py +++ b/src/tests/test_datasources.py @@ -276,6 +276,23 @@ async def test_get_data_sources_filtered_hint_without_total_header(mock_get_api_ assert "without a query" in payload["message"].lower() +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_filtered_hint_with_malformed_total_header(mock_get_api_key): + """A malformed total header is treated as absent rather than raising.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + headers={"X-CodeAlive-Total-Data-Sources": "not-a-number"}, + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = result + assert "omitted" in payload["message"].lower() + assert "without a query" in payload["message"].lower() + + @pytest.mark.asyncio @patch('tools.datasources.get_api_key_from_context') async def test_get_data_sources_all_relevant_hint_reports_no_omission(mock_get_api_key): diff --git a/src/tools/datasources.py b/src/tools/datasources.py index 729de0d..672bf7a 100644 --- a/src/tools/datasources.py +++ b/src/tools/datasources.py @@ -37,8 +37,11 @@ def _relevance_message(data_sources: list, response) -> str: ) shown = len(data_sources) - total_header = response.headers.get(_TOTAL_HEADER) - total = int(total_header) if total_header and total_header.isdigit() else None + try: + total = int(response.headers.get(_TOTAL_HEADER)) + except (TypeError, ValueError): + # Header absent (TypeError on int(None)) or malformed (ValueError). + total = None if total is not None and total > shown: return ( f"{shown} of {total} available data sources are relevant to this query; the other "