diff --git a/cmd/odek/file_tool.go b/cmd/odek/file_tool.go index 45c5dff..7b4568a 100644 --- a/cmd/odek/file_tool.go +++ b/cmd/odek/file_tool.go @@ -21,6 +21,10 @@ import ( const maxLines = 2000 +// maxReadBytes caps the content returned by read_file / batch_read to prevent +// memory exhaustion from huge files. +const maxReadBytes = 1 << 20 // 1 MiB + type readFileTool struct { dangerousConfig danger.DangerousConfig } @@ -226,6 +230,14 @@ func (t *writeFileTool) Call(argsJSON string) (string, error) { } } + // Preserve the original file's mode when overwriting, so a temp file + // created with default permissions does not change the accessibility + // of an existing file (e.g., making a 0640 file world-readable). + var origMode os.FileMode = 0644 + if st, err := os.Stat(args.Path); err == nil { + origMode = st.Mode().Perm() + } + // Atomic write via temp file + rename to prevent TOCTOU symlink races. // os.CreateTemp creates the file in the same directory (same filesystem), // and os.Rename atomically replaces the directory entry without following @@ -241,6 +253,11 @@ func (t *writeFileTool) Call(argsJSON string) (string, error) { os.Remove(tmpPath) return jsonError(fmt.Sprintf("cannot write %q: %v", args.Path, err)) } + if err := tmpFile.Chmod(origMode); err != nil { + tmpFile.Close() + os.Remove(tmpPath) + return jsonError(fmt.Sprintf("cannot set permissions %q: %v", args.Path, err)) + } if err := tmpFile.Close(); err != nil { os.Remove(tmpPath) return jsonError(fmt.Sprintf("cannot close temp file: %v", err)) @@ -722,6 +739,8 @@ func isBinary(data []byte) bool { // readLinesWithCount reads lines from an open file, returning content // and total line count in a single pass. offset is 1-based, limit caps lines. +// The returned content is capped at maxReadBytes to avoid unbounded memory +// consumption from huge lines or huge limits. func readLinesWithCount(f *os.File, offset, limit int) (string, int, error) { var out strings.Builder scanner := bufio.NewScanner(f) @@ -729,6 +748,7 @@ func readLinesWithCount(f *os.File, offset, limit int) (string, int, error) { lineNum := 0 start := offset end := offset + limit - 1 + truncated := false for scanner.Scan() { lineNum++ @@ -738,7 +758,17 @@ func readLinesWithCount(f *os.File, offset, limit int) (string, int, error) { if lineNum > end { continue // count total even beyond limit } - out.WriteString(fmt.Sprintf("%d|%s\n", lineNum, scanner.Text())) + line := scanner.Text() + formatted := fmt.Sprintf("%d|%s\n", lineNum, line) + if !truncated && out.Len()+len(formatted) > maxReadBytes { + out.WriteString("... [truncated]\n") + truncated = true + // Continue scanning only to count total lines. + continue + } + if !truncated { + out.WriteString(formatted) + } } // If no limit was set (limit=0), continue counting past start @@ -759,6 +789,10 @@ func confineToCWD(path string) (string, error) { if err != nil { return "", fmt.Errorf("cannot determine working directory: %v", err) } + cwdResolved, err := filepath.EvalSymlinks(cwd) + if err != nil { + return "", fmt.Errorf("cannot resolve working directory: %v", err) + } // Resolve to absolute path var abs string @@ -768,6 +802,34 @@ func confineToCWD(path string) (string, error) { abs = filepath.Join(cwd, path) } + // Resolve symlinks so a path that is lexically under CWD but traverses a + // symlink cannot escape (e.g., cwd/link -> /etc, cwd/link/file would + // resolve to /etc/file). If the full path or an intermediate directory + // does not exist yet (common for write_file), walk up to the deepest + // existing ancestor, resolve that, and re-attach the missing suffix. + // Missing directories cannot be symlinks, so they cannot be used to escape. + absResolved := abs + resolved := false + cur := abs + for cur != "/" && cur != "" { + if r, err := filepath.EvalSymlinks(cur); err == nil { + suffix := strings.TrimPrefix(abs, cur) + if suffix == "" { + absResolved = r + } else { + absResolved = r + suffix + } + resolved = true + break + } + cur = filepath.Dir(cur) + } + if !resolved { + // Nothing resolvable along the path (should not happen in practice, + // since / always exists). Fall back to lexical path. + absResolved = abs + } + // Allow paths under ~/.odek/ even when outside CWD — the agent // frequently writes memory and other state to this directory. The // carve-out deliberately EXCLUDES odek's trust anchors (config.json, @@ -779,8 +841,8 @@ func confineToCWD(path string) (string, error) { home, homeErr := os.UserHomeDir() if homeErr == nil { odekPrefix := home + "/.odek/" - if strings.HasPrefix(abs, odekPrefix) { - if isProtectedOdekPath(strings.TrimPrefix(abs, odekPrefix)) { + if strings.HasPrefix(absResolved, odekPrefix) { + if isProtectedOdekPath(strings.TrimPrefix(absResolved, odekPrefix)) { return "", fmt.Errorf("path %q is a protected odek configuration path and cannot be written by file tools", path) } return abs, nil @@ -788,7 +850,7 @@ func confineToCWD(path string) (string, error) { } // Check that the resolved path is within CWD - if !strings.HasPrefix(abs, cwd+string(filepath.Separator)) && abs != cwd { + if !strings.HasPrefix(absResolved, cwdResolved+string(filepath.Separator)) && absResolved != cwdResolved { return "", fmt.Errorf("path %q escapes the working directory", path) } @@ -986,7 +1048,7 @@ func (t *batchReadTool) readSingle(arg batchReadFileArg) batchReadFileResult { return batchReadFileResult{ Path: arg.Path, - Content: content, + Content: wrapUntrusted(arg.Path, content), TotalLines: totalLines, } } @@ -1169,10 +1231,14 @@ func (t *globTool) Call(argsJSON string) (result string, err error) { return jsonError(fmt.Sprintf("invalid glob %q: %v", args.Pattern, err)) } for _, p := range gm { - info, err := os.Stat(p) + // Use Lstat so symlinks are not followed to their targets. + info, err := os.Lstat(p) if err != nil { continue } + if info.Mode()&os.ModeSymlink != 0 { + continue + } matches = append(matches, globMatch{ Path: p, Size: info.Size(), diff --git a/cmd/odek/security_vulnerabilities_test.go b/cmd/odek/security_vulnerabilities_test.go new file mode 100644 index 0000000..de31451 --- /dev/null +++ b/cmd/odek/security_vulnerabilities_test.go @@ -0,0 +1,274 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +// skipIfSymlinksUnsupported skips the test on platforms where creating +// symlinks is unreliable (Windows without dev mode / admin). +func skipIfSymlinksUnsupported(t *testing.T) { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("symlink tests skipped on Windows") + } +} + +// ── 1. Symlink directory traversal in write_file / patch / batch_patch ─── + +func TestWriteFile_SymlinkDirectoryTraversal(t *testing.T) { + skipIfSymlinksUnsupported(t) + + cwd := t.TempDir() + outsideDir := t.TempDir() + outsideFile := filepath.Join(outsideDir, "escaped.txt") + + link := filepath.Join(cwd, "link") + if err := os.Symlink(outsideDir, link); err != nil { + t.Fatalf("create symlink: %v", err) + } + + tool := &writeFileTool{restrictToCWD: true} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"content":"escaped"}`, filepath.Join(link, "escaped.txt"))) + var r struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + + if r.Success { + t.Fatalf("write_file should reject symlink directory traversal, but succeeded") + } + if _, err := os.Stat(outsideFile); !os.IsNotExist(err) { + t.Fatalf("write_file escaped CWD via directory symlink; file exists at %s", outsideFile) + } +} + +func TestPatch_SymlinkDirectoryTraversal(t *testing.T) { + skipIfSymlinksUnsupported(t) + + cwd := t.TempDir() + outsideDir := t.TempDir() + outsideFile := filepath.Join(outsideDir, "target.txt") + os.WriteFile(outsideFile, []byte("old content"), 0644) + + link := filepath.Join(cwd, "link") + if err := os.Symlink(outsideDir, link); err != nil { + t.Fatalf("create symlink: %v", err) + } + + tool := &patchTool{restrictToCWD: true} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"old_string":"old content","new_string":"new content"}`, filepath.Join(link, "target.txt"))) + var r struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + + if r.Success { + t.Fatalf("patch should reject symlink directory traversal, but succeeded") + } + data, _ := os.ReadFile(outsideFile) + if string(data) != "old content" { + t.Fatalf("patch escaped CWD and modified outside file: %q", string(data)) + } +} + +func TestBatchPatch_SymlinkDirectoryTraversal(t *testing.T) { + skipIfSymlinksUnsupported(t) + + cwd := t.TempDir() + outsideDir := t.TempDir() + outsideFile := filepath.Join(outsideDir, "target.txt") + os.WriteFile(outsideFile, []byte("old content"), 0644) + + link := filepath.Join(cwd, "link") + if err := os.Symlink(outsideDir, link); err != nil { + t.Fatalf("create symlink: %v", err) + } + + tool := &batchPatchTool{restrictToCWD: true} + args := fmt.Sprintf(`{"patches":[{"path":%q,"old_string":"old content","new_string":"new content"}]}`, filepath.Join(link, "target.txt")) + result := callJSON(t, tool, args) + var r struct { + Results []struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + + if len(r.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(r.Results)) + } + if r.Results[0].Success { + t.Fatalf("batch_patch should reject symlink directory traversal, but succeeded") + } + data, _ := os.ReadFile(outsideFile) + if string(data) != "old content" { + t.Fatalf("batch_patch escaped CWD and modified outside file: %q", string(data)) + } +} + +// ── 2. glob must not follow symlinks in simple mode ───────────────────── + +func TestGlob_SymlinkFileTraversal(t *testing.T) { + skipIfSymlinksUnsupported(t) + + cwd := t.TempDir() + outsideFile := filepath.Join(t.TempDir(), "secret.txt") + os.WriteFile(outsideFile, []byte("secret"), 0644) + + link := filepath.Join(cwd, "link.txt") + if err := os.Symlink(outsideFile, link); err != nil { + t.Fatalf("create symlink: %v", err) + } + + tool := &globTool{} + result := callJSON(t, tool, `{"pattern":"*.txt","path":"`+cwd+`"}`) + var r struct { + Matches []globMatch `json:"matches"` + } + mustUnmarshal(t, result, &r) + + for _, m := range r.Matches { + if m.Path == link || strings.HasPrefix(m.Path, filepath.Dir(outsideFile)) { + t.Fatalf("glob followed file symlink to outside path: %s", m.Path) + } + } +} + +func TestGlob_SymlinkDirectoryTraversal(t *testing.T) { + skipIfSymlinksUnsupported(t) + + cwd := t.TempDir() + outsideDir := t.TempDir() + os.WriteFile(filepath.Join(outsideDir, "secret.txt"), []byte("secret"), 0644) + + link := filepath.Join(cwd, "link") + if err := os.Symlink(outsideDir, link); err != nil { + t.Fatalf("create symlink: %v", err) + } + + tool := &globTool{} + result := callJSON(t, tool, `{"pattern":"*","path":"`+cwd+`"}`) + var r struct { + Matches []globMatch `json:"matches"` + } + mustUnmarshal(t, result, &r) + + for _, m := range r.Matches { + if m.Path == link || strings.HasPrefix(m.Path, outsideDir) { + t.Fatalf("glob listed directory symlink that points outside: %s", m.Path) + } + } +} + +// ── 3. batch_read must wrap content with untrusted_content ─────────────── + +func TestBatchRead_WrapsContent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + os.WriteFile(path, []byte("hello world"), 0644) + + tool := &batchReadTool{} + result := callJSON(t, tool, `{"files":[{"path":"`+path+`"}]}`) + var r struct { + Results []batchReadFileResult `json:"results"` + } + mustUnmarshal(t, result, &r) + + if len(r.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(r.Results)) + } + if !strings.HasPrefix(r.Results[0].Content, " 1024*1024 { + t.Fatalf("read_file returned %d bytes, expected cap at 1 MiB", len(body)) + } +} + +func TestBatchRead_CapsTotalSize(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "big.txt") + + var lines []string + for i := 0; i < 10; i++ { + lines = append(lines, strings.Repeat("x", 500*1024)) + } + os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) + + tool := &batchReadTool{} + result := callJSON(t, tool, `{"files":[{"path":"`+path+`","limit":10}]}`) + var r struct { + Results []batchReadFileResult `json:"results"` + } + mustUnmarshal(t, result, &r) + + if len(r.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(r.Results)) + } + body := r.Results[0].Content + if len(body) > 1024*1024 { + t.Fatalf("batch_read returned %d bytes, expected cap at 1 MiB", len(body)) + } +} + +// ── 5. write_file must preserve original file mode on overwrite ────────── + +func TestWriteFile_PreservesFileMode(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "secret.txt") + // Start with a specific mode (e.g., group-readable). write_file's temp+rename + // currently drops it to the temp-file default (0600), leaking/changing mode. + if err := os.WriteFile(path, []byte("old"), 0640); err != nil { + t.Fatalf("write initial file: %v", err) + } + + tool := &writeFileTool{} + result := callJSON(t, tool, fmt.Sprintf(`{"path":%q,"content":"new"}`, path)) + var r struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + } + mustUnmarshal(t, result, &r) + if !r.Success { + t.Fatalf("write_file failed: %s", r.Error) + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat file: %v", err) + } + if info.Mode().Perm() != 0640 { + t.Fatalf("write_file changed mode from 0640 to %04o", info.Mode().Perm()) + } +}