From 809711511708f5e9345affaff44c4e3b442c27dd Mon Sep 17 00:00:00 2001 From: Elizabeth Worstell Date: Thu, 25 Jun 2026 17:20:36 -0700 Subject: [PATCH] feat(client): parallel git snapshot download APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add ParallelGet — a concurrent chunked range download that fetches an object in chunkSize-byte chunks (up to concurrency requests in flight) and hands each chunk to a ChunkSink, which owns where the bytes land: - StreamSink reassembles the in-order byte stream for a streaming consumer, buffering out-of-order chunks in a fixed arena of 2*concurrency reusable slots. A slow reader applies backpressure to the fetchers, so peak memory is bounded at O(concurrency*chunkSize) regardless of object size — letting a decompress/extract pipeline overlap the download over a non-seekable sink without staging to disk or RAM. - DiskSink scatters each chunk straight to its offset in an io.WriterAt (e.g. *os.File), concurrently and unordered — the faster path for seekable sinks such as cache-to-cache backfill. ETag-pinned chunks reject mid-download rewrites; over/under-length chunks are rejected; a missing ETag, a range-ignoring backend, an object that fits in the first chunk, or concurrency 1 all fall back to a single full read. Add OpenGitSnapshotParallel: streaming git-snapshot helper returning a GitSnapshot whose Commit/BundleURL are available immediately while bytes stream in the background via a StreamSink; closing it cancels the download. The cachew CLI now extracts directly from the snapshot body (no temp file). Drops the redundant DownloadGitSnapshot. Co-authored-by: Amp Amp-Thread-ID: https://ampcode.com/threads/T-019ef6a9-a407-7389-bc43-001405e3ae9e --- client/chunk_sink.go | 275 +++++++++++++++++++++++++++ client/git.go | 122 +++++++++--- client/git_test.go | 103 ++++++++-- client/parallel_get.go | 188 +++++++++--------- client/parallel_get_test.go | 285 ++++++++++++++++++++++++---- cmd/cachew/git.go | 83 ++------ internal/cache/parallel_get.go | 9 +- internal/cache/parallel_get_test.go | 49 ++--- 8 files changed, 836 insertions(+), 278 deletions(-) create mode 100644 client/chunk_sink.go diff --git a/client/chunk_sink.go b/client/chunk_sink.go new file mode 100644 index 00000000..0ba85080 --- /dev/null +++ b/client/chunk_sink.go @@ -0,0 +1,275 @@ +package client + +import ( + "context" + "io" + "sync" + + "github.com/alecthomas/errors" +) + +// ChunkSink is the destination ParallelGet places fetched chunks into. The +// engine calls Place once per chunk, concurrently from up to `concurrency` +// goroutines, with the chunk's absolute byte offset and an open body holding +// exactly length bytes (length < 0 means "read the whole body", used for the +// single-stream fallback when the object cannot be chunked). Place must read the +// chunk from body and close body. Implementations own where the bytes land and +// may block in Place to bound memory; a blocked Place must abort when ctx is +// cancelled. +// +// Two implementations cover the cases in this package: StreamSink reassembles +// the in-order byte stream for a streaming consumer, and DiskSink scatters +// chunks to their offsets in a file. +type ChunkSink interface { + Place(ctx context.Context, off, length int64, body io.ReadCloser) error +} + +// StreamSink is a ChunkSink that reorders concurrently-fetched chunks back into +// the original byte stream, exposed via Read. Chunks land in a fixed arena of +// 2*concurrency reusable slots indexed by chunk number, so a slow consumer +// applies backpressure to the fetchers (capping memory) instead of letting +// fetched-but-unread chunks pile up. The doubled slot count lets the fetchers +// run a full window ahead of the consumer rather than stalling on it. +// +// A StreamSink must be read concurrently while ParallelGet runs — the fetchers +// block once they get a window ahead of the reader, so a caller that does not +// read will deadlock. After the download finishes the caller signals completion +// with Done; Read then drains the remaining buffered chunks and returns io.EOF, +// or the download error. +type StreamSink struct { + chunkSize int64 + n int // slot count = 2*concurrency + + mu sync.Mutex + cond *sync.Cond // signals Read that a chunk was deposited (or Done) + advance chan struct{} // closed and replaced when readSeq advances, waking blocked Place + bufs [][]byte // n reusable backing buffers, indexed by seq%n (nil until first use) + ready []bool // ready[slot] => bufs[slot] holds the chunk for its current seq + readSeq int64 // sequence number of the chunk Read is emitting next + cur []byte // chunk currently being emitted (aliases bufs[readSeq%n]) + curPos int + + passthru io.ReadCloser // set in single-stream fallback mode (length < 0) + done bool + err error + closed bool +} + +// NewStreamSink returns a StreamSink sized for the given chunk size and download +// concurrency. It holds up to 2*concurrency chunk buffers, giving the fetchers a +// full window of run-ahead over the consumer while capping peak memory at +// 2*concurrency*chunkSize. Buffers are allocated lazily, so a small object never +// reserves the full window. +func NewStreamSink(chunkSize int64, concurrency int) *StreamSink { + n := 2 * max(concurrency, 1) + s := &StreamSink{ + chunkSize: chunkSize, + n: n, + bufs: make([][]byte, n), + ready: make([]bool, n), + advance: make(chan struct{}), + } + s.cond = sync.NewCond(&s.mu) + return s +} + +// Place reads the chunk into its slot and queues it for in-order delivery to +// Read. It blocks until the chunk is within one window of the read cursor +// (backpressure from a slow consumer) and aborts if ctx is cancelled. A negative +// length switches to pass-through mode: the whole body is handed to Read +// directly, since a single-stream fallback has unknown size and must not be +// buffered. +func (s *StreamSink) Place(ctx context.Context, off, length int64, body io.ReadCloser) error { + if length < 0 { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return errors.Join(errors.New("stream sink closed"), body.Close()) + } + s.passthru = body + s.cond.Broadcast() + s.mu.Unlock() + return nil + } + + seq := off / s.chunkSize + slot := int(seq % int64(s.n)) + + // Admission: a chunk may only occupy its slot once the previous occupant + // (seq-n) has been read, i.e. once seq is within n of the read cursor. This + // bounds run-ahead and guarantees no other in-flight chunk maps to this slot, + // so the in-order chunk's slot is always reserved for it. + s.mu.Lock() + for seq >= s.readSeq+int64(s.n) { + if s.closed { + s.mu.Unlock() + return errors.Join(errors.New("stream sink closed"), body.Close()) + } + ch := s.advance + s.mu.Unlock() + select { + case <-ch: + case <-ctx.Done(): + return errors.Join(errors.WithStack(ctx.Err()), body.Close()) + } + s.mu.Lock() + } + buf := s.bufs[slot] + s.mu.Unlock() + + if int64(cap(buf)) < length { + buf = make([]byte, s.chunkSize) + } + buf = buf[:length] + if err := readChunk(off, buf, body); err != nil { + return err + } + + s.mu.Lock() + // A Close racing the body read above leaves no reader to drain this slot; + // drop the chunk rather than mark it ready. readChunk already closed body. + if s.closed { + s.mu.Unlock() + return errors.New("stream sink closed") + } + s.bufs[slot] = buf + s.ready[slot] = true + s.cond.Broadcast() + s.mu.Unlock() + return nil +} + +// Read emits the reassembled object in order. It blocks until the next chunk is +// available, returning io.EOF once every chunk has been read and Done has been +// called, or the download error reported to Done. +func (s *StreamSink) Read(p []byte) (int, error) { + s.mu.Lock() + for { + if s.passthru != nil { + body := s.passthru + s.mu.Unlock() + return body.Read(p) //nolint:wrapcheck // must return io.EOF verbatim for io.ReadAll + } + if s.cur != nil { + n := copy(p, s.cur[s.curPos:]) + s.curPos += n + if s.curPos >= len(s.cur) { + // Chunk fully emitted: free its slot and advance, waking any Place + // blocked waiting for this slot's window to open. + slot := int(s.readSeq % int64(s.n)) + s.ready[slot] = false + s.readSeq++ + s.cur = nil + s.curPos = 0 + close(s.advance) + s.advance = make(chan struct{}) + } + s.mu.Unlock() + return n, nil + } + slot := int(s.readSeq % int64(s.n)) + if s.ready[slot] { + s.cur = s.bufs[slot] + s.curPos = 0 + continue + } + if s.err != nil { + err := s.err + s.mu.Unlock() + return 0, err + } + if s.done { + s.mu.Unlock() + return 0, io.EOF + } + // Closed mid-download with no terminal status: stop rather than block + // forever on cond, since the fetchers are being torn down. + if s.closed { + s.mu.Unlock() + return 0, errors.WithStack(io.ErrClosedPipe) + } + s.cond.Wait() + } +} + +// Done signals that no further chunks will be placed. err is the download +// outcome (nil on success); it is surfaced to Read after the buffered chunks +// drain. +func (s *StreamSink) Done(err error) { + s.mu.Lock() + s.done = true + if err != nil && s.err == nil { + s.err = err + } + s.cond.Broadcast() + s.mu.Unlock() +} + +// Close releases the sink, unblocking any in-flight Place and closing the +// pass-through body if one is set. The arena buffers are released to the garbage +// collector. Cancelling the download itself is the caller's responsibility (see +// OpenGitSnapshotParallel). +func (s *StreamSink) Close() error { + s.mu.Lock() + s.closed = true + body := s.passthru + s.passthru = nil + close(s.advance) + s.advance = make(chan struct{}) + s.cond.Broadcast() + s.mu.Unlock() + if body != nil { + return errors.WithStack(body.Close()) + } + return nil +} + +// DiskSink is a ChunkSink that writes each chunk straight to its offset in an +// io.WriterAt such as an *os.File. io.WriterAt permits concurrent +// non-overlapping writes, so chunks are scattered to disk as they arrive with no +// reordering and negligible memory — the right sink for seekable destinations +// such as cache-to-cache backfill. Unlike StreamSink it needs no concurrent +// reader, so ParallelGet may run to completion synchronously. On error the +// destination is left partially written and must be discarded by the caller. +type DiskSink struct{ W io.WriterAt } + +// Place streams the chunk straight to its offset in the underlying WriterAt. +func (d DiskSink) Place(_ context.Context, off, length int64, body io.ReadCloser) error { + dst := io.NewOffsetWriter(d.W, off) + if length < 0 { + _, err := io.Copy(dst, body) + return errors.Join(errors.Wrap(err, "write chunk"), body.Close()) + } + n, err := io.Copy(dst, io.LimitReader(body, length)) + if err != nil { + return errors.Join(errors.Errorf("write chunk at offset %d: %w", off, err), body.Close()) + } + if n != length { + return errors.Join(errors.Errorf("chunk at offset %d: wrote %d of %d bytes", off, n, length), body.Close()) + } + if overlong(body) { + return errors.Join(errors.Errorf("chunk at offset %d: read more than the expected %d bytes", off, length), body.Close()) + } + return errors.WithStack(body.Close()) +} + +// readChunk fills buf from body (reading exactly len(buf) bytes) and closes +// body. A body shorter than buf (a truncated chunk) or longer than buf (a +// backend that ignored the range) is reported as an error. +func readChunk(off int64, buf []byte, body io.ReadCloser) error { + if _, err := io.ReadFull(body, buf); err != nil { + return errors.Join(errors.Errorf("read chunk at offset %d: %w", off, err), body.Close()) + } + if overlong(body) { + return errors.Join(errors.Errorf("chunk at offset %d: read more than the expected %d bytes", off, len(buf)), body.Close()) + } + return errors.WithStack(body.Close()) +} + +// overlong reports whether r has any bytes left, used to detect a body longer +// than the requested chunk without buffering the excess. +func overlong(r io.Reader) bool { + var probe [1]byte + n, _ := io.ReadFull(r, probe[:]) //nolint:errcheck // any byte past the chunk is overlong, regardless of the error + return n > 0 +} diff --git a/client/git.go b/client/git.go index 999509d9..f1d3dcaa 100644 --- a/client/git.go +++ b/client/git.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "os" + "strconv" "strings" "sync" @@ -184,32 +185,76 @@ func (c *Client) openGitArtifact(ctx context.Context, repoURL, suffix string) (* }, nil } -// GitSnapshotMetadata carries the freshen metadata returned alongside a -// parallel snapshot download. Commit is the mirror's HEAD SHA at snapshot time -// (empty for cold serves); BundleURL, when non-empty, points at a delta bundle -// that brings the snapshot up to the mirror's current HEAD. -type GitSnapshotMetadata struct { - Commit string - BundleURL string -} - -// DownloadGitSnapshot fetches the working-tree snapshot for repoURL into dst, -// using up to concurrency concurrent range requests of chunkSize bytes each. -// When concurrency is 1, or the server does not support ranges, it transparently -// falls back to a single full download. dst is written at non-overlapping -// offsets via WriteAt (e.g. an *os.File) and the caller owns its lifecycle. It -// returns the snapshot's freshen metadata, read from the discovery response. -// Returns os.ErrNotExist when the server has no snapshot available. -func (c *Client) DownloadGitSnapshot(ctx context.Context, repoURL string, dst io.WriterAt, chunkSize int64, concurrency int) (GitSnapshotMetadata, error) { +// OpenGitSnapshotParallel downloads the working-tree snapshot for repoURL using +// up to concurrency concurrent range requests of chunkSize bytes each, and +// returns as soon as the discovery response's freshen metadata (Commit, +// BundleURL) is available, with the in-order bytes exposed on the returned +// GitSnapshot's Body. The download proceeds in the background while the caller +// reads Body, so a decompress/extract pipeline overlaps with the transfer. +// +// The caller must Close the returned GitSnapshot; doing so cancels any +// in-flight download. A concurrency of 1, or a server without range support, +// transparently falls back to a single full download. Returns os.ErrNotExist +// when the server has no snapshot available. +func (c *Client) OpenGitSnapshotParallel(ctx context.Context, repoURL string, chunkSize int64, concurrency int) (*GitSnapshot, error) { endpoint, err := gitEndpointURL(c.baseURL, repoURL, "snapshot.tar.zst") if err != nil { - return GitSnapshotMetadata{}, err + return nil, err } - reader := &gitArtifactRangeReader{client: c, endpoint: endpoint} - if err := ParallelGet(ctx, reader, NewKey(repoURL), dst, chunkSize, concurrency); err != nil { - return GitSnapshotMetadata{}, errors.Wrap(err, "download snapshot") + reader := &gitArtifactRangeReader{client: c, endpoint: endpoint, discovered: make(chan struct{})} + + ctx, cancel := context.WithCancel(ctx) + sink := NewStreamSink(chunkSize, concurrency) + done := make(chan error, 1) + go func() { + err := ParallelGet(ctx, reader, NewKey(repoURL), sink, chunkSize, concurrency) + sink.Done(err) + done <- errors.Wrap(err, "download snapshot") + }() + + // Block only until the discovery response lands (metadata available) or the + // download terminates first (e.g. os.ErrNotExist before any headers). The + // remaining bytes stream through Body in the background. + // + // A small object's download can finish (and signal done) before we observe + // discovered, leaving both ready when select runs. select would then pick a + // branch at random and the done branch would mistake a completed download for + // a missing snapshot. So whenever the download finished after recording + // discovery, fall through to the streaming path and let Body drain the + // buffered bytes (or surface the download error); only treat done as + // authoritative when no discovery ever happened. + select { + case <-reader.discovered: + case err := <-done: + if !reader.didDiscover() { + cancel() + _ = sink.Close() //nolint:errcheck + if err == nil { + return nil, errors.WithStack(os.ErrNotExist) + } + return nil, err + } } - return reader.metadata(), nil + headers := reader.discoveryHeaders() + return &GitSnapshot{ + Body: &cancelReadCloser{ReadCloser: sink, cancel: cancel}, + Headers: headers, + Commit: headers.Get(SnapshotCommitHeader), + BundleURL: headers.Get(BundleURLHeader), + }, nil +} + +// cancelReadCloser cancels the supplied context when Closed, in addition to +// closing the wrapped reader, so closing a streaming download stops its +// background goroutine promptly. +type cancelReadCloser struct { + io.ReadCloser + cancel context.CancelFunc +} + +func (c *cancelReadCloser) Close() error { + c.cancel() + return errors.WithStack(c.ReadCloser.Close()) } // gitArtifactRangeReader adapts a git artifact endpoint to the RangeReader @@ -221,6 +266,11 @@ type gitArtifactRangeReader struct { client *Client endpoint string + // discovered, when non-nil, is closed once the first response's headers are + // recorded, letting a streaming caller surface the freshen metadata before + // the download completes. + discovered chan struct{} + mu sync.Mutex discovery http.Header } @@ -253,6 +303,15 @@ func (g *gitArtifactRangeReader) Open(ctx context.Context, _ Key, opts ...Reques } } +// didDiscover reports whether the first response's headers have been recorded. +// It lets OpenGitSnapshotParallel distinguish a download that completed after +// discovery (stream the buffered bytes) from one that never found the object. +func (g *gitArtifactRangeReader) didDiscover() bool { + g.mu.Lock() + defer g.mu.Unlock() + return g.discovery != nil +} + // recordDiscovery stores the first response's headers so the freshen metadata // they carry survives after the bodies are consumed. func (g *gitArtifactRangeReader) recordDiscovery(h http.Header) { @@ -260,16 +319,27 @@ func (g *gitArtifactRangeReader) recordDiscovery(h http.Header) { defer g.mu.Unlock() if g.discovery == nil { g.discovery = h.Clone() + if g.discovered != nil { + close(g.discovered) + } } } -func (g *gitArtifactRangeReader) metadata() GitSnapshotMetadata { +// discoveryHeaders returns a copy of the first response's headers describing the +// full snapshot, with transport-layer headers stripped, matching +// OpenGitSnapshot's Headers field. The discovery response is a 206 for only the +// first chunk, so its Content-Range/Content-Length describe that chunk, not the +// reassembled body streamed on GitSnapshot.Body; rewrite them to the full object +// size so callers don't mistake the body for partial content. +func (g *gitArtifactRangeReader) discoveryHeaders() http.Header { g.mu.Lock() defer g.mu.Unlock() - return GitSnapshotMetadata{ - Commit: g.discovery.Get(SnapshotCommitHeader), - BundleURL: g.discovery.Get(BundleURLHeader), + headers := filterHeaders(g.discovery, transportHeaders...) + if total, ok := parseContentRangeTotal(headers.Get("Content-Range")); ok { + headers.Del("Content-Range") + headers.Set("Content-Length", strconv.FormatInt(total, 10)) } + return headers } // gitEndpointURL builds a /git/{host}/{repoPath}/{suffix} URL from a cachew diff --git a/client/git_test.go b/client/git_test.go index 0e2ac435..8bb4a878 100644 --- a/client/git_test.go +++ b/client/git_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "os" + "strconv" "strings" "sync/atomic" "testing" @@ -175,7 +176,7 @@ func TestOpenGitBundleNotFound(t *testing.T) { assert.True(t, errors.Is(err, os.ErrNotExist)) } -func TestDownloadGitSnapshotParallel(t *testing.T) { +func TestOpenGitSnapshotParallel(t *testing.T) { body := make([]byte, 1000) for i := range body { body[i] = byte(i % 251) @@ -190,29 +191,63 @@ func TestDownloadGitSnapshotParallel(t *testing.T) { w.Header().Set("ETag", etag) w.Header().Set(client.SnapshotCommitHeader, "deadbeef") w.Header().Set(client.BundleURLHeader, "/git/github.com/org/repo/snapshot.bundle?base=deadbeef") - // ServeContent honours Range/If-Range against the ETag set above, so it - // returns 206 + Content-Range for the chunked requests ParallelGet makes. http.ServeContent(w, r, "snapshot.tar.zst", time.Time{}, bytes.NewReader(body)) })) defer srv.Close() api := client.NewWithHTTPClient(srv.URL, srv.Client()) - var dst bufferAt - // A 128-byte chunk over a 1000-byte body forces multiple chunks, exercising - // concurrent range reassembly. - meta, err := api.DownloadGitSnapshot(context.Background(), "https://github.com/org/repo", &dst, 128, 4) + snap, err := api.OpenGitSnapshotParallel(context.Background(), "https://github.com/org/repo", 128, 4) assert.NoError(t, err) - assert.Equal(t, body, dst.buf) - assert.Equal(t, "deadbeef", meta.Commit) - assert.Equal(t, "/git/github.com/org/repo/snapshot.bundle?base=deadbeef", meta.BundleURL) + defer snap.Close() + + // Metadata is available before the body is fully read. + assert.Equal(t, "deadbeef", snap.Commit) + assert.Equal(t, "/git/github.com/org/repo/snapshot.bundle?base=deadbeef", snap.BundleURL) + + // Headers must describe the full reassembled body, not the discovery chunk: + // the 206's Content-Range is dropped and Content-Length reflects the total. + assert.Equal(t, "", snap.Headers.Get("Content-Range")) + assert.Equal(t, strconv.Itoa(len(body)), snap.Headers.Get("Content-Length")) + + got, err := io.ReadAll(snap.Body) + assert.NoError(t, err) + assert.Equal(t, body, got) assert.True(t, requests.Load() > 1, "expected multiple range requests, got %d", requests.Load()) } -func TestDownloadGitSnapshotFallsBackWithoutRange(t *testing.T) { +// A small snapshot can finish downloading before OpenGitSnapshotParallel +// observes the discovery signal, leaving the discovered and done channels both +// ready. The select must still return the snapshot rather than mistaking a +// completed download for a missing one. Looping exercises the random select +// choice between the two ready channels. +func TestOpenGitSnapshotParallelSmallObject(t *testing.T) { + body := []byte("a tiny snapshot that fits in the first chunk") + const etag = `"snap-small"` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/zstd") + w.Header().Set("ETag", etag) + w.Header().Set(client.SnapshotCommitHeader, "beef") + http.ServeContent(w, r, "snapshot.tar.zst", time.Time{}, bytes.NewReader(body)) + })) + defer srv.Close() + + api := client.NewWithHTTPClient(srv.URL, srv.Client()) + for range 200 { + // chunkSize > len(body) so the whole object arrives on the discovery chunk + // and the download completes immediately. + snap, err := api.OpenGitSnapshotParallel(context.Background(), "https://github.com/org/repo", 4096, 4) + assert.NoError(t, err) + assert.Equal(t, "beef", snap.Commit) + got, err := io.ReadAll(snap.Body) + assert.NoError(t, err) + assert.Equal(t, body, got) + assert.NoError(t, snap.Close()) + } +} + +func TestOpenGitSnapshotParallelFallsBackWithoutRange(t *testing.T) { body := []byte("full body, server ignores ranges") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - // No ETag and no range handling: always answer the full object with 200, - // mimicking an older server. ParallelGet must fall back to a single read. w.Header().Set("Content-Type", "application/zstd") w.Header().Set(client.SnapshotCommitHeader, "cafe") _, _ = w.Write(body) //nolint:errcheck @@ -220,21 +255,49 @@ func TestDownloadGitSnapshotFallsBackWithoutRange(t *testing.T) { defer srv.Close() api := client.NewWithHTTPClient(srv.URL, srv.Client()) - var dst bufferAt - meta, err := api.DownloadGitSnapshot(context.Background(), "https://github.com/org/repo", &dst, 8, 4) + snap, err := api.OpenGitSnapshotParallel(context.Background(), "https://github.com/org/repo", 8, 4) assert.NoError(t, err) - assert.Equal(t, body, dst.buf) - assert.Equal(t, "cafe", meta.Commit) + defer snap.Close() + + assert.Equal(t, "cafe", snap.Commit) + got, err := io.ReadAll(snap.Body) + assert.NoError(t, err) + assert.Equal(t, body, got) } -func TestDownloadGitSnapshotNotFound(t *testing.T) { +func TestOpenGitSnapshotParallelNotFound(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer srv.Close() api := client.NewWithHTTPClient(srv.URL, srv.Client()) - var dst bufferAt - _, err := api.DownloadGitSnapshot(context.Background(), "https://github.com/org/repo", &dst, 8, 4) + _, err := api.OpenGitSnapshotParallel(context.Background(), "https://github.com/org/repo", 8, 4) assert.True(t, errors.Is(err, os.ErrNotExist)) } + +func TestOpenGitSnapshotParallelCloseStopsDownload(t *testing.T) { + body := make([]byte, 1<<20) + for i := range body { + body[i] = byte(i % 251) + } + const etag = `"snap-v1"` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/zstd") + w.Header().Set("ETag", etag) + http.ServeContent(w, r, "snapshot.tar.zst", time.Time{}, bytes.NewReader(body)) + })) + defer srv.Close() + + api := client.NewWithHTTPClient(srv.URL, srv.Client()) + snap, err := api.OpenGitSnapshotParallel(context.Background(), "https://github.com/org/repo", 4096, 4) + assert.NoError(t, err) + + // Read a little, then close before draining: Close must return without + // hanging on the background download. + buf := make([]byte, 16) + _, err = io.ReadFull(snap.Body, buf) + assert.NoError(t, err) + assert.NoError(t, snap.Close()) +} diff --git a/client/parallel_get.go b/client/parallel_get.go index be7c9aa5..55a51b3c 100644 --- a/client/parallel_get.go +++ b/client/parallel_get.go @@ -6,6 +6,7 @@ import ( "net/http" "strconv" "strings" + "sync/atomic" "github.com/alecthomas/errors" "golang.org/x/sync/errgroup" @@ -19,44 +20,41 @@ type RangeReader interface { Open(ctx context.Context, key Key, opts ...RequestOption) (io.ReadCloser, http.Header, error) } -// ParallelGet downloads an object from any Range-capable RangeReader into dst, -// fetching it in chunkSize-byte chunks concurrently (up to concurrency requests -// in flight) and writing each chunk at its offset via dst.WriteAt. Latency-bound -// backends such as a remote cache can saturate bandwidth with overlapping reads. +// ParallelGet downloads an object from a Range-capable RangeReader, fetching it +// in chunkSize-byte chunks concurrently (up to concurrency requests in flight) +// and handing each chunk to sink, which decides where the bytes land. Latency- +// bound backends such as a remote cache can saturate bandwidth with overlapping +// reads. // -// The first chunk is fetched with a ranged Open, whose response yields both the -// total size (from Content-Range) and the object's ETag; every remaining chunk -// is then requested with IfRange pinned to that ETag. If the object changes -// mid-download, a chunk's ETag will differ and ParallelGet returns an error -// rather than splicing bytes from two revisions. A missing or truncated chunk -// is likewise reported as an error, so a partially written dst must be discarded -// by the caller on failure. An object with no ETag to pin to (e.g. one stored -// before ETags were recorded) cannot be kept revision-safe across chunks, so it -// falls back to a single full read instead of parallelising. A concurrency of -// 1 likewise reads the whole object in one request, since chunking a single -// worker would only serialise ranged GETs for no benefit. +// Pass a StreamSink to reassemble the in-order byte stream for a streaming +// consumer (run ParallelGet in a goroutine and read the sink concurrently), or a +// DiskSink to scatter chunks to their offsets in a file (ParallelGet may then run +// synchronously). The engine itself is agnostic to the destination. // -// dst is written via concurrent WriteAt calls at non-overlapping offsets; the -// caller owns dst's lifecycle (open, close, cleanup) and need not pre-size it, -// as WriteAt extends the destination. -func ParallelGet(ctx context.Context, c RangeReader, key Key, dst io.WriterAt, chunkSize int64, concurrency int) error { - if chunkSize <= 0 { - return errors.Errorf("parallel get: chunk size must be positive, got %d", chunkSize) - } - concurrency = max(concurrency, 1) - +// The first chunk is fetched with a ranged Open, whose response yields the total +// size (from Content-Range) and the object's ETag; every later chunk is then +// requested with IfRange pinned to that ETag. If the object changes mid-download +// a chunk's ETag differs and ParallelGet returns an error rather than splicing +// bytes from two revisions. An object with no ETag to pin to, a backend that +// ignores ranges, an object that fits within the first chunk, or a concurrency +// of 1 all fall back to a single full read handed to the sink as one stream +// (offset 0, length < 0). +func ParallelGet(ctx context.Context, c RangeReader, key Key, sink ChunkSink, chunkSize int64, concurrency int) error { // A single worker gains nothing from chunking — it would only serialise // ranged GETs — so skip discovery entirely and read the object in one - // revision-consistent request. - if concurrency == 1 { - return fullRead(ctx, c, key, dst) + // revision-consistent request. chunkSize is unused on this path. + if max(concurrency, 1) == 1 { + return fullRead(ctx, c, key, sink) + } + if chunkSize <= 0 { + return errors.Errorf("parallel get: chunk size must be positive, got %d", chunkSize) } // Discovery: the first ranged Open delivers chunk zero and reveals the total // size and ETag used to pin the rest. rc, headers, err := c.Open(ctx, key, Range(0, chunkSize)) if errors.Is(err, ErrRangeNotSatisfiable) { - return nil // Empty object: nothing to write. + return nil // Empty object: nothing to place. } if err != nil { return errors.Wrap(err, "parallel get: open first chunk") @@ -65,90 +63,96 @@ func ParallelGet(ctx context.Context, c RangeReader, key Key, dst io.WriterAt, c etag := headers.Get(ETagKey) total, hasRange := parseContentRangeTotal(headers.Get("Content-Range")) - // A backend that ignored the range (no Content-Range), or an object that - // fits within the first chunk, is delivered entirely by this response: copy - // it and return, as there is nothing to parallelise. A negative want skips - // the length check when the total size is unknown. - firstLen := min(chunkSize, total) + // A backend that ignored the range, or an object that fits in the first + // chunk, is delivered whole by this response: hand it to the sink as a single + // stream. A negative length tells the sink the size is unknown and the body + // must be streamed, not buffered. if !hasRange { - firstLen = -1 + return errors.Wrap(sink.Place(ctx, 0, -1, rc), "parallel get") } - if !hasRange || total <= chunkSize { - return errors.Wrap(writeChunkAt(dst, 0, firstLen, rc), "parallel get") + if total <= chunkSize { + return errors.Wrap(sink.Place(ctx, 0, total, rc), "parallel get") } - // Subsequent chunks are pinned to the discovery ETag via IfRange. Without a - // validator there is nothing to pin to (IfRange("") is a no-op and an empty - // ETag matches an empty ETag), so chunks could be spliced across a rewrite - // undetected. Objects stored before ETags were recorded fall here, so fall - // back to a single, revision-consistent read rather than parallelising. + // Subsequent chunks pin the discovery ETag via IfRange. Without a validator + // there is nothing to pin to (IfRange("") is a no-op), so chunks could be + // spliced across a rewrite undetected. Objects stored before ETags were + // recorded fall here, so fall back to a single, revision-consistent read. if etag == "" { if err := rc.Close(); err != nil { return errors.Wrap(err, "parallel get: close discovery reader") } - return fullRead(ctx, c, key, dst) + return fullRead(ctx, c, key, sink) } - // Multiple chunks: copy the already-open first chunk concurrently with the - // rest rather than blocking on it here. The first goroutine is scheduled - // before the limit can be reached, so it never stalls holding an open body. - numChunks := int((total + chunkSize - 1) / chunkSize) + numChunks := (total + chunkSize - 1) / chunkSize eg, egCtx := errgroup.WithContext(ctx) - eg.SetLimit(concurrency) - eg.Go(func() error { return writeChunkAt(dst, 0, firstLen, rc) }) - for seq := 1; seq < numChunks; seq++ { - // Stop scheduling once a chunk has failed and cancelled the group. - if egCtx.Err() != nil { - break + + // Workers pull sequence numbers and fetch the remaining chunks. The sink + // bounds how far ahead they run — a StreamSink blocks Place once it is a + // window ahead of the reader — so peak memory stays bounded regardless of + // object size. + var nextSeq atomic.Int64 + nextSeq.Store(1) + worker := func() error { + for { + if egCtx.Err() != nil { + return errors.WithStack(egCtx.Err()) + } + seq := nextSeq.Add(1) - 1 + if seq >= numChunks { + return nil + } + start := seq * chunkSize + end := min(start+chunkSize, total) + body, h, err := c.Open(egCtx, key, Range(start, end), IfRange(etag)) + if err != nil { + return errors.Errorf("parallel get: open range %d-%d: %w", start, end, err) + } + if got := h.Get(ETagKey); got != etag { + return errors.Join( + errors.Errorf("parallel get: object changed during read at offset %d: etag %q != %q", start, got, etag), + body.Close(), + ) + } + if err := sink.Place(egCtx, start, end-start, body); err != nil { + return errors.WithStack(err) + } + } + } + + // One worker places chunk zero from the already-open discovery body, then + // joins the pool, so the number of concurrent range requests stays at + // concurrency rather than concurrency+1. The discovery body was opened with + // the parent ctx and so would not unblock if a sibling chunk fails and + // cancels egCtx, so close it on egCtx cancellation; the Place path closes rc + // on success, and a double Close is harmless. + eg.Go(func() error { + stop := context.AfterFunc(egCtx, func() { _ = rc.Close() }) //nolint:errcheck + err := sink.Place(egCtx, 0, min(chunkSize, total), rc) + stop() + if err != nil { + return errors.WithStack(err) } - start := int64(seq) * chunkSize - end := min(start+chunkSize, total) - eg.Go(func() error { return fetchChunk(egCtx, c, key, dst, start, end, etag) }) + return worker() + }) + for range concurrency - 1 { + eg.Go(worker) } return errors.Wrap(eg.Wait(), "parallel get") } -// fullRead downloads the entire object in a single request and writes it at -// offset zero. It is used when chunking would add no value (a single worker) or -// cannot be made revision-safe (no ETag to pin). The body is a single -// consistent revision, but its length is unknown up front, so writeChunkAt's -// length check is skipped (-1). -func fullRead(ctx context.Context, c RangeReader, key Key, dst io.WriterAt) error { +// fullRead downloads the entire object in a single request and hands it to the +// sink as one stream. It is used when chunking adds no value (a single worker) or +// cannot be made revision-safe (no ETag, or a backend that ignores ranges). The +// body is a single consistent revision whose length is unknown up front, so it is +// placed with a negative length. +func fullRead(ctx context.Context, c RangeReader, key Key, sink ChunkSink) error { rc, _, err := c.Open(ctx, key) if err != nil { return errors.Wrap(err, "parallel get: full read") } - return errors.Wrap(writeChunkAt(dst, 0, -1, rc), "parallel get") -} - -// fetchChunk opens the [start, end) range pinned to etag and writes it at start. -// An ETag change (the object was rewritten mid-download) or a short read is -// reported as an error. -func fetchChunk(ctx context.Context, c RangeReader, key Key, dst io.WriterAt, start, end int64, etag string) error { - rc, headers, err := c.Open(ctx, key, Range(start, end), IfRange(etag)) - if err != nil { - return errors.Errorf("open range %d-%d: %w", start, end, err) - } - if got := headers.Get(ETagKey); got != etag { - return errors.Join( - errors.Errorf("object changed during read at offset %d: etag %q != %q", start, got, etag), - rc.Close(), - ) - } - return writeChunkAt(dst, start, end-start, rc) -} - -// writeChunkAt streams src into dst at off and closes src. It fails if fewer -// than want bytes arrive; a negative want skips that check (total size unknown). -func writeChunkAt(dst io.WriterAt, off, want int64, src io.ReadCloser) error { - n, copyErr := io.Copy(io.NewOffsetWriter(dst, off), src) - if err := errors.Join(copyErr, src.Close()); err != nil { - return errors.Errorf("write chunk at offset %d: %w", off, err) - } - if want >= 0 && n != want { - return errors.Errorf("short chunk at offset %d: wrote %d of %d bytes", off, n, want) - } - return nil + return errors.Wrap(sink.Place(ctx, 0, -1, rc), "parallel get") } // parseContentRangeTotal extracts the total size from a Content-Range value of diff --git a/client/parallel_get_test.go b/client/parallel_get_test.go index b9d920f5..f92ae986 100644 --- a/client/parallel_get_test.go +++ b/client/parallel_get_test.go @@ -8,9 +8,12 @@ import ( "net/http" "strconv" "sync" + "sync/atomic" "testing" + "time" "github.com/alecthomas/assert/v2" + "github.com/alecthomas/errors" "github.com/block/cachew/client" ) @@ -19,21 +22,35 @@ import ( // interface ParallelGet drives. var _ client.RangeReader = (*client.Client)(nil) -// bufferAt is an in-memory io.WriterAt that extends like a file, zero-filling -// any gap, so tests can assert reassembly without touching disk. -type bufferAt struct { - mu sync.Mutex - buf []byte +func patternBytes(n int) []byte { + data := make([]byte, n) + for i := range data { + data[i] = byte(i % 251) + } + return data } -func (b *bufferAt) WriteAt(p []byte, off int64) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - if end := int(off) + len(p); end > len(b.buf) { - b.buf = append(b.buf, make([]byte, end-len(b.buf))...) +// collect runs ParallelGet into a StreamSink and returns the reassembled bytes, +// reading the sink concurrently as the engine requires. The download error (if +// any) takes precedence over the read error. +func collect(c client.RangeReader, key client.Key, chunkSize int64, concurrency int) ([]byte, error) { + sink := client.NewStreamSink(chunkSize, concurrency) + type result struct { + data []byte + err error } - copy(b.buf[off:], p) - return len(p), nil + rc := make(chan result, 1) + go func() { + data, err := io.ReadAll(sink) + rc <- result{data: data, err: err} + }() + err := client.ParallelGet(context.Background(), c, key, sink, chunkSize, concurrency) + sink.Done(err) + res := <-rc + if err != nil { + return res.data, err + } + return res.data, res.err } // rangeFlipReader serves correct byte ranges but reports a different ETag for @@ -67,8 +84,7 @@ func (f *rangeFlipReader) Open(_ context.Context, _ client.Key, opts ...client.R func TestParallelGetETagMismatch(t *testing.T) { c := &rangeFlipReader{data: make([]byte, 1000), firstETag: `"v1"`, restETag: `"v2"`} - var dst bufferAt - err := client.ParallelGet(context.Background(), c, client.NewKey("k"), &dst, 100, 4) + _, err := collect(c, client.NewKey("k"), 100, 4) assert.Error(t, err) assert.Contains(t, err.Error(), "object changed during read") } @@ -97,15 +113,11 @@ func (n *noETagReader) Open(_ context.Context, _ client.Key, opts ...client.Requ func TestParallelGetNoETagMultiChunk(t *testing.T) { // A multi-chunk object with no ETag can't be pinned, so it falls back to a // single full read (backwards compatible with objects stored before ETags). - data := make([]byte, 1000) - for i := range data { - data[i] = byte(i % 251) - } + data := patternBytes(1000) c := &noETagReader{data: data} - var dst bufferAt - err := client.ParallelGet(context.Background(), c, client.NewKey("k"), &dst, 100, 4) + got, err := collect(c, client.NewKey("k"), 100, 4) assert.NoError(t, err) - assert.Equal(t, data, dst.buf) + assert.Equal(t, data, got) } func TestParallelGetNoETagSingleChunk(t *testing.T) { @@ -113,10 +125,9 @@ func TestParallelGetNoETagSingleChunk(t *testing.T) { // revision, so it succeeds without pinning. data := []byte("0123456789") c := &noETagReader{data: data} - var dst bufferAt - err := client.ParallelGet(context.Background(), c, client.NewKey("k"), &dst, 100, 4) + got, err := collect(c, client.NewKey("k"), 100, 4) assert.NoError(t, err) - assert.Equal(t, data, dst.buf) + assert.Equal(t, data, got) } // changingSizeReader serves a multi-chunk body with no ETag on the ranged @@ -147,6 +158,17 @@ func (c *changingSizeReader) Open(_ context.Context, _ client.Key, opts ...clien return io.NopCloser(bytes.NewReader(c.discovery[start : start+length])), headers, nil } +func TestParallelGetNoETagSizeChangedBetweenRequests(t *testing.T) { + // A no-ETag multi-chunk object falls back to a single full read. If it is + // rewritten to a different size between discovery and that read, the + // discovery total must not be used to validate the full body: the full read + // is itself a consistent revision and should be accepted in its entirety. + c := &changingSizeReader{discovery: make([]byte, 1000), rewritten: []byte("changed")} + got, err := collect(c, client.NewKey("k"), 100, 4) + assert.NoError(t, err) + assert.Equal(t, c.rewritten, got) +} + // recordingReader serves byte ranges and records the Range option of every // Open call ("" for a full, non-ranged read), so tests can assert how the // object was fetched. @@ -181,29 +203,212 @@ func (r *recordingReader) Open(_ context.Context, _ client.Key, opts ...client.R return io.NopCloser(bytes.NewReader(r.data[start : start+length])), headers, nil } +func TestParallelGetReassembly(t *testing.T) { + // A multi-chunk object must be emitted to the writer as the original, + // in-order byte stream despite being fetched concurrently. + data := patternBytes(10_000) + c := &recordingReader{data: data, etag: `"v1"`} + got, err := collect(c, client.NewKey("k"), 1000, 4) + assert.NoError(t, err) + assert.Equal(t, data, got) +} + func TestParallelGetSingleWorkerFullRead(t *testing.T) { // A concurrency of 1 gains nothing from chunking, so it must issue a single // non-ranged read rather than discovering and serialising ranged GETs. - data := make([]byte, 1000) - for i := range data { - data[i] = byte(i % 251) - } + data := patternBytes(1000) c := &recordingReader{data: data, etag: `"v1"`} - var dst bufferAt - err := client.ParallelGet(context.Background(), c, client.NewKey("k"), &dst, 100, 1) + got, err := collect(c, client.NewKey("k"), 100, 1) assert.NoError(t, err) - assert.Equal(t, data, dst.buf) + assert.Equal(t, data, got) assert.Equal(t, []string{""}, c.opens) } -func TestParallelGetNoETagSizeChangedBetweenRequests(t *testing.T) { - // A no-ETag multi-chunk object falls back to a single full read. If it is - // rewritten to a different size between discovery and that read, the - // discovery total must not be used to validate the full body: the full read - // is itself a consistent revision and should be accepted in its entirety. - c := &changingSizeReader{discovery: make([]byte, 1000), rewritten: []byte("changed")} - var dst bufferAt - err := client.ParallelGet(context.Background(), c, client.NewKey("k"), &dst, 100, 4) +func TestParallelGetEmptyObject(t *testing.T) { + c := &recordingReader{data: nil, etag: `"v1"`} + got, err := collect(c, client.NewKey("k"), 100, 4) + assert.NoError(t, err) + assert.Equal(t, 0, len(got)) +} + +func TestParallelGetServerIgnoresRange(t *testing.T) { + // A backend that ignores the range header delivers the whole object on the + // discovery request; it must be streamed in full. + data := patternBytes(1000) + c := &ignoreRangeReader{data: data} + got, err := collect(c, client.NewKey("k"), 100, 4) + assert.NoError(t, err) + assert.Equal(t, data, got) +} + +func TestParallelGetOutOfOrderCompletion(t *testing.T) { + // Chunks deliberately complete in reverse order within the in-flight window; + // the writer must still emit a correctly ordered stream. + data := patternBytes(10_000) + c := &reorderReader{data: data, etag: `"v1"`, chunkSize: 1000} + got, err := collect(c, client.NewKey("k"), 1000, 4) + assert.NoError(t, err) + assert.Equal(t, data, got) +} + +func TestParallelGetPropagatesOpenError(t *testing.T) { + // An error opening a non-first chunk must surface and cancel the download. + c := &failingChunkReader{data: patternBytes(10_000), etag: `"v1"`, failAtStart: 5000} + _, err := collect(c, client.NewKey("k"), 1000, 4) + assert.Error(t, err) + assert.Contains(t, err.Error(), "boom") +} + +func TestParallelGetRejectsOverlongChunk(t *testing.T) { + // A backend that honours the discovery range but ignores a later chunk's + // range — returning the whole object with the same ETag — must be detected + // rather than emitting truncated, mis-aligned bytes. + c := &fullBodyOnChunkReader{data: patternBytes(10_000), etag: `"v1"`} + _, err := collect(c, client.NewKey("k"), 1000, 4) + assert.Error(t, err) + assert.Contains(t, err.Error(), "more than the expected 1000 bytes") +} + +func TestParallelGetWriterAtReassembly(t *testing.T) { + // A DiskSink scatters chunks to their offsets via concurrent WriteAt; a + // multi-chunk object must still reassemble correctly. + data := patternBytes(10_000) + c := &recordingReader{data: data, etag: `"v1"`} + dst := &bufferAt{} + err := client.ParallelGet(context.Background(), c, client.NewKey("k"), client.DiskSink{W: dst}, 1000, 4) assert.NoError(t, err) - assert.Equal(t, c.rewritten, dst.buf) + assert.Equal(t, data, dst.buf) +} + +func TestParallelGetWriterAtOutOfOrder(t *testing.T) { + // Chunks complete out of order; DiskSink places each at its offset, so the + // result is correct with no reordering needed. + data := patternBytes(10_000) + c := &reorderReader{data: data, etag: `"v1"`, chunkSize: 1000} + dst := &bufferAt{} + err := client.ParallelGet(context.Background(), c, client.NewKey("k"), client.DiskSink{W: dst}, 1000, 4) + assert.NoError(t, err) + assert.Equal(t, data, dst.buf) +} + +func TestParallelGetWriterAtRejectsOverlongChunk(t *testing.T) { + // The overlong-chunk guard must hold on the DiskSink path too. + c := &fullBodyOnChunkReader{data: patternBytes(10_000), etag: `"v1"`} + dst := &bufferAt{} + err := client.ParallelGet(context.Background(), c, client.NewKey("k"), client.DiskSink{W: dst}, 1000, 4) + assert.Error(t, err) + assert.Contains(t, err.Error(), "more than the expected 1000 bytes") +} + +// bufferAt is an in-memory io.WriterAt that extends like a file, zero-filling +// gaps, so tests can exercise DiskSink's concurrent WriteAt path without +// touching disk. +type bufferAt struct { + mu sync.Mutex + buf []byte +} + +func (b *bufferAt) WriteAt(p []byte, off int64) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + if end := int(off) + len(p); end > len(b.buf) { + b.buf = append(b.buf, make([]byte, end-len(b.buf))...) + } + copy(b.buf[off:], p) + return len(p), nil +} + +// fullBodyOnChunkReader honours the discovery range (start 0) with a proper 206 +// but ignores the range on any later chunk, returning the entire object with the +// same ETag — modelling a backend that degrades to full responses mid-download. +type fullBodyOnChunkReader struct { + data []byte + etag string +} + +func (r *fullBodyOnChunkReader) Open(_ context.Context, _ client.Key, opts ...client.RequestOption) (io.ReadCloser, http.Header, error) { + size := int64(len(r.data)) + start, length, outcome := client.NewRequestOptions(opts...).ResolveRange(size, r.etag) + headers := http.Header{} + headers.Set(client.ETagKey, r.etag) + if outcome == client.RangePartial && start == 0 { + headers.Set("Content-Length", strconv.FormatInt(length, 10)) + headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, start+length-1, size)) + return io.NopCloser(bytes.NewReader(r.data[start : start+length])), headers, nil + } + headers.Set("Content-Length", strconv.FormatInt(size, 10)) + return io.NopCloser(bytes.NewReader(r.data)), headers, nil +} + +// ignoreRangeReader returns the whole object with no Content-Range regardless of +// the requested range, modelling a backend that doesn't honour ranges. +type ignoreRangeReader struct{ data []byte } + +func (r *ignoreRangeReader) Open(_ context.Context, _ client.Key, _ ...client.RequestOption) (io.ReadCloser, http.Header, error) { + headers := http.Header{} + headers.Set("Content-Length", strconv.Itoa(len(r.data))) + return io.NopCloser(bytes.NewReader(r.data)), headers, nil +} + +// reorderReader serves correct byte ranges but delays earlier offsets longer +// than later ones, so within the in-flight window chunks complete out of order +// and the writer must buffer and reorder them. +type reorderReader struct { + data []byte + etag string + chunkSize int64 +} + +func (r *reorderReader) Open(_ context.Context, _ client.Key, opts ...client.RequestOption) (io.ReadCloser, http.Header, error) { + size := int64(len(r.data)) + o := client.NewRequestOptions(opts...) + start, length, outcome := o.ResolveRange(size, r.etag) + headers := http.Header{} + if outcome == client.RangeNotSatisfiable { + headers.Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + return nil, headers, client.ErrRangeNotSatisfiable + } + // Earlier chunks within a window sleep longer, so higher offsets finish + // first and the writer is forced to reorder. + if outcome == client.RangePartial { + chunks := (size - start) / r.chunkSize + time.Sleep(time.Duration(chunks) * time.Millisecond) + } + headers.Set(client.ETagKey, r.etag) + headers.Set("Content-Length", strconv.FormatInt(length, 10)) + if outcome == client.RangePartial { + headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, start+length-1, size)) + } + return io.NopCloser(bytes.NewReader(r.data[start : start+length])), headers, nil +} + +// failingChunkReader serves ranges normally but errors when the requested range +// starts at failAtStart, modelling a mid-download fetch failure. +type failingChunkReader struct { + data []byte + etag string + failAtStart int64 + + opens atomic.Int64 +} + +func (r *failingChunkReader) Open(_ context.Context, _ client.Key, opts ...client.RequestOption) (io.ReadCloser, http.Header, error) { + r.opens.Add(1) + size := int64(len(r.data)) + o := client.NewRequestOptions(opts...) + start, length, outcome := o.ResolveRange(size, r.etag) + if outcome == client.RangePartial && start == r.failAtStart { + return nil, nil, errors.New("boom") + } + headers := http.Header{} + if outcome == client.RangeNotSatisfiable { + headers.Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + return nil, headers, client.ErrRangeNotSatisfiable + } + headers.Set(client.ETagKey, r.etag) + headers.Set("Content-Length", strconv.FormatInt(length, 10)) + if outcome == client.RangePartial { + headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, start+length-1, size)) + } + return io.NopCloser(bytes.NewReader(r.data[start : start+length])), headers, nil } diff --git a/cmd/cachew/git.go b/cmd/cachew/git.go index 0774314c..b3c60951 100644 --- a/cmd/cachew/git.go +++ b/cmd/cachew/git.go @@ -6,7 +6,6 @@ import ( "io" "os" "os/exec" - "path/filepath" "strings" "time" @@ -116,27 +115,22 @@ func (c *GitRestoreCmd) Run(ctx context.Context, api *client.Client) error { return nil } -// fetchAndExtractSnapshot downloads the snapshot and extracts it into the target -// directory, returning its freshen metadata (commit and bundle URL). With a -// download concurrency above 1 it downloads in parallel into a temp file, since -// ParallelGet needs a WriterAt; otherwise it streams the single response -// directly into extraction. +// fetchAndExtractSnapshot downloads the snapshot and pipes it straight into +// extraction, overlapping download and extraction, and returns its freshen +// metadata (commit and bundle URL). A DownloadConcurrency above 1 fetches the +// snapshot with that many concurrent range requests reassembled in order; 1 (or +// a server without range support) streams a single request. func (c *GitRestoreCmd) fetchAndExtractSnapshot(ctx context.Context, api *client.Client) (commit, bundleURL string, err error) { - if c.DownloadConcurrency > 1 { - return c.parallelFetchAndExtract(ctx, api) - } - return c.streamFetchAndExtract(ctx, api) -} - -// streamFetchAndExtract downloads the snapshot in a single request and pipes the -// response body straight into extraction, overlapping download and extraction. -func (c *GitRestoreCmd) streamFetchAndExtract(ctx context.Context, api *client.Client) (string, string, error) { var snap *client.GitSnapshot if err := inSpan(ctx, "cachew.download_snapshot", - []attribute.KeyValue{attribute.String("cachew.repo_url", c.RepoURL)}, + []attribute.KeyValue{ + attribute.String("cachew.repo_url", c.RepoURL), + attribute.Int("cachew.download_concurrency", c.DownloadConcurrency), + attribute.Int("cachew.download_chunk_size_mb", c.DownloadChunkSizeMB), + }, func(ctx context.Context) error { downloadStart := time.Now() - s, err := api.OpenGitSnapshot(ctx, c.RepoURL) + s, err := api.OpenGitSnapshotParallel(ctx, c.RepoURL, int64(c.DownloadChunkSizeMB)<<20, c.DownloadConcurrency) if err != nil { return err //nolint:wrapcheck // wrapped by caller } @@ -158,61 +152,6 @@ func (c *GitRestoreCmd) streamFetchAndExtract(ctx context.Context, api *client.C return snap.Commit, snap.BundleURL, nil } -// parallelFetchAndExtract downloads the snapshot into a temp file using bounded -// concurrent range requests, then extracts from the file. ParallelGet writes via -// WriteAt so it cannot stream into extraction; the temp file is removed on -// return. -func (c *GitRestoreCmd) parallelFetchAndExtract(ctx context.Context, api *client.Client) (string, string, error) { - // Stage the temp snapshot on the same filesystem as the restore target so a - // small or separate /tmp can't fail a restore the target directory has room - // for. The parent of c.Directory shares its filesystem and is created by - // extraction anyway. - tmpDir := filepath.Dir(c.Directory) - if err := os.MkdirAll(tmpDir, 0o750); err != nil { - return "", "", errors.Wrap(err, "create snapshot temp dir") - } - tmp, err := os.CreateTemp(tmpDir, ".cachew-snapshot-*.tar.zst") - if err != nil { - return "", "", errors.Wrap(err, "create snapshot temp file") - } - defer func() { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) //nolint:gosec // name is from os.CreateTemp, not external input - }() - - var meta client.GitSnapshotMetadata - if err := inSpan(ctx, "cachew.download_snapshot", - []attribute.KeyValue{ - attribute.String("cachew.repo_url", c.RepoURL), - attribute.Int("cachew.download_concurrency", c.DownloadConcurrency), - attribute.Int("cachew.download_chunk_size_mb", c.DownloadChunkSizeMB), - }, - func(ctx context.Context) error { - downloadStart := time.Now() - m, err := api.DownloadGitSnapshot(ctx, c.RepoURL, tmp, int64(c.DownloadChunkSizeMB)<<20, c.DownloadConcurrency) - if err != nil { - return err //nolint:wrapcheck // wrapped by caller - } - meta = m - trace.SpanFromContext(ctx).SetAttributes( - attribute.String("cachew.snapshot_commit", m.Commit), - attribute.String("cachew.bundle_url", m.BundleURL), - attribute.Float64("cachew.elapsed_seconds", time.Since(downloadStart).Seconds()), - ) - return nil - }); err != nil { - return "", "", err - } - - if _, err := tmp.Seek(0, io.SeekStart); err != nil { - return "", "", errors.Wrap(err, "rewind snapshot temp file") - } - if err := c.extract(ctx, tmp); err != nil { - return "", "", err - } - return meta.Commit, meta.BundleURL, nil -} - // extract decompresses and unpacks the snapshot body into the target directory. func (c *GitRestoreCmd) extract(ctx context.Context, body io.Reader) error { fmt.Fprintf(os.Stderr, "Extracting to %s...\n", c.Directory) //nolint:forbidigo,gosec // c.Directory is an operator-supplied CLI path diff --git a/internal/cache/parallel_get.go b/internal/cache/parallel_get.go index 55458d07..0276e94d 100644 --- a/internal/cache/parallel_get.go +++ b/internal/cache/parallel_get.go @@ -2,16 +2,15 @@ package cache import ( "context" - "io" "github.com/alecthomas/errors" "github.com/block/cachew/client" ) -// ParallelGet downloads an object from any Range-capable Cache into dst, -// fetching it in chunkSize-byte chunks concurrently. It delegates to +// ParallelGet downloads an object from any Range-capable Cache, fetching it in +// chunkSize-byte chunks concurrently and handing each to sink. It delegates to // [client.ParallelGet]; see that function for the full semantics. -func ParallelGet(ctx context.Context, c Cache, key Key, dst io.WriterAt, chunkSize int64, concurrency int) error { - return errors.WithStack(client.ParallelGet(ctx, c, key, dst, chunkSize, concurrency)) +func ParallelGet(ctx context.Context, c Cache, key Key, sink client.ChunkSink, chunkSize int64, concurrency int) error { + return errors.WithStack(client.ParallelGet(ctx, c, key, sink, chunkSize, concurrency)) } diff --git a/internal/cache/parallel_get_test.go b/internal/cache/parallel_get_test.go index 1baa3bb8..fd396272 100644 --- a/internal/cache/parallel_get_test.go +++ b/internal/cache/parallel_get_test.go @@ -5,31 +5,36 @@ import ( "io" "log/slog" "os" - "sync" "testing" "time" "github.com/alecthomas/assert/v2" + "github.com/block/cachew/client" "github.com/block/cachew/internal/cache" "github.com/block/cachew/internal/logging" ) -// bufferAt is an in-memory io.WriterAt that extends like a file, zero-filling -// any gap, so tests can assert reassembly without touching disk. -type bufferAt struct { - mu sync.Mutex - buf []byte -} - -func (b *bufferAt) WriteAt(p []byte, off int64) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - if end := int(off) + len(p); end > len(b.buf) { - b.buf = append(b.buf, make([]byte, end-len(b.buf))...) +// collect runs cache.ParallelGet into a StreamSink and returns the reassembled +// bytes, reading the sink concurrently as the engine requires. +func collect(ctx context.Context, c cache.Cache, key cache.Key, chunkSize int64, concurrency int) ([]byte, error) { + sink := client.NewStreamSink(chunkSize, concurrency) + type result struct { + data []byte + err error + } + rc := make(chan result, 1) + go func() { + data, err := io.ReadAll(sink) + rc <- result{data: data, err: err} + }() + err := cache.ParallelGet(ctx, c, key, sink, chunkSize, concurrency) + sink.Done(err) + res := <-rc + if err != nil { + return res.data, err } - copy(b.buf[off:], p) - return len(p), nil + return res.data, res.err } func TestParallelGet(t *testing.T) { @@ -61,10 +66,9 @@ func TestParallelGet(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var dst bufferAt - err := cache.ParallelGet(ctx, c, key, &dst, tt.chunkSize, tt.concurrency) + got, err := collect(ctx, c, key, tt.chunkSize, tt.concurrency) assert.NoError(t, err) - assert.Equal(t, content, dst.buf) + assert.Equal(t, content, got) }) } } @@ -83,9 +87,9 @@ func TestParallelGetEmptyObject(t *testing.T) { // concurrency 4 takes the ranged discovery path (ErrRangeNotSatisfiable), // concurrency 1 takes the up-front full-read path; both must yield nothing. for _, concurrency := range []int{4, 1} { - var dst bufferAt - assert.NoError(t, cache.ParallelGet(ctx, c, key, &dst, 100, concurrency)) - assert.Equal(t, 0, len(dst.buf)) + got, err := collect(ctx, c, key, 100, concurrency) + assert.NoError(t, err) + assert.Equal(t, 0, len(got)) } } @@ -95,7 +99,6 @@ func TestParallelGetNotFound(t *testing.T) { assert.NoError(t, err) defer c.Close() - var dst bufferAt - err = cache.ParallelGet(ctx, c, cache.NewKey("missing"), &dst, 100, 4) + _, err = collect(ctx, c, cache.NewKey("missing"), 100, 4) assert.IsError(t, err, os.ErrNotExist) }