Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions src/webwright/agents/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import Any

from jinja2 import StrictUndefined, Template
from pydantic import BaseModel
from pydantic import BaseModel, Field

from webwright import Environment, Model, __version__
from webwright.cache import ScriptCache
from webwright.config import CacheConfig
from webwright.exceptions import FormatError, InterruptAgentFlow, LimitsExceeded
from webwright.utils.serialize import recursive_merge

Expand Down Expand Up @@ -40,6 +42,9 @@ class AgentConfig(BaseModel):
summary_every_n_steps: int = 0
summary_user_prompt: str = DEFAULT_SUMMARY_USER_PROMPT
output_path: Path | None = None
cache: CacheConfig = Field(default_factory=CacheConfig)
cache_fingerprint: str | None = None
cache_metadata: dict[str, Any] = Field(default_factory=dict)


def _sanitize_message_for_disk(message: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -263,6 +268,56 @@ def _tool_gate_error(self) -> str | None:
)
return None

def _latest_final_run_script(self) -> Path | None:
workspace_dir = self.get_template_vars().get("workspace_dir")
if not workspace_dir:
return None
final_runs_dir = Path(workspace_dir) / "final_runs"
if not final_runs_dir.is_dir():
return None

run_dirs: list[tuple[int, Path]] = []
for entry in final_runs_dir.iterdir():
if not entry.is_dir() or not entry.name.startswith("run_"):
continue
suffix = entry.name[len("run_"):]
try:
run_id = int(suffix)
except ValueError:
continue
run_dirs.append((run_id, entry))

for _, run_dir in sorted(run_dirs, key=lambda item: item[0], reverse=True):
final_script_path = run_dir / "final_script.py"
if final_script_path.is_file():
return final_script_path
return None

def _cache_final_script_path(self) -> Path | None:
latest_run_script = self._latest_final_run_script()
if latest_run_script is not None:
return latest_run_script

final_script_path = self.get_template_vars().get("final_script_path")
if not final_script_path:
return None
path = Path(final_script_path)
return path if path.is_file() else None

def _write_cache_entry(self) -> None:
if not self.config.cache.enabled or not self.config.cache_fingerprint or self.config.output_path is None:
return
final_script_path = self._cache_final_script_path()
if final_script_path is None:
return
cache = ScriptCache(self.config.cache)
cache.put(
self.config.cache_fingerprint,
final_script_path,
self.config.output_path,
metadata=self.config.cache_metadata,
)

def add_messages(self, *messages: dict[str, Any]) -> list[dict[str, Any]]:
self.messages.extend(messages)
return list(messages)
Expand Down Expand Up @@ -345,7 +400,10 @@ def run(self, task: str = "", **kwargs) -> dict[str, Any]:
):
self._compact_history()
self.save(self.config.output_path)
return self.messages[-1].get("extra", {})
result = self.messages[-1].get("extra", {})
if result.get("exit_status") == "Submitted":
self._write_cache_entry()
return result

def step(self) -> list[dict[str, Any]]:
return self.execute_actions(self.query())
Expand Down
9 changes: 9 additions & 0 deletions src/webwright/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from webwright.cache.script_cache import CachedScript, ScriptCache, make_fingerprint

__all__ = [
"CachedScript",
"ScriptCache",
"make_fingerprint",
]
160 changes: 160 additions & 0 deletions src/webwright/cache/script_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

import hashlib
import json
import shutil
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import httpx

from webwright.config import CacheConfig


@dataclass(frozen=True)
class CachedScript:
fingerprint: str
directory: Path
script_path: Path
trajectory_path: Path
metadata: dict[str, Any]


def _field_value(config: dict[str, Any], field: str) -> Any:
if field == "task":
return config.get("run", {}).get("task", "")
if field == "start_url":
return config.get("run", {}).get("start_url", "")

value: Any = config
for part in field.split("."):
if not isinstance(value, dict):
return ""
value = value.get(part, "")
return value


def make_fingerprint(config: dict[str, Any]) -> str:
cache_config = CacheConfig(**config.get("cache", {}))
payload = {
field: _field_value(config, field)
for field in cache_config.fingerprint_fields
}
encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
return hashlib.sha256(encoded).hexdigest()[:16]


