diff --git a/CLAUDE.md b/CLAUDE.md index 1791bd7..2c790ec 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,29 +29,31 @@ quantmind/ Key principle: QuantMind does NOT rebuild Agent runtime, lifecycle hooks, tracing, multi-agent handoff, or tool framework. Those come from `openai-agents`. -## Current Repository State (transitional, after PR #70 / #73 / #74 / PR4) +## Current Repository State (after PR #70 / #73 / #74 / #75 / PR5) | Module | Status | Notes | |--------|--------|-------| | `quantmind/knowledge/` | landed (PR3) | data standard with three shapes: `FlattenKnowledge` (`News` / `Earnings` / `PaperKnowledgeCard`), `TreeKnowledge` (`Paper`), `GraphKnowledge` (placeholder); shared base = `BaseKnowledge` with typed `SourceRef` / `ExtractionRef` provenance + `embedding_text()` contract | | `quantmind/configs/` | landed (PR3) | `BaseFlowCfg` / `BaseInput` + per-flow cfg + discriminated-union input types | -| `quantmind/preprocess/` | landed (PR4) | `fetch/` (`fetch_arxiv` / `fetch_url` / `resolve_doi` / `read_local_file` returning `Fetched` / `RawPaper` / `CrossrefMetadata` frozen dataclasses) + `format/` (`pdf_to_markdown` via PyMuPDF, `html_to_markdown` via trafilatura) + `clean.py` (`normalize_unicode` / `collapse_whitespace` / `dedupe_lines`) + `time.py` (`to_utc` / `parse_filing_date` / `business_days_between`); leaf module — only depends on `quantmind.utils` | +| `quantmind/preprocess/` | landed (PR4) | `fetch/` (`fetch_arxiv` / `fetch_url` / `resolve_doi` / `read_local_file` returning `Fetched` / `RawPaper` / `CrossrefMetadata` frozen dataclasses) + `format/` (`pdf_to_markdown` via PyMuPDF, `html_to_markdown` via trafilatura) + `clean.py` + `time.py`; leaf module — only depends on `quantmind.utils` | +| `quantmind/flows/` | landed (PR5) | apex layer: `paper_flow` (`PaperInput` → `Paper` via SDK Agent), `batch_run` + `BatchResult` (bounded-concurrency fan-out, `memory=` rejected by design), `_runner.run_with_observability` + `_compose_hooks` + `_archive_run_artifacts` (PR6 stub); only depends on configs/knowledge/preprocess/utils + `agents` SDK | +| `quantmind/magic.py` | landed (PR5) | `resolve_magic_input(natural_language, *, target_flow, ...) -> (input, cfg)` plus `preview_resolve` debug helper; introspects flow signatures and runs a lightweight resolver Agent with `output_type=ResolvedFlowConfig[InputT, CfgT]` | | `quantmind/utils/logger.py` | permanent | only general-purpose utility | -| `quantmind/flow/` | transitional | replaced by `flows/` in PR5 | -| `quantmind/config/` | transitional | superseded by `quantmind/configs/` (PR3); deletion when consumers (`flow/`, `llm/`) migrate in PR5 | -| `quantmind/llm/` | transitional | deleted in PR5 (use SDK + `openai` directly) | -| `quantmind/models/{content,paper,analysis}.py` | transitional | superseded by `quantmind/knowledge/` (PR3); deletion when consumers (`flow/`) migrate in PR5 | - -`quantmind/parsers/`, `quantmind/sources/`, and `quantmind/utils/tmp.py` -were removed in PR4 (replaced by `preprocess/format/`, `preprocess/fetch/`, -and deletion respectively). - -Transitional modules are excluded from `basedpyright` AND from -`coverage.run` (see `pyproject.toml`) to keep the harness green during -migration. New modules (`knowledge/`, `configs/`, `preprocess/`, `flows/`, -`mind/`, `magic.py`) are auto-included at standard mode and gated by -`import-linter` contracts (4 contracts as of PR4) so they cannot -accidentally pull in a transitional module. + +PR5 removed the transitional packages (`quantmind/{flow,llm,config,models}/` +and their tests under `tests/{config,models}/`); PR4 had already removed +`quantmind/parsers/`, `quantmind/sources/`, and `quantmind/utils/tmp.py`. +The codebase has now converged to the five permanent module roots +(`flows/`, `configs/`, `knowledge/`, `preprocess/`, `mind/`) plus +`magic.py` and `utils/`. + +`basedpyright` runs in standard mode across the whole `quantmind/` +package — there are no per-module exclusions left. Five `import-linter` +contracts pin the dependency graph: `utils` and `knowledge` are leaves, +`configs` only depends on `knowledge`, `preprocess` only depends on +`utils`, and `flows + magic` is the apex (cannot import the deleted +transitional packages, which are listed in the contract as a tripwire +against accidental re-introduction). ## Development Commands @@ -79,9 +81,8 @@ It runs five steps in fixed order, fast-failing on the first error: 2. `ruff check` — lint (D, E, F, I, W, B, W505) must pass 3. `basedpyright` — standard-mode type check on permanent + new modules 4. `lint-imports` — architectural boundary contracts must hold -5. `pytest --cov` — tests pass with ≥ 65% branch coverage (will ratchet up - to 75%+ in PR5 once `flow/` / `llm/` / `config/` and the transitional - `models/*.py` are removed) +5. `pytest --cov` — tests pass with ≥ 75% branch coverage (raised from 65 + in PR5 after the transitional packages were deleted) Pre-commit hooks (`.pre-commit-config.yaml`): - pre-commit stage: trailing whitespace / EOF / ruff / ruff-format (fast) @@ -148,7 +149,8 @@ issue instead. - ❌ Add a CLI (`argparse`/`typer`/`click`); users run Python runbook scripts - ❌ Introduce class-based `BaseFlow` / plugin registry / hook discovery - ❌ Wrap `from agents import ...` in a QuantMind-side facade — use the SDK directly -- ❌ Mix `batch_run` and `memory` (they will be mutually exclusive in MVP; see PR5) +- ❌ Mix `batch_run` and `memory` (mutually exclusive in MVP; `batch_run` rejects + `memory=` at the signature layer — design doc §4.3.5) - ❌ Use `Dict[str, Any]` in init functions; use Pydantic models - ❌ Add hard deps on observability platforms (Langfuse / Logfire / etc.); document integration via `add_trace_processor()` in user-facing cookbook only @@ -170,8 +172,8 @@ issue instead. | #70 (merged) | Clean removal of self-built agent runtime | | #73 (merged) | Golden Harness — `scripts/verify.sh` with ruff + basedpyright + import-linter + pytest --cov, plus matching CI | | #74 (merged) | `knowledge/` data standard (Flatten / Tree / Graph shapes) + `configs/` skeleton; `openai-agents>=0.14` introduced for `BaseFlowCfg.model_settings` | -| PR4 (this PR) | `preprocess/` (fetch + format two layers); deletes `parsers/` + `sources/` + `utils/tmp.py`; coverage floor 60→65; 4th import-linter contract (`preprocess` is a leaf) | -| PR5 | `flows/` + `paper_flow` + `batch_run` + `magic.py`; delete `quantmind/flow/`, `quantmind/llm/`, `quantmind/config/`, `quantmind/models/{content,paper,analysis}.py` | -| PR6 | `mind/memory/filesystem` MVP + trajectory archive | +| #75 (merged) | `preprocess/` (fetch + format two layers); deletes `parsers/` + `sources/` + `utils/tmp.py`; coverage floor 60→65; 4th import-linter contract | +| PR5 (this PR) | `flows/` (`paper_flow` + `batch_run` + `BatchResult` + `_runner`) + `magic.py`; deletes `quantmind/{flow,llm,config,models}/`; coverage floor 65→75; 5th import-linter contract pins `flows + magic` as apex | +| PR6 | `mind/memory/filesystem` MVP + trajectory archive (fills `_archive_run_artifacts` stub) | | PR7 | `mind/store/` + SQLite + `sqlite-vec` MVP; introduces `preprocess/chunk.py` with `tiktoken` | | PR8+ | Second flow (news/earnings) / observability cookbook / longer-term modules | diff --git a/README.md b/README.md index 33c129c..fb6390f 100644 --- a/README.md +++ b/README.md @@ -154,29 +154,80 @@ We use [uv](https://github.com/astral-sh/uv) for fast and reliable Python packag ### 📚 Usage Examples -#### Fetch and format an arXiv paper +#### Run a single paper through `paper_flow` ```python import asyncio -from quantmind.preprocess import fetch_arxiv, pdf_to_markdown +from quantmind.configs import PaperFlowCfg +from quantmind.configs.paper import ArxivIdentifier +from quantmind.flows import paper_flow async def main() -> None: - raw = await fetch_arxiv("arXiv:2401.12345") - markdown = await pdf_to_markdown(raw.bytes) - print(f"Title: {raw.title}") - print(f"Authors: {', '.join(raw.authors)}") - print(markdown[:500]) + paper = await paper_flow( + ArxivIdentifier(id="2401.12345"), + cfg=PaperFlowCfg(model="gpt-4o-mini"), + ) + print(paper.model_dump_json(indent=2)) + + +asyncio.run(main()) +``` + +#### Fan out a batch with `batch_run` + +```python +import asyncio + +from quantmind.configs import PaperFlowCfg +from quantmind.configs.paper import ArxivIdentifier +from quantmind.flows import batch_run, paper_flow + + +async def main() -> None: + inputs = [ArxivIdentifier(id=aid) for aid in ( + "2401.12345", "2401.12346", "2401.12347", + )] + result = await batch_run( + paper_flow, + inputs, + cfg=PaperFlowCfg(model="gpt-4o-mini"), + concurrency=3, + on_error="skip", + on_progress=lambda done, total: print(f"{done}/{total}"), + ) + print(f"ok={result.success_count} failed={result.failure_count}") + + +asyncio.run(main()) +``` + +#### Resolve free-form intent with `magic` + +```python +import asyncio + +from quantmind.flows import paper_flow +from quantmind.magic import resolve_magic_input + + +async def main() -> None: + inp, cfg = await resolve_magic_input( + "Pull arXiv 2401.12345 about cross-sectional momentum; use gpt-4o-mini.", + target_flow=paper_flow, + ) + paper = await paper_flow(inp, cfg=cfg) + print(paper.model_dump_json(indent=2)) asyncio.run(main()) ``` > **Note**: QuantMind is mid-migration to OpenAI Agents SDK -> (see [#71](https://github.com/LLMQuant/quant-mind/issues/71)). The high-level -> flows/storage APIs land in upcoming PRs; for now the `preprocess/` and -> `knowledge/` layers are stable. +> (see [#71](https://github.com/LLMQuant/quant-mind/issues/71)). PR5 lands the +> apex layer (`flows/` + `magic.py`); the remaining work is the `mind/` +> memory + store layer scheduled for PR6 and PR7. --- @@ -201,13 +252,13 @@ QuantMind is designed with a larger vision: to become a comprehensive intelligen The foundation we're building today—starting with papers—will expand to encompass the entire financial information ecosystem. > [!NOTE] -> **Future Conceptual Example:** +> **Future Conceptual Example (PR6 brings `FilesystemMemory`):** > > ```python -> # The future we are building towards -> from quantmind.flows import paper_flow, batch_run +> from quantmind.configs.paper import ArxivIdentifier +> from quantmind.flows import paper_flow > from quantmind.knowledge import Paper -> from quantmind.mind.memory import FilesystemMemory +> from quantmind.mind.memory import FilesystemMemory # PR6 > > memory = FilesystemMemory("./mem/factor-research/") > for arxiv_id in arxiv_ids: diff --git a/pyproject.toml b/pyproject.toml index 04cb42b..50d6fe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,29 +92,17 @@ convention = "google" # ---------------------------------------------------------------------------- # basedpyright: type checker # ---------------------------------------------------------------------------- -# Strategy: explicit include-list of files that survive PR3-PR5 plus the -# target-architecture module dirs (knowledge/, configs/, preprocess/, flows/, -# mind/). Transitional modules still scheduled for deletion in PR5 -# (config/, flow/, llm/, models/{analysis,content,paper}.py) are NOT -# included; they get deleted soon and are not worth modernizing for type -# checks. parsers/, sources/, utils/tmp.py were removed in PR4. +# Every module in `quantmind/` is type-checked at "standard" mode. PR5 +# deleted the transitional packages (`config/`, `flow/`, `llm/`, +# `models/`) so the exclude list is now down to the directories pyright +# always wants ignored. [tool.basedpyright] include = ["quantmind"] -# Transitional modules excluded from type checking — they get deleted in -# PR5 (per CLAUDE.md "Current Repository State" table). New modules -# (knowledge/, configs/, preprocess/, flows/, mind/, magic.py) automatically -# get type-checked at "standard" mode as they land — no further config needed. exclude = [ "**/__pycache__", "**/.venv", "**/build", - "quantmind/config", - "quantmind/flow", - "quantmind/llm", - "quantmind/models/analysis.py", - "quantmind/models/content.py", - "quantmind/models/paper.py", ] pythonVersion = "3.10" typeCheckingMode = "standard" @@ -131,8 +119,10 @@ reportIncompatibleVariableOverride = "none" # import-linter: architectural boundary contracts # ---------------------------------------------------------------------------- # Encodes the target architecture: utils, knowledge, and preprocess are -# leaves; configs depends only on knowledge. flows (PR5) and mind (PR6+) -# get their own contracts when they land. +# leaves; configs depends only on knowledge; flows + magic is the apex +# layer (PR5). The transitional packages (config/, flow/, llm/, models/) +# were deleted in PR5; they remain in the forbidden lists as a tripwire +# against accidental re-introduction during future refactors. [tool.importlinter] root_packages = ["quantmind"] @@ -142,12 +132,10 @@ name = "utils is a leaf (no inbound deps from quantmind packages)" type = "forbidden" source_modules = ["quantmind.utils"] forbidden_modules = [ - "quantmind.config", "quantmind.configs", - "quantmind.flow", + "quantmind.flows", "quantmind.knowledge", - "quantmind.llm", - "quantmind.models", + "quantmind.magic", "quantmind.preprocess", ] @@ -156,36 +144,46 @@ name = "knowledge is a leaf (no inbound deps from quantmind packages)" type = "forbidden" source_modules = ["quantmind.knowledge"] forbidden_modules = [ - "quantmind.config", "quantmind.configs", - "quantmind.flow", - "quantmind.llm", - "quantmind.models", + "quantmind.flows", + "quantmind.magic", "quantmind.preprocess", "quantmind.utils", ] [[tool.importlinter.contracts]] -name = "configs only depends on knowledge (transitional modules forbidden)" +name = "configs only depends on knowledge" type = "forbidden" source_modules = ["quantmind.configs"] forbidden_modules = [ - "quantmind.config", - "quantmind.flow", - "quantmind.llm", - "quantmind.models", + "quantmind.flows", + "quantmind.magic", "quantmind.preprocess", ] [[tool.importlinter.contracts]] -name = "preprocess only depends on utils (no inbound deps on configs/knowledge/transitional)" +name = "preprocess only depends on utils (no inbound deps on configs/knowledge/flows)" type = "forbidden" source_modules = ["quantmind.preprocess"] forbidden_modules = [ - "quantmind.config", "quantmind.configs", - "quantmind.flow", + "quantmind.flows", "quantmind.knowledge", + "quantmind.magic", +] + +[[tool.importlinter.contracts]] +name = "flows + magic is apex (no transitional package imports)" +type = "forbidden" +source_modules = [ + "quantmind.flows", + "quantmind.magic", +] +# These packages were deleted in PR5; the contract guards against any +# future code re-introducing them under the same names. +forbidden_modules = [ + "quantmind.config", + "quantmind.flow", "quantmind.llm", "quantmind.models", ] @@ -196,39 +194,19 @@ forbidden_modules = [ [tool.pytest.ini_options] testpaths = ["tests"] -# Coverage floor 65% after PR4 deleted parsers/sources (the low-coverage -# drag) and added preprocess/ at >85% line. Will ratchet to 75% in PR5 -# once flow/, llm/, config/, models/{content,paper,analysis}.py are gone. +# Coverage floor 75% — PR5 deleted the transitional packages so every +# remaining module is in the target architecture and is well-tested. addopts = [ "--cov=quantmind", "--cov-report=term-missing", - "--cov-fail-under=65", + "--cov-fail-under=75", "-ra", ] asyncio_mode = "auto" -filterwarnings = [ - # Pydantic v1 → v2 transition warnings on transitional model code that - # gets removed in PR5. Suppress until those modules are gone. - "ignore::DeprecationWarning:pydantic.*", - "ignore::pydantic.PydanticDeprecatedSince20", -] [tool.coverage.run] source = ["quantmind"] branch = true -# Transitional modules slated for PR5 deletion. They lost their tests in -# PR4 (parsers/sources tests were dragging the floor down anyway and were -# the only callers of flow/, llm/, config/parsers, config/sources). Once -# PR5 deletes them, this omit list goes empty and the floor ratchets to 75. -omit = [ - "quantmind/flow/*", - "quantmind/llm/*", - "quantmind/config/parsers.py", - "quantmind/config/sources.py", - "quantmind/config/flows.py", - "quantmind/config/registry.py", - "quantmind/models/analysis.py", -] [tool.coverage.report] exclude_lines = [ @@ -240,4 +218,4 @@ exclude_lines = [ # Branch coverage. Floor moves with the migration: # PR3: 60% (parsers/sources still drag the average down) # PR4: 65% (parsers/sources gone; preprocess >85%) -# PR5: 75% target (flow/, llm/, config/, transitional models/ gone) +# PR5: 75% (flow/, llm/, config/, transitional models/ gone) diff --git a/quantmind/config/__init__.py b/quantmind/config/__init__.py deleted file mode 100644 index 51bb433..0000000 --- a/quantmind/config/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Configuration management for QuantMind.""" - -from .embedding import EmbeddingConfig -from .flows import ( - BaseFlowConfig, - SummaryFlowConfig, -) -from .llm import LLMConfig -from .parsers import LlamaParserConfig, PDFParserConfig -from .settings import ( - Setting, - create_default_config, - load_config, -) -from .sources import ( - ArxivSourceConfig, - BaseSourceConfig, - NewsSourceConfig, - WebSourceConfig, -) -from .storage import BaseStorageConfig, LocalStorageConfig -from .taggers import LLMTaggerConfig - -__all__ = [ - # Core Settings - "Setting", - # LLM Configuration - "LLMConfig", - "EmbeddingConfig", - # Tagger Configurations - "LLMTaggerConfig", - # Parser Configurations - "PDFParserConfig", - "LlamaParserConfig", - # Source Configurations - "BaseSourceConfig", - "ArxivSourceConfig", - "NewsSourceConfig", - "WebSourceConfig", - # Storage Configurations - "BaseStorageConfig", - "LocalStorageConfig", - # Flow Configurations - "BaseFlowConfig", - "SummaryFlowConfig", - # Utility Functions - "create_default_config", - "load_config", -] diff --git a/quantmind/config/embedding.py b/quantmind/config/embedding.py deleted file mode 100644 index 1774d5f..0000000 --- a/quantmind/config/embedding.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Embedding configuration for QuantMind.""" - -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field, field_validator - - -class EmbeddingConfig(BaseModel): - """Configuration for EmbeddingBlock.""" - - # Model configuration - model: str = Field( - default="text-embedding-ada-002", description="Embedding model name" - ) - - # Optional parameters - user: Optional[str] = Field( - default=None, - description="A unique identifier representing your end-user", - ) - dimensions: Optional[int] = Field( - default=None, - description="The number of dimensions the resulting output embeddings should have. Only supported in OpenAI/Azure text-embedding-3 and later models", - ) - encoding_format: str = Field( - default="float", - description="The format to return the embeddings in. Can be either 'float' or 'base64'", - ) - timeout: int = Field( - default=600, - description="The maximum time, in seconds, to wait for the API to respond", - ) - retry_attempts: int = Field( - default=3, - ge=0, - description="The number of retry attempts", - ) - retry_delay: float = Field( - default=1.0, - ge=0, - description="The delay between retries in seconds", - ) - api_base: Optional[str] = Field( - default=None, - description="The api endpoint you want to call the model with", - ) - api_version: Optional[str] = Field( - default=None, - description="(Azure-specific) the api version for the call", - ) - api_key: Optional[str] = Field( - default=None, - description="The API key to authenticate and authorize requests. If not provided, the default API key is used", - ) - api_type: Optional[str] = Field( - default=None, description="The type of API to use" - ) - - @field_validator("model") - def validate_model(cls, v): - """Validate model name format.""" - if not v or not isinstance(v, str): - raise ValueError("Model name must be a non-empty string") - return v.strip() - - @field_validator("api_key") - def validate_api_key(cls, v): - """Validate API key.""" - if v is not None and not isinstance(v, str): - raise ValueError("API key must be a string") - return v - - def get_litellm_params(self) -> Dict[str, Any]: - """Get parameters formatted for LiteLLM embedding.""" - params = { - "model": self.model, - } - - # Add optional parameters if provided - if self.user: - params["user"] = self.user - if self.dimensions: - params["dimensions"] = self.dimensions - if self.encoding_format: - params["encoding_format"] = self.encoding_format - if self.api_base: - params["api_base"] = self.api_base - if self.api_version: - params["api_version"] = self.api_version - if self.api_key: - params["api_key"] = self.api_key - if self.api_type: - params["api_type"] = self.api_type - - return params - - def get_provider_type(self) -> str: - """Extract provider type from model name.""" - model_lower = self.model.lower() - - # OpenAI models - if model_lower in [ - "text-embedding-ada-002", - "text-embedding-3-small", - "text-embedding-3-large", - ]: - return "openai" - - # Azure models - elif "azure" in model_lower: - return "azure" - - # Gemini models - elif "gemini" in model_lower: - return "gemini" - - # Default to openai for unknown models - else: - return "unknown" - - def create_variant(self, **overrides) -> "EmbeddingConfig": - """Create a variant of this config with parameter overrides.""" - current_dict = self.model_dump() - current_dict.update(overrides) - return EmbeddingConfig(**current_dict) diff --git a/quantmind/config/flows.py b/quantmind/config/flows.py deleted file mode 100644 index 419609d..0000000 --- a/quantmind/config/flows.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Base flow configuration for QuantMind framework.""" - -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, List, Union - -import yaml -from pydantic import BaseModel, Field - -from quantmind.config.llm import LLMConfig -from quantmind.utils.logger import get_logger - -logger = get_logger(__name__) - - -# ===== Base Flow Configuration ===== -class BaseFlowConfig(BaseModel): - """Base configuration for flows - only defines resources needed. - - This simplified config focuses on providing resources (LLM blocks and prompt templates) - rather than orchestrating flow logic, which is now handled in code. - """ - - name: str - llm_blocks: Dict[str, LLMConfig] = Field(default_factory=dict) - prompt_templates: Dict[str, str] = Field(default_factory=dict) - prompt_templates_path: Union[str, Path, None] = None - - def model_post_init(self, __context: Any) -> None: - """Initialize configuration after dataclass creation.""" - if self.prompt_templates_path: - self._load_prompt_templates() - - def _load_prompt_templates(self): - """Load prompt templates from YAML file.""" - logger.info( - f"Loading prompt templates from {self.prompt_templates_path}" - ) - - path = Path(self.prompt_templates_path) - if not path.exists(): - raise FileNotFoundError(f"Prompt templates file not found: {path}") - - if path.suffix.lower() not in [".yaml", ".yml"]: - raise ValueError( - f"Prompt templates file must be a YAML file, got: {path.suffix}" - ) - - with open(path, "r", encoding="utf-8") as f: - data = yaml.safe_load(f) - - # Replace current prompt_templates with loaded ones - templates = data.get("templates", {}) - if not templates: - raise ValueError(f"No 'templates' section found in {path}") - - self.prompt_templates = templates - - -# ===== Summary Flow Configuration ===== -class ChunkingStrategy(Enum): - """Strategy for chunking content. - - Attributes: - BY_SIZE: Chunk by size - BY_SECTION: Chunk by section - """ - - BY_SIZE = "by_size" - BY_SECTION = "by_section" - BY_CUSTOM = "by_custom" - - -class SummaryFlowConfig(BaseFlowConfig): - """Configuration for content summary generation flow.""" - - use_chunking: bool = True - chunk_size: int = 2000 - chunk_strategy: ChunkingStrategy = ChunkingStrategy.BY_SIZE - chunk_custom_strategy: Union[Callable[[str], List[str]], None] = None - - def model_post_init(self, __context: Any) -> None: - """Initialize default LLM blocks and templates for summary flow.""" - # First load prompt templates from path if specified - super().model_post_init(__context) - - # Allow BY_SIZE and BY_CUSTOM strategies - if self.chunk_strategy not in [ - ChunkingStrategy.BY_SIZE, - ChunkingStrategy.BY_CUSTOM, - ]: - raise NotImplementedError( - f"Chunking strategy {self.chunk_strategy} is not implemented for this flow." - ) - - if not self.llm_blocks: - # Default LLM blocks for the two-stage summary process - self.llm_blocks = { - "cheap_summarizer": LLMConfig( - model="gpt-4o-mini", temperature=0.3, max_tokens=1000 - ), - "powerful_combiner": LLMConfig( - model="gpt-4o", temperature=0.3, max_tokens=2000 - ), - } - - if not self.prompt_templates: - # Default prompt templates - self.prompt_templates = { - "summarize_chunk_template": ( - "You are a financial research expert. Summarize the following content chunk " - "focusing on key insights, methodology, and findings. Keep it concise but comprehensive.\n\n" - "Content:\n{{ chunk_text }}\n\n" - "Summary:" - ), - "combine_summaries_template": ( - "You are a financial research expert. Combine the following chunk summaries " - "into a coherent, comprehensive final summary. Eliminate redundancy and " - "create a well-structured overview.\n\n" - "Chunk Summaries:\n{{ summaries }}\n\n" - "Final Summary:" - ), - } - - -class PodcastFlowConfig(BaseFlowConfig): - """Configuration for podcast generation flow.""" - - num_speakers: int = 2 - speaker_languages: str = "en-us" - summary_hint: str diff --git a/quantmind/config/llm.py b/quantmind/config/llm.py deleted file mode 100644 index 7e3d491..0000000 --- a/quantmind/config/llm.py +++ /dev/null @@ -1,162 +0,0 @@ -"""LLM configuration for QuantMind.""" - -import os -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field, field_validator - - -class LLMConfig(BaseModel): - """Configuration for LLMBlock.""" - - # Model configuration - model: str = Field( - default="gpt-4o", description="LLM model name (LiteLLM format)" - ) - temperature: float = Field( - default=0.0, ge=0.0, le=2.0, description="Temperature for generation" - ) - max_tokens: int = Field( - default=4000, gt=0, description="Maximum tokens to generate" - ) - top_p: float = Field( - default=1.0, ge=0.0, le=1.0, description="Top-p sampling parameter" - ) - - # API configuration - api_key: Optional[str] = Field( - default=None, description="API key for the LLM provider" - ) - base_url: Optional[str] = Field( - default=None, description="Custom base URL for API" - ) - api_version: Optional[str] = Field( - default=None, description="API version (for Azure)" - ) - - # Request configuration - timeout: int = Field( - default=60, gt=0, description="Request timeout in seconds" - ) - retry_attempts: int = Field( - default=3, ge=0, description="Number of retry attempts" - ) - retry_delay: float = Field( - default=1.0, ge=0, description="Delay between retries in seconds" - ) - - # Additional provider-specific parameters - extra_params: Dict[str, Any] = Field( - default_factory=dict, - description="Additional parameters for the provider", - ) - - # System configuration - system_prompt: Optional[str] = Field( - default=None, description="System prompt for the model" - ) - custom_instructions: Optional[str] = Field( - default=None, description="Custom instructions to append" - ) - - @field_validator("model") - def validate_model(cls, v): - """Validate model name format.""" - if not v or not isinstance(v, str): - raise ValueError("Model name must be a non-empty string") - return v.strip() - - @field_validator("api_key") - def validate_api_key(cls, v): - """Validate API key.""" - if v is not None and not isinstance(v, str): - raise ValueError("API key must be a string") - return v - - def get_effective_api_key(self) -> Optional[str]: - """Get the effective API key with fallback logic. - - Priority: - 1. Directly provided api_key - 2. Environment variable specified in api_key_env_var - 3. Smart inference based on model provider type - - Returns: - The effective API key or None if not found - """ - # Priority 1: Direct API key - if self.api_key and "${" not in self.api_key: - return self.api_key - - # Priority 2: Smart inference based on provider - provider = self.get_provider_type() - smart_env_vars = { - "openai": ["OPENAI_API_KEY", "OPENAI_KEY"], - "anthropic": ["ANTHROPIC_API_KEY", "CLAUDE_API_KEY"], - "google": ["GOOGLE_API_KEY", "GEMINI_API_KEY"], - "azure": ["AZURE_OPENAI_API_KEY", "AZURE_API_KEY"], - "deepseek": ["DEEPSEEK_API_KEY"], - } - - if provider in smart_env_vars: - for env_var in smart_env_vars[provider]: - key = os.getenv(env_var) - if key: - return key - - return None - - def get_litellm_params(self) -> Dict[str, Any]: - """Get parameters formatted for LiteLLM.""" - params = { - "model": self.model, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "timeout": self.timeout, - } - - # Add API key if provided - effective_api_key = self.get_effective_api_key() - if effective_api_key: - params["api_key"] = effective_api_key - - # Add base URL if provided - if self.base_url: - params["base_url"] = self.base_url - - # Add API version if provided (for Azure) - if self.api_version: - params["api_version"] = self.api_version - - # Add extra parameters - params.update(self.extra_params) - - return params - - def get_provider_type(self) -> str: - """Extract provider type from model name.""" - if self.model.startswith("gpt-") or self.model.startswith("openai/"): - return "openai" - elif self.model.startswith("claude-") or self.model.startswith( - "anthropic/" - ): - return "anthropic" - elif self.model.startswith("gemini-") or self.model.startswith( - "google/" - ): - return "google" - elif "azure" in self.model.lower(): - return "azure" - elif "ollama" in self.model.lower(): - return "ollama" - elif "deepseek" in self.model.lower(): - return "deepseek" - else: - return "unknown" - - def create_variant(self, **overrides) -> "LLMConfig": - """Create a variant of this config with parameter overrides.""" - current_dict = self.model_dump() - current_dict.update(overrides) - return LLMConfig(**current_dict) diff --git a/quantmind/config/parsers.py b/quantmind/config/parsers.py deleted file mode 100644 index 557b50e..0000000 --- a/quantmind/config/parsers.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Configuration models for parsers.""" - -from enum import Enum -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field, field_validator - - -class ParsingMode(str, Enum): - """The mode of parsing to use.""" - - FAST = "fast" - BALANCED = "balanced" - PREMIUM = "premium" - - -class ResultType(str, Enum): - """The result type for the parser.""" - - TXT = "text" - MD = "markdown" - JSON = "json" - - -class BaseParserConfig(BaseModel): - """Base configuration for all parsers.""" - - max_file_size_mb: int = Field(default=50, ge=1, le=100) - timeout: int = Field(default=120, ge=10, le=600) - retry_attempts: int = Field(default=3, ge=0, le=10) - enable_caching: bool = Field(default=True) - - class Config: - """Pydantic configuration.""" - - validate_assignment = True - use_enum_values = True - - -class PDFParserConfig(BaseParserConfig): - """Configuration for PDF parser.""" - - method: str = Field(default="pymupdf") - download_pdfs: bool = Field(default=True) - extract_images: bool = Field(default=False) - extract_tables: bool = Field(default=True) - - @field_validator("method") - @classmethod - def validate_method(cls, v: str) -> str: - """Validate parsing method.""" - valid_methods = ["pymupdf", "pdfplumber", "marker"] - if v not in valid_methods: - raise ValueError(f"method must be one of {valid_methods}") - return v - - -class LlamaParserConfig(BaseParserConfig): - """Configuration for LlamaParser.""" - - # API settings - api_key: Optional[str] = Field(default=None) - result_type: ResultType = Field(default=ResultType.MD) - parsing_mode: ParsingMode = Field(default=ParsingMode.FAST) - - # Custom prompts - system_prompt: Optional[str] = Field(default=None) - system_prompt_append: Optional[str] = Field(default=None) - - # Performance settings - num_workers: Optional[int] = Field(default=None, ge=1, le=10) - verbose: bool = Field(default=False) - language: Optional[str] = Field(default=None) - - # Page selection - target_pages: Optional[List[int]] = Field(default=None) - split_by_page: bool = Field(default=False) - - # Caching options - invalidate_cache: bool = Field(default=False) - do_not_cache: bool = Field(default=False) - - # Advanced settings - check_interval: Optional[int] = Field(default=None, ge=1, le=60) - max_timeout: Optional[int] = Field(default=None, ge=60, le=3600) - auto_mode_trigger_on_text_length: Optional[int] = Field( - default=None, ge=100, le=10000 - ) - - @field_validator("language") - @classmethod - def validate_language(cls, v: Optional[str]) -> Optional[str]: - """Validate language code.""" - if v is None: - return v - - # Common language codes - valid_languages = [ - "en", - "zh", - "es", - "fr", - "de", - "ja", - "ko", - "ru", - "ar", - "pt", - "it", - ] - if v not in valid_languages: - raise ValueError( - f"language must be one of {valid_languages} or None" - ) - return v - - @field_validator("target_pages") - @classmethod - def validate_target_pages( - cls, v: Optional[List[int]] - ) -> Optional[List[int]]: - """Validate target pages.""" - if v is None: - return v - - if not v: - raise ValueError("target_pages cannot be empty if provided") - - for page in v: - if page == 0: - raise ValueError("page numbers start from 1, not 0") - if page < -1: - raise ValueError("negative page numbers must be -1 or greater") - - return v - - def get_llama_parse_config(self) -> Dict[str, Any]: - """Get configuration for LlamaParse initialization. - - Returns: - Dictionary with LlamaParse configuration parameters - """ - # Handle enum values properly - result_type_value = ( - self.result_type - if isinstance(self.result_type, str) - else self.result_type.value - ) - parsing_mode_value = ( - self.parsing_mode - if isinstance(self.parsing_mode, str) - else self.parsing_mode.value - ) - - config = { - "result_type": result_type_value, - } - - # Set parsing mode - if parsing_mode_value == "fast": - config["fast_mode"] = True - config["premium_mode"] = False - elif parsing_mode_value == "balanced": - config["fast_mode"] = False - config["premium_mode"] = False - elif parsing_mode_value == "premium": - config["fast_mode"] = False - config["premium_mode"] = True - - # Add system prompts if provided - if self.system_prompt: - config["system_prompt"] = self.system_prompt - if self.system_prompt_append: - config["system_prompt_append"] = self.system_prompt_append - - # Add optional parameters - optional_fields = [ - "num_workers", - "verbose", - "language", - "target_pages", - "split_by_page", - "invalidate_cache", - "do_not_cache", - "check_interval", - "max_timeout", - "auto_mode_trigger_on_text_length", - ] - - for field in optional_fields: - value = getattr(self, field) - if value is not None: - config[field] = value - - return config diff --git a/quantmind/config/registry.py b/quantmind/config/registry.py deleted file mode 100644 index 89338a3..0000000 --- a/quantmind/config/registry.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Flow configuration registry for dynamic type resolution.""" - -import importlib -import inspect -from pathlib import Path -from typing import Dict, Optional, Type - -from quantmind.config.flows import BaseFlowConfig -from quantmind.utils.logger import get_logger - -logger = get_logger(__name__) - - -# TODO: Add the corresponding unittests. -class FlowConfigRegistry: - """Registry for flow configuration classes enabling dynamic loading.""" - - _instance: Optional["FlowConfigRegistry"] = None - _registry: Dict[str, Type[BaseFlowConfig]] = {} - - def __new__(cls) -> "FlowConfigRegistry": - """Singleton pattern.""" - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - """Initialize registry with built-in flow types.""" - if not hasattr(self, "_initialized"): - self._register_builtin_flows() - self._initialized = True - - def _register_builtin_flows(self): - """Register built-in flow configuration types.""" - from quantmind.config.flows import ( - BaseFlowConfig, - PodcastFlowConfig, - SummaryFlowConfig, - ) - - self._registry["base"] = BaseFlowConfig - self._registry["summary"] = SummaryFlowConfig - self._registry["podcast"] = PodcastFlowConfig - - logger.debug("Registered built-in flow types: base, summary") - - def register( - self, flow_type: str, config_class: Type[BaseFlowConfig] - ) -> None: - """Register a flow configuration class. - - Args: - flow_type: String identifier for the flow type - config_class: Configuration class (must inherit from BaseFlowConfig) - - Raises: - ValueError: If config_class doesn't inherit from BaseFlowConfig - """ - if not issubclass(config_class, BaseFlowConfig): - raise ValueError( - f"Flow config class {config_class.__name__} must inherit from BaseFlowConfig" - ) - - self._registry[flow_type] = config_class - logger.debug( - f"Registered flow type '{flow_type}' -> {config_class.__name__}" - ) - - def get_config_class(self, flow_type: str) -> Type[BaseFlowConfig]: - """Get configuration class for a flow type. - - Args: - flow_type: String identifier for the flow type - - Returns: - Configuration class - - Raises: - KeyError: If flow type is not registered - """ - if flow_type not in self._registry: - raise KeyError(f"Unknown flow type: {flow_type}") - return self._registry[flow_type] - - def list_types(self) -> Dict[str, str]: - """List all registered flow types. - - Returns: - Dictionary mapping flow type to class name - """ - return { - flow_type: config_class.__name__ - for flow_type, config_class in self._registry.items() - } - - def auto_discover_flows(self, search_paths: list[Path]) -> None: - """Auto-discover and register flow configurations from specified paths. - - Args: - search_paths: List of directories to search for flow configurations - """ - for search_path in search_paths: - if not search_path.exists(): - logger.debug(f"Search path does not exist: {search_path}") - continue - - self._discover_flows_in_path(search_path) - - def _discover_flows_in_path(self, path: Path) -> None: - """Discover flows in a specific path.""" - # Look for flow.py files in subdirectories. - # `rglob` on macOS tmpdirs can hit AppTranslocation paths that raise - # OSError mid-iteration; swallow scan errors so registry probing never - # crashes the caller. - try: - flow_files = list(path.rglob("flow.py")) - except OSError as e: - logger.warning(f"Failed to scan {path} for flows: {e}") - return - for flow_file in flow_files: - try: - self._load_flow_from_file(flow_file) - except Exception as e: - logger.warning(f"Failed to load flow from {flow_file}: {e}") - - def _load_flow_from_file(self, flow_file: Path) -> None: - """Load flow configuration from a Python file.""" - # Convert path to module name - relative_path = flow_file.relative_to(Path.cwd()) - module_path_parts = list(relative_path.parts[:-1]) + [ - relative_path.stem - ] - module_name = ".".join(module_path_parts) - - try: - # Import the module - module = importlib.import_module(module_name) - - # Look for classes that inherit from BaseFlowConfig - for name, obj in inspect.getmembers(module, inspect.isclass): - if ( - issubclass(obj, BaseFlowConfig) - and obj != BaseFlowConfig - and obj.__module__ == module_name - ): - # Infer flow type from class name (e.g., GreetingFlowConfig -> greeting) - flow_type = self._infer_flow_type(name) - self.register(flow_type, obj) - - except ImportError as e: - logger.warning(f"Could not import module {module_name}: {e}") - - def _infer_flow_type(self, class_name: str) -> str: - """Infer flow type from class name. - - Args: - class_name: Class name (e.g., "GreetingFlowConfig") - - Returns: - Flow type string (e.g., "greeting") - """ - # Remove "FlowConfig" suffix and convert to lowercase - if class_name.endswith("FlowConfig"): - flow_type = class_name[:-10] # Remove "FlowConfig" - elif class_name.endswith("Config"): - flow_type = class_name[:-6] # Remove "Config" - else: - flow_type = class_name - - # Convert from CamelCase to snake_case - import re - - flow_type = re.sub("([A-Z]+)", r"_\1", flow_type).lower().strip("_") - - return flow_type - - -# Global registry instance -flow_registry = FlowConfigRegistry() - - -def register_flow_config(flow_type: str): - """Decorator to register a flow configuration class. - - Args: - flow_type: String identifier for the flow type - - Example: - @register_flow_config("greeting") - class GreetingFlowConfig(BaseFlowConfig): - pass - """ - - def decorator(config_class: Type[BaseFlowConfig]): - flow_registry.register(flow_type, config_class) - return config_class - - return decorator diff --git a/quantmind/config/settings.py b/quantmind/config/settings.py deleted file mode 100644 index 9e1170d..0000000 --- a/quantmind/config/settings.py +++ /dev/null @@ -1,418 +0,0 @@ -"""Unified configuration management for QuantMind. - -Simple, type-safe configuration system with YAML loading and environment variable substitution. -""" - -import os -import re -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import yaml -from dotenv import load_dotenv -from pydantic import BaseModel, Field - -from quantmind.config.flows import BaseFlowConfig -from quantmind.config.llm import LLMConfig -from quantmind.config.parsers import LlamaParserConfig, PDFParserConfig -from quantmind.config.registry import flow_registry -from quantmind.config.sources import ( - ArxivSourceConfig, - NewsSourceConfig, - WebSourceConfig, -) -from quantmind.config.storage import LocalStorageConfig -from quantmind.config.taggers import LLMTaggerConfig -from quantmind.utils.logger import get_logger - -logger = get_logger(__name__) - - -class Setting(BaseModel): - """Unified configuration for QuantMind - single instance pattern.""" - - # Component configurations - single instances, not dictionaries - source: Optional[ - Union[ArxivSourceConfig, NewsSourceConfig, WebSourceConfig] - ] = None - parser: Optional[Union[PDFParserConfig, LlamaParserConfig]] = None - tagger: Optional[LLMTaggerConfig] = None - storage: LocalStorageConfig = Field(default_factory=LocalStorageConfig) - flows: Dict[str, BaseFlowConfig] = Field(default_factory=dict) - - # Core configuration - llm: LLMConfig = Field(default_factory=LLMConfig) - - # Global settings - log_level: str = Field( - default="INFO", pattern=r"^(DEBUG|INFO|WARNING|ERROR|CRITICAL)$" - ) - - class Config: - """Pydantic model configuration.""" - - validate_assignment = True - extra = "forbid" - - @classmethod - def load_dotenv(cls, dotenv_path: Optional[str] = None) -> bool: - """Load environment variables from .env file. - - Args: - dotenv_path: Path to .env file. If None, auto-discovers .env file. - - Returns: - True if .env file was found and loaded, False otherwise - """ - if dotenv_path: - # Load specific file - env_path = Path(dotenv_path) - if env_path.exists(): - load_dotenv(env_path) - logger.info(f"Loaded environment from {env_path}") - return True - else: - logger.warning(f"Dotenv file not found: {env_path}") - return False - else: - # Auto-discover .env file - current_dir = Path.cwd() - env_paths = [ - current_dir / ".env", - current_dir.parent / ".env", - ] - - for env_path in env_paths: - if env_path.exists(): - load_dotenv(env_path) - logger.info(f"Loaded environment from {env_path}") - return True - - logger.debug("No .env file found") - return False - - @classmethod - def substitute_env_vars(cls, config_dict: Dict[str, Any]) -> Dict[str, Any]: - """Substitute environment variables in configuration values. - - Supports syntax: ${ENV_VAR} or ${ENV_VAR:default_value} - """ - - def substitute_value(value: Any) -> Any: - if isinstance(value, str): - # Pattern: ${VAR} or ${VAR:default} - pattern = r"\$\{([^}:]+)(?::([^}]*))?\}" - - def replacer(match): - env_var = match.group(1) - default_val = ( - match.group(2) if match.group(2) is not None else "" - ) - return os.getenv(env_var, default_val) - - return re.sub(pattern, replacer, value) - elif isinstance(value, dict): - return {k: substitute_value(v) for k, v in value.items()} - elif isinstance(value, list): - return [substitute_value(item) for item in value] - else: - return value - - return substitute_value(config_dict) - - @classmethod - def from_yaml( - cls, - config_path: Union[str, Path], - env_file: Optional[str] = None, - auto_discover_flows: bool = True, - ) -> "Setting": - """Load configuration from YAML file with environment variable substitution. - - Args: - config_path: Path to YAML configuration file - env_file: Optional path to .env file - auto_discover_flows: Whether to auto-discover custom flow configurations - - Returns: - Configured Setting instance - - Raises: - FileNotFoundError: If config file doesn't exist - ValueError: If config format is invalid - """ - config_path = Path(config_path) - - if not config_path.exists(): - raise FileNotFoundError( - f"Configuration file not found: {config_path}" - ) - - # Load .env file first - cls.load_dotenv(env_file) - - # Auto-discover custom flows if enabled - if auto_discover_flows: - cls._auto_discover_flows(config_path) - - try: - # Load YAML - with open(config_path, "r", encoding="utf-8") as f: - config_dict = yaml.safe_load(f) - - if not isinstance(config_dict, dict): - raise ValueError("Configuration file must contain a dictionary") - - # Substitute environment variables - config_dict = cls.substitute_env_vars(config_dict) - - # Parse configuration - return cls._parse_config(config_dict) - - except Exception as e: - logger.error( - f"Failed to load configuration from {config_path}: {e}" - ) - raise - - @classmethod - def _auto_discover_flows(cls, config_path: Path) -> None: - """Auto-discover custom flow configurations near the config file. - - Args: - config_path: Path to the configuration file - """ - # Search in the same directory as config file and subdirectories - search_paths = [ - config_path.parent, # Same directory as config - config_path.parent / "flows", # flows subdirectory - ] - - # Add additional search paths if they exist - for subdir in ["examples", "custom", "user_flows"]: - potential_path = config_path.parent / subdir - if potential_path.exists(): - search_paths.append(potential_path) - - flow_registry.auto_discover_flows(search_paths) - logger.debug(f"Auto-discovered flows from paths: {search_paths}") - - @classmethod - def _parse_config(cls, config_dict: Dict[str, Any]) -> "Setting": - """Parse configuration dictionary into Setting instance.""" - # Configuration type registry - CONFIG_REGISTRY = { - "source": { - "arxiv": ArxivSourceConfig, - "news": NewsSourceConfig, - "web": WebSourceConfig, - }, - "parser": { - "pdf": PDFParserConfig, - "llama": LlamaParserConfig, - }, - "tagger": { - "llm": LLMTaggerConfig, - }, - "storage": { - "local": LocalStorageConfig, - }, - } - - parsed = {} - - # Parse component configurations - for component_name, type_registry in CONFIG_REGISTRY.items(): - if component_name in config_dict: - component_data = config_dict[component_name] - if isinstance(component_data, dict): - component_type = component_data.get("type") - component_config = component_data.get("config", {}) - - if component_type in type_registry: - config_class = type_registry[component_type] - parsed[component_name] = config_class( - **component_config - ) - else: - logger.warning( - f"Unknown {component_name} type: {component_type}" - ) - - # Parse flows dictionary using registry - if "flows" in config_dict: - flows_dict = {} - flows_config = config_dict["flows"] - if isinstance(flows_config, dict): - for flow_name, flow_data in flows_config.items(): - if isinstance(flow_data, dict): - flow_type = flow_data.get("type", "base") - flow_config = flow_data.get("config", {}) - - try: - # Use registry to get config class - config_class = flow_registry.get_config_class( - flow_type - ) - # Add name to config if not present - flow_config.setdefault("name", flow_name) - flows_dict[flow_name] = config_class(**flow_config) - except KeyError: - logger.warning(f"Unknown flow type: {flow_type}") - except Exception as e: - logger.error( - f"Failed to create config for flow '{flow_name}': {e}" - ) - parsed["flows"] = flows_dict - - # Parse other configurations - if "llm" in config_dict: - parsed["llm"] = LLMConfig(**config_dict["llm"]) - - # Copy simple fields - if "log_level" in config_dict: - parsed["log_level"] = config_dict["log_level"] - - return cls(**parsed) - - @classmethod - def create_default(cls) -> "Setting": - """Create default configuration with sensible defaults.""" - return cls( - source=ArxivSourceConfig( - max_results=100, - sort_by="submittedDate", - sort_order="descending", - ), - parser=PDFParserConfig( - method="pymupdf", - download_pdfs=True, - extract_tables=True, - ), - storage=LocalStorageConfig(), - ) - - def save_to_yaml(self, config_path: Union[str, Path]) -> None: - """Save configuration to YAML file. - - Args: - config_path: Path to save configuration to - """ - config_path = Path(config_path) - config_dict = self._export_config() - - try: - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config_dict, f, default_flow_style=False, indent=2) - - logger.info(f"Saved configuration to {config_path}") - - except Exception as e: - logger.error(f"Failed to save configuration to {config_path}: {e}") - raise - - def _export_config(self) -> Dict[str, Any]: - """Export configuration to dictionary format suitable for YAML.""" - - def serialize_value(value: Any) -> Any: - """Recursively serialize values.""" - if isinstance(value, Path): - return str(value) - elif isinstance(value, dict): - return {k: serialize_value(v) for k, v in value.items()} - elif isinstance(value, list): - return [serialize_value(item) for item in value] - else: - return value - - def serialize_component(component, component_type_map): - if component is None: - return None - - # Find component type - component_class = type(component) - component_type = None - for type_name, type_class in component_type_map.items(): - if component_class == type_class: - component_type = type_name - break - - if component_type is None: - return None - - # Serialize config, excluding sensitive fields - config_dict = component.model_dump(exclude_none=True) - config_dict.pop("api_key", None) # Remove sensitive data - - # Convert Path objects to strings - config_dict = serialize_value(config_dict) - - return {"type": component_type, "config": config_dict} - - # Type mappings for export - type_maps = { - "source": { - "arxiv": ArxivSourceConfig, - "news": NewsSourceConfig, - "web": WebSourceConfig, - }, - "parser": {"pdf": PDFParserConfig, "llama": LlamaParserConfig}, - "tagger": {"llm": LLMTaggerConfig}, - "storage": {"local": LocalStorageConfig}, - "flow": { - flow_type: flow_registry.get_config_class(flow_type) - for flow_type in flow_registry.list_types() - }, - } - - config_dict = {} - - # Export components (excluding flows which are handled separately) - for component_name, type_map in type_maps.items(): - if component_name == "flow": - continue # Handle flows separately - component = getattr(self, component_name, None) - serialized = serialize_component(component, type_map) - if serialized: - config_dict[component_name] = serialized - - # Export flows dictionary - if self.flows: - flows_dict = {} - for flow_name, flow_config in self.flows.items(): - flow_serialized = serialize_component( - flow_config, type_maps["flow"] - ) - if flow_serialized: - flows_dict[flow_name] = flow_serialized - if flows_dict: - config_dict["flows"] = flows_dict - - # Export LLM config (exclude sensitive data) - config_dict["llm"] = self.llm.model_dump(exclude={"api_key"}) - - # Export simple fields - config_dict["log_level"] = self.log_level - - return config_dict - - -# Factory functions for convenience -def load_config( - config_path: Union[str, Path], env_file: Optional[str] = None -) -> Setting: - """Load configuration from YAML file.""" - return Setting.from_yaml(config_path, env_file) - - -def create_default_config() -> Setting: - """Create default configuration.""" - return Setting.create_default() - - -# Export public API -__all__ = [ - "Setting", - "load_config", - "create_default_config", -] diff --git a/quantmind/config/sources.py b/quantmind/config/sources.py deleted file mode 100644 index 41a91dd..0000000 --- a/quantmind/config/sources.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Configuration models for sources.""" - -from pathlib import Path -from typing import Any, Dict, List, Optional - -import arxiv -from pydantic import BaseModel, Field, field_validator - - -class BaseSourceConfig(BaseModel): - """Base configuration for all sources.""" - - max_results: int = Field(default=100, ge=1, le=1000) - timeout: int = Field(default=30, ge=1, le=300) - retry_attempts: int = Field(default=3, ge=0, le=10) - proxies: Optional[dict] = Field(default=None) - - class Config: - """Pydantic configuration.""" - - validate_assignment = True - - -class ArxivSourceConfig(BaseSourceConfig): - """Configuration for ArXiv source.""" - - # API settings - sort_by: str = Field(default="submittedDate") - sort_order: str = Field(default="descending") - - # Download settings - download_pdfs: bool = Field(default=False) - download_dir: Optional[Path] = Field(default=None) - - # Rate limiting - requests_per_second: float = Field(default=1.0, ge=0.1, le=10.0) - - # Content filtering - include_categories: Optional[List[str]] = Field(default=None) - exclude_categories: Optional[List[str]] = Field(default=None) - min_abstract_length: int = Field(default=50, ge=0) - - # Language filtering - languages: Optional[List[str]] = Field(default=None) - - @field_validator("sort_by") - def validate_sort_by(cls, v): - """Validate sort_by field.""" - valid_sorts = ["relevance", "lastUpdatedDate", "submittedDate"] - if v not in valid_sorts: - raise ValueError(f"sort_by must be one of {valid_sorts}") - return v - - @field_validator("sort_order") - def validate_sort_order(cls, v): - """Validate sort_order field.""" - if v not in ["ascending", "descending"]: - raise ValueError("sort_order must be 'ascending' or 'descending'") - return v - - @field_validator("download_dir") - def validate_download_dir(cls, v): - """Validate and create download directory if needed.""" - if v is not None: - v = Path(v) - if not v.exists(): - v.mkdir(parents=True, exist_ok=True) - elif not v.is_dir(): - raise ValueError(f"download_dir must be a directory: {v}") - return v - - @field_validator("include_categories", "exclude_categories") - def validate_categories(cls, v): - """Validate arXiv categories.""" - if v is not None: - # Common arXiv categories for validation - valid_categories = [ - "cs.AI", - "cs.CL", - "cs.CV", - "cs.LG", - "cs.MA", - "cs.NE", - "stat.ML", - "stat.AP", - "stat.CO", - "stat.ME", - "stat.TH", - "q-fin.CP", - "q-fin.EC", - "q-fin.GN", - "q-fin.MF", - "q-fin.PM", - "q-fin.PR", - "q-fin.RM", - "q-fin.ST", - "q-fin.TR", - "math.PR", - "math.ST", - "math.OC", - "math.NA", - "econ.EM", - "econ.GN", - "econ.TH", - ] - for cat in v: - if cat not in valid_categories: - # Allow but warn about unknown categories - pass - return v - - def get_arxiv_sort_criterion(self) -> arxiv.SortCriterion: - """Get arXiv sort criterion from config.""" - sort_map = { - "relevance": arxiv.SortCriterion.Relevance, - "lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate, - "submittedDate": arxiv.SortCriterion.SubmittedDate, - } - return sort_map[self.sort_by] - - def get_arxiv_sort_order(self) -> arxiv.SortOrder: - """Get arXiv sort order from config.""" - order_map = { - "ascending": arxiv.SortOrder.Ascending, - "descending": arxiv.SortOrder.Descending, - } - return order_map[self.sort_order] - - -class NewsSourceConfig(BaseSourceConfig): - """Configuration for news sources.""" - - # API settings - api_key: Optional[str] = Field(default=None) - base_url: Optional[str] = Field(default=None) - - # Content filtering - sources: Optional[List[str]] = Field(default=None) - domains: Optional[List[str]] = Field(default=None) - exclude_domains: Optional[List[str]] = Field(default=None) - - # Language and location - language: str = Field(default="en") - country: Optional[str] = Field(default=None) - - @field_validator("language") - def validate_language(cls, v): - """Validate language code.""" - # ISO 639-1 language codes - valid_languages = [ - "en", - "es", - "fr", - "de", - "it", - "pt", - "ru", - "ja", - "ko", - "zh", - ] - if v not in valid_languages: - raise ValueError(f"language must be one of {valid_languages}") - return v - - -class WebSourceConfig(BaseSourceConfig): - """Configuration for web scraping sources.""" - - # Request settings - user_agent: str = Field(default="QuantMind/1.0") - headers: Optional[dict] = Field(default=None) - cookies: Optional[dict] = Field(default=None) - - # Scraping settings - follow_redirects: bool = Field(default=True) - verify_ssl: bool = Field(default=True) - - # Content extraction - selectors: Optional[dict] = Field(default=None) - - # Rate limiting - delay_between_requests: float = Field(default=1.0, ge=0.0) - - @field_validator("delay_between_requests") - def validate_delay(cls, v): - """Ensure reasonable delay between requests.""" - if v < 0.1: - raise ValueError( - "delay_between_requests should be at least 0.1 seconds" - ) - return v - - -# Source configuration registry -SOURCE_CONFIGS = { - "arxiv": ArxivSourceConfig, - "news": NewsSourceConfig, - "web": WebSourceConfig, -} - - -def get_source_config( - source_type: str, config_data: Dict[str, Any] -) -> BaseSourceConfig: - """Get configured source instance for source type. - - Args: - source_type: Type of source (e.g., 'arxiv', 'news') - config_data: Configuration data dictionary - - Returns: - Configured source instance - - Raises: - ValueError: If source type is not supported - """ - if source_type not in SOURCE_CONFIGS: - raise ValueError(f"Unsupported source type: {source_type}") - - config_class = SOURCE_CONFIGS[source_type] - return config_class(**config_data) diff --git a/quantmind/config/storage.py b/quantmind/config/storage.py deleted file mode 100644 index 31decc3..0000000 --- a/quantmind/config/storage.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Storage configuration models for QuantMind.""" - -from pathlib import Path - -from pydantic import BaseModel, Field - - -class BaseStorageConfig(BaseModel): - """Base configuration for all storage types.""" - - storage_dir: Path = Field( - default=Path("./data"), description="Base storage directory" - ) - - download_timeout: int = Field( - default=30, description="Timeout in seconds for downloading files" - ) - - -class LocalStorageConfig(BaseStorageConfig): - """Configuration for local file-based storage.""" - - def model_post_init(self, __context): - """Ensure storage directory exists.""" - self.storage_dir = Path(self.storage_dir).expanduser().resolve() - self.storage_dir.mkdir(parents=True, exist_ok=True) - - # Create subdirectories - (self.storage_dir / "raw_files").mkdir(exist_ok=True) - (self.storage_dir / "knowledges").mkdir(exist_ok=True) - (self.storage_dir / "embeddings").mkdir(exist_ok=True) - (self.storage_dir / "extra").mkdir(exist_ok=True) - - @property - def raw_files_dir(self) -> Path: - """Directory for raw files (PDFs, etc.).""" - return self.storage_dir / "raw_files" - - @property - def knowledges_dir(self) -> Path: - """Directory for knowledge JSONs.""" - return self.storage_dir / "knowledges" - - @property - def embeddings_dir(self) -> Path: - """Directory for embedding arrays.""" - return self.storage_dir / "embeddings" - - @property - def extra_dir(self) -> Path: - """Directory for extra data.""" - return self.storage_dir / "extra" diff --git a/quantmind/config/taggers.py b/quantmind/config/taggers.py deleted file mode 100644 index 66f0820..0000000 --- a/quantmind/config/taggers.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Configuration models for taggers.""" - -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field - -from quantmind.config.llm import LLMConfig - - -class BaseTaggerConfig(BaseModel): - """Base configuration for all taggers.""" - - max_tags: int = Field(default=5, ge=1, le=10) - meta_info: Optional[Dict[str, Any]] = Field(default=None) - - -class LLMTaggerConfig(BaseTaggerConfig): - """Configuration for LLM-based tagger using LLMConfig composition.""" - - # LLM configuration - using composition pattern to avoid field duplication - llm_config: LLMConfig = Field( - default_factory=LLMConfig, description="LLM configuration" - ) - - # Tagger-specific settings - custom_prompt: Optional[str] = Field( - default=None, description="Custom tagging prompt" - ) - - @classmethod - def create( - cls, - model: str = "gpt-4o", - api_key: Optional[str] = None, - temperature: float = 0.3, - max_tokens: int = 5000, - max_tags: int = 5, - custom_instructions: Optional[str] = None, - **kwargs, - ) -> "LLMTaggerConfig": - """Create an LLMTaggerConfig with convenient LLM parameter specification. - - Args: - model: LLM model name - api_key: API key for LLM - temperature: LLM temperature - max_tokens: Maximum tokens - max_tags: Maximum number of tags to generate - custom_instructions: Custom instructions to append to prompts - **kwargs: Additional tagger-specific parameters - - Returns: - Configured LLMTaggerConfig instance - """ - llm_config = LLMConfig( - model=model, - api_key=api_key, - temperature=temperature, - max_tokens=max_tokens, - custom_instructions=custom_instructions, - ) - return cls(llm_config=llm_config, max_tags=max_tags, **kwargs) diff --git a/quantmind/flow/__init__.py b/quantmind/flow/__init__.py deleted file mode 100644 index 8b6b4b3..0000000 --- a/quantmind/flow/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Flow framework components.""" - -from quantmind.flow.base import BaseFlow -from quantmind.flow.summary_flow import SummaryFlow - -__all__ = [ - "BaseFlow", - "SummaryFlow", - "PodcastFlow", -] diff --git a/quantmind/flow/base.py b/quantmind/flow/base.py deleted file mode 100644 index 544673f..0000000 --- a/quantmind/flow/base.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Base flow abstract class for QuantMind framework.""" - -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union - -from jinja2 import Template - -from quantmind.config import BaseFlowConfig, LLMConfig -from quantmind.llm import LLMBlock, create_llm_block -from quantmind.models.content import KnowledgeItem -from quantmind.utils.logger import get_logger - -logger = get_logger(__name__) - - -# Type alias for flow input data -FlowInput = Union[ - KnowledgeItem, List[KnowledgeItem], Dict[str, KnowledgeItem], Any -] - - -class BaseFlow(ABC): - """Abstract base class for all flows providing resource access and orchestration framework. - - BaseFlow provides: - - Resource management (LLM blocks and prompt templates) - - Helper methods for rendering prompts - - Abstract run() method for subclass-specific business logic - - Flow subclasses implement the run() method with Python code to define - the specific orchestration logic, conditions, loops, and parallel operations. - """ - - def __init__(self, config: BaseFlowConfig): - """Initialize flow with configuration. - - Args: - config: Flow configuration defining resources (LLM blocks and templates) - """ - self.config = config - self._llm_blocks = self._initialize_llm_blocks(config.llm_blocks) - self._templates = { - name: Template(template_str) - for name, template_str in config.prompt_templates.items() - } - - logger.info( - f"Initialized flow '{config.name}' with {len(self._llm_blocks)} LLM blocks" - ) - - def _initialize_llm_blocks( - self, llm_configs: Dict[str, LLMConfig] - ) -> Dict[str, Union[LLMBlock, None]]: - """Initialize LLM blocks from configurations. - - Args: - llm_configs: Dictionary of LLM configurations - - Returns: - Dictionary of initialized LLM blocks - """ - llm_blocks = {} - for identifier, llm_config in llm_configs.items(): - try: - llm_block = create_llm_block(llm_config) - llm_blocks[identifier] = llm_block - logger.debug( - f"Initialized LLM block '{identifier}' with model: {llm_config.model}" - ) - except Exception as e: - logger.error( - f"Failed to initialize LLM block '{identifier}': {e}" - ) - llm_blocks[identifier] = None - - return llm_blocks - - def _render_prompt(self, template_name: str, **kwargs) -> str: - """Render prompt using specified template and variables. - - Args: - template_name: Name of template to use - **kwargs: Template variables - - Returns: - Rendered prompt string - - Raises: - KeyError: If template not found - """ - if template_name not in self._templates: - raise KeyError( - f"Template '{template_name}' not found in flow config" - ) - - template = self._templates[template_name] - return template.render(**kwargs) - - @abstractmethod - def run(self, flow_input: FlowInput) -> Any: - """Execute the flow's business logic orchestration. - - This method must be implemented by subclasses to define the specific - workflow steps, using the available LLM blocks and prompt templates. - - Args: - flow_input: Initial input data (typically KnowledgeItem) - - Returns: - Flow execution result (structure depends on flow) - """ - pass diff --git a/quantmind/flow/podcast_flow.py b/quantmind/flow/podcast_flow.py deleted file mode 100644 index 5e45b8a..0000000 --- a/quantmind/flow/podcast_flow.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Podcast Flow - Generate final podcast scripts from summary input. - -This flow takes a summary and generates a podcast script in JSON format. -""" - -from typing import Any, Dict - -from quantmind.config.flows import PodcastFlowConfig -from quantmind.flow.base import BaseFlow -from quantmind.utils.logger import get_logger - -logger = get_logger(__name__) - - -class PodcastFlow(BaseFlow): - """Flow for generating podcast scripts from summary input.""" - - def __init__(self, config: PodcastFlowConfig): - super().__init__(config) - self.config = config - - def run(self, summary: str) -> Dict[str, Any]: - """Execute the podcast script generation flow. - - Args: - summary: Summary of the podcast content to generate the script from. - - Returns: - JSON string containing the podcast script - """ - if summary: - self.config.summary_hint = summary - logger.info("Using input summary.") - else: - logger.warning("No summary provided, using default summary hint.") - - logger.info("Starting podcast script generation flow") - # Generate podcast script - script = self._generate_script(self.config.summary_hint) - - logger.info("Podcast script generation completed") - return script - - def _generate_script(self, summary: str) -> Dict[str, Any]: - """Generate podcast script from summary.""" - script = {} - - main_generator = self._llm_blocks["main_generator"] - main_prompt = self._render_prompt("main_prompt", summary_hint=summary) - main_script = main_generator.generate_text(main_prompt) - script["main"] = main_script - - return script diff --git a/quantmind/flow/summary_flow.py b/quantmind/flow/summary_flow.py deleted file mode 100644 index 0efe76a..0000000 --- a/quantmind/flow/summary_flow.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Content summary generation flow using two-stage approach for QuantMind framework.""" - -from typing import List - -from quantmind.config.flows import ChunkingStrategy -from quantmind.flow.base import BaseFlow -from quantmind.models.content import KnowledgeItem -from quantmind.utils.logger import get_logger - -logger = get_logger(__name__) - - -class SummaryFlow(BaseFlow): - """A two-step built-in summary flow: chunk documents, then combine summaries. - - You can set `use_chunking` to False to use the powerful model directly on the full content or - change the chunking strategy to `ChunkingStrategy.BY_CUSTOM` to use a custom chunking strategy. - - This flow demonstrates the new architecture by implementing a cost-effective - approach: use a cheap model for chunk summarization, then a powerful model - for final combination. - """ - - def run(self, document: KnowledgeItem) -> str: - """Execute the two-stage summary process. - - Args: - document: KnowledgeItem to summarize - - Returns: - Final combined summary - """ - logger.info(f"Starting summary flow for: {document.title}") - - content = document.content or "" - if not content: - logger.warning("No content to summarize") - return "No content available for summarization." - - # Two different strategies based on chunking configuration - if self.config.use_chunking: - # Strategy 1: Chunking mode - use cheap model for chunks, powerful for combination - logger.debug("Using chunking mode with two-stage summarization") - - chunks = self._chunk_document(content) - if not chunks: - logger.warning("No chunks generated") - return "No content available for summarization." - - # Use cheap model to summarize each chunk - summarizer_llm = self._llm_blocks["cheap_summarizer"] - chunk_summaries = [] - - for i, chunk in enumerate(chunks): - logger.debug(f"Summarizing chunk {i + 1}/{len(chunks)}") - prompt = self._render_prompt( - "summarize_chunk_template", chunk_text=chunk - ) - summary = summarizer_llm.generate_text(prompt) - if summary: - chunk_summaries.append(summary) - - if not chunk_summaries: - logger.error("Failed to generate any chunk summaries") - return "Failed to summarize content." - - # If only one chunk, return its summary directly - if len(chunk_summaries) == 1: - logger.info( - f"Successfully generated summary for: {document.title}" - ) - return chunk_summaries[0] - - # Use powerful model to combine multiple chunk summaries - combiner_llm = self._llm_blocks["powerful_combiner"] - final_prompt = self._render_prompt( - "combine_summaries_template", - summaries="\n\n".join(chunk_summaries), - ) - - final_summary = combiner_llm.generate_text(final_prompt) - - if final_summary: - logger.info( - f"Successfully generated summary for: {document.title}" - ) - return final_summary - else: - logger.error("Failed to generate final summary") - return "Failed to generate final summary." - - else: - # Strategy 2: No chunking - use powerful model directly on full content - logger.debug( - "Using non-chunking mode with direct powerful model summarization" - ) - - combiner_llm = self._llm_blocks["powerful_combiner"] - prompt = self._render_prompt( - "summarize_chunk_template", chunk_text=content - ) - - summary = combiner_llm.generate_text(prompt) - - if summary: - logger.info( - f"Successfully generated summary for: {document.title}" - ) - return summary - else: - logger.error("Failed to generate summary") - return "Failed to summarize content." - - def _chunk_document(self, text: str) -> List[str]: - """Split document into chunks for processing. - - Args: - text: Document text to chunk - - Returns: - List of text chunks - """ - if not text: - return [] - if self.config.chunk_strategy == ChunkingStrategy.BY_CUSTOM: - if self.config.chunk_custom_strategy: - return self.config.chunk_custom_strategy(text) - else: - logger.warning( - "Custom chunking strategy specified but no function provided, falling back to BY_SIZE" - ) - elif self.config.chunk_strategy == ChunkingStrategy.BY_SECTION: - raise NotImplementedError( - "Chunking by section is not implemented for this flow." - ) - - # Default to BY_SIZE strategy (already validated in config) - chunk_size = self.config.chunk_size - chunks = [] - - # Simple chunking by character count with word boundary preservation - for i in range(0, len(text), chunk_size): - chunk = text[i : i + chunk_size] - - # Try to end at word boundary if not last chunk - if i + chunk_size < len(text): - last_space = chunk.rfind(" ") - if ( - last_space > chunk_size // 2 - ): # Only if we don't lose too much - chunk = chunk[:last_space] - - chunks.append(chunk.strip()) - - logger.debug( - f"Split document into {len(chunks)} chunks using {self.config.chunk_strategy.value} strategy" - ) - return chunks diff --git a/quantmind/flows/__init__.py b/quantmind/flows/__init__.py new file mode 100644 index 0000000..78dd9bf --- /dev/null +++ b/quantmind/flows/__init__.py @@ -0,0 +1,22 @@ +"""Apex layer — composes configs / knowledge / preprocess on the SDK. + +Each flow function (``paper_flow``, future ``news_flow`` / ``earnings_flow``) +takes a typed input and a ``FlowCfg`` and returns a knowledge item. +Cross-flow utilities live alongside: + +- ``batch_run`` runs any flow over a list of inputs with bounded + concurrency and aggregated results. +- ``BatchResult`` is the shape returned by ``batch_run``. +- ``UnsupportedContentTypeError`` is raised when ``paper_flow`` cannot + route fetched bytes through the format layer. +""" + +from quantmind.flows.batch import BatchResult, batch_run +from quantmind.flows.paper import UnsupportedContentTypeError, paper_flow + +__all__ = [ + "BatchResult", + "UnsupportedContentTypeError", + "batch_run", + "paper_flow", +] diff --git a/quantmind/flows/_runner.py b/quantmind/flows/_runner.py new file mode 100644 index 0000000..1f3b507 --- /dev/null +++ b/quantmind/flows/_runner.py @@ -0,0 +1,137 @@ +"""Internal helpers shared by every flow function. + +`run_with_observability` wraps `Runner.run` with `RunConfig` derived from +`BaseFlowCfg`, composes user-supplied `RunHooks` (the SDK accepts only a +single hooks instance per run), and leaves a no-op call site for the +PR6 trajectory archive. Flow modules call this instead of touching the +SDK directly so observability behaviour stays in one place. +""" + +from typing import Any + +from agents import Agent, RunConfig, RunHooks, Runner + +from quantmind.configs import BaseFlowCfg + + +async def run_with_observability( + agent: Agent[Any], + input: str | list[Any], + *, + cfg: BaseFlowCfg, + memory: object | None = None, + extra_run_hooks: list[RunHooks[Any]], +) -> Any: + """Build `RunConfig` + composed hooks, run the agent, return final output. + + Args: + agent: The Agents SDK ``Agent`` to invoke. + input: Prompt string or pre-built input items. + cfg: Flow configuration. Tracing fields and ``max_turns`` are + forwarded to the SDK; ``workflow_name`` falls back to + ``"quantmind."`` when unset. + memory: PR6 ``Memory`` placeholder. Currently unused at runtime; + the value is forwarded to the trajectory-archive stub so PR6 + can wire it in without changing call sites. + extra_run_hooks: User-supplied hooks. Composed with any + memory-derived hooks (none in PR5) into a single + ``RunHooks`` instance. + + Returns: + ``RunResult.final_output`` typed by the agent's ``output_type``. + """ + workflow_name = cfg.workflow_name or f"quantmind.{agent.name}" + run_cfg = RunConfig( + workflow_name=workflow_name, + trace_metadata=cfg.trace_metadata, + trace_include_sensitive_data=cfg.trace_include_sensitive_data, + tracing_disabled=cfg.tracing_disabled, + ) + hooks = _compose_hooks(_collect_hooks(memory, extra_run_hooks)) + result = await Runner.run( + agent, + input, + run_config=run_cfg, + hooks=hooks, + max_turns=cfg.max_turns, + ) + _archive_run_artifacts(cfg, memory, result) + return result.final_output + + +def _collect_hooks( + memory: object | None, + extras: list[RunHooks[Any]], +) -> list[RunHooks[Any]]: + """Return hooks in run order: memory hooks first (PR6), then extras.""" + hooks: list[RunHooks[Any]] = [] + # PR6 will append `memory.run_hooks()` here when `memory` exposes the + # `Memory` Protocol. PR5 keeps `memory` opaque and contributes no hooks. + del memory + hooks.extend(extras) + return hooks + + +def _compose_hooks( + hooks: list[RunHooks[Any]], +) -> RunHooks[Any] | None: + """Merge multiple `RunHooks` into one (the SDK takes a single instance).""" + if not hooks: + return None + if len(hooks) == 1: + return hooks[0] + return _CompositeRunHooks(hooks) + + +class _CompositeRunHooks(RunHooks[Any]): + """Fan out every lifecycle method to each wrapped hook in order. + + Exceptions from earlier hooks short-circuit the rest by design — hooks + are integral to the run, not best-effort. PR6's archive hook should + catch its own exceptions internally if it wants resilience. + """ + + def __init__(self, inner: list[RunHooks[Any]]) -> None: + self._inner = list(inner) + + async def on_llm_start(self, *args: Any, **kwargs: Any) -> None: + for h in self._inner: + await h.on_llm_start(*args, **kwargs) + + async def on_llm_end(self, *args: Any, **kwargs: Any) -> None: + for h in self._inner: + await h.on_llm_end(*args, **kwargs) + + async def on_agent_start(self, *args: Any, **kwargs: Any) -> None: + for h in self._inner: + await h.on_agent_start(*args, **kwargs) + + async def on_agent_end(self, *args: Any, **kwargs: Any) -> None: + for h in self._inner: + await h.on_agent_end(*args, **kwargs) + + async def on_handoff(self, *args: Any, **kwargs: Any) -> None: + for h in self._inner: + await h.on_handoff(*args, **kwargs) + + async def on_tool_start(self, *args: Any, **kwargs: Any) -> None: + for h in self._inner: + await h.on_tool_start(*args, **kwargs) + + async def on_tool_end(self, *args: Any, **kwargs: Any) -> None: + for h in self._inner: + await h.on_tool_end(*args, **kwargs) + + +def _archive_run_artifacts( + cfg: BaseFlowCfg, + memory: object | None, + result: Any, +) -> None: + """No-op stub. PR6 writes a trajectory record under ``/runs/``. + + Kept as a real call site (rather than commented-out) so PR6 changes + one function body, not the runner's public path. + """ + del cfg, memory, result + return None diff --git a/quantmind/flows/batch.py b/quantmind/flows/batch.py new file mode 100644 index 0000000..9e6c659 --- /dev/null +++ b/quantmind/flows/batch.py @@ -0,0 +1,136 @@ +"""Batch runner — fan a flow function out over many inputs. + +`batch_run` is the single concurrency primitive QuantMind ships in MVP. +It does NOT support `memory=`; for memory-accumulating workflows users +write a serial `for` loop themselves (design doc §4.3.5). This keeps the +batch path stateless and free of cross-run race hazards. +""" + +import asyncio +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeVar + +from quantmind.configs import BaseFlowCfg, BaseInput + +InputT = TypeVar("InputT", bound=BaseInput) +OutputT = TypeVar("OutputT") + + +@dataclass(slots=True) +class BatchResult(Generic[OutputT]): + """Aggregate result of running a flow over many inputs. + + ``results[i]`` is the output for ``inputs[i]`` or ``None`` if that + input failed. ``errors`` carries ``(index, exception)`` for every + failure, sorted by index. ``successes`` and ``failures`` are + convenience views derived from these primary fields. + """ + + total: int + success_count: int + failure_count: int + results: list[OutputT | None] + errors: list[tuple[int, Exception]] + duration_seconds: float + tokens_total: dict[str, int] = field(default_factory=dict) + cost_estimate_usd: float = 0.0 + + @property + def successes(self) -> list[tuple[int, OutputT]]: + """``(index, result)`` for every input that succeeded.""" + return [(i, r) for i, r in enumerate(self.results) if r is not None] + + @property + def failures(self) -> list[tuple[int, Exception]]: + """Alias for ``errors`` to mirror ``successes`` for symmetry.""" + return list(self.errors) + + +async def batch_run( + flow_fn: Callable[..., Awaitable[OutputT]], + inputs: list[InputT], + *, + cfg: BaseFlowCfg | None = None, + concurrency: int = 4, + on_error: Literal["raise", "skip"] = "skip", + on_progress: Callable[[int, int], None] | None = None, + **flow_kwargs: Any, +) -> BatchResult[OutputT]: + """Run ``flow_fn`` over ``inputs`` with bounded concurrency. + + Args: + flow_fn: Any flow function with signature + ``(input, *, cfg, **kwargs) -> Awaitable[OutputT]``. + inputs: Inputs to fan out over. Empty list returns an empty + ``BatchResult`` immediately. + cfg: Shared cfg forwarded to every call. ``None`` lets the flow + use its own default. + concurrency: Maximum number of in-flight calls. Must be ≥ 1. + on_error: ``"raise"`` propagates the first failure (siblings get + cancelled); ``"skip"`` records every failure into + ``errors`` and returns the batch normally. + on_progress: Called as ``on_progress(done, total)`` after every + completion (success or failure). Must be cheap and + non-blocking — callbacks are invoked synchronously inside + the worker loop. + **flow_kwargs: Forwarded verbatim to ``flow_fn``. ``memory=`` is + **forbidden** in MVP; passing it raises ``ValueError``. + + Returns: + ``BatchResult`` with ``results`` parallel to ``inputs`` (None for + failures) and ``errors`` sorted by index. + + Raises: + ValueError: If ``memory=`` is passed via ``flow_kwargs``, or if + ``concurrency < 1``. + Exception: Re-raised when ``on_error="raise"`` and any input + fails. The exception is the first one raised by a worker; + other workers may already be cancelled when this surfaces. + """ + if "memory" in flow_kwargs: + raise ValueError( + "batch_run does not support `memory=` in MVP. For " + "memory-accumulating workflows write a serial loop instead: " + "`for inp in inputs: await flow_fn(inp, cfg=cfg, memory=memory)`. " + "See design doc §4.3.5." + ) + if concurrency < 1: + raise ValueError(f"concurrency must be >= 1, got {concurrency}") + + sem = asyncio.Semaphore(concurrency) + results: list[OutputT | None] = [None] * len(inputs) + errors: list[tuple[int, Exception]] = [] + started = time.monotonic() + done_counter = 0 + + async def run_one(i: int, inp: InputT) -> None: + nonlocal done_counter + async with sem: + try: + results[i] = await flow_fn(inp, cfg=cfg, **flow_kwargs) + except Exception as exc: + errors.append((i, exc)) + if on_error == "raise": + raise + finally: + # asyncio is single-threaded; this increment + read + + # callback all happen synchronously between await points. + done_counter += 1 + if on_progress is not None: + on_progress(done_counter, len(inputs)) + + # Same call shape for both modes — `run_one` swallows its own + # exception when on_error="skip", and re-raises (cancelling siblings + # via gather's default behaviour) when on_error="raise". + await asyncio.gather(*(run_one(i, inp) for i, inp in enumerate(inputs))) + + return BatchResult( + total=len(inputs), + success_count=sum(1 for r in results if r is not None), + failure_count=len(errors), + results=results, + errors=sorted(errors, key=lambda t: t[0]), + duration_seconds=time.monotonic() - started, + ) diff --git a/quantmind/flows/paper.py b/quantmind/flows/paper.py new file mode 100644 index 0000000..5b8373b --- /dev/null +++ b/quantmind/flows/paper.py @@ -0,0 +1,203 @@ +"""Paper extraction flow. + +`paper_flow` ingests one of the ``PaperInput`` discriminated-union +variants, fetches and converts the raw payload to markdown via +``preprocess.fetch`` + ``preprocess.format``, then runs an +``Agent(output_type=Paper)`` to produce a typed ``Paper`` +``TreeKnowledge`` object. + +Customization happens through the configured ``PaperFlowCfg`` (Layer 1) +or the keyword arguments on this function (Layer 2). To swap the whole +flow, fork this file (Layer 3 — design doc §9). +""" + +from typing import Any, TypeVar + +from agents import Agent, RunHooks, Tool + +from quantmind.configs import PaperFlowCfg +from quantmind.configs.paper import ( + ArxivIdentifier, + DoiIdentifier, + HttpUrl, + LocalFilePath, + PaperInput, + RawText, +) +from quantmind.flows._runner import run_with_observability +from quantmind.knowledge import Paper +from quantmind.preprocess.fetch import ( + Fetched, + fetch_arxiv, + fetch_url, + read_local_file, +) +from quantmind.preprocess.format import html_to_markdown, pdf_to_markdown + +P = TypeVar("P", bound=Paper) + +_DEFAULT_INSTRUCTIONS = """\ +You are extracting a research paper into a structured QuantMind ``Paper`` +TreeKnowledge object. Build the section tree top-down: every node has a +title and a short summary; leaf nodes additionally carry the section +markdown content. Cite supporting passages on each node. + +Honour these flags from the run config: +- extract_methodology={extract_methodology}: when true, every methodology + section becomes its own subtree with a per-step summary. +- extract_limitations={extract_limitations}: when true, surface + limitations as a dedicated top-level child rather than inlining them. +- asset_class_hint={asset_class_hint!r}: when set, prefer this asset + class for ``Paper.asset_classes`` if the paper does not state one + explicitly. + +Set ``as_of`` to the publication date when given; otherwise use today's +date. Set the ``source`` provenance ref using the metadata supplied in +the prompt. +""" + + +class UnsupportedContentTypeError(ValueError): + """Fetched bytes have a content type paper_flow cannot route to a parser.""" + + +async def paper_flow( + input: PaperInput, + *, + cfg: PaperFlowCfg | None = None, + extra_tools: list[Tool] | None = None, + extra_instructions: str | None = None, + output_type: type[P] | None = None, + memory: object | None = None, + extra_run_hooks: list[RunHooks[Any]] | None = None, + extra_input_guardrails: list[Any] | None = None, + extra_output_guardrails: list[Any] | None = None, +) -> P | Paper: + """Extract a ``Paper`` from a typed ``PaperInput``. + + See design doc §4.1 for the rationale on each kwarg. ``memory`` is a + PR6 placeholder — non-None values are accepted but unused in PR5. + + Raises: + UnsupportedContentTypeError: When fetched bytes are not PDF / + HTML / markdown / plain-text. + NotImplementedError: When ``input`` is a ``DoiIdentifier`` (the + unpaywall fallback is its own follow-up issue). + """ + cfg = cfg or PaperFlowCfg() + out_type: type[Paper] = output_type or Paper # type: ignore[assignment] + + raw_md, source_meta = await _fetch_and_format(input) + + # Agent's `model_settings` parameter is non-optional (defaults to a + # fresh ``ModelSettings()``); only forward when cfg has one set. + agent_kwargs: dict[str, Any] = { + "name": "paper_extractor", + "instructions": _compose_instructions( + _DEFAULT_INSTRUCTIONS, extra_instructions, cfg + ), + "model": cfg.model, + "tools": list(extra_tools or []), + "output_type": out_type, + "input_guardrails": list(extra_input_guardrails or []), + "output_guardrails": list(extra_output_guardrails or []), + } + if cfg.model_settings is not None: + agent_kwargs["model_settings"] = cfg.model_settings + agent: Agent[Any] = Agent(**agent_kwargs) + return await run_with_observability( + agent, + _format_input(raw_md, source_meta), + cfg=cfg, + memory=memory, + extra_run_hooks=list(extra_run_hooks or []), + ) + + +async def _fetch_and_format( + input: PaperInput, +) -> tuple[str, dict[str, Any]]: + """Dispatch on the input variant; return (markdown, source metadata).""" + if isinstance(input, ArxivIdentifier): + raw = await fetch_arxiv(input.id) + md = await pdf_to_markdown(raw.bytes) + return md, { + "source": "arxiv", + "arxiv_id": raw.arxiv_id, + "title": raw.title, + "authors": list(raw.authors), + } + if isinstance(input, HttpUrl): + raw = await fetch_url(input.url) + md = await _format_by_content_type(raw) + return md, { + "source": "web", + "url": input.url, + "content_type": raw.content_type, + } + if isinstance(input, LocalFilePath): + raw = await read_local_file(input.path) + md = await _format_by_content_type(raw) + return md, { + "source": "local", + "path": str(input.path), + "content_type": raw.content_type, + } + if isinstance(input, RawText): + return input.text, {"source": "inline"} + if isinstance(input, DoiIdentifier): + # PR4's CrossrefMetadata exposes only `primary_url` (publisher + # landing page), not a direct PDF link. Adding the unpaywall + # fallback that turns a DOI into an OA PDF URL is its own + # follow-up issue. + raise NotImplementedError( + "DOI inputs require an OA PDF resolver (unpaywall fallback) " + "which is tracked as a PR4 follow-up. Use ArxivIdentifier or " + "HttpUrl for now." + ) + raise TypeError(f"Unsupported PaperInput variant: {type(input)!r}") + + +async def _format_by_content_type(raw: Fetched) -> str: + """Route a ``Fetched`` payload through the right format helper.""" + ct = (raw.content_type or "").lower() + if ct.startswith("application/pdf"): + return await pdf_to_markdown(raw.bytes) + if ct.startswith("text/html"): + return await html_to_markdown( + raw.bytes.decode("utf-8", errors="replace") + ) + if ct.startswith("text/markdown") or ct.startswith("text/plain"): + return raw.bytes.decode("utf-8", errors="replace") + raise UnsupportedContentTypeError( + f"Unsupported content-type for paper input: {ct!r}" + ) + + +def _compose_instructions( + base: str, extra: str | None, cfg: PaperFlowCfg +) -> str: + """Render the system instructions, appending ``extra`` if provided.""" + instructions = base.format( + extract_methodology=cfg.extract_methodology, + extract_limitations=cfg.extract_limitations, + asset_class_hint=cfg.asset_class_hint, + ) + if extra: + instructions = f"{instructions}\n\nAdditional instructions:\n{extra}" + return instructions + + +def _format_input(raw_md: str, source_meta: dict[str, Any]) -> str: + """Concatenate metadata + content into the prompt the agent sees.""" + lines: list[str] = [] + for key, value in source_meta.items(): + if value is None: + continue + if isinstance(value, (list, tuple)): + value = ", ".join(map(str, value)) + lines.append(f"{key}: {value}") + header = "\n".join(lines) + return ( + f"--- Source metadata ---\n{header}\n\n--- Paper content ---\n{raw_md}" + ) diff --git a/quantmind/llm/__init__.py b/quantmind/llm/__init__.py deleted file mode 100644 index 55ad035..0000000 --- a/quantmind/llm/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""LLM module for QuantMind - Basic LLM functionality.""" - -from .block import LLMBlock, create_llm_block -from .embedding import EmbeddingBlock, create_embedding_block - -__all__ = [ - "LLMBlock", - "create_llm_block", - "EmbeddingBlock", - "create_embedding_block", -] diff --git a/quantmind/llm/block.py b/quantmind/llm/block.py deleted file mode 100644 index 2b9b69f..0000000 --- a/quantmind/llm/block.py +++ /dev/null @@ -1,347 +0,0 @@ -"""LLMBlock - A reusable LLM function block using LiteLLM.""" - -import os -import time -from contextlib import contextmanager -from typing import Any, Dict, List, Optional - -from quantmind.utils.logger import get_logger - -from ..config import LLMConfig - -logger = get_logger(__name__) - -try: - import litellm - from litellm import completion - - LITELLM_AVAILABLE = True -except ImportError: - LITELLM_AVAILABLE = False - - -class LLMBlock: - """A reusable LLM function block using LiteLLM. - - LLMBlock provides a consistent interface for LLM operations across - different providers (OpenAI, Anthropic, Google, Azure, etc.). - - Unlike workflows, LLMBlock focuses on providing basic LLM capabilities - without business logic. - """ - - def __init__(self, config: LLMConfig): - """Initialize LLMBlock with configuration. - - Args: - config: LLM configuration - - Raises: - ImportError: If LiteLLM is not available - """ - if not LITELLM_AVAILABLE: - raise ImportError( - "LiteLLM is not available. Please install it with: pip install litellm" - ) - - self.config = config - self._setup_litellm() - - logger.info(f"Initialized LLMBlock with model: {config.model}") - - def _setup_litellm(self): - """Setup LiteLLM configuration.""" - # Set global LiteLLM settings - litellm.set_verbose = False # Disable verbose logging by default - - # Configure retries - litellm.num_retries = self.config.retry_attempts - litellm.request_timeout = self.config.timeout - - # Set API key as environment variable if provided - if self.config.api_key: - provider_type = self.config.get_provider_type() - if provider_type == "openai": - os.environ["OPENAI_API_KEY"] = self.config.api_key - elif provider_type == "anthropic": - os.environ["ANTHROPIC_API_KEY"] = self.config.api_key - elif provider_type == "google": - os.environ["GOOGLE_API_KEY"] = self.config.api_key - elif provider_type == "deepseek": - os.environ["DEEPSEEK_API_KEY"] = self.config.api_key - - logger.debug( - f"Configured LiteLLM for provider: {self.config.get_provider_type()}" - ) - - def generate_text( - self, prompt: str, system_prompt: Optional[str] = None, **kwargs - ) -> Optional[str]: - """Generate text using the configured LLM. - - Args: - prompt: User prompt - system_prompt: Optional system prompt (overrides config) - **kwargs: Additional parameters to override config - - Returns: - Generated text or None if failed - """ - try: - # Build messages - messages = self._build_messages(prompt, system_prompt) - - # Get LiteLLM parameters - params = self.config.get_litellm_params() - params.update(kwargs) # Allow runtime overrides - - # Add messages to parameters - params["messages"] = messages - - # Call LiteLLM with retry logic - response = self._call_with_retry(params) - - if response and response.choices: - content = response.choices[0].message.content - if content: - return content.strip() - - logger.warning("No content received from LLM") - return None - - except Exception as e: - logger.error(f"Error generating text: {e}") - return None - - def generate_structured_output( - self, - prompt: str, - system_prompt: Optional[str] = None, - response_format: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Optional[Dict[str, Any]]: - """Generate structured output (JSON) using the configured LLM. - - Args: - prompt: User prompt - system_prompt: Optional system prompt - response_format: JSON schema for structured output - **kwargs: Additional parameters - - Returns: - Parsed JSON response or None if failed - """ - try: - # Build messages - messages = self._build_messages(prompt, system_prompt) - - # Get LiteLLM parameters - params = self.config.get_litellm_params() - params.update(kwargs) - params["messages"] = messages - - # Add response format if provided - # TODO: Refactor the response_format to be more generic - if response_format: - provider_type = self.config.get_provider_type() - if provider_type == "openai": - params["response_format"] = response_format - elif ( - provider_type == "google" - and "response_schema" in response_format - ): - # Gemini specific format - params["response_format"] = response_format - - # Call LiteLLM - response = self._call_with_retry(params) - - if response and response.choices: - content = response.choices[0].message.content - if content: - # Try to parse JSON - import json - - try: - return json.loads(content.strip()) - except json.JSONDecodeError: - # Fallback: try to extract JSON from text - return self._extract_json_from_text(content) - - return None - - except Exception as e: - logger.error(f"Error generating structured output: {e}") - return None - - def _build_messages( - self, prompt: str, system_prompt: Optional[str] = None - ) -> List[Dict[str, str]]: - """Build messages array for LLM call. - - Args: - prompt: User prompt - system_prompt: Optional system prompt - - Returns: - Messages array - """ - messages = [] - - # Add system prompt - final_system_prompt = system_prompt or self.config.system_prompt - if final_system_prompt: - messages.append({"role": "system", "content": final_system_prompt}) - - # Add user prompt with custom instructions - final_prompt = prompt - if self.config.custom_instructions: - final_prompt = f"{prompt}\n\nAdditional Instructions:\n{self.config.custom_instructions}" - - messages.append({"role": "user", "content": final_prompt}) - - return messages - - def _call_with_retry(self, params: Dict[str, Any]) -> Optional[Any]: - """Call LiteLLM with retry logic. - - Args: - params: LiteLLM parameters - - Returns: - LiteLLM response or None - """ - last_exception = None - - for attempt in range(self.config.retry_attempts + 1): - try: - logger.debug( - f"LLM call attempt {attempt + 1}/{self.config.retry_attempts + 1}" - ) - - response = completion(**params) - - # Log usage if available - if hasattr(response, "usage") and response.usage: - logger.debug(f"Token usage: {response.usage}") - - return response - - except Exception as e: - last_exception = e - logger.warning(f"LLM call attempt {attempt + 1} failed: {e}") - - if attempt < self.config.retry_attempts: - time.sleep(self.config.retry_delay) - else: - logger.error( - f"All {self.config.retry_attempts + 1} attempts failed" - ) - - # Log final error - if last_exception: - logger.error(f"Final error: {last_exception}") - - return None - - def _extract_json_from_text(self, text: str) -> Optional[Dict[str, Any]]: - """Extract JSON from text response as fallback. - - Args: - text: Response text - - Returns: - Parsed JSON or None - """ - import json - import re - - # Try to find JSON objects in the text - json_patterns = [ - r"\{[^{}]*\}", # Simple JSON object - r"\{.*?\}", # JSON object with nested content - r"\[.*?\]", # JSON array - ] - - for pattern in json_patterns: - matches = re.findall(pattern, text, re.DOTALL) - for match in matches: - try: - return json.loads(match) - except json.JSONDecodeError: - continue - - logger.warning("Could not extract JSON from text response") - return None - - def test_connection(self) -> bool: - """Test if the LLM connection is working. - - Returns: - True if connection is working, False otherwise - """ - try: - response = self.generate_text( - "Hello, this is a test. Please respond with 'OK'." - ) - return response is not None and len(response) > 0 - except Exception as e: - logger.error(f"Connection test failed: {e}") - return False - - def get_info(self) -> Dict[str, Any]: - """Get information about the current LLMBlock. - - Returns: - LLMBlock information dictionary - """ - return { - "model": self.config.model, - "provider": self.config.get_provider_type(), - "temperature": self.config.temperature, - "max_tokens": self.config.max_tokens, - "timeout": self.config.timeout, - "retry_attempts": self.config.retry_attempts, - } - - def update_config(self, **kwargs) -> None: - """Update configuration parameters. - - Args: - **kwargs: Configuration parameters to update - """ - # Create new config with overrides - self.config = self.config.create_variant(**kwargs) - - # Re-setup LiteLLM - self._setup_litellm() - - logger.info("Updated LLMBlock configuration") - - @contextmanager - def temporary_config(self, **kwargs): - """Context manager for temporary configuration changes. - - Args: - **kwargs: Temporary configuration overrides - """ - original_config = self.config.model_copy() - - try: - self.update_config(**kwargs) - yield self - finally: - self.config = original_config - self._setup_litellm() - - -def create_llm_block(config: LLMConfig) -> LLMBlock: - """Create a new LLMBlock instance. - - Args: - config: LLM configuration - - Returns: - New LLMBlock instance - """ - return LLMBlock(config) diff --git a/quantmind/llm/embedding.py b/quantmind/llm/embedding.py deleted file mode 100644 index 3b1c85f..0000000 --- a/quantmind/llm/embedding.py +++ /dev/null @@ -1,323 +0,0 @@ -"""EmbeddingBlock - A reusable Embedding function block using LiteLLM.""" - -import os -import time -from contextlib import contextmanager -from typing import Any, Dict, List, Optional - -from quantmind.utils.logger import get_logger - -from ..config import EmbeddingConfig - -logger = get_logger(__name__) - -try: - import litellm - from litellm import embedding - - LITELLM_AVAILABLE = True -except ImportError: - LITELLM_AVAILABLE = False - - -class EmbeddingBlock: - """A reusable Embedding function block using LiteLLM. - - EmbeddingBlock provides a consistent interface for generating embeddings across - different providers (OpenAI, Gemini, etc.). - - Unlike workflows, EmbeddingBlock focuses on providing basic embedding capabilities - without business logic. - """ - - def __init__(self, config: EmbeddingConfig): - """Initialize the EmbeddingBlock with configuration. - - Args: - config: Embedding configuration - - Raises: - ImportError: If LiteLLM is not available. - """ - if not LITELLM_AVAILABLE: - raise ImportError( - "litellm is required for EmbeddingBlock but not installed." - ) - - self.config = config - self._setup_litellm() - - logger.info(f"Initialized EmbeddingBlock with model: {config.model}") - - def _setup_litellm(self): - """Setup LiteLLM configuration.""" - # Set global LiteLLM settings - litellm.set_verbose = False # Disable verbose logging by default - - # Configure retries - litellm.num_retries = self.config.retry_attempts - litellm.request_timeout = self.config.timeout - - # Set API key as environment variable if provided - if self.config.api_key: - provider_type = self.config.get_provider_type() - if provider_type == "openai": - os.environ["OPENAI_API_KEY"] = self.config.api_key - elif provider_type == "azure": - os.environ["AZURE_API_KEY"] = self.config.api_key - elif provider_type == "gemini": - os.environ["GEMINI_API_KEY"] = self.config.api_key - - logger.debug( - f"Configured LiteLLM for provider: {self.config.get_provider_type()}" - ) - - def generate_embedding(self, text: str, **kwargs) -> Optional[List[float]]: - """Generate embedding using the configured Embedding model. - - Args: - text (str): The input text to embed. - **kwargs: Additional parameters to override config - - Returns: - List[float]: The embedding vector as a list of floats, or None if failed. - """ - try: - # Get LiteLLM parameters - params = self.config.get_litellm_params() - params.update(kwargs) # Allow runtime overrides - - # Add input text - params["input"] = text - - # Call LiteLLM embedding - response = self._call_with_retry(params) - - if response and hasattr(response, "data"): - # Extract embedding from response - embedding_data = ( - response.data[0] - if isinstance(response.data, list) - else response.data - ) - return embedding_data.embedding - - return None - - except Exception as e: - logger.error(f"Failed to generate embedding: {e}") - return None - - def generate_embeddings( - self, texts: List[str], **kwargs - ) -> Optional[List[List[float]]]: - """Generate embeddings for multiple texts. - - Args: - texts (List[str]): List of input texts to embed. - **kwargs: Additional parameters to override config - - Returns: - List[List[float]]: List of embedding vectors, or None if failed. - """ - try: - # Get LiteLLM parameters - params = self.config.get_litellm_params() - params.update(kwargs) # Allow runtime overrides - - # Add input texts - params["input"] = texts - - # Call LiteLLM embedding - response = self._call_with_retry(params) - - if response and hasattr(response, "data"): - # Extract embeddings from response - return [item.embedding for item in response.data] - - return None - - except Exception as e: - logger.error(f"Failed to generate embeddings: {e}") - return None - - def _call_with_retry(self, params: Dict[str, Any]) -> Optional[Any]: - """Call LiteLLM embedding with retry logic. - - Args: - params (Dict[str, Any]): The parameters to pass to the embedding function. - - Returns: - Optional[Any]: The embedding result or None if failed. - """ - last_exception = None - for attempt in range(self.config.retry_attempts + 1): - try: - logger.debug( - f"Embedding call attempt {attempt + 1}/{self.config.retry_attempts + 1}" - ) - - # Create a copy of params to avoid mutation - call_params = params.copy() - - # Extract input from params - input_text = call_params.pop("input") - - # Remove model from params if it exists to avoid duplication - call_params.pop("model", None) - - response = embedding( - model=self.config.model, input=input_text, **call_params - ) - - if hasattr(response, "usage") and response.usage: - logger.debug(f"Token usage: {response.usage}") - return response - except Exception as e: - last_exception = e - logger.warning( - f"Embedding call attempt {attempt + 1} failed: {e}" - ) - - if attempt < self.config.retry_attempts: - time.sleep(self.config.retry_delay) - else: - logger.error( - f"All {self.config.retry_attempts + 1} attempts failed" - ) - - # Log final error - if last_exception: - logger.error(f"Final error: {last_exception}") - - return None - - def test_connection(self) -> bool: - """Test if the embedding connection is working. - - Returns: - True if connection is working, False otherwise - """ - try: - response = self.generate_embedding("test") - return response is not None and len(response) > 0 - except Exception as e: - logger.error(f"Connection test failed: {e}") - return False - - def get_info(self) -> Dict[str, Any]: - """Get information about the embedding block. - - Returns: - Dictionary with embedding block information - """ - info = { - "model": self.config.model, - "provider": self.config.get_provider_type(), - "timeout": self.config.timeout, - "retry_attempts": self.config.retry_attempts, - } - return info - - def get_embedding_dimension(self) -> Optional[int]: - """Get the dimension of embeddings generated by this model. - - Returns: - Embedding dimension or None if not available - """ - # First check if dimensions is specified in config - if self.config.dimensions: - return self.config.dimensions - - try: - # Try to get dimension by generating a test embedding - test_embedding = self.generate_embedding("test") - return len(test_embedding) if test_embedding else None - except Exception as e: - logger.error(f"Failed to get embedding dimension: {e}") - return None - - def update_config(self, **kwargs) -> None: - """Update the embedding configuration. - - Args: - **kwargs: Configuration parameters to update - """ - for key, value in kwargs.items(): - if hasattr(self.config, key): - setattr(self.config, key, value) - - logger.info(f"Updated embedding configuration: {kwargs}") - - @contextmanager - def temporary_config(self, **kwargs): - """Temporarily modify configuration for a context. - - Args: - **kwargs: Temporary configuration parameters - - Yields: - Self with temporary configuration - """ - original_config = {} - for key, value in kwargs.items(): - if hasattr(self.config, key): - original_config[key] = getattr(self.config, key) - setattr(self.config, key, value) - - try: - yield self - finally: - # Restore original configuration - for key, value in original_config.items(): - setattr(self.config, key, value) - - def batch_embed( - self, texts: List[str], batch_size: int = 32, **kwargs - ) -> Optional[List[List[float]]]: - """Generate embeddings in batches for large datasets. - - Args: - texts: List of texts to embed - batch_size: Number of texts to process in each batch - **kwargs: Additional parameters for embedding generation - - Returns: - List of embedding vectors or None if failed - """ - try: - all_embeddings = [] - - for i in range(0, len(texts), batch_size): - batch = texts[i : i + batch_size] - batch_embeddings = self.generate_embeddings(batch, **kwargs) - - if batch_embeddings is None: - logger.error( - f"Failed to generate embeddings for batch {i // batch_size}" - ) - return None - - all_embeddings.extend(batch_embeddings) - - # Add delay between batches if specified - if self.config.retry_delay > 0 and i + batch_size < len(texts): - time.sleep(self.config.retry_delay) - - return all_embeddings - - except Exception as e: - logger.error(f"Batch embedding failed: {e}") - return None - - -def create_embedding_block(config: EmbeddingConfig) -> EmbeddingBlock: - """Create an EmbeddingBlock instance. - - Args: - config: Embedding configuration - - Returns: - Configured EmbeddingBlock instance - """ - return EmbeddingBlock(config) diff --git a/quantmind/magic.py b/quantmind/magic.py new file mode 100644 index 0000000..71e522a --- /dev/null +++ b/quantmind/magic.py @@ -0,0 +1,201 @@ +"""Natural-language → ``(input, cfg)`` resolver. + +``resolve_magic_input`` introspects a flow function's ``input`` and +``cfg`` parameter annotations, builds a parameterized +``ResolvedFlowConfig[InputT, CfgT]``, and runs a lightweight resolver +agent to populate it. The resolver instructions are templated with the +JSON-schema rendering of both types so the model sees exactly which +fields are valid. + +This module sits at the top level (not under ``flows/``) because its +output — a ``(input_obj, cfg_obj)`` tuple — is flow-agnostic. The same +resolver works for any future flow that follows the +``(input, *, cfg, ...)`` signature convention. +""" + +import inspect +import json +import types +from collections.abc import Awaitable, Callable +from typing import Any, Generic, TypeVar, Union, get_args, get_origin + +from agents import Agent, Runner +from pydantic import BaseModel + +from quantmind.configs.base import BaseFlowCfg + +InputT = TypeVar("InputT", bound=BaseModel) +CfgT = TypeVar("CfgT", bound=BaseModel) + + +class ResolvedFlowConfig(BaseModel, Generic[InputT, CfgT]): + """Output schema returned by the resolver agent.""" + + input_obj: InputT + cfg_obj: CfgT + + +_RESOLVER_INSTRUCTIONS = """\ +You are a configuration resolver for the QuantMind {flow_name} flow. +Given a natural-language description of intent, produce a +``ResolvedFlowConfig`` with two fields: + +- ``input_obj`` — one variant of the input discriminated union. +- ``cfg_obj`` — the flow configuration. + +Rules: +- Set fields conservatively. Leave unspecified fields at their defaults + rather than inventing values. +- The ``input_obj.type`` discriminator decides which variant you produce. +- Never invent file paths or URLs. If the description does not give a + concrete identifier, prefer the ``RawText`` variant (when available) + with the description's content. + +Input schema: +{input_schema} + +Cfg schema: +{cfg_schema} +""" + + +async def resolve_magic_input( + natural_language: str, + *, + target_flow: Callable[..., Awaitable[Any]], + resolver_model: str = "gpt-4o-mini", + resolver_instructions: str | None = None, +) -> tuple[Any, Any]: + """Parse ``natural_language`` into ``(input_obj, cfg_obj)`` for ``target_flow``. + + Args: + natural_language: User-supplied free-form description of intent. + target_flow: The flow function to resolve for. Must accept + ``input`` (positional) and ``cfg`` (keyword) parameters. + resolver_model: LLM used by the resolver agent. + resolver_instructions: Optional override for the resolver's + system prompt template. Receives ``flow_name``, + ``input_schema``, and ``cfg_schema`` via ``str.format``. + + Returns: + Tuple of ``(input_obj, cfg_obj)`` populated by the resolver. + """ + input_type, cfg_type = _introspect_flow_signature(target_flow) + template = resolver_instructions or _RESOLVER_INSTRUCTIONS + instructions = template.format( + flow_name=target_flow.__name__, + input_schema=_pydantic_schema_str(input_type), + cfg_schema=_pydantic_schema_str(cfg_type), + ) + resolver: Agent[Any] = Agent( + name=f"magic_resolver_{target_flow.__name__}", + instructions=instructions, + model=resolver_model, + output_type=ResolvedFlowConfig[input_type, cfg_type], # type: ignore[valid-type] + ) + result = await Runner.run(resolver, natural_language) + out = result.final_output + return out.input_obj, out.cfg_obj + + +async def preview_resolve( + natural_language: str, + *, + target_flow: Callable[..., Awaitable[Any]], + resolver_model: str = "gpt-4o-mini", +) -> tuple[Any, Any]: + """Resolve and pretty-print the result without invoking the flow.""" + inp, cfg = await resolve_magic_input( + natural_language, + target_flow=target_flow, + resolver_model=resolver_model, + ) + print("input_obj:", inp.model_dump_json(indent=2)) + print("cfg_obj:", cfg.model_dump_json(indent=2)) + return inp, cfg + + +def _introspect_flow_signature( + flow_fn: Callable[..., Any], +) -> tuple[Any, type[BaseFlowCfg]]: + """Return ``(input_annotation, cfg_type)`` for a flow function. + + ``input_annotation`` is returned as-is — it may be a discriminated- + union alias such as ``Annotated[Union[...], Field(discriminator=...)]``. + Pydantic accepts both plain ``BaseModel`` subclasses and discriminated + aliases as generic parameters. + + ``cfg_type`` strips an outer ``T | None`` so the resolver instantiates + the concrete cfg subclass. The result must be a ``BaseFlowCfg`` + subclass; anything else means the flow's signature is misshapen. + """ + sig = inspect.signature(flow_fn) + if "input" not in sig.parameters: + raise TypeError( + f"Flow {flow_fn.__name__!r} must accept an `input` parameter" + ) + if "cfg" not in sig.parameters: + raise TypeError( + f"Flow {flow_fn.__name__!r} must accept a `cfg` keyword parameter" + ) + input_anno = sig.parameters["input"].annotation + cfg_anno = sig.parameters["cfg"].annotation + cfg_type = _strip_optional(cfg_anno) + if not (isinstance(cfg_type, type) and issubclass(cfg_type, BaseFlowCfg)): + raise TypeError( + f"Flow {flow_fn.__name__!r} `cfg` annotation must resolve to " + f"a BaseFlowCfg subclass (got {cfg_anno!r})" + ) + return input_anno, cfg_type + + +def _strip_optional(anno: Any) -> Any: + """Peel ``T | None`` / ``Optional[T]`` to return the inner T.""" + origin = get_origin(anno) + if origin in (Union, types.UnionType): + non_none = [a for a in get_args(anno) if a is not type(None)] + if len(non_none) == 1: + return non_none[0] + return anno + + +def _pydantic_schema_str(t: Any) -> str: + """Render a JSON-schema-ish description for resolver instructions. + + Cases handled: + + 1. ``Annotated[X, ...]`` — peel via ``__metadata__`` and recurse on X. + 2. Plain ``BaseModel`` subclass — use ``model_json_schema()``. + 3. ``Union[...]`` / ``T | U`` — recurse on each variant; emit + ``{"oneOf": [...]}``. + 4. Anything else — fall back to ``repr`` so the resolver still gets + *some* hint. Should not happen for the supported flows. + """ + if hasattr(t, "__metadata__"): + inner = get_args(t)[0] + return _pydantic_schema_str(inner) + + if isinstance(t, type) and hasattr(t, "model_json_schema"): + try: + return json.dumps(t.model_json_schema(), indent=2) + except Exception: + # Some Pydantic models (e.g. those holding callable fields + # like ``ModelSettings`` from the agents SDK) cannot render + # a full JSON schema. Fall back to a name+fields summary + # so the resolver still has something to work with. + fields = { + name: repr(field.annotation) + for name, field in t.model_fields.items() + } + return json.dumps({"title": t.__name__, "fields": fields}, indent=2) + + origin = get_origin(t) + if origin in (Union, types.UnionType): + variants = get_args(t) + schemas = [ + json.loads(_pydantic_schema_str(v)) + for v in variants + if isinstance(v, type) and hasattr(v, "model_json_schema") + ] + return json.dumps({"oneOf": schemas}, indent=2) + return repr(t) diff --git a/quantmind/models/__init__.py b/quantmind/models/__init__.py deleted file mode 100644 index 94151f5..0000000 --- a/quantmind/models/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Data models for QuantMind knowledge representation.""" - -from .content import BaseContent, KnowledgeItem -from .paper import Paper - -__all__ = ["Paper", "BaseContent", "KnowledgeItem"] diff --git a/quantmind/models/analysis.py b/quantmind/models/analysis.py deleted file mode 100644 index eebcb53..0000000 --- a/quantmind/models/analysis.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Analysis and research-specific data models for QuantMind.""" - -from datetime import datetime, timezone -from typing import List, Optional -from uuid import uuid4 - -from pydantic import BaseModel, Field - - -class QuestionAnswer(BaseModel): - """Question and answer pair.""" - - question: str = Field(..., description="Generated question") - answer: str = Field(..., description="Generated answer") - difficulty: str = Field(default="medium", description="Question difficulty") - difficulty_level: str = Field( - default="medium", - description="Question difficulty level (backward compatibility)", - ) - category: str = Field(default="general", description="Question category") - confidence: float = Field( - ge=0.0, le=1.0, default=0.8, description="Answer confidence" - ) - generated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc) - ) - - -class PaperAnalysis(BaseModel): - """Comprehensive paper analysis results.""" - - paper_id: str - analysis_id: str = Field(default_factory=lambda: str(uuid4())) - # Tags - primary_tags: List[str] = Field(default_factory=list) - secondary_tags: List[str] = Field(default_factory=list) - # Q&A - questions_answers: List[QuestionAnswer] = Field(default_factory=list) - # Summary - key_insights: List[str] = Field(default_factory=list) - methodology_summary: Optional[str] = None - results_summary: Optional[str] = None - # Metadata - analysis_version: str = "1.0" - analysis_timestamp: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc) - ) - analysis_duration: Optional[float] = None # seconds - - class Config: - """Pydantic model configuration.""" - - json_encoders = {datetime: lambda v: v.isoformat()} - - -class AnalysisConfig(BaseModel): - """Configuration for paper analysis.""" - - # Tag analysis - enable_tag_analysis: bool = True - tag_confidence_threshold: float = 0.7 - max_primary_tags: int = 5 - max_secondary_tags: int = 10 - # Q&A generation - enable_qa_generation: bool = True - num_questions: int = 5 - include_different_difficulties: bool = True - focus_on_insights: bool = True - # Visual extraction - enable_visual_extraction: bool = True - extract_framework_only: bool = False - min_importance_score: float = 0.6 - # LLM settings - llm_model: str = "gpt-4o" - max_tokens: int = 4000 - temperature: float = 0.3 - # Processing - parallel_processing: bool = True - cache_results: bool = True diff --git a/quantmind/models/content.py b/quantmind/models/content.py deleted file mode 100644 index 0b0ac20..0000000 --- a/quantmind/models/content.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Generic content model for QuantMind knowledge representation.""" - -from abc import ABC, abstractmethod -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional -from uuid import uuid4 - -from pydantic import BaseModel, Field - - -class BaseContent(BaseModel, ABC): - """Abstract base class for all content types in QuantMind. - - This serves as the foundation for different knowledge entities - like papers, articles, reports, etc. - """ - - # Core identifiers - id: str = Field(default_factory=lambda: str(uuid4())) - source_id: Optional[str] = None # ID from original source - - # Core content - title: str = Field(..., min_length=1) - abstract: Optional[str] = None - content: Optional[str] = None # Full content text - - # Metadata - authors: List[str] = Field(default_factory=list) - published_date: Optional[datetime] = None - categories: List[str] = Field(default_factory=list) - tags: List[str] = Field(default_factory=list) - - # Source information - url: Optional[str] = None - source: Optional[str] = None # e.g., "arxiv", "pubmed", "news" - extraction_method: Optional[str] = None # e.g., "api", "scraping" - processed_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc) - ) - - # Flexible metadata storage - meta_info: Dict[str, Any] = Field(default_factory=dict) - - class Config: - """Pydantic configuration.""" - - json_encoders = {datetime: lambda v: v.isoformat()} - - @abstractmethod - def get_text_for_embedding(self) -> str: - """Get text content for embedding generation. - - Returns: - String representation of content for vectorization - """ - pass - - def get_primary_id(self) -> str: - """Get the primary identifier for the content. - - Returns: - Source ID if available, otherwise internal ID - """ - return self.source_id or self.id - - def add_tag(self, tag: str) -> None: - """Add a tag if not already present.""" - if tag not in self.tags: - self.tags.append(tag) - - def add_category(self, category: str) -> None: - """Add a category if not already present.""" - if category not in self.categories: - self.categories.append(category) - - -class KnowledgeItem(BaseContent): - """Generic knowledge item implementation. - - Can represent various types of content like papers, articles, - reports, news items, etc. - """ - - # Content-specific fields - content_type: str = Field(default="generic") # paper, article, report, etc. - language: Optional[str] = None - - # Additional URLs - pdf_url: Optional[str] = None - code_url: Optional[str] = None - - # Vector representation - embedding: Optional[List[float]] = None - embedding_model: Optional[str] = None - - def get_text_for_embedding(self) -> str: - """Get concatenated text for embedding generation.""" - parts = [] - if self.title: - parts.append(self.title) - if self.abstract: - parts.append(self.abstract) - if self.content: - parts.append(self.content[:1000]) # Limit content length - - return "\n\n".join(parts) - - def has_full_content(self) -> bool: - """Check if item has full content.""" - return bool(self.content and len(self.content.strip()) > 0) - - def set_embedding(self, embedding: List[float], model: str = None) -> None: - """Set the content's embedding vector.""" - self.embedding = embedding - if model: - self.embedding_model = model diff --git a/quantmind/models/paper.py b/quantmind/models/paper.py deleted file mode 100644 index 774bf00..0000000 --- a/quantmind/models/paper.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Paper model for QuantMind knowledge representation.""" - -import json -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -from pydantic import Field, field_validator - -from quantmind.models.content import KnowledgeItem - - -class Paper(KnowledgeItem): - """Research paper entity with structured metadata and validation. - - Core knowledge unit in the QuantMind system, representing a research paper - with comprehensive metadata, content, and processing information. - """ - - # Paper-specific identifiers - paper_id: Optional[str] = None - arxiv_id: Optional[str] = None - doi: Optional[str] = None - - # Override content type - content_type: str = Field(default="paper") - - # Paper-specific content (inherits title, abstract, content from parent) - # Additional paper URLs (inherit url, pdf_url from parent) - code_url: Optional[str] = None - - @field_validator("categories", "tags", mode="before") - def ensure_list(cls, v): - """Ensure categories and tags are always lists.""" - if isinstance(v, str): - return [v] - return v or [] - - @field_validator("authors", mode="before") - def parse_authors(cls, v): - """Parse authors from various formats.""" - if isinstance(v, str): - # Handle comma-separated authors - return [author.strip() for author in v.split(",")] - return v or [] - - def get_text_for_embedding(self) -> str: - """Get concatenated text for embedding generation. - - Returns: - Combined title and abstract text - """ - return f"{self.title}\n\n{self.abstract}" - - def add_tag(self, tag: str) -> None: - """Add a tag if not already present. - - Args: - tag: Tag to add - """ - if tag not in self.tags: - self.tags.append(tag) - - def add_category(self, category: str) -> None: - """Add a category if not already present. - - Args: - category: Category to add - """ - if category not in self.categories: - self.categories.append(category) - - def set_embedding(self, embedding: List[float], model: str = None) -> None: - """Set the paper's embedding vector. - - Args: - embedding: Vector representation - model: Name of the embedding model used - """ - self.embedding = embedding - if model: - self.embedding_model = model - - def has_content(self) -> bool: - """Check if paper has parsed content available. - - Returns: - True if content is available - """ - return bool(self.content and len(self.content.strip()) > 0) - - def get_primary_id(self) -> str: - """Get the primary identifier for the paper. - - Returns: - ArXiv ID if available, otherwise paper_id, otherwise id - """ - return self.arxiv_id or self.paper_id or self.id - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Paper": - """Create Paper from dictionary with flexible field mapping. - - Args: - data: Dictionary with paper data - - Returns: - Paper instance - """ - # Handle datetime fields - if "published_date" in data and isinstance(data["published_date"], str): - try: - data["published_date"] = datetime.fromisoformat( - data["published_date"] - ) - except ValueError: - data["published_date"] = None - - if "processed_at" in data and isinstance(data["processed_at"], str): - try: - data["processed_at"] = datetime.fromisoformat( - data["processed_at"] - ) - except ValueError: - data["processed_at"] = datetime.utcnow() - - return cls(**data) - - @classmethod - def load_from_file(cls, file_path: Union[str, Path]) -> "Paper": - """Load paper from JSON file. - - Args: - file_path: Path to JSON file - - Returns: - Paper instance - """ - with open(file_path, "r", encoding="utf-8") as f: - data = json.load(f) - return cls.from_dict(data) - - @classmethod - def load_from_files( - cls, file_paths: List[Union[str, Path]] - ) -> List["Paper"]: - """Load multiple papers from JSON files. - - Args: - file_paths: List of file paths - - Returns: - List of Paper instances - """ - return [cls.load_from_file(path) for path in file_paths] - - def save_to_file(self, file_path: Union[str, Path]) -> None: - """Save paper to JSON file. - - Args: - file_path: Output file path - """ - with open(file_path, "w", encoding="utf-8") as f: - json.dump( - self.model_dump(), f, ensure_ascii=False, indent=2, default=str - ) - - def __str__(self) -> str: - """String representation.""" - return f"Paper({self.get_primary_id()}): {self.title[:50]}..." - - def __repr__(self) -> str: - """Detailed representation.""" - return f"Paper(id='{self.id}', title='{self.title[:30]}...', source='{self.source}')" diff --git a/tests/config/test_embedding.py b/tests/config/test_embedding.py deleted file mode 100644 index 6c2fd14..0000000 --- a/tests/config/test_embedding.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Tests for Embedding configuration.""" - -import unittest - -from quantmind.config.embedding import EmbeddingConfig - - -class TestEmbeddingConfig(unittest.TestCase): - """Test cases for EmbeddingConfig.""" - - def test_default_config(self): - """Test default configuration values.""" - config = EmbeddingConfig() - - # Test default values - self.assertEqual(config.model, "text-embedding-ada-002") - self.assertIsNone(config.user) - self.assertIsNone(config.dimensions) - self.assertEqual(config.encoding_format, "float") - self.assertEqual(config.timeout, 600) - self.assertIsNone(config.api_base) - self.assertIsNone(config.api_version) - self.assertIsNone(config.api_key) - self.assertIsNone(config.api_type) - - def test_custom_config(self): - """Test custom configuration values.""" - config = EmbeddingConfig( - model="text-embedding-3-small", - user="test_user_123", - dimensions=512, - encoding_format="base64", - timeout=1, - api_key="test-key", - api_base="https://api.example.com", - api_version="2023-05-15", - api_type="azure", - ) - - self.assertEqual(config.model, "text-embedding-3-small") - self.assertEqual(config.user, "test_user_123") - self.assertEqual(config.dimensions, 512) - self.assertEqual(config.encoding_format, "base64") - self.assertEqual(config.timeout, 1) - self.assertEqual(config.api_key, "test-key") - self.assertEqual(config.api_base, "https://api.example.com") - self.assertEqual(config.api_version, "2023-05-15") - self.assertEqual(config.api_type, "azure") - - def test_validation_model(self): - """Test model validation.""" - # Valid model - config = EmbeddingConfig(model="text-embedding-ada-002") - self.assertEqual(config.model, "text-embedding-ada-002") - - # Empty model should raise error - with self.assertRaises(ValueError): - EmbeddingConfig(model="") - - # None model should raise error - with self.assertRaises(ValueError): - EmbeddingConfig(model=None) - - # Whitespace should be stripped - config = EmbeddingConfig(model=" text-embedding-ada-002 ") - self.assertEqual(config.model, "text-embedding-ada-002") - - def test_validation_api_key(self): - """Test API key validation.""" - # Valid API key - config = EmbeddingConfig(api_key="test-key") - self.assertEqual(config.api_key, "test-key") - - # None API key is valid - config = EmbeddingConfig(api_key=None) - self.assertIsNone(config.api_key) - - # Invalid API key type should raise error - with self.assertRaises(ValueError): - EmbeddingConfig(api_key=123) - - with self.assertRaises(ValueError): - EmbeddingConfig(api_key=[]) - - def test_get_provider_type(self): - """Test provider type detection.""" - # OpenAI models - config = EmbeddingConfig(model="text-embedding-ada-002") - self.assertEqual(config.get_provider_type(), "openai") - - config = EmbeddingConfig(model="text-embedding-3-small") - self.assertEqual(config.get_provider_type(), "openai") - - config = EmbeddingConfig(model="text-embedding-3-large") - self.assertEqual(config.get_provider_type(), "openai") - - # Azure models - config = EmbeddingConfig(model="azure/text-embedding-ada-002") - self.assertEqual(config.get_provider_type(), "azure") - - config = EmbeddingConfig(model="text-embedding-ada-002-azure") - self.assertEqual(config.get_provider_type(), "azure") - - # Gemini models - config = EmbeddingConfig(model="gemini/embed-multilingual-v3.0") - self.assertEqual(config.get_provider_type(), "gemini") - - # Unknown models - config = EmbeddingConfig(model="unknown-model") - self.assertEqual(config.get_provider_type(), "unknown") - - def test_get_litellm_params_minimal(self): - """Test get_litellm_params with minimal configuration.""" - config = EmbeddingConfig(model="text-embedding-ada-002") - params = config.get_litellm_params() - - self.assertEqual(params["model"], "text-embedding-ada-002") - self.assertIn("encoding_format", params) - self.assertEqual(len(params), 2) # Only model and encoding_format - - def test_get_litellm_params_full(self): - """Test get_litellm_params with full configuration.""" - config = EmbeddingConfig( - model="text-embedding-3-small", - user="test_user", - dimensions=512, - encoding_format="base64", - timeout=1, - api_key="test-key", - api_base="https://api.example.com", - api_version="2023-05-15", - api_type="azure", - ) - params = config.get_litellm_params() - - expected_params = { - "model": "text-embedding-3-small", - "user": "test_user", - "dimensions": 512, - "encoding_format": "base64", - "api_base": "https://api.example.com", - "api_version": "2023-05-15", - "api_key": "test-key", - "api_type": "azure", - } - - self.assertEqual(params, expected_params) - - def test_get_litellm_params_partial(self): - """Test get_litellm_params with partial configuration.""" - config = EmbeddingConfig( - model="text-embedding-ada-002", - user="test_user", - dimensions=1536, - api_key="test-key", - ) - params = config.get_litellm_params() - - expected_params = { - "model": "text-embedding-ada-002", - "user": "test_user", - "dimensions": 1536, - "encoding_format": "float", - "api_key": "test-key", - } - - self.assertEqual(params, expected_params) - - def test_create_variant(self): - """Test creating configuration variants.""" - base_config = EmbeddingConfig( - model="text-embedding-ada-002", - timeout=1, - api_key="base-key", - ) - - # Create variant with overrides - variant = base_config.create_variant( - timeout=1, - api_key="variant-key", - user="test_user", - ) - - # Original config should be unchanged - self.assertEqual(base_config.timeout, 1) - self.assertEqual(base_config.api_key, "base-key") - self.assertIsNone(base_config.user) - - # Variant should have new values - self.assertEqual(variant.timeout, 1) - self.assertEqual(variant.api_key, "variant-key") - self.assertEqual(variant.user, "test_user") - self.assertEqual(variant.model, "text-embedding-ada-002") # Unchanged - - def test_create_variant_empty(self): - """Test creating variant with no overrides.""" - base_config = EmbeddingConfig( - model="text-embedding-ada-002", - timeout=1, - ) - - variant = base_config.create_variant() - - # Should be identical to base config - self.assertEqual(variant.model, base_config.model) - self.assertEqual(variant.timeout, base_config.timeout) - self.assertEqual(variant.encoding_format, base_config.encoding_format) - - def test_encoding_format_validation(self): - """Test encoding format validation.""" - # Valid encoding formats - config = EmbeddingConfig(encoding_format="float") - self.assertEqual(config.encoding_format, "float") - - config = EmbeddingConfig(encoding_format="base64") - self.assertEqual(config.encoding_format, "base64") - - def test_dimensions_validation(self): - """Test dimensions validation.""" - # Valid dimensions - config = EmbeddingConfig(dimensions=512) - self.assertEqual(config.dimensions, 512) - - config = EmbeddingConfig(dimensions=1536) - self.assertEqual(config.dimensions, 1536) - - config = EmbeddingConfig(dimensions=3072) - self.assertEqual(config.dimensions, 3072) - - # None is valid - config = EmbeddingConfig(dimensions=None) - self.assertIsNone(config.dimensions) - - # Zero and negative dimensions should be allowed (validation handled by API) - config = EmbeddingConfig(dimensions=0) - self.assertEqual(config.dimensions, 0) - - config = EmbeddingConfig(dimensions=-1) - self.assertEqual(config.dimensions, -1) - - def test_timeout_validation(self): - """Test timeout validation.""" - # Valid timeouts - config = EmbeddingConfig(timeout=1) - self.assertEqual(config.timeout, 1) - - config = EmbeddingConfig(timeout=1) - self.assertEqual(config.timeout, 1) - - config = EmbeddingConfig(timeout=1) - self.assertEqual(config.timeout, 1) - - # Zero and negative timeouts should be allowed (validation handled by API) - config = EmbeddingConfig(timeout=0) - self.assertEqual(config.timeout, 0) - - config = EmbeddingConfig(timeout=-1) - self.assertEqual(config.timeout, -1) - - def test_equality(self): - """Test config equality.""" - config1 = EmbeddingConfig( - model="text-embedding-ada-002", - user="test_user", - dimensions=512, - ) - - config2 = EmbeddingConfig( - model="text-embedding-ada-002", - user="test_user", - dimensions=512, - ) - - config3 = EmbeddingConfig( - model="text-embedding-3-small", - user="test_user", - dimensions=512, - ) - - self.assertEqual(config1, config2) - self.assertNotEqual(config1, config3) - - def test_repr(self): - """Test config string representation.""" - config = EmbeddingConfig( - model="text-embedding-ada-002", - user="test_user", - dimensions=512, - ) - - repr_str = repr(config) - self.assertIn("text-embedding-ada-002", repr_str) - self.assertIn("test_user", repr_str) - self.assertIn("512", repr_str) - - def test_str(self): - """Test config string representation.""" - config = EmbeddingConfig( - model="text-embedding-ada-002", - user="test_user", - dimensions=512, - ) - - str_repr = str(config) - self.assertIn("text-embedding-ada-002", str_repr) - self.assertIn("test_user", str_repr) - self.assertIn("512", str_repr) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/config/test_flow.py b/tests/config/test_flow.py deleted file mode 100644 index 9aeb054..0000000 --- a/tests/config/test_flow.py +++ /dev/null @@ -1,205 +0,0 @@ -"""Tests for flow configuration.""" - -import unittest - -from quantmind.config.flows import ( - BaseFlowConfig, - ChunkingStrategy, - SummaryFlowConfig, -) -from quantmind.config.llm import LLMConfig - - -class TestBaseFlowConfig(unittest.TestCase): - """Test cases for BaseFlowConfig.""" - - def test_init_basic(self): - """Test basic initialization.""" - config = BaseFlowConfig(name="test_flow") - - self.assertEqual(config.name, "test_flow") - self.assertEqual(config.llm_blocks, {}) - self.assertEqual(config.prompt_templates, {}) - self.assertIsNone(config.prompt_templates_path) - - def test_direct_llm_block_assignment(self): - """Test direct assignment of LLM blocks.""" - config = BaseFlowConfig(name="test_flow") - llm_config = LLMConfig(model="gpt-4o", temperature=0.5) - - config.llm_blocks["test_llm"] = llm_config - - self.assertIn("test_llm", config.llm_blocks) - self.assertEqual(config.llm_blocks["test_llm"], llm_config) - - def test_direct_prompt_template_assignment(self): - """Test direct assignment of prompt templates.""" - config = BaseFlowConfig(name="test_flow") - template = "Hello {{ name }}, how are you?" - - config.prompt_templates["greeting"] = template - - self.assertIn("greeting", config.prompt_templates) - self.assertEqual(config.prompt_templates["greeting"], template) - - def test_config_initialization_with_resources(self): - """Test initialization with resources.""" - llm_config = LLMConfig(model="gpt-4o") - template = "Test template" - - config = BaseFlowConfig( - name="test_flow", - llm_blocks={"test_llm": llm_config}, - prompt_templates={"test": template}, - ) - - self.assertEqual(config.llm_blocks["test_llm"], llm_config) - self.assertEqual(config.prompt_templates["test"], template) - - def test_empty_config(self): - """Test accessing non-existent items raises KeyError.""" - config = BaseFlowConfig(name="test_flow") - - with self.assertRaises(KeyError): - _ = config.llm_blocks["nonexistent"] - - with self.assertRaises(KeyError): - _ = config.prompt_templates["nonexistent"] - - -class TestSummaryFlowConfig(unittest.TestCase): - """Test cases for SummaryFlowConfig.""" - - def test_init_with_defaults(self): - """Test initialization with default values.""" - config = SummaryFlowConfig(name="summary_flow") - - self.assertEqual(config.name, "summary_flow") - self.assertEqual(config.chunk_size, 2000) - self.assertEqual(config.use_chunking, True) - self.assertEqual(config.chunk_strategy, ChunkingStrategy.BY_SIZE) - self.assertIsNone(config.chunk_custom_strategy) - - # Check default LLM blocks are created - self.assertIn("cheap_summarizer", config.llm_blocks) - self.assertIn("powerful_combiner", config.llm_blocks) - - # Check default templates are created - self.assertIn("summarize_chunk_template", config.prompt_templates) - self.assertIn("combine_summaries_template", config.prompt_templates) - - def test_init_with_custom_chunk_size(self): - """Test initialization with custom chunk size.""" - config = SummaryFlowConfig(name="summary_flow", chunk_size=1000) - - self.assertEqual(config.chunk_size, 1000) - - def test_default_llm_blocks_configuration(self): - """Test default LLM block configurations.""" - config = SummaryFlowConfig(name="summary_flow") - - cheap_config = config.llm_blocks["cheap_summarizer"] - self.assertEqual(cheap_config.model, "gpt-4o-mini") - self.assertEqual(cheap_config.temperature, 0.3) - self.assertEqual(cheap_config.max_tokens, 1000) - - powerful_config = config.llm_blocks["powerful_combiner"] - self.assertEqual(powerful_config.model, "gpt-4o") - self.assertEqual(powerful_config.temperature, 0.3) - self.assertEqual(powerful_config.max_tokens, 2000) - - def test_default_prompt_templates(self): - """Test default prompt templates are properly set.""" - config = SummaryFlowConfig(name="summary_flow") - - chunk_template = config.prompt_templates["summarize_chunk_template"] - self.assertIn("chunk_text", chunk_template) - self.assertIn("financial research expert", chunk_template.lower()) - - combine_template = config.prompt_templates["combine_summaries_template"] - self.assertIn("summaries", combine_template) - self.assertIn("coherent", combine_template.lower()) - - def test_custom_llm_blocks_preserved(self): - """Test that custom LLM blocks are preserved.""" - custom_llm_blocks = { - "custom_llm": LLMConfig(model="custom-model", temperature=0.7) - } - - config = SummaryFlowConfig( - name="summary_flow", llm_blocks=custom_llm_blocks - ) - - # Custom blocks should be preserved, defaults not added - self.assertEqual(len(config.llm_blocks), 1) - self.assertIn("custom_llm", config.llm_blocks) - self.assertNotIn("cheap_summarizer", config.llm_blocks) - - def test_custom_templates_preserved(self): - """Test that custom templates are preserved.""" - custom_templates = {"custom_template": "Custom template content"} - - config = SummaryFlowConfig( - name="summary_flow", prompt_templates=custom_templates - ) - - # Custom templates should be preserved, defaults not added - self.assertEqual(len(config.prompt_templates), 1) - self.assertIn("custom_template", config.prompt_templates) - self.assertNotIn("summarize_chunk_template", config.prompt_templates) - - def test_mixed_custom_and_default_initialization(self): - """Test initialization with some custom configs.""" - custom_llm_blocks = {"custom_llm": LLMConfig(model="custom-model")} - - config = SummaryFlowConfig( - name="summary_flow", llm_blocks=custom_llm_blocks, chunk_size=1500 - ) - - # Should have custom LLM blocks, not defaults - self.assertEqual(len(config.llm_blocks), 1) - self.assertIn("custom_llm", config.llm_blocks) - - # Should have default templates since none provided - self.assertIn("summarize_chunk_template", config.prompt_templates) - self.assertIn("combine_summaries_template", config.prompt_templates) - - # Custom chunk size should be preserved - self.assertEqual(config.chunk_size, 1500) - - def test_chunking_configuration_options(self): - """Test various chunking configuration options.""" - # Test disabling chunking - config = SummaryFlowConfig(name="summary_flow", use_chunking=False) - self.assertEqual(config.use_chunking, False) - - # Test custom chunk strategy - def custom_chunker(text): - return text.split("\n\n") - - config = SummaryFlowConfig( - name="summary_flow", - chunk_strategy=ChunkingStrategy.BY_CUSTOM, - chunk_custom_strategy=custom_chunker, - ) - self.assertEqual(config.chunk_strategy, ChunkingStrategy.BY_CUSTOM) - self.assertEqual(config.chunk_custom_strategy, custom_chunker) - - def test_unsupported_chunk_strategy_raises_error(self): - """Test that unsupported chunk strategies raise NotImplementedError.""" - with self.assertRaises(NotImplementedError) as context: - SummaryFlowConfig( - name="summary_flow", chunk_strategy=ChunkingStrategy.BY_SECTION - ) - - self.assertIn("not implemented", str(context.exception)) - - def test_chunking_strategy_enum_values(self): - """Test ChunkingStrategy enum values.""" - self.assertEqual(ChunkingStrategy.BY_SIZE.value, "by_size") - self.assertEqual(ChunkingStrategy.BY_SECTION.value, "by_section") - self.assertEqual(ChunkingStrategy.BY_CUSTOM.value, "by_custom") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/config/test_llm.py b/tests/config/test_llm.py deleted file mode 100644 index a5b5d9c..0000000 --- a/tests/config/test_llm.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Tests for LLM configuration.""" - -import unittest - -from quantmind.config.llm import LLMConfig - - -class TestLLMConfig(unittest.TestCase): - """Test cases for LLMConfig.""" - - def test_default_config(self): - """Test default configuration values.""" - config = LLMConfig() - - # Test default values - self.assertEqual(config.model, "gpt-4o") - self.assertEqual(config.temperature, 0.0) - self.assertEqual(config.max_tokens, 4000) - self.assertEqual(config.top_p, 1.0) - self.assertEqual(config.timeout, 60) - self.assertEqual(config.retry_attempts, 3) - self.assertEqual(config.retry_delay, 1.0) - self.assertIsNone(config.api_key) - self.assertIsNone(config.base_url) - self.assertIsNone(config.system_prompt) - self.assertEqual(config.extra_params, {}) - - def test_custom_config(self): - """Test custom configuration values.""" - config = LLMConfig( - model="claude-3-5-sonnet-20241022", - temperature=0.7, - max_tokens=2000, - api_key="test-key", - base_url="https://api.example.com", - system_prompt="You are a helpful assistant.", - extra_params={"frequency_penalty": 0.1}, - ) - - self.assertEqual(config.model, "claude-3-5-sonnet-20241022") - self.assertEqual(config.temperature, 0.7) - self.assertEqual(config.max_tokens, 2000) - self.assertEqual(config.api_key, "test-key") - self.assertEqual(config.base_url, "https://api.example.com") - self.assertEqual(config.system_prompt, "You are a helpful assistant.") - self.assertEqual(config.extra_params, {"frequency_penalty": 0.1}) - - def test_validation_model(self): - """Test model validation.""" - # Valid model - config = LLMConfig(model="gpt-4o") - self.assertEqual(config.model, "gpt-4o") - - # Empty model should raise error - with self.assertRaises(ValueError): - LLMConfig(model="") - - # Whitespace should be stripped - config = LLMConfig(model=" gpt-4o ") - self.assertEqual(config.model, "gpt-4o") - - def test_validation_temperature(self): - """Test temperature validation.""" - # Valid temperatures - LLMConfig(temperature=0.0) - LLMConfig(temperature=1.0) - LLMConfig(temperature=2.0) - - # Invalid temperatures - with self.assertRaises(ValueError): - LLMConfig(temperature=-0.1) - - with self.assertRaises(ValueError): - LLMConfig(temperature=2.1) - - def test_validation_max_tokens(self): - """Test max_tokens validation.""" - # Valid max_tokens - LLMConfig(max_tokens=1) - LLMConfig(max_tokens=4000) - - # Invalid max_tokens - with self.assertRaises(ValueError): - LLMConfig(max_tokens=0) - - with self.assertRaises(ValueError): - LLMConfig(max_tokens=-1) - - def test_get_provider_type(self): - """Test provider type detection.""" - # OpenAI - config = LLMConfig(model="gpt-4o") - self.assertEqual(config.get_provider_type(), "openai") - - config = LLMConfig(model="openai/gpt-4o") - self.assertEqual(config.get_provider_type(), "openai") - - # Anthropic - config = LLMConfig(model="claude-3-5-sonnet-20241022") - self.assertEqual(config.get_provider_type(), "anthropic") - - config = LLMConfig(model="anthropic/claude-3-5-sonnet") - self.assertEqual(config.get_provider_type(), "anthropic") - - # Google - config = LLMConfig(model="gemini-pro") - self.assertEqual(config.get_provider_type(), "google") - - config = LLMConfig(model="google/gemini-pro") - self.assertEqual(config.get_provider_type(), "google") - - # Azure - config = LLMConfig(model="azure/gpt-4o") - self.assertEqual(config.get_provider_type(), "azure") - - # Ollama - config = LLMConfig(model="ollama/llama2") - self.assertEqual(config.get_provider_type(), "ollama") - - # Unknown - config = LLMConfig(model="unknown-model") - self.assertEqual(config.get_provider_type(), "unknown") - - def test_get_litellm_params(self): - """Test LiteLLM parameters generation.""" - config = LLMConfig( - model="gpt-4o", - temperature=0.7, - max_tokens=2000, - api_key="test-key", - base_url="https://api.example.com", - extra_params={"frequency_penalty": 0.1}, - ) - - params = config.get_litellm_params() - - expected_params = { - "model": "gpt-4o", - "temperature": 0.7, - "max_tokens": 2000, - "top_p": 1.0, - "timeout": 60, - "api_key": "test-key", - "base_url": "https://api.example.com", - "frequency_penalty": 0.1, - } - - self.assertEqual(params, expected_params) - - def test_get_litellm_params_minimal(self): - """Test LiteLLM parameters with minimal config.""" - config = LLMConfig() - params = config.get_litellm_params() - - # Since we will automatically resolve the API key, we should remove it from the parameters. - if "api_key" in params: - params.pop("api_key") - - expected_params = { - "model": "gpt-4o", - "temperature": 0.0, - "max_tokens": 4000, - "top_p": 1.0, - "timeout": 60, - } - - self.assertEqual(params, expected_params) - - def test_create_variant(self): - """Test creating configuration variants.""" - base_config = LLMConfig( - model="gpt-4o", temperature=0.0, api_key="base-key" - ) - - # Create variant with overrides - variant = base_config.create_variant( - temperature=0.7, max_tokens=2000, api_key="variant-key" - ) - - # Check variant has overrides - self.assertEqual(variant.temperature, 0.7) - self.assertEqual(variant.max_tokens, 2000) - self.assertEqual(variant.api_key, "variant-key") - - # Check variant keeps non-overridden values - self.assertEqual(variant.model, "gpt-4o") - - # Check original config is unchanged - self.assertEqual(base_config.temperature, 0.0) - self.assertEqual(base_config.max_tokens, 4000) - self.assertEqual(base_config.api_key, "base-key") - - def test_api_key_validation(self): - """Test API key validation.""" - # Valid API key - config = LLMConfig(api_key="test-key") - self.assertEqual(config.api_key, "test-key") - - # None API key is valid - config = LLMConfig(api_key=None) - self.assertIsNone(config.api_key) - - # Non-string API key should raise error - with self.assertRaises(ValueError): - LLMConfig(api_key=123) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/config/test_settings.py b/tests/config/test_settings.py deleted file mode 100644 index 1bf69f5..0000000 --- a/tests/config/test_settings.py +++ /dev/null @@ -1,336 +0,0 @@ -"""Unit tests for settings configuration system.""" - -import os -import shutil -import tempfile -import unittest -from unittest.mock import patch - -from quantmind.config.llm import LLMConfig -from quantmind.config.parsers import LlamaParserConfig, PDFParserConfig -from quantmind.config.settings import ( - Setting, - create_default_config, - load_config, -) -from quantmind.config.sources import ArxivSourceConfig -from quantmind.config.storage import LocalStorageConfig -from quantmind.config.taggers import LLMTaggerConfig - - -# TODO (whisper): No tests for flows integration. -class TestSetting(unittest.TestCase): - """Comprehensive test cases for Setting configuration system.""" - - def test_default_setting(self): - """Test creating Setting with default values.""" - setting = Setting() - - self.assertEqual(setting.log_level, "INFO") - self.assertIsNone(setting.source) - self.assertIsNone(setting.parser) - self.assertIsNone(setting.tagger) - self.assertIsInstance(setting.storage, LocalStorageConfig) - self.assertIsInstance(setting.llm, LLMConfig) - - def test_setting_with_components(self): - """Test creating Setting with component configurations.""" - # Create component configs - source_config = ArxivSourceConfig(max_results=50) - parser_config = PDFParserConfig(method="pdfplumber") - tagger_config = LLMTaggerConfig(max_tags=10) - - setting = Setting( - source=source_config, parser=parser_config, tagger=tagger_config - ) - - self.assertIsInstance(setting.source, ArxivSourceConfig) - self.assertEqual(setting.source.max_results, 50) - - self.assertIsInstance(setting.parser, PDFParserConfig) - self.assertEqual(setting.parser.method, "pdfplumber") - - self.assertIsInstance(setting.tagger, LLMTaggerConfig) - self.assertEqual(setting.tagger.max_tags, 10) - - def test_parse_config_with_components(self): - """Test parsing configuration dictionary with various components.""" - config_dict = { - "source": { - "type": "arxiv", - "config": {"max_results": 50, "sort_by": "relevance"}, - }, - "parser": { - "type": "pdf", - "config": { - "method": "pdfplumber", - "download_pdfs": True, - "max_file_size_mb": 25, - }, - }, - "tagger": { - "type": "llm", - "config": {"max_tags": 8, "model": "gpt-4o"}, - }, - "log_level": "DEBUG", - } - - setting = Setting._parse_config(config_dict) - - # Test source parsing - self.assertIsInstance(setting.source, ArxivSourceConfig) - self.assertEqual(setting.source.max_results, 50) - self.assertEqual(setting.source.sort_by, "relevance") - - # Test parser parsing - self.assertIsInstance(setting.parser, PDFParserConfig) - self.assertEqual(setting.parser.method, "pdfplumber") - self.assertTrue(setting.parser.download_pdfs) - self.assertEqual(setting.parser.max_file_size_mb, 25) - - # Test tagger parsing - self.assertIsInstance(setting.tagger, LLMTaggerConfig) - self.assertEqual(setting.tagger.max_tags, 8) - self.assertEqual(setting.tagger.llm_config.model, "gpt-4o") - - # Test simple fields - self.assertEqual(setting.log_level, "DEBUG") - - if setting.storage.storage_dir.exists(): - shutil.rmtree(setting.storage.storage_dir) - - def test_parse_config_unknown_types(self): - """Test parsing configuration with unknown component types.""" - config_dict = { - "source": { - "type": "unknown_source", - "config": {"some_param": "value"}, - }, - "parser": {"type": "llama", "config": {"result_type": "markdown"}}, - } - - setting = Setting._parse_config(config_dict) - - # Unknown source should be ignored - self.assertIsNone(setting.source) - - # Known parser should be parsed - self.assertIsInstance(setting.parser, LlamaParserConfig) - self.assertEqual(setting.parser.result_type, "markdown") - - def test_create_default_config(self): - """Test creating default configuration.""" - setting = create_default_config() - - # Test default source - self.assertIsInstance(setting.source, ArxivSourceConfig) - self.assertEqual(setting.source.max_results, 100) - self.assertEqual(setting.source.sort_by, "submittedDate") - self.assertEqual(setting.source.sort_order, "descending") - - # Test default parser - self.assertIsInstance(setting.parser, PDFParserConfig) - self.assertEqual(setting.parser.method, "pymupdf") - self.assertTrue(setting.parser.download_pdfs) - self.assertTrue(setting.parser.extract_tables) - - # Test default storage - self.assertIsInstance(setting.storage, LocalStorageConfig) - - # Test default values - self.assertEqual(setting.log_level, "INFO") - self.assertIsInstance(setting.llm, LLMConfig) - - def test_load_config_yaml_file(self): - """Test loading configuration from YAML file.""" - config_dict = { - "source": { - "type": "arxiv", - "config": {"max_results": 25, "sort_by": "relevance"}, - }, - "parser": { - "type": "pdf", - "config": {"method": "pdfplumber", "download_pdfs": False}, - }, - "log_level": "WARNING", - } - - # Test the actual YAML loading with temporary file - with tempfile.NamedTemporaryFile( - mode="w", suffix=".yaml", delete=False - ) as f: - import yaml - - yaml.dump(config_dict, f) - temp_path = f.name - - try: - setting = load_config(temp_path) - - # Verify loaded configuration - self.assertIsInstance(setting.source, ArxivSourceConfig) - self.assertEqual(setting.source.max_results, 25) - self.assertEqual(setting.source.sort_by, "relevance") - - self.assertIsInstance(setting.parser, PDFParserConfig) - self.assertEqual(setting.parser.method, "pdfplumber") - self.assertFalse(setting.parser.download_pdfs) - - self.assertEqual(setting.log_level, "WARNING") - - finally: - # Clean up - os.unlink(temp_path) - - def test_load_config_nonexistent_file(self): - """Test loading configuration from nonexistent file.""" - with self.assertRaises(FileNotFoundError): - load_config("nonexistent.yaml") - - def test_substitute_env_vars(self): - """Test environment variable substitution in configuration.""" - # Set up environment variables - os.environ["TEST_VAR"] = "test_value" - os.environ["API_KEY"] = "secret_key" - - config_dict = { - "source": { - "type": "arxiv", - "config": { - "api_key": "${API_KEY}", - "max_results": "${MAX_RESULTS:50}", # with default - }, - }, - } - - result = Setting.substitute_env_vars(config_dict) - - # Test substitution - self.assertEqual(result["source"]["config"]["api_key"], "secret_key") - self.assertEqual( - result["source"]["config"]["max_results"], "50" - ) # default used - - # Clean up - del os.environ["TEST_VAR"] - del os.environ["API_KEY"] - - def test_substitute_env_vars_nested(self): - """Test environment variable substitution in nested structures.""" - os.environ["NESTED_VAR"] = "nested_value" - - config_dict = { - "components": { - "parser": { - "config": { - "nested_list": ["${NESTED_VAR}", "static_value"], - "nested_dict": {"key": "${NESTED_VAR}"}, - } - } - } - } - - result = Setting.substitute_env_vars(config_dict) - - self.assertEqual( - result["components"]["parser"]["config"]["nested_list"][0], - "nested_value", - ) - self.assertEqual( - result["components"]["parser"]["config"]["nested_dict"]["key"], - "nested_value", - ) - - # Clean up - del os.environ["NESTED_VAR"] - - @patch.dict(os.environ, {}, clear=True) - def test_substitute_env_vars_defaults(self): - """Test environment variable substitution with defaults when vars don't exist.""" - config_dict = { - "api_key": "${MISSING_KEY:default_key}", - "no_default": "${MISSING_NO_DEFAULT}", - "mixed": "prefix_${MISSING_WITH_DEFAULT:default}_suffix", - } - - result = Setting.substitute_env_vars(config_dict) - - self.assertEqual(result["api_key"], "default_key") - self.assertEqual( - result["no_default"], "" - ) # empty string when no default - self.assertEqual(result["mixed"], "prefix_default_suffix") - - def test_export_config(self): - """Test exporting configuration to dictionary.""" - setting = Setting( - source=ArxivSourceConfig(max_results=30), - parser=PDFParserConfig(method="pdfplumber", download_pdfs=True), - tagger=LLMTaggerConfig(max_tags=5), - log_level="DEBUG", - ) - - config_dict = setting._export_config() - - # Test component export - self.assertEqual(config_dict["source"]["type"], "arxiv") - self.assertEqual(config_dict["source"]["config"]["max_results"], 30) - - self.assertEqual(config_dict["parser"]["type"], "pdf") - self.assertEqual( - config_dict["parser"]["config"]["method"], "pdfplumber" - ) - self.assertTrue(config_dict["parser"]["config"]["download_pdfs"]) - - self.assertEqual(config_dict["tagger"]["type"], "llm") - self.assertEqual(config_dict["tagger"]["config"]["max_tags"], 5) - - # Test simple fields - self.assertEqual(config_dict["log_level"], "DEBUG") - - # Test sensitive data exclusion - self.assertNotIn("api_key", config_dict["llm"]) - - assert setting.storage.storage_dir.exists() - if setting.storage.storage_dir.exists(): - shutil.rmtree(setting.storage.storage_dir) - assert not setting.storage.storage_dir.exists() - - def test_save_to_yaml(self): - """Test saving configuration to YAML file.""" - setting = Setting( - source=ArxivSourceConfig(max_results=20), - parser=PDFParserConfig(method="pymupdf"), - ) - - # Test saving to temporary file - with tempfile.NamedTemporaryFile( - mode="w", suffix=".yaml", delete=False - ) as f: - temp_path = f.name - - try: - setting.save_to_yaml(temp_path) - - # Verify file was created and contains expected content - with open(temp_path, "r") as f: - import yaml - - saved_config = yaml.safe_load(f) - - self.assertEqual(saved_config["source"]["type"], "arxiv") - self.assertEqual( - saved_config["source"]["config"]["max_results"], 20 - ) - self.assertEqual(saved_config["parser"]["type"], "pdf") - self.assertEqual( - saved_config["parser"]["config"]["method"], "pymupdf" - ) - - finally: - # Clean up - os.unlink(temp_path) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/config/test_storage.py b/tests/config/test_storage.py deleted file mode 100644 index ffc433c..0000000 --- a/tests/config/test_storage.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Tests for storage configuration models.""" - -import shutil -import tempfile -import unittest -from pathlib import Path - -from quantmind.config.storage import LocalStorageConfig - - -class TestLocalStorageConfig(unittest.TestCase): - """Test LocalStorageConfig functionality.""" - - def setUp(self): - """Set up test environment with temporary directory.""" - self.temp_dir = Path(tempfile.mkdtemp()) - self.addCleanup(self._cleanup_temp_dir) - - def _cleanup_temp_dir(self): - """Clean up temporary directory after test.""" - if self.temp_dir.exists(): - shutil.rmtree(self.temp_dir) - - def test_default_configuration(self): - """Test default configuration values.""" - config = LocalStorageConfig() - - # Check default storage directory - self.assertEqual( - config.storage_dir, Path("./data").expanduser().resolve() - ) - - # Check that directory was created - self.assertTrue(config.storage_dir.exists()) - - # Clean up default directory - if config.storage_dir.exists() and config.storage_dir.name == "data": - shutil.rmtree(config.storage_dir) - - def test_custom_storage_directory(self): - """Test custom storage directory configuration.""" - custom_dir = self.temp_dir / "custom_storage" - - config = LocalStorageConfig(storage_dir=custom_dir) - - # Check that custom directory is set and resolved - self.assertEqual(config.storage_dir, custom_dir.resolve()) - self.assertTrue(config.storage_dir.exists()) - - def test_model_post_init_creates_directories(self): - """Test that model_post_init creates all required subdirectories.""" - storage_dir = self.temp_dir / "test_storage" - - LocalStorageConfig(storage_dir=storage_dir) - - # Check that main directory exists - self.assertTrue(storage_dir.exists()) - - # Check that all subdirectories were created - self.assertTrue((storage_dir / "raw_files").exists()) - self.assertTrue((storage_dir / "knowledges").exists()) - self.assertTrue((storage_dir / "embeddings").exists()) - self.assertTrue((storage_dir / "extra").exists()) - - def test_directory_properties(self): - """Test directory property methods.""" - storage_dir = self.temp_dir / "property_test" - - config = LocalStorageConfig(storage_dir=storage_dir) - - # Test raw_files_dir property - use resolve() for consistent path comparison - expected_raw_files = (storage_dir / "raw_files").resolve() - self.assertEqual(config.raw_files_dir.resolve(), expected_raw_files) - self.assertTrue(config.raw_files_dir.exists()) - - # Test knowledges_dir property - expected_knowledges = (storage_dir / "knowledges").resolve() - self.assertEqual(config.knowledges_dir.resolve(), expected_knowledges) - self.assertTrue(config.knowledges_dir.exists()) - - # Test embeddings_dir property - expected_embeddings = (storage_dir / "embeddings").resolve() - self.assertEqual(config.embeddings_dir.resolve(), expected_embeddings) - self.assertTrue(config.embeddings_dir.exists()) - - # Test extra_dir property - expected_extra = (storage_dir / "extra").resolve() - self.assertEqual(config.extra_dir.resolve(), expected_extra) - self.assertTrue(config.extra_dir.exists()) - - def test_path_expansion_and_resolution(self): - """Test that paths are properly expanded and resolved.""" - # Test with relative path - relative_path = Path("./relative_storage") - config = LocalStorageConfig(storage_dir=relative_path) - - # Should be converted to absolute path - self.assertTrue(config.storage_dir.is_absolute()) - self.assertEqual(config.storage_dir.name, "relative_storage") - - # Clean up - if config.storage_dir.exists(): - shutil.rmtree(config.storage_dir) - - def test_home_directory_expansion(self): - """Test that ~ in paths is properly expanded.""" - # Test with home directory path - home_path = Path("~/test_quantmind_storage") - config = LocalStorageConfig(storage_dir=home_path) - - # Should expand ~ to actual home directory - self.assertFalse(str(config.storage_dir).startswith("~")) - self.assertTrue(config.storage_dir.is_absolute()) - self.assertTrue( - str(config.storage_dir).endswith("test_quantmind_storage") - ) - - # Clean up - if config.storage_dir.exists(): - shutil.rmtree(config.storage_dir) - - def test_existing_directory_handling(self): - """Test behavior when storage directory already exists.""" - storage_dir = self.temp_dir / "existing_storage" - storage_dir.mkdir(parents=True, exist_ok=True) - - # Create some existing subdirectories - (storage_dir / "raw_files").mkdir(exist_ok=True) - (storage_dir / "custom_subdir").mkdir(exist_ok=True) - - # Initialize config with existing directory - config = LocalStorageConfig(storage_dir=storage_dir) - - # Should not fail and should create missing subdirectories - self.assertTrue(config.storage_dir.exists()) - self.assertTrue(config.raw_files_dir.exists()) - self.assertTrue(config.knowledges_dir.exists()) - self.assertTrue(config.embeddings_dir.exists()) - self.assertTrue(config.extra_dir.exists()) - - # Custom subdirectory should still exist - self.assertTrue((storage_dir / "custom_subdir").exists()) - - def test_nested_path_creation(self): - """Test creation of deeply nested storage paths.""" - nested_dir = self.temp_dir / "level1" / "level2" / "level3" / "storage" - - config = LocalStorageConfig(storage_dir=nested_dir) - - # Should create all parent directories - self.assertTrue(nested_dir.exists()) - self.assertTrue(config.raw_files_dir.exists()) - self.assertTrue(config.knowledges_dir.exists()) - self.assertTrue(config.embeddings_dir.exists()) - self.assertTrue(config.extra_dir.exists()) - - def test_string_path_input(self): - """Test that string paths are properly converted to Path objects.""" - storage_dir_str = str(self.temp_dir / "string_input") - - config = LocalStorageConfig(storage_dir=storage_dir_str) - - # Should convert string to Path and work properly - self.assertIsInstance(config.storage_dir, Path) - self.assertTrue(config.storage_dir.exists()) - self.assertTrue(config.raw_files_dir.exists()) - - def test_directory_permissions(self): - """Test that created directories have proper permissions.""" - storage_dir = self.temp_dir / "permission_test" - - config = LocalStorageConfig(storage_dir=storage_dir) - - # Check that directories are readable and writable - self.assertTrue(storage_dir.is_dir()) - self.assertTrue(config.raw_files_dir.is_dir()) - self.assertTrue(config.knowledges_dir.is_dir()) - self.assertTrue(config.embeddings_dir.is_dir()) - self.assertTrue(config.extra_dir.is_dir()) - - # Test that we can write to directories - test_file = config.raw_files_dir / "test.txt" - test_file.write_text("test content") - self.assertTrue(test_file.exists()) - self.assertEqual(test_file.read_text(), "test content") - - def test_model_dump_functionality(self): - """Test that model can be serialized properly.""" - storage_dir = self.temp_dir / "dump_test" - - config = LocalStorageConfig(storage_dir=storage_dir) - - # Test model_dump - dumped = config.model_dump() - self.assertIn("storage_dir", dumped) - self.assertEqual(Path(dumped["storage_dir"]), storage_dir.resolve()) - - def test_multiple_config_instances(self): - """Test that multiple config instances work independently.""" - dir1 = self.temp_dir / "storage1" - dir2 = self.temp_dir / "storage2" - - config1 = LocalStorageConfig(storage_dir=dir1) - config2 = LocalStorageConfig(storage_dir=dir2) - - # Both should exist and be different - self.assertTrue(config1.storage_dir.exists()) - self.assertTrue(config2.storage_dir.exists()) - self.assertNotEqual(config1.storage_dir, config2.storage_dir) - - # Both should have their own subdirectories - self.assertTrue(config1.raw_files_dir.exists()) - self.assertTrue(config2.raw_files_dir.exists()) - self.assertNotEqual(config1.raw_files_dir, config2.raw_files_dir) - - def test_reconfiguration(self): - """Test that config can be updated after creation.""" - storage_dir1 = self.temp_dir / "original" - storage_dir2 = self.temp_dir / "updated" - - # Create initial config - config = LocalStorageConfig(storage_dir=storage_dir1) - self.assertTrue(storage_dir1.exists()) - - # Update storage directory - config.storage_dir = storage_dir2 - config.model_post_init(None) # Manually trigger post_init - - # New directory should be created - self.assertTrue(storage_dir2.exists()) - self.assertTrue((storage_dir2 / "raw_files").exists()) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/flows/__init__.py b/tests/flows/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/flows/test_batch.py b/tests/flows/test_batch.py new file mode 100644 index 0000000..6df8c5c --- /dev/null +++ b/tests/flows/test_batch.py @@ -0,0 +1,160 @@ +"""Tests for ``quantmind.flows.batch``.""" + +import asyncio +import unittest +from typing import Any + +from quantmind.configs import PaperFlowCfg +from quantmind.configs.paper import RawText +from quantmind.flows.batch import BatchResult, batch_run + + +class BatchResultPropertiesTests(unittest.TestCase): + def test_successes_and_failures_views(self) -> None: + result: BatchResult[str] = BatchResult( + total=4, + success_count=2, + failure_count=2, + results=["a", None, "c", None], + errors=[(1, ValueError("b")), (3, RuntimeError("d"))], + duration_seconds=0.0, + ) + self.assertEqual(result.successes, [(0, "a"), (2, "c")]) + self.assertEqual( + [(i, type(e).__name__) for i, e in result.failures], + [(1, "ValueError"), (3, "RuntimeError")], + ) + + +class BatchRunTests(unittest.IsolatedAsyncioTestCase): + async def test_happy_path(self) -> None: + async def flow(input: RawText, *, cfg: Any = None) -> str: + return f"ok:{input.text}" + + inputs = [RawText(text=str(i)) for i in range(5)] + result = await batch_run(flow, inputs, concurrency=3) + self.assertEqual(result.total, 5) + self.assertEqual(result.success_count, 5) + self.assertEqual(result.failure_count, 0) + self.assertEqual(result.results, [f"ok:{i}" for i in range(5)]) + self.assertEqual(result.errors, []) + self.assertGreaterEqual(result.duration_seconds, 0) + self.assertEqual(result.tokens_total, {}) + self.assertEqual(result.cost_estimate_usd, 0.0) + + async def test_empty_inputs(self) -> None: + async def flow(input: RawText, *, cfg: Any = None) -> str: + return "x" + + result = await batch_run(flow, [], concurrency=4) + self.assertEqual(result.total, 0) + self.assertEqual(result.success_count, 0) + self.assertEqual(result.results, []) + + async def test_on_error_skip_collects(self) -> None: + async def flow(input: RawText, *, cfg: Any = None) -> str: + if input.text in ("2", "4"): + raise ValueError(f"bad:{input.text}") + return f"ok:{input.text}" + + inputs = [RawText(text=str(i)) for i in range(5)] + result = await batch_run(flow, inputs, on_error="skip") + self.assertEqual(result.success_count, 3) + self.assertEqual(result.failure_count, 2) + self.assertEqual(result.results[0], "ok:0") + self.assertIsNone(result.results[2]) + self.assertIsNone(result.results[4]) + self.assertEqual([i for i, _ in result.errors], [2, 4]) + + async def test_on_error_raise_propagates(self) -> None: + boom = RuntimeError("boom") + + async def flow(input: RawText, *, cfg: Any = None) -> str: + raise boom + + with self.assertRaises(RuntimeError) as ctx: + await batch_run( + flow, + [RawText(text="x")], + concurrency=1, + on_error="raise", + ) + self.assertIs(ctx.exception, boom) + + async def test_memory_kwarg_rejected(self) -> None: + async def flow(input: RawText, *, cfg: Any = None) -> str: + return "x" + + with self.assertRaises(ValueError) as ctx: + await batch_run(flow, [RawText(text="x")], memory=object()) + self.assertIn("memory", str(ctx.exception)) + + async def test_concurrency_must_be_at_least_one(self) -> None: + async def flow(input: RawText, *, cfg: Any = None) -> str: + return "x" + + with self.assertRaises(ValueError): + await batch_run(flow, [RawText(text="x")], concurrency=0) + + async def test_concurrency_cap_honoured(self) -> None: + in_flight = 0 + peak = 0 + + async def flow(input: RawText, *, cfg: Any = None) -> str: + nonlocal in_flight, peak + in_flight += 1 + peak = max(peak, in_flight) + await asyncio.sleep(0.01) # let scheduler interleave + in_flight -= 1 + return "ok" + + await batch_run( + flow, + [RawText(text=str(i)) for i in range(10)], + concurrency=3, + ) + self.assertLessEqual(peak, 3) + self.assertGreater(peak, 0) + + async def test_on_progress_called_per_completion(self) -> None: + calls: list[tuple[int, int]] = [] + + async def flow(input: RawText, *, cfg: Any = None) -> str: + return "ok" + + inputs = [RawText(text=str(i)) for i in range(5)] + await batch_run( + flow, + inputs, + concurrency=2, + on_progress=lambda done, total: calls.append((done, total)), + ) + self.assertEqual(len(calls), 5) + # `done` strictly increases, ends at total. + dones = [c[0] for c in calls] + self.assertEqual(dones, sorted(dones)) + self.assertEqual(dones[-1], 5) + self.assertTrue(all(t == 5 for _, t in calls)) + + async def test_cfg_forwarded(self) -> None: + seen_cfg: list[Any] = [] + + async def flow(input: RawText, *, cfg: Any = None) -> str: + seen_cfg.append(cfg) + return "ok" + + cfg = PaperFlowCfg(model="sentinel-model") + await batch_run(flow, [RawText(text="x")], cfg=cfg, concurrency=1) + self.assertIs(seen_cfg[0], cfg) + + async def test_extra_kwargs_forwarded(self) -> None: + seen: list[Any] = [] + + async def flow( + input: RawText, *, cfg: Any = None, marker: str = "" + ) -> str: + seen.append(marker) + return "ok" + + await batch_run(flow, [RawText(text="x")], concurrency=1, marker="here") + self.assertEqual(seen, ["here"]) diff --git a/tests/flows/test_paper.py b/tests/flows/test_paper.py new file mode 100644 index 0000000..2a143c7 --- /dev/null +++ b/tests/flows/test_paper.py @@ -0,0 +1,359 @@ +"""Tests for ``quantmind.flows.paper``.""" + +import unittest +from datetime import datetime, timezone +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +from agents import RunHooks + +from quantmind.configs import PaperFlowCfg +from quantmind.configs.paper import ( + ArxivIdentifier, + DoiIdentifier, + HttpUrl, + LocalFilePath, + RawText, +) +from quantmind.flows.paper import ( + UnsupportedContentTypeError, + _compose_instructions, + _fetch_and_format, + _format_by_content_type, + _format_input, + paper_flow, +) +from quantmind.knowledge import Paper, SourceRef, TreeNode +from quantmind.preprocess.fetch import Fetched, RawPaper + + +def _stub_paper() -> Paper: + root_id = uuid4() + root = TreeNode(node_id=root_id, title="root", summary="stub") + return Paper( + as_of=datetime(2026, 5, 7, tzinfo=timezone.utc), + source=SourceRef( + kind="arxiv", + uri="arxiv:2604.12345", + fetched_at=datetime(2026, 5, 7, tzinfo=timezone.utc), + ), + root_node_id=root_id, + nodes={root_id: root}, + ) + + +def _patch_runner(return_value: Any) -> Any: + return patch( + "quantmind.flows.paper.run_with_observability", + new=AsyncMock(return_value=return_value), + ) + + +class FormatByContentTypeTests(unittest.IsolatedAsyncioTestCase): + async def test_pdf_dispatches_to_pdf_to_markdown(self) -> None: + raw = Fetched(bytes=b"%PDF-x", content_type="application/pdf") + with patch( + "quantmind.flows.paper.pdf_to_markdown", + new=AsyncMock(return_value="MD"), + ) as pdf_mock: + md = await _format_by_content_type(raw) + pdf_mock.assert_awaited_once_with(b"%PDF-x") + self.assertEqual(md, "MD") + + async def test_html_dispatches_to_html_to_markdown(self) -> None: + raw = Fetched( + bytes="hi".encode("utf-8"), + content_type="text/html; charset=utf-8", + ) + with patch( + "quantmind.flows.paper.html_to_markdown", + new=AsyncMock(return_value="HTML-MD"), + ) as html_mock: + md = await _format_by_content_type(raw) + html_mock.assert_awaited_once_with("hi") + self.assertEqual(md, "HTML-MD") + + async def test_markdown_passes_through(self) -> None: + raw = Fetched( + bytes=b"# heading\n\nbody", + content_type="text/markdown", + ) + md = await _format_by_content_type(raw) + self.assertEqual(md, "# heading\n\nbody") + + async def test_plain_text_passes_through(self) -> None: + raw = Fetched(bytes=b"plain", content_type="text/plain") + md = await _format_by_content_type(raw) + self.assertEqual(md, "plain") + + async def test_unsupported_content_type_raises(self) -> None: + raw = Fetched(bytes=b"\x00\x00", content_type="application/zip") + with self.assertRaises(UnsupportedContentTypeError): + await _format_by_content_type(raw) + + +class FetchAndFormatTests(unittest.IsolatedAsyncioTestCase): + async def test_arxiv_branch(self) -> None: + raw_paper = RawPaper( + bytes=b"%PDF", + content_type="application/pdf", + source_url="http://arxiv.org/pdf/2604.12345.pdf", + arxiv_id="2604.12345", + title="Momentum", + authors=("Alice", "Bob"), + ) + with ( + patch( + "quantmind.flows.paper.fetch_arxiv", + new=AsyncMock(return_value=raw_paper), + ) as fetch_mock, + patch( + "quantmind.flows.paper.pdf_to_markdown", + new=AsyncMock(return_value="MARKDOWN"), + ) as fmt_mock, + ): + md, meta = await _fetch_and_format(ArxivIdentifier(id="2604.12345")) + fetch_mock.assert_awaited_once_with("2604.12345") + fmt_mock.assert_awaited_once_with(b"%PDF") + self.assertEqual(md, "MARKDOWN") + self.assertEqual(meta["source"], "arxiv") + self.assertEqual(meta["arxiv_id"], "2604.12345") + self.assertEqual(meta["title"], "Momentum") + self.assertEqual(meta["authors"], ["Alice", "Bob"]) + + async def test_http_pdf_branch(self) -> None: + raw = Fetched( + bytes=b"%PDF", + content_type="application/pdf", + source_url="http://example/x.pdf", + ) + with ( + patch( + "quantmind.flows.paper.fetch_url", + new=AsyncMock(return_value=raw), + ) as fetch_mock, + patch( + "quantmind.flows.paper.pdf_to_markdown", + new=AsyncMock(return_value="PDFMD"), + ), + ): + md, meta = await _fetch_and_format( + HttpUrl(url="http://example/x.pdf") + ) + fetch_mock.assert_awaited_once_with("http://example/x.pdf") + self.assertEqual(md, "PDFMD") + self.assertEqual(meta["source"], "web") + self.assertEqual(meta["content_type"], "application/pdf") + + async def test_local_file_branch(self) -> None: + raw = Fetched( + bytes=b"## body", + content_type="text/markdown", + source_url="file:///tmp/p.md", + ) + with patch( + "quantmind.flows.paper.read_local_file", + new=AsyncMock(return_value=raw), + ) as read_mock: + md, meta = await _fetch_and_format( + LocalFilePath(path=Path("/tmp/p.md")) + ) + read_mock.assert_awaited_once_with(Path("/tmp/p.md")) + self.assertEqual(md, "## body") + self.assertEqual(meta["source"], "local") + self.assertEqual(meta["path"], "/tmp/p.md") + self.assertEqual(meta["content_type"], "text/markdown") + + async def test_raw_text_branch(self) -> None: + md, meta = await _fetch_and_format(RawText(text="hello")) + self.assertEqual(md, "hello") + self.assertEqual(meta, {"source": "inline"}) + + async def test_doi_branch_raises_not_implemented(self) -> None: + with self.assertRaises(NotImplementedError) as ctx: + await _fetch_and_format(DoiIdentifier(doi="10.1234/abcd")) + self.assertIn("DOI", str(ctx.exception)) + + +class ComposeInstructionsTests(unittest.TestCase): + def test_default_renders_cfg_flags(self) -> None: + cfg = PaperFlowCfg( + extract_methodology=False, + extract_limitations=True, + asset_class_hint="equities", + ) + out = _compose_instructions( + "go {extract_methodology} {extract_limitations} " + "{asset_class_hint!r}", + None, + cfg, + ) + self.assertEqual(out, "go False True 'equities'") + + def test_extra_appended(self) -> None: + cfg = PaperFlowCfg() + out = _compose_instructions("BASE", "USER", cfg) + self.assertIn("BASE", out) + self.assertIn("Additional instructions:", out) + self.assertIn("USER", out) + + +class FormatInputTests(unittest.TestCase): + def test_tuple_authors_join_as_csv(self) -> None: + out = _format_input( + "BODY", + {"authors": ("Alice", "Bob"), "title": "x"}, + ) + self.assertIn("authors: Alice, Bob", out) + self.assertIn("title: x", out) + self.assertIn("BODY", out) + + def test_none_values_skipped(self) -> None: + out = _format_input("BODY", {"a": "1", "b": None}) + self.assertIn("a: 1", out) + self.assertNotIn("b:", out) + + +class PaperFlowTests(unittest.IsolatedAsyncioTestCase): + async def test_happy_path_arxiv(self) -> None: + raw_paper = RawPaper( + bytes=b"%PDF", + content_type="application/pdf", + arxiv_id="2604.12345", + ) + stub = _stub_paper() + with ( + patch( + "quantmind.flows.paper.fetch_arxiv", + new=AsyncMock(return_value=raw_paper), + ), + patch( + "quantmind.flows.paper.pdf_to_markdown", + new=AsyncMock(return_value="MD"), + ), + _patch_runner(stub) as runner, + ): + out = await paper_flow(ArxivIdentifier(id="2604.12345")) + self.assertIs(out, stub) + runner.assert_awaited_once() + + async def test_extra_instructions_passed_to_agent(self) -> None: + seen: dict[str, Any] = {} + + def _capture_agent(*_a: Any, **kwargs: Any) -> Any: + seen.update(kwargs) + return MagicMock(name="agent", _name="paper_extractor") + + stub = _stub_paper() + with ( + patch("quantmind.flows.paper.Agent", side_effect=_capture_agent), + _patch_runner(stub), + ): + await paper_flow( + RawText(text="hello"), + extra_instructions="EXTRA-USER-DIRECTIVE", + ) + self.assertIn("EXTRA-USER-DIRECTIVE", seen["instructions"]) + self.assertIn("structured QuantMind", seen["instructions"]) + + async def test_output_type_override_propagated(self) -> None: + seen: dict[str, Any] = {} + + class MyPaper(Paper): + pass + + def _capture_agent(*_a: Any, **kwargs: Any) -> Any: + seen.update(kwargs) + return MagicMock() + + with ( + patch("quantmind.flows.paper.Agent", side_effect=_capture_agent), + _patch_runner(_stub_paper()), + ): + await paper_flow(RawText(text="x"), output_type=MyPaper) + self.assertIs(seen["output_type"], MyPaper) + + async def test_extra_tools_and_guardrails_forwarded(self) -> None: + seen: dict[str, Any] = {} + + def _capture_agent(*_a: Any, **kwargs: Any) -> Any: + seen.update(kwargs) + return MagicMock() + + sentinel_tool = MagicMock(name="tool") + in_g = MagicMock() + out_g = MagicMock() + with ( + patch("quantmind.flows.paper.Agent", side_effect=_capture_agent), + _patch_runner(_stub_paper()), + ): + await paper_flow( + RawText(text="x"), + extra_tools=[sentinel_tool], + extra_input_guardrails=[in_g], + extra_output_guardrails=[out_g], + ) + self.assertEqual(seen["tools"], [sentinel_tool]) + self.assertEqual(seen["input_guardrails"], [in_g]) + self.assertEqual(seen["output_guardrails"], [out_g]) + + async def test_memory_accepted_as_no_op(self) -> None: + with ( + patch( + "quantmind.flows.paper.Agent", + return_value=MagicMock(), + ), + _patch_runner(_stub_paper()) as runner, + ): + await paper_flow(RawText(text="x"), memory=object()) + # The runner sees the memory placeholder forwarded. + self.assertIsNotNone(runner.await_args.kwargs["memory"]) + + async def test_extra_run_hooks_forwarded(self) -> None: + class _H(RunHooks[Any]): + pass + + hook = _H() + with ( + patch( + "quantmind.flows.paper.Agent", + return_value=MagicMock(), + ), + _patch_runner(_stub_paper()) as runner, + ): + await paper_flow(RawText(text="x"), extra_run_hooks=[hook]) + self.assertEqual(runner.await_args.kwargs["extra_run_hooks"], [hook]) + + async def test_model_settings_forwarded_when_set(self) -> None: + seen: dict[str, Any] = {} + + def _capture_agent(*_a: Any, **kwargs: Any) -> Any: + seen.update(kwargs) + return MagicMock() + + from agents import ModelSettings + + ms = ModelSettings(temperature=0.42) + cfg = PaperFlowCfg(model_settings=ms) + with ( + patch("quantmind.flows.paper.Agent", side_effect=_capture_agent), + _patch_runner(_stub_paper()), + ): + await paper_flow(RawText(text="x"), cfg=cfg) + self.assertIs(seen["model_settings"], ms) + + async def test_model_settings_omitted_when_none(self) -> None: + seen: dict[str, Any] = {} + + def _capture_agent(*_a: Any, **kwargs: Any) -> Any: + seen.update(kwargs) + return MagicMock() + + with ( + patch("quantmind.flows.paper.Agent", side_effect=_capture_agent), + _patch_runner(_stub_paper()), + ): + await paper_flow(RawText(text="x")) + self.assertNotIn("model_settings", seen) diff --git a/tests/flows/test_runner.py b/tests/flows/test_runner.py new file mode 100644 index 0000000..4ece594 --- /dev/null +++ b/tests/flows/test_runner.py @@ -0,0 +1,199 @@ +"""Tests for ``quantmind.flows._runner``.""" + +import unittest +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +from agents import RunHooks + +from quantmind.configs import PaperFlowCfg +from quantmind.flows._runner import ( + _archive_run_artifacts, + _collect_hooks, + _compose_hooks, + _CompositeRunHooks, + run_with_observability, +) + + +class _RecordingHooks(RunHooks[Any]): + """Test hook that records every lifecycle call on a shared list.""" + + def __init__(self, label: str, log: list[tuple[str, str]]) -> None: + self.label = label + self.log = log + + async def on_llm_start(self, *_: Any, **__: Any) -> None: + self.log.append((self.label, "on_llm_start")) + + async def on_llm_end(self, *_: Any, **__: Any) -> None: + self.log.append((self.label, "on_llm_end")) + + async def on_agent_start(self, *_: Any, **__: Any) -> None: + self.log.append((self.label, "on_agent_start")) + + async def on_agent_end(self, *_: Any, **__: Any) -> None: + self.log.append((self.label, "on_agent_end")) + + async def on_handoff(self, *_: Any, **__: Any) -> None: + self.log.append((self.label, "on_handoff")) + + async def on_tool_start(self, *_: Any, **__: Any) -> None: + self.log.append((self.label, "on_tool_start")) + + async def on_tool_end(self, *_: Any, **__: Any) -> None: + self.log.append((self.label, "on_tool_end")) + + +class ComposeHooksTests(unittest.TestCase): + def test_empty_returns_none(self) -> None: + self.assertIsNone(_compose_hooks([])) + + def test_single_returns_same_instance(self) -> None: + hook = _RecordingHooks("a", []) + self.assertIs(_compose_hooks([hook]), hook) + + def test_multiple_returns_composite(self) -> None: + a = _RecordingHooks("a", []) + b = _RecordingHooks("b", []) + composed = _compose_hooks([a, b]) + self.assertIsInstance(composed, _CompositeRunHooks) + + +class CompositeRunHooksTests(unittest.IsolatedAsyncioTestCase): + async def test_fan_out_in_registration_order(self) -> None: + log: list[tuple[str, str]] = [] + a = _RecordingHooks("a", log) + b = _RecordingHooks("b", log) + composite = _CompositeRunHooks([a, b]) + await composite.on_llm_start() + await composite.on_llm_end() + await composite.on_agent_start() + await composite.on_agent_end() + await composite.on_handoff() + await composite.on_tool_start() + await composite.on_tool_end() + # Each method fires for both hooks in registration order. + for method in ( + "on_llm_start", + "on_llm_end", + "on_agent_start", + "on_agent_end", + "on_handoff", + "on_tool_start", + "on_tool_end", + ): + self.assertEqual( + [entry for entry in log if entry[1] == method], + [("a", method), ("b", method)], + ) + + async def test_earlier_hook_exception_short_circuits(self) -> None: + class _Boom(RunHooks[Any]): + async def on_llm_start(self, *_: Any, **__: Any) -> None: + raise RuntimeError("boom") + + log: list[tuple[str, str]] = [] + composite = _CompositeRunHooks([_Boom(), _RecordingHooks("b", log)]) + with self.assertRaises(RuntimeError): + await composite.on_llm_start() + self.assertEqual(log, []) + + +class CollectHooksTests(unittest.TestCase): + def test_memory_contributes_nothing_in_pr5(self) -> None: + extra = _RecordingHooks("a", []) + # PR5: memory is forwarded but unused. + self.assertEqual(_collect_hooks(None, [extra]), [extra]) + self.assertEqual(_collect_hooks(object(), [extra]), [extra]) + + def test_no_extras_returns_empty(self) -> None: + self.assertEqual(_collect_hooks(None, []), []) + + +class ArchiveStubTests(unittest.TestCase): + def test_archive_is_no_op(self) -> None: + cfg = PaperFlowCfg() + result = MagicMock() + # Must not raise, must return None, must not touch result. + self.assertIsNone(_archive_run_artifacts(cfg, None, result)) + result.assert_not_called() + + +class RunWithObservabilityTests(unittest.IsolatedAsyncioTestCase): + async def test_run_config_built_from_cfg(self) -> None: + cfg = PaperFlowCfg( + model="gpt-test", + max_turns=7, + workflow_name="custom-name", + trace_metadata={"k": "v"}, + trace_include_sensitive_data=False, + tracing_disabled=True, + ) + agent = MagicMock(name="agent") + agent.name = "paper_extractor" + fake_result = MagicMock() + fake_result.final_output = "OUT" + with patch( + "quantmind.flows._runner.Runner.run", + new=AsyncMock(return_value=fake_result), + ) as run_mock: + out = await run_with_observability( + agent, + "prompt", + cfg=cfg, + memory=None, + extra_run_hooks=[], + ) + self.assertEqual(out, "OUT") + run_mock.assert_awaited_once() + call = run_mock.await_args + self.assertIs(call.args[0], agent) + self.assertEqual(call.args[1], "prompt") + self.assertEqual(call.kwargs["max_turns"], 7) + run_cfg = call.kwargs["run_config"] + self.assertEqual(run_cfg.workflow_name, "custom-name") + self.assertEqual(run_cfg.trace_metadata, {"k": "v"}) + self.assertFalse(run_cfg.trace_include_sensitive_data) + self.assertTrue(run_cfg.tracing_disabled) + # No hooks supplied -> Runner.run sees None. + self.assertIsNone(call.kwargs["hooks"]) + + async def test_workflow_name_falls_back_to_agent_name(self) -> None: + cfg = PaperFlowCfg() # workflow_name = None + agent = MagicMock() + agent.name = "paper_extractor" + fake_result = MagicMock() + fake_result.final_output = None + with patch( + "quantmind.flows._runner.Runner.run", + new=AsyncMock(return_value=fake_result), + ) as run_mock: + await run_with_observability( + agent, "x", cfg=cfg, memory=None, extra_run_hooks=[] + ) + self.assertEqual( + run_mock.await_args.kwargs["run_config"].workflow_name, + "quantmind.paper_extractor", + ) + + async def test_extra_hooks_forwarded(self) -> None: + cfg = PaperFlowCfg() + agent = MagicMock() + agent.name = "x" + fake_result = MagicMock() + fake_result.final_output = None + hook = _RecordingHooks("a", []) + with patch( + "quantmind.flows._runner.Runner.run", + new=AsyncMock(return_value=fake_result), + ) as run_mock: + await run_with_observability( + agent, + "x", + cfg=cfg, + memory=object(), # PR6 placeholder + extra_run_hooks=[hook], + ) + # Single hook -> passed through as-is, not wrapped in composite. + self.assertIs(run_mock.await_args.kwargs["hooks"], hook) diff --git a/tests/models/test_paper.py b/tests/models/test_paper.py deleted file mode 100644 index 54a0ad9..0000000 --- a/tests/models/test_paper.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Tests for Paper model.""" - -from datetime import datetime - -import pytest - -from quantmind.models.paper import Paper - - -class TestPaper: - """Test cases for the Paper model.""" - - def test_paper_creation(self): - """Test basic paper creation.""" - paper = Paper(title="Test Paper", abstract="This is a test abstract") - - assert paper.title == "Test Paper" - assert paper.abstract == "This is a test abstract" - assert isinstance(paper.id, str) - assert len(paper.categories) == 0 - assert len(paper.tags) == 0 - - def test_paper_with_metadata(self): - """Test paper creation with metadata.""" - published_date = datetime(2023, 1, 15) - - paper = Paper( - title="Advanced ML Paper", - abstract="This paper discusses advanced machine learning techniques", - authors=["John Doe", "Jane Smith"], - published_date=published_date, - categories=["Machine Learning"], - tags=["deep learning", "neural networks"], - arxiv_id="2301.12345", - ) - - assert paper.title == "Advanced ML Paper" - assert len(paper.authors) == 2 - assert paper.published_date == published_date - assert "Machine Learning" in paper.categories - assert "deep learning" in paper.tags - assert paper.arxiv_id == "2301.12345" - - def test_add_category(self): - """Test adding categories.""" - paper = Paper(title="Test", abstract="Test abstract") - - paper.add_category("Finance") - paper.add_category("Machine Learning") - paper.add_category("Finance") # Duplicate - - assert len(paper.categories) == 2 - assert "Finance" in paper.categories - assert "Machine Learning" in paper.categories - - def test_add_tag(self): - """Test adding tags.""" - paper = Paper(title="Test", abstract="Test abstract") - - paper.add_tag("lstm") - paper.add_tag("trading") - paper.add_tag("lstm") # Duplicate - - assert len(paper.tags) == 2 - assert "lstm" in paper.tags - assert "trading" in paper.tags - - def test_get_text_for_embedding(self): - """Test text extraction for embedding.""" - paper = Paper( - title="ML in Finance", - abstract="Machine learning applications in financial markets", - ) - - text = paper.get_text_for_embedding() - expected = "ML in Finance\n\nMachine learning applications in financial markets" - assert text == expected - - def test_set_embedding(self): - """Test setting embedding.""" - paper = Paper(title="Test", abstract="Test abstract") - embedding = [0.1, 0.2, 0.3, 0.4, 0.5] - - paper.set_embedding(embedding, "text-embedding-ada-002") - - assert paper.embedding == embedding - assert paper.embedding_model == "text-embedding-ada-002" - - def test_has_content(self): - """Test content availability check.""" - paper1 = Paper(title="Test", abstract="Test abstract") - paper2 = Paper( - title="Test", - abstract="Test abstract", - content="Full paper content", - ) - paper3 = Paper(title="Test", abstract="Test abstract", content=" ") - - assert not paper1.has_content() - assert paper2.has_content() - assert not paper3.has_content() - - def test_content_field(self): - """Test content field functionality.""" - paper = Paper(title="Test", abstract="Test abstract") - - # Test setting content - paper.content = "This is the full content" - assert paper.content == "This is the full content" - assert paper.has_content() - - # Test content modification - paper.content = "Updated content" - assert paper.content == "Updated content" - assert paper.has_content() - - # Test content is None initially - paper2 = Paper(title="Test 2", abstract="Test abstract 2") - assert paper2.content is None - assert not paper2.has_content() - - def test_get_primary_id(self): - """Test primary ID extraction.""" - paper1 = Paper( - title="Test", abstract="Test abstract", arxiv_id="2301.12345" - ) - paper2 = Paper( - title="Test", abstract="Test abstract", paper_id="custom_id" - ) - paper3 = Paper(title="Test", abstract="Test abstract") - - assert paper1.get_primary_id() == "2301.12345" - assert paper2.get_primary_id() == "custom_id" - assert paper3.get_primary_id() == paper3.id - - def test_from_dict(self): - """Test creating paper from dictionary.""" - data = { - "title": "Test Paper", - "abstract": "Test abstract", - "authors": ["Author 1", "Author 2"], - "categories": ["AI", "ML"], - "arxiv_id": "2301.12345", - "published_date": "2023-01-15T00:00:00", - } - - paper = Paper.from_dict(data) - - assert paper.title == "Test Paper" - assert len(paper.authors) == 2 - assert len(paper.categories) == 2 - assert paper.arxiv_id == "2301.12345" - assert isinstance(paper.published_date, datetime) - - def test_dict_conversion(self): - """Test paper to dictionary conversion.""" - paper = Paper( - title="Test Paper", - abstract="Test abstract", - authors=["Author 1"], - categories=["AI"], - tags=["test"], - ) - - data = paper.dict() - - assert data["title"] == "Test Paper" - assert data["abstract"] == "Test abstract" - assert data["authors"] == ["Author 1"] - assert data["categories"] == ["AI"] - assert data["tags"] == ["test"] - - def test_authors_parsing(self): - """Test author parsing from various formats.""" - # String format - paper1 = Paper( - title="Test", abstract="Test", authors="John Doe, Jane Smith" - ) - assert len(paper1.authors) == 2 - assert "John Doe" in paper1.authors - - # List format - paper2 = Paper( - title="Test", abstract="Test", authors=["John Doe", "Jane Smith"] - ) - assert len(paper2.authors) == 2 - - # Empty - paper3 = Paper(title="Test", abstract="Test", authors=None) - assert len(paper3.authors) == 0 - - def test_validation(self): - """Test paper validation.""" - # Valid paper - paper1 = Paper( - title="Valid Title", - abstract="Valid abstract with sufficient length", - ) - assert len(paper1.title) >= 1 - assert len(paper1.abstract) >= 1 - - # Test minimum requirements are enforced by Pydantic - with pytest.raises(ValueError): - Paper(title="", abstract="Valid abstract") - - # Test empty abstract is allowed - Paper(title="Valid title", abstract="") - - def test_string_representations(self): - """Test string representations.""" - paper = Paper( - title="Test Paper", abstract="Test abstract", arxiv_id="2301.12345" - ) - - str_repr = str(paper) - repr_repr = repr(paper) - - assert "2301.12345" in str_repr - assert "Test Paper" in str_repr - assert "Test Paper" in repr_repr diff --git a/tests/test_magic.py b/tests/test_magic.py new file mode 100644 index 0000000..426f6cd --- /dev/null +++ b/tests/test_magic.py @@ -0,0 +1,222 @@ +"""Tests for ``quantmind.magic``.""" + +import io +import json +import unittest +from contextlib import redirect_stdout +from typing import Optional, Union +from unittest.mock import AsyncMock, MagicMock, patch + +from pydantic import BaseModel + +from quantmind.configs import PaperFlowCfg +from quantmind.configs.paper import ArxivIdentifier, PaperInput +from quantmind.flows import paper_flow +from quantmind.magic import ( + ResolvedFlowConfig, + _introspect_flow_signature, + _pydantic_schema_str, + _strip_optional, + preview_resolve, + resolve_magic_input, +) + + +class StripOptionalTests(unittest.TestCase): + def test_optional_t(self) -> None: + self.assertIs(_strip_optional(Optional[int]), int) + + def test_pep604_union_with_none(self) -> None: + self.assertIs(_strip_optional(int | None), int) + + def test_plain_t_unchanged(self) -> None: + self.assertIs(_strip_optional(int), int) + + def test_union_without_none_unchanged(self) -> None: + anno = Union[int, str] + self.assertEqual(_strip_optional(anno), anno) + + +class IntrospectFlowSignatureTests(unittest.TestCase): + def test_paper_flow_returns_paper_input_and_cfg(self) -> None: + input_type, cfg_type = _introspect_flow_signature(paper_flow) + self.assertIs(cfg_type, PaperFlowCfg) + # PaperInput is the Annotated[Union[...]] alias; pass through. + self.assertEqual(input_type, PaperInput) + + def test_missing_input_param_raises(self) -> None: + async def bad(*, cfg: PaperFlowCfg | None = None) -> None: + return None + + with self.assertRaises(TypeError): + _introspect_flow_signature(bad) + + def test_missing_cfg_param_raises(self) -> None: + async def bad(input: ArxivIdentifier) -> None: + return None + + with self.assertRaises(TypeError): + _introspect_flow_signature(bad) + + def test_cfg_annotation_must_be_baseflowcfg(self) -> None: + async def bad(input: ArxivIdentifier, *, cfg: int = 0) -> None: + return None + + with self.assertRaises(TypeError): + _introspect_flow_signature(bad) + + +class PydanticSchemaStrTests(unittest.TestCase): + def test_basemodel_renders_json_schema(self) -> None: + # PaperFlowCfg embeds ModelSettings which has callable fields; + # the renderer falls back to a fields summary in that case. + out = _pydantic_schema_str(PaperFlowCfg) + parsed = json.loads(out) + self.assertEqual(parsed.get("title"), "PaperFlowCfg") + self.assertIn("model", parsed["fields"]) + + def test_basemodel_with_clean_schema(self) -> None: + class Clean(BaseModel): + x: int = 0 + + out = _pydantic_schema_str(Clean) + parsed = json.loads(out) + # Clean schema path -> emits standard "properties" key. + self.assertIn("properties", parsed) + + def test_annotated_union_emits_one_of(self) -> None: + out = _pydantic_schema_str(PaperInput) + parsed = json.loads(out) + self.assertIn("oneOf", parsed) + # PaperInput has 5 variants; not all need schema-rendering, but + # the rendered list should be non-empty. + self.assertGreater(len(parsed["oneOf"]), 0) + + def test_baseinput_subclass_directly(self) -> None: + out = _pydantic_schema_str(ArxivIdentifier) + parsed = json.loads(out) + self.assertEqual(parsed["properties"]["type"]["default"], "arxiv") + + def test_fallback_for_unknown_type(self) -> None: + out = _pydantic_schema_str(int) + # Falls back to repr. + self.assertEqual(out, repr(int)) + + +class ResolveMagicInputTests(unittest.IsolatedAsyncioTestCase): + async def test_happy_path_returns_tuple(self) -> None: + captured: dict[str, object] = {} + + def _capture_agent(*_a: object, **kwargs: object) -> object: + captured.update(kwargs) + return MagicMock(name="agent") + + # Build a fake resolver result whose final_output is a populated + # ResolvedFlowConfig. + resolved = ResolvedFlowConfig[PaperInput, PaperFlowCfg]( + input_obj=ArxivIdentifier(id="2604.12345"), + cfg_obj=PaperFlowCfg(model="gpt-test"), + ) + fake_result = MagicMock() + fake_result.final_output = resolved + with ( + patch("quantmind.magic.Agent", side_effect=_capture_agent), + patch( + "quantmind.magic.Runner.run", + new=AsyncMock(return_value=fake_result), + ), + ): + inp, cfg = await resolve_magic_input( + "fetch arxiv 2604.12345 about momentum", + target_flow=paper_flow, + ) + self.assertIs(inp, resolved.input_obj) + self.assertIs(cfg, resolved.cfg_obj) + # Resolver agent was given a name derived from the flow. + self.assertEqual(captured["name"], "magic_resolver_paper_flow") + self.assertEqual(captured["model"], "gpt-4o-mini") + + async def test_custom_resolver_instructions(self) -> None: + captured: dict[str, object] = {} + + def _capture_agent(*_a: object, **kwargs: object) -> object: + captured.update(kwargs) + return MagicMock() + + resolved = ResolvedFlowConfig[PaperInput, PaperFlowCfg]( + input_obj=ArxivIdentifier(id="x"), + cfg_obj=PaperFlowCfg(), + ) + fake_result = MagicMock() + fake_result.final_output = resolved + template = "FLOW={flow_name} INPUT={input_schema} CFG={cfg_schema}" + with ( + patch("quantmind.magic.Agent", side_effect=_capture_agent), + patch( + "quantmind.magic.Runner.run", + new=AsyncMock(return_value=fake_result), + ), + ): + await resolve_magic_input( + "x", + target_flow=paper_flow, + resolver_instructions=template, + ) + instructions = captured["instructions"] + assert isinstance(instructions, str) + self.assertTrue(instructions.startswith("FLOW=paper_flow")) + self.assertIn("INPUT=", instructions) + self.assertIn("CFG=", instructions) + + async def test_custom_resolver_model_used(self) -> None: + captured: dict[str, object] = {} + + def _capture_agent(*_a: object, **kwargs: object) -> object: + captured.update(kwargs) + return MagicMock() + + resolved = ResolvedFlowConfig[PaperInput, PaperFlowCfg]( + input_obj=ArxivIdentifier(id="x"), + cfg_obj=PaperFlowCfg(), + ) + fake_result = MagicMock() + fake_result.final_output = resolved + with ( + patch("quantmind.magic.Agent", side_effect=_capture_agent), + patch( + "quantmind.magic.Runner.run", + new=AsyncMock(return_value=fake_result), + ), + ): + await resolve_magic_input( + "x", + target_flow=paper_flow, + resolver_model="claude-3-5-sonnet", + ) + self.assertEqual(captured["model"], "claude-3-5-sonnet") + + +class PreviewResolveTests(unittest.IsolatedAsyncioTestCase): + async def test_prints_and_returns_tuple(self) -> None: + resolved = ResolvedFlowConfig[PaperInput, PaperFlowCfg]( + input_obj=ArxivIdentifier(id="2604.12345"), + cfg_obj=PaperFlowCfg(), + ) + fake_result = MagicMock() + fake_result.final_output = resolved + with ( + patch("quantmind.magic.Agent", return_value=MagicMock()), + patch( + "quantmind.magic.Runner.run", + new=AsyncMock(return_value=fake_result), + ), + ): + buf = io.StringIO() + with redirect_stdout(buf): + inp, cfg = await preview_resolve("x", target_flow=paper_flow) + self.assertIs(inp, resolved.input_obj) + self.assertIs(cfg, resolved.cfg_obj) + out = buf.getvalue() + self.assertIn("input_obj:", out) + self.assertIn("cfg_obj:", out) + self.assertIn("2604.12345", out)