diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9da40d6..8e5a945 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,3 +15,12 @@ repos: - id: ruff args: [--fix] - id: ruff-format + + - repo: local + hooks: + - id: mypy-protocols + name: mypy (protocols + monitor) + language: system + entry: uv run mypy + args: [src/paperscout/protocols.py, src/paperscout/monitor.py, src/paperscout/__main__.py] + pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index 1db4818..82c229b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ packages = ["src/paperscout"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] +norecursedirs = [".git", ".tox", "dist", "build", "*.egg", "typing"] markers = [ "benchmark: ISO probe cycle performance regression (run: pytest benchmarks/ -m benchmark)", ] @@ -93,3 +94,11 @@ module = [ "fitz.*", ] ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = [ + "paperscout.protocols", + "paperscout.monitor", + "paperscout.__main__", +] +strict = true diff --git a/src/paperscout/__main__.py b/src/paperscout/__main__.py index dca6078..c0a42a6 100644 --- a/src/paperscout/__main__.py +++ b/src/paperscout/__main__.py @@ -8,13 +8,16 @@ import signal import sys import threading +from collections.abc import Callable from datetime import datetime, timezone from pathlib import Path +from typing import Any, cast from .config import settings from .db import init_db, init_pool from .health import start_health_server -from .monitor import Scheduler +from .monitor import PollResult, Scheduler +from .protocols import DataSource, OpsAlertFn from .scout import ( MessageQueue, create_app, @@ -40,7 +43,7 @@ ) -def _mq_health_fields(mq: MessageQueue) -> dict: +def _mq_health_fields(mq: MessageQueue) -> dict[str, Any]: """MQ metrics for /health; from health_fields() when present, else depth only.""" if hasattr(mq, "health_fields"): try: @@ -72,13 +75,13 @@ def _mq_health_fields(mq: MessageQueue) -> dict: def _merge_extra_health_fields( - scheduler_snap: dict, - mq_extra: dict, - db_pool: dict, -) -> dict: + scheduler_snap: dict[str, Any], + mq_extra: dict[str, Any], + db_pool: dict[str, Any], +) -> dict[str, Any]: """Merge health JSON with scheduler winning on key conflicts.""" scheduler_keys = set(scheduler_snap) - mq_filtered: dict = {} + mq_filtered: dict[str, Any] = {} for key, value in mq_extra.items(): if key in _MQ_HEALTH_FIELD_NAMES: if key in scheduler_keys: @@ -137,17 +140,27 @@ def _register_shutdown_signals( ) -> None: """Register SIGTERM/SIGINT handlers that set *shutdown_event*.""" - def _on_signal(signame: str) -> None: - """Record the first shutdown signal and wake the scheduler.""" - if shutdown_reason[0] is None: - shutdown_reason[0] = signame - shutdown_event.set() + def _bind_shutdown_handler(signame: str) -> Callable[[], None]: + def handler() -> None: + if shutdown_reason[0] is None: + shutdown_reason[0] = signame + shutdown_event.set() + + return handler + + def _bind_sys_signal_handler(signame: str) -> Callable[..., None]: + def handler(*_a: object) -> None: + if shutdown_reason[0] is None: + shutdown_reason[0] = signame + shutdown_event.set() + + return handler for sig, name in ((signal.SIGTERM, "SIGTERM"), (signal.SIGINT, "SIGINT")): try: - loop.add_signal_handler(sig, lambda n=name: _on_signal(n)) + loop.add_signal_handler(sig, _bind_shutdown_handler(name)) except NotImplementedError: - signal.signal(sig, lambda *_a, n=name: _on_signal(n)) + signal.signal(sig, _bind_sys_signal_handler(name)) async def _async_main() -> None: @@ -222,7 +235,7 @@ async def _async_main() -> None: user_watchlist = UserWatchlist(pool) index = WG21Index(pool, cfg=settings) prober = ISOProber(index, state, user_watchlist) - sources: list = [index, prober] + sources: list[DataSource] = [index, prober] if settings.enable_open_std: sources.append(OpenStdSource()) app = create_app() @@ -232,7 +245,7 @@ async def _async_main() -> None: def paper_count_fn() -> int: return len(index.papers) - def _on_poll_result(result): + def _on_poll_result(result: PollResult) -> None: notify_channel(app, result, mq) notify_users(app, result, mq) @@ -243,9 +256,9 @@ def _ops_alert(msg: str) -> None: f":rotating_light: PaperScout alert: {msg}", ) - def _pool_status(p) -> dict: + def _pool_status(p: Any) -> dict[str, Any]: """Best-effort pool stats (psycopg2 ThreadedConnectionPool uses private attrs).""" - status: dict = {"max": getattr(p, "maxconn", None)} + status: dict[str, Any] = {"max": getattr(p, "maxconn", None)} try: status["in_use"] = len(p._used) status["available"] = len(p._pool) @@ -260,10 +273,10 @@ def _pool_status(p) -> dict: state=state, cfg=settings, notify_callback=_on_poll_result, - ops_alert_fn=_ops_alert, + ops_alert_fn=cast(OpsAlertFn, _ops_alert), ) - def _extra_health_fields() -> dict: + def _extra_health_fields() -> dict[str, Any]: return _merge_extra_health_fields( scheduler.health_snapshot(), _mq_health_fields(mq), diff --git a/src/paperscout/monitor.py b/src/paperscout/monitor.py index 4c6ccec..643bd0b 100644 --- a/src/paperscout/monitor.py +++ b/src/paperscout/monitor.py @@ -8,7 +8,7 @@ import logging import threading import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime, timezone from types import MappingProxyType @@ -20,7 +20,14 @@ from .config import Settings, settings from .errors import ConfigurationError, FailureCategory from .models import CycleResult, CycleStatus, Paper, PerUserMatches, ProbeHit -from .protocols import SOURCE_ISO_PROBE, SOURCE_OPEN_STD, SOURCE_WG21_INDEX, DataSource +from .protocols import ( + SOURCE_ISO_PROBE, + SOURCE_OPEN_STD, + SOURCE_WG21_INDEX, + DataSource, + NotifyCallback, + OpsAlertFn, +) from .sources import ISOProber, OpenStdEntry, WG21Index from .storage import ProbeState, UserWatchlist @@ -163,15 +170,15 @@ def __init__( user_watchlist: UserWatchlist, state: ProbeState, cfg: Settings | None = None, - notify_callback=None, - ops_alert_fn: Callable[[str], None] | None = None, + notify_callback: NotifyCallback | None = None, + ops_alert_fn: OpsAlertFn | None = None, ): self.sources = list(sources) self.user_watchlist = user_watchlist self.state = state self.cfg = cfg or settings - self.notify_callback = notify_callback - self.ops_alert_fn = ops_alert_fn + self.notify_callback: NotifyCallback | None = notify_callback + self.ops_alert_fn: OpsAlertFn | None = ops_alert_fn self._snapshots: dict[str, Any] = {} self._seeded = False self._poll_count = 0 diff --git a/src/paperscout/protocols.py b/src/paperscout/protocols.py index 601b077..5c4a30c 100644 --- a/src/paperscout/protocols.py +++ b/src/paperscout/protocols.py @@ -1,4 +1,4 @@ -"""Structural typing contracts for pluggable data sources. +"""Structural typing contracts for pluggable data sources and callbacks. Known ``source_id`` values: @@ -9,7 +9,10 @@ from __future__ import annotations -from typing import Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from .monitor import PollResult # Well-known source identifiers (stable across releases). SOURCE_WG21_INDEX = "wg21_index" @@ -31,3 +34,17 @@ async def fetch(self) -> Any: def diff(self, previous: Any, current: Any) -> Any: """Compare *previous* and *current* snapshots; return source-specific diff.""" ... + + +@runtime_checkable +class NotifyCallback(Protocol): + """Deliver a completed poll result (channel + per-user notifications).""" + + def __call__(self, result: PollResult) -> None: ... + + +@runtime_checkable +class OpsAlertFn(Protocol): + """Surface operational errors to operators. Caller catches exceptions.""" + + def __call__(self, message: str) -> None: ... diff --git a/tests/test_callback_protocols.py b/tests/test_callback_protocols.py new file mode 100644 index 0000000..bca724e --- /dev/null +++ b/tests/test_callback_protocols.py @@ -0,0 +1,85 @@ +"""Tests for NotifyCallback and OpsAlertFn protocol conformance.""" + +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +from paperscout.models import CycleResult, CycleStatus +from paperscout.monitor import DiffResult, PollResult, Scheduler +from paperscout.protocols import SOURCE_ISO_PROBE, SOURCE_WG21_INDEX, NotifyCallback, OpsAlertFn +from paperscout.sources import ISOProber, WG21Index +from paperscout.storage import ProbeState, UserWatchlist +from tests.conftest import make_test_settings + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_INVALID_CALLBACKS = _REPO_ROOT / "tests" / "typing" / "invalid_callbacks.py" + + +def test_notify_callback_protocol_isinstance() -> None: + def ok(result: PollResult) -> None: + del result + + assert isinstance(ok, NotifyCallback) + + +def test_ops_alert_fn_protocol_isinstance() -> None: + def ok(message: str) -> None: + del message + + assert isinstance(ok, OpsAlertFn) + + +async def test_scheduler_accepts_protocol_callbacks(fake_pool) -> None: + notified: list[PollResult] = [] + alerts: list[str] = [] + + def on_notify(result: PollResult) -> None: + notified.append(result) + + def on_ops_alert(message: str) -> None: + alerts.append(message) + + wg21 = MagicMock(spec=WG21Index) + wg21.source_id = SOURCE_WG21_INDEX + wg21.fetch = AsyncMock(return_value={}) + wg21.diff = MagicMock(return_value=DiffResult([], [])) + + iso = MagicMock(spec=ISOProber) + iso.source_id = SOURCE_ISO_PROBE + iso.fetch = AsyncMock(return_value=CycleResult(CycleStatus.EMPTY)) + iso.diff = MagicMock(return_value=[]) + iso.snapshot_stats = MagicMock(return_value={}) + + scheduler = Scheduler( + sources=[wg21, iso], + user_watchlist=MagicMock(spec=UserWatchlist), + state=ProbeState(fake_pool), + cfg=make_test_settings(), + notify_callback=on_notify, + ops_alert_fn=on_ops_alert, + ) + assert isinstance(scheduler.notify_callback, NotifyCallback) + assert isinstance(scheduler.ops_alert_fn, OpsAlertFn) + + wg21.papers = {} + await scheduler.poll_once() + assert notified == [] + + +def test_mypy_rejects_invalid_callbacks() -> None: + env = {**os.environ, "MYPYPATH": str(_REPO_ROOT / "src")} + result = subprocess.run( + [sys.executable, "-m", "mypy", str(_INVALID_CALLBACKS)], + cwd=_REPO_ROOT, + capture_output=True, + text=True, + check=False, + env=env, + ) + assert result.returncode != 0 + combined = result.stdout + result.stderr + assert "Incompatible types" in combined diff --git a/tests/typing/invalid_callbacks.py b/tests/typing/invalid_callbacks.py new file mode 100644 index 0000000..74326eb --- /dev/null +++ b/tests/typing/invalid_callbacks.py @@ -0,0 +1,21 @@ +"""Deliberate mypy failures for callback protocol conformance. + +Not collected by pytest (see pyproject.toml norecursedirs). Verified by +``test_mypy_rejects_invalid_callbacks`` in tests/test_callback_protocols.py. +""" + +from __future__ import annotations + +from paperscout.protocols import NotifyCallback, OpsAlertFn + + +def bad_notify(x: str) -> None: + del x + + +def bad_ops_alert(n: int) -> None: + del n + + +_reject_notify: NotifyCallback = bad_notify +_reject_ops: OpsAlertFn = bad_ops_alert