class ScriptCache:
def __init__(self, config: dict[str, Any] | CacheConfig | None = None):
if isinstance(config, CacheConfig):
self.config = config
else:
self.config = CacheConfig(**(config or {}))
self.directory = self.config.directory.expanduser()

@property
def enabled(self) -> bool:
return self.config.enabled

def _entry_dir(self, fingerprint: str) -> Path:
return self.directory / fingerprint

def get(self, fingerprint: str) -> CachedScript | None:
if not self.enabled:
return None

entry_dir = self._entry_dir(fingerprint)
metadata_path = entry_dir / "metadata.json"
script_path = entry_dir / "final_script.py"
trajectory_path = entry_dir / "trajectory.json"
if not metadata_path.is_file() or not script_path.is_file() or not trajectory_path.is_file():
return None

try:
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
self.invalidate(fingerprint)
return None

if self._is_expired(metadata):
self.invalidate(fingerprint)
return None

return CachedScript(
fingerprint=fingerprint,
directory=entry_dir,
script_path=script_path,
trajectory_path=trajectory_path,
metadata=metadata,
)

def put(
self,
fingerprint: str,
final_script_path: str | Path,
trajectory_path: str | Path,
*,
metadata: dict[str, Any] | None = None,
) -> CachedScript | None:
if not self.enabled:
return None

source_script = Path(final_script_path).expanduser()
source_trajectory = Path(trajectory_path).expanduser()
if not source_script.is_file() or not source_trajectory.is_file():
return None

entry_dir = self._entry_dir(fingerprint)
entry_dir.mkdir(parents=True, exist_ok=True)
script_path = entry_dir / "final_script.py"
trajectory_copy_path = entry_dir / "trajectory.json"
shutil.copy2(source_script, script_path)
shutil.copy2(source_trajectory, trajectory_copy_path)

entry_metadata = {
**(metadata or {}),
"fingerprint": fingerprint,
"created_at": datetime.now(timezone.utc).isoformat(),
"final_script_source": str(source_script),
"trajectory_source": str(source_trajectory),
}
(entry_dir / "metadata.json").write_text(
json.dumps(entry_metadata, indent=2),
encoding="utf-8",
)
return CachedScript(
fingerprint=fingerprint,
directory=entry_dir,
script_path=script_path,
trajectory_path=trajectory_copy_path,
metadata=entry_metadata,
)

def invalidate(self, fingerprint: str) -> None:
shutil.rmtree(self._entry_dir(fingerprint), ignore_errors=True)

def validate_url(self, start_url: str | None) -> bool:
if not self.config.validate_url or not start_url:
return True
try:
response = httpx.head(start_url, follow_redirects=True, timeout=10.0)
except httpx.HTTPError:
return False
return response.status_code < 400

def _is_expired(self, metadata: dict[str, Any]) -> bool:
if self.config.ttl_seconds <= 0:
return False
raw_created_at = metadata.get("created_at")
if not isinstance(raw_created_at, str):
return True
try:
created_at = datetime.fromisoformat(raw_created_at)
except ValueError:
return True
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
age = datetime.now(timezone.utc) - created_at
return age.total_seconds() > self.config.ttl_seconds
16 changes: 16 additions & 0 deletions src/webwright/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,29 @@
from pathlib import Path
from typing import Any

from pydantic import BaseModel, Field
import yaml

from webwright import package_dir

builtin_config_dir = package_dir / "config"


class CacheConfig(BaseModel):
enabled: bool = False
directory: Path = Path("~/.cache/webwright")
ttl_seconds: int = 604800
validate_url: bool = True
fingerprint_fields: list[str] = Field(
default_factory=lambda: [
"task",
"start_url",
"model.model_name",
"environment.environment_class",
]
)


def _nest_key_value(key: str, value: Any) -> dict[str, Any]:
parts = key.split(".")
nested: dict[str, Any] = value
Expand Down
11 changes: 11 additions & 0 deletions src/webwright/config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ run:
task_id:
start_url:

cache:
enabled: false
directory: ~/.cache/webwright
ttl_seconds: 604800
validate_url: true
fingerprint_fields:
- task
- start_url
- model.model_name
- environment.environment_class

agent:
agent_class: default
debug_log: true
Expand Down
Loading