Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ repos:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: local
hooks:
- id: mypy-protocols
name: mypy (protocols + monitor)
language: system
entry: uv run mypy
args: [src/paperscout/protocols.py, src/paperscout/monitor.py, src/paperscout/__main__.py]
pass_filenames: false
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ packages = ["src/paperscout"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
norecursedirs = [".git", ".tox", "dist", "build", "*.egg", "typing"]
markers = [
"benchmark: ISO probe cycle performance regression (run: pytest benchmarks/ -m benchmark)",
]
Expand Down Expand Up @@ -93,3 +94,11 @@ module = [
"fitz.*",
]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"paperscout.protocols",
"paperscout.monitor",
"paperscout.__main__",
]
strict = true
53 changes: 33 additions & 20 deletions src/paperscout/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import signal
import sys
import threading
from collections.abc import Callable
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, cast

from .config import settings
from .db import init_db, init_pool
from .health import start_health_server
from .monitor import Scheduler
from .monitor import PollResult, Scheduler
from .protocols import DataSource, OpsAlertFn
from .scout import (
MessageQueue,
create_app,
Expand All @@ -40,7 +43,7 @@
)


def _mq_health_fields(mq: MessageQueue) -> dict:
def _mq_health_fields(mq: MessageQueue) -> dict[str, Any]:
"""MQ metrics for /health; from health_fields() when present, else depth only."""
if hasattr(mq, "health_fields"):
try:
Expand Down Expand Up @@ -72,13 +75,13 @@ def _mq_health_fields(mq: MessageQueue) -> dict:


def _merge_extra_health_fields(
scheduler_snap: dict,
mq_extra: dict,
db_pool: dict,
) -> dict:
scheduler_snap: dict[str, Any],
mq_extra: dict[str, Any],
db_pool: dict[str, Any],
) -> dict[str, Any]:
"""Merge health JSON with scheduler winning on key conflicts."""
scheduler_keys = set(scheduler_snap)
mq_filtered: dict = {}
mq_filtered: dict[str, Any] = {}
for key, value in mq_extra.items():
if key in _MQ_HEALTH_FIELD_NAMES:
if key in scheduler_keys:
Expand Down Expand Up @@ -137,17 +140,27 @@ def _register_shutdown_signals(
) -> None:
"""Register SIGTERM/SIGINT handlers that set *shutdown_event*."""

def _on_signal(signame: str) -> None:
"""Record the first shutdown signal and wake the scheduler."""
if shutdown_reason[0] is None:
shutdown_reason[0] = signame
shutdown_event.set()
def _bind_shutdown_handler(signame: str) -> Callable[[], None]:
def handler() -> None:
if shutdown_reason[0] is None:
shutdown_reason[0] = signame
shutdown_event.set()

return handler

def _bind_sys_signal_handler(signame: str) -> Callable[..., None]:
def handler(*_a: object) -> None:
if shutdown_reason[0] is None:
shutdown_reason[0] = signame
shutdown_event.set()

return handler

for sig, name in ((signal.SIGTERM, "SIGTERM"), (signal.SIGINT, "SIGINT")):
try:
loop.add_signal_handler(sig, lambda n=name: _on_signal(n))
loop.add_signal_handler(sig, _bind_shutdown_handler(name))
except NotImplementedError:
signal.signal(sig, lambda *_a, n=name: _on_signal(n))
signal.signal(sig, _bind_sys_signal_handler(name))


async def _async_main() -> None:
Expand Down Expand Up @@ -222,7 +235,7 @@ async def _async_main() -> None:
user_watchlist = UserWatchlist(pool)
index = WG21Index(pool, cfg=settings)
prober = ISOProber(index, state, user_watchlist)
sources: list = [index, prober]
sources: list[DataSource] = [index, prober]
if settings.enable_open_std:
sources.append(OpenStdSource())
app = create_app()
Expand All @@ -232,7 +245,7 @@ async def _async_main() -> None:
def paper_count_fn() -> int:
return len(index.papers)

def _on_poll_result(result):
def _on_poll_result(result: PollResult) -> None:
notify_channel(app, result, mq)
notify_users(app, result, mq)

Expand All @@ -243,9 +256,9 @@ def _ops_alert(msg: str) -> None:
f":rotating_light: PaperScout alert: {msg}",
)

def _pool_status(p) -> dict:
def _pool_status(p: Any) -> dict[str, Any]:
"""Best-effort pool stats (psycopg2 ThreadedConnectionPool uses private attrs)."""
status: dict = {"max": getattr(p, "maxconn", None)}
status: dict[str, Any] = {"max": getattr(p, "maxconn", None)}
try:
status["in_use"] = len(p._used)
status["available"] = len(p._pool)
Expand All @@ -260,10 +273,10 @@ def _pool_status(p) -> dict:
state=state,
cfg=settings,
notify_callback=_on_poll_result,
ops_alert_fn=_ops_alert,
ops_alert_fn=cast(OpsAlertFn, _ops_alert),
)

