diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index ad9c6fb..fc8d93d 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -120,6 +120,31 @@ ), ) +# Classes that need a field_validator injected after generation. +# Each entry: (class_name, field_name, validator_body) +# The validator_body is the full method source (indented 4 spaces inside the class). +CLASS_VALIDATOR_INJECTIONS: tuple[tuple[str, str, str], ...] = ( + ( + "InitializeRequest", + "protocol_version", + textwrap.dedent("""\ + @field_validator("protocol_version", mode="before") + @classmethod + def _coerce_protocol_version(cls, v: Any) -> int: + # Some clients (e.g. Zed) send a date string like "2024-11-05" instead + # of an integer. The Rust SDK treats any string as version 0; we map it + # to 1 (the current stable version) so the connection is not rejected. + # See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs + if isinstance(v, int): + return v + try: + return int(v) + except (TypeError, ValueError): + return 1 + """), + ), +) + @dataclass(frozen=True) class _ProcessingStep: @@ -182,6 +207,7 @@ def postprocess_generated_schema(output_path: Path) -> list[str]: _ProcessingStep("apply default overrides", _apply_default_overrides), _ProcessingStep("attach description comments", _add_description_comments), _ProcessingStep("ensure custom BaseModel", _ensure_custom_base_model), + _ProcessingStep("inject field validators", _inject_field_validators), ) for step in processing_steps: @@ -338,6 +364,51 @@ def __getattr__(self, item: str) -> Any: return "\n".join(lines) + "\n" +def _ensure_pydantic_import(content: str, name: str) -> str: + """Add *name* to the ``from pydantic import ...`` line if not already present.""" + lines = content.splitlines() + for idx, line in enumerate(lines): + if not line.startswith("from pydantic import "): + continue + imports = [part.strip() for part in line[len("from pydantic import "):].split(",")] + if name not in imports: + imports.append(name) + lines[idx] = "from pydantic import " + ", ".join(imports) + return "\n".join(lines) + "\n" + return content + + +def _inject_field_validators(content: str) -> str: + """Inject field_validator methods into classes listed in CLASS_VALIDATOR_INJECTIONS.""" + for class_name, _field_name, validator_body in CLASS_VALIDATOR_INJECTIONS: + # Ensure field_validator is imported from pydantic. + content = _ensure_pydantic_import(content, "field_validator") + + # Find the end of the class body and append the validator before the next class. + class_pattern = re.compile( + rf"(class {class_name}\(BaseModel\):)(.*?)(?=\nclass |\Z)", + re.DOTALL, + ) + + def _append_validator( + match: re.Match[str], + _body: str = validator_body, + _class: str = class_name, + ) -> str: + header, block = match.group(1), match.group(2) + # Indent the validator body by 4 spaces to sit inside the class. + indented = "\n" + textwrap.indent(_body.rstrip(), " ") + return header + block + indented + "\n" + + content, count = class_pattern.subn(_append_validator, content, count=1) + if count == 0: + print( + f"Warning: class {class_name} not found for validator injection", + file=sys.stderr, + ) + return content + + def _apply_field_overrides(content: str) -> str: for class_name, field_name, new_type, optional in FIELD_TYPE_OVERRIDES: if optional: diff --git a/src/acp/schema.py b/src/acp/schema.py index 32031c4..e8b9c54 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict +from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict, field_validator PermissionOptionKind = Literal["allow_once", "allow_always", "reject_once", "reject_always"] PlanEntryPriority = Literal["high", "medium", "low"] @@ -1588,6 +1588,20 @@ class InitializeRequest(BaseModel): ), ] + @field_validator("protocol_version", mode="before") + @classmethod + def _coerce_protocol_version(cls, v: Any) -> int: + # Some clients (e.g. Zed) send a date string like "2024-11-05" instead + # of an integer. The Rust SDK treats any string as version 0; we map it + # to 1 (the current stable version) so the connection is not rejected. + # See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs + if isinstance(v, int): + return v + try: + return int(v) + except (TypeError, ValueError): + return 1 + class KillTerminalRequest(BaseModel): # The _meta property is reserved by ACP to allow clients and agents to attach additional