diff --git a/AGENTS.md b/AGENTS.md index 331bb5f..dfccb21 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -91,7 +91,7 @@ System prompt is loaded by priority: `--system` flag > `~/.odek/IDENTITY.md` > c ### Security Architecture Layered prompt-injection / approval-fatigue defenses. Full reference: [docs/SECURITY.md](docs/SECURITY.md). -- **Untrusted-content wrapper** (`cmd/odek/untrusted.go`) — every tool whose output sources from outside the trust boundary (`browser`, `read_file`, `shell`, `search_files`, `multi_grep`, `transcribe`, any MCP tool) wraps results in ` source="...">…>`. Per-call nonce defeats wrapper-escape via literal close tag. +- **Untrusted-content wrapper** (`cmd/odek/untrusted.go`) — every tool whose output sources from outside the trust boundary (`browser`, `read_file`, `shell`, `search_files`, `multi_grep`, `transcribe`, `head_tail`, `diff`, `tr`, `sort`, `json_query`, `batch_patch`, `glob`, `file_info`, `tree`, `base64` file mode, `session_search`, `@-resources`, `--ctx` files, any MCP tool) wraps results in ` source="...">…>`. Browser page title and interactive-element text are wrapped in addition to the main content. Per-call nonce defeats wrapper-escape via literal close tag. - **Audit log** (`cmd/odek/audit.go` + `internal/session/audit.go`) — every `wrapUntrusted` call records source + content-hash + turn into `/audit/.json`. After each turn a divergence heuristic flags `suspicious_divergence=true` when the agent ingested untrusted content AND its tool calls referenced resources the user did not mention. Inspect with `odek audit ` / `odek audit --list`. - **Memory taint** (`internal/memory/provenance.go`) — `EpisodeProvenance` tracks Untrusted/Sources/UserApproved. Tainted episodes are stored but `Search()` filters them out, so a one-shot injection cannot persist via the episode pipeline. User must explicitly promote. - **Skill provenance gate** (`internal/skills/loader.go` + `cache.go`) — `Skill.Provenance{Untrusted, Sources, NeedsReview}`. NeedsReview skills pin to Lazy regardless of `auto_load`. `odek skill promote ` clears the flag after user review. @@ -100,6 +100,35 @@ Layered prompt-injection / approval-fatigue defenses. Full reference: [docs/SECU - **Approver friction** (`internal/danger/approver.go`, `cmd/odek/wsapprover.go`) — both TTYApprover and WSApprover engage friction mode after 3 approvals of the same class in 60s: require typing literal `approve`, 1.5s pause. Trust-class shortcut disabled for `destructive` + `blocked` regardless. - **Danger classifier bypass resistance** (`internal/danger/classifier.go`) — `normalize()` pre-processes: expand `$IFS` / `${IFS}`, extract `$(...)` / `` `...` `` substitutions, strip `command` / `exec` / `builtin` wrappers, collapse unquoted backslashes, basename absolute paths. Regression suite in `classifier_bypass_test.go`. - **WS Origin allowlist** (`cmd/odek/serve.go::checkLocalOrigin`) — rejects non-localhost upgrades. Closes CSRF-on-localhost. +- **REST API CSRF protection** (`cmd/odek/serve.go::requireLocalOrigin`) — state-changing HTTP endpoints (POST/PUT/PATCH/DELETE) require a localhost origin or no Origin header, and static responses set `X-Frame-Options: DENY` + `Content-Security-Policy: frame-ancestors 'none'` to block clickjacking. +- **Browser history cap** (`cmd/odek/browser_tool.go`) — navigation history is capped at 50 snapshots to prevent memory DoS from repeated `browser_navigate` calls. +- **Browser element cap** (`cmd/odek/browser_tool.go`) — the number of interactive elements extracted per page is capped at 500 so a hostile page cannot OOM the agent with thousands of links or buttons. +- **Search result bounds** (`cmd/odek/file_tool.go`, `cmd/odek/perf_tools.go`) — `search_files` and `multi_grep` enforce a max match limit (500) and a total returned-content cap (1 MiB) to avoid unbounded result JSON. +- **Perf-tool file-size cap** (`cmd/odek/perf_tools.go`) — `diff`, `base64`, `tr`, `sort`, `json_query`, and `batch_patch` reject files larger than 10 MiB to avoid loading multi-gigabyte files into memory. +- **Shell output cap** (`cmd/odek/shell.go`, `cmd/odek/perf_tools.go`) — `shell` and `parallel_shell` cap captured stdout/stderr at 1 MiB per stream to prevent memory DoS from commands that dump huge files. +- **Browser request timeout** (`cmd/odek/browser_tool.go`) — the browser HTTP client enforces a 30-second request timeout so a slow/malicious server cannot hang the agent turn. +- **Transcribe input/output guard** (`cmd/odek/transcribe_tool.go`) — rejects audio files larger than 10 MiB, caps whisper stdout at 10 MiB, and writes ffmpeg output to a temp file so it cannot clobber an existing `.wav` next to the source path. +- **Tree width cap** (`cmd/odek/perf_tools.go`) — the `tree` tool limits each directory listing to 1,000 entries to avoid OOM from directories with millions of files. +- **patch tool hardening** (`cmd/odek/file_tool.go`) — `patch` rejects files larger than 10 MiB and preserves the original file mode instead of resetting it to 0644. +- **glob tool hardening** (`cmd/odek/file_tool.go`) — `glob` caps results at 1,000 matches and wraps returned paths as untrusted content. +- **Sub-agent task-file cap** (`cmd/odek/subagent.go`) — `odek subagent --task ` rejects task files larger than 10 MiB before loading them into memory. +- **session_search hardening** (`cmd/odek/session_search_tool.go`) — the `get` action returns at most the 100 most recent messages and wraps each message content, task, and buffer entry as untrusted; `list`/`search`/`find` also wrap session tasks. +- **@-resource / --ctx prompt wrapping** (`cmd/odek/refs.go`, `cmd/odek/serve.go`) — content resolved from `@file` references and `--ctx` files is wrapped as untrusted before being inserted into the prompt. +- **Config file size cap** (`internal/config/loader.go`) — `~/.odek/config.json` and `./odek.json` are rejected if larger than 5 MiB to prevent OOM from a malicious or broken config at startup. +- **Resource resolver size cap** (`internal/resource/resource.go`) — `@-resource` file loads are capped at 1 MiB to prevent OOM from `@hugefile` references. +- **Resource resolver symlink hardening** (`internal/resource/resource.go`) — `FileResolver.Search` uses `os.Lstat` (not `os.Stat`) for search-result metadata, so symlinks cannot leak the size of arbitrary targets outside the workspace. +- **Sub-agent summary cap** (`cmd/odek/subagent_tool.go`) — each sub-agent result included in the `delegate_tasks` summary is truncated to 100 KiB to prevent memory DoS. +- **Tree path wrapping** (`cmd/odek/perf_tools.go`) — the `tree` tool wraps every filesystem-derived path as untrusted content. +- **head_tail output cap** (`cmd/odek/perf_tools.go`) — `head_tail` truncates returned lines so total content stays within 1 MiB, preventing multi-file/multi-line memory DoS. +- **search_files symlink hardening** (`cmd/odek/file_tool.go`) — the `files` target uses `Lstat` (not `Stat`) and skips symlinks in the glob branch, closing metadata disclosure via symlinked paths. +- **AGENTS.md size cap** (`odek.go`) — project-level `AGENTS.md` is ignored if larger than 256 KiB to prevent OOM/prompt stuffing from a malicious repo. +- **IDENTITY.md size cap** (`cmd/odek/main.go`) — `~/.odek/IDENTITY.md` is ignored if larger than 256 KiB, falling back to the default identity. +- **patch / batch_patch output expansion cap** (`cmd/odek/file_tool.go`, `cmd/odek/perf_tools.go`) — the post-replacement result is capped at 10 MiB so `ReplaceAll` cannot explode memory. +- **write_file content cap** (`cmd/odek/file_tool.go`) — the `content` argument is capped at 1 MiB to prevent disk exhaustion and memory pressure from a single enormous tool call. +- **file_info confinement + wrapping** (`cmd/odek/file_tool.go`) — `file_info` respects the same `restrictToCWD` path confinement as `write_file`/`patch`, and the returned path is wrapped as untrusted content. +- **WebSocket message-size cap** (`cmd/odek/serve.go`) — `odek serve` sets `MaxPayloadBytes` on every WebSocket connection so a local client cannot OOM the server with a huge frame. +- **Session file size cap** (`internal/session/session.go`) — session files larger than 32 MiB are rejected by `Load()` to prevent OOM from tampered or corrupted transcripts. +- **Skill file size cap** (`internal/skills/loader.go`) — `SKILL.md` files larger than 1 MiB are skipped so a malicious project cannot OOM the process at startup or bloat the system prompt. - **Serve sandbox default-on** — `odek serve` enables `--sandbox` automatically unless `--no-sandbox` is passed. - **Secret redaction** (`internal/redact/redact.go`) — 20+ patterns: OpenAI, Anthropic, GitHub PAT, AWS, PEM, JWT, Vault, Google OAuth, SendGrid, Discord, DB URLs, etc. diff --git a/cmd/odek/browser_tool.go b/cmd/odek/browser_tool.go index 98ae972..9c25f31 100644 --- a/cmd/odek/browser_tool.go +++ b/cmd/odek/browser_tool.go @@ -9,6 +9,7 @@ import ( "regexp" "strings" "sync" + "time" "github.com/BackendStack21/odek/internal/danger" ) @@ -44,6 +45,15 @@ type browserSnapshot struct { Elements []clickableRef `json:"elements,omitempty"` } +// maxBrowserHistory caps the number of snapshots retained in browser state to +// prevent memory DoS from repeated navigate actions. +const maxBrowserHistory = 50 + +// maxBrowserElements caps the number of interactive elements extracted from a +// page to prevent a hostile page from OOMing the agent with thousands of links +// or buttons. +const maxBrowserElements = 500 + // browserState holds the shared state for one browser session. type browserState struct { mu sync.Mutex @@ -62,12 +72,17 @@ type browserTool struct { trustedClasses map[danger.RiskClass]bool } +// browserRequestTimeout bounds each browser HTTP request. Tests may lower it to +// verify timeout behavior. +var browserRequestTimeout = 30 * time.Second + func newBrowserTool(dc danger.DangerousConfig) *browserTool { t := &browserTool{ state: &browserState{nextRef: 1}, dangerousConfig: dc, } t.client = &http.Client{ + Timeout: browserRequestTimeout, CheckRedirect: t.checkRedirect, Transport: ssrfGuardedTransport(), } @@ -167,6 +182,7 @@ func (t *browserTool) Call(argsJSON string) (string, error) { } if t.client == nil { t.client = &http.Client{ + Timeout: browserRequestTimeout, CheckRedirect: t.checkRedirect, Transport: ssrfGuardedTransport(), } @@ -226,10 +242,15 @@ func (t *browserTool) doNavigate(rawURL string) (string, error) { html := string(body) snap := parseHTML(html, rawURL, resp.StatusCode) - // Store in state + // Store in state. Keep a persistent copy of the snapshot for current; the + // local variable's address would otherwise escape to the heap implicitly. t.state.mu.Lock() t.state.history = append(t.state.history, snap) - t.state.current = &snap + if len(t.state.history) > maxBrowserHistory { + t.state.history = t.state.history[len(t.state.history)-maxBrowserHistory:] + } + snapCopy := snap + t.state.current = &snapCopy t.state.nextRef = len(snap.Elements) + 1 t.state.mu.Unlock() @@ -363,6 +384,9 @@ func parseHTML(html, pageURL string, status int) browserSnapshot { // Extract links for _, m := range reLink.FindAllStringSubmatch(html, -1) { + if len(elements) >= maxBrowserElements { + break + } href := strings.TrimSpace(m[1]) text := strings.TrimSpace(m[2]) if href == "" || text == "" || href == "#" || strings.HasPrefix(href, "javascript:") { @@ -388,6 +412,9 @@ func parseHTML(html, pageURL string, status int) browserSnapshot { // Extract buttons and inputs for _, m := range reButton.FindAllStringSubmatch(html, -1) { + if len(elements) >= maxBrowserElements { + break + } text := strings.TrimSpace(m[1]) if text == "" { text = "button" @@ -403,6 +430,9 @@ func parseHTML(html, pageURL string, status int) browserSnapshot { } for _, m := range reInput.FindAllStringSubmatch(html, -1) { + if len(elements) >= maxBrowserElements { + break + } tag := m[0] text := "" if vm := reInputVal.FindStringSubmatch(tag); len(vm) > 1 { @@ -426,6 +456,12 @@ func parseHTML(html, pageURL string, status int) browserSnapshot { snap.Content = strings.Join(contentParts, "\n") snap.Elements = elements + // Title and element text come from the page — wrap them as untrusted content. + snap.Title = wrapUntrusted(pageURL, snap.Title) + for i := range snap.Elements { + snap.Elements[i].Text = wrapUntrusted(pageURL, snap.Elements[i].Text) + } + return snap } diff --git a/cmd/odek/browser_tool_test.go b/cmd/odek/browser_tool_test.go index 8e47494..df3cba7 100644 --- a/cmd/odek/browser_tool_test.go +++ b/cmd/odek/browser_tool_test.go @@ -34,8 +34,8 @@ func TestBrowser_Navigate(t *testing.T) { if r.Error != "" { t.Fatalf("navigate error: %s", r.Error) } - if r.Title != "Test Page" { - t.Errorf("title = %q, want %q", r.Title, "Test Page") + if unwrapUntrusted(r.Title) != "Test Page" { + t.Errorf("title = %q, want %q", unwrapUntrusted(r.Title), "Test Page") } if !strings.Contains(r.Content, "Hello World") { t.Errorf("content missing 'Hello World': %q", r.Content) diff --git a/cmd/odek/file_tool.go b/cmd/odek/file_tool.go index 7b4568a..194d4f9 100644 --- a/cmd/odek/file_tool.go +++ b/cmd/odek/file_tool.go @@ -25,6 +25,22 @@ const maxLines = 2000 // memory exhaustion from huge files. const maxReadBytes = 1 << 20 // 1 MiB +// maxWriteFileContentBytes caps the content argument of write_file to prevent +// disk exhaustion and memory pressure from a single enormous tool call. +const maxWriteFileContentBytes = maxReadBytes // 1 MiB + +// maxSearchLimit caps the number of matches returned by search_files to +// prevent unbounded result JSON from exhausting memory. +const maxSearchLimit = 500 + +// maxSearchResultBytes caps the total returned content bytes for a single +// search_files / multi_grep content query. +const maxSearchResultBytes = maxReadBytes + +// maxGlobMatches caps the number of paths returned by the glob tool to prevent +// unbounded JSON responses from broad patterns. +const maxGlobMatches = 1000 + type readFileTool struct { dangerousConfig danger.DangerousConfig } @@ -203,6 +219,9 @@ func (t *writeFileTool) Call(argsJSON string) (string, error) { if args.Path == "" { return jsonError("path is required") } + if len(args.Content) > maxWriteFileContentBytes { + return jsonError(fmt.Sprintf("content too large (%d bytes, max %d)", len(args.Content), maxWriteFileContentBytes)) + } // Path confinement: when restrictToCWD is enabled, reject paths that // escape the working directory via ".." traversal or absolute paths. @@ -371,6 +390,9 @@ func (t *searchFilesTool) Call(argsJSON string) (string, error) { if args.Limit <= 0 { args.Limit = maxMatches } + if args.Limit > maxSearchLimit { + args.Limit = maxSearchLimit + } // Security: check search path risk := danger.ClassifyPath(args.Path) @@ -398,6 +420,7 @@ func (t *searchFilesTool) searchContent(args searchFilesArgs) (string, error) { var matches []searchMatch limit := args.Limit + resultBytes := 0 err = filepath.Walk(args.Path, func(path string, info os.FileInfo, err error) error { if err != nil { @@ -455,10 +478,16 @@ func (t *searchFilesTool) searchContent(args searchFilesArgs) (string, error) { lineNum++ line := scanner.Text() if re.MatchString(line) { + trimmed := strings.TrimSpace(line) + if resultBytes+len(trimmed) > maxSearchResultBytes { + limit = len(matches) + break + } + resultBytes += len(trimmed) matches = append(matches, searchMatch{ Path: path, Line: lineNum, - Content: wrapUntrusted(fmt.Sprintf("%s:%d", path, lineNum), strings.TrimSpace(line)), + Content: wrapUntrusted(fmt.Sprintf("%s:%d", path, lineNum), trimmed), }) if len(matches) >= limit { break @@ -493,12 +522,18 @@ func (t *searchFilesTool) searchFiles(args searchFilesArgs) (string, error) { return jsonError(fmt.Sprintf("invalid glob %q: %v", pattern, err)) } for _, p := range globMatches { - info, err := os.Stat(p) - if err == nil && !info.IsDir() { - matches = append(matches, searchMatch{Path: p}) - if len(matches) >= limit { - break - } + // Lstat so symlinks are not followed to their targets for metadata. + info, err := os.Lstat(p) + if err != nil { + continue + } + // Skip directories and symlinks — same policy as the walk branch. + if info.IsDir() || info.Mode()&os.ModeSymlink != 0 { + continue + } + matches = append(matches, searchMatch{Path: wrapUntrusted("search_files:"+p, p)}) + if len(matches) >= limit { + break } } } else { @@ -522,7 +557,7 @@ func (t *searchFilesTool) searchFiles(args searchFilesArgs) (string, error) { } match, _ := filepath.Match(pattern, info.Name()) if match { - matches = append(matches, searchMatch{Path: path}) + matches = append(matches, searchMatch{Path: wrapUntrusted("search_files:"+path, path)}) if len(matches) >= limit { return filepath.SkipAll } @@ -531,12 +566,13 @@ func (t *searchFilesTool) searchFiles(args searchFilesArgs) (string, error) { }) } - // Sort by modification time (newest first) + // Sort by modification time (newest first). Use Lstat so symlinks are not + // followed and their own metadata is used for sorting. sort.Slice(matches, func(i, j int) bool { - fi, _ := os.Stat(matches[i].Path) - fj, _ := os.Stat(matches[j].Path) + fi, _ := os.Lstat(unwrapUntrusted(matches[i].Path)) + fj, _ := os.Lstat(unwrapUntrusted(matches[j].Path)) if fi == nil || fj == nil { - return matches[i].Path < matches[j].Path + return unwrapUntrusted(matches[i].Path) < unwrapUntrusted(matches[j].Path) } return fi.ModTime().After(fj.ModTime()) }) @@ -638,6 +674,16 @@ func (t *patchTool) Call(argsJSON string) (string, error) { } defer f.Close() + // Reject files that would exhaust memory during the read/edit/write cycle. + info, err := f.Stat() + if err != nil { + return jsonError(fmt.Sprintf("cannot stat %q: %v", args.Path, err)) + } + if info.Size() > maxFileReadBytes { + return jsonError(fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), maxFileReadBytes)) + } + origMode := info.Mode().Perm() + // Read content through the opened fd (not re-opening the path) var sb strings.Builder _, err = io.Copy(&sb, f) @@ -657,6 +703,9 @@ func (t *patchTool) Call(argsJSON string) (string, error) { } else { modified = strings.Replace(original, args.OldString, args.NewString, 1) } + if len(modified) > maxFileReadBytes { + return jsonError(fmt.Sprintf("patch result too large (%d bytes, max %d)", len(modified), maxFileReadBytes)) + } // Generate a simple diff diff := fmt.Sprintf("--- a/%s\n+++ b/%s\n@@ -1 +1 @@\n-%s\n+%s\n", @@ -681,7 +730,7 @@ func (t *patchTool) Call(argsJSON string) (string, error) { os.Remove(tmpPath) return jsonError(fmt.Sprintf("cannot write %q: %v", args.Path, err)) } - if err := tmpFile.Chmod(0644); err != nil { + if err := tmpFile.Chmod(origMode); err != nil { tmpFile.Close() os.Remove(tmpPath) return jsonError(fmt.Sprintf("cannot set permissions %q: %v", args.Path, err)) @@ -1137,6 +1186,9 @@ func (t *globTool) Call(argsJSON string) (result string, err error) { if args.Limit <= 0 { args.Limit = maxMatches } + if args.Limit > maxGlobMatches { + args.Limit = maxGlobMatches + } // Security: classify search root path risk := danger.ClassifyPath(args.Path) @@ -1260,6 +1312,10 @@ func (t *globTool) Call(argsJSON string) (result string, err error) { return fi.ModTime().After(fj.ModTime()) }) + for i := range matches { + matches[i].Path = wrapUntrusted("glob:"+args.Path, matches[i].Path) + } + return jsonResult(globResult{Matches: matches}) } @@ -1272,6 +1328,7 @@ func (t *globTool) Call(argsJSON string) (result string, err error) { type fileInfoTool struct { dangerousConfig danger.DangerousConfig + restrictToCWD bool // when true, reject paths escaping the working directory } func (t *fileInfoTool) Name() string { return "file_info" } @@ -1326,6 +1383,16 @@ func (t *fileInfoTool) Call(argsJSON string) (result string, err error) { return jsonError("path is required") } + // Path confinement: when restrictToCWD is enabled, reject paths that + // escape the working directory via ".." traversal or absolute paths. + if t.restrictToCWD { + resolved, err := confineToCWD(args.Path) + if err != nil { + return jsonError(err.Error()) + } + args.Path = resolved + } + // Security: classify path risk := danger.ClassifyPath(args.Path) if err := t.dangerousConfig.CheckOperation(danger.ToolOperation{ @@ -1359,6 +1426,10 @@ func (t *fileInfoTool) Call(argsJSON string) (result string, err error) { IsRegular: lInfo.Mode().IsRegular(), } + // file_info output originates from the filesystem trust boundary, so + // mark the returned path as untrusted. + fi.Path = wrapUntrusted("file_info:"+args.Path, fi.Path) + return jsonResult(fi) } diff --git a/cmd/odek/file_tool_test.go b/cmd/odek/file_tool_test.go index 691d4f5..fc99d62 100644 --- a/cmd/odek/file_tool_test.go +++ b/cmd/odek/file_tool_test.go @@ -402,7 +402,7 @@ func TestSearchFiles_FindByName(t *testing.T) { } for _, m := range r.Matches { - name := filepath.Base(m.Path) + name := filepath.Base(unwrapUntrusted(m.Path)) if name != "main.go" && name != "main_test.go" { t.Errorf("unexpected match: %s", name) } @@ -939,8 +939,9 @@ func TestSearchFiles_GlobWithPathSeparator(t *testing.T) { if len(r.Matches) != 1 { t.Fatalf("expected 1 match for 'subdir/*.txt', got %d", len(r.Matches)) } - if !strings.HasSuffix(r.Matches[0].Path, "subdir/result.txt") && !strings.HasSuffix(r.Matches[0].Path, "subdir\\result.txt") { - t.Errorf("unexpected match path: %s", r.Matches[0].Path) + p := unwrapUntrusted(r.Matches[0].Path) + if !strings.HasSuffix(p, "subdir/result.txt") && !strings.HasSuffix(p, "subdir\\result.txt") { + t.Errorf("unexpected match path: %s", p) } } @@ -1491,8 +1492,8 @@ func TestSearchFiles_FilesTargetWithPathSeparator(t *testing.T) { if len(r.Matches) != 1 { t.Fatalf("expected 1 match for 'sub/*.txt', got %d", len(r.Matches)) } - if !strings.Contains(r.Matches[0].Path, "nested.txt") { - t.Errorf("expected nested.txt match, got: %s", r.Matches[0].Path) + if !strings.Contains(unwrapUntrusted(r.Matches[0].Path), "nested.txt") { + t.Errorf("expected nested.txt match, got: %s", unwrapUntrusted(r.Matches[0].Path)) } } @@ -1512,8 +1513,8 @@ func TestSearchFiles_FilesTargetHiddenDirSkipped(t *testing.T) { } mustUnmarshal(t, result, &r) for _, m := range r.Matches { - if strings.Contains(m.Path, ".hidden") { - t.Errorf("should not include hidden dir contents: %s", m.Path) + if strings.Contains(unwrapUntrusted(m.Path), ".hidden") { + t.Errorf("should not include hidden dir contents: %s", unwrapUntrusted(m.Path)) } } if len(r.Matches) != 1 { diff --git a/cmd/odek/main.go b/cmd/odek/main.go index 58cd73d..3bcd20b 100644 --- a/cmd/odek/main.go +++ b/cmd/odek/main.go @@ -158,6 +158,11 @@ func buildSystemPrompt(resolved config.ResolvedConfig) string { return base } +// maxIdentityFileBytes caps the size of ~/.odek/IDENTITY.md that will be +// loaded into the system prompt. A tampered or corrupted identity file could +// otherwise OOM the process or stuff every prompt. +const maxIdentityFileBytes = 256 * 1024 // 256 KiB + // loadIdentityFile reads ~/.odek/IDENTITY.md and returns its content. // Returns defaultSystem if the file does not exist or cannot be read. func loadIdentityFile() string { @@ -166,6 +171,14 @@ func loadIdentityFile() string { return defaultSystem } path := filepath.Join(home, ".odek", "IDENTITY.md") + info, err := os.Stat(path) + if err != nil { + return defaultSystem + } + if info.Size() > maxIdentityFileBytes { + fmt.Fprintf(os.Stderr, "odek: warning: IDENTITY.md is too large (%d bytes, max %d) — using default identity\n", info.Size(), maxIdentityFileBytes) + return defaultSystem + } data, err := os.ReadFile(path) if err != nil { return defaultSystem @@ -1140,7 +1153,7 @@ func builtinTools(dc danger.DangerousConfig, sm *skills.SkillManager, approver d &patchTool{dangerousConfig: dc, restrictToCWD: true}, &batchReadTool{dangerousConfig: dc}, &globTool{dangerousConfig: dc}, - &fileInfoTool{dangerousConfig: dc}, + &fileInfoTool{dangerousConfig: dc, restrictToCWD: true}, &batchPatchTool{dangerousConfig: dc, restrictToCWD: true}, ¶llelShellTool{dangerousConfig: dc, approver: approver}, newHTTPBatchTool(dc), diff --git a/cmd/odek/next_security_vulnerabilities_test.go b/cmd/odek/next_security_vulnerabilities_test.go new file mode 100644 index 0000000..ef13bd7 --- /dev/null +++ b/cmd/odek/next_security_vulnerabilities_test.go @@ -0,0 +1,1140 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/BackendStack21/odek/internal/config" + "github.com/BackendStack21/odek/internal/danger" + "github.com/BackendStack21/odek/internal/llm" + "github.com/BackendStack21/odek/internal/resource" + "github.com/BackendStack21/odek/internal/session" + "github.com/BackendStack21/odek/internal/skills" +) + +// ── 1. Browser history must be capped to avoid memory DoS ──────────────── + +func TestBrowser_HistoryCap(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "page") + })) + defer srv.Close() + + tool := newBrowserTool(danger.DangerousConfig{}) + for i := 0; i < 55; i++ { + callJSON(t, tool, fmt.Sprintf(`{"action":"navigate","url":%q}`, srv.URL)) + } + + if len(tool.state.history) > 50 { + t.Fatalf("browser history grew unbounded: got %d snapshots (max expected 50)", len(tool.state.history)) + } +} + +// ── 2. search_files / multi_grep must cap limit and result size ────────── + +func TestSearchFiles_LimitCap(t *testing.T) { + dir := t.TempDir() + var lines []string + for i := 0; i < 600; i++ { + lines = append(lines, "match") + } + os.WriteFile(filepath.Join(dir, "data.txt"), []byte(strings.Join(lines, "\n")), 0644) + + tool := &searchFilesTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"pattern":"match","path":%q,"limit":10000}`, dir)) + var r struct { + Matches []any `json:"matches"` + } + mustUnmarshal(t, result, &r) + if len(r.Matches) > 500 { + t.Fatalf("search_files limit was not capped: got %d matches", len(r.Matches)) + } +} + +func TestSearchFiles_ResultByteCap(t *testing.T) { + dir := t.TempDir() + line := strings.Repeat("x", 500*1024) + " MATCH" + os.WriteFile(filepath.Join(dir, "big.txt"), []byte(line+"\n"+line+"\n"+line+"\n"), 0644) + + tool := &searchFilesTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"pattern":"MATCH","path":%q,"limit":10}`, dir)) + var r struct { + Matches []struct { + Content string `json:"content"` + } `json:"matches"` + } + mustUnmarshal(t, result, &r) + + total := 0 + for _, m := range r.Matches { + total += len(unwrapUntrusted(m.Content)) + } + if total > 1024*1024 { + t.Fatalf("search_files returned %d bytes of content, expected cap near 1 MiB", total) + } +} + +func TestMultiGrep_LimitCap(t *testing.T) { + dir := t.TempDir() + var lines []string + for i := 0; i < 600; i++ { + lines = append(lines, "match") + } + os.WriteFile(filepath.Join(dir, "data.txt"), []byte(strings.Join(lines, "\n")), 0644) + + tool := &multiGrepTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"patterns":["match"],"path":%q,"limit":10000}`, dir)) + var r struct { + Results []struct { + Matches []any `json:"matches"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) != 1 { + t.Fatalf("expected 1 pattern result, got %d", len(r.Results)) + } + if len(r.Results[0].Matches) > 500 { + t.Fatalf("multi_grep limit was not capped: got %d matches", len(r.Results[0].Matches)) + } +} + +// ── 3. perf tools must reject (not load) huge files ────────────────────── + +func TestBase64_RejectsHugeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "huge.bin") + os.WriteFile(path, make([]byte, 15*1024*1024), 0644) + + tool := &base64Tool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q}`, path)) + var r struct { + Encoded string `json:"encoded,omitempty"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Encoded != "" { + t.Fatalf("base64 should reject a 15 MiB file, but returned encoded data") + } + if r.Error == "" { + t.Fatalf("base64 should return an error for a 15 MiB file") + } +} + +func TestDiff_RejectsHugeFile(t *testing.T) { + dir := t.TempDir() + pathA := filepath.Join(dir, "a.txt") + pathB := filepath.Join(dir, "b.txt") + os.WriteFile(pathA, make([]byte, 15*1024*1024), 0644) + os.WriteFile(pathB, []byte("small"), 0644) + + tool := &diffTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"path_a":%q,"path_b":%q}`, pathA, pathB)) + var r struct { + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Error == "" { + t.Fatalf("diff should return an error for a 15 MiB file") + } +} + +func TestJsonQuery_RejectsHugeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "big.json") + // Build a ~15 MB JSON object without newlines so it looks like one value. + big := `{"x":"` + strings.Repeat("a", 15*1024*1024) + `"}` + os.WriteFile(path, []byte(big), 0644) + + tool := &jsonQueryTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"query":"x"}`, path)) + var r struct { + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Error == "" { + t.Fatalf("json_query should return an error for a 15 MiB file") + } +} + +// ── 4. serve state-changing endpoints must require a local origin ──────── + +func TestServe_CSRF_RejectForeignOrigin(t *testing.T) { + base := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := requireLocalOrigin(base) + + req := httptest.NewRequest(http.MethodPost, "/api/cancel", nil) + req.Header.Set("Origin", "http://evil.example.com") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusForbidden { + t.Fatalf("foreign origin POST should be rejected (403), got %d", rr.Code) + } +} + +func TestServe_CSRF_AllowsEmptyOrigin(t *testing.T) { + base := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + handler := requireLocalOrigin(base) + + req := httptest.NewRequest(http.MethodPost, "/api/cancel", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusNoContent { + t.Fatalf("empty-origin POST should be allowed, got %d", rr.Code) + } +} + +func TestServe_CSRF_AllowsLocalhostOrigin(t *testing.T) { + base := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + handler := requireLocalOrigin(base) + + for _, origin := range []string{"http://localhost:8080", "http://127.0.0.1:8080"} { + req := httptest.NewRequest(http.MethodPost, "/api/cancel", nil) + req.Header.Set("Origin", origin) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusNoContent { + t.Fatalf("localhost origin %q should be allowed, got %d", origin, rr.Code) + } + } +} + +func TestServe_StaticSecurityHeaders(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handleStatic().ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("static handler returned %d", rr.Code) + } + if rr.Header().Get("X-Frame-Options") == "" { + t.Error("static handler missing X-Frame-Options") + } + if rr.Header().Get("Content-Security-Policy") == "" { + t.Error("static handler missing Content-Security-Policy") + } +} + +// ── 5. file-reading perf tools must wrap content as untrusted ──────────── + +func TestHeadTail_WrapsContent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + os.WriteFile(path, []byte("hello world\n"), 0644) + + tool := &headTailTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"files":[{"path":%q}],"lines":10}`, path)) + var r struct { + Results []struct { + Lines []string `json:"lines"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) == 0 || len(r.Results[0].Lines) == 0 { + t.Fatal("expected at least one line") + } + if !strings.HasPrefix(r.Results[0].Lines[0], " 1024*1024+200 { + t.Fatalf("shell returned %d bytes, expected cap near 1 MiB", len(body)) + } +} + +func TestParallelShell_CapsOutputSize(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "huge.txt") + os.WriteFile(path, []byte(strings.Repeat("x", 15*1024*1024)), 0644) + + tool := ¶llelShellTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"commands":[{"command":"cat %s"}]}`, path)) + var r struct { + Results []struct { + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + Error string `json:"error,omitempty"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(r.Results)) + } + out := r.Results[0].Stdout + r.Results[0].Stderr + if len(out) > 1024*1024+200 { + t.Fatalf("parallel_shell returned %d bytes, expected cap near 1 MiB", len(out)) + } +} + +// ── 7. Browser must enforce an HTTP request timeout ────────────────────── + +func TestBrowser_NavigateTimeout(t *testing.T) { + orig := browserRequestTimeout + browserRequestTimeout = 100 * time.Millisecond + defer func() { browserRequestTimeout = orig }() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(300 * time.Millisecond) + fmt.Fprint(w, "page") + })) + defer srv.Close() + + tool := newBrowserTool(danger.DangerousConfig{}) + result := callJSON(t, tool, fmt.Sprintf(`{"action":"navigate","url":%q}`, srv.URL)) + var r struct { + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Error == "" || !strings.Contains(strings.ToLower(r.Error), "timeout") { + t.Fatalf("browser should time out on a slow server, got: %q", r.Error) + } +} + +// ── 8. batch_patch must reject huge files and wrap diff output ─────────── + +func TestBatchPatch_RejectsHugeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "huge.txt") + os.WriteFile(path, []byte(strings.Repeat("x", 15*1024*1024)), 0644) + + tool := &batchPatchTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"patches":[{"path":%q,"old_string":"xxx","new_string":"yyy"}]}`, path)) + var r struct { + Results []struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(r.Results)) + } + if r.Results[0].Success { + t.Fatal("batch_patch should reject a 15 MiB file") + } + if r.Results[0].Error == "" { + t.Fatal("batch_patch should return an error for a 15 MiB file") + } +} + +func TestBatchPatch_WrapsDiff(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + os.WriteFile(path, []byte("hello world\n"), 0644) + + tool := &batchPatchTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"patches":[{"path":%q,"old_string":"hello","new_string":"goodbye"}]}`, path)) + var r struct { + Results []struct { + Diff string `json:"diff"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) == 0 || !strings.HasPrefix(r.Results[0].Diff, " 1000 { + t.Fatalf("tree did not cap directory width: got %d children", len(r.Tree.Children)) + } +} + + +// ── 11. patch must reject huge files and preserve original permissions ─── + +func TestPatch_RejectsHugeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "huge.txt") + os.WriteFile(path, []byte(strings.Repeat("x", 15*1024*1024)), 0644) + + tool := &patchTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"old_string":"xxx","new_string":"yyy"}`, path)) + var r struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Success { + t.Fatal("patch should reject a 15 MiB file") + } + if !strings.Contains(r.Error, "too large") { + t.Fatalf("patch should reject huge file with a size error, got: %q", r.Error) + } +} + +func TestPatch_PreservesFileMode(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "script.sh") + os.WriteFile(path, []byte("#!/bin/sh\necho hello\n"), 0755) + + tool := &patchTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"old_string":"hello","new_string":"world"}`, path)) + var r struct { + Success bool `json:"success"` + } + mustUnmarshal(t, result, &r) + if !r.Success { + t.Fatal("patch failed") + } + + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.Mode().Perm() != 0755 { + t.Fatalf("patch changed mode from 0755 to %04o", info.Mode().Perm()) + } +} + +// ── 12. glob must cap match count and wrap paths as untrusted ──────────── + +func TestGlob_CapsMatchCount(t *testing.T) { + dir := t.TempDir() + for i := 0; i < 1500; i++ { + os.WriteFile(filepath.Join(dir, fmt.Sprintf("file%d.txt", i)), []byte("x"), 0644) + } + + tool := &globTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"pattern":"*","path":%q,"limit":10000}`, dir)) + var r struct { + Matches []struct { + Path string `json:"path"` + } `json:"matches"` + } + mustUnmarshal(t, result, &r) + if len(r.Matches) > 1000 { + t.Fatalf("glob did not cap match count: got %d", len(r.Matches)) + } + if len(r.Matches) == 0 { + t.Fatal("expected at least one match") + } + if !strings.HasPrefix(r.Matches[0].Path, " 100 { + t.Fatalf("session_search get did not cap messages: got %d", len(r.SessionMessages)) + } + if len(r.SessionMessages) == 0 { + t.Fatal("expected at least one message") + } + if !strings.HasPrefix(r.SessionMessages[0].Content, " 1024*1024+500 { + t.Fatalf("delegate_tasks summary returned %d bytes, expected cap near 1 MiB", len(result)) + } +} + +// ── 20. patch / batch_patch must cap ReplaceAll expansion ──────────────── + +func TestPatch_RejectsOutputExpansion(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data.txt") + // 2,000 'a' chars. Replacing each with 10,000 'x' => ~20M chars. + os.WriteFile(path, []byte(strings.Repeat("a", 2000)), 0644) + + tool := &patchTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"old_string":"a","new_string":%q,"replace_all":true}`, path, strings.Repeat("x", 10000))) + var r struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Success { + t.Fatal("patch should reject a ReplaceAll that explodes output size") + } + if !strings.Contains(r.Error, "too large") { + t.Fatalf("expected size error, got: %q", r.Error) + } +} + +func TestBatchPatch_RejectsOutputExpansion(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data.txt") + os.WriteFile(path, []byte(strings.Repeat("a", 2000)), 0644) + + tool := &batchPatchTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"patches":[{"path":%q,"old_string":"a","new_string":%q,"replace_all":true}]}`, path, strings.Repeat("x", 10000))) + var r struct { + Results []struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) != 1 || r.Results[0].Success { + t.Fatal("batch_patch should reject a ReplaceAll that explodes output size") + } + if !strings.Contains(r.Results[0].Error, "too large") { + t.Fatalf("expected size error, got: %q", r.Results[0].Error) + } +} + +// ── 21. write_file must cap content size to prevent DoS / disk exhaustion ─ + +func TestWriteFile_CapsContentSize(t *testing.T) { + path := filepath.Join(t.TempDir(), "out.txt") + huge := strings.Repeat("x", maxWriteFileContentBytes+1) + + tool := &writeFileTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"content":%q}`, path, huge)) + var r struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Success { + t.Fatal("write_file should reject content above maxWriteFileContentBytes") + } + if !strings.Contains(r.Error, "too large") { + t.Fatalf("expected size error, got: %q", r.Error) + } +} + +// ── 22. file_info must respect restrictToCWD and wrap its output ────────── + +func TestFileInfo_RestrictToCWD(t *testing.T) { + tool := &fileInfoTool{restrictToCWD: true} + result := callJSON(t, tool, `{"path":"/etc/passwd"}`) + var r struct { + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Error == "" { + t.Fatal("file_info with restrictToCWD=true should reject paths outside CWD") + } +} + +func TestFileInfo_WrapsPath(t *testing.T) { + t.Chdir(t.TempDir()) + os.WriteFile("target.txt", []byte("hello"), 0644) + + tool := &fileInfoTool{restrictToCWD: true} + result := callJSON(t, tool, `{"path":"target.txt"}`) + var r struct { + Path string `json:"path"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Error != "" { + t.Fatalf("unexpected error: %s", r.Error) + } + if !strings.HasPrefix(r.Path, " 0") + } + if !strings.HasPrefix(r.Encoded, "Evil Titleclick me`) + })) + defer srv.Close() + + tool := newBrowserTool(danger.DangerousConfig{}) + result := callJSON(t, tool, fmt.Sprintf(`{"action":"navigate","url":%q}`, srv.URL)) + var r struct { + Title string `json:"title"` + Elements []struct { + Text string `json:"text"` + URL string `json:"url"` + } `json:"elements"` + } + mustUnmarshal(t, result, &r) + if !strings.HasPrefix(r.Title, "")) + for i := 0; i < 1500; i++ { + fmt.Fprintf(w, `link %d`, i, i) + } + w.Write([]byte("")) + })) + defer srv.Close() + + tool := newBrowserTool(danger.DangerousConfig{}) + result := callJSON(t, tool, fmt.Sprintf(`{"action":"navigate","url":%q}`, srv.URL)) + var r struct { + Elements []any `json:"elements"` + } + mustUnmarshal(t, result, &r) + if len(r.Elements) > 1000 { + t.Fatalf("browser did not cap element count: got %d", len(r.Elements)) + } +} + +// ── 25. tree must wrap filesystem-derived paths as untrusted ────────────── + +func TestTree_WrapsPaths(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "note.txt"), []byte("hello"), 0644) + + tool := &treeTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"max_depth":1}`, dir)) + var r struct { + Tree struct { + Path string `json:"path"` + Children []struct { + Path string `json:"path"` + } `json:"children"` + } `json:"tree"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if r.Error != "" { + t.Fatalf("tree error: %s", r.Error) + } + if !strings.HasPrefix(r.Tree.Path, " 2 MB of content, exceeding the 1 MiB cap. + var lines []string + for i := 0; i < 10; i++ { + lines = append(lines, fmt.Sprintf("line-%d-%s", i, strings.Repeat("x", 200*1024))) + } + os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) + + tool := &headTailTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"files":[{"path":%q}],"lines":100}`, path)) + var r struct { + Results []struct { + Lines []string `json:"lines"` + Total int `json:"total"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(r.Results)) + } + + total := 0 + for _, line := range r.Results[0].Lines { + total += len(unwrapUntrusted(line)) + } + if total > maxHeadTailTotalBytes+200 { + t.Fatalf("head_tail returned %d bytes of content, expected cap near %d", total, maxHeadTailTotalBytes) + } +} + +// TestHeadTail_CapsOutputSizeMultiFile locks the aggregate bound: the per-file +// cap (maxHeadTailTotalBytes) combined with the 10-file-per-call limit means a +// single head_tail response stays within ~10 files × the per-file cap, even +// when every file is individually oversized. +func TestHeadTail_CapsOutputSizeMultiFile(t *testing.T) { + dir := t.TempDir() + const nFiles = 10 + var paths []string + for f := 0; f < nFiles; f++ { + path := filepath.Join(dir, fmt.Sprintf("big-%d.txt", f)) + var lines []string + for i := 0; i < 10; i++ { + lines = append(lines, fmt.Sprintf("line-%d-%s", i, strings.Repeat("x", 200*1024))) + } + os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) + paths = append(paths, path) + } + + var fileArgs []string + for _, p := range paths { + fileArgs = append(fileArgs, fmt.Sprintf("{\"path\":%q}", p)) + } + tool := &headTailTool{dangerousConfig: danger.DangerousConfig{}} + result := callJSON(t, tool, fmt.Sprintf(`{"files":[%s],"lines":100}`, strings.Join(fileArgs, ","))) + var r struct { + Results []struct { + Lines []string `json:"lines"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) != nFiles { + t.Fatalf("expected %d results, got %d", nFiles, len(r.Results)) + } + + total := 0 + for _, res := range r.Results { + fileTotal := 0 + for _, line := range res.Lines { + fileTotal += len(unwrapUntrusted(line)) + } + if fileTotal > maxHeadTailTotalBytes+200 { + t.Fatalf("per-file content %d bytes exceeds per-file cap %d", fileTotal, maxHeadTailTotalBytes) + } + total += fileTotal + } + if total > nFiles*(maxHeadTailTotalBytes+200) { + t.Fatalf("aggregate head_tail content %d bytes exceeds bound %d", total, nFiles*maxHeadTailTotalBytes) + } +} + +// ── 27. search_files target=files must not follow symlinks for metadata ─── + +func TestSearchFiles_TargetFiles_NoSymlinkFollow(t *testing.T) { + dir := t.TempDir() + sub := filepath.Join(dir, "sub") + os.MkdirAll(sub, 0755) + // Regular file + os.WriteFile(filepath.Join(sub, "real.txt"), []byte("hello"), 0644) + // Symlink to a non-existent target — old os.Stat would skip it; Lstat lets us detect and skip it ourselves. + os.Symlink("/nonexistent/odek-test", filepath.Join(sub, "link.txt")) + + tool := &searchFilesTool{dangerousConfig: danger.DangerousConfig{}} + // Pattern with a separator forces the filepath.Glob branch. + result := callJSON(t, tool, fmt.Sprintf(`{"pattern":"**/*.txt","path":%q,"target":"files"}`, dir)) + var r struct { + Matches []struct { + Path string `json:"path"` + } `json:"matches"` + } + mustUnmarshal(t, result, &r) + if len(r.Matches) != 1 { + t.Fatalf("expected 1 regular file match, got %d", len(r.Matches)) + } + if !strings.Contains(r.Matches[0].Path, "real.txt") { + t.Fatalf("expected real.txt match, got: %q", r.Matches[0].Path) + } + if !strings.HasPrefix(r.Matches[0].Path, " maxFileReadBytes { + return nil, fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), maxFileReadBytes) + } + + return io.ReadAll(io.LimitReader(f, maxFileReadBytes+1)) } // ═════════════════════════════════════════════════════════════════════════ @@ -158,6 +178,20 @@ func (t *batchPatchTool) Call(argsJSON string) (result string, err error) { continue } + info, err := f.Stat() + if err != nil { + f.Close() + entry.Error = fmt.Sprintf("cannot stat %q: %v", p.Path, err) + results[idx] = entry + continue + } + if info.Size() > maxFileReadBytes { + f.Close() + entry.Error = fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), maxFileReadBytes) + results[idx] = entry + continue + } + var sb strings.Builder _, err = io.Copy(&sb, f) f.Close() @@ -180,6 +214,11 @@ func (t *batchPatchTool) Call(argsJSON string) (result string, err error) { } else { modified = strings.Replace(original, p.OldString, p.NewString, 1) } + if len(modified) > maxFileReadBytes { + entry.Error = fmt.Sprintf("patch result too large (%d bytes, max %d)", len(modified), maxFileReadBytes) + results[idx] = entry + continue + } diff := fmt.Sprintf("--- a/%s\n+++ b/%s\n@@ -1 +1 @@\n-%s\n+%s\n", p.Path, p.Path, truncateDiff(original, 100), truncateDiff(modified, 100)) @@ -228,7 +267,7 @@ func (t *batchPatchTool) Call(argsJSON string) (result string, err error) { } entry.Success = true - entry.Diff = diff + entry.Diff = wrapUntrusted("batch_patch:"+p.Path, diff) results[idx] = entry } @@ -365,9 +404,11 @@ func (t *parallelShellTool) runOne(cmd parallelShellCmd) parallelShellEntry { } else { shCmd = exec.Command("sh", "-c", cmd.Command) } - var stdout, stderr strings.Builder - shCmd.Stdout = &stdout - shCmd.Stderr = &stderr + var stdout, stderr bytes.Buffer + outW := &limitWriter{buf: &stdout, limit: maxShellOutputBytes} + errW := &limitWriter{buf: &stderr, limit: maxShellOutputBytes} + shCmd.Stdout = outW + shCmd.Stderr = errW // Kill on timeout via goroutine, with mutex to avoid Process race var procMu sync.Mutex @@ -787,6 +828,12 @@ func (t *diffTool) Call(argsJSON string) (result string, err error) { linesB = strings.Split(string(data), "\n") } else if args.Path != "" { pathA, pathB = args.Path, "" + if len(args.Content) > maxFileReadBytes { + return jsonResult(diffResult{ + Error: fmt.Sprintf("inline content too large (%d bytes, max %d)", len(args.Content), maxFileReadBytes), + PathA: pathA, PathB: pathB, + }) + } if err := t.dangerousConfig.CheckOperation(danger.ToolOperation{ Name: "diff", Resource: args.Path, Risk: danger.ClassifyPath(args.Path), }, nil); err != nil { @@ -822,6 +869,12 @@ func (t *diffTool) Call(argsJSON string) (result string, err error) { } hunks := computeDiff(linesA, linesB) + src := fmt.Sprintf("diff:%s|%s", pathA, pathB) + for i := range hunks { + for j := range hunks[i].Lines { + hunks[i].Lines[j].Content = wrapUntrusted(src, hunks[i].Lines[j].Content) + } + } return jsonResult(diffResult{Hunks: hunks, PathA: pathA, PathB: pathB}) } @@ -1099,6 +1152,9 @@ func (t *multiGrepTool) Call(argsJSON string) (string, error) { if args.Limit <= 0 { args.Limit = 50 } + if args.Limit > maxSearchLimit { + args.Limit = maxSearchLimit + } if err := t.dangerousConfig.CheckOperation(danger.ToolOperation{ Name: "multi_grep", Resource: args.Path, Risk: danger.ClassifyPath(args.Path), @@ -1129,6 +1185,7 @@ func (t *multiGrepTool) searchPattern(pattern, root, fileGlob string, limit int) } var matches []grepMatch + resultBytes := 0 filepath.Walk(root, func(path string, info os.FileInfo, err error) error { if err != nil || info == nil { @@ -1174,10 +1231,15 @@ func (t *multiGrepTool) searchPattern(pattern, root, fileGlob string, limit int) lineNum++ line := scanner.Text() if re.MatchString(line) { + trimmed := strings.TrimSpace(line) + if resultBytes+len(trimmed) > maxSearchResultBytes { + return filepath.SkipAll + } + resultBytes += len(trimmed) matches = append(matches, grepMatch{ Path: path, Line: lineNum, - Content: wrapUntrusted(fmt.Sprintf("%s:%d", path, lineNum), strings.TrimSpace(line)), + Content: wrapUntrusted(fmt.Sprintf("%s:%d", path, lineNum), trimmed), }) if len(matches) >= limit { return filepath.SkipAll @@ -1261,6 +1323,14 @@ func (t *jsonQueryTool) Call(argsJSON string) (result string, err error) { } defer f.Close() + info, err := f.Stat() + if err != nil { + return jsonResult(jsonQueryResult{Path: args.Path, Error: fmt.Sprintf("cannot stat %q: %v", args.Path, err)}) + } + if info.Size() > maxFileReadBytes { + return jsonResult(jsonQueryResult{Path: args.Path, Error: fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), maxFileReadBytes)}) + } + var data interface{} if err := json.NewDecoder(f).Decode(&data); err != nil { return jsonResult(jsonQueryResult{Path: args.Path, Error: fmt.Sprintf("invalid JSON: %v", err)}) @@ -1268,7 +1338,7 @@ func (t *jsonQueryTool) Call(argsJSON string) (result string, err error) { if args.Query == "" { vt := fmt.Sprintf("%T", data) - return jsonResult(jsonQueryResult{Path: args.Path, Query: "", Value: data, ValueType: vt}) + return jsonResult(jsonQueryResult{Path: args.Path, Query: "", Value: wrapJSONStrings(args.Path, data), ValueType: vt}) } value, err := jsonPathQuery(data, args.Query) @@ -1277,7 +1347,27 @@ func (t *jsonQueryTool) Call(argsJSON string) (result string, err error) { } vt := fmt.Sprintf("%T", value) - return jsonResult(jsonQueryResult{Path: args.Path, Query: args.Query, Value: value, ValueType: vt}) + return jsonResult(jsonQueryResult{Path: args.Path, Query: args.Query, Value: wrapJSONStrings(args.Path, value), ValueType: vt}) +} + +// wrapJSONStrings recursively wraps string values inside decoded JSON so that +// file content returned by json_query is treated as untrusted. +func wrapJSONStrings(source string, v interface{}) interface{} { + switch x := v.(type) { + case string: + return wrapUntrusted(source, x) + case map[string]interface{}: + for k, val := range x { + x[k] = wrapJSONStrings(source, val) + } + return x + case []interface{}: + for i, val := range x { + x[i] = wrapJSONStrings(source, val) + } + return x + } + return v } func jsonPathQuery(data interface{}, query string) (interface{}, error) { @@ -1423,7 +1513,7 @@ func (t *treeTool) Call(argsJSON string) (result string, err error) { func buildTree(root, path string, depth, maxDepth int, includeHidden bool) (treeEntry, error) { info, err := os.Lstat(path) if err != nil { - return treeEntry{Path: path, ErrMsg: err.Error()}, nil + return treeEntry{Path: wrapUntrusted("tree:"+root, path), ErrMsg: err.Error()}, nil } entry := treeEntry{ @@ -1436,6 +1526,10 @@ func buildTree(root, path string, depth, maxDepth int, includeHidden bool) (tree entry.Path = path } + // Tree paths come from the filesystem trust boundary, so mark them as + // untrusted before returning them to the model. + entry.Path = wrapUntrusted("tree:"+root, entry.Path) + if !info.IsDir() || depth >= maxDepth { if !info.IsDir() { entry.FileCount = 1 @@ -1449,10 +1543,21 @@ func buildTree(root, path string, depth, maxDepth int, includeHidden bool) (tree return entry, nil } + totalEntries := len(entries) + truncated := false + if totalEntries > maxTreeEntries { + entries = entries[:maxTreeEntries] + truncated = true + } + sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) + if truncated { + entry.ErrMsg = fmt.Sprintf("directory truncated (%d entries shown, %d total)", maxTreeEntries, totalEntries) + } + entry.Children = make([]treeEntry, 0, len(entries)) for _, e := range entries { if !includeHidden && strings.HasPrefix(e.Name(), ".") { @@ -1693,6 +1798,17 @@ func (t *sortTool) Call(argsJSON string) (result string, err error) { results = append(results, sortEntry{File: p, Error: fmt.Sprintf("cannot open %q: %v", p, err)}) continue } + info, err := f.Stat() + if err != nil { + f.Close() + results = append(results, sortEntry{File: p, Error: fmt.Sprintf("cannot stat %q: %v", p, err)}) + continue + } + if info.Size() > maxFileReadBytes { + f.Close() + results = append(results, sortEntry{File: p, Error: fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), maxFileReadBytes)}) + continue + } data, err := io.ReadAll(f) f.Close() if err != nil { @@ -1759,7 +1875,7 @@ func (t *sortTool) Call(argsJSON string) (result string, err error) { output := strings.Join(allLines, "\n") return jsonResult(sortResult{ Results: results, - Output: output, + Output: wrapUntrusted("sort:"+strings.Join(paths, ","), output), Total: len(allLines), }) } @@ -1768,6 +1884,12 @@ func (t *sortTool) Call(argsJSON string) (result string, err error) { // 12. head_tail — Quick file preview (first/last N lines) // ═════════════════════════════════════════════════════════════════════════ +// maxHeadTailTotalBytes caps the content returned by head_tail for a single +// file. Combined with the 10-file-per-call limit (see Call), this bounds a +// head_tail response to ~10 MiB. Without it, 10 files × 100 lines × 1 MiB +// lines could allocate roughly 1 GB in a single tool call. +const maxHeadTailTotalBytes = maxReadBytes // 1 MiB per file + type headTailTool struct { dangerousConfig danger.DangerousConfig } @@ -1876,14 +1998,19 @@ func (t *headTailTool) readPreview(path string, n int, mode string) (result head func (t *headTailTool) readHead(f *os.File, path string, n int) headTailFileResult { scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 1024*1024), 1024*1024) - var lines []string + var rawLines []string total := 0 for scanner.Scan() { total++ - if len(lines) < n { - lines = append(lines, scanner.Text()) + if len(rawLines) < n { + rawLines = append(rawLines, scanner.Text()) } } + rawLines = truncateHeadTailLines(rawLines) + lines := make([]string, len(rawLines)) + for i, l := range rawLines { + lines[i] = wrapUntrusted(path, l) + } return headTailFileResult{Path: path, Lines: lines, Count: len(lines), Total: total} } @@ -1900,17 +2027,39 @@ func (t *headTailTool) readTail(f *os.File, path string, n int) headTailFileResu total++ } // Extract in correct order - var lines []string + var rawLines []string start := 0 if written >= n { start = written % n } for i := 0; i < n && i < written; i++ { - lines = append(lines, buf[(start+i)%n]) + rawLines = append(rawLines, buf[(start+i)%n]) + } + rawLines = truncateHeadTailLines(rawLines) + lines := make([]string, len(rawLines)) + for i, l := range rawLines { + lines[i] = wrapUntrusted(path, l) } return headTailFileResult{Path: path, Lines: lines, Count: len(lines), Total: total} } +// truncateHeadTailLines truncates a slice of raw lines so the total byte +// count stays within maxHeadTailTotalBytes. It preserves leading lines and +// appends a marker when truncation occurs. +func truncateHeadTailLines(lines []string) []string { + total := 0 + for i, l := range lines { + if total+len(l) > maxHeadTailTotalBytes { + if i == 0 { + return []string{"... [truncated]"} + } + return append(lines[:i], "... [truncated]") + } + total += len(l) + } + return lines +} + // ═════════════════════════════════════════════════════════════════════════ // 13. base64 — Encode/decode base64 // ═════════════════════════════════════════════════════════════════════════ @@ -1998,7 +2147,7 @@ func (t *base64Tool) Call(argsJSON string) (result string, err error) { return jsonResult(base64Result{Error: fmt.Sprintf("cannot read %q: %v", args.Path, err)}) } encoded := base64.StdEncoding.EncodeToString(data) - return jsonResult(base64Result{Encoded: encoded, Size: len(data)}) + return jsonResult(base64Result{Encoded: wrapUntrusted(args.Path, encoded), Size: len(data)}) } // ═════════════════════════════════════════════════════════════════════════ @@ -2131,6 +2280,9 @@ func (t *trTool) Call(argsJSON string) (result string, err error) { } } + if fromFile { + text = wrapUntrusted(args.Path, text) + } return jsonResult(trResult{Result: text, FromFile: fromFile}) } diff --git a/cmd/odek/perf_tools_edge2_test.go b/cmd/odek/perf_tools_edge2_test.go index d96eab8..46b8568 100644 --- a/cmd/odek/perf_tools_edge2_test.go +++ b/cmd/odek/perf_tools_edge2_test.go @@ -103,7 +103,7 @@ func TestSort_AlreadySorted(t *testing.T) { if r.Error != "" { t.Fatalf("error: %s", r.Error) } - if r.Output != "a\nb\nc\nd" { + if unwrapUntrusted(r.Output) != "a\nb\nc\nd" { t.Errorf("output = %q, want a\nb\nc\nd", r.Output) } } @@ -123,7 +123,7 @@ func TestSort_Descending(t *testing.T) { if r.Error != "" { t.Fatalf("error: %s", r.Error) } - if r.Output != "c\nb\na" { + if unwrapUntrusted(r.Output) != "c\nb\na" { t.Errorf("desc sort = %q, want c\nb\na", r.Output) } } @@ -272,8 +272,8 @@ func TestTR_FileInput(t *testing.T) { if !r.FromFile { t.Errorf("expected from_file=true") } - if r.Result != "hello world\n" { - t.Errorf("result = %q, want 'hello world\\n'", r.Result) + if unwrapUntrusted(r.Result) != "hello world" { + t.Errorf("result = %q, want 'hello world' (unwrapped)", r.Result) } } diff --git a/cmd/odek/perf_tools_edge_test.go b/cmd/odek/perf_tools_edge_test.go index ac20d5a..6bee6c0 100644 --- a/cmd/odek/perf_tools_edge_test.go +++ b/cmd/odek/perf_tools_edge_test.go @@ -310,7 +310,7 @@ func TestJSONQuery_NestedObjects(t *testing.T) { if r.Error != "" { t.Fatalf("error: %s", r.Error) } - if r.Value != "deep" { + if got, ok := r.Value.(string); !ok || unwrapUntrusted(got) != "deep" { t.Errorf("value = %v, want 'deep'", r.Value) } } diff --git a/cmd/odek/perf_tools_test.go b/cmd/odek/perf_tools_test.go index 6007def..103c4bd 100644 --- a/cmd/odek/perf_tools_test.go +++ b/cmd/odek/perf_tools_test.go @@ -515,7 +515,7 @@ func TestJSONQuery_Basic(t *testing.T) { if r.Error != "" { t.Fatalf("error: %s", r.Error) } - if r.Value != "Alice" { + if got, ok := r.Value.(string); !ok || unwrapUntrusted(got) != "Alice" { t.Errorf("value = %v, want 'Alice'", r.Value) } } @@ -534,7 +534,7 @@ func TestJSONQuery_ArrayIndex(t *testing.T) { } mustUnmarshal(t, result, &r) - if r.Value != "Bob" { + if got, ok := r.Value.(string); !ok || unwrapUntrusted(got) != "Bob" { t.Errorf("value = %v, want 'Bob'", r.Value) } } @@ -713,7 +713,7 @@ func TestSort_Basic(t *testing.T) { if r.Total != 3 { t.Errorf("total = %d, want 3", r.Total) } - if r.Output != "a\nb\nc" { + if unwrapUntrusted(r.Output) != "a\nb\nc" { t.Errorf("output = %q, want a\\nb\\nc", r.Output) } } @@ -732,7 +732,7 @@ func TestSort_Desc(t *testing.T) { } mustUnmarshal(t, result, &r) - if r.Output != "c\nb\na" { + if unwrapUntrusted(r.Output) != "c\nb\na" { t.Errorf("output = %q, want c\\nb\\na", r.Output) } } @@ -787,7 +787,7 @@ func TestHeadTail_Head(t *testing.T) { if r.Results[0].Total != 5 { t.Errorf("total = %d, want 5", r.Results[0].Total) } - if r.Results[0].Lines[0] != "a" { + if unwrapUntrusted(r.Results[0].Lines[0]) != "a" { t.Errorf("first line = %q, want 'a'", r.Results[0].Lines[0]) } } @@ -809,7 +809,7 @@ func TestHeadTail_Tail(t *testing.T) { } mustUnmarshal(t, result, &r) - if r.Results[0].Count != 2 || r.Results[0].Lines[0] != "d" { + if r.Results[0].Count != 2 || unwrapUntrusted(r.Results[0].Lines[0]) != "d" { t.Errorf("tail(2) = %v, want [d e]", r.Results[0].Lines) } } @@ -1406,8 +1406,8 @@ func TestTr_FileInput(t *testing.T) { if !r.FromFile { t.Error("should indicate from_file=true") } - if r.Result != "HELLO WORLD\n" { - t.Errorf("result = %q, want 'HELLO WORLD\\n'", r.Result) + if unwrapUntrusted(r.Result) != "HELLO WORLD" { + t.Errorf("result = %q, want 'HELLO WORLD' (unwrapped)", r.Result) } } @@ -1424,7 +1424,7 @@ func TestSort_IgnoreCase(t *testing.T) { var r struct{ Output string } mustUnmarshal(t, result, &r) - if r.Output != "alpha\nBeta\nGamma" { + if unwrapUntrusted(r.Output) != "alpha\nBeta\nGamma" { t.Errorf("case-insensitive sort = %q", r.Output) } } @@ -1442,7 +1442,7 @@ func TestSort_MultipleFiles(t *testing.T) { var r struct{ Output string } mustUnmarshal(t, result, &r) - if r.Output != "a\nb\nc" { + if unwrapUntrusted(r.Output) != "a\nb\nc" { t.Errorf("merged sort = %q, want 'a\\nb\\nc'", r.Output) } } @@ -1762,7 +1762,7 @@ func TestSort_Numeric(t *testing.T) { if r.Total != 4 { t.Errorf("total = %d, want 4", r.Total) } - if r.Output != "1\n2\n10\n30" { + if unwrapUntrusted(r.Output) != "1\n2\n10\n30" { t.Errorf("numeric sort = %q, want '1\\n2\\n10\\n30'", r.Output) } } diff --git a/cmd/odek/refs.go b/cmd/odek/refs.go index 8476593..4e069b2 100644 --- a/cmd/odek/refs.go +++ b/cmd/odek/refs.go @@ -39,7 +39,7 @@ func enrichTask(task string, ctxFiles []string, cwd string) (string, error) { // Leave unresolved refs as-is continue } - resolved[ref.Raw] = content + resolved[ref.Raw] = wrapUntrusted("resource:"+ref.Raw, content) } enriched = resource.ReplaceRefs(task, resolved) } @@ -56,7 +56,7 @@ func enrichTask(task string, ctxFiles []string, cwd string) (string, error) { if err != nil { return "", fmt.Errorf("ctx file %q: %w", f, err) } - blocks = append(blocks, fmt.Sprintf("--- %s ---\n%s\n--- end %s ---", f, content, f)) + blocks = append(blocks, fmt.Sprintf("--- %s ---\n%s\n--- end %s ---", f, wrapUntrusted("ctx:"+f, content), f)) } if len(blocks) > 0 { // Log attached files to stderr diff --git a/cmd/odek/serve.go b/cmd/odek/serve.go index 2c36925..a831048 100644 --- a/cmd/odek/serve.go +++ b/cmd/odek/serve.go @@ -32,6 +32,11 @@ import ( //go:embed ui var uiFS embed.FS +// maxWSMessageBytes caps the size of an incoming WebSocket text message. +// This prevents a local client from exhausting server memory by sending a +// multi-gigabyte frame. +const maxWSMessageBytes = 8 * 1024 * 1024 // 8 MiB + // currentPromptCancel holds the cancel function for the currently executing // prompt. Used by the POST /api/cancel endpoint to abort a running agent. var currentPromptCancel atomic.Value @@ -150,11 +155,11 @@ func serveCmd(args []string) error { handleWS(store, resourceReg, resolved, systemMessage, conn) }, }) - mux.HandleFunc("/api/resources", handleResourceSearch(resourceReg)) - mux.HandleFunc("/api/sessions", handleSessionList(store)) - mux.HandleFunc("/api/sessions/", handleSessionByID(store)) - mux.HandleFunc("/api/models", handleModelList(resolved.Model)) - mux.HandleFunc("/api/cancel", handleCancel) + mux.Handle("/api/resources", requireLocalOrigin(handleResourceSearch(resourceReg))) + mux.Handle("/api/sessions", requireLocalOrigin(handleSessionList(store))) + mux.Handle("/api/sessions/", requireLocalOrigin(handleSessionByID(store))) + mux.Handle("/api/models", requireLocalOrigin(handleModelList(resolved.Model))) + mux.Handle("/api/cancel", requireLocalOrigin(http.HandlerFunc(handleCancel))) listener, err := net.Listen("tcp", addr) if err != nil { @@ -457,6 +462,10 @@ func handleWS(store *session.Store, resources *resource.Registry, resolved confi }() defer conn.Close() + // Cap incoming message size to prevent a local client from exhausting + // server memory with a single huge frame. + conn.MaxPayloadBytes = maxWSMessageBytes + // Create ONE agent per WebSocket connection — provides buffer // continuity across turns within the same session. agent, sandboxCleanup, mcpCleanup, approver, err := newServeAgent(resolved, system, func(v any) error { @@ -682,7 +691,7 @@ func handlePrompt( if err != nil { continue } - resolvedRefs[ref.Raw] = content + resolvedRefs[ref.Raw] = wrapUntrusted("resource:"+ref.Raw, content) } enrichedPrompt := resource.ReplaceRefs(prompt, resolvedRefs) @@ -945,6 +954,38 @@ func checkLocalOrigin(_ *golangws.Config, req *http.Request) error { return fmt.Errorf("Origin %q not allowed (only localhost is accepted)", origin) } +// requireLocalOrigin rejects cross-origin state-changing requests to the REST +// API. It is the HTTP counterpart to checkLocalOrigin. +func requireLocalOrigin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isStateChangingMethod(r.Method) { + origin := r.Header.Get("Origin") + if origin != "" { + u, err := url.Parse(origin) + if err != nil { + http.Error(w, "invalid Origin", http.StatusForbidden) + return + } + host := u.Hostname() + if host != "localhost" && host != "127.0.0.1" && host != "::1" { + http.Error(w, "Origin not allowed", http.StatusForbidden) + return + } + } + } + w.Header().Set("Vary", "Origin") + next.ServeHTTP(w, r) + }) +} + +func isStateChangingMethod(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return false + } + return true +} + func writeWSJSON(conn *golangws.Conn, data any) { payload, err := json.Marshal(data) if err != nil { @@ -1144,6 +1185,8 @@ func handleStatic() http.HandlerFunc { w.Header().Set("Content-Type", entry[1]) w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Referrer-Policy", "no-referrer") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Content-Security-Policy", "frame-ancestors 'none'") w.Write(data) } } diff --git a/cmd/odek/serve_test.go b/cmd/odek/serve_test.go index bb3270d..828454d 100644 --- a/cmd/odek/serve_test.go +++ b/cmd/odek/serve_test.go @@ -113,6 +113,7 @@ func (s *testServer) handleResourceSearch() http.HandlerFunc { } func (s *testServer) handleWebSocket(conn *golangws.Conn) { + conn.MaxPayloadBytes = maxWSMessageBytes defer conn.Close() for { @@ -1717,3 +1718,41 @@ func TestServe_E2E_CancelWithMockLLM(t *testing.T) { t.Log("No terminal event (connection may have been closed by cancel)") } } + +// TestServe_WebSocketMaxPayload verifies that the WebSocket endpoint rejects +// incoming messages larger than maxWSMessageBytes to prevent memory DoS. +func TestServe_WebSocketMaxPayload(t *testing.T) { + s := startTestServer(t) + defer s.Close() + + conn, err := golangws.Dial(s.wsURL+"/ws", "", "http://localhost") + if err != nil { + t.Fatalf("Dial(): %v", err) + } + defer conn.Close() + + // First confirm a normal message works. + if err := golangws.Message.Send(conn, `{"type":"prompt","content":"hi"}`); err != nil { + t.Fatalf("Send small message: %v", err) + } + var small map[string]any + if err := readJSON(conn, &small); err != nil { + t.Fatalf("Receive small message response: %v", err) + } + if small["type"] != "session" { + t.Fatalf("expected session event, got %v", small["type"]) + } + + // Send a message that exceeds the payload cap. + huge := `{"type":"prompt","content":"` + strings.Repeat("x", int(maxWSMessageBytes)+1024) + `"}` + if err := golangws.Message.Send(conn, huge); err != nil { + // Some transports close the connection on send; that also satisfies the test. + return + } + + // The next receive must fail because the server closed the connection. + var data []byte + if err := golangws.Message.Receive(conn, &data); err == nil { + t.Fatalf("expected connection to be closed after oversized message, but received: %s", string(data)) + } +} diff --git a/cmd/odek/session_search_tool.go b/cmd/odek/session_search_tool.go index af526d1..542ba35 100644 --- a/cmd/odek/session_search_tool.go +++ b/cmd/odek/session_search_tool.go @@ -142,7 +142,7 @@ func (t *sessionSearchTool) handleList(limit int) (string, error) { Count: 0, }) } - results := toSummaries(sessions) + results := toSummaries("session_search:list", sessions) return jsonResult(sessionSearchResult{ Action: "list", Sessions: results, @@ -191,7 +191,7 @@ func (t *sessionSearchTool) handleSearch(query string, limit int) (string, error scoreLabel := fmt.Sprintf("(score: %.3f)", vr.Score) results = append(results, sessionSummary{ ID: sess.ID, - Task: sess.Task + " " + scoreLabel, + Task: wrapUntrusted("session_search:search", sess.Task+" "+scoreLabel), Turns: sess.Turns, CreatedAt: sess.CreatedAt.UTC().Format(time.RFC3339), UpdatedAt: sess.UpdatedAt.UTC().Format(time.RFC3339), @@ -248,17 +248,18 @@ func (t *sessionSearchTool) handleSearch(query string, limit int) (string, error results := make([]sessionSummary, len(matches)) for i, m := range matches { + task := m.session.Task + if m.snippet != "" && m.snippet != m.session.Task { + task = m.session.Task + " — " + m.snippet + } results[i] = sessionSummary{ ID: m.session.ID, - Task: m.session.Task, + Task: wrapUntrusted("session_search:search", task), Turns: m.session.Turns, CreatedAt: m.session.CreatedAt.UTC().Format(time.RFC3339), UpdatedAt: m.session.UpdatedAt.UTC().Format(time.RFC3339), Model: m.session.Model, } - if m.snippet != "" && m.snippet != m.session.Task { - results[i].Task = m.session.Task + " — " + m.snippet - } } return jsonResult(sessionSearchResult{ @@ -362,7 +363,9 @@ func (t *sessionSearchTool) handleGet(id string) (string, error) { }) } - // Build session messages for the LLM to read. + // Build session messages for the LLM to read. Cap how many are returned + // and treat the content as untrusted because it includes prior tool outputs. + const maxSessionGetMessages = 100 var sessionMessages []sessionMessage for _, m := range sess.Messages { if m.Role == "user" || m.Role == "assistant" { @@ -372,16 +375,26 @@ func (t *sessionSearchTool) handleGet(id string) (string, error) { }) } } + if len(sessionMessages) > maxSessionGetMessages { + sessionMessages = sessionMessages[len(sessionMessages)-maxSessionGetMessages:] + } + for i := range sessionMessages { + sessionMessages[i].Content = wrapUntrusted("session_search:"+sess.ID, sessionMessages[i].Content) + } msgCount := len(sessionMessages) + wrappedBuffer := make([]string, len(sess.Buffer)) + for i, b := range sess.Buffer { + wrappedBuffer[i] = wrapUntrusted("session_search:get:buffer", b) + } return jsonResult(sessionSearchResult{ Action: "get", ID: sess.ID, - Task: sess.Task, + Task: wrapUntrusted("session_search:get", sess.Task), Turns: sess.Turns, CreatedAt: sess.CreatedAt.UTC().Format(time.RFC3339), UpdatedAt: sess.UpdatedAt.UTC().Format(time.RFC3339), Model: sess.Model, - Buffer: sess.Buffer, + Buffer: wrappedBuffer, Messages: msgCount, SessionMessages: sessionMessages, }) @@ -418,7 +431,7 @@ func (t *sessionSearchTool) handleFind(query string, limit int) (string, error) return jsonResult(sessionSearchResult{ Action: "find", - Sessions: toSummaries(matched), + Sessions: toSummaries("session_search:find", matched), Count: len(matched), }) } @@ -441,12 +454,12 @@ func matchTokens(tokens []string, text string) int { } // toSummaries converts session.Session slices to sessionSummary (metadata only). -func toSummaries(sessions []session.Session) []sessionSummary { +func toSummaries(source string, sessions []session.Session) []sessionSummary { results := make([]sessionSummary, len(sessions)) for i, s := range sessions { results[i] = sessionSummary{ ID: s.ID, - Task: s.Task, + Task: wrapUntrusted(source, s.Task), Turns: s.Turns, CreatedAt: s.CreatedAt.UTC().Format(time.RFC3339), UpdatedAt: s.UpdatedAt.UTC().Format(time.RFC3339), diff --git a/cmd/odek/session_search_tool_test.go b/cmd/odek/session_search_tool_test.go index 5112e98..bc503ab 100644 --- a/cmd/odek/session_search_tool_test.go +++ b/cmd/odek/session_search_tool_test.go @@ -186,7 +186,7 @@ func TestSessionSearch_Get(t *testing.T) { if r.ID != "20260520-auth-fix" { t.Errorf("id = %q, want '20260520-auth-fix'", r.ID) } - if r.Task != "fix O_NOFOLLOW in file_tool.go" { + if unwrapUntrusted(r.Task) != "fix O_NOFOLLOW in file_tool.go" { t.Errorf("task = %q", r.Task) } if r.Turns != 8 { @@ -712,7 +712,7 @@ func TestSessionSearch_GetReturnsSessionMessages(t *testing.T) { if resp.SessionMessages[i].Role != c.role { t.Errorf("msg[%d] role = %q, want %q", i, resp.SessionMessages[i].Role, c.role) } - if resp.SessionMessages[i].Content != c.content { + if unwrapUntrusted(resp.SessionMessages[i].Content) != c.content { t.Errorf("msg[%d] content = %q, want %q", i, resp.SessionMessages[i].Content, c.content) } } diff --git a/cmd/odek/shell.go b/cmd/odek/shell.go index 849c885..b53bfea 100644 --- a/cmd/odek/shell.go +++ b/cmd/odek/shell.go @@ -22,6 +22,34 @@ import ( // immediately regardless of this backstop. const defaultShellTimeout = 30 * time.Minute +// maxShellOutputBytes caps the stdout + stderr captured from a single shell +// command to prevent memory DoS from commands that dump huge files. +const maxShellOutputBytes = 1 << 20 // 1 MiB + +// limitWriter wraps a bytes.Buffer and drops further writes once the total +// size would exceed limit, recording that output was truncated. +type limitWriter struct { + buf *bytes.Buffer + limit int + truncated bool +} + +func (w *limitWriter) Write(p []byte) (int, error) { + if w.truncated { + return len(p), nil + } + if w.buf.Len()+len(p) > w.limit { + w.truncated = true + room := w.limit - w.buf.Len() + if room > 0 { + w.buf.Write(p[:room]) + } + w.buf.WriteString("\n... [output truncated]") + return len(p), nil + } + return w.buf.Write(p) +} + // shellTool is odek's built-in tool that lets the agent run shell commands. // // This is the only built-in tool — it's enough for reading files, running @@ -163,8 +191,10 @@ func (t *shellTool) Call(args string) (string, error) { cmd.WaitDelay = 3 * time.Second var outBuf, errBuf bytes.Buffer - cmd.Stdout = &outBuf - cmd.Stderr = &errBuf + outW := &limitWriter{buf: &outBuf, limit: maxShellOutputBytes} + errW := &limitWriter{buf: &errBuf, limit: maxShellOutputBytes} + cmd.Stdout = outW + cmd.Stderr = errW err := cmd.Run() diff --git a/cmd/odek/subagent.go b/cmd/odek/subagent.go index 741237a..228933d 100644 --- a/cmd/odek/subagent.go +++ b/cmd/odek/subagent.go @@ -224,6 +224,13 @@ func subagentCmd(args []string) error { var taskTrust string // "trusted" or "untrusted" (from parent agent) var taskMaxRisk string if hasTaskFile { + info, err := os.Stat(cfg.taskFile) + if err != nil { + return fmt.Errorf("stat task file: %w", err) + } + if info.Size() > maxFileReadBytes { + return fmt.Errorf("task file too large (%d bytes, max %d)", info.Size(), maxFileReadBytes) + } data, err := os.ReadFile(cfg.taskFile) if err != nil { return fmt.Errorf("read task file: %w", err) diff --git a/cmd/odek/subagent_tool.go b/cmd/odek/subagent_tool.go index d765c4f..1451246 100644 --- a/cmd/odek/subagent_tool.go +++ b/cmd/odek/subagent_tool.go @@ -157,12 +157,18 @@ func (t *delegateTasksTool) Call(args string) (string, error) { sem <- struct{}{} } - // Build summary for the calling agent + // Build summary for the calling agent. Cap each sub-agent result so the + // summary cannot grow without bound. var buf strings.Builder buf.WriteString("📋 Sub-agent results:\n\n") for i, r := range results { buf.WriteString(fmt.Sprintf("─── Task %d: %s ───\n", i+1, truncate(input.Tasks[i].Goal, 60))) - buf.WriteString(r) + if len(r) > maxSubagentSummaryResultBytes { + buf.WriteString(r[:maxSubagentSummaryResultBytes]) + buf.WriteString("\n... [result truncated]") + } else { + buf.WriteString(r) + } buf.WriteString("\n\n") } return buf.String(), nil @@ -284,6 +290,11 @@ func (t *delegateTasksTool) runTask(taskIdx int, goal, taskContext, guidance, tr return `{"error":"no result from sub-agent"}` } +// maxSubagentSummaryResultBytes caps how much of each sub-agent result is +// included in the parent delegate_tasks summary, preventing memory DoS from +// huge sub-agent outputs. +const maxSubagentSummaryResultBytes = 100 << 10 // 100 KiB + // maxSubagentLine caps a single NDJSON line read from a sub-agent's stdout. // Streamed tool_call events embed full tool arguments (e.g. a large write_file // or patch), which routinely exceed bufio.Scanner's default 64KB token cap. diff --git a/cmd/odek/transcribe_tool.go b/cmd/odek/transcribe_tool.go index e945bf6..3bd758d 100644 --- a/cmd/odek/transcribe_tool.go +++ b/cmd/odek/transcribe_tool.go @@ -35,11 +35,20 @@ func convertToWAV(ctx context.Context, srcPath string) string { } // Convert to WAV using ffmpeg — best-effort, fall through on failure. - dstPath := srcPath + ".wav" + // Write the output to a temp file in the system temp directory so we never + // clobber an existing .wav file next to the source path. + dstFile, err := os.CreateTemp("", "odek-transcribe-*.wav") + if err != nil { + return srcPath + } + dstPath := dstFile.Name() + dstFile.Close() + cmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", srcPath, "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", dstPath) if err := cmd.Run(); err != nil { // If ffmpeg fails (corrupt file, unsupported codec, etc.), // just pass the original path — whisper will produce its own error. + os.Remove(dstPath) return srcPath } return dstPath @@ -216,14 +225,27 @@ func (t *transcribeTool) Call(argsJSON string) (result string, err error) { return jsonError(err.Error()) } - // Check the audio file exists (O_NOFOLLOW to prevent symlink attacks) + // Check the audio file exists (O_NOFOLLOW to prevent symlink attacks) and + // reject inputs that would exhaust memory during conversion / transcription. f, err := os.OpenFile(args.Path, os.O_RDONLY|syscall.O_NOFOLLOW, 0) if err != nil { return jsonResult(transcribeResult{ Error: fmt.Sprintf("cannot open audio file %q: %v", args.Path, err), }) } + info, err := f.Stat() f.Close() + if err != nil { + return jsonResult(transcribeResult{ + Error: fmt.Sprintf("cannot stat audio file %q: %v", args.Path, err), + }) + } + const maxAudioFileBytes = maxFileReadBytes // 10 MiB — same cap as other file-reading tools + if info.Size() > maxAudioFileBytes { + return jsonResult(transcribeResult{ + Error: fmt.Sprintf("audio file too large (%d bytes, max %d)", info.Size(), maxAudioFileBytes), + }) + } // Convert to WAV if needed (whisper.cpp doesn't support OGG Opus natively). wavPath := convertToWAV(t.toolCtx(), args.Path) @@ -265,8 +287,14 @@ func (t *transcribeTool) Call(argsJSON string) (result string, err error) { args2 = append(args2, "--language", lang) } + const maxWhisperOutputBytes = 10 << 20 // 10 MiB cmd := exec.CommandContext(t.toolCtx(), binary, args2...) output, err := cmd.Output() + if err == nil && len(output) > maxWhisperOutputBytes { + return jsonResult(transcribeResult{ + Error: fmt.Sprintf("whisper output too large (%d bytes, max %d)", len(output), maxWhisperOutputBytes), + }) + } if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { return jsonResult(transcribeResult{ diff --git a/cmd/odek/untrusted.go b/cmd/odek/untrusted.go index 50cf370..9982f83 100644 --- a/cmd/odek/untrusted.go +++ b/cmd/odek/untrusted.go @@ -183,7 +183,10 @@ func unwrapUntrusted(s string) string { if len(m) < 2 { return s } - return m[1] + body := m[1] + body = strings.TrimPrefix(body, "\n") + body = strings.TrimSuffix(body, "\n") + return body } // hasUntrustedWrapper reports whether s contains a complete nonce'd diff --git a/docs/SECURITY.md b/docs/SECURITY.md index 942c1dd..3e86db2 100644 --- a/docs/SECURITY.md +++ b/docs/SECURITY.md @@ -53,7 +53,7 @@ Tools that wrap: | Tool | Source attribute | |---|---| -| `browser` (navigate / snapshot / back) | the URL | +| `browser` (navigate / snapshot / back) | the URL; page title and interactive-element text are wrapped too | | `read_file` | the absolute path | | `search_files`, `multi_grep` | `:` per match | | `shell` | `$ ` | @@ -61,6 +61,9 @@ Tools that wrap: | `vision` | `vision:` (full description) | | `web_search` | `web_search:` (results + answers from SearXNG) | | `session_search` | `session_search` (whole result — past sessions may be tainted) | +| `file_info` | `file_info:` (metadata about an external file) | +| `tree` | `tree:` (directory/file names from the filesystem) | +| `base64` (file/path mode) | `base64:` (the encoded bytes are wrapped) | | any MCP tool | `mcp::` | `session_search` is wrapped because it can surface content from arbitrary past sessions — including sessions that ingested untrusted content. Wrapping its whole output keeps that content from re-entering as trusted instructions and records the retrieval in the audit log, closing a path that otherwise bypassed the memory taint gate (defense 5). @@ -69,6 +72,8 @@ The MCP wrapper guards a tool's **output**. The server-supplied tool **descripti The model is instructed (via the default system prompt) to treat the wrapped region as data, not instructions. A model trained on prompt-injection resistance (Claude Sonnet 4.6+ does this well) honours the boundary. Older models or aggressively fine-tuned ones may not. +Two additional boundaries keep filesystem-derived metadata from leaking as "trusted" context. First, the `base64` tool wraps encoded output when reading from a file path, so even transformed filesystem bytes stay inside an untrusted boundary. Second, the `@`-resource resolver (`FileResolver.Search`) uses `os.Lstat` when building search-result metadata, which prevents a symlink inside the workspace from leaking the size (or other `stat` metadata) of an arbitrary target outside it. + ### 3. Danger classifier (shell) The `shell` tool tokenises commands and classifies each into one of 9 risk classes (`safe`, `local_write`, `system_write`, `destructive`, `network_egress`, `code_execution`, `install`, `unknown`, `blocked`). Per-class policy (allow / prompt / deny) is configurable. @@ -255,6 +260,10 @@ See [CLI.md — Dangerous Operations](CLI.md#dangerous-operations) for the full } ``` +### 15. Configuration file size cap + +`~/.odek/config.json` and `./odek.json` are rejected if they exceed 5 MiB. This prevents a malicious, truncated, or accidentally-generated config file from causing an out-of-memory condition at startup. + ### YOLO mode ```json diff --git a/internal/config/loader.go b/internal/config/loader.go index 96a42b6..1bcafb1 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -30,6 +30,10 @@ import ( "github.com/BackendStack21/odek/internal/telegram" ) +// maxConfigFileBytes caps how large a config file may be before it is rejected. +// This prevents a malicious or broken config from OOMing startup. +const maxConfigFileBytes = 5 << 20 // 5 MiB + // ── Types ────────────────────────────────────────────────────────────── // CLIFlags holds values parsed from the CLI. Zero/nil values mean the @@ -408,10 +412,18 @@ func loadFile(path string) FileConfig { if path == "" { return FileConfig{} } - data, err := os.ReadFile(path) + info, err := os.Stat(path) if err != nil { return FileConfig{} // missing or unreadable = empty } + if info.Size() > maxConfigFileBytes { + fmt.Fprintf(os.Stderr, "odek: warning: config %s: file exceeds maximum size %d bytes — ignoring file\n", path, maxConfigFileBytes) + return FileConfig{} + } + data, err := os.ReadFile(path) + if err != nil { + return FileConfig{} // unreadable = empty + } var cfg FileConfig if err := json.Unmarshal(data, &cfg); err != nil { fmt.Fprintf(os.Stderr, "odek: warning: config %s: invalid JSON — ignoring file: %v\n", path, err) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 6132ec6..3f5b72c 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "strings" "testing" "github.com/BackendStack21/odek/internal/memory" @@ -1015,3 +1016,17 @@ func TestLoadConfig_EmbeddingOverrides(t *testing.T) { t.Errorf("explicit skills timeout = %d, want 7 (respected as-is)", cfg.Skills.Embedding.TimeoutSeconds) } } + +// TestLoadFile_CapsSize verifies that config files larger than maxConfigFileBytes +// are ignored to prevent OOM from a malicious or broken config file. +func TestLoadFile_CapsSize(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "odek.json") + if err := os.WriteFile(path, []byte(strings.Repeat("x", maxConfigFileBytes+1)), 0644); err != nil { + t.Fatal(err) + } + cfg := loadFile(path) + if cfg.Model != "" { + t.Fatalf("loadFile should reject a huge config file, got Model=%q", cfg.Model) + } +} diff --git a/internal/resource/resource.go b/internal/resource/resource.go index 1d48a8d..7aea221 100644 --- a/internal/resource/resource.go +++ b/internal/resource/resource.go @@ -21,6 +21,11 @@ import ( "github.com/BackendStack21/odek/internal/session" ) +// maxResourceFileBytes caps how much of a file the @-resource resolver will +// read into memory. It still truncates the returned content to 50 KB, but this +// guard prevents OOM from files larger than 1 MiB. +const maxResourceFileBytes = 1 << 20 // 1 MiB + // Resource is a discovered resource returned by a Resolver. type Resource struct { ID string `json:"id"` // Full @ reference (e.g. "@src/main.go") @@ -218,13 +223,27 @@ func (f *FileResolver) Search(ctx context.Context, query string, limit int) ([]R matches = f.walkAndMatch(query) } + // Resolve the root once so every match can be confined to it. A query + // such as "../../etc/passwd" makes filepath.Join above clean to a path + // outside root, and filepath.Glob would then match files the workspace + // must not expose. Skip any match that escapes root before touching the + // filesystem (closes CodeQL "uncontrolled data in path expression"). + absRoot, err := filepath.Abs(f.root) + if err != nil { + return nil, nil + } + var resources []Resource for _, match := range matches { if len(resources) >= limit { break } + if !withinRoot(absRoot, match) { + continue + } rel, _ := filepath.Rel(f.root, match) - info, err := os.Stat(match) + // Use Lstat so that symlinks do not leak metadata from their targets. + info, err := os.Lstat(match) if err != nil { continue } @@ -251,11 +270,7 @@ func (f *FileResolver) Load(ctx context.Context, id string) (string, error) { if err != nil { return "", err } - absRoot, err := filepath.Abs(f.root) - if err != nil { - return "", err - } - if !strings.HasPrefix(absTarget, absRoot) { + if !withinRoot(f.root, absTarget) { return "", fmt.Errorf("resource: path %q is outside root", id) } @@ -268,6 +283,14 @@ func (f *FileResolver) Load(ctx context.Context, id string) (string, error) { } defer fd.Close() + info, err := fd.Stat() + if err != nil { + return "", err + } + if info.Size() > maxResourceFileBytes { + return "", fmt.Errorf("resource: file too large (%d bytes, max %d)", info.Size(), maxResourceFileBytes) + } + data, err := io.ReadAll(fd) if err != nil { return "", err @@ -326,6 +349,25 @@ func skipDir(name string) bool { return false } +// withinRoot reports whether candidate resolves to a path inside root. +// root may be relative; candidate may be relative or absolute. The check is +// separator-aware so "/foo" does not match "/foobar". It is the single +// confinement primitive for the file resolver (Search metadata + Load reads). +func withinRoot(root, candidate string) bool { + absRoot, err := filepath.Abs(root) + if err != nil { + return false + } + absCandidate, err := filepath.Abs(candidate) + if err != nil { + return false + } + if absCandidate == absRoot { + return true + } + return strings.HasPrefix(absCandidate, absRoot+string(os.PathSeparator)) +} + func describeFile(info os.FileInfo) string { size := info.Size() switch { diff --git a/internal/resource/resource_test.go b/internal/resource/resource_test.go index b920130..97219de 100644 --- a/internal/resource/resource_test.go +++ b/internal/resource/resource_test.go @@ -1,6 +1,7 @@ package resource import ( + "bytes" "context" "fmt" "os" @@ -221,6 +222,30 @@ func TestFileResolver_SearchRecursive(t *testing.T) { } } +func TestFileResolver_SearchOutsideRoot(t *testing.T) { + // Parent holds a sentinel file; root is a subdirectory of it. A traversal + // query must not surface metadata for files outside root. + parent := t.TempDir() + if err := os.WriteFile(filepath.Join(parent, "secret.txt"), []byte("top secret"), 0644); err != nil { + t.Fatal(err) + } + root := filepath.Join(parent, "workspace") + if err := os.MkdirAll(root, 0755); err != nil { + t.Fatal(err) + } + res := NewFileResolver(root) + + results, err := res.Search(context.Background(), "../secret", 10) + if err != nil { + t.Fatalf("Search() error: %v", err) + } + for _, r := range results { + if strings.Contains(r.Label, "secret") || strings.Contains(r.ID, "secret") { + t.Fatalf("traversal query leaked file outside root: %+v", r) + } + } +} + func TestFileResolver_Load(t *testing.T) { dir := newTestDir(t) res := NewFileResolver(dir) @@ -649,3 +674,42 @@ func TestSessionResolverLoad_PathTraversal(t *testing.T) { } } } + +// ── Bug #30: FileResolver.Search follows symlinks via os.Stat ─────────────── + +func TestFileResolverSearch_DoesNotFollowSymlinksForMetadata(t *testing.T) { + dir := t.TempDir() + base := filepath.Join(dir, "base") + if err := os.MkdirAll(base, 0755); err != nil { + t.Fatal(err) + } + + secret := filepath.Join(dir, "secret.txt") + // 2000 bytes produces "2.0 KB" in Detail if os.Stat follows the symlink. + if err := os.WriteFile(secret, bytes.Repeat([]byte("x"), 2000), 0600); err != nil { + t.Fatal(err) + } + link := filepath.Join(base, "leak.txt") + if err := os.Symlink(secret, link); err != nil { + t.Fatal(err) + } + + resolver := NewFileResolver(base) + results, err := resolver.Search(context.Background(), "leak", 10) + if err != nil { + t.Fatalf("Search failed: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + res := results[0] + // With os.Stat the target's size (2.0 KB) would leak through Detail. + // With os.Lstat we get the symlink's own metadata, never the target size. + if strings.Contains(res.Detail, "2.0") { + t.Errorf("symlink leaked target file size through Detail: %s", res.Detail) + } + // The returned resource reference must point to the symlink inside the base. + if res.ID != "@leak.txt" { + t.Errorf("expected resource ID %q, got %q", "@leak.txt", res.ID) + } +} diff --git a/internal/session/session.go b/internal/session/session.go index a5dcb48..e1bc25b 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -32,6 +32,11 @@ import ( "github.com/BackendStack21/odek/internal/redact" ) +// MaxSessionFileBytes caps the on-disk size of a session file that Load will +// read into memory. This prevents a tampered or corrupted multi-gigabyte +// session file from causing an OOM when any caller loads it. +const MaxSessionFileBytes = 32 * 1024 * 1024 // 32 MiB + // ── Types ────────────────────────────────────────────────────────────── // Session represents a single multi-turn conversation with the agent. @@ -319,6 +324,13 @@ func (s *Store) Load(id string) (*Session, error) { if err := ValidateSessionID(id); err != nil { return nil, err } + info, err := os.Stat(s.path(id)) + if err != nil { + return nil, fmt.Errorf("session: load %q: %w", id, err) + } + if info.Size() > MaxSessionFileBytes { + return nil, fmt.Errorf("session: load %q: file too large (%d bytes, max %d)", id, info.Size(), MaxSessionFileBytes) + } data, err := os.ReadFile(s.path(id)) if err != nil { return nil, fmt.Errorf("session: load %q: %w", id, err) diff --git a/internal/skills/loader.go b/internal/skills/loader.go index 3a6decb..f8e7803 100644 --- a/internal/skills/loader.go +++ b/internal/skills/loader.go @@ -8,6 +8,11 @@ import ( "strings" ) +// MaxSkillFileBytes caps the size of a single SKILL.md file that the loader +// will read into memory. A maliciously huge skill file could otherwise OOM +// the process at startup or bloat the system prompt. +const MaxSkillFileBytes = 1 * 1024 * 1024 // 1 MiB + // ── Frontmatter Parsing ─────────────────────────────────────────────── // // Manual YAML frontmatter parser for the SKILL.md subset: @@ -20,6 +25,13 @@ import ( // parseSkillFile reads and parses a single SKILL.md file. // Returns nil if the file doesn't exist or can't be parsed. func parseSkillFile(path string) *Skill { + info, err := os.Stat(path) + if err != nil { + return nil + } + if info.Size() > MaxSkillFileBytes { + return nil + } data, err := os.ReadFile(path) if err != nil { return nil diff --git a/odek.go b/odek.go index 84f7ef5..99799f3 100644 --- a/odek.go +++ b/odek.go @@ -298,6 +298,11 @@ func ProfileLabel(model string) string { // that odek automatically loads from the working directory. const ProjectFileName = "AGENTS.md" +// maxProjectFileBytes caps the size of AGENTS.md that will be loaded into the +// system prompt. A maliciously huge project file could otherwise OOM the +// process at startup or bloat every prompt. +const maxProjectFileBytes = 256 * 1024 // 256 KiB + // LoadProjectFile reads ProjectFileName from the current working directory. // Returns the file content (trimmed) if it exists and is readable. // Returns empty string if the file doesn't exist or can't be read. @@ -315,6 +320,10 @@ func LoadProjectFile() string { fmt.Fprintf(os.Stderr, "odek: warning: %s is a symlink — refusing to follow for security\n", ProjectFileName) return "" } + if info.Size() > maxProjectFileBytes { + fmt.Fprintf(os.Stderr, "odek: warning: %s is too large (%d bytes, max %d) — ignoring\n", ProjectFileName, info.Size(), maxProjectFileBytes) + return "" + } data, err := os.ReadFile(ProjectFileName) if err != nil { return "" diff --git a/odek_test.go b/odek_test.go index 6adad30..412570c 100644 --- a/odek_test.go +++ b/odek_test.go @@ -16,6 +16,17 @@ import ( "github.com/BackendStack21/odek/internal/skills" ) +func TestLoadProjectFile_CapsSize(t *testing.T) { + dir := t.TempDir() + t.Chdir(dir) + + os.WriteFile(ProjectFileName, []byte(strings.Repeat("x", maxProjectFileBytes+1)), 0644) + got := LoadProjectFile() + if got != "" { + t.Fatalf("LoadProjectFile should reject a huge %s, got length %d", ProjectFileName, len(got)) + } +} + func TestConfigDefaults(t *testing.T) { os.Unsetenv("DEEPSEEK_API_KEY") os.Unsetenv("OPENAI_API_KEY")