Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions libs/localenv/target.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package localenv

import (
"context"
"fmt"
"strings"
)

// ComputeClient is a narrow seam over the SDK so tests can stub it.
type ComputeClient interface {
// GetClusterSparkVersion returns the Spark version string for a cluster.
GetClusterSparkVersion(ctx context.Context, clusterID string) (string, error)
// GetJobSparkVersion returns either a Spark version (isServerless=false) or a
// serverless marker (isServerless=true) for a job, plus a recorded version string.
GetJobSparkVersion(ctx context.Context, jobID string) (sparkVersion string, isServerless bool, version string, err error)
}

// TargetFlags holds the mutually-exclusive compute target flags from the CLI.
type TargetFlags struct {
Cluster string
Serverless string
Job string
}

// BundleTarget is the three-state result of reading the bundle's configured
// target. Selected=false means nothing was configured.
type BundleTarget struct {
ClusterID string
Serverless bool
Selected bool
}

// noTargetMessage is the actionable message shown when no target is selected,
// matching spec §2.3.
const noTargetMessage = "No compute target is selected. Select a cluster or serverless target, or pass --cluster / --serverless / --job"

// ValidateTargetFlags returns an error if more than one of the three flags is set.
// Cobra marks them mutually exclusive too; this guards the library path.
func ValidateTargetFlags(f TargetFlags) error {
var set []string
if f.Cluster != "" {
set = append(set, "--cluster")
}
if f.Serverless != "" {
set = append(set, "--serverless")
}
if f.Job != "" {
set = append(set, "--job")
}
if len(set) > 1 {
return fmt.Errorf("flags %s are mutually exclusive; specify at most one", strings.Join(set, " and "))
}
return nil
}

// ResolveTarget resolves the compute target using ordered precedence:
// --cluster flag → --serverless flag → --job flag → bundle target.
// PythonVersion is left empty; it is filled later from constraint data.
//
// Incompatible flags are rejected up front: without this a library caller that
// bypasses Cobra (which also marks the flags mutually exclusive) and passes more
// than one target flag would have all but the first precedence branch silently
// ignored, resolving a different target than requested.
func ResolveTarget(ctx context.Context, f TargetFlags, c ComputeClient, bt BundleTarget) (*TargetInfo, error) {
if err := ValidateTargetFlags(f); err != nil {
return nil, NewError(ErrResolve, err, "invalid compute target flags")
}

if f.Cluster != "" {
v, err := c.GetClusterSparkVersion(ctx, f.Cluster)
if err != nil {
return nil, NewError(ErrResolve, err, "resolving cluster %s", f.Cluster)
}
return &TargetInfo{
Source: "cluster",
ClusterID: f.Cluster,
SparkVersion: v,
EnvKey: EnvKeyForSparkVersion(v),
}, nil
}

if f.Serverless != "" {
return &TargetInfo{
Source: "serverless",
ServerlessVersion: NormalizeServerless(f.Serverless),
EnvKey: EnvKeyForServerless(f.Serverless),
}, nil
}

if f.Job != "" {
sparkVersion, isServerless, version, err := c.GetJobSparkVersion(ctx, f.Job)
if err != nil {
return nil, NewError(ErrResolve, err, "resolving job %s", f.Job)
}
if isServerless {
// Default to v4 when the job is serverless; the serverless env version
// is not recorded in the bundle/project (documented stand-in from the
// original script, spec §4.3).
v := version
if v == "" {
v = "v4"
}
return &TargetInfo{
Source: "job",
ServerlessVersion: NormalizeServerless(v),
EnvKey: EnvKeyForServerless(v),
}, nil
}
// Classic compute: the Spark version is the first return per the
// GetJobSparkVersion contract, not the recorded-version third return.
return &TargetInfo{
Source: "job",
SparkVersion: sparkVersion,
EnvKey: EnvKeyForSparkVersion(sparkVersion),
}, nil
}

// Fall back to bundle target.
if !bt.Selected {
return nil, NewError(ErrNoTarget, nil, "%s", noTargetMessage)
}

if bt.Serverless {
// Default to serverless-v4: the serverless env version is not recorded
// in the bundle/project (documented stand-in from the original script).
return &TargetInfo{
Source: "bundle",
ServerlessVersion: "v4",
EnvKey: EnvKeyForServerless("v4"),
}, nil
}

if bt.ClusterID != "" {
v, err := c.GetClusterSparkVersion(ctx, bt.ClusterID)
if err != nil {
return nil, NewError(ErrResolve, err, "resolving bundle cluster %s", bt.ClusterID)
}
return &TargetInfo{
Source: "bundle",
ClusterID: bt.ClusterID,
SparkVersion: v,
EnvKey: EnvKeyForSparkVersion(v),
}, nil
}

// Bundle target is selected but has neither serverless nor a cluster ID —
// treat this the same as nothing selected so the user gets a clear message.
return nil, NewError(ErrNoTarget, nil, "%s", noTargetMessage)
}
106 changes: 106 additions & 0 deletions libs/localenv/target_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package localenv

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type stubCompute struct {
clusterVersion string
clusterErr error
}

