diff --git a/changelog.d/3440.changed.md b/changelog.d/3440.changed.md index e70faa71b..16905ae65 100644 --- a/changelog.d/3440.changed.md +++ b/changelog.d/3440.changed.md @@ -1 +1,3 @@ Dual-write live simulation and report create/update traffic into the new run tables, keep parent run pointers in sync, and harden report mutations to remain country-scoped and transactionally consistent. + +Preserve explicit report definitions and execution metadata across later syncs, key new report creation and alias validation by canonical report identity, and resolve report reads through canonical parents plus display-run selection instead of recreating current-version parent rows. diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 085f31c0b..4aacc9112 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -135,6 +135,8 @@ CREATE TABLE IF NOT EXISTS report_outputs ( report_spec_json JSON DEFAULT NULL, report_spec_schema_version INT DEFAULT NULL, report_spec_status VARCHAR(32) DEFAULT NULL, + report_identity_hash VARCHAR(64) DEFAULT NULL, + report_identity_schema_version INT DEFAULT NULL, active_run_id CHAR(36) DEFAULT NULL, latest_successful_run_id CHAR(36) DEFAULT NULL ); diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index 53a37b4c8..b8530be65 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -147,6 +147,8 @@ CREATE TABLE IF NOT EXISTS report_outputs ( report_spec_json JSON DEFAULT NULL, report_spec_schema_version INT DEFAULT NULL, report_spec_status VARCHAR(32) DEFAULT NULL, + report_identity_hash VARCHAR(64) DEFAULT NULL, + report_identity_schema_version INT DEFAULT NULL, active_run_id CHAR(36) DEFAULT NULL, latest_successful_run_id CHAR(36) DEFAULT NULL ); diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index b3be7672c..c64bbde7e 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -8,6 +8,25 @@ report_output_bp = Blueprint("report_output", __name__) report_output_service = ReportOutputService() +RUN_METADATA_FIELDS = ( + "country_package_version", + "policyengine_version", + "data_version", + "runtime_app_name", + "resolved_dataset", +) + + +def _parse_report_run_metadata(payload: dict) -> dict[str, str | None]: + metadata: dict[str, str | None] = {} + for field_name in RUN_METADATA_FIELDS: + if field_name not in payload: + continue + value = payload.get(field_name) + if value is not None and not isinstance(value, str): + raise BadRequest(f"{field_name} must be a string or null") + metadata[field_name] = value + return metadata @report_output_bp.route("//report", methods=["POST"]) @@ -33,6 +52,8 @@ def create_report_output(country_id: str) -> Response: simulation_1_id = payload.get("simulation_1_id") simulation_2_id = payload.get("simulation_2_id") # Optional year = payload.get("year", CURRENT_YEAR) # Default to current year as string + report_spec_payload = payload.get("report_spec") + report_spec_schema_version = payload.get("report_spec_schema_version") # Validate required fields if simulation_1_id is None: @@ -43,14 +64,35 @@ def create_report_output(country_id: str) -> Response: raise BadRequest("simulation_2_id must be an integer or null") if not isinstance(year, str): raise BadRequest("year must be a string") + if report_spec_payload is not None and not isinstance(report_spec_payload, dict): + raise BadRequest("report_spec must be an object") + if report_spec_schema_version is not None and not isinstance( + report_spec_schema_version, int + ): + raise BadRequest("report_spec_schema_version must be an integer") + + report_spec = None + if report_spec_payload is not None: + try: + report_spec = report_output_service.parse_report_spec_payload( + report_spec_payload, + ( + report_spec_schema_version + if report_spec_schema_version is not None + else 1 + ), + ) + except ValueError as exc: + raise BadRequest(str(exc)) from exc try: # Check if report already exists with these simulation IDs and year - existing_report = report_output_service.find_existing_report_output( + existing_report = report_output_service.find_existing_report_output_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, ) if existing_report: @@ -58,6 +100,8 @@ def create_report_output(country_id: str) -> Response: report_output_service.ensure_report_output_dual_write_state( existing_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) ) # Report already exists, return it with 200 status @@ -79,6 +123,8 @@ def create_report_output(country_id: str) -> Response: simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) response_body = dict( @@ -156,6 +202,7 @@ def update_report_output(country_id: str) -> Response: report_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + version_manifest_overrides = _parse_report_run_metadata(payload) print(f"Updating report #{report_id} for country {country_id}") # Validate status if provided @@ -181,6 +228,7 @@ def update_report_output(country_id: str) -> Response: status=status, output=output, error_message=error_message, + version_manifest_overrides=version_manifest_overrides, ) if not success: diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 5a16b807e..d38d39e98 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -7,6 +7,24 @@ simulation_bp = Blueprint("simulation", __name__) simulation_service = SimulationService() +RUN_METADATA_FIELDS = ( + "country_package_version", + "policyengine_version", + "data_version", + "runtime_app_name", +) + + +def _parse_simulation_run_metadata(payload: dict) -> dict[str, str | None]: + metadata: dict[str, str | None] = {} + for field_name in RUN_METADATA_FIELDS: + if field_name not in payload: + continue + value = payload.get(field_name) + if value is not None and not isinstance(value, str): + raise BadRequest(f"{field_name} must be a string or null") + metadata[field_name] = value + return metadata @simulation_bp.route("//simulation", methods=["POST"]) @@ -161,6 +179,7 @@ def update_simulation(country_id: str) -> Response: simulation_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + version_manifest_overrides = _parse_simulation_run_metadata(payload) print(f"Updating simulation #{simulation_id} for country {country_id}") # Validate status if provided @@ -186,6 +205,7 @@ def update_simulation(country_id: str) -> Response: status=status, output=output, error_message=error_message, + version_manifest_overrides=version_manifest_overrides, ) if not success: diff --git a/policyengine_api/services/report_output_alias_service.py b/policyengine_api/services/report_output_alias_service.py index 9440cfdfd..6a120ba6c 100644 --- a/policyengine_api/services/report_output_alias_service.py +++ b/policyengine_api/services/report_output_alias_service.py @@ -7,7 +7,7 @@ class ReportOutputAliasService: def _get_report_output_row(self, report_output_id: int) -> dict | None: row: Row | None = database.query( """ - SELECT id, country_id, simulation_1_id, simulation_2_id, year + SELECT id, country_id, report_identity_hash, report_identity_schema_version FROM report_outputs WHERE id = ? """, @@ -15,6 +15,45 @@ def _get_report_output_row(self, report_output_id: int) -> dict | None: ).fetchone() return dict(row) if row is not None else None + def _validate_alias_identity_compatibility( + self, + legacy_report_output: dict, + canonical_report_output: dict, + ) -> None: + if legacy_report_output["country_id"] != canonical_report_output["country_id"]: + raise ValueError( + "Legacy and canonical report outputs must describe the same report" + ) + + if ( + legacy_report_output["report_identity_hash"] is None + or legacy_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Legacy report output must have canonical report identity before " + "aliasing" + ) + + if ( + canonical_report_output["report_identity_hash"] is None + or canonical_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Canonical report output must have canonical report identity before " + "aliasing" + ) + + if ( + legacy_report_output["report_identity_hash"] + != canonical_report_output["report_identity_hash"] + or legacy_report_output["report_identity_schema_version"] + != canonical_report_output["report_identity_schema_version"] + ): + raise ValueError( + "Legacy and canonical report outputs must share canonical report " + "identity" + ) + def get_alias(self, legacy_report_output_id: int) -> dict | None: row: Row | None = database.query( """ @@ -78,14 +117,10 @@ def set_alias( f"#{existing_alias['canonical_report_output_id']}" ) - logical_key = ("country_id", "simulation_1_id", "simulation_2_id", "year") - if any( - legacy_report_output[field] != canonical_report_output[field] - for field in logical_key - ): - raise ValueError( - "Legacy and canonical report outputs must describe the same report" - ) + self._validate_alias_identity_compatibility( + legacy_report_output, + canonical_report_output, + ) database.query( """ INSERT INTO legacy_report_output_aliases diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index c462257c8..b73664ab1 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -4,9 +4,14 @@ from policyengine_api.constants import get_report_output_cache_version from policyengine_api.data import database +from policyengine_api.services.report_output_alias_service import ( + ReportOutputAliasService, +) +from policyengine_api.services.report_run_service import ReportRunService from policyengine_api.services.report_spec_service import ( ECONOMY_REPORT_KINDS, ReportSpec, + REPORT_SPEC_SCHEMA_VERSION, ReportSpecService, ) from policyengine_api.services.run_sync_utils import ( @@ -21,6 +26,8 @@ class ReportOutputService: def __init__(self): self.report_spec_service = ReportSpecService() self.simulation_service = SimulationService() + self.report_output_alias_service = ReportOutputAliasService() + self.report_run_service = ReportRunService() def _lock_clause(self) -> str: return "" if database.local else " FOR UPDATE" @@ -159,18 +166,30 @@ def _derive_report_country_package_version( return versions[0] return None - def _build_version_manifest( + def _merge_version_manifest_overrides( + self, + version_manifest: dict[str, str | None], + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + merged_manifest = dict(version_manifest) + for key, value in (version_manifest_overrides or {}).items(): + if key in merged_manifest and value is not None: + merged_manifest[key] = value + return merged_manifest + + def _build_bootstrap_version_manifest( self, report_output: dict, report_spec: ReportSpec | None, simulation_1: dict | None = None, simulation_2: dict | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict[str, str | None]: resolved_dataset = None if report_spec is not None and report_spec.report_kind in ECONOMY_REPORT_KINDS: resolved_dataset = report_spec.dataset - return { + version_manifest = { "country_package_version": self._derive_report_country_package_version( simulation_1, simulation_2 ), @@ -183,22 +202,164 @@ def _build_version_manifest( "resolved_dataset": resolved_dataset, "resolved_options_hash": None, } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) + + def _build_existing_run_version_manifest( + self, + run: dict, + report_spec: ReportSpec | None, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + derived_resolved_dataset = ( + report_spec.dataset + if report_spec is not None + and report_spec.report_kind in ECONOMY_REPORT_KINDS + else None + ) + version_manifest = { + "country_package_version": run.get("country_package_version"), + "policyengine_version": run.get("policyengine_version"), + "data_version": run.get("data_version"), + "runtime_app_name": run.get("runtime_app_name"), + "report_cache_version": run.get("report_cache_version"), + "simulation_cache_version": run.get("simulation_cache_version"), + "requested_version_override": run.get("requested_version_override"), + "resolved_dataset": run.get("resolved_dataset") or derived_resolved_dataset, + "resolved_options_hash": run.get("resolved_options_hash"), + } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) def _get_report_spec_status(self, report_spec: ReportSpec) -> str: if report_spec.report_kind in ECONOMY_REPORT_KINDS: return "backfilled_assumed" return "explicit" - def _upsert_report_spec_in_transaction( + def _persist_explicit_report_spec_in_transaction( self, tx, report_output: dict, - simulation_1: dict | None, + simulation_1: dict, + simulation_2: dict | None, + explicit_report_spec: ReportSpec, + report_spec_schema_version: int | None = None, + ) -> ReportSpec: + schema_version = ( + report_spec_schema_version + if report_spec_schema_version is not None + else REPORT_SPEC_SCHEMA_VERSION + ) + self.report_spec_service._validate_schema_version(schema_version) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + explicit_report_spec, + simulation_1, + simulation_2, + ) + report_spec_status = "explicit" + existing_spec = parse_json_field(report_output.get("report_spec_json")) + if ( + existing_spec != explicit_report_spec.model_dump() + or report_output.get("report_kind") != explicit_report_spec.report_kind + or report_output.get("report_spec_schema_version") != schema_version + or report_output.get("report_spec_status") != report_spec_status + ): + tx.query( + """ + UPDATE report_outputs + SET report_kind = ?, report_spec_json = ?, + report_spec_schema_version = ?, report_spec_status = ? + WHERE id = ? + """, + ( + explicit_report_spec.report_kind, + explicit_report_spec.model_dump_json(), + schema_version, + report_spec_status, + report_output["id"], + ), + ) + report_output["report_kind"] = explicit_report_spec.report_kind + report_output["report_spec_json"] = explicit_report_spec.model_dump() + report_output["report_spec_schema_version"] = schema_version + report_output["report_spec_status"] = report_spec_status + return explicit_report_spec + + def _sync_report_identity_in_transaction( + self, + tx, + report_output: dict, + report_spec: ReportSpec | None, + ) -> None: + if report_spec is None: + return + + report_identity_hash, report_identity_schema_version = ( + self.report_spec_service.get_report_identity(report_spec) + ) + if ( + report_output.get("report_identity_hash") == report_identity_hash + and report_output.get("report_identity_schema_version") + == report_identity_schema_version + ): + return + + tx.query( + """ + UPDATE report_outputs + SET report_identity_hash = ?, report_identity_schema_version = ? + WHERE id = ? + """, + ( + report_identity_hash, + report_identity_schema_version, + report_output["id"], + ), + ) + report_output["report_identity_hash"] = report_identity_hash + report_output["report_identity_schema_version"] = report_identity_schema_version + + def _load_existing_explicit_report_spec( + self, + report_output: dict, + simulation_1: dict, simulation_2: dict | None, ) -> ReportSpec | None: - if simulation_1 is None: + if report_output.get("report_spec_status") != "explicit": return None + raw_spec = parse_json_field(report_output.get("report_spec_json")) + if raw_spec is None: + raise ValueError("Stored explicit report spec is missing report_spec_json") + + report_spec = self.report_spec_service.parse_report_spec( + raw_spec, + schema_version=report_output.get("report_spec_schema_version"), + ) + if report_output.get("report_kind") != report_spec.report_kind: + raise ValueError( + "Stored explicit report kind must match stored report spec" + ) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + return report_spec + + def _derive_and_upsert_report_spec_in_transaction( + self, + tx, + report_output: dict, + simulation_1: dict, + simulation_2: dict | None, + ) -> ReportSpec | None: try: report_spec = self.report_spec_service.build_report_spec( report_output=report_output, @@ -242,6 +403,51 @@ def _upsert_report_spec_in_transaction( return report_spec + def _upsert_report_spec_in_transaction( + self, + tx, + report_output: dict, + simulation_1: dict | None, + simulation_2: dict | None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + ) -> ReportSpec | None: + if simulation_1 is None: + if explicit_report_spec is not None: + raise ValueError( + "Explicit report specs require linked simulations to be present" + ) + if report_output.get("report_spec_status") == "explicit": + raise ValueError( + "Stored explicit report specs require linked simulations to be present" + ) + return None + + if explicit_report_spec is not None: + return self._persist_explicit_report_spec_in_transaction( + tx, + report_output, + simulation_1, + simulation_2, + explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, + ) + + stored_explicit_report_spec = self._load_existing_explicit_report_spec( + report_output, + simulation_1, + simulation_2, + ) + if stored_explicit_report_spec is not None: + return stored_explicit_report_spec + + return self._derive_and_upsert_report_spec_in_transaction( + tx, + report_output, + simulation_1, + simulation_2, + ) + def _run_matches_parent( self, run: dict, @@ -391,6 +597,9 @@ def _ensure_report_output_dual_write_state_in_transaction( report_output_id: int, *, country_id: str | None = None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: report_output = self._get_report_output_row( report_output_id, @@ -408,6 +617,8 @@ def _ensure_report_output_dual_write_state_in_transaction( bootstrap_dual_write_state=True, ) except ValueError as exc: + if explicit_report_spec is not None: + raise print( "Skipping linked simulation sync for report output " f"#{report_output_id}. Details: {str(exc)}" @@ -419,17 +630,21 @@ def _ensure_report_output_dual_write_state_in_transaction( report_output, simulation_1, simulation_2, + explicit_report_spec=explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, ) - version_manifest = self._build_version_manifest( - report_output, - report_spec=report_spec, - simulation_1=simulation_1, - simulation_2=simulation_2, - ) + self._sync_report_identity_in_transaction(tx, report_output, report_spec) runs_descending = self._list_report_runs_descending( report_output_id, queryer=tx ) if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) self._insert_bootstrap_report_run( tx, report_output, @@ -441,6 +656,21 @@ def _ensure_report_output_dual_write_state_in_transaction( ) else: mutable_run = self._select_mutable_run(report_output, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + mutable_run, + report_spec=report_spec, + version_manifest_overrides=version_manifest_overrides, + ) + if mutable_run is not None + else self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + ) if mutable_run is not None and not self._run_matches_parent( mutable_run, report_output, @@ -472,15 +702,31 @@ def ensure_report_output_dual_write_state( self, report_output_id: int, country_id: str | None = None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: return database.transaction( lambda tx: self._ensure_report_output_dual_write_state_in_transaction( tx, report_output_id, country_id=country_id, + explicit_report_spec=explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, + version_manifest_overrides=version_manifest_overrides, ) ) + def parse_report_spec_payload( + self, + raw_report_spec: dict, + schema_version: int = REPORT_SPEC_SCHEMA_VERSION, + ) -> ReportSpec: + return self.report_spec_service.parse_report_spec( + raw_report_spec, + schema_version=schema_version, + ) + def get_stored_report_output( self, country_id: str, report_output_id: int ) -> dict | None: @@ -491,11 +737,6 @@ def get_stored_report_output( """ return self._get_report_output_row(report_output_id, country_id=country_id) - def _is_current_report_output(self, report_output: dict) -> bool: - return report_output.get("api_version") == get_report_output_cache_version( - report_output["country_id"] - ) - def _find_existing_report_output_row( self, *, @@ -522,28 +763,226 @@ def _find_existing_report_output_row( row = queryer.query(query, tuple(params)).fetchone() return dict(row) if row is not None else None - def _get_or_create_current_report_output(self, report_output: dict) -> dict: - current_report = self.find_existing_report_output( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], + def _find_existing_report_output_row_by_identity( + self, + *, + country_id: str, + report_identity_hash: str, + report_identity_schema_version: int, + queryer=None, + ) -> dict | None: + queryer = queryer or database + row = queryer.query( + """ + SELECT * FROM report_outputs + WHERE country_id = ? AND report_identity_hash = ? + AND report_identity_schema_version = ? + ORDER BY id DESC + """, + ( + country_id, + report_identity_hash, + report_identity_schema_version, + ), + ).fetchone() + return dict(row) if row is not None else None + + def _list_report_output_rows_by_legacy_key( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + queryer=None, + ) -> list[dict]: + queryer = queryer or database + query = """ + SELECT * FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND year = ? + """ + params: list[int | str] = [country_id, simulation_1_id, year] + if simulation_2_id is not None: + query += " AND simulation_2_id = ?" + params.append(simulation_2_id) + else: + query += " AND simulation_2_id IS NULL" + query += " ORDER BY id DESC" + + rows = queryer.query(query, tuple(params)).fetchall() + return [dict(row) for row in rows] + + def _build_report_spec_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + queryer=None, + ) -> ReportSpec | None: + queryer = queryer or database + simulation_1 = self.simulation_service._get_simulation_row( + simulation_1_id, + queryer=queryer, + country_id=country_id, ) - if current_report is not None: - return current_report - - return self.create_report_output( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], + if simulation_1 is None: + return None + + simulation_2 = None + if simulation_2_id is not None: + simulation_2 = self.simulation_service._get_simulation_row( + simulation_2_id, + queryer=queryer, + country_id=country_id, + ) + if simulation_2 is None: + return None + + try: + return self.report_spec_service.build_report_spec( + report_output={ + "country_id": country_id, + "simulation_1_id": simulation_1_id, + "simulation_2_id": simulation_2_id, + "year": year, + }, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + except ValueError: + return None + + def _get_report_spec_for_identity_matching( + self, + report_output: dict, + *, + queryer=None, + ) -> ReportSpec | None: + queryer = queryer or database + try: + simulation_1, simulation_2 = self._get_linked_simulations( + report_output, + queryer=queryer, + ) + except ValueError: + return None + + raw_spec = parse_json_field(report_output.get("report_spec_json")) + if ( + raw_spec is not None + and report_output.get("report_spec_schema_version") is not None + ): + try: + report_spec = self.report_spec_service.parse_report_spec( + raw_spec, + schema_version=report_output["report_spec_schema_version"], + ) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + return report_spec + except ValueError: + return None + + try: + return self.report_spec_service.build_report_spec( + report_output=report_output, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + except ValueError: + return None + + def _find_existing_report_output_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + report_spec: ReportSpec | None = None, + queryer=None, + ) -> dict | None: + queryer = queryer or database + identity_report_spec = report_spec or self._build_report_spec_for_create( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, ) + if identity_report_spec is None: + return self._find_existing_report_output_row( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) + + report_identity_hash, report_identity_schema_version = ( + self.report_spec_service.get_report_identity(identity_report_spec) + ) + existing_report = self._find_existing_report_output_row_by_identity( + country_id=country_id, + report_identity_hash=report_identity_hash, + report_identity_schema_version=report_identity_schema_version, + queryer=queryer, + ) + if existing_report is not None: + return existing_report + + candidate_rows = self._list_report_output_rows_by_legacy_key( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) + for candidate_row in candidate_rows: + candidate_report_spec = self._get_report_spec_for_identity_matching( + candidate_row, + queryer=queryer, + ) + if candidate_report_spec is None: + continue + candidate_identity_hash, candidate_identity_schema_version = ( + self.report_spec_service.get_report_identity(candidate_report_spec) + ) + if ( + candidate_identity_hash == report_identity_hash + and candidate_identity_schema_version == report_identity_schema_version + ): + return candidate_row + + return None def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict: aliased_report = dict(report_output) aliased_report["id"] = report_output_id return aliased_report + def _merge_display_run_into_report_output( + self, + report_output: dict, + display_run: dict | None, + ) -> dict: + if display_run is None: + return dict(report_output) + + result = dict(report_output) + result["status"] = display_run["status"] + result["output"] = display_run.get("output") + result["error_message"] = display_run.get("error_message") + if display_run.get("report_cache_version") is not None: + result["api_version"] = display_run["report_cache_version"] + return result + def find_existing_report_output( self, country_id: str, @@ -571,12 +1010,43 @@ def find_existing_report_output( print(f"Error checking for existing report output. Details: {str(e)}") raise e + def find_existing_report_output_for_create( + self, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None = None, + year: str = "2025", + report_spec: ReportSpec | None = None, + ) -> dict | None: + try: + existing_report = self._find_existing_report_output_for_create( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + report_spec=report_spec, + ) + if existing_report is not None: + print( + "Found existing report output for create with ID: " + f"{existing_report['id']}" + ) + return existing_report + except Exception as e: + print( + "Error checking for existing report output by canonical identity. " + f"Details: {str(e)}" + ) + raise e + def create_report_output( self, country_id: str, simulation_1_id: int, simulation_2_id: int | None = None, year: str = "2025", + report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, ) -> dict: """ Create a new report output record with pending status. @@ -587,11 +1057,12 @@ def create_report_output( try: def tx_callback(tx): - existing_report = self._find_existing_report_output_row( + existing_report = self._find_existing_report_output_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, queryer=tx, ) if existing_report is not None: @@ -602,6 +1073,8 @@ def tx_callback(tx): tx, existing_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) self._require_simulation_exists( @@ -663,6 +1136,8 @@ def tx_callback(tx): tx, created_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) return database.transaction(tx_callback) @@ -683,18 +1158,34 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No f"Invalid report output ID: {report_output_id}. Must be a positive integer." ) - report_output = self._get_report_output_row( - report_output_id, - country_id=country_id, + canonical_report_output_id = ( + self.report_output_alias_service.resolve_canonical_report_output_id( + report_output_id + ) ) - if report_output is None: + if canonical_report_output_id is None: return None - if self._is_current_report_output(report_output): - return report_output + canonical_report_output = self._get_report_output_row( + canonical_report_output_id, + country_id=country_id, + ) + if canonical_report_output is None: + return None - current_report = self._get_or_create_current_report_output(report_output) - return self._alias_report_output(report_output_id, current_report) + display_run = self.report_run_service.select_display_run( + canonical_report_output + ) + resolved_report_output = self._merge_display_run_into_report_output( + canonical_report_output, + display_run, + ) + if report_output_id != canonical_report_output_id: + return self._alias_report_output( + report_output_id, + resolved_report_output, + ) + return resolved_report_output except Exception as e: print( @@ -709,6 +1200,7 @@ def update_report_output( status: str | None = None, output: str | None = None, error_message: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ Update a report output record with results or error. @@ -731,7 +1223,7 @@ def update_report_output( update_fields.append("error_message = ?") update_values.append(error_message) - if not update_fields: + if not update_fields and not version_manifest_overrides: print("No fields to update") return False @@ -745,14 +1237,16 @@ def tx_callback(tx): if requested_report is None: raise ValueError(f"Report output #{report_id} not found") - tx.query( - f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, report_id, country_id), - ) + if update_fields: + tx.query( + f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", + (*update_values, report_id, country_id), + ) self._ensure_report_output_dual_write_state_in_transaction( tx, report_id, country_id=country_id, + version_manifest_overrides=version_manifest_overrides, ) database.transaction(tx_callback) diff --git a/policyengine_api/services/report_spec_service.py b/policyengine_api/services/report_spec_service.py index b81cc566f..457d0dc0e 100644 --- a/policyengine_api/services/report_spec_service.py +++ b/policyengine_api/services/report_spec_service.py @@ -1,12 +1,15 @@ import json +import hashlib from typing import Any, Literal from pydantic import BaseModel, Field from sqlalchemy.engine.row import Row from policyengine_api.data import database +from policyengine_api.data.congressional_districts import normalize_us_region REPORT_SPEC_SCHEMA_VERSION = 1 +REPORT_IDENTITY_SCHEMA_VERSION = 1 REPORT_SPEC_STATUSES = {"explicit", "backfilled_assumed"} HOUSEHOLD_REPORT_KINDS = {"household_single", "household_comparison"} ECONOMY_REPORT_KINDS = {"economy_single", "economy_comparison"} @@ -48,6 +51,14 @@ def _validate_schema_version(self, schema_version: int | None) -> None: f"Unsupported report spec schema version: {schema_version}" ) + def _validate_report_identity_schema_version( + self, schema_version: int | None + ) -> None: + if schema_version != REPORT_IDENTITY_SCHEMA_VERSION: + raise ValueError( + f"Unsupported report identity schema version: {schema_version}" + ) + def _get_report_output_row(self, report_output_id: int) -> dict | None: row: Row | None = database.query( "SELECT * FROM report_outputs WHERE id = ?", @@ -211,6 +222,20 @@ def _validate_report_spec_matches_row( self, report_output: dict, report_spec: ReportSpec ) -> None: simulation_1, simulation_2 = self._get_linked_simulations(report_output) + self.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + + def validate_report_spec_matches_context( + self, + report_output: dict, + report_spec: ReportSpec, + simulation_1: dict, + simulation_2: dict | None = None, + ) -> None: inferred_report_kind = self.infer_report_kind(simulation_1, simulation_2) if report_spec.country_id != report_output["country_id"]: raise ValueError("Report spec country must match report output country") @@ -268,6 +293,17 @@ def _validate_report_spec_matches_row( "Report spec reform_policy_id must match linked simulations" ) + def parse_report_spec( + self, + raw_spec: dict, + schema_version: int = REPORT_SPEC_SCHEMA_VERSION, + ) -> ReportSpec: + self._validate_schema_version(schema_version) + report_kind = raw_spec.get("report_kind") + if report_kind is None: + raise ValueError("Report spec is missing report_kind") + return self._parse_report_spec(report_kind, raw_spec) + def infer_report_kind( self, simulation_1: dict, @@ -339,6 +375,60 @@ def _parse_json_field(self, value: str | dict | None) -> dict | None: return json.loads(value) return value + def canonicalize_report_spec_for_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> dict[str, Any]: + self._validate_report_identity_schema_version(schema_version) + + canonical_spec = report_spec.model_dump() + if ( + isinstance(report_spec, EconomyReportSpec) + and report_spec.country_id == "us" + ): + canonical_spec["region"] = normalize_us_region(canonical_spec["region"]) + return canonical_spec + + def serialize_canonical_report_spec_for_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> str: + canonical_spec = self.canonicalize_report_spec_for_identity( + report_spec, + schema_version=schema_version, + ) + return json.dumps( + canonical_spec, + sort_keys=True, + separators=(",", ":"), + ) + + def get_report_identity_hash( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> str: + canonical_json = self.serialize_canonical_report_spec_for_identity( + report_spec, + schema_version=schema_version, + ) + return hashlib.sha256(canonical_json.encode("utf-8")).hexdigest() + + def get_report_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> tuple[str, int]: + return ( + self.get_report_identity_hash( + report_spec, + schema_version=schema_version, + ), + schema_version, + ) + def _parse_report_spec(self, report_kind: str, raw_spec: dict) -> ReportSpec: if report_kind in HOUSEHOLD_REPORT_KINDS: return HouseholdReportSpec.model_validate(raw_spec) diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index dfb208db2..1147fdcb1 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -62,14 +62,50 @@ def _find_existing_simulation_row( ).fetchone() return dict(row) if row is not None else None - def _build_version_manifest(self, simulation: dict) -> dict[str, str | None]: - return { + def _merge_version_manifest_overrides( + self, + version_manifest: dict[str, str | None], + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + merged_manifest = dict(version_manifest) + for key, value in (version_manifest_overrides or {}).items(): + if key in merged_manifest and value is not None: + merged_manifest[key] = value + return merged_manifest + + def _build_bootstrap_version_manifest( + self, + simulation: dict, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + version_manifest = { "country_package_version": simulation.get("api_version"), "policyengine_version": None, "data_version": None, "runtime_app_name": None, "simulation_cache_version": None, } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) + + def _build_existing_run_version_manifest( + self, + run: dict, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + version_manifest = { + "country_package_version": run.get("country_package_version"), + "policyengine_version": run.get("policyengine_version"), + "data_version": run.get("data_version"), + "runtime_app_name": run.get("runtime_app_name"), + "simulation_cache_version": run.get("simulation_cache_version"), + } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) def _list_simulation_runs_descending( self, simulation_id: int, *, queryer=None @@ -134,8 +170,8 @@ def _run_matches_parent( run: dict, simulation: dict, simulation_spec: SimulationSpec, + version_manifest: dict[str, str | None], ) -> bool: - version_manifest = self._build_version_manifest(simulation) return ( run["status"] == simulation["status"] and run.get("output") == simulation.get("output") @@ -152,9 +188,12 @@ def _run_matches_parent( ) def _insert_bootstrap_run( - self, tx, simulation: dict, simulation_spec: SimulationSpec + self, + tx, + simulation: dict, + simulation_spec: SimulationSpec, + version_manifest: dict[str, str | None], ) -> None: - version_manifest = self._build_version_manifest(simulation) tx.query( """ INSERT INTO simulation_runs ( @@ -194,8 +233,8 @@ def _update_simulation_run_in_transaction( run_id: str, simulation: dict, simulation_spec: SimulationSpec, + version_manifest: dict[str, str | None], ) -> None: - version_manifest = self._build_version_manifest(simulation) tx.query( """ UPDATE simulation_runs @@ -253,6 +292,7 @@ def _ensure_simulation_dual_write_state_in_transaction( simulation_id: int, *, country_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: simulation = self._get_simulation_row( simulation_id, @@ -268,22 +308,44 @@ def _ensure_simulation_dual_write_state_in_transaction( simulation_id, queryer=tx ) if not runs_descending: - self._insert_bootstrap_run(tx, simulation, simulation_spec) + version_manifest = self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + self._insert_bootstrap_run( + tx, + simulation, + simulation_spec, + version_manifest=version_manifest, + ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx ) else: mutable_run = self._select_mutable_run(simulation, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + mutable_run, + version_manifest_overrides=version_manifest_overrides, + ) + if mutable_run is not None + else self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + ) if mutable_run is not None and not self._run_matches_parent( mutable_run, simulation, simulation_spec, + version_manifest=version_manifest, ): self._update_simulation_run_in_transaction( tx, run_id=mutable_run["id"], simulation=simulation, simulation_spec=simulation_spec, + version_manifest=version_manifest, ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx @@ -300,13 +362,17 @@ def _ensure_simulation_dual_write_state_in_transaction( return refreshed_simulation def ensure_simulation_dual_write_state( - self, simulation_id: int, country_id: str | None = None + self, + simulation_id: int, + country_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: return database.transaction( lambda tx: self._ensure_simulation_dual_write_state_in_transaction( tx, simulation_id, country_id=country_id, + version_manifest_overrides=version_manifest_overrides, ) ) @@ -459,6 +525,7 @@ def update_simulation( status: str | None = None, output: str | None = None, error_message: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ Update a simulation record with results or error. @@ -495,7 +562,7 @@ def update_simulation( update_fields.append("api_version = ?") update_values.append(api_version) - if not update_fields: + if not update_fields and not version_manifest_overrides: print("No fields to update") return False @@ -509,14 +576,16 @@ def tx_callback(tx): if simulation is None: raise ValueError(f"Simulation #{simulation_id} not found") - tx.query( - f"UPDATE simulations SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, simulation_id, country_id), - ) + if update_fields: + tx.query( + f"UPDATE simulations SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", + (*update_values, simulation_id, country_id), + ) self._ensure_simulation_dual_write_state_in_transaction( tx, simulation_id, country_id=country_id, + version_manifest_overrides=version_manifest_overrides, ) database.transaction(tx_callback) diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index 2bcba1eff..72b0ba6c0 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -15,6 +15,8 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): "report_spec_json", "report_spec_schema_version", "report_spec_status", + "report_identity_hash", + "report_identity_schema_version", "active_run_id", "latest_successful_run_id", }.issubset(report_output_columns) @@ -77,6 +79,8 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", "report_spec_json", "report_spec_status", + "report_identity_hash", + "report_identity_schema_version", "simulation_spec_json", "active_run_id", "latest_successful_run_id", diff --git a/tests/unit/services/test_report_output_alias_service.py b/tests/unit/services/test_report_output_alias_service.py index e4e28c916..d1d65002a 100644 --- a/tests/unit/services/test_report_output_alias_service.py +++ b/tests/unit/services/test_report_output_alias_service.py @@ -18,12 +18,15 @@ def _insert_legacy_report_output( legacy_report_output_id: int, canonical_report: dict, api_version: str = "legacy-version", + report_identity_hash: str | None = None, + report_identity_schema_version: int | None = None, ) -> None: test_db.query( """ INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, api_version, status, year - ) VALUES (?, ?, ?, ?, ?, ?, ?) + id, country_id, simulation_1_id, simulation_2_id, api_version, status, year, + report_identity_hash, report_identity_schema_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( legacy_report_output_id, @@ -33,6 +36,10 @@ def _insert_legacy_report_output( api_version, canonical_report["status"], canonical_report["year"], + report_identity_hash or canonical_report.get("report_identity_hash"), + report_identity_schema_version + if report_identity_schema_version is not None + else canonical_report.get("report_identity_schema_version"), ), ) @@ -194,9 +201,73 @@ def test_rejects_alias_when_legacy_report_output_is_missing(self, test_db): assert "Legacy report output #10030 not found" in str(exc_info.value) - def test_rejects_alias_when_legacy_and_canonical_reports_do_not_match( + def test_rejects_alias_when_reports_do_not_share_canonical_identity( self, test_db ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=34, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=35, + ) + default_report_spec = report_output_service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + cliff_report_spec = report_output_service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=default_report_spec, + report_spec_schema_version=1, + ) + distinct_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=cliff_report_spec, + report_spec_schema_version=1, + ) + + with pytest.raises(ValueError) as exc_info: + alias_service.set_alias( + legacy_report_output_id=distinct_report["id"], + canonical_report_output_id=canonical_report["id"], + ) + + assert "must share canonical report identity" in str(exc_info.value) + + def test_rejects_alias_when_legacy_report_output_has_no_identity(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_4b", @@ -209,20 +280,29 @@ def test_rejects_alias_when_legacy_and_canonical_reports_do_not_match( simulation_2_id=None, year="2025", ) - mismatched_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2026", + self._insert_legacy_report_output( + test_db, + legacy_report_output_id=10031, + canonical_report=canonical_report, + report_identity_hash=None, + report_identity_schema_version=None, + ) + test_db.query( + """ + UPDATE report_outputs + SET report_identity_hash = NULL, report_identity_schema_version = NULL + WHERE id = ? + """, + (10031,), ) with pytest.raises(ValueError) as exc_info: alias_service.set_alias( - legacy_report_output_id=mismatched_report["id"], + legacy_report_output_id=10031, canonical_report_output_id=canonical_report["id"], ) - assert "must describe the same report" in str(exc_info.value) + assert "must have canonical report identity" in str(exc_info.value) def test_rejects_alias_when_legacy_and_canonical_ids_match(self, test_db): simulation = simulation_service.create_simulation( diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index c1b6709a5..52048ee68 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -2,6 +2,9 @@ import json from policyengine_api.constants import get_report_output_cache_version +from policyengine_api.services.report_output_alias_service import ( + ReportOutputAliasService, +) from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.services.simulation_service import SimulationService @@ -11,6 +14,7 @@ service = ReportOutputService() simulation_service = SimulationService() +alias_service = ReportOutputAliasService() class TestFindExistingReportOutput: @@ -421,20 +425,138 @@ def test_create_report_output_populates_economy_comparison_report_spec( if isinstance(report_spec, str): report_spec = json.loads(report_spec) assert report_spec["region"] == "state/ca" - assert report_spec["baseline_policy_id"] == 30 - assert report_spec["reform_policy_id"] == 31 - assert report_spec["dataset"] == "default" - run = test_db.query( - "SELECT * FROM report_output_runs WHERE report_output_id = ?", - (created_report["id"],), + def test_create_report_output_reuses_same_explicit_economy_spec(self, test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=32, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=33, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ny", + "baseline_policy_id": 32, + "reform_policy_id": 33, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + first_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + second_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + assert first_report["id"] == second_report["id"] + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (first_report["id"],), ).fetchone() - assert run is not None - snapshot = run["report_spec_snapshot_json"] - if isinstance(snapshot, str): - snapshot = json.loads(snapshot) - assert snapshot["report_kind"] == "economy_comparison" - assert snapshot["region"] == "state/ca" + assert stored_report["report_identity_hash"] is not None + assert stored_report["report_identity_schema_version"] == 1 + + def test_create_report_output_distinguishes_explicit_economy_specs_by_identity( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=34, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=35, + ) + default_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + cliff_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + first_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=default_report_spec, + report_spec_schema_version=1, + ) + second_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=cliff_report_spec, + report_spec_schema_version=1, + ) + + assert first_report["id"] != second_report["id"] + stored_reports = test_db.query( + """ + SELECT id, report_identity_hash, report_spec_json + FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND simulation_2_id = ? AND year = ? + ORDER BY id + """, + ( + "us", + baseline_simulation["id"], + reform_simulation["id"], + "2026", + ), + ).fetchall() + assert len(stored_reports) == 2 + assert ( + stored_reports[0]["report_identity_hash"] + != stored_reports[1]["report_identity_hash"] + ) class TestGetReportOutput: @@ -500,74 +622,114 @@ def test_get_report_output_with_json_output(self, test_db): assert result["year"] == "2025" # Frontend will parse this string - def test_get_report_output_resolves_stale_id_to_current_runtime_row(self, test_db): - stale_output = { - "budget": {"budgetary_impact": 1}, - "congressional_district_impact": { - "districts": [ - { - "district": "AL-01", - "average_household_income_change": 120, - "relative_household_income_change": 0.01, - } - ] - }, - } + def test_get_report_output_uses_selected_display_run_for_canonical_parent( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_display_run", + population_type="household", + policy_id=5, + ) + report_output = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report_output["id"], + status="complete", + output=json.dumps({"budget": {"budgetary_impact": 2}}), + ) test_db.query( - """INSERT INTO report_outputs - (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) - VALUES (?, ?, ?, ?, ?, ?, ?)""", + """ + UPDATE report_outputs + SET status = ?, output = ?, api_version = ? + WHERE id = ? + """, ( - "us", - 2, + "pending", None, - "complete", - json.dumps(stale_output), "r0stale1", - "2025", + report_output["id"], ), ) - stale_record = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" - ).fetchone() + result = service.get_report_output( + country_id="us", report_output_id=report_output["id"] + ) - current_version = get_report_output_cache_version("us") + assert result is not None + assert result["id"] == report_output["id"] + assert result["status"] == "complete" + assert result["output"] == json.dumps({"budget": {"budgetary_impact": 2}}) + assert result["api_version"] == get_report_output_cache_version("us") + + def test_get_report_output_resolves_alias_to_canonical_parent_and_display_run( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_alias_display_run", + population_type="household", + policy_id=6, + ) + canonical_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"budget": {"budgetary_impact": 3}}), + ) test_db.query( - """INSERT INTO report_outputs - (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) - VALUES (?, ?, ?, ?, ?, ?, ?)""", + """ + INSERT INTO report_outputs ( + id, country_id, simulation_1_id, simulation_2_id, status, output, api_version, year, + report_identity_hash, report_identity_schema_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + 999, "us", - 2, + simulation["id"], None, - "complete", - json.dumps({"budget": {"budgetary_impact": 2}}), - current_version, + "error", + json.dumps({"legacy": True}), + "r0legacy1", "2025", + canonical_report["report_identity_hash"], + canonical_report["report_identity_schema_version"], ), ) + alias_service.set_alias( + legacy_report_output_id=999, + canonical_report_output_id=canonical_report["id"], + ) - current_record = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" - ).fetchone() + result = service.get_report_output(country_id="us", report_output_id=999) - result = service.get_report_output( - country_id="us", report_output_id=stale_record["id"] - ) assert result is not None - assert result["id"] == stale_record["id"] - assert result["api_version"] == current_record["api_version"] - assert result["output"] == current_record["output"] + assert result["id"] == 999 + assert result["status"] == "complete" + assert result["output"] == json.dumps({"budget": {"budgetary_impact": 3}}) + assert result["api_version"] == get_report_output_cache_version("us") - def test_get_report_output_creates_current_runtime_row_for_stale_id(self, test_db): + def test_get_report_output_does_not_create_current_runtime_row_for_stale_id( + self, test_db + ): stale_version = "r0stale1" - current_version = get_report_output_cache_version("us") simulation = simulation_service.create_simulation( country_id="us", - population_id="household_stale_runtime_create", + population_id="household_stale_runtime_read", population_type="household", - policy_id=5, + policy_id=7, ) test_db.query( @@ -587,17 +749,15 @@ def test_get_report_output_creates_current_runtime_row_for_stale_id(self, test_d assert result is not None assert result["id"] == stale_record["id"] - assert result["api_version"] == current_version - assert result["status"] == "pending" + assert result["api_version"] == stale_version + assert result["status"] == "complete" assert result["output"] is None - current_rows = test_db.query( + rows = test_db.query( "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? ORDER BY id ASC", ("us", simulation["id"], "2025"), ).fetchall() - assert len(current_rows) == 2 - assert current_rows[0]["api_version"] == stale_version - assert current_rows[1]["api_version"] == current_version + assert len(rows) == 1 def test_get_report_output_invalid_id(self, test_db): """Test that invalid report IDs are handled properly.""" @@ -778,6 +938,163 @@ def test_update_report_output_updates_dual_write_state(self, test_db): assert run["output"] == output_json assert run["id"] == stored_report["latest_successful_run_id"] + def test_update_report_output_preserves_stored_explicit_report_spec(self, test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=61, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=62, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/co", + "baseline_policy_id": 61, + "reform_policy_id": 62, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + success = service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + ) + + assert success is True + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + def test_update_report_output_preserves_existing_run_metadata_without_overrides( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/az", + population_type="geography", + policy_id=63, + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + "data_version": "2026.04.17", + "runtime_app_name": "policyengine-app-v2", + "resolved_dataset": "enhanced_us_household", + }, + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="error", + error_message="later failure", + ) + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run["status"] == "error" + assert run["error_message"] == "later failure" + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.0" + assert run["data_version"] == "2026.04.17" + assert run["runtime_app_name"] == "policyengine-app-v2" + assert run["resolved_dataset"] == "enhanced_us_household" + + def test_update_report_output_allows_explicit_metadata_override_on_existing_run( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nm", + population_type="geography", + policy_id=64, + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + }, + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + version_manifest_overrides={ + "policyengine_version": "0.95.1", + }, + ) + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.1" + def test_update_report_output_bootstraps_missing_run_state(self, test_db): simulation_1 = simulation_service.create_simulation( country_id="us", @@ -1044,7 +1361,7 @@ def test_create_report_output_rolls_back_parent_insert_on_dual_write_failure( policy_id=34, ) - def fail_dual_write(tx, report_output_id, *, country_id=None): + def fail_dual_write(tx, report_output_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -1086,7 +1403,7 @@ def test_update_report_output_rolls_back_parent_update_on_dual_write_failure( year="2025", ) - def fail_dual_write(tx, report_output_id, *, country_id=None): + def fail_dual_write(tx, report_output_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -1208,3 +1525,146 @@ def test_ensure_report_output_dual_write_state_bootstraps_linked_simulations( ).fetchone() assert simulation_1_run is not None assert simulation_2_run is not None + + def test_ensure_report_output_dual_write_state_reuses_stored_explicit_report_spec( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/il", + population_type="geography", + policy_id=63, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/il", + population_type="geography", + policy_id=64, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/il", + "baseline_policy_id": 63, + "reform_policy_id": 64, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + synced_report = service.ensure_report_output_dual_write_state( + created_report["id"], + country_id="us", + ) + + assert synced_report["report_spec_status"] == "explicit" + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + def test_update_report_output_invalid_stored_explicit_report_spec_fails_closed( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mi", + population_type="geography", + policy_id=65, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mi", + population_type="geography", + policy_id=66, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/mi", + "baseline_policy_id": 65, + "reform_policy_id": 66, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + corrupted_spec = { + **explicit_report_spec.model_dump(), + "region": "state/ca", + } + test_db.query( + """ + UPDATE report_outputs + SET report_spec_json = ? + WHERE id = ? + """, + ( + json.dumps(corrupted_spec), + created_report["id"], + ), + ) + + with pytest.raises( + ValueError, match="Report spec region must match linked simulations" + ): + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "should_rollback"}), + ) + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + assert stored_report["status"] == "pending" + assert stored_report["output"] is None + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run is not None + assert run["status"] == "pending" + assert run["output"] is None diff --git a/tests/unit/services/test_report_spec_service.py b/tests/unit/services/test_report_spec_service.py index f924df8db..0dd98db86 100644 --- a/tests/unit/services/test_report_spec_service.py +++ b/tests/unit/services/test_report_spec_service.py @@ -5,6 +5,7 @@ from policyengine_api.services.report_spec_service import ( EconomyReportSpec, HouseholdReportSpec, + REPORT_IDENTITY_SCHEMA_VERSION, ReportSpecService, ) from policyengine_api.services.simulation_service import SimulationService @@ -467,3 +468,135 @@ def test_rejects_unsupported_schema_version_on_read(self, test_db): report_spec_service.get_report_spec(report_output["id"]) assert "Unsupported report spec schema version" in str(exc_info.value) + + +class TestReportIdentity: + def test_canonical_identity_reuses_normalized_us_region(self): + report_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_single", + "time_period": "2027", + "region": "ca", + "baseline_policy_id": 10, + "reform_policy_id": 10, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + + canonical_spec = report_spec_service.canonicalize_report_spec_for_identity( + report_spec + ) + + assert canonical_spec["region"] == "state/ca" + + def test_equal_specs_produce_equal_hashes_despite_json_key_order(self): + first_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {"b": 2, "a": 1}, + } + ) + second_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {"a": 1, "b": 2}, + } + ) + + assert report_spec_service.get_report_identity_hash( + first_spec + ) == report_spec_service.get_report_identity_hash(second_spec) + + def test_distinct_economy_dataset_changes_identity_hash(self): + first_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + second_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "enhanced_us_household", + "target": "general", + "options": {}, + } + ) + + assert report_spec_service.get_report_identity_hash( + first_spec + ) != report_spec_service.get_report_identity_hash(second_spec) + + def test_report_identity_returns_hash_and_schema_version(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) + + report_identity_hash, schema_version = report_spec_service.get_report_identity( + report_spec + ) + + assert len(report_identity_hash) == 64 + assert schema_version == REPORT_IDENTITY_SCHEMA_VERSION + + def test_rejects_unsupported_identity_schema_version(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) + + with pytest.raises(ValueError) as exc_info: + report_spec_service.get_report_identity_hash( + report_spec, + schema_version=2, + ) + + assert "Unsupported report identity schema version" in str(exc_info.value) diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 254c8867c..6050901b6 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -227,7 +227,7 @@ def test_create_simulation_reuses_existing_row_and_bootstraps_dual_write( def test_create_simulation_rolls_back_parent_insert_on_dual_write_failure( self, test_db, monkeypatch ): - def fail_dual_write(tx, simulation_id, *, country_id=None): + def fail_dual_write(tx, simulation_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -454,7 +454,7 @@ def test_update_simulation_rolls_back_parent_update_on_dual_write_failure( policy_id=15, ) - def fail_dual_write(tx, simulation_id, *, country_id=None): + def fail_dual_write(tx, simulation_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -485,3 +485,80 @@ def fail_dual_write(tx, simulation_id, *, country_id=None): assert run is not None assert run["status"] == "pending" assert run["output"] is None + + def test_update_simulation_preserves_existing_run_metadata_without_overrides( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_preserve", + population_type="household", + policy_id=16, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + "data_version": "2026.04.16", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="error", + error_message="later failure", + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["status"] == "error" + assert run["error_message"] == "later failure" + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.94.2" + assert run["data_version"] == "2026.04.16" + assert run["runtime_app_name"] == "policyengine-app-v2" + + def test_update_simulation_allows_explicit_metadata_override_on_existing_run( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_override", + population_type="household", + policy_id=17, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + }, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + version_manifest_overrides={ + "policyengine_version": "0.95.0", + }, + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.95.0" diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 5d0ca4f79..14f5b32b8 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -120,6 +120,182 @@ def test_create_report_output_existing_row_repairs_dual_write_state(test_db): assert snapshot["report_kind"] == "household_single" +def test_create_report_output_with_explicit_spec_persists_it(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=45, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=46, + ) + + client = create_test_client() + response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ny", + "baseline_policy_id": 45, + "reform_policy_id": 46, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + + assert response.status_code == 201 + report_id = response.get_json()["result"]["id"] + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_kind"] == "economy_comparison" + assert stored_report["report_spec_schema_version"] == 1 + assert stored_report["report_spec_status"] == "explicit" + + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + assert stored_report["report_identity_hash"] is not None + assert stored_report["report_identity_schema_version"] == 1 + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + +def test_create_report_output_same_explicit_spec_returns_existing_row(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/va", + population_type="geography", + policy_id=53, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/va", + population_type="geography", + policy_id=54, + ) + payload = { + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/va", + "baseline_policy_id": 53, + "reform_policy_id": 54, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + } + + client = create_test_client() + first_response = client.post("/us/report", json=payload) + second_response = client.post("/us/report", json=payload) + + assert first_response.status_code == 201 + assert second_response.status_code == 200 + assert ( + first_response.get_json()["result"]["id"] + == second_response.get_json()["result"]["id"] + ) + + +def test_create_report_output_distinct_explicit_specs_create_distinct_rows(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/md", + population_type="geography", + policy_id=55, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/md", + population_type="geography", + policy_id=56, + ) + + client = create_test_client() + default_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/md", + "baseline_policy_id": 55, + "reform_policy_id": 56, + "dataset": "default", + "target": "general", + "options": {}, + }, + }, + ) + cliff_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/md", + "baseline_policy_id": 55, + "reform_policy_id": 56, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + + assert default_response.status_code == 201 + assert cliff_response.status_code == 201 + assert ( + default_response.get_json()["result"]["id"] + != cliff_response.get_json()["result"]["id"] + ) + + def test_create_report_output_missing_primary_simulation_returns_bad_request(test_db): client = create_test_client() response = client.post( @@ -211,6 +387,39 @@ def test_patch_simulation_wrong_country_returns_not_found_and_does_not_mutate(te assert stored_simulation["output"] is None +def test_patch_simulation_persists_run_metadata_fields(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_metadata", + population_type="household", + policy_id=47, + ) + + client = create_test_client() + response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "status": "complete", + "output": json.dumps({"ok": True}), + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + "data_version": "2026.04.16", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + assert response.status_code == 200 + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (simulation["id"],), + ).fetchone() + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.94.2" + assert run["data_version"] == "2026.04.16" + assert run["runtime_app_name"] == "policyengine-app-v2" + + def test_get_report_output_wrong_country_returns_not_found(test_db): test_db.query( """ @@ -230,6 +439,110 @@ def test_get_report_output_wrong_country_returns_not_found(test_db): assert response.status_code == 404 +def test_get_report_output_alias_resolves_to_canonical_display_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_alias", + population_type="household", + policy_id=57, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical"}), + ) + test_db.query( + """ + INSERT INTO report_outputs ( + id, country_id, simulation_1_id, simulation_2_id, api_version, status, output, year + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 2001, + "us", + simulation["id"], + None, + "r0legacy1", + "error", + json.dumps({"result": "legacy"}), + "2025", + ), + ) + test_db.query( + """ + INSERT INTO legacy_report_output_aliases ( + legacy_report_output_id, canonical_report_output_id + ) VALUES (?, ?) + """, + (2001, canonical_report["id"]), + ) + + client = create_test_client() + response = client.get("/us/report/2001") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == 2001 + assert payload["result"]["status"] == "complete" + assert payload["result"]["output"] == json.dumps({"result": "canonical"}) + assert payload["result"]["api_version"] == get_report_output_cache_version("us") + + +def test_get_report_output_reads_malformed_legacy_row_without_runs_or_identity( + test_db, +): + household_simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_legacy_malformed", + population_type="household", + policy_id=58, + ) + geography_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=59, + ) + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, output, year + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "us", + household_simulation["id"], + geography_simulation["id"], + "r0legacy-malformed", + "error", + json.dumps({"result": "legacy-malformed"}), + "2025", + ), + ) + malformed_report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.get(f"/us/report/{malformed_report['id']}") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == malformed_report["id"] + assert payload["result"]["status"] == "error" + assert payload["result"]["output"] == json.dumps( + {"result": "legacy-malformed"} + ) + assert payload["result"]["api_version"] == "r0legacy-malformed" + + def test_patch_report_output_wrong_country_returns_not_found_and_does_not_mutate( test_db, ): @@ -264,3 +577,192 @@ def test_patch_report_output_wrong_country_returns_not_found_and_does_not_mutate assert stored_report["country_id"] == "us" assert stored_report["status"] == "pending" assert stored_report["output"] is None + + +def test_patch_report_output_persists_run_metadata_fields(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/wa", + population_type="geography", + policy_id=48, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report_output["id"], + "status": "complete", + "output": json.dumps({"result": "ok"}), + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + "data_version": "2026.04.17", + "runtime_app_name": "policyengine-app-v2", + "resolved_dataset": "enhanced_us_household", + }, + ) + + assert response.status_code == 200 + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_output["id"],), + ).fetchone() + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.0" + assert run["data_version"] == "2026.04.17" + assert run["runtime_app_name"] == "policyengine-app-v2" + assert run["resolved_dataset"] == "enhanced_us_household" + + +def test_patch_report_output_preserves_stored_explicit_report_spec(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/or", + population_type="geography", + policy_id=49, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/or", + population_type="geography", + policy_id=50, + ) + + client = create_test_client() + create_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/or", + "baseline_policy_id": 49, + "reform_policy_id": 50, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + report_id = create_response.get_json()["result"]["id"] + + patch_response = client.patch( + "/us/report", + json={ + "id": report_id, + "status": "complete", + "output": json.dumps({"result": "ok"}), + }, + ) + + assert patch_response.status_code == 200 + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + +def test_patch_report_output_metadata_only_preserves_stored_explicit_report_spec( + test_db, +): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nj", + population_type="geography", + policy_id=51, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nj", + population_type="geography", + policy_id=52, + ) + + client = create_test_client() + create_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/nj", + "baseline_policy_id": 51, + "reform_policy_id": 52, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + report_id = create_response.get_json()["result"]["id"] + + patch_response = client.patch( + "/us/report", + json={ + "id": report_id, + "policyengine_version": "0.95.1", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + assert patch_response.status_code == 200 + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + assert run["policyengine_version"] == "0.95.1" + assert run["runtime_app_name"] == "policyengine-app-v2" + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"}