Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions scripts/gen_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,31 @@
),
)

# Classes that need a field_validator injected after generation.
# Each entry: (class_name, field_name, validator_body)
# The validator_body is the full method source (indented 4 spaces inside the class).
CLASS_VALIDATOR_INJECTIONS: tuple[tuple[str, str, str], ...] = (
(
"InitializeRequest",
"protocol_version",
textwrap.dedent("""\
@field_validator("protocol_version", mode="before")
@classmethod
def _coerce_protocol_version(cls, v: Any) -> int:
# Some clients (e.g. Zed) send a date string like "2024-11-05" instead
# of an integer. The Rust SDK treats any string as version 0; we map it
# to 1 (the current stable version) so the connection is not rejected.
# See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs
if isinstance(v, int):
return v
try:
return int(v)
except (TypeError, ValueError):
return 1
"""),
),
)


@dataclass(frozen=True)
class _ProcessingStep:
Expand Down Expand Up @@ -182,6 +207,7 @@ def postprocess_generated_schema(output_path: Path) -> list[str]:
_ProcessingStep("apply default overrides", _apply_default_overrides),
_ProcessingStep("attach description comments", _add_description_comments),
_ProcessingStep("ensure custom BaseModel", _ensure_custom_base_model),
_ProcessingStep("inject field validators", _inject_field_validators),
)

for step in processing_steps:
Expand Down Expand Up @@ -338,6 +364,51 @@ def __getattr__(self, item: str) -> Any:
return "\n".join(lines) + "\n"


def _ensure_pydantic_import(content: str, name: str) -> str:
"""Add *name* to the ``from pydantic import ...`` line if not already present."""
lines = content.splitlines()
for idx, line in enumerate(lines):
if not line.startswith("from pydantic import "):
continue
imports = [part.strip() for part in line[len("from pydantic import "):].split(",")]
if name not in imports:
imports.append(name)
lines[idx] = "from pydantic import " + ", ".join(imports)
return "\n".join(lines) + "\n"
return content


def _inject_field_validators(content: str) -> str:
"""Inject field_validator methods into classes listed in CLASS_VALIDATOR_INJECTIONS."""
for class_name, _field_name, validator_body in CLASS_VALIDATOR_INJECTIONS:
# Ensure field_validator is imported from pydantic.
content = _ensure_pydantic_import(content, "field_validator")

# Find the end of the class body and append the validator before the next class.
class_pattern = re.compile(
rf"(class {class_name}\(BaseModel\):)(.*?)(?=\nclass |\Z)",
re.DOTALL,
)

def _append_validator(
match: re.Match[str],
_body: str = validator_body,
_class: str = class_name,
) -> str:
header, block = match.group(1), match.group(2)
# Indent the validator body by 4 spaces to sit inside the class.
indented = "\n" + textwrap.indent(_body.rstrip(), " ")
return header + block + indented + "\n"

content, count = class_pattern.subn(_append_validator, content, count=1)
if count == 0:
print(
f"Warning: class {class_name} not found for validator injection",
file=sys.stderr,
)
return content


def _apply_field_overrides(content: str) -> str:
for class_name, field_name, new_type, optional in FIELD_TYPE_OVERRIDES:
if optional:
Expand Down
16 changes: 15 additions & 1 deletion src/acp/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict
from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict, field_validator

PermissionOptionKind = Literal["allow_once", "allow_always", "reject_once", "reject_always"]
PlanEntryPriority = Literal["high", "medium", "low"]
Expand Down Expand Up @@ -1588,6 +1588,20 @@ class InitializeRequest(BaseModel):
),
]

@field_validator("protocol_version", mode="before")
@classmethod
def _coerce_protocol_version(cls, v: Any) -> int:
# Some clients (e.g. Zed) send a date string like "2024-11-05" instead
# of an integer. The Rust SDK treats any string as version 0; we map it
# to 1 (the current stable version) so the connection is not rejected.
# See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs
if isinstance(v, int):
return v
try:
return int(v)
except (TypeError, ValueError):
return 1


class KillTerminalRequest(BaseModel):
# The _meta property is reserved by ACP to allow clients and agents to attach additional
Expand Down