diff --git a/acceptance/experimental/air/help/output.txt b/acceptance/experimental/air/help/output.txt index 8eac074581..ee89e778d6 100644 --- a/acceptance/experimental/air/help/output.txt +++ b/acceptance/experimental/air/help/output.txt @@ -12,7 +12,7 @@ Usage: Available Commands: cancel Cancel one or more runs get Show status, configuration, and timing details for a specific run - list List your recent runs (active and completed) for the current profile + list List your active runs for the current profile (use --all-status for finished runs) logs Stream or fetch logs for a run register-image Mirror a Docker image into the workspace registry run Submit a training workload from a YAML config @@ -30,13 +30,13 @@ Use "databricks experimental air [command] --help" for more information about a === list help >>> [CLI] experimental air list --help -List your recent runs (active and completed) for the current profile +List your active runs for the current profile (use --all-status for finished runs) Usage: databricks experimental air list [flags] Flags: - --active Show only active runs + --all-status Show runs in all states (default: active only) --all-users Show runs from all users --filter stringArray Filter runs, e.g. experiment=foo* (repeatable) -h, --help help for list diff --git a/acceptance/experimental/air/list/output.txt b/acceptance/experimental/air/list/output.txt index 3f84d281ad..e27cfb0bf3 100644 --- a/acceptance/experimental/air/list/output.txt +++ b/acceptance/experimental/air/list/output.txt @@ -22,3 +22,27 @@ ] } } + +=== list --all-status (text, via AiTrainingService index) +>>> [CLI] experimental air list --all-status + Run ID Experiment Status Started Duration MLflow User Accelerators + [NUMID] qwen-train ● SUCCESS [TIMESTAMP] 12s …/runs/run1 [USERNAME] 8x H100 + +=== list --all-status (json) +>>> [CLI] experimental air list --all-status -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "runs": [ + { + "run_id": "[NUMID]", + "run_name": "qwen-train", + "user": "[USERNAME]", + "status": "SUCCESS", + "started_at": "[TIMESTAMP]", + "is_sweep": false + } + ] + } +} diff --git a/acceptance/experimental/air/list/script b/acceptance/experimental/air/list/script index 14702b283e..df547794c6 100644 --- a/acceptance/experimental/air/list/script +++ b/acceptance/experimental/air/list/script @@ -3,3 +3,9 @@ trace $CLI experimental air list title "list (json)" trace $CLI experimental air list -o json + +title "list --all-status (text, via AiTrainingService index)" +trace $CLI experimental air list --all-status + +title "list --all-status (json)" +trace $CLI experimental air list --all-status -o json diff --git a/acceptance/experimental/air/list/test.toml b/acceptance/experimental/air/list/test.toml index 23dd0f6ba5..82f1f829d8 100644 --- a/acceptance/experimental/air/list/test.toml +++ b/acceptance/experimental/air/list/test.toml @@ -2,6 +2,10 @@ [EnvMatrix] DATABRICKS_BUNDLE_ENGINE = [] +# Disable the on-disk run cache so --all-status output is deterministic across runs. +[Env] +DATABRICKS_CACHE_ENABLED = "false" + # The SDK occasionally probes host reachability with a HEAD request; stub it so # the test is deterministic. [[Server]] @@ -51,3 +55,33 @@ Pattern = "GET /api/2.2/jobs/runs/get-output" Response.Body = ''' {"ai_runtime_task_output": {"mlflow_experiment_id": "exp1", "mlflow_run_id": "run1"}} ''' + +# `air list --all-status` scoped to the current user is served by the +# AiTrainingService index: it returns cheap (job_run_id, submit_time) pairs, which +# the CLI orders and then hydrates into full runs via runs/get. +[[Server]] +Pattern = "GET /api/2.0/ai-training/workflows" +Response.Body = ''' +{"training_workflows": [{"job_run_id": "334747067049496", "submit_time": "2024-06-05T17:32:39Z"}]} +''' + +# runs/get hydrates one index id into the same shape as a runs/list element. +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get" +Response.Body = ''' +{ + "run_id": 334747067049496, + "run_name": "qwen-train", + "creator_user_name": "tester@databricks.com", + "start_time": 1717608759000, + "end_time": 1717608771000, + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}, + "tasks": [{ + "run_id": 334747067049497, + "ai_runtime_task": { + "experiment": "/Users/tester@databricks.com/qwen-train", + "deployments": [{"compute": {"accelerator_type": "GPU_8xH100", "accelerator_count": 8}}] + } + }] +} +''' diff --git a/experimental/air/cmd/aitraining.go b/experimental/air/cmd/aitraining.go new file mode 100644 index 0000000000..202a7048db --- /dev/null +++ b/experimental/air/cmd/aitraining.go @@ -0,0 +1,105 @@ +package aircmd + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" +) + +// aiTrainingWorkflowsPath is the AiTrainingService index of the caller's own AIR +// runs. It returns cheap (job_run_id, submit_time) pairs, letting `air list` +// order and page without scanning the Jobs runs/list firehose. +const aiTrainingWorkflowsPath = "/api/2.0/ai-training/workflows" + +// workflowRef is one run from the index: its Jobs run id and submission time. +type workflowRef struct { + jobRunID int64 + submitTimeMs int64 +} + +type aiTrainingWorkflow struct { + // job_run_id is a Jobs run id; tolerate it arriving as a JSON number or string. + JobRunID json.Number `json:"job_run_id"` + // submit_time is a proto Timestamp, serialized over HTTP as either an RFC3339 + // string or a {seconds, nanos} object. + SubmitTime json.RawMessage `json:"submit_time"` +} + +type aiTrainingWorkflowsResponse struct { + TrainingWorkflows []aiTrainingWorkflow `json:"training_workflows"` + NextPageToken string `json:"next_page_token"` +} + +// listAiTrainingWorkflows pages the index and returns every workflow ref the +// caller owns. Pagination stops at the end or when a page token repeats, which +// guards against a stuck or cycling cursor without an arbitrary page cap. +func listAiTrainingWorkflows(ctx context.Context, w *databricks.WorkspaceClient, activeOnly bool) ([]workflowRef, error) { + apiClient, err := client.New(w.Config) + if err != nil { + return nil, fmt.Errorf("failed to create API client: %w", err) + } + + var refs []workflowRef + seen := map[string]bool{} + var pageToken string + for { + query := map[string]any{} + if activeOnly { + query["active_only"] = true + } + if pageToken != "" { + query["page_token"] = pageToken + } + + var resp aiTrainingWorkflowsResponse + err = apiClient.Do(ctx, http.MethodGet, aiTrainingWorkflowsPath, nil, nil, query, &resp) + if err != nil { + return nil, fmt.Errorf("failed to list training workflows: %w", err) + } + + for _, wf := range resp.TrainingWorkflows { + id, err := wf.JobRunID.Int64() + if err != nil || id == 0 { + continue + } + refs = append(refs, workflowRef{jobRunID: id, submitTimeMs: parseSubmitTimeMs(wf.SubmitTime)}) + } + + if resp.NextPageToken == "" || seen[resp.NextPageToken] { + break + } + seen[resp.NextPageToken] = true + pageToken = resp.NextPageToken + } + return refs, nil +} + +// parseSubmitTimeMs converts a proto Timestamp (RFC3339 string or {seconds, nanos} +// object) to epoch milliseconds, or 0 when absent or unparseable (so it sorts last). +func parseSubmitTimeMs(raw json.RawMessage) int64 { + if len(raw) == 0 { + return 0 + } + + var s string + if json.Unmarshal(raw, &s) == nil { + if t, err := time.Parse(time.RFC3339, s); err == nil { + return t.UnixMilli() + } + return 0 + } + + var obj struct { + Seconds int64 `json:"seconds"` + Nanos int64 `json:"nanos"` + } + if json.Unmarshal(raw, &obj) == nil { + return obj.Seconds*1000 + obj.Nanos/1_000_000 + } + return 0 +} diff --git a/experimental/air/cmd/aitraining_test.go b/experimental/air/cmd/aitraining_test.go new file mode 100644 index 0000000000..45bce27983 --- /dev/null +++ b/experimental/air/cmd/aitraining_test.go @@ -0,0 +1,75 @@ +package aircmd + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseSubmitTimeMs(t *testing.T) { + cases := []struct { + name string + raw string + want int64 + }{ + {"rfc3339", `"2023-11-14T22:13:20Z"`, 1700000000000}, + {"rfc3339 offset", `"2023-11-14T22:13:20+00:00"`, 1700000000000}, + {"seconds and nanos", `{"seconds": 1700000000, "nanos": 500000000}`, 1700000000500}, + {"seconds only", `{"seconds": 1700000000}`, 1700000000000}, + {"empty", ``, 0}, + {"garbage string", `"not-a-time"`, 0}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, parseSubmitTimeMs(json.RawMessage(tc.raw))) + }) + } +} + +// indexServer serves paginated AiTrainingService responses, one body per call, +// tracking whether the index was hit. +func indexServer(t *testing.T, hit *bool, bodies ...string) *httptest.Server { + t.Helper() + call := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == aiTrainingWorkflowsPath { + *hit = true + body := bodies[min(call, len(bodies)-1)] + call++ + _, _ = w.Write([]byte(body)) + return + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestListAiTrainingWorkflowsPaginates(t *testing.T) { + page1 := `{"training_workflows":[{"job_run_id":"1","submit_time":"2023-11-14T22:13:20Z"}],"next_page_token":"tok"}` + page2 := `{"training_workflows":[{"job_run_id":2,"submit_time":{"seconds":1700000100}}]}` + var hit bool + srv := indexServer(t, &hit, page1, page2) + + refs, err := listAiTrainingWorkflows(t.Context(), newTestWorkspaceClient(t, srv.URL), false) + require.NoError(t, err) + require.Len(t, refs, 2) + assert.Equal(t, int64(1), refs[0].jobRunID) + assert.Equal(t, int64(1700000000000), refs[0].submitTimeMs) + assert.Equal(t, int64(2), refs[1].jobRunID) +} + +func TestListAiTrainingWorkflowsStopsOnRepeatedToken(t *testing.T) { + // A cursor that always returns the same token must not loop forever. + page := `{"training_workflows":[{"job_run_id":1,"submit_time":"2023-11-14T22:13:20Z"}],"next_page_token":"tok"}` + var hit bool + srv := indexServer(t, &hit, page) + + refs, err := listAiTrainingWorkflows(t.Context(), newTestWorkspaceClient(t, srv.URL), false) + require.NoError(t, err) + require.Len(t, refs, 2) // first page + one more before the repeated token stops it +} diff --git a/experimental/air/cmd/joblist.go b/experimental/air/cmd/joblist.go index 8a50921264..c84f87c9df 100644 --- a/experimental/air/cmd/joblist.go +++ b/experimental/air/cmd/joblist.go @@ -2,17 +2,27 @@ package aircmd import ( "context" + "errors" "fmt" "net/http" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/client" + "golang.org/x/sync/errgroup" ) -// jobsRunsListPath is the Jobs runs/list endpoint. We call it directly (rather -// than via the typed SDK) because the SDK's RunTask omits ai_runtime_task, the -// task type the AI runtime now submits. -const jobsRunsListPath = "/api/2.2/jobs/runs/list" +// jobsRunsListPath and jobsRunsGetPath are the Jobs endpoints we call directly +// (rather than via the typed SDK) because the SDK's RunTask omits +// ai_runtime_task, the task type the AI runtime now submits. +const ( + jobsRunsListPath = "/api/2.2/jobs/runs/list" + jobsRunsGetPath = "/api/2.2/jobs/runs/get" +) + +// hydrateConcurrency bounds the parallel runs/get calls when hydrating a batch +// of run ids from the AiTrainingService index. +const hydrateConcurrency = 16 type jobsRunsListResponse struct { Runs []jobRun `json:"runs"` @@ -155,6 +165,16 @@ func jobTiming(r *jobRun) (startMillis, endMillis int64) { return startMillis, endMillis } +// isTerminal reports whether a run has finished and its details are immutable, +// so its row is safe to cache. +func isTerminal(r *jobRun) bool { + switch r.State.LifeCycleState { + case "TERMINATED", "INTERNAL_ERROR", "SKIPPED": + return true + } + return false +} + // fetchJobRunsPage fetches one page of Jobs runs/list. query carries the request // params (and page_token across calls). func fetchJobRunsPage(ctx context.Context, w *databricks.WorkspaceClient, query map[string]any) (*jobsRunsListResponse, error) { @@ -170,3 +190,55 @@ func fetchJobRunsPage(ctx context.Context, w *databricks.WorkspaceClient, query } return &resp, nil } + +// fetchJobRun fetches a single run via runs/get. The response mirrors one +// runs/list element, so it deserializes into jobRun directly. +func fetchJobRun(ctx context.Context, w *databricks.WorkspaceClient, runID int64) (*jobRun, error) { + apiClient, err := client.New(w.Config) + if err != nil { + return nil, fmt.Errorf("failed to create API client: %w", err) + } + + var run jobRun + query := map[string]any{"run_id": runID, "expand_tasks": true} + err = apiClient.Do(ctx, http.MethodGet, jobsRunsGetPath, nil, nil, query, &run) + if err != nil { + return nil, err + } + return &run, nil +} + +// hydrateJobRuns fetches the given run ids concurrently via runs/get, preserving +// input order. runs/get enforces per-run view ACLs, so an id the caller can't +// view (403) or that has been purged (404) is dropped; any other error is +// systemic and fails the whole batch. +func hydrateJobRuns(ctx context.Context, w *databricks.WorkspaceClient, ids []int64) ([]*jobRun, error) { + runs := make([]*jobRun, len(ids)) + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(hydrateConcurrency) + for i, id := range ids { + g.Go(func() error { + run, err := fetchJobRun(gctx, w, id) + if err != nil { + if apiErr, ok := errors.AsType[*apierr.APIError](err); ok && + (apiErr.StatusCode == http.StatusForbidden || apiErr.StatusCode == http.StatusNotFound) { + return nil // not viewable or purged: drop this id + } + return fmt.Errorf("failed to get run %d: %w", id, err) + } + runs[i] = run + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, err + } + + hydrated := make([]*jobRun, 0, len(runs)) + for _, run := range runs { + if run != nil { + hydrated = append(hydrated, run) + } + } + return hydrated, nil +} diff --git a/experimental/air/cmd/list.go b/experimental/air/cmd/list.go index 8c716f5556..d5e5e48bcc 100644 --- a/experimental/air/cmd/list.go +++ b/experimental/air/cmd/list.go @@ -60,29 +60,32 @@ type listedRun struct { // listQuery holds the resolved inputs to a runFetcher. type listQuery struct { activeOnly bool + allUsers bool userFilter string + currentUser string filters listFilters fetchMLflow bool + limit int } func newListCommand() *cobra.Command { var ( - limit int - active bool - allUsers bool - filters []string + limit int + allStatus bool + allUsers bool + filters []string ) cmd := &cobra.Command{ Use: "list", Args: root.NoArgs, - Short: "List your recent runs (active and completed) for the current profile", + Short: "List your active runs for the current profile (use --all-status for finished runs)", } cmd.PreRunE = root.MustWorkspaceClient cmd.Flags().IntVar(&limit, "limit", 20, "Maximum number of runs to show") - cmd.Flags().BoolVar(&active, "active", false, "Show only active runs") + cmd.Flags().BoolVar(&allStatus, "all-status", false, "Show runs in all states (default: active only)") cmd.Flags().BoolVar(&allUsers, "all-users", false, "Show runs from all users") cmd.Flags().StringArrayVar(&filters, "filter", nil, "Filter runs, e.g. experiment=foo* (repeatable)") @@ -103,19 +106,24 @@ func newListCommand() *cobra.Command { // unless --all-users is set. runs/list has no creator param, so the // creator is matched while scanning. userFilter := f.User + var currentUser string if userFilter == "" && !allUsers { me, err := w.CurrentUser.Me(ctx, iam.MeRequest{}) if err != nil { return fmt.Errorf("failed to resolve current user: %w", err) } - userFilter = me.UserName + currentUser = me.UserName + userFilter = currentUser } fetcher := newRunFetcher(ctx, w, listQuery{ - activeOnly: active, + activeOnly: !allStatus, + allUsers: allUsers, userFilter: userFilter, + currentUser: currentUser, filters: f, fetchMLflow: root.OutputType(cmd) == flags.OutputText, + limit: limit, }) // JSON prints the newest `limit` runs once. Text renders the table: @@ -135,23 +143,96 @@ func newListCommand() *cobra.Command { return cmd } -// runFetcher pages Jobs runs/list on demand, yielding AIR runs that match the -// user and filters. It buffers a page's leftover runs so successive next() calls -// resume where the last stopped — driving both one-shot output and lazy paging. +// listStrategy is a source of matching runs, pulled in batches. Two implement it: +// jobsScanStrategy pages runs/list; indexStrategy hydrates the AiTrainingService +// index. The fetcher wraps whichever is chosen. +type listStrategy interface { + // next returns up to want more matching runs (already row-built + task id). + next(want int) ([]listedRun, error) + // done reports whether the source has no more runs to yield. + done() bool + // truncated reports whether a safety cap stopped the scan short of the end. + truncated() bool +} + +// runFetcher yields matching rows in batches, driving both one-shot output and +// the interactive table's lazy paging. It wraps a listStrategy and adds the +// shared tail: MLflow enrichment (text only) and row projection. type runFetcher struct { ctx context.Context w *databricks.WorkspaceClient - query map[string]any - userFilter string - filters listFilters fetchMLflow bool + strategy listStrategy - pending []jobRun // runs from the last page not yet inspected - scanned int exhausted bool } func newRunFetcher(ctx context.Context, w *databricks.WorkspaceClient, q listQuery) *runFetcher { + return &runFetcher{ + ctx: ctx, + w: w, + fetchMLflow: q.fetchMLflow, + strategy: newListStrategy(ctx, w, q), + } +} + +// newListStrategy picks the fetch source. The AiTrainingService index serves only +// the caller's own runs, so it's used for an all-status self-scoped list; if the +// index load fails (e.g. endpoint unavailable in this workspace), we fall back to +// the Jobs scan so the command still returns. Everything else — the default +// active list, --all-users, and --all-status for another user — uses the scan. +func newListStrategy(ctx context.Context, w *databricks.WorkspaceClient, q listQuery) listStrategy { + useIndex := !q.activeOnly && !q.allUsers && (q.userFilter == "" || q.userFilter == q.currentUser) + if !useIndex { + return newJobsScanStrategy(ctx, w, q) + } + idx := newIndexStrategy(ctx, w, q, q.limit) + if err := idx.load(); err != nil { + log.Debugf(ctx, "air list: AiTrainingService index unavailable, falling back to Jobs scan: %v", err) + return newJobsScanStrategy(ctx, w, q) + } + return idx +} + +// next pulls the next batch from the strategy, enriches it with MLflow links for +// text output, and projects it to rows. It sets exhausted once the strategy is +// drained so the interactive table knows to stop paging. +func (f *runFetcher) next(want int) ([]listRow, error) { + entries, err := f.strategy.next(want) + if err != nil { + return nil, err + } + f.exhausted = f.strategy.done() + + // MLflow links appear only in the text table, so the per-run get-output + // lookups are skipped for JSON output (which omits the column anyway). + if f.fetchMLflow { + setMLflowLinks(f.ctx, f.w, entries) + } + + rows := make([]listRow, len(entries)) + for i, e := range entries { + rows[i] = e.row + } + return rows, nil +} + +// jobsScanStrategy pages Jobs runs/list, keeping the AIR runs that match the user +// and filters. It buffers a page's leftover runs so successive next() calls +// resume where the last stopped. +type jobsScanStrategy struct { + ctx context.Context + w *databricks.WorkspaceClient + query map[string]any + userFilter string + filters listFilters + + pending []jobRun // runs from the last page not yet inspected + scanned int + drained bool +} + +func newJobsScanStrategy(ctx context.Context, w *databricks.WorkspaceClient, q listQuery) *jobsScanStrategy { query := map[string]any{ "run_type": "SUBMIT_RUN", "expand_tasks": true, @@ -160,84 +241,74 @@ func newRunFetcher(ctx context.Context, w *databricks.WorkspaceClient, q listQue if q.activeOnly { query["active_only"] = true } - return &runFetcher{ - ctx: ctx, - w: w, - query: query, - userFilter: q.userFilter, - filters: q.filters, - fetchMLflow: q.fetchMLflow, + return &jobsScanStrategy{ + ctx: ctx, + w: w, + query: query, + userFilter: q.userFilter, + filters: q.filters, } } -// next returns up to want more matching rows, paging runs/list (and buffering the -// leftover runs of a page) until it has enough, the server has no more pages, or -// it has scanned maxListScan runs. MLflow links are filled in for text output. -func (f *runFetcher) next(want int) ([]listRow, error) { +func (s *jobsScanStrategy) next(want int) ([]listedRun, error) { var entries []listedRun for len(entries) < want { - if len(f.pending) == 0 { - if f.exhausted || f.scanned >= maxListScan { + if len(s.pending) == 0 { + if s.drained || s.scanned >= maxListScan { break } - if err := f.fetchPage(); err != nil { + if err := s.fetchPage(); err != nil { return nil, err } continue } - run := &f.pending[0] - f.pending = f.pending[1:] - f.scanned++ + run := &s.pending[0] + s.pending = s.pending[1:] + s.scanned++ if !isAirRun(run) { continue } - if f.userFilter != "" && run.CreatorUserName != f.userFilter { + if s.userFilter != "" && run.CreatorUserName != s.userFilter { continue } - if !f.filters.matches(run) { + if !s.filters.matches(run) { continue } entries = append(entries, listedRun{row: buildListRow(run), taskRunID: taskRunID(run)}) } - if f.scanned >= maxListScan { - f.exhausted = true - } + return entries, nil +} - // MLflow links appear only in the text table, so the per-run get-output - // lookups are skipped for JSON output (which omits the column anyway). - if f.fetchMLflow { - setMLflowLinks(f.ctx, f.w, entries) - } +func (s *jobsScanStrategy) done() bool { + return (s.drained && len(s.pending) == 0) || s.scanned >= maxListScan +} - rows := make([]listRow, len(entries)) - for i, e := range entries { - rows[i] = e.row - } - return rows, nil +func (s *jobsScanStrategy) truncated() bool { + return s.scanned >= maxListScan } // fetchPage loads the next runs/list page into the pending buffer, marking the -// fetcher exhausted once the server reports no further pages. -func (f *runFetcher) fetchPage() error { - resp, err := fetchJobRunsPage(f.ctx, f.w, f.query) +// strategy drained once the server reports no further pages. +func (s *jobsScanStrategy) fetchPage() error { + resp, err := fetchJobRunsPage(s.ctx, s.w, s.query) if err != nil { return err } - f.pending = resp.Runs + s.pending = resp.Runs if resp.NextPageToken == "" { - f.exhausted = true + s.drained = true } else { - f.query["page_token"] = resp.NextPageToken + s.query["page_token"] = resp.NextPageToken } return nil } -// warnIfTruncated logs when the scan hit maxListScan, so one-shot output signals +// warnIfTruncated logs when a scan hit its safety cap, so one-shot output signals // its results may be incomplete. func warnIfTruncated(ctx context.Context, f *runFetcher) { - if f.scanned >= maxListScan { + if f.strategy.truncated() { log.Warnf(ctx, "air list: stopped after scanning %d runs; results may be incomplete", maxListScan) } } diff --git a/experimental/air/cmd/list_cache.go b/experimental/air/cmd/list_cache.go new file mode 100644 index 0000000000..fb4442a112 --- /dev/null +++ b/experimental/air/cmd/list_cache.go @@ -0,0 +1,79 @@ +package aircmd + +import ( + "context" + "time" + + "github.com/databricks/cli/libs/cache" +) + +// The AiTrainingService index path caches hydrated terminal runs on disk: +// terminal runs are immutable, so once we've paid for runs/get + get-output + +// MLflow we persist the finished row and skip those round-trips next time. The +// TTL matches AICM's ~60-day retention, after which the run drops out of the +// index anyway. +const ( + listCacheComponent = "air-list-runs" + listCacheTTL = 60 * 24 * time.Hour +) + +// listCacheKey fingerprints a cached run. Host isolates workspaces (a Jobs run +// id is unique only within one), matching how libs/cache namespaces entries. +type listCacheKey struct { + Host string `json:"host"` + RunID int64 `json:"run_id"` +} + +// cachedRun is the persisted value: every listRow field (including the +// table-only columns, which listRow tags json:"-" and so wouldn't survive a +// direct marshal) plus the submit time. +type cachedRun struct { + RunID string `json:"run_id"` + RunName string `json:"run_name"` + User string `json:"user"` + Status string `json:"status"` + StartedAt *string `json:"started_at"` + IsSweep bool `json:"is_sweep"` + Experiment string `json:"experiment"` + Duration string `json:"duration"` + MLflowURL string `json:"mlflow_url"` + Accelerators string `json:"accelerators"` + SubmitTimeMs int64 `json:"submit_time_ms"` +} + +func (c cachedRun) toRow() listRow { + return listRow{ + RunID: c.RunID, RunName: c.RunName, User: c.User, Status: c.Status, + StartedAt: c.StartedAt, IsSweep: c.IsSweep, Experiment: c.Experiment, + Duration: c.Duration, MLflowURL: c.MLflowURL, Accelerators: c.Accelerators, + } +} + +func cachedRunFromRow(r listRow, submitTimeMs int64) cachedRun { + return cachedRun{ + RunID: r.RunID, RunName: r.RunName, User: r.User, Status: r.Status, + StartedAt: r.StartedAt, IsSweep: r.IsSweep, Experiment: r.Experiment, + Duration: r.Duration, MLflowURL: r.MLflowURL, Accelerators: r.Accelerators, + SubmitTimeMs: submitTimeMs, + } +} + +// newListCache builds the cache for the index path. It fails open, so a nil +// return (or any cache error) just means every run is hydrated from the API. +func newListCache(ctx context.Context) *cache.Cache { + return cache.NewCache(ctx, listCacheComponent, listCacheTTL, nil) +} + +// cachedRow returns the cached row for a run, or (zero, false) on miss. +func cachedRow(ctx context.Context, c *cache.Cache, host string, runID int64) (listRow, bool) { + entry, ok := cache.Get[cachedRun](ctx, c, listCacheKey{Host: host, RunID: runID}) + if !ok { + return listRow{}, false + } + return entry.toRow(), true +} + +// putRow caches a terminal run's finished row under its submit time. +func putRow(ctx context.Context, c *cache.Cache, host string, runID, submitTimeMs int64, row listRow) { + cache.Put(ctx, c, listCacheKey{Host: host, RunID: runID}, cachedRunFromRow(row, submitTimeMs)) +} diff --git a/experimental/air/cmd/list_cache_test.go b/experimental/air/cmd/list_cache_test.go new file mode 100644 index 0000000000..3ac695e8d0 --- /dev/null +++ b/experimental/air/cmd/list_cache_test.go @@ -0,0 +1,56 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListCacheRoundTrip(t *testing.T) { + t.Setenv("DATABRICKS_CACHE_DIR", t.TempDir()) + ctx := t.Context() + c := newListCache(ctx) + + _, ok := cachedRow(ctx, c, "https://host.test", 42) + require.False(t, ok, "miss before write") + + row := listRow{RunID: "42", Experiment: "exp", Status: "SUCCESS"} + putRow(ctx, c, "https://host.test", 42, 1700000000000, row) + + got, ok := cachedRow(ctx, c, "https://host.test", 42) + require.True(t, ok, "hit after write") + assert.Equal(t, row, got) + + // Different host is a different key. + _, ok = cachedRow(ctx, c, "https://other.test", 42) + assert.False(t, ok) +} + +func TestIndexStrategyServesCachedRowWithoutFetch(t *testing.T) { + t.Setenv("DATABRICKS_CACHE_DIR", t.TempDir()) + + refs := []workflowRef{{jobRunID: 7, submitTimeMs: 1000_000}} + srv, hits := indexAndGetServer(t, refs, map[int64]jobRun{7: indexRun(7, 1000_000)}, nil, nil) + host := srv.URL + + // Pre-seed the cache for run 7 so hydration should skip runs/get entirely. + ctx := t.Context() + putRow(ctx, newListCache(ctx), host, 7, 1000_000, listRow{RunID: "7", Status: "SUCCESS"}) + + f := newRunFetcher(ctx, newTestWorkspaceClient(t, host), listQuery{ + userFilter: "me@example.com", currentUser: "me@example.com", limit: 10, + }) + rows, err := f.next(10) + require.NoError(t, err) + require.Len(t, rows, 1) + assert.Equal(t, "7", rows[0].RunID) + assert.Equal(t, 0, hits.get, "cached run must not hit runs/get") +} + +func TestIsTerminal(t *testing.T) { + assert.True(t, isTerminal(&jobRun{State: jobState{LifeCycleState: "TERMINATED"}})) + assert.True(t, isTerminal(&jobRun{State: jobState{LifeCycleState: "INTERNAL_ERROR"}})) + assert.False(t, isTerminal(&jobRun{State: jobState{LifeCycleState: "RUNNING"}})) + assert.False(t, isTerminal(&jobRun{State: jobState{LifeCycleState: "PENDING"}})) +} diff --git a/experimental/air/cmd/list_filter.go b/experimental/air/cmd/list_filter.go index 6ac47c74b0..272141e80d 100644 --- a/experimental/air/cmd/list_filter.go +++ b/experimental/air/cmd/list_filter.go @@ -10,6 +10,14 @@ import ( // supportedFilterKeys are the keys accepted by `air list --filter KEY=VALUE`. var supportedFilterKeys = []string{"accelerator_type", "experiment", "num_accelerators", "user"} +// hasTaskFilter reports whether any filter is applied to a run's task fields +// (experiment or accelerators), i.e. matched after a run is fetched rather than +// while scanning. The index path uses this to skip its newest-N truncation, so a +// dropped match doesn't shrink the result below --limit. +func (f listFilters) hasTaskFilter() bool { + return f.Experiment != "" || f.AcceleratorType != "" || f.NumAccelerators != nil +} + // listFilters holds the parsed `--filter` values for `air list`. type listFilters struct { // User is an exact creator-email match diff --git a/experimental/air/cmd/list_index.go b/experimental/air/cmd/list_index.go new file mode 100644 index 0000000000..097859ba40 --- /dev/null +++ b/experimental/air/cmd/list_index.go @@ -0,0 +1,141 @@ +package aircmd + +import ( + "cmp" + "context" + "slices" + + "github.com/databricks/cli/libs/cache" + "github.com/databricks/databricks-sdk-go" +) + +// indexStrategy serves the caller's own runs from the AiTrainingService index: +// it fetches every run id up front (cheap id+timestamp pairs), orders them +// newest-first, keeps the newest `limit`, then hydrates them into full rows in +// want-sized batches via Jobs runs/get. Terminal rows are cached so repeat calls +// skip the network. Unlike the Jobs scan it can't lazy-page (it must sort the +// whole id set first), but it still yields in batches so the table paints early. +type indexStrategy struct { + ctx context.Context + w *databricks.WorkspaceClient + activeOnly bool + filters listFilters + limit int + cache *cache.Cache + + ids []int64 // newest-first run ids to hydrate, resolved on first next() + pos int + loaded bool +} + +func newIndexStrategy(ctx context.Context, w *databricks.WorkspaceClient, q listQuery, limit int) *indexStrategy { + return &indexStrategy{ + ctx: ctx, + w: w, + activeOnly: q.activeOnly, + filters: q.filters, + limit: limit, + cache: newListCache(ctx), + } +} + +// load fetches and orders the index once. It returns an error only when the +// index endpoint itself fails, letting the caller fall back to the Jobs scan. +func (s *indexStrategy) load() error { + refs, err := listAiTrainingWorkflows(s.ctx, s.w, s.activeOnly) + if err != nil { + return err + } + slices.SortFunc(refs, func(a, b workflowRef) int { return cmp.Compare(b.submitTimeMs, a.submitTimeMs) }) + // Keep only the newest `limit` ids so hydration is bounded — but skip that when + // a task filter is active, since it drops matches post-hydration and we'd + // otherwise return fewer than `limit`. The caller stops pulling at `limit`. + if s.limit > 0 && len(refs) > s.limit && !s.filters.hasTaskFilter() { + refs = refs[:s.limit] + } + s.ids = make([]int64, len(refs)) + for i, r := range refs { + s.ids[i] = r.jobRunID + } + s.loaded = true + return nil +} + +func (s *indexStrategy) next(want int) ([]listedRun, error) { + if !s.loaded { + if err := s.load(); err != nil { + return nil, err + } + } + + var entries []listedRun + for len(entries) < want && s.pos < len(s.ids) { + end := min(s.pos+want-len(entries), len(s.ids)) + batch := s.ids[s.pos:end] + s.pos = end + + rows, err := s.hydrate(batch) + if err != nil { + return nil, err + } + entries = append(entries, rows...) + } + return entries, nil +} + +func (s *indexStrategy) done() bool { + return s.loaded && s.pos >= len(s.ids) +} + +// truncated is always false: the index path is bounded by limit, not a scan cap. +func (s *indexStrategy) truncated() bool { return false } + +// hydrate turns a batch of run ids into rows, serving cached terminal rows +// without a network call and fetching the rest via runs/get. Freshly hydrated +// terminal runs are cached. Results keep the input (newest-first) order, then +// the batch is re-sorted by start time since concurrent hydration reorders it. +func (s *indexStrategy) hydrate(ids []int64) ([]listedRun, error) { + host := s.w.Config.Host + + rows := make([]listedRun, 0, len(ids)) + var toFetch []int64 + for _, id := range ids { + if row, ok := cachedRow(s.ctx, s.cache, host, id); ok { + rows = append(rows, listedRun{row: row, taskRunID: id}) + continue + } + toFetch = append(toFetch, id) + } + + runs, err := hydrateJobRuns(s.ctx, s.w, toFetch) + if err != nil { + return nil, err + } + for _, run := range runs { + if !s.filters.matches(run) { + continue + } + row := buildListRow(run) + rows = append(rows, listedRun{row: row, taskRunID: taskRunID(run)}) + if isTerminal(run) { + start, _ := jobTiming(run) + putRow(s.ctx, s.cache, host, run.RunID, start, row) + } + } + + // Concurrent hydration reorders runs, so re-sort the batch newest-first. The + // ISO start timestamp sorts lexicographically; a missing time ("") sorts last. + slices.SortStableFunc(rows, func(a, b listedRun) int { + return cmp.Compare(rowStartKey(b.row), rowStartKey(a.row)) + }) + return rows, nil +} + +// rowStartKey returns a row's ISO start timestamp for ordering, or "" when the +// run hasn't started (which sorts last under descending comparison). +func rowStartKey(r listRow) string { + if r.StartedAt == nil { + return "" + } + return *r.StartedAt +} diff --git a/experimental/air/cmd/list_index_test.go b/experimental/air/cmd/list_index_test.go new file mode 100644 index 0000000000..ec8d082e3a --- /dev/null +++ b/experimental/air/cmd/list_index_test.go @@ -0,0 +1,212 @@ +package aircmd + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// indexRun is a terminal AIR run with a start time, for index-path hydration. +func indexRun(id, startMillis int64) jobRun { + r := airJobRun(id, "me@example.com", "GPU_1xH100", 1, "/Users/me@example.com/exp") + r.State = jobState{LifeCycleState: "TERMINATED", ResultState: "SUCCESS"} + r.Tasks[0].StartTime = startMillis + r.Tasks[0].EndTime = startMillis + 1000 + return r +} + +// indexAndGetServer serves the AiTrainingService index (a single page of the +// given refs) and runs/get for each id, recording hit counts per endpoint. A +// runID in forbidden returns 403; in missing returns 404. +type indexHits struct{ index, get int } + +func indexAndGetServer(t *testing.T, refs []workflowRef, runs map[int64]jobRun, forbidden, missing map[int64]bool) (*httptest.Server, *indexHits) { + t.Helper() + hits := &indexHits{} + wfs := make([]map[string]any, len(refs)) + for i, r := range refs { + wfs[i] = map[string]any{"job_run_id": strconv.FormatInt(r.jobRunID, 10), "submit_time": map[string]any{"seconds": r.submitTimeMs / 1000}} + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case aiTrainingWorkflowsPath: + hits.index++ + _ = json.NewEncoder(w).Encode(map[string]any{"training_workflows": wfs}) + case jobsRunsGetPath: + hits.get++ + id, _ := strconv.ParseInt(r.URL.Query().Get("run_id"), 10, 64) + if forbidden[id] { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message":"forbidden"}`)) + return + } + if missing[id] { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message":"not found"}`)) + return + } + run := runs[id] + _ = json.NewEncoder(w).Encode(run) + default: + _, _ = w.Write([]byte(`{}`)) + } + })) + t.Cleanup(srv.Close) + return srv, hits +} + +func TestIndexStrategyOrdersAndLimits(t *testing.T) { + // Three runs, out of submit-time order; newest two should win, newest-first. + refs := []workflowRef{ + {jobRunID: 1, submitTimeMs: 1000_000}, + {jobRunID: 2, submitTimeMs: 3000_000}, + {jobRunID: 3, submitTimeMs: 2000_000}, + } + runs := map[int64]jobRun{ + 1: indexRun(1, 1000_000), + 2: indexRun(2, 3000_000), + 3: indexRun(3, 2000_000), + } + srv, _ := indexAndGetServer(t, refs, runs, nil, nil) + t.Setenv("DATABRICKS_CACHE_ENABLED", "false") + + f := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{ + userFilter: "me@example.com", currentUser: "me@example.com", limit: 2, + }) + rows, err := f.next(10) + require.NoError(t, err) + require.Len(t, rows, 2) + assert.Equal(t, "2", rows[0].RunID) // submit 3000 + assert.Equal(t, "3", rows[1].RunID) // submit 2000 + assert.True(t, f.exhausted) +} + +func TestIndexStrategyOverFetchesWithTaskFilter(t *testing.T) { + // With a task filter and limit 1, the newest run doesn't match; the strategy + // must keep hydrating past `limit` to find the match rather than truncating. + refs := []workflowRef{ + {jobRunID: 1, submitTimeMs: 3000_000}, + {jobRunID: 2, submitTimeMs: 2000_000}, + } + run1 := indexRun(1, 3000_000) + run1.Tasks[0].AiRuntimeTask.Experiment = "/Users/me@example.com/llama" + run2 := indexRun(2, 2000_000) + run2.Tasks[0].AiRuntimeTask.Experiment = "/Users/me@example.com/qwen" + srv, _ := indexAndGetServer(t, refs, map[int64]jobRun{1: run1, 2: run2}, nil, nil) + t.Setenv("DATABRICKS_CACHE_ENABLED", "false") + + f := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{ + userFilter: "me@example.com", currentUser: "me@example.com", limit: 1, + filters: listFilters{Experiment: "qwen"}, + }) + rows, err := f.next(1) + require.NoError(t, err) + require.Len(t, rows, 1) + assert.Equal(t, "2", rows[0].RunID) // found despite being the older, second id +} + +func TestIndexStrategyDropsForbiddenAndMissing(t *testing.T) { + refs := []workflowRef{ + {jobRunID: 1, submitTimeMs: 3000_000}, + {jobRunID: 2, submitTimeMs: 2000_000}, + {jobRunID: 3, submitTimeMs: 1000_000}, + } + runs := map[int64]jobRun{1: indexRun(1, 3000_000), 3: indexRun(3, 1000_000)} + srv, _ := indexAndGetServer(t, refs, runs, map[int64]bool{2: true}, nil) + t.Setenv("DATABRICKS_CACHE_ENABLED", "false") + + f := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{ + userFilter: "me@example.com", currentUser: "me@example.com", limit: 10, + }) + rows, err := f.next(10) + require.NoError(t, err) + require.Len(t, rows, 2) // run 2 (403) dropped + assert.Equal(t, "1", rows[0].RunID) + assert.Equal(t, "3", rows[1].RunID) +} + +func TestIndexStrategyPropagatesServerError(t *testing.T) { + refs := []workflowRef{{jobRunID: 1, submitTimeMs: 1000_000}} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case aiTrainingWorkflowsPath: + wfs := []map[string]any{{"job_run_id": "1", "submit_time": map[string]any{"seconds": int64(1000)}}} + _ = json.NewEncoder(w).Encode(map[string]any{"training_workflows": wfs}) + case jobsRunsGetPath: + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message":"boom"}`)) + default: + _, _ = w.Write([]byte(`{}`)) + } + })) + t.Cleanup(srv.Close) + _ = refs + t.Setenv("DATABRICKS_CACHE_ENABLED", "false") + + f := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{ + userFilter: "me@example.com", currentUser: "me@example.com", limit: 10, + }) + _, err := f.next(10) + require.Error(t, err) // 500 is systemic, not an ACL drop +} + +func TestNewListStrategyGate(t *testing.T) { + // --all-users and other-user filters must NOT touch the index endpoint. + cases := []struct { + name string + q listQuery + wantIndex bool + }{ + {"active default → scan", listQuery{activeOnly: true, userFilter: "me@example.com", currentUser: "me@example.com"}, false}, + {"all-status self → index", listQuery{userFilter: "me@example.com", currentUser: "me@example.com", limit: 5}, true}, + {"all-status all-users → scan", listQuery{allUsers: true}, false}, + {"all-status other user → scan", listQuery{userFilter: "other@example.com", currentUser: "me@example.com"}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var indexHit bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == aiTrainingWorkflowsPath { + indexHit = true + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + t.Setenv("DATABRICKS_CACHE_ENABLED", "false") + + newListStrategy(t.Context(), newTestWorkspaceClient(t, srv.URL), tc.q) + assert.Equal(t, tc.wantIndex, indexHit) + }) + } +} + +func TestNewListStrategyFallsBackWhenIndexFails(t *testing.T) { + // Index 500 must silently fall back to the Jobs scan, not fail the command. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == aiTrainingWorkflowsPath { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message":"boom"}`)) + return + } + if r.URL.Path == jobsRunsListPath { + _, _ = fmt.Fprint(w, runsListBody(t, "", airJobRun(1, "me@example.com", "GPU_1xH100", 1, "exp"))) + return + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + t.Setenv("DATABRICKS_CACHE_ENABLED", "false") + + f := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{ + userFilter: "me@example.com", currentUser: "me@example.com", limit: 10, + }) + rows, err := f.next(10) + require.NoError(t, err) + require.Len(t, rows, 1) // served by the Jobs scan fallback +} diff --git a/experimental/air/cmd/list_test.go b/experimental/air/cmd/list_test.go index 11185528a2..8d186eb048 100644 --- a/experimental/air/cmd/list_test.go +++ b/experimental/air/cmd/list_test.go @@ -67,6 +67,7 @@ func TestListAirRunsFiltersUserAndType(t *testing.T) { srv := runsServer(t, runsListBody(t, "", runs...)) rows, err := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{ + activeOnly: true, userFilter: "me@example.com", }).next(10) require.NoError(t, err) @@ -83,7 +84,8 @@ func TestListAirRunsExperimentFilter(t *testing.T) { srv := runsServer(t, runsListBody(t, "", runs...)) rows, err := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{ - filters: listFilters{Experiment: "qwen*"}, + activeOnly: true, + filters: listFilters{Experiment: "qwen*"}, }).next(10) require.NoError(t, err) require.Len(t, rows, 1) @@ -98,7 +100,7 @@ func TestListAirRunsLimitTruncates(t *testing.T) { } srv := runsServer(t, runsListBody(t, "", runs...)) - rows, err := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{}).next(2) + rows, err := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{activeOnly: true}).next(2) require.NoError(t, err) require.Len(t, rows, 2) assert.Equal(t, "1", rows[0].RunID) @@ -110,7 +112,7 @@ func TestListAirRunsPaginates(t *testing.T) { page2 := runsListBody(t, "", airJobRun(2, "me@example.com", "GPU_1xH100", 1, "exp-b")) srv := runsServer(t, page1, page2) - rows, err := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{}).next(10) + rows, err := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{activeOnly: true}).next(10) require.NoError(t, err) require.Len(t, rows, 2) assert.Equal(t, "1", rows[0].RunID) @@ -127,7 +129,7 @@ func TestRunFetcherResumesAcrossCalls(t *testing.T) { airJobRun(3, "me@example.com", "GPU_1xH100", 1, "exp-c"), } srv := runsServer(t, runsListBody(t, "", runs...)) - f := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{}) + f := newRunFetcher(t.Context(), newTestWorkspaceClient(t, srv.URL), listQuery{activeOnly: true}) first, err := f.next(2) require.NoError(t, err) diff --git a/experimental/air/cmd/list_tui.go b/experimental/air/cmd/list_tui.go index 5a2b1e128e..6126d1fbe8 100644 --- a/experimental/air/cmd/list_tui.go +++ b/experimental/air/cmd/list_tui.go @@ -113,25 +113,25 @@ type moreRowsMsg struct { err error } -// fetchCmd pulls the next batch of rows in the background; guarded by loading so -// only one runs at a time. -func (m *listModel) fetchCmd() tea.Cmd { +// fetchCmd returns the model with loading set and a command that pulls the next +// batch of rows in the background. +func (m listModel) fetchCmd() (listModel, tea.Cmd) { m.loading = true f := m.fetcher - return func() tea.Msg { + return m, func() tea.Msg { rows, err := f.next(listPageRows) return moreRowsMsg{rows: rows, err: err} } } // maybeFetch starts a fetch when the cursor nears the end of the loaded rows and -// more runs may still exist. -func (m *listModel) maybeFetch() tea.Cmd { +// more runs may still exist, returning the (possibly loading) model and command. +func (m listModel) maybeFetch() (listModel, tea.Cmd) { if m.fetcher == nil || m.loading || m.loadErr != nil || m.fetcher.exhausted { - return nil + return m, nil } if m.cursor < len(m.rows)-m.visibleCount() { - return nil + return m, nil } return m.fetchCmd() } @@ -154,7 +154,8 @@ func (m listModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: m.height = msg.Height m.offset = m.clampedOffset() - return m, m.maybeFetch() + m, cmd := m.maybeFetch() + return m, cmd case moreRowsMsg: m.loading = false @@ -168,7 +169,8 @@ func (m listModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // A page with no matches but more to scan: keep paging so the cursor isn't // stuck at the end of the loaded rows. if len(msg.rows) == 0 && !m.fetcher.exhausted { - return m, m.fetchCmd() + m, cmd := m.fetchCmd() + return m, cmd } return m, nil @@ -199,7 +201,8 @@ func (m listModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } m.offset = m.clampedOffset() - return m, m.maybeFetch() + m, cmd := m.maybeFetch() + return m, cmd } return m, nil }