diff --git a/cmd/odek/browser_tool.go b/cmd/odek/browser_tool.go index 9e457fb..98ae972 100644 --- a/cmd/odek/browser_tool.go +++ b/cmd/odek/browser_tool.go @@ -55,6 +55,7 @@ type browserState struct { // ── Browser Tool ────────────────────────────────────────────────────── type browserTool struct { + ctxTool state *browserState client *http.Client dangerousConfig danger.DangerousConfig @@ -196,7 +197,7 @@ func (t *browserTool) doNavigate(rawURL string) (string, error) { return jsonError(fmt.Sprintf("invalid URL %q: must start with http:// or https://", rawURL)) } - req, err := http.NewRequest("GET", rawURL, nil) + req, err := http.NewRequestWithContext(t.toolCtx(), "GET", rawURL, nil) if err != nil { return jsonError(fmt.Sprintf("cannot create request: %v", err)) } diff --git a/cmd/odek/perf_tools.go b/cmd/odek/perf_tools.go index a77e412..457aa77 100644 --- a/cmd/odek/perf_tools.go +++ b/cmd/odek/perf_tools.go @@ -412,6 +412,7 @@ func (t *parallelShellTool) runOne(cmd parallelShellCmd) parallelShellEntry { const maxHTTPBatchURLs = 10 type httpBatchTool struct { + ctxTool dangerousConfig danger.DangerousConfig client *http.Client } @@ -556,7 +557,7 @@ func (t *httpBatchTool) fetchOne(r httpBatchReq) httpBatchEntry { } entry := httpBatchEntry{URL: r.URL} - httpReq, err := http.NewRequest(method, r.URL, nil) + httpReq, err := http.NewRequestWithContext(t.toolCtx(), method, r.URL, nil) if err != nil { entry.Error = err.Error() return entry diff --git a/cmd/odek/shell.go b/cmd/odek/shell.go index 8bb9c92..849c885 100644 --- a/cmd/odek/shell.go +++ b/cmd/odek/shell.go @@ -2,15 +2,26 @@ package main import ( "bytes" + "context" "encoding/json" "fmt" "os/exec" "strings" "sync" + "syscall" + "time" "github.com/BackendStack21/odek/internal/danger" ) +// defaultShellTimeout bounds a single shell command. It is deliberately +// generous — the goal is to stop a genuinely stuck command (a network read +// that never returns, an interactive prompt, an infinite loop) from wedging +// the agent forever, NOT to kill legitimate long builds or test suites. When +// the agent context is cancelled (Ctrl-C, turn timeout) the command is killed +// immediately regardless of this backstop. +const defaultShellTimeout = 30 * time.Minute + // shellTool is odek's built-in tool that lets the agent run shell commands. // // This is the only built-in tool — it's enough for reading files, running @@ -61,6 +72,13 @@ type shellTool struct { // ttyPath is the path to the terminal device for approval prompts. // Overridden in tests to mock user input. Only used when approver is nil. ttyPath string + + // ctxTool provides SetContext/toolCtx so cancelling the agent context + // (Ctrl-C, turn timeout) kills the running command. + ctxTool + + // timeout bounds a single command. Zero falls back to defaultShellTimeout. + timeout time.Duration } func (t *shellTool) Name() string { return "shell" } @@ -113,13 +131,52 @@ func (t *shellTool) Call(args string) (string, error) { return "", err } - cmd := t.buildCmd(input.Command) + // Bound execution: cancel with the agent context (Ctrl-C / turn timeout) + // and a generous backstop timeout so a stuck command can never wedge the + // agent forever. Note: in sandbox mode this kills the host-side + // `docker exec` client, which unblocks the agent, but Docker does not + // propagate the signal to the in-container process — that lingers until the + // container is torn down at session end. + base := t.toolCtx() + timeout := t.timeout + if timeout <= 0 { + timeout = defaultShellTimeout + } + ctx, cancel := context.WithTimeout(base, timeout) + defer cancel() + + cmd := t.buildCmd(ctx, input.Command) + // Run the command in its own process group and, on cancel/timeout, kill the + // WHOLE group — not just the `sh` leader. `sh -c ""` may fork children + // (e.g. `sleep`); killing only the leader leaves them alive holding the + // output pipes, so Run() would block until WaitDelay. Signalling the group + // (negative pid) tears the whole tree down at once. + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Cancel = func() error { + if cmd.Process != nil { + // Best-effort group kill; ignore ESRCH if it already exited. + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + } + return nil + } + // WaitDelay is a backstop in case a process somehow outlives the group kill. + cmd.WaitDelay = 3 * time.Second var outBuf, errBuf bytes.Buffer cmd.Stdout = &outBuf cmd.Stderr = &errBuf err := cmd.Run() + + // Surface cancellation/timeout as a clear, actionable error rather than an + // opaque "signal: killed". + if ctxErr := ctx.Err(); ctxErr != nil { + if ctxErr == context.DeadlineExceeded { + return "", fmt.Errorf("shell: command timed out after %s (still running? it was killed): %s", timeout, input.Command) + } + return "", fmt.Errorf("shell: command cancelled: %s", input.Command) + } + output := strings.TrimSpace(outBuf.String()) stderrStr := strings.TrimSpace(errBuf.String()) if stderrStr != "" { @@ -200,9 +257,9 @@ func (t *shellTool) promptUser(cmd, description string) error { // // When running on the host (default), the command executes via "sh -c" // in odek's current working directory. -func (t *shellTool) buildCmd(command string) *exec.Cmd { +func (t *shellTool) buildCmd(ctx context.Context, command string) *exec.Cmd { if t.containerName != "" { - return exec.Command("docker", "exec", "-w", "/workspace", t.containerName, "sh", "-c", command) + return exec.CommandContext(ctx, "docker", "exec", "-w", "/workspace", t.containerName, "sh", "-c", command) } - return exec.Command("sh", "-c", command) + return exec.CommandContext(ctx, "sh", "-c", command) } diff --git a/cmd/odek/shell_test.go b/cmd/odek/shell_test.go index 021c46c..34dd162 100644 --- a/cmd/odek/shell_test.go +++ b/cmd/odek/shell_test.go @@ -1,11 +1,13 @@ package main import ( + "context" "encoding/json" "os" "os/exec" "strings" "testing" + "time" ) func TestShellTool_Name(t *testing.T) { @@ -15,6 +17,60 @@ func TestShellTool_Name(t *testing.T) { } } +// TestShellTool_Timeout verifies a stuck command can no longer wedge the agent: +// a tiny per-tool timeout kills the command and Call returns promptly with a +// clear timeout error instead of blocking forever. +func TestShellTool_Timeout(t *testing.T) { + st := &shellTool{timeout: 200 * time.Millisecond} + done := make(chan struct{}) + var out string + var err error + go func() { + out, err = st.Call(`{"command":"sleep 30 | cat"}`) + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Call did not return after the command timeout — agent would hang") + } + if err == nil { + t.Fatalf("expected a timeout error, got output %q", out) + } + if !strings.Contains(err.Error(), "timed out") { + t.Errorf("error should mention the timeout, got: %v", err) + } +} + +// TestShellTool_ContextCancellation verifies Ctrl-C / turn cancellation kills a +// running command immediately via the agent context. +func TestShellTool_ContextCancellation(t *testing.T) { + st := &shellTool{} + ctx, cancel := context.WithCancel(context.Background()) + st.SetContext(ctx) + + done := make(chan struct{}) + var err error + go func() { + _, err = st.Call(`{"command":"sleep 30 | cat"}`) + close(done) + }() + time.Sleep(100 * time.Millisecond) + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Call did not return after context cancellation — Ctrl-C would not work") + } + if err == nil { + t.Fatal("expected a cancellation error") + } + if !strings.Contains(err.Error(), "cancelled") { + t.Errorf("error should mention cancellation, got: %v", err) + } +} + func TestShellTool_Description(t *testing.T) { st := &shellTool{} desc := st.Description() @@ -130,7 +186,7 @@ func TestShellTool_Call_StdoutAndStderr(t *testing.T) { func TestShellTool_BuildCmd_Local(t *testing.T) { st := &shellTool{} - cmd := st.buildCmd("echo test") + cmd := st.buildCmd(context.Background(), "echo test") args := cmd.Args if args[0] != "sh" || args[1] != "-c" || args[2] != "echo test" { t.Errorf("local cmd args = %v, want [sh -c 'echo test']", args) @@ -139,7 +195,7 @@ func TestShellTool_BuildCmd_Local(t *testing.T) { func TestShellTool_BuildCmd_Docker(t *testing.T) { st := &shellTool{containerName: "odek-12345"} - cmd := st.buildCmd("echo test") + cmd := st.buildCmd(context.Background(), "echo test") args := cmd.Args expected := []string{"docker", "exec", "-w", "/workspace", "odek-12345", "sh", "-c", "echo test"} if !stringSlicesEqual(args, expected) { @@ -303,7 +359,7 @@ func TestShellTool_CheckApproval(t *testing.T) { func TestShellTool_BuildCmd_Default(t *testing.T) { st := &shellTool{} - cmd := st.buildCmd("echo hello") + cmd := st.buildCmd(context.Background(), "echo hello") if cmd.Args[0] != "sh" { t.Errorf("expected sh, got %s", cmd.Args[0]) } diff --git a/cmd/odek/toolctx.go b/cmd/odek/toolctx.go new file mode 100644 index 0000000..c5d65a0 --- /dev/null +++ b/cmd/odek/toolctx.go @@ -0,0 +1,40 @@ +package main + +import ( + "context" + "sync" +) + +// ctxTool is embedded by tools that support agent-context cancellation. The +// agent loop calls SetContext on any tool implementing it (see internal/loop) +// right before invoking the tool, so cancelling the agent context — Ctrl-C, a +// turn timeout — interrupts the tool's in-flight network request or subprocess +// instead of letting it run to completion (or hang) unsupervised. +// +// The mutex matters: when the LLM emits two calls to the SAME tool in one +// turn, the loop runs them in parallel goroutines and calls SetContext on the +// shared instance from each. Without synchronisation that is a data race on the +// context field even though the value is identical for the turn. +type ctxTool struct { + mu sync.Mutex + ctx context.Context +} + +// SetContext records the agent context for the next Call. Safe for concurrent +// use by parallel invocations of the same tool instance. +func (c *ctxTool) SetContext(ctx context.Context) { + c.mu.Lock() + c.ctx = ctx + c.mu.Unlock() +} + +// toolCtx returns the recorded agent context, or context.Background() if none +// was set (e.g. tools invoked directly in tests or outside the agent loop). +func (c *ctxTool) toolCtx() context.Context { + c.mu.Lock() + defer c.mu.Unlock() + if c.ctx == nil { + return context.Background() + } + return c.ctx +} diff --git a/cmd/odek/toolctx_test.go b/cmd/odek/toolctx_test.go new file mode 100644 index 0000000..14b5177 --- /dev/null +++ b/cmd/odek/toolctx_test.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "strings" + "sync" + "testing" +) + +func TestCtxTool_DefaultsToBackground(t *testing.T) { + var c ctxTool + if c.toolCtx() != context.Background() { + t.Error("unset ctxTool should return context.Background()") + } +} + +func TestCtxTool_SetAndGet(t *testing.T) { + var c ctxTool + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c.SetContext(ctx) + if c.toolCtx() != ctx { + t.Error("toolCtx should return the context set via SetContext") + } +} + +// TestCtxTool_ConcurrentSetContext mirrors the loop calling SetContext on a +// shared tool instance from parallel goroutines — it must be race-free. +func TestCtxTool_ConcurrentSetContext(t *testing.T) { + var c ctxTool + ctx := context.Background() + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c.SetContext(ctx) + _ = c.toolCtx() + }() + } + wg.Wait() +} + +// TestHTTPBatch_ContextCancelled verifies the propagated context aborts the +// fetch: a cancelled context yields an error entry instead of a real request. +func TestHTTPBatch_ContextCancelled(t *testing.T) { + tool := newHTTPBatchTool(allowAllDanger()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // already cancelled + tool.SetContext(ctx) + + result := callJSON(t, tool, `{"requests":[{"url":"http://example.com/"}]}`) + var r struct { + Results []struct { + Error string `json:"error"` + } `json:"results"` + } + mustUnmarshal(t, result, &r) + if len(r.Results) != 1 { + t.Fatalf("Results = %d, want 1", len(r.Results)) + } + if r.Results[0].Error == "" { + t.Fatal("expected an error from the cancelled context") + } + if !strings.Contains(r.Results[0].Error, "context canceled") { + t.Errorf("error should reflect cancellation, got: %s", r.Results[0].Error) + } +} diff --git a/cmd/odek/transcribe_tool.go b/cmd/odek/transcribe_tool.go index 7d1832d..e945bf6 100644 --- a/cmd/odek/transcribe_tool.go +++ b/cmd/odek/transcribe_tool.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "os" @@ -20,7 +21,7 @@ import ( // Returns the path to the WAV file (may be the same as input if already WAV/MP3/FLAC // or if ffmpeg is unavailable/fails — in which case whisper will produce its own error). // The caller must remove the returned path if it differs from the input path. -func convertToWAV(srcPath string) string { +func convertToWAV(ctx context.Context, srcPath string) string { ext := strings.ToLower(filepath.Ext(srcPath)) // whisper.cpp supports WAV, MP3, FLAC natively via dr_wav/dr_mp3/dr_flac. switch ext { @@ -35,7 +36,7 @@ func convertToWAV(srcPath string) string { // Convert to WAV using ffmpeg — best-effort, fall through on failure. dstPath := srcPath + ".wav" - cmd := exec.Command("ffmpeg", "-y", "-i", srcPath, "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", dstPath) + cmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", srcPath, "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", dstPath) if err := cmd.Run(); err != nil { // If ffmpeg fails (corrupt file, unsupported codec, etc.), // just pass the original path — whisper will produce its own error. @@ -138,6 +139,7 @@ Set "model" in transcription config to change which model is expected.`, // ═════════════════════════════════════════════════════════════════════════ type transcribeTool struct { + ctxTool dangerousConfig danger.DangerousConfig transcriptionCfg config.TranscriptionConfig } @@ -224,7 +226,7 @@ func (t *transcribeTool) Call(argsJSON string) (result string, err error) { f.Close() // Convert to WAV if needed (whisper.cpp doesn't support OGG Opus natively). - wavPath := convertToWAV(args.Path) + wavPath := convertToWAV(t.toolCtx(), args.Path) cleanup := func() { if wavPath != args.Path { os.Remove(wavPath) @@ -263,7 +265,7 @@ func (t *transcribeTool) Call(argsJSON string) (result string, err error) { args2 = append(args2, "--language", lang) } - cmd := exec.Command(binary, args2...) + cmd := exec.CommandContext(t.toolCtx(), binary, args2...) output, err := cmd.Output() if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { diff --git a/cmd/odek/vision_tool.go b/cmd/odek/vision_tool.go index 20b1ab3..73e0bfe 100644 --- a/cmd/odek/vision_tool.go +++ b/cmd/odek/vision_tool.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "os" @@ -89,7 +90,7 @@ Or set models_dir in the vision config.`, mp, dir, dir) // extractVideoFrames samples n evenly-spaced frames from videoPath into a // temporary directory. Returns paths to the JPEG frame files; caller must // remove the directory (filepath.Dir of the first path). -func extractVideoFrames(videoPath string, n int) ([]string, error) { +func extractVideoFrames(ctx context.Context, videoPath string, n int) ([]string, error) { if _, err := exec.LookPath("ffmpeg"); err != nil { return nil, fmt.Errorf("ffmpeg not found — required for video frame extraction") } @@ -98,7 +99,7 @@ func extractVideoFrames(videoPath string, n int) ([]string, error) { } // Get duration with ffprobe - out, err := exec.Command("ffprobe", + out, err := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "csv=p=0", @@ -124,7 +125,7 @@ func extractVideoFrames(videoPath string, n int) ([]string, error) { for i := 1; i <= n; i++ { ts := interval * float64(i) out := filepath.Join(tmpDir, fmt.Sprintf("frame_%02d.jpg", i)) - cmd := exec.Command("ffmpeg", + cmd := exec.CommandContext(ctx, "ffmpeg", "-ss", fmt.Sprintf("%.3f", ts), "-i", videoPath, "-frames:v", "1", @@ -146,7 +147,7 @@ func extractVideoFrames(videoPath string, n int) ([]string, error) { // runLlamaMtmd calls llama-mtmd-cli in single-turn mode with one or more images // and returns the trimmed stdout response. -func runLlamaMtmd(binary, modelPath, mmprojPath, prompt string, imagePaths []string) (string, error) { +func runLlamaMtmd(ctx context.Context, binary, modelPath, mmprojPath, prompt string, imagePaths []string) (string, error) { args := []string{ "-m", modelPath, "--mmproj", mmprojPath, @@ -162,7 +163,7 @@ func runLlamaMtmd(binary, modelPath, mmprojPath, prompt string, imagePaths []str args = append(args, "--image", img) } - cmd := exec.Command(binary, args...) + cmd := exec.CommandContext(ctx, binary, args...) output, err := cmd.Output() if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { @@ -179,6 +180,7 @@ func runLlamaMtmd(binary, modelPath, mmprojPath, prompt string, imagePaths []str // ═════════════════════════════════════════════════════════════════════════ type visionTool struct { + ctxTool dangerousConfig danger.DangerousConfig visionCfg config.VisionConfig } @@ -277,7 +279,7 @@ func (t *visionTool) Call(argsJSON string) (result string, err error) { } func (t *visionTool) analyzeImage(binary, modelPath, mmprojPath, imgPath, prompt, source string) (string, error) { - desc, err := runLlamaMtmd(binary, modelPath, mmprojPath, prompt, []string{imgPath}) + desc, err := runLlamaMtmd(t.toolCtx(), binary, modelPath, mmprojPath, prompt, []string{imgPath}) if err != nil { return jsonResult(visionResult{Error: err.Error()}) } @@ -294,7 +296,7 @@ func (t *visionTool) analyzeVideo(binary, modelPath, mmprojPath, videoPath, prom n = 8 } - frames, err := extractVideoFrames(videoPath, n) + frames, err := extractVideoFrames(t.toolCtx(), videoPath, n) if err != nil { return jsonResult(visionResult{Error: err.Error()}) } @@ -304,7 +306,7 @@ func (t *visionTool) analyzeVideo(binary, modelPath, mmprojPath, videoPath, prom "These are %d frames sampled evenly from a video. %s", len(frames), prompt, ) - desc, err := runLlamaMtmd(binary, modelPath, mmprojPath, videoPrompt, frames) + desc, err := runLlamaMtmd(t.toolCtx(), binary, modelPath, mmprojPath, videoPrompt, frames) if err != nil { return jsonResult(visionResult{Error: err.Error()}) } diff --git a/cmd/odek/web_search_tool.go b/cmd/odek/web_search_tool.go index 908c033..b155034 100644 --- a/cmd/odek/web_search_tool.go +++ b/cmd/odek/web_search_tool.go @@ -34,6 +34,7 @@ var ( // ═════════════════════════════════════════════════════════════════════════ type webSearchTool struct { + ctxTool dangerousConfig danger.DangerousConfig cfg config.WebSearchConfig client *http.Client @@ -235,7 +236,7 @@ func (t *webSearchTool) query(query, category string) (*searxngResponse, error) } endpoint.RawQuery = q.Encode() - req, err := http.NewRequest(http.MethodGet, endpoint.String(), nil) + req, err := http.NewRequestWithContext(t.toolCtx(), http.MethodGet, endpoint.String(), nil) if err != nil { return nil, fmt.Errorf("build request: %v", err) } diff --git a/internal/fsatomic/fsatomic.go b/internal/fsatomic/fsatomic.go new file mode 100644 index 0000000..e9d37cf --- /dev/null +++ b/internal/fsatomic/fsatomic.go @@ -0,0 +1,70 @@ +// Package fsatomic provides a crash-durable atomic file write. +// +// The common "write a temp file then rename over the target" idiom gives +// atomicity (a reader sees either the old or the new file, never a torn one), +// but NOT durability: without an fsync, a power loss or kernel crash can land +// the rename in the directory while the file's data is still only in the page +// cache, leaving an empty or truncated file after reboot. For data the agent +// can't reconstruct — conversation sessions, extracted memories — that is silent +// data loss. +// +// WriteFile closes the gap: it fsyncs the file data before the rename and +// fsyncs the parent directory after, so a successful return means the bytes are +// durably on disk. It also uses a unique temp name, so two concurrent writers +// to the same target can't clobber each other's temp file. +package fsatomic + +import ( + "fmt" + "os" + "path/filepath" +) + +// WriteFile atomically and durably writes data to path with the given perm. +// On success the bytes are fsynced to disk and the rename is durable. +func WriteFile(path string, data []byte, perm os.FileMode) (err error) { + dir := filepath.Dir(path) + + f, err := os.CreateTemp(dir, "."+filepath.Base(path)+".tmp-*") + if err != nil { + return fmt.Errorf("fsatomic: create temp: %w", err) + } + tmp := f.Name() + // Remove the temp file on any failure before the rename succeeds. + defer func() { + if tmp != "" { + os.Remove(tmp) + } + }() + + if _, err := f.Write(data); err != nil { + f.Close() + return fmt.Errorf("fsatomic: write: %w", err) + } + if err := f.Chmod(perm); err != nil { + f.Close() + return fmt.Errorf("fsatomic: chmod: %w", err) + } + // Flush the file's data to disk before exposing it via the rename. + if err := f.Sync(); err != nil { + f.Close() + return fmt.Errorf("fsatomic: fsync temp: %w", err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("fsatomic: close temp: %w", err) + } + + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("fsatomic: rename: %w", err) + } + tmp = "" // renamed — no longer ours to remove + + // Make the rename itself durable. Best-effort: some filesystems don't + // support directory fsync, and the data is already synced, so a failure + // here doesn't corrupt anything — it just weakens the crash guarantee. + if d, derr := os.Open(dir); derr == nil { + _ = d.Sync() + d.Close() + } + return nil +} diff --git a/internal/fsatomic/fsatomic_test.go b/internal/fsatomic/fsatomic_test.go new file mode 100644 index 0000000..307ab57 --- /dev/null +++ b/internal/fsatomic/fsatomic_test.go @@ -0,0 +1,111 @@ +package fsatomic + +import ( + "os" + "path/filepath" + "sync" + "testing" +) + +func TestWriteFile_WritesContentAndPerm(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data.json") + + if err := WriteFile(path, []byte("hello"), 0600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "hello" { + t.Errorf("content = %q, want %q", got, "hello") + } + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("perm = %v, want 0600", info.Mode().Perm()) + } +} + +func TestWriteFile_Overwrites(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data") + if err := WriteFile(path, []byte("first"), 0644); err != nil { + t.Fatal(err) + } + if err := WriteFile(path, []byte("second"), 0644); err != nil { + t.Fatal(err) + } + got, _ := os.ReadFile(path) + if string(got) != "second" { + t.Errorf("content = %q, want %q", got, "second") + } +} + +// TestWriteFile_LeavesNoTempOnSuccess verifies the unique temp file is renamed +// away (no litter), which also confirms the temp naming doesn't collide with +// the target. +func TestWriteFile_LeavesNoTempOnSuccess(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data") + if err := WriteFile(path, []byte("x"), 0644); err != nil { + t.Fatal(err) + } + entries, _ := os.ReadDir(dir) + if len(entries) != 1 || entries[0].Name() != "data" { + var names []string + for _, e := range entries { + names = append(names, e.Name()) + } + t.Errorf("expected only [data], got %v", names) + } +} + +// TestWriteFile_ConcurrentSameTarget verifies concurrent writers to the same +// path don't clobber each other's temp file (the old fixed-".tmp" pattern +// could) — the final content must be a complete one of the writes, never torn. +func TestWriteFile_ConcurrentSameTarget(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data") + payloads := []string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + "cccccccccccccccccccccccccccccc", + } + var wg sync.WaitGroup + for _, p := range payloads { + wg.Add(1) + go func(data string) { + defer wg.Done() + for i := 0; i < 20; i++ { + if err := WriteFile(path, []byte(data), 0644); err != nil { + t.Errorf("WriteFile: %v", err) + return + } + } + }(p) + } + wg.Wait() + + got, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + ok := false + for _, p := range payloads { + if string(got) == p { + ok = true + } + } + if !ok { + t.Errorf("final content %q is torn — not a complete write", got) + } + // No temp files left behind. + entries, _ := os.ReadDir(dir) + if len(entries) != 1 { + t.Errorf("expected only the target file, got %d entries", len(entries)) + } +} diff --git a/internal/llm/client.go b/internal/llm/client.go index b9dfafa..62dfe3c 100644 --- a/internal/llm/client.go +++ b/internal/llm/client.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "time" @@ -209,30 +210,12 @@ func (c *Client) SimpleCall(ctx context.Context, systemPrompt, userPrompt string return "", fmt.Errorf("llm: marshal request: %w", err) } - url := c.BaseURL + "/chat/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(reqBytes)) - if err != nil { - return "", fmt.Errorf("llm: create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.APIKey) - - resp, err := c.http.Do(req) + // Share the main loop's retry/backoff so a transient blip doesn't abort + // these best-effort secondary calls (skill matching, memory summaries, + // episode extraction, session titles). + respBytes, err := c.postChatWithRetry(ctx, reqBytes) if err != nil { - return "", fmt.Errorf("llm: %w", err) - } - defer resp.Body.Close() - - respBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize+1)) - if err != nil { - return "", fmt.Errorf("llm: read response: %w", err) - } - if len(respBytes) > maxResponseSize { - return "", fmt.Errorf("llm: response exceeds maximum size (%d bytes)", maxResponseSize) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("llm: %s (status %d)", resp.Status, resp.StatusCode) + return "", err } var raw struct { @@ -300,21 +283,36 @@ func (c *Client) Call(ctx context.Context, messages []Message, systemBlocks []Sy return nil, fmt.Errorf("llm: marshal request: %w", err) } + respBytes, err := c.postChatWithRetry(ctx, reqBytes) + if err != nil { + return nil, err + } + return parseResponse(respBytes) +} + +// postChatWithRetry POSTs reqBytes to /chat/completions and returns the raw 200 +// response body, retrying transient network errors and retryable HTTP statuses +// (429, 502, 503, 504) with exponential backoff. Shared by every chat call so +// the main loop and the lightweight secondary calls (SimpleCall) get identical +// resilience. Respects ctx cancellation during the backoff sleep. +func (c *Client) postChatWithRetry(ctx context.Context, reqBytes []byte) ([]byte, error) { url := c.BaseURL + "/chat/completions" const maxRetries = 3 var lastErr error + var wait time.Duration // how long to sleep before the next attempt for attempt := 0; attempt <= maxRetries; attempt++ { if attempt > 0 { - // Exponential backoff: 1s, 2s, 4s - backoff := time.Duration(1<<(attempt-1)) * time.Second select { case <-ctx.Done(): return nil, ctx.Err() - case <-time.After(backoff): + case <-time.After(wait): } } + // Default backoff for the next attempt if this one fails: 1s, 2s, 4s. + // A Retry-After header on a 429/503 overrides it below. + wait = time.Duration(1< 0 { + wait = ra + } continue } return nil, lastErr } - return parseResponse(respBytes) + return respBytes, nil } return nil, fmt.Errorf("llm: retry exhausted (%d attempts): %w", maxRetries+1, lastErr) } +// maxRetryAfter caps how long we'll honor a server's Retry-After. A pathological +// or hostile value (e.g. "Retry-After: 86400") must not wedge a turn for hours; +// ctx cancellation can still break the wait sooner. +const maxRetryAfter = 60 * time.Second + +// parseRetryAfter interprets an HTTP Retry-After header, which is either an +// integer number of seconds or an HTTP-date. Returns 0 when absent or +// unparseable (callers then fall back to exponential backoff). The result is +// capped at maxRetryAfter. +func parseRetryAfter(v string) time.Duration { + v = strings.TrimSpace(v) + if v == "" { + return 0 + } + var d time.Duration + if secs, err := strconv.Atoi(v); err == nil { + if secs <= 0 { + return 0 + } + d = time.Duration(secs) * time.Second + } else if t, err := http.ParseTime(v); err == nil { + d = time.Until(t) + if d <= 0 { + return 0 + } + } else { + return 0 + } + if d > maxRetryAfter { + d = maxRetryAfter + } + return d +} + // isRetryableHTTPStatus returns true for HTTP status codes that indicate // a transient error safe to retry after a backoff. func isRetryableHTTPStatus(code int) bool { diff --git a/internal/llm/retry_test.go b/internal/llm/retry_test.go index f97d536..933cb10 100644 --- a/internal/llm/retry_test.go +++ b/internal/llm/retry_test.go @@ -9,6 +9,90 @@ import ( "time" ) +func TestParseRetryAfter(t *testing.T) { + if d := parseRetryAfter("2"); d != 2*time.Second { + t.Errorf("parseRetryAfter(\"2\") = %v, want 2s", d) + } + if d := parseRetryAfter(" 5 "); d != 5*time.Second { + t.Errorf("parseRetryAfter trims and parses, got %v", d) + } + if d := parseRetryAfter(""); d != 0 { + t.Errorf("empty header → 0, got %v", d) + } + if d := parseRetryAfter("garbage"); d != 0 { + t.Errorf("unparseable → 0, got %v", d) + } + if d := parseRetryAfter("0"); d != 0 { + t.Errorf("zero/negative → 0, got %v", d) + } + // Capped at maxRetryAfter. + if d := parseRetryAfter("100000"); d != maxRetryAfter { + t.Errorf("huge value should cap at %v, got %v", maxRetryAfter, d) + } +} + +// TestClient_Call_HonorsRetryAfter verifies a 429 with a Retry-After header is +// retried (rather than failed) and ultimately succeeds. The 1s value keeps the +// test fast while exercising the header path. +func TestClient_Call_HonorsRetryAfter(t *testing.T) { + var callCount atomic.Int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if int(callCount.Add(1)) == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error":"slow down"}`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"choices":[{"message":{"content":"ok"}}]}`)) + })) + defer ts.Close() + + c := New(ts.URL, "key", "model", "", 0, 10*time.Second) + start := time.Now() + result, err := c.Call(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Content != "ok" { + t.Errorf("content = %q, want ok", result.Content) + } + if elapsed := time.Since(start); elapsed < 900*time.Millisecond { + t.Errorf("expected to wait ~1s for Retry-After, only waited %v", elapsed) + } +} + +// TestClient_SimpleCall_RetryOn429 verifies the lightweight secondary calls +// share the main loop's retry resilience: a transient 429 no longer aborts a +// skill-match / memory / title call on the first failure. +func TestClient_SimpleCall_RetryOn429(t *testing.T) { + var callCount atomic.Int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := int(callCount.Add(1)) + if count <= 2 { + w.WriteHeader(http.StatusTooManyRequests) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"error":{"message":"Rate limited"}}`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"choices":[{"message":{"content":"assessed"}}]}`)) + })) + defer ts.Close() + + c := New(ts.URL, "key", "model", "", 0, 10*time.Second) + out, err := c.SimpleCall(context.Background(), "sys", "user") + if err != nil { + t.Fatalf("unexpected error after retries: %v", err) + } + if out != "assessed" { + t.Errorf("content = %q, want %q", out, "assessed") + } + if callCount.Load() != 3 { + t.Errorf("call count = %d, want 3 (SimpleCall should retry)", callCount.Load()) + } +} + func TestClient_Call_RetryOn429(t *testing.T) { var callCount atomic.Int32 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/memory/episodes.go b/internal/memory/episodes.go index 01a2ad6..dcc017e 100644 --- a/internal/memory/episodes.go +++ b/internal/memory/episodes.go @@ -14,6 +14,7 @@ import ( "github.com/BackendStack21/go-vector/pkg/vector" "github.com/BackendStack21/odek/internal/embedding" + "github.com/BackendStack21/odek/internal/fsatomic" "github.com/BackendStack21/odek/internal/session" ) @@ -224,7 +225,7 @@ func (e *EpisodeStore) writeLocked(sessionID, summary string, turns int, prov Ep // Write the summary file. path := filepath.Join(e.dir, sessionID+".md") - if err := os.WriteFile(path, []byte(summary), 0600); err != nil { + if err := fsatomic.WriteFile(path, []byte(summary), 0600); err != nil { return events, fmt.Errorf("memory: write episode: %w", err) } @@ -661,19 +662,14 @@ func trustRank(p EpisodeProvenance) int { // Invalidates the in-memory cache after a successful write so the next // ReadIndex call picks up the new data. func (e *EpisodeStore) writeIndex(idx []EpisodeMeta) error { - // Write to temp + rename for atomicity + // Write atomically and durably (temp → fsync → rename → dir fsync). idxPath := filepath.Join(e.dir, episodeIndexFile) - tmpPath := idxPath + ".tmp" data, err := json.MarshalIndent(idx, "", " ") if err != nil { return fmt.Errorf("memory: marshal index: %w", err) } - if err := os.WriteFile(tmpPath, data, 0600); err != nil { - return err - } - if err := os.Rename(tmpPath, idxPath); err != nil { - os.Remove(tmpPath) // best-effort cleanup + if err := fsatomic.WriteFile(idxPath, data, 0600); err != nil { return err } diff --git a/internal/memory/facts.go b/internal/memory/facts.go index 0af6323..d266379 100644 --- a/internal/memory/facts.go +++ b/internal/memory/facts.go @@ -27,6 +27,8 @@ import ( "path/filepath" "strings" "sync" + + "github.com/BackendStack21/odek/internal/fsatomic" ) // File names for fact targets. @@ -313,27 +315,7 @@ func (f *FactStore) Entries(target string) ([]string, error) { // mutual exclusion that f.mu provides per-instance. func (f *FactStore) writeEntries(target string, entries []string) error { content := strings.Join(entries, entrySep) - path := f.path(target) - - tmp, err := os.CreateTemp(f.dir, filepath.Base(path)+".tmp-*") - if err != nil { - return err - } - tmpName := tmp.Name() - if _, err := tmp.Write([]byte(content)); err != nil { - tmp.Close() - os.Remove(tmpName) - return err - } - if err := tmp.Close(); err != nil { - os.Remove(tmpName) - return err - } - if err := os.Rename(tmpName, path); err != nil { - os.Remove(tmpName) - return err - } - return nil + return fsatomic.WriteFile(f.path(target), []byte(content), 0600) } // parseEntries splits file content into individual entries. diff --git a/internal/session/session.go b/internal/session/session.go index 65b51d7..a5dcb48 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -27,6 +27,7 @@ import ( "time" "github.com/BackendStack21/odek/internal/embedding" + "github.com/BackendStack21/odek/internal/fsatomic" "github.com/BackendStack21/odek/internal/llm" "github.com/BackendStack21/odek/internal/redact" ) @@ -197,15 +198,8 @@ func (s *Store) saveIndexLocked(idx map[string]*IndexEntry) error { if err != nil { return fmt.Errorf("session: marshal index: %w", err) } - target := s.indexPath() - tmp := target + ".tmp" - if err := os.WriteFile(tmp, data, 0600); err != nil { - os.Remove(tmp) - return fmt.Errorf("session: write index tmp: %w", err) - } - if err := os.Rename(tmp, target); err != nil { - os.Remove(tmp) - return fmt.Errorf("session: rename index: %w", err) + if err := fsatomic.WriteFile(s.indexPath(), data, 0600); err != nil { + return fmt.Errorf("session: write index: %w", err) } return nil } @@ -265,9 +259,10 @@ func (s *Store) Append(id string, newMsgs []llm.Message) error { return s.saveLocked(sess) } -// Save writes a session to disk atomically using a temp-file + rename -// strategy. This prevents: +// Save writes a session to disk atomically and durably via fsatomic.WriteFile +// (temp-file → fsync → rename → dir fsync). This prevents: // - Partial writes from crashes (rename is atomic on POSIX) +// - Data loss on power failure (the fsync flushes bytes before the rename) // - Symlink-following TOCTOU attacks (os.Rename replaces the // directory entry itself — it does NOT follow symlinks) func (s *Store) Save(sess *Session) error { @@ -297,15 +292,8 @@ func (s *Store) saveLocked(sess *Session) error { return fmt.Errorf("session: marshal: %w", err) } - target := s.path(sess.ID) - tmp := target + ".tmp" - if err := os.WriteFile(tmp, data, 0600); err != nil { - os.Remove(tmp) - return fmt.Errorf("session: write tmp: %w", err) - } - if err := os.Rename(tmp, target); err != nil { - os.Remove(tmp) - return fmt.Errorf("session: rename: %w", err) + if err := fsatomic.WriteFile(s.path(sess.ID), data, 0600); err != nil { + return fmt.Errorf("session: write: %w", err) } // Update the index atomically.