diff --git a/src/tests/test_datasources.py b/src/tests/test_datasources.py index 1adefa6..c79b640 100644 --- a/src/tests/test_datasources.py +++ b/src/tests/test_datasources.py @@ -170,4 +170,187 @@ async def test_get_data_sources_handles_missing_repository_ids(mock_get_api_key) workspace = data_sources[0] assert workspace["id"] == "workspace-1" assert workspace["name"] == "Test Workspace" - assert "repositoryIds" not in workspace \ No newline at end of file + assert "repositoryIds" not in workspace + + +def _ctx_with_response(json_return, headers=None): + """Builds a mocked Context whose client.get returns a response with the given JSON body.""" + mock_ctx = MagicMock(spec=Context) + mock_ctx.info = AsyncMock() + mock_ctx.warning = AsyncMock() + mock_ctx.error = AsyncMock() + + mock_response = MagicMock() + mock_response.json.return_value = json_return + mock_response.headers = headers or {} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + + mock_lifespan_context = MagicMock() + mock_lifespan_context.base_url = "https://api.example.com" + mock_lifespan_context.client = mock_client + mock_ctx.request_context.lifespan_context = mock_lifespan_context + return mock_ctx, mock_client + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_with_query_passes_query_param(mock_get_api_key): + """When a query is supplied, it is forwarded to the listing endpoint as the `query` param.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, mock_client = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "handles OAuth"}, + ]) + + await get_data_sources(mock_ctx, alive_only=True, query="add OAuth to checkout") + + call_args = mock_client.get.call_args + assert call_args.args[0] == "/api/datasources/ready" + assert call_args.kwargs["params"] == {"query": "add OAuth to checkout"} + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_without_query_sends_no_query_param(mock_get_api_key): + """Without a query, no `query` param is sent (legacy behavior unchanged).""" + mock_get_api_key.return_value = "test-key" + mock_ctx, mock_client = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository"}, + ]) + + await get_data_sources(mock_ctx, alive_only=True) + + call_args = mock_client.get.call_args + assert call_args.kwargs.get("params") is None + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_surfaces_relevance_reason(mock_get_api_key): + """relevanceReason is preserved per item for the client (wrapped shape when query is set).""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "implements the checkout flow"}, + ]) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = result + assert payload["dataSources"][0]["relevanceReason"] == "implements the checkout flow" + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_filtered_hint_reports_total_and_omitted(mock_get_api_key): + """Filtered success surfaces how many sources exist beyond the shown subset and how to get them.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + headers={"X-CodeAlive-Total-Data-Sources": "25"}, + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = result + assert len(payload["dataSources"]) == 1 + assert "1 of 25" in payload["message"] + assert "omitted" in payload["message"].lower() + assert "without a query" in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_filtered_hint_without_total_header(mock_get_api_key): + """Filtered success without the total header still hints that sources were omitted.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = result + assert "omitted" in payload["message"].lower() + assert "without a query" in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_filtered_hint_with_malformed_total_header(mock_get_api_key): + """A malformed total header is treated as absent rather than raising.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + headers={"X-CodeAlive-Total-Data-Sources": "not-a-number"}, + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = result + assert "omitted" in payload["message"].lower() + assert "without a query" in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_all_relevant_hint_reports_no_omission(mock_get_api_key): + """When every available source is relevant, the hint says so instead of claiming omissions.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response( + [{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}], + headers={"X-CodeAlive-Total-Data-Sources": "1"}, + ) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = result + assert "all 1" in payload["message"].lower() + assert "omitted" not in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_failopen_hint_when_no_reasons_present(mock_get_api_key): + """Query supplied but no item carries relevanceReason → the filter did not run (fail-open, + disabled, or an older backend); the hint must say the FULL list is returned.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([ + {"id": "repo-1", "name": "Repo", "type": "Repository"}, + {"id": "repo-2", "name": "Other", "type": "Repository"}, + ]) + + result = await get_data_sources(mock_ctx, alive_only=True, query="checkout") + + payload = result + assert len(payload["dataSources"]) == 2 + assert "unavailable" in payload["message"].lower() + assert "full" in payload["message"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_empty_with_query_returns_no_relevant_hint(mock_get_api_key): + """Empty result WITH a query returns a 'no relevant' hint, not 'add a repository'.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([]) + + result = await get_data_sources(mock_ctx, alive_only=True, query="something unrelated") + + assert result["dataSources"] == [] + assert "relevant" in result["hint"].lower() + assert "add a repository" not in result["hint"].lower() + + +@pytest.mark.asyncio +@patch('tools.datasources.get_api_key_from_context') +async def test_get_data_sources_empty_without_query_keeps_add_repository_hint(mock_get_api_key): + """Empty result WITHOUT a query keeps the existing 'add a repository' hint.""" + mock_get_api_key.return_value = "test-key" + mock_ctx, _ = _ctx_with_response([]) + + result = await get_data_sources(mock_ctx, alive_only=True) + + assert result["dataSources"] == [] + assert "add a repository" in result["hint"].lower() \ No newline at end of file diff --git a/src/tools/datasources.py b/src/tools/datasources.py index 2886577..672bf7a 100644 --- a/src/tools/datasources.py +++ b/src/tools/datasources.py @@ -6,12 +6,55 @@ import httpx from fastmcp import Context -from core import CodeAliveContext, get_api_key_from_context, log_api_request, log_api_response +from core import ( + CodeAliveContext, + get_api_key_from_context, + log_api_request, + log_api_response, +) from utils import handle_api_error # MCP tool/method name surfaced in every error/log message from this module. _TOOL_NAME = "get_data_sources" +# Pre-filter scoped candidate count, emitted by the backend only on relevance-filtered requests. +_TOTAL_HEADER = "X-CodeAlive-Total-Data-Sources" + + +def _relevance_message(data_sources: list, response) -> str: + """Builds the hint accompanying a query'd (relevance-filtered) result. + + The backend guarantees every relevance-selected item carries a non-empty `relevanceReason`, + so a query'd response where NO item has one means the filter did not run (fail-open on error, + disabled by config, or an older backend ignoring `query`) and the FULL list was returned — + the model must be told, instead of mistaking the full dump for a relevant shortlist. + """ + filtered = any(ds.get("relevanceReason") for ds in data_sources) + if not filtered: + return ( + "Relevance filtering was unavailable for this request (it may have failed or be " + "disabled), so the FULL unfiltered list of data sources is returned." + ) + + shown = len(data_sources) + try: + total = int(response.headers.get(_TOTAL_HEADER)) + except (TypeError, ValueError): + # Header absent (TypeError on int(None)) or malformed (ValueError). + total = None + if total is not None and total > shown: + return ( + f"{shown} of {total} available data sources are relevant to this query; the other " + f"{total - shown} were omitted. Call get_data_sources without a query to get the full list." + ) + if total is not None: + return f"All {total} available data sources are relevant to this query." + return ( + "Only the data sources relevant to this query are shown; non-relevant sources were " + "omitted. Call get_data_sources without a query to get the full list." + ) + + # Hint embedded in every successful response. Mirrors the convention used by # the search tools (see _SEARCH_HINT in utils/response_transformer.py): the # response is always in front of the model when it picks the next step, so we @@ -31,9 +74,19 @@ "being indexed." ) +# Empty result WITH a query means "nothing relevant to this intent" (sources DO exist) — +# a distinct hint from the no-sources-at-all case, so the model doesn't tell the user +# to add a repository. +_DATASOURCES_EMPTY_QUERY_HINT = ( + "No data sources are relevant to this query. Try a broader query, or call " + "get_data_sources without a query to see the full list." +) + # alive_only refers to ready_only. leaved as is for backward compatibility. -async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, Any]: +async def get_data_sources( + ctx: Context, alive_only: bool = True, query: str | None = None +) -> Dict[str, Any]: """ **CALL THIS FIRST**: Gets all available data sources (repositories and workspaces) for the user's account. @@ -47,10 +100,20 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A Args: alive_only: If True (default), returns only data sources that are fully processed and ready for use. If False, returns all data sources regardless of processing state. + query: Optional. The user's initial intent/task in natural language (e.g. "add OAuth to + checkout"). When provided, the backend runs an agentic relevance filter and returns + ONLY the data sources relevant to that intent, each with a `relevanceReason` + explaining why. This is the user's GOAL — distinct from `searchTerm` (a substring + name filter). Omit it to get the full list. Pass it whenever you + know what the user is trying to accomplish, to keep the returned list focused. Returns: {"dataSources": [...], "hint": "..."} + With `query`, the object also carries a `message` field telling you whether sources + were omitted as non-relevant (and how many of the total), that every available source + was relevant, or that relevance filtering was unavailable and the FULL list is returned. + Each entry in `dataSources` carries: - id: Unique identifier for the data source - name: Human-readable name - CRITICAL for matching with current working directory name @@ -59,6 +122,7 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A - type: The type of data source ("Repository" or "Workspace") - url: Repository URL (for Repository type only) - useful for matching with git remote - state: The processing state of the data source (if alive_only=false) + - relevanceReason: Why this source is relevant to `query` (present ONLY when `query` was supplied) The `hint` field reminds you how to use the result and how to distinguish the CURRENT repository from EXTERNAL ones. @@ -117,12 +181,17 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A "X-CodeAlive-Client": "fastmcp", } + # Thread the user's intent as the `query` param when present so the backend relevance + # filter runs. Omitted entirely otherwise, so the request is unchanged for legacy callers + # (and an older backend that ignores `query` simply returns the full list). + params = {"query": query} if query else None + # Log the request full_url = urljoin(context.base_url, endpoint) request_id = log_api_request("GET", full_url, headers) # Make API request - response = await context.client.get(endpoint, headers=headers) + response = await context.client.get(endpoint, headers=headers, params=params) # Log the response log_api_response(response, request_id) @@ -133,19 +202,29 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A data_sources = response.json() if not data_sources or len(data_sources) == 0: - return {"dataSources": [], "hint": _DATASOURCES_EMPTY_HINT} + hint = _DATASOURCES_EMPTY_QUERY_HINT if query else _DATASOURCES_EMPTY_HINT + return {"dataSources": [], "hint": hint} # Remove repositoryIds from workspace data sources for data_source in data_sources: - if data_source.get("type") == "Workspace" and "repositoryIds" in data_source: + if ( + data_source.get("type") == "Workspace" + and "repositoryIds" in data_source + ): del data_source["repositoryIds"] # FastMCP serializes via pydantic_core.to_json, which preserves UTF-8. - return {"dataSources": data_sources, "hint": _DATASOURCES_HINT} + result: Dict[str, Any] = {"dataSources": data_sources, "hint": _DATASOURCES_HINT} + if query: + result["message"] = _relevance_message(data_sources, response) + return result except (httpx.HTTPStatusError, Exception) as e: await handle_api_error( - ctx, e, "retrieving data sources", method=_TOOL_NAME, + ctx, + e, + "retrieving data sources", + method=_TOOL_NAME, recovery_hints={ # 422 means *some* sources are still indexing — surface alive_only=false as the next step 422: (