diff --git a/src/webwright/agents/default.py b/src/webwright/agents/default.py index 4fb77f1..464dd54 100644 --- a/src/webwright/agents/default.py +++ b/src/webwright/agents/default.py @@ -6,9 +6,11 @@ from typing import Any from jinja2 import StrictUndefined, Template -from pydantic import BaseModel +from pydantic import BaseModel, Field from webwright import Environment, Model, __version__ +from webwright.cache import ScriptCache +from webwright.config import CacheConfig from webwright.exceptions import FormatError, InterruptAgentFlow, LimitsExceeded from webwright.utils.serialize import recursive_merge @@ -40,6 +42,9 @@ class AgentConfig(BaseModel): summary_every_n_steps: int = 0 summary_user_prompt: str = DEFAULT_SUMMARY_USER_PROMPT output_path: Path | None = None + cache: CacheConfig = Field(default_factory=CacheConfig) + cache_fingerprint: str | None = None + cache_metadata: dict[str, Any] = Field(default_factory=dict) def _sanitize_message_for_disk(message: dict[str, Any]) -> dict[str, Any]: @@ -263,6 +268,56 @@ def _tool_gate_error(self) -> str | None: ) return None + def _latest_final_run_script(self) -> Path | None: + workspace_dir = self.get_template_vars().get("workspace_dir") + if not workspace_dir: + return None + final_runs_dir = Path(workspace_dir) / "final_runs" + if not final_runs_dir.is_dir(): + return None + + run_dirs: list[tuple[int, Path]] = [] + for entry in final_runs_dir.iterdir(): + if not entry.is_dir() or not entry.name.startswith("run_"): + continue + suffix = entry.name[len("run_"):] + try: + run_id = int(suffix) + except ValueError: + continue + run_dirs.append((run_id, entry)) + + for _, run_dir in sorted(run_dirs, key=lambda item: item[0], reverse=True): + final_script_path = run_dir / "final_script.py" + if final_script_path.is_file(): + return final_script_path + return None + + def _cache_final_script_path(self) -> Path | None: + latest_run_script = self._latest_final_run_script() + if latest_run_script is not None: + return latest_run_script + + final_script_path = self.get_template_vars().get("final_script_path") + if not final_script_path: + return None + path = Path(final_script_path) + return path if path.is_file() else None + + def _write_cache_entry(self) -> None: + if not self.config.cache.enabled or not self.config.cache_fingerprint or self.config.output_path is None: + return + final_script_path = self._cache_final_script_path() + if final_script_path is None: + return + cache = ScriptCache(self.config.cache) + cache.put( + self.config.cache_fingerprint, + final_script_path, + self.config.output_path, + metadata=self.config.cache_metadata, + ) + def add_messages(self, *messages: dict[str, Any]) -> list[dict[str, Any]]: self.messages.extend(messages) return list(messages) @@ -345,7 +400,10 @@ def run(self, task: str = "", **kwargs) -> dict[str, Any]: ): self._compact_history() self.save(self.config.output_path) - return self.messages[-1].get("extra", {}) + result = self.messages[-1].get("extra", {}) + if result.get("exit_status") == "Submitted": + self._write_cache_entry() + return result def step(self) -> list[dict[str, Any]]: return self.execute_actions(self.query()) diff --git a/src/webwright/cache/__init__.py b/src/webwright/cache/__init__.py new file mode 100644 index 0000000..8d917f9 --- /dev/null +++ b/src/webwright/cache/__init__.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from webwright.cache.script_cache import CachedScript, ScriptCache, make_fingerprint + +__all__ = [ + "CachedScript", + "ScriptCache", + "make_fingerprint", +] diff --git a/src/webwright/cache/script_cache.py b/src/webwright/cache/script_cache.py new file mode 100644 index 0000000..9a4e488 --- /dev/null +++ b/src/webwright/cache/script_cache.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import hashlib +import json +import shutil +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import httpx + +from webwright.config import CacheConfig + + +@dataclass(frozen=True) +class CachedScript: + fingerprint: str + directory: Path + script_path: Path + trajectory_path: Path + metadata: dict[str, Any] + + +def _field_value(config: dict[str, Any], field: str) -> Any: + if field == "task": + return config.get("run", {}).get("task", "") + if field == "start_url": + return config.get("run", {}).get("start_url", "") + + value: Any = config + for part in field.split("."): + if not isinstance(value, dict): + return "" + value = value.get(part, "") + return value + + +def make_fingerprint(config: dict[str, Any]) -> str: + cache_config = CacheConfig(**config.get("cache", {})) + payload = { + field: _field_value(config, field) + for field in cache_config.fingerprint_fields + } + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode() + return hashlib.sha256(encoded).hexdigest()[:16] + + +class ScriptCache: + def __init__(self, config: dict[str, Any] | CacheConfig | None = None): + if isinstance(config, CacheConfig): + self.config = config + else: + self.config = CacheConfig(**(config or {})) + self.directory = self.config.directory.expanduser() + + @property + def enabled(self) -> bool: + return self.config.enabled + + def _entry_dir(self, fingerprint: str) -> Path: + return self.directory / fingerprint + + def get(self, fingerprint: str) -> CachedScript | None: + if not self.enabled: + return None + + entry_dir = self._entry_dir(fingerprint) + metadata_path = entry_dir / "metadata.json" + script_path = entry_dir / "final_script.py" + trajectory_path = entry_dir / "trajectory.json" + if not metadata_path.is_file() or not script_path.is_file() or not trajectory_path.is_file(): + return None + + try: + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + self.invalidate(fingerprint) + return None + + if self._is_expired(metadata): + self.invalidate(fingerprint) + return None + + return CachedScript( + fingerprint=fingerprint, + directory=entry_dir, + script_path=script_path, + trajectory_path=trajectory_path, + metadata=metadata, + ) + + def put( + self, + fingerprint: str, + final_script_path: str | Path, + trajectory_path: str | Path, + *, + metadata: dict[str, Any] | None = None, + ) -> CachedScript | None: + if not self.enabled: + return None + + source_script = Path(final_script_path).expanduser() + source_trajectory = Path(trajectory_path).expanduser() + if not source_script.is_file() or not source_trajectory.is_file(): + return None + + entry_dir = self._entry_dir(fingerprint) + entry_dir.mkdir(parents=True, exist_ok=True) + script_path = entry_dir / "final_script.py" + trajectory_copy_path = entry_dir / "trajectory.json" + shutil.copy2(source_script, script_path) + shutil.copy2(source_trajectory, trajectory_copy_path) + + entry_metadata = { + **(metadata or {}), + "fingerprint": fingerprint, + "created_at": datetime.now(timezone.utc).isoformat(), + "final_script_source": str(source_script), + "trajectory_source": str(source_trajectory), + } + (entry_dir / "metadata.json").write_text( + json.dumps(entry_metadata, indent=2), + encoding="utf-8", + ) + return CachedScript( + fingerprint=fingerprint, + directory=entry_dir, + script_path=script_path, + trajectory_path=trajectory_copy_path, + metadata=entry_metadata, + ) + + def invalidate(self, fingerprint: str) -> None: + shutil.rmtree(self._entry_dir(fingerprint), ignore_errors=True) + + def validate_url(self, start_url: str | None) -> bool: + if not self.config.validate_url or not start_url: + return True + try: + response = httpx.head(start_url, follow_redirects=True, timeout=10.0) + except httpx.HTTPError: + return False + return response.status_code < 400 + + def _is_expired(self, metadata: dict[str, Any]) -> bool: + if self.config.ttl_seconds <= 0: + return False + raw_created_at = metadata.get("created_at") + if not isinstance(raw_created_at, str): + return True + try: + created_at = datetime.fromisoformat(raw_created_at) + except ValueError: + return True + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + age = datetime.now(timezone.utc) - created_at + return age.total_seconds() > self.config.ttl_seconds diff --git a/src/webwright/config/__init__.py b/src/webwright/config/__init__.py index 3b93c43..d91cb9e 100644 --- a/src/webwright/config/__init__.py +++ b/src/webwright/config/__init__.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any +from pydantic import BaseModel, Field import yaml from webwright import package_dir @@ -12,6 +13,21 @@ builtin_config_dir = package_dir / "config" +class CacheConfig(BaseModel): + enabled: bool = False + directory: Path = Path("~/.cache/webwright") + ttl_seconds: int = 604800 + validate_url: bool = True + fingerprint_fields: list[str] = Field( + default_factory=lambda: [ + "task", + "start_url", + "model.model_name", + "environment.environment_class", + ] + ) + + def _nest_key_value(key: str, value: Any) -> dict[str, Any]: parts = key.split(".") nested: dict[str, Any] = value diff --git a/src/webwright/config/base.yaml b/src/webwright/config/base.yaml index 3857517..c27229c 100644 --- a/src/webwright/config/base.yaml +++ b/src/webwright/config/base.yaml @@ -81,6 +81,17 @@ run: task_id: start_url: +cache: + enabled: false + directory: ~/.cache/webwright + ttl_seconds: 604800 + validate_url: true + fingerprint_fields: + - task + - start_url + - model.model_name + - environment.environment_class + agent: agent_class: default debug_log: true diff --git a/src/webwright/run/cli.py b/src/webwright/run/cli.py index cc27632..22ae49d 100644 --- a/src/webwright/run/cli.py +++ b/src/webwright/run/cli.py @@ -1,5 +1,8 @@ from __future__ import annotations +import json +import shlex +import sys from datetime import datetime from pathlib import Path from typing import Any @@ -8,6 +11,7 @@ from rich.console import Console from webwright.agents import get_agent +from webwright.cache import CachedScript, ScriptCache, make_fingerprint from webwright.config import get_config_from_spec, snapshot_config_specs from webwright.environments import get_environment from webwright.models import get_model @@ -16,7 +20,7 @@ DEFAULT_CONFIGS = ["base.yaml", "model_openai.yaml"] -app = typer.Typer(no_args_is_help=True) +app = typer.Typer(no_args_is_help=True, context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) console = Console(highlight=False) @@ -27,6 +31,137 @@ def _timestamped_output_dir(base_dir: str | Path | None, task_id: str | None) -> return base / f"{suffix}_{stamp}" +def _extra_config_specs(args: list[str]) -> list[str]: + specs: list[str] = [] + index = 0 + while index < len(args): + raw_arg = args[index] + if not raw_arg.startswith("--") or "." not in raw_arg: + raise ValueError(f"Unsupported CLI override: {raw_arg!r}") + + spec = raw_arg[2:] + if "=" not in spec: + if index + 1 < len(args) and not args[index + 1].startswith("--"): + spec = f"{spec}={args[index + 1]}" + index += 1 + else: + spec = f"{spec}=true" + specs.append(spec) + index += 1 + return specs + + +def _cache_metadata(config: dict[str, Any], fingerprint: str) -> dict[str, Any]: + return { + "fingerprint": fingerprint, + "task": config.get("run", {}).get("task", ""), + "start_url": config.get("run", {}).get("start_url", ""), + "model": { + "model_class": config.get("model", {}).get("model_class", ""), + "model_name": config.get("model", {}).get("model_name", ""), + }, + "environment": { + "environment_class": config.get("environment", {}).get("environment_class", ""), + }, + } + + +def _result_from_cached_trajectory(cached: CachedScript, trajectory: dict[str, Any]) -> dict[str, Any]: + info = trajectory.get("info", {}) + return { + "exit_status": info.get("exit_status", "Submitted"), + "submission": info.get("submission", ""), + "final_response": info.get("submission", ""), + "cached": True, + "cache_fingerprint": cached.fingerprint, + } + + +def _write_cached_trajectory( + *, + cached: CachedScript, + config: dict[str, Any], + replay_output: dict[str, Any], +) -> dict[str, Any]: + try: + trajectory = json.loads(cached.trajectory_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + trajectory = { + "info": { + "exit_status": "Submitted", + "submission": "Cached script completed.", + "api_calls": 0, + "format_errors": 0, + }, + "messages": [], + "trajectory_format": "mini-swe-webagent-0.1", + } + + trajectory["cached"] = True + trajectory["cache"] = { + "fingerprint": cached.fingerprint, + "source": str(cached.directory), + "script_path": str(cached.script_path), + "replay_returncode": replay_output.get("returncode"), + } + trajectory["replay_observation"] = replay_output.get("observation", {}) + trajectory.setdefault("info", {})["cached"] = True + trajectory["info"]["api_calls"] = 0 + + output_path = Path(config.get("agent", {}).get("output_path", "trajectory.json")).expanduser() + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(trajectory, indent=2), encoding="utf-8") + return _result_from_cached_trajectory(cached, trajectory) + + +def _try_replay_cache( + *, + config: dict[str, Any], + fingerprint: str, + task: str, + task_id: str | None, + start_url: str | None, +) -> dict[str, Any] | None: + cache = ScriptCache(config.get("cache", {})) + if not cache.enabled: + return None + + cached = cache.get(fingerprint) + if cached is None: + console.print("Cache miss: running agent") + return None + + if not cache.validate_url(start_url): + cache.invalidate(fingerprint) + console.print("Cache miss: running agent") + return None + + console.print("Cache hit: skipping model loop") + env = get_environment(config.get("environment", {})) + try: + env.prepare( + task=task, + task_id=task_id, + start_url=start_url, + ) + output = env.execute( + {"bash_command": f"{shlex.quote(sys.executable)} {shlex.quote(str(cached.script_path))}"} + ) + finally: + env.close() + + if output.get("returncode") != 0 or output.get("exception_info"): + cache.invalidate(fingerprint) + console.print("Cache miss: running agent") + return None + + return _write_cached_trajectory( + cached=cached, + config=config, + replay_output=output, + ) + + def run_one( *, task: str | None = None, @@ -82,6 +217,30 @@ def run_one( }, }, ) + fingerprint = make_fingerprint(config) + cache_metadata = _cache_metadata(config, fingerprint) + config = recursive_merge( + config, + { + "agent": { + "cache": config.get("cache", {}), + "cache_fingerprint": fingerprint, + "cache_metadata": cache_metadata, + } + }, + ) + + cached_result = _try_replay_cache( + config=config, + fingerprint=fingerprint, + task=resolved_task, + task_id=resolved_task_id, + start_url=resolved_start_url, + ) + if cached_result is not None: + cached_result["_output_dir"] = str(resolved_output_dir) + console.print(cached_result.get("final_response") or cached_result.get("submission") or "Task finished.") + return cached_result model = get_model(config.get("model", {})) env = get_environment(config.get("environment", {})) @@ -133,6 +292,7 @@ def run_one( @app.command() def main( + ctx: typer.Context, task: str = typer.Option(..., "-t", "--task", help="Natural language task description."), task_id: str | None = typer.Option(None, "--task-id", help="Optional identifier used in the output directory name."), start_url: str | None = typer.Option(None, "--start-url", help="Optional starting URL for the task."), @@ -140,11 +300,12 @@ def main( output_dir: Path | None = typer.Option(None, "-o", "--output-dir"), debug: bool = typer.Option(False, "--debug", help="Launch headed local Playwright with devtools and keep it open for inspection."), ) -> Any: + resolved_config_spec = list(config_spec) + _extra_config_specs(list(ctx.args)) return run_one( task=task, task_id=task_id, start_url=start_url, - config_spec=config_spec, + config_spec=resolved_config_spec, output_dir=output_dir, debug=debug, ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2426b37 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" +sys.path.insert(0, str(SRC)) diff --git a/tests/unit/test_script_cache.py b/tests/unit/test_script_cache.py new file mode 100644 index 0000000..99e21a2 --- /dev/null +++ b/tests/unit/test_script_cache.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import json + +import httpx + +from webwright.cache import ScriptCache, make_fingerprint +from webwright.run.cli import _try_replay_cache + + +def _fingerprint_config() -> dict: + return { + "cache": {"enabled": True}, + "run": { + "task": "Find red shoes", + "start_url": "https://example.com", + }, + "model": { + "model_class": "openai", + "model_name": "gpt-4.1", + }, + "environment": { + "environment_class": "local_workspace", + }, + } + + +def test_make_fingerprint_is_stable_and_changes_on_input_change() -> None: + config = _fingerprint_config() + assert make_fingerprint(config) == make_fingerprint(_fingerprint_config()) + + changed = _fingerprint_config() + changed["run"]["task"] = "Find blue shoes" + assert make_fingerprint(config) != make_fingerprint(changed) + + +def test_replay_script_error_invalidates_cache_entry(tmp_path) -> None: + cache_dir = tmp_path / "cache" + source_dir = tmp_path / "source" + source_dir.mkdir() + final_script_path = source_dir / "final_script.py" + final_script_path.write_text("raise SystemExit(7)\n", encoding="utf-8") + trajectory_path = source_dir / "trajectory.json" + trajectory_path.write_text( + json.dumps( + { + "info": { + "exit_status": "Submitted", + "submission": "done", + "api_calls": 1, + "format_errors": 0, + }, + "messages": [], + "trajectory_format": "mini-swe-webagent-0.1", + } + ), + encoding="utf-8", + ) + + config = _fingerprint_config() + config["cache"] = { + "enabled": True, + "directory": str(cache_dir), + "validate_url": False, + } + config["environment"] = { + "environment_class": "local_workspace", + "output_dir": str(tmp_path / "replay"), + } + config["agent"] = { + "output_path": str(tmp_path / "replay" / "trajectory.json"), + } + fingerprint = make_fingerprint(config) + + cache = ScriptCache(config["cache"]) + cache.put(fingerprint, final_script_path, trajectory_path) + + result = _try_replay_cache( + config=config, + fingerprint=fingerprint, + task=config["run"]["task"], + task_id=None, + start_url=config["run"]["start_url"], + ) + + assert result is None + assert not (cache_dir / fingerprint).exists() + + +def test_start_url_500_invalidates_cache_entry(tmp_path, monkeypatch) -> None: + cache_dir = tmp_path / "cache" + source_dir = tmp_path / "source" + source_dir.mkdir() + final_script_path = source_dir / "final_script.py" + final_script_path.write_text("print('should not run')\n", encoding="utf-8") + trajectory_path = source_dir / "trajectory.json" + trajectory_path.write_text( + json.dumps( + { + "info": { + "exit_status": "Submitted", + "submission": "done", + "api_calls": 1, + "format_errors": 0, + }, + "messages": [], + "trajectory_format": "mini-swe-webagent-0.1", + } + ), + encoding="utf-8", + ) + + config = _fingerprint_config() + config["cache"] = { + "enabled": True, + "directory": str(cache_dir), + "validate_url": True, + } + config["environment"] = { + "environment_class": "local_workspace", + "output_dir": str(tmp_path / "replay"), + } + config["agent"] = { + "output_path": str(tmp_path / "replay" / "trajectory.json"), + } + fingerprint = make_fingerprint(config) + + cache = ScriptCache(config["cache"]) + cache.put(fingerprint, final_script_path, trajectory_path) + + def fake_head(*args, **kwargs) -> httpx.Response: + return httpx.Response(status_code=500) + + monkeypatch.setattr(httpx, "head", fake_head) + + result = _try_replay_cache( + config=config, + fingerprint=fingerprint, + task=config["run"]["task"], + task_id=None, + start_url=config["run"]["start_url"], + ) + + assert result is None + assert not (cache_dir / fingerprint).exists() + assert not (tmp_path / "replay").exists()