diff --git a/libs/localenv/target.go b/libs/localenv/target.go new file mode 100644 index 0000000000..b3ccb3171e --- /dev/null +++ b/libs/localenv/target.go @@ -0,0 +1,149 @@ +package localenv + +import ( + "context" + "fmt" + "strings" +) + +// ComputeClient is a narrow seam over the SDK so tests can stub it. +type ComputeClient interface { + // GetClusterSparkVersion returns the Spark version string for a cluster. + GetClusterSparkVersion(ctx context.Context, clusterID string) (string, error) + // GetJobSparkVersion returns either a Spark version (isServerless=false) or a + // serverless marker (isServerless=true) for a job, plus a recorded version string. + GetJobSparkVersion(ctx context.Context, jobID string) (sparkVersion string, isServerless bool, version string, err error) +} + +// TargetFlags holds the mutually-exclusive compute target flags from the CLI. +type TargetFlags struct { + Cluster string + Serverless string + Job string +} + +// BundleTarget is the three-state result of reading the bundle's configured +// target. Selected=false means nothing was configured. +type BundleTarget struct { + ClusterID string + Serverless bool + Selected bool +} + +// noTargetMessage is the actionable message shown when no target is selected, +// matching spec §2.3. +const noTargetMessage = "No compute target is selected. Select a cluster or serverless target, or pass --cluster / --serverless / --job" + +// ValidateTargetFlags returns an error if more than one of the three flags is set. +// Cobra marks them mutually exclusive too; this guards the library path. +func ValidateTargetFlags(f TargetFlags) error { + var set []string + if f.Cluster != "" { + set = append(set, "--cluster") + } + if f.Serverless != "" { + set = append(set, "--serverless") + } + if f.Job != "" { + set = append(set, "--job") + } + if len(set) > 1 { + return fmt.Errorf("flags %s are mutually exclusive; specify at most one", strings.Join(set, " and ")) + } + return nil +} + +// ResolveTarget resolves the compute target using ordered precedence: +// --cluster flag → --serverless flag → --job flag → bundle target. +// PythonVersion is left empty; it is filled later from constraint data. +// +// Incompatible flags are rejected up front: without this a library caller that +// bypasses Cobra (which also marks the flags mutually exclusive) and passes more +// than one target flag would have all but the first precedence branch silently +// ignored, resolving a different target than requested. +func ResolveTarget(ctx context.Context, f TargetFlags, c ComputeClient, bt BundleTarget) (*TargetInfo, error) { + if err := ValidateTargetFlags(f); err != nil { + return nil, NewError(ErrResolve, err, "invalid compute target flags") + } + + if f.Cluster != "" { + v, err := c.GetClusterSparkVersion(ctx, f.Cluster) + if err != nil { + return nil, NewError(ErrResolve, err, "resolving cluster %s", f.Cluster) + } + return &TargetInfo{ + Source: "cluster", + ClusterID: f.Cluster, + SparkVersion: v, + EnvKey: EnvKeyForSparkVersion(v), + }, nil + } + + if f.Serverless != "" { + return &TargetInfo{ + Source: "serverless", + ServerlessVersion: NormalizeServerless(f.Serverless), + EnvKey: EnvKeyForServerless(f.Serverless), + }, nil + } + + if f.Job != "" { + sparkVersion, isServerless, version, err := c.GetJobSparkVersion(ctx, f.Job) + if err != nil { + return nil, NewError(ErrResolve, err, "resolving job %s", f.Job) + } + if isServerless { + // Default to v4 when the job is serverless; the serverless env version + // is not recorded in the bundle/project (documented stand-in from the + // original script, spec §4.3). + v := version + if v == "" { + v = "v4" + } + return &TargetInfo{ + Source: "job", + ServerlessVersion: NormalizeServerless(v), + EnvKey: EnvKeyForServerless(v), + }, nil + } + // Classic compute: the Spark version is the first return per the + // GetJobSparkVersion contract, not the recorded-version third return. + return &TargetInfo{ + Source: "job", + SparkVersion: sparkVersion, + EnvKey: EnvKeyForSparkVersion(sparkVersion), + }, nil + } + + // Fall back to bundle target. + if !bt.Selected { + return nil, NewError(ErrNoTarget, nil, "%s", noTargetMessage) + } + + if bt.Serverless { + // Default to serverless-v4: the serverless env version is not recorded + // in the bundle/project (documented stand-in from the original script). + return &TargetInfo{ + Source: "bundle", + ServerlessVersion: "v4", + EnvKey: EnvKeyForServerless("v4"), + }, nil + } + + if bt.ClusterID != "" { + v, err := c.GetClusterSparkVersion(ctx, bt.ClusterID) + if err != nil { + return nil, NewError(ErrResolve, err, "resolving bundle cluster %s", bt.ClusterID) + } + return &TargetInfo{ + Source: "bundle", + ClusterID: bt.ClusterID, + SparkVersion: v, + EnvKey: EnvKeyForSparkVersion(v), + }, nil + } + + // Bundle target is selected but has neither serverless nor a cluster ID — + // treat this the same as nothing selected so the user gets a clear message. + return nil, NewError(ErrNoTarget, nil, "%s", noTargetMessage) +} diff --git a/libs/localenv/target_test.go b/libs/localenv/target_test.go new file mode 100644 index 0000000000..22beaee338 --- /dev/null +++ b/libs/localenv/target_test.go @@ -0,0 +1,106 @@ +package localenv + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type stubCompute struct { + clusterVersion string + clusterErr error +} + +func (s stubCompute) GetClusterSparkVersion(_ context.Context, _ string) (string, error) { + return s.clusterVersion, s.clusterErr +} + +func (s stubCompute) GetJobSparkVersion(_ context.Context, _ string) (string, bool, string, error) { + return "", false, "", nil +} + +func TestResolveServerlessFlag(t *testing.T) { + ti, err := ResolveTarget(t.Context(), TargetFlags{Serverless: "v4"}, stubCompute{}, BundleTarget{}) + require.NoError(t, err) + assert.Equal(t, "serverless", ti.Source) + assert.Equal(t, "v4", ti.ServerlessVersion) + assert.Equal(t, "serverless/serverless-v4", ti.EnvKey) +} + +func TestResolveClusterFlag(t *testing.T) { + c := stubCompute{clusterVersion: "15.4.x-scala2.12"} + ti, err := ResolveTarget(t.Context(), TargetFlags{Cluster: "abc"}, c, BundleTarget{}) + require.NoError(t, err) + assert.Equal(t, "cluster", ti.Source) + assert.Equal(t, "15.4.x-scala2.12", ti.SparkVersion) + assert.Equal(t, "dbr/15.4.x-scala2.12", ti.EnvKey) + assert.Equal(t, "abc", ti.ClusterID) +} + +func TestResolveClusterFlagError(t *testing.T) { + c := stubCompute{clusterErr: errors.New("cluster not found")} + _, err := ResolveTarget(t.Context(), TargetFlags{Cluster: "abc"}, c, BundleTarget{}) + var pe *PipelineError + require.ErrorAs(t, err, &pe) + assert.Equal(t, ErrResolve, pe.Code) +} + +func TestResolveBundleNothingSelected(t *testing.T) { + _, err := ResolveTarget(t.Context(), TargetFlags{}, stubCompute{}, BundleTarget{Selected: false}) + var pe *PipelineError + require.ErrorAs(t, err, &pe) + assert.Equal(t, ErrNoTarget, pe.Code) +} + +func TestResolveBundleServerless(t *testing.T) { + ti, err := ResolveTarget(t.Context(), TargetFlags{}, stubCompute{}, BundleTarget{Selected: true, Serverless: true}) + require.NoError(t, err) + assert.Equal(t, "bundle", ti.Source) + assert.Equal(t, "serverless/serverless-v4", ti.EnvKey) +} + +// jobStubCompute returns distinct values for the first (sparkVersion) and third +// (recorded version) results of GetJobSparkVersion so the classic-compute branch +// can be checked against the documented contract (it must use the first). +type jobStubCompute struct { + sparkVersion string + isServerless bool + version string +} + +func (jobStubCompute) GetClusterSparkVersion(_ context.Context, _ string) (string, error) { + return "", nil +} + +func (s jobStubCompute) GetJobSparkVersion(_ context.Context, _ string) (string, bool, string, error) { + return s.sparkVersion, s.isServerless, s.version, nil +} + +func TestResolveJobClassicUsesSparkVersionReturn(t *testing.T) { + // Contract: for a classic-compute job the Spark version is the FIRST return. + // The third "recorded version" return differs here to catch use of the wrong one. + c := jobStubCompute{sparkVersion: "15.4.x-scala2.12", isServerless: false, version: "wrong-recorded"} + ti, err := ResolveTarget(t.Context(), TargetFlags{Job: "42"}, c, BundleTarget{}) + require.NoError(t, err) + assert.Equal(t, "job", ti.Source) + assert.Equal(t, "15.4.x-scala2.12", ti.SparkVersion) + assert.Equal(t, "dbr/15.4.x-scala2.12", ti.EnvKey) +} + +func TestValidateTargetFlagsMutuallyExclusive(t *testing.T) { + assert.Error(t, ValidateTargetFlags(TargetFlags{Cluster: "a", Serverless: "v4"})) + assert.NoError(t, ValidateTargetFlags(TargetFlags{Cluster: "a"})) +} + +func TestResolveTargetRejectsConflictingFlags(t *testing.T) { + // ResolveTarget must reject incompatible flags rather than silently taking + // the first precedence branch, so a library caller bypassing Cobra can't + // resolve a different target than it asked for. + _, err := ResolveTarget(t.Context(), TargetFlags{Cluster: "c", Serverless: "v4"}, stubCompute{}, BundleTarget{}) + var pe *PipelineError + require.ErrorAs(t, err, &pe) + assert.Equal(t, ErrResolve, pe.Code) +}