func (s stubCompute) GetClusterSparkVersion(_ context.Context, _ string) (string, error) {
return s.clusterVersion, s.clusterErr
}

func (s stubCompute) GetJobSparkVersion(_ context.Context, _ string) (string, bool, string, error) {
return "", false, "", nil
}

func TestResolveServerlessFlag(t *testing.T) {
ti, err := ResolveTarget(t.Context(), TargetFlags{Serverless: "v4"}, stubCompute{}, BundleTarget{})
require.NoError(t, err)
assert.Equal(t, "serverless", ti.Source)
assert.Equal(t, "v4", ti.ServerlessVersion)
assert.Equal(t, "serverless/serverless-v4", ti.EnvKey)
}

func TestResolveClusterFlag(t *testing.T) {
c := stubCompute{clusterVersion: "15.4.x-scala2.12"}
ti, err := ResolveTarget(t.Context(), TargetFlags{Cluster: "abc"}, c, BundleTarget{})
require.NoError(t, err)
assert.Equal(t, "cluster", ti.Source)
assert.Equal(t, "15.4.x-scala2.12", ti.SparkVersion)
assert.Equal(t, "dbr/15.4.x-scala2.12", ti.EnvKey)
assert.Equal(t, "abc", ti.ClusterID)
}

func TestResolveClusterFlagError(t *testing.T) {
c := stubCompute{clusterErr: errors.New("cluster not found")}
_, err := ResolveTarget(t.Context(), TargetFlags{Cluster: "abc"}, c, BundleTarget{})
var pe *PipelineError
require.ErrorAs(t, err, &pe)
assert.Equal(t, ErrResolve, pe.Code)
}

func TestResolveBundleNothingSelected(t *testing.T) {
_, err := ResolveTarget(t.Context(), TargetFlags{}, stubCompute{}, BundleTarget{Selected: false})
var pe *PipelineError
require.ErrorAs(t, err, &pe)
assert.Equal(t, ErrNoTarget, pe.Code)
}

func TestResolveBundleServerless(t *testing.T) {
ti, err := ResolveTarget(t.Context(), TargetFlags{}, stubCompute{}, BundleTarget{Selected: true, Serverless: true})
require.NoError(t, err)
assert.Equal(t, "bundle", ti.Source)
assert.Equal(t, "serverless/serverless-v4", ti.EnvKey)
}

// jobStubCompute returns distinct values for the first (sparkVersion) and third
// (recorded version) results of GetJobSparkVersion so the classic-compute branch
// can be checked against the documented contract (it must use the first).
type jobStubCompute struct {
sparkVersion string
isServerless bool
version string
}

func (jobStubCompute) GetClusterSparkVersion(_ context.Context, _ string) (string, error) {
return "", nil
}

func (s jobStubCompute) GetJobSparkVersion(_ context.Context, _ string) (string, bool, string, error) {
return s.sparkVersion, s.isServerless, s.version, nil
}

func TestResolveJobClassicUsesSparkVersionReturn(t *testing.T) {
// Contract: for a classic-compute job the Spark version is the FIRST return.
// The third "recorded version" return differs here to catch use of the wrong one.
c := jobStubCompute{sparkVersion: "15.4.x-scala2.12", isServerless: false, version: "wrong-recorded"}
ti, err := ResolveTarget(t.Context(), TargetFlags{Job: "42"}, c, BundleTarget{})
require.NoError(t, err)
assert.Equal(t, "job", ti.Source)
assert.Equal(t, "15.4.x-scala2.12", ti.SparkVersion)
assert.Equal(t, "dbr/15.4.x-scala2.12", ti.EnvKey)
}

func TestValidateTargetFlagsMutuallyExclusive(t *testing.T) {
assert.Error(t, ValidateTargetFlags(TargetFlags{Cluster: "a", Serverless: "v4"}))
assert.NoError(t, ValidateTargetFlags(TargetFlags{Cluster: "a"}))
}

func TestResolveTargetRejectsConflictingFlags(t *testing.T) {
// ResolveTarget must reject incompatible flags rather than silently taking
// the first precedence branch, so a library caller bypassing Cobra can't
// resolve a different target than it asked for.
_, err := ResolveTarget(t.Context(), TargetFlags{Cluster: "c", Serverless: "v4"}, stubCompute{}, BundleTarget{})
var pe *PipelineError
require.ErrorAs(t, err, &pe)
assert.Equal(t, ErrResolve, pe.Code)
}
Loading