def _extra_health_fields() -> dict:
def _extra_health_fields() -> dict[str, Any]:
return _merge_extra_health_fields(
scheduler.health_snapshot(),
_mq_health_fields(mq),
Expand Down
19 changes: 13 additions & 6 deletions src/paperscout/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import threading
import time
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from types import MappingProxyType
Expand All @@ -20,7 +20,14 @@
from .config import Settings, settings
from .errors import ConfigurationError, FailureCategory
from .models import CycleResult, CycleStatus, Paper, PerUserMatches, ProbeHit
from .protocols import SOURCE_ISO_PROBE, SOURCE_OPEN_STD, SOURCE_WG21_INDEX, DataSource
from .protocols import (
SOURCE_ISO_PROBE,
SOURCE_OPEN_STD,
SOURCE_WG21_INDEX,
DataSource,
NotifyCallback,
OpsAlertFn,
)
from .sources import ISOProber, OpenStdEntry, WG21Index
from .storage import ProbeState, UserWatchlist

Expand Down Expand Up @@ -163,15 +170,15 @@ def __init__(
user_watchlist: UserWatchlist,
state: ProbeState,
cfg: Settings | None = None,
notify_callback=None,
ops_alert_fn: Callable[[str], None] | None = None,
notify_callback: NotifyCallback | None = None,
ops_alert_fn: OpsAlertFn | None = None,
):
self.sources = list(sources)
self.user_watchlist = user_watchlist
self.state = state
self.cfg = cfg or settings
self.notify_callback = notify_callback
self.ops_alert_fn = ops_alert_fn
self.notify_callback: NotifyCallback | None = notify_callback
self.ops_alert_fn: OpsAlertFn | None = ops_alert_fn
self._snapshots: dict[str, Any] = {}
self._seeded = False
self._poll_count = 0
Expand Down
21 changes: 19 additions & 2 deletions src/paperscout/protocols.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Structural typing contracts for pluggable data sources.
"""Structural typing contracts for pluggable data sources and callbacks.

Known ``source_id`` values:

Expand All @@ -9,7 +9,10 @@

from __future__ import annotations

from typing import Any, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

if TYPE_CHECKING:
from .monitor import PollResult

# Well-known source identifiers (stable across releases).
SOURCE_WG21_INDEX = "wg21_index"
Expand All @@ -31,3 +34,17 @@ async def fetch(self) -> Any:
def diff(self, previous: Any, current: Any) -> Any:
"""Compare *previous* and *current* snapshots; return source-specific diff."""
...


@runtime_checkable
class NotifyCallback(Protocol):
"""Deliver a completed poll result (channel + per-user notifications)."""

def __call__(self, result: PollResult) -> None: ...


@runtime_checkable
class OpsAlertFn(Protocol):
"""Surface operational errors to operators. Caller catches exceptions."""

def __call__(self, message: str) -> None: ...
85 changes: 85 additions & 0 deletions tests/test_callback_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Tests for NotifyCallback and OpsAlertFn protocol conformance."""

from __future__ import annotations

import os
import subprocess
import sys
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock

from paperscout.models import CycleResult, CycleStatus
from paperscout.monitor import DiffResult, PollResult, Scheduler
from paperscout.protocols import SOURCE_ISO_PROBE, SOURCE_WG21_INDEX, NotifyCallback, OpsAlertFn
from paperscout.sources import ISOProber, WG21Index
from paperscout.storage import ProbeState, UserWatchlist
from tests.conftest import make_test_settings

_REPO_ROOT = Path(__file__).resolve().parents[1]
_INVALID_CALLBACKS = _REPO_ROOT / "tests" / "typing" / "invalid_callbacks.py"


def test_notify_callback_protocol_isinstance() -> None:
def ok(result: PollResult) -> None:
del result

assert isinstance(ok, NotifyCallback)


def test_ops_alert_fn_protocol_isinstance() -> None:
def ok(message: str) -> None:
del message

assert isinstance(ok, OpsAlertFn)


async def test_scheduler_accepts_protocol_callbacks(fake_pool) -> None:
notified: list[PollResult] = []
alerts: list[str] = []

def on_notify(result: PollResult) -> None:
notified.append(result)

def on_ops_alert(message: str) -> None:
alerts.append(message)

wg21 = MagicMock(spec=WG21Index)
wg21.source_id = SOURCE_WG21_INDEX
wg21.fetch = AsyncMock(return_value={})
wg21.diff = MagicMock(return_value=DiffResult([], []))

iso = MagicMock(spec=ISOProber)
iso.source_id = SOURCE_ISO_PROBE
iso.fetch = AsyncMock(return_value=CycleResult(CycleStatus.EMPTY))
iso.diff = MagicMock(return_value=[])
iso.snapshot_stats = MagicMock(return_value={})

scheduler = Scheduler(
sources=[wg21, iso],
user_watchlist=MagicMock(spec=UserWatchlist),
state=ProbeState(fake_pool),
cfg=make_test_settings(),
notify_callback=on_notify,
ops_alert_fn=on_ops_alert,
)
assert isinstance(scheduler.notify_callback, NotifyCallback)
assert isinstance(scheduler.ops_alert_fn, OpsAlertFn)

wg21.papers = {}
await scheduler.poll_once()
assert notified == []


def test_mypy_rejects_invalid_callbacks() -> None:
env = {**os.environ, "MYPYPATH": str(_REPO_ROOT / "src")}
result = subprocess.run(
[sys.executable, "-m", "mypy", str(_INVALID_CALLBACKS)],
cwd=_REPO_ROOT,
capture_output=True,
text=True,
check=False,
env=env,
)
assert result.returncode != 0
combined = result.stdout + result.stderr
assert "Incompatible types" in combined
21 changes: 21 additions & 0 deletions tests/typing/invalid_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Deliberate mypy failures for callback protocol conformance.

Not collected by pytest (see pyproject.toml norecursedirs). Verified by
``test_mypy_rejects_invalid_callbacks`` in tests/test_callback_protocols.py.
"""

from __future__ import annotations

from paperscout.protocols import NotifyCallback, OpsAlertFn


def bad_notify(x: str) -> None:
del x


def bad_ops_alert(n: int) -> None:
del n


_reject_notify: NotifyCallback = bad_notify
_reject_ops: OpsAlertFn = bad_ops_alert