diff --git a/client/git.go b/client/git.go index 94d0f6a0..999509d9 100644 --- a/client/git.go +++ b/client/git.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "strings" + "sync" "github.com/alecthomas/errors" ) @@ -183,6 +184,94 @@ func (c *Client) openGitArtifact(ctx context.Context, repoURL, suffix string) (* }, nil } +// GitSnapshotMetadata carries the freshen metadata returned alongside a +// parallel snapshot download. Commit is the mirror's HEAD SHA at snapshot time +// (empty for cold serves); BundleURL, when non-empty, points at a delta bundle +// that brings the snapshot up to the mirror's current HEAD. +type GitSnapshotMetadata struct { + Commit string + BundleURL string +} + +// DownloadGitSnapshot fetches the working-tree snapshot for repoURL into dst, +// using up to concurrency concurrent range requests of chunkSize bytes each. +// When concurrency is 1, or the server does not support ranges, it transparently +// falls back to a single full download. dst is written at non-overlapping +// offsets via WriteAt (e.g. an *os.File) and the caller owns its lifecycle. It +// returns the snapshot's freshen metadata, read from the discovery response. +// Returns os.ErrNotExist when the server has no snapshot available. +func (c *Client) DownloadGitSnapshot(ctx context.Context, repoURL string, dst io.WriterAt, chunkSize int64, concurrency int) (GitSnapshotMetadata, error) { + endpoint, err := gitEndpointURL(c.baseURL, repoURL, "snapshot.tar.zst") + if err != nil { + return GitSnapshotMetadata{}, err + } + reader := &gitArtifactRangeReader{client: c, endpoint: endpoint} + if err := ParallelGet(ctx, reader, NewKey(repoURL), dst, chunkSize, concurrency); err != nil { + return GitSnapshotMetadata{}, errors.Wrap(err, "download snapshot") + } + return reader.metadata(), nil +} + +// gitArtifactRangeReader adapts a git artifact endpoint to the RangeReader +// interface so ParallelGet can fetch it with concurrent range requests. The +// object's identity is the endpoint URL, so the Key argument is ignored. It +// records the first response's headers, which carry the snapshot's freshen +// metadata (delivered on the discovery chunk) that ParallelGet does not surface. +type gitArtifactRangeReader struct { + client *Client + endpoint string + + mu sync.Mutex + discovery http.Header +} + +func (g *gitArtifactRangeReader) Open(ctx context.Context, _ Key, opts ...RequestOption) (io.ReadCloser, http.Header, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, g.endpoint, nil) + if err != nil { + return nil, nil, errors.Wrap(err, "create request") + } + NewRequestOptions(opts...).applyToRequest(req) + + resp, err := g.client.http.Do(req) + if err != nil { + return nil, nil, errors.Wrap(err, "execute request") + } + switch resp.StatusCode { + case http.StatusOK, http.StatusPartialContent: + g.recordDiscovery(resp.Header) + return resp.Body, resp.Header, nil + case http.StatusNotFound: + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec + return nil, nil, errors.Join(os.ErrNotExist, resp.Body.Close()) + case http.StatusRequestedRangeNotSatisfiable: + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec + g.recordDiscovery(resp.Header) + return nil, resp.Header, errors.Join(ErrRangeNotSatisfiable, resp.Body.Close()) + default: + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec + return nil, nil, errors.Join(errors.WithStack(&HTTPStatusError{StatusCode: resp.StatusCode}), resp.Body.Close()) + } +} + +// recordDiscovery stores the first response's headers so the freshen metadata +// they carry survives after the bodies are consumed. +func (g *gitArtifactRangeReader) recordDiscovery(h http.Header) { + g.mu.Lock() + defer g.mu.Unlock() + if g.discovery == nil { + g.discovery = h.Clone() + } +} + +func (g *gitArtifactRangeReader) metadata() GitSnapshotMetadata { + g.mu.Lock() + defer g.mu.Unlock() + return GitSnapshotMetadata{ + Commit: g.discovery.Get(SnapshotCommitHeader), + BundleURL: g.discovery.Get(BundleURLHeader), + } +} + // gitEndpointURL builds a /git/{host}/{repoPath}/{suffix} URL from a cachew // base URL and an upstream repository URL (e.g. https://github.com/org/repo). func gitEndpointURL(baseURL, repoURL, suffix string) (string, error) { diff --git a/client/git_test.go b/client/git_test.go index 18197489..0e2ac435 100644 --- a/client/git_test.go +++ b/client/git_test.go @@ -1,6 +1,7 @@ package client_test import ( + "bytes" "context" "encoding/json" "io" @@ -8,7 +9,9 @@ import ( "net/http/httptest" "os" "strings" + "sync/atomic" "testing" + "time" "github.com/alecthomas/assert/v2" "github.com/alecthomas/errors" @@ -171,3 +174,67 @@ func TestOpenGitBundleNotFound(t *testing.T) { _, err := api.OpenGitBundle(context.Background(), "/git/x/y/snapshot.bundle") assert.True(t, errors.Is(err, os.ErrNotExist)) } + +func TestDownloadGitSnapshotParallel(t *testing.T) { + body := make([]byte, 1000) + for i := range body { + body[i] = byte(i % 251) + } + const etag = `"snap-v1"` + + var requests atomic.Int64 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/git/github.com/org/repo/snapshot.tar.zst", r.URL.Path) + requests.Add(1) + w.Header().Set("Content-Type", "application/zstd") + w.Header().Set("ETag", etag) + w.Header().Set(client.SnapshotCommitHeader, "deadbeef") + w.Header().Set(client.BundleURLHeader, "/git/github.com/org/repo/snapshot.bundle?base=deadbeef") + // ServeContent honours Range/If-Range against the ETag set above, so it + // returns 206 + Content-Range for the chunked requests ParallelGet makes. + http.ServeContent(w, r, "snapshot.tar.zst", time.Time{}, bytes.NewReader(body)) + })) + defer srv.Close() + + api := client.NewWithHTTPClient(srv.URL, srv.Client()) + var dst bufferAt + // A 128-byte chunk over a 1000-byte body forces multiple chunks, exercising + // concurrent range reassembly. + meta, err := api.DownloadGitSnapshot(context.Background(), "https://github.com/org/repo", &dst, 128, 4) + assert.NoError(t, err) + assert.Equal(t, body, dst.buf) + assert.Equal(t, "deadbeef", meta.Commit) + assert.Equal(t, "/git/github.com/org/repo/snapshot.bundle?base=deadbeef", meta.BundleURL) + assert.True(t, requests.Load() > 1, "expected multiple range requests, got %d", requests.Load()) +} + +func TestDownloadGitSnapshotFallsBackWithoutRange(t *testing.T) { + body := []byte("full body, server ignores ranges") + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // No ETag and no range handling: always answer the full object with 200, + // mimicking an older server. ParallelGet must fall back to a single read. + w.Header().Set("Content-Type", "application/zstd") + w.Header().Set(client.SnapshotCommitHeader, "cafe") + _, _ = w.Write(body) //nolint:errcheck + })) + defer srv.Close() + + api := client.NewWithHTTPClient(srv.URL, srv.Client()) + var dst bufferAt + meta, err := api.DownloadGitSnapshot(context.Background(), "https://github.com/org/repo", &dst, 8, 4) + assert.NoError(t, err) + assert.Equal(t, body, dst.buf) + assert.Equal(t, "cafe", meta.Commit) +} + +func TestDownloadGitSnapshotNotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer srv.Close() + + api := client.NewWithHTTPClient(srv.URL, srv.Client()) + var dst bufferAt + _, err := api.DownloadGitSnapshot(context.Background(), "https://github.com/org/repo", &dst, 8, 4) + assert.True(t, errors.Is(err, os.ErrNotExist)) +} diff --git a/cmd/cachew/git.go b/cmd/cachew/git.go index 4e30e013..af3285cd 100644 --- a/cmd/cachew/git.go +++ b/cmd/cachew/git.go @@ -52,6 +52,11 @@ type GitRestoreCmd struct { Commit []string `help:"Required commit SHAs that must exist on the server, regardless of which ref points at them. May be repeated."` NoBundle bool `help:"Skip applying delta bundle."` ZstdThreads int `help:"Threads for zstd decompression (0 = all CPU cores)." default:"0"` + // DownloadConcurrency > 1 fetches the snapshot with that many concurrent + // range requests (requires server range support; falls back to a single + // request otherwise). 1 keeps the streaming single-request download. + DownloadConcurrency int `help:"Concurrent range requests for the snapshot download (1 = single streaming request)." default:"1"` + DownloadChunkSizeMB int `help:"Chunk size in MiB for parallel snapshot downloads." default:"8"` } func (c *GitRestoreCmd) Run(ctx context.Context, api *client.Client) error { @@ -72,56 +77,22 @@ func (c *GitRestoreCmd) Run(ctx context.Context, api *client.Client) error { fmt.Fprintf(os.Stderr, "Fetching snapshot for %s\n", c.RepoURL) //nolint:forbidigo - var snap *client.GitSnapshot - if err := inSpan(ctx, "cachew.download_snapshot", - []attribute.KeyValue{attribute.String("cachew.repo_url", c.RepoURL)}, - func(ctx context.Context) error { - downloadStart := time.Now() - s, err := api.OpenGitSnapshot(ctx, c.RepoURL) - if err != nil { - return err //nolint:wrapcheck // wrapped by caller - } - snap = s - trace.SpanFromContext(ctx).SetAttributes( - attribute.String("cachew.snapshot_commit", s.Commit), - attribute.String("cachew.bundle_url", s.BundleURL), - attribute.Float64("cachew.elapsed_seconds", time.Since(downloadStart).Seconds()), - ) - return nil - }); err != nil { + commit, bundleURL, err := c.fetchAndExtractSnapshot(ctx, api) + if err != nil { if errors.Is(err, os.ErrNotExist) { return errors.Errorf("no snapshot available for %s", c.RepoURL) } span.RecordError(err) span.SetStatus(codes.Error, err.Error()) - return errors.Wrap(err, "fetch snapshot") - } - defer snap.Close() - span.SetAttributes(attribute.String("cachew.snapshot_commit", snap.Commit)) - - fmt.Fprintf(os.Stderr, "Extracting to %s...\n", c.Directory) //nolint:forbidigo - if err := inSpan(ctx, "cachew.extract", - []attribute.KeyValue{attribute.String("cachew.directory", c.Directory)}, - func(ctx context.Context) error { - extractStart := time.Now() - if err := snapshot.Extract(ctx, snap.Body, c.Directory, c.ZstdThreads); err != nil { - return err //nolint:wrapcheck // wrapped by caller - } - elapsed := time.Since(extractStart) - trace.SpanFromContext(ctx).SetAttributes(attribute.Float64("cachew.elapsed_seconds", elapsed.Seconds())) - fmt.Fprintf(os.Stderr, "Snapshot extracted in %s\n", elapsed) //nolint:forbidigo - return nil - }); err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - return errors.Wrap(err, "extract snapshot") + return errors.Wrap(err, "restore snapshot") } + span.SetAttributes(attribute.String("cachew.snapshot_commit", commit)) fmt.Fprintf(os.Stderr, "Snapshot restored to %s\n", c.Directory) //nolint:forbidigo - if snap.BundleURL != "" && !c.NoBundle { + if bundleURL != "" && !c.NoBundle { fmt.Fprintf(os.Stderr, "Applying delta bundle...\n") //nolint:forbidigo bundleStart := time.Now() - if err := applyBundle(ctx, api, snap.BundleURL, c.Directory); err != nil { + if err := applyBundle(ctx, api, bundleURL, c.Directory); err != nil { fmt.Fprintf(os.Stderr, "Warning: failed to apply delta bundle: %v\n", err) //nolint:forbidigo span.RecordError(err) } else { @@ -144,6 +115,112 @@ func (c *GitRestoreCmd) Run(ctx context.Context, api *client.Client) error { return nil } +// fetchAndExtractSnapshot downloads the snapshot and extracts it into the target +// directory, returning its freshen metadata (commit and bundle URL). With a +// download concurrency above 1 it downloads in parallel into a temp file, since +// ParallelGet needs a WriterAt; otherwise it streams the single response +// directly into extraction. +func (c *GitRestoreCmd) fetchAndExtractSnapshot(ctx context.Context, api *client.Client) (commit, bundleURL string, err error) { + if c.DownloadConcurrency > 1 { + return c.parallelFetchAndExtract(ctx, api) + } + return c.streamFetchAndExtract(ctx, api) +} + +// streamFetchAndExtract downloads the snapshot in a single request and pipes the +// response body straight into extraction, overlapping download and extraction. +func (c *GitRestoreCmd) streamFetchAndExtract(ctx context.Context, api *client.Client) (string, string, error) { + var snap *client.GitSnapshot + if err := inSpan(ctx, "cachew.download_snapshot", + []attribute.KeyValue{attribute.String("cachew.repo_url", c.RepoURL)}, + func(ctx context.Context) error { + downloadStart := time.Now() + s, err := api.OpenGitSnapshot(ctx, c.RepoURL) + if err != nil { + return err //nolint:wrapcheck // wrapped by caller + } + snap = s + trace.SpanFromContext(ctx).SetAttributes( + attribute.String("cachew.snapshot_commit", s.Commit), + attribute.String("cachew.bundle_url", s.BundleURL), + attribute.Float64("cachew.elapsed_seconds", time.Since(downloadStart).Seconds()), + ) + return nil + }); err != nil { + return "", "", err + } + defer snap.Close() + + if err := c.extract(ctx, snap.Body); err != nil { + return "", "", err + } + return snap.Commit, snap.BundleURL, nil +} + +// parallelFetchAndExtract downloads the snapshot into a temp file using bounded +// concurrent range requests, then extracts from the file. ParallelGet writes via +// WriteAt so it cannot stream into extraction; the temp file is removed on +// return. +func (c *GitRestoreCmd) parallelFetchAndExtract(ctx context.Context, api *client.Client) (string, string, error) { + tmp, err := os.CreateTemp("", "cachew-snapshot-*.tar.zst") + if err != nil { + return "", "", errors.Wrap(err, "create snapshot temp file") + } + defer func() { + _ = tmp.Close() + _ = os.Remove(tmp.Name()) //nolint:gosec // name is from os.CreateTemp, not external input + }() + + var meta client.GitSnapshotMetadata + if err := inSpan(ctx, "cachew.download_snapshot", + []attribute.KeyValue{ + attribute.String("cachew.repo_url", c.RepoURL), + attribute.Int("cachew.download_concurrency", c.DownloadConcurrency), + attribute.Int("cachew.download_chunk_size_mb", c.DownloadChunkSizeMB), + }, + func(ctx context.Context) error { + downloadStart := time.Now() + m, err := api.DownloadGitSnapshot(ctx, c.RepoURL, tmp, int64(c.DownloadChunkSizeMB)<<20, c.DownloadConcurrency) + if err != nil { + return err //nolint:wrapcheck // wrapped by caller + } + meta = m + trace.SpanFromContext(ctx).SetAttributes( + attribute.String("cachew.snapshot_commit", m.Commit), + attribute.String("cachew.bundle_url", m.BundleURL), + attribute.Float64("cachew.elapsed_seconds", time.Since(downloadStart).Seconds()), + ) + return nil + }); err != nil { + return "", "", err + } + + if _, err := tmp.Seek(0, io.SeekStart); err != nil { + return "", "", errors.Wrap(err, "rewind snapshot temp file") + } + if err := c.extract(ctx, tmp); err != nil { + return "", "", err + } + return meta.Commit, meta.BundleURL, nil +} + +// extract decompresses and unpacks the snapshot body into the target directory. +func (c *GitRestoreCmd) extract(ctx context.Context, body io.Reader) error { + fmt.Fprintf(os.Stderr, "Extracting to %s...\n", c.Directory) //nolint:forbidigo,gosec // c.Directory is an operator-supplied CLI path + return inSpan(ctx, "cachew.extract", + []attribute.KeyValue{attribute.String("cachew.directory", c.Directory)}, + func(ctx context.Context) error { + extractStart := time.Now() + if err := snapshot.Extract(ctx, body, c.Directory, c.ZstdThreads); err != nil { + return err //nolint:wrapcheck // wrapped by caller + } + elapsed := time.Since(extractStart) + trace.SpanFromContext(ctx).SetAttributes(attribute.Float64("cachew.elapsed_seconds", elapsed.Seconds())) + fmt.Fprintf(os.Stderr, "Snapshot extracted in %s\n", elapsed) //nolint:forbidigo + return nil + }) +} + // satisfyRefs ensures the working tree contains every requested ref and // commit. It short-circuits whenever the local clone already has what was // asked for, avoiding both /ensure-refs and git pull when the snapshot+bundle diff --git a/internal/strategy/git/snapshot.go b/internal/strategy/git/snapshot.go index 0418941d..adb289b5 100644 --- a/internal/strategy/git/snapshot.go +++ b/internal/strategy/git/snapshot.go @@ -265,48 +265,37 @@ func (s *Strategy) handleSnapshotRequest(w http.ResponseWriter, r *http.Request, if existing, loaded := s.coldSnapshotMu.LoadOrStore(upstreamURL, entry); loaded { winner := existing.(*coldSnapshotEntry) <-winner.done - reader, _, openErr := s.cache.Open(ctx, cacheKey) - if openErr == nil && reader != nil { + reader, headers, openErr := s.cache.Open(ctx, cacheKey, httputil.ConditionalOptions(r)...) + if !errors.Is(openErr, os.ErrNotExist) { winner.serving.Add(1) - defer func() { - _ = reader.Close() - winner.serving.Done() - }() - logger.InfoContext(ctx, "Serving locally cached snapshot after waiting for in-flight fill", "upstream", upstreamURL) - w.Header().Set("Content-Type", "application/zstd") - n, err := serveReaderFast(w, r, reader) - s.metrics.recordSnapshotServe(ctx, "cold_cache", repoName, n, time.Since(start)) - span.SetAttributes(attribute.String("cachew.source", "cold_cache"), attribute.Int64("cachew.bytes", n)) - if err != nil { - logger.WarnContext(ctx, "Failed to stream locally cached snapshot", "upstream", upstreamURL, "error", err) - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + served, serveErr := s.serveOpenedSnapshot(ctx, w, reader, headers, openErr, repoName, "cold_cache", start) + winner.serving.Done() + if serveErr != nil { + logger.WarnContext(ctx, "Failed to serve locally cached snapshot after waiting for in-flight fill", "upstream", upstreamURL, "error", serveErr) + span.RecordError(serveErr) + } + if served { + logger.InfoContext(ctx, "Served locally cached snapshot after waiting for in-flight fill", "upstream", upstreamURL) + return } - return } } else { defer func() { close(entry.done) s.coldSnapshotMu.Delete(upstreamURL) }() - reader, _, openErr := s.cache.Open(ctx, cacheKey) - if openErr == nil && reader != nil { - logger.InfoContext(ctx, "Serving cached snapshot while mirror warms up", "upstream", upstreamURL) - w.Header().Set("Content-Type", "application/zstd") - n, err := serveReaderFast(w, r, reader) - s.metrics.recordSnapshotServe(ctx, "cold_cache", repoName, n, time.Since(start)) - span.SetAttributes(attribute.String("cachew.source", "cold_cache"), attribute.Int64("cachew.bytes", n)) - if err != nil { - logger.WarnContext(ctx, "Failed to stream cached snapshot", "upstream", upstreamURL, "error", err) - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + reader, headers, openErr := s.cache.Open(ctx, cacheKey, httputil.ConditionalOptions(r)...) + if !errors.Is(openErr, os.ErrNotExist) { + served, serveErr := s.serveOpenedSnapshot(ctx, w, reader, headers, openErr, repoName, "cold_cache", start) + if serveErr != nil { + logger.WarnContext(ctx, "Failed to serve cached snapshot while mirror warms up", "upstream", upstreamURL, "error", serveErr) + span.RecordError(serveErr) + } + if served { + logger.InfoContext(ctx, "Served cached snapshot while mirror warms up", "upstream", upstreamURL) + s.scheduleDeferredMirrorRestore(ctx, repo, entry) + return } - _ = reader.Close() - s.scheduleDeferredMirrorRestore(ctx, repo, entry) - return - } - if reader != nil { - _ = reader.Close() } } } @@ -556,10 +545,7 @@ func (s *Strategy) serveSnapshotWithBundle(ctx context.Context, w http.ResponseW return errors.Wrap(openErr, "serve snapshot") } - source := "cache" - if headers.Get("Content-Range") != "" { - source = "cache_range" - } + source := snapshotServeSource("cache", headers) s.metrics.recordSnapshotServe(ctx, source, repoName, n, time.Since(start)) if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() { span.SetAttributes(attribute.String("cachew.source", source), attribute.Int64("cachew.bytes", n)) @@ -567,6 +553,42 @@ func (s *Strategy) serveSnapshotWithBundle(ctx context.Context, w http.ResponseW return errors.Wrap(err, "serve snapshot") } +// serveOpenedSnapshot writes an already-opened cached snapshot to w, honouring +// Range and conditional requests and forcing the snapshot Content-Type. It is +// the cold-start serve path (no mirror, so no bundle negotiation); source labels +// the serve in metrics and traces. reader/headers/openErr are the cache Open +// results; callers must not pass an os.ErrNotExist miss. It returns served=false +// when the cache returned an unexpected error, so the caller can fall through to +// generation, and closes reader on every path. +func (s *Strategy) serveOpenedSnapshot(ctx context.Context, w http.ResponseWriter, reader io.ReadCloser, headers http.Header, openErr error, repoName, source string, start time.Time) (served bool, err error) { + decorate := func(rw http.ResponseWriter, _ http.Header) { + rw.Header().Set("Content-Type", "application/zstd") + } + handled, n, serveErr := httputil.ServeCacheHit(w, headers, reader, openErr, httputil.WithResponseDecorator(decorate)) + if !handled { + if reader != nil { + serveErr = errors.Join(serveErr, reader.Close()) + } + return false, errors.Wrap(errors.Join(openErr, serveErr), "open cached snapshot") + } + source = snapshotServeSource(source, headers) + s.metrics.recordSnapshotServe(ctx, source, repoName, n, time.Since(start)) + if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() { + span.SetAttributes(attribute.String("cachew.source", source), attribute.Int64("cachew.bytes", n)) + } + return true, errors.Wrap(serveErr, "serve cached snapshot") +} + +// snapshotServeSource appends a "_range" suffix to the metric source label when +// the response carried a satisfied byte range, so full and partial serves are +// distinguishable. +func snapshotServeSource(base string, headers http.Header) string { + if headers.Get("Content-Range") != "" { + return base + "_range" + } + return base +} + // pregenerateBundle builds and caches the delta bundle for snapshotCommit in the // background so any pod can later serve it without regenerating. func (s *Strategy) pregenerateBundle(ctx context.Context, repo *gitclone.Repository, upstreamURL, snapshotCommit string) {