Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 184 additions & 1 deletion src/tests/test_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,187 @@ async def test_get_data_sources_handles_missing_repository_ids(mock_get_api_key)
workspace = data_sources[0]
assert workspace["id"] == "workspace-1"
assert workspace["name"] == "Test Workspace"
assert "repositoryIds" not in workspace
assert "repositoryIds" not in workspace


def _ctx_with_response(json_return, headers=None):
"""Builds a mocked Context whose client.get returns a response with the given JSON body."""
mock_ctx = MagicMock(spec=Context)
mock_ctx.info = AsyncMock()
mock_ctx.warning = AsyncMock()
mock_ctx.error = AsyncMock()

mock_response = MagicMock()
mock_response.json.return_value = json_return
mock_response.headers = headers or {}
mock_response.raise_for_status = MagicMock()

mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)

mock_lifespan_context = MagicMock()
mock_lifespan_context.base_url = "https://api.example.com"
mock_lifespan_context.client = mock_client
mock_ctx.request_context.lifespan_context = mock_lifespan_context
return mock_ctx, mock_client


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_with_query_passes_query_param(mock_get_api_key):
"""When a query is supplied, it is forwarded to the listing endpoint as the `query` param."""
mock_get_api_key.return_value = "test-key"
mock_ctx, mock_client = _ctx_with_response([
{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "handles OAuth"},
])

await get_data_sources(mock_ctx, alive_only=True, query="add OAuth to checkout")

call_args = mock_client.get.call_args
assert call_args.args[0] == "/api/datasources/ready"
assert call_args.kwargs["params"] == {"query": "add OAuth to checkout"}


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_without_query_sends_no_query_param(mock_get_api_key):
"""Without a query, no `query` param is sent (legacy behavior unchanged)."""
mock_get_api_key.return_value = "test-key"
mock_ctx, mock_client = _ctx_with_response([
{"id": "repo-1", "name": "Repo", "type": "Repository"},
])

await get_data_sources(mock_ctx, alive_only=True)

call_args = mock_client.get.call_args
assert call_args.kwargs.get("params") is None


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_surfaces_relevance_reason(mock_get_api_key):
"""relevanceReason is preserved per item for the client (wrapped shape when query is set)."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response([
{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "implements the checkout flow"},
])

result = await get_data_sources(mock_ctx, alive_only=True, query="checkout")

payload = result
assert payload["dataSources"][0]["relevanceReason"] == "implements the checkout flow"


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_filtered_hint_reports_total_and_omitted(mock_get_api_key):
"""Filtered success surfaces how many sources exist beyond the shown subset and how to get them."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response(
[{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}],
headers={"X-CodeAlive-Total-Data-Sources": "25"},
)

result = await get_data_sources(mock_ctx, alive_only=True, query="checkout")

payload = result
assert len(payload["dataSources"]) == 1
assert "1 of 25" in payload["message"]
assert "omitted" in payload["message"].lower()
assert "without a query" in payload["message"].lower()


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_filtered_hint_without_total_header(mock_get_api_key):
"""Filtered success without the total header still hints that sources were omitted."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response(
[{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}],
)

result = await get_data_sources(mock_ctx, alive_only=True, query="checkout")

payload = result
assert "omitted" in payload["message"].lower()
assert "without a query" in payload["message"].lower()


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_filtered_hint_with_malformed_total_header(mock_get_api_key):
"""A malformed total header is treated as absent rather than raising."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response(
[{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}],
headers={"X-CodeAlive-Total-Data-Sources": "not-a-number"},
)

result = await get_data_sources(mock_ctx, alive_only=True, query="checkout")

payload = result
assert "omitted" in payload["message"].lower()
assert "without a query" in payload["message"].lower()


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_all_relevant_hint_reports_no_omission(mock_get_api_key):
"""When every available source is relevant, the hint says so instead of claiming omissions."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response(
[{"id": "repo-1", "name": "Repo", "type": "Repository", "relevanceReason": "checkout flow"}],
headers={"X-CodeAlive-Total-Data-Sources": "1"},
)

result = await get_data_sources(mock_ctx, alive_only=True, query="checkout")

payload = result
assert "all 1" in payload["message"].lower()
assert "omitted" not in payload["message"].lower()


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_failopen_hint_when_no_reasons_present(mock_get_api_key):
"""Query supplied but no item carries relevanceReason → the filter did not run (fail-open,
disabled, or an older backend); the hint must say the FULL list is returned."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response([
{"id": "repo-1", "name": "Repo", "type": "Repository"},
{"id": "repo-2", "name": "Other", "type": "Repository"},
])

result = await get_data_sources(mock_ctx, alive_only=True, query="checkout")

payload = result
assert len(payload["dataSources"]) == 2
assert "unavailable" in payload["message"].lower()
assert "full" in payload["message"].lower()


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_empty_with_query_returns_no_relevant_hint(mock_get_api_key):
"""Empty result WITH a query returns a 'no relevant' hint, not 'add a repository'."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response([])

result = await get_data_sources(mock_ctx, alive_only=True, query="something unrelated")

assert result["dataSources"] == []
assert "relevant" in result["hint"].lower()
assert "add a repository" not in result["hint"].lower()


@pytest.mark.asyncio
@patch('tools.datasources.get_api_key_from_context')
async def test_get_data_sources_empty_without_query_keeps_add_repository_hint(mock_get_api_key):
"""Empty result WITHOUT a query keeps the existing 'add a repository' hint."""
mock_get_api_key.return_value = "test-key"
mock_ctx, _ = _ctx_with_response([])

result = await get_data_sources(mock_ctx, alive_only=True)

assert result["dataSources"] == []
assert "add a repository" in result["hint"].lower()
93 changes: 86 additions & 7 deletions src/tools/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,55 @@
import httpx
from fastmcp import Context

from core import CodeAliveContext, get_api_key_from_context, log_api_request, log_api_response
from core import (
Comment thread
sciapanCA marked this conversation as resolved.
CodeAliveContext,
get_api_key_from_context,
log_api_request,
log_api_response,
)
from utils import handle_api_error

# MCP tool/method name surfaced in every error/log message from this module.
_TOOL_NAME = "get_data_sources"

# Pre-filter scoped candidate count, emitted by the backend only on relevance-filtered requests.
_TOTAL_HEADER = "X-CodeAlive-Total-Data-Sources"


def _relevance_message(data_sources: list, response) -> str:
"""Builds the hint accompanying a query'd (relevance-filtered) result.

The backend guarantees every relevance-selected item carries a non-empty `relevanceReason`,
so a query'd response where NO item has one means the filter did not run (fail-open on error,
disabled by config, or an older backend ignoring `query`) and the FULL list was returned —
the model must be told, instead of mistaking the full dump for a relevant shortlist.
"""
filtered = any(ds.get("relevanceReason") for ds in data_sources)
if not filtered:
return (
"Relevance filtering was unavailable for this request (it may have failed or be "
"disabled), so the FULL unfiltered list of data sources is returned."
)

shown = len(data_sources)
try:
total = int(response.headers.get(_TOTAL_HEADER))
except (TypeError, ValueError):
# Header absent (TypeError on int(None)) or malformed (ValueError).
total = None
if total is not None and total > shown:
return (
f"{shown} of {total} available data sources are relevant to this query; the other "
f"{total - shown} were omitted. Call get_data_sources without a query to get the full list."
)
if total is not None:
return f"All {total} available data sources are relevant to this query."
return (
"Only the data sources relevant to this query are shown; non-relevant sources were "
"omitted. Call get_data_sources without a query to get the full list."
)


# Hint embedded in every successful response. Mirrors the convention used by
# the search tools (see _SEARCH_HINT in utils/response_transformer.py): the
# response is always in front of the model when it picks the next step, so we
Expand All @@ -31,9 +74,19 @@
"being indexed."
)

# Empty result WITH a query means "nothing relevant to this intent" (sources DO exist) —
# a distinct hint from the no-sources-at-all case, so the model doesn't tell the user
# to add a repository.
_DATASOURCES_EMPTY_QUERY_HINT = (
"No data sources are relevant to this query. Try a broader query, or call "
"get_data_sources without a query to see the full list."
)


# alive_only refers to ready_only. leaved as is for backward compatibility.
async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, Any]:
async def get_data_sources(
ctx: Context, alive_only: bool = True, query: str | None = None
) -> Dict[str, Any]:
Comment thread
sciapanCA marked this conversation as resolved.
"""
**CALL THIS FIRST**: Gets all available data sources (repositories and workspaces) for the user's account.

Expand All @@ -47,10 +100,20 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A
Args:
alive_only: If True (default), returns only data sources that are fully processed and ready for use.
If False, returns all data sources regardless of processing state.
query: Optional. The user's initial intent/task in natural language (e.g. "add OAuth to
checkout"). When provided, the backend runs an agentic relevance filter and returns
ONLY the data sources relevant to that intent, each with a `relevanceReason`
explaining why. This is the user's GOAL — distinct from `searchTerm` (a substring
name filter). Omit it to get the full list. Pass it whenever you
know what the user is trying to accomplish, to keep the returned list focused.

Returns:
{"dataSources": [...], "hint": "..."}

With `query`, the object also carries a `message` field telling you whether sources
were omitted as non-relevant (and how many of the total), that every available source
was relevant, or that relevance filtering was unavailable and the FULL list is returned.

Each entry in `dataSources` carries:
- id: Unique identifier for the data source
- name: Human-readable name - CRITICAL for matching with current working directory name
Expand All @@ -59,6 +122,7 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A
- type: The type of data source ("Repository" or "Workspace")
- url: Repository URL (for Repository type only) - useful for matching with git remote
- state: The processing state of the data source (if alive_only=false)
- relevanceReason: Why this source is relevant to `query` (present ONLY when `query` was supplied)

The `hint` field reminds you how to use the result and how to distinguish
the CURRENT repository from EXTERNAL ones.
Expand Down Expand Up @@ -117,12 +181,17 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A
"X-CodeAlive-Client": "fastmcp",
}

# Thread the user's intent as the `query` param when present so the backend relevance
# filter runs. Omitted entirely otherwise, so the request is unchanged for legacy callers
# (and an older backend that ignores `query` simply returns the full list).
params = {"query": query} if query else None

# Log the request
full_url = urljoin(context.base_url, endpoint)
request_id = log_api_request("GET", full_url, headers)

# Make API request
response = await context.client.get(endpoint, headers=headers)
response = await context.client.get(endpoint, headers=headers, params=params)

# Log the response
log_api_response(response, request_id)
Expand All @@ -133,19 +202,29 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> Dict[str, A
data_sources = response.json()

if not data_sources or len(data_sources) == 0:
return {"dataSources": [], "hint": _DATASOURCES_EMPTY_HINT}
hint = _DATASOURCES_EMPTY_QUERY_HINT if query else _DATASOURCES_EMPTY_HINT
return {"dataSources": [], "hint": hint}

# Remove repositoryIds from workspace data sources
for data_source in data_sources:
if data_source.get("type") == "Workspace" and "repositoryIds" in data_source:
if (
data_source.get("type") == "Workspace"
and "repositoryIds" in data_source
):
del data_source["repositoryIds"]

# FastMCP serializes via pydantic_core.to_json, which preserves UTF-8.
return {"dataSources": data_sources, "hint": _DATASOURCES_HINT}
result: Dict[str, Any] = {"dataSources": data_sources, "hint": _DATASOURCES_HINT}
if query:
result["message"] = _relevance_message(data_sources, response)
return result

except (httpx.HTTPStatusError, Exception) as e:
await handle_api_error(
ctx, e, "retrieving data sources", method=_TOOL_NAME,
ctx,
e,
"retrieving data sources",
method=_TOOL_NAME,
recovery_hints={
# 422 means *some* sources are still indexing — surface alive_only=false as the next step
422: (
Expand Down
Loading