Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
var configureServerless bool
var skipWorkspace bool
var scopes string
var groupID string
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
"Timeout for completing login challenge in the browser")
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
Expand All @@ -143,6 +144,8 @@
"Skip workspace selection for account-level access")
cmd.Flags().StringVar(&scopes, "scopes", "",
"Comma-separated list of OAuth scopes to request (defaults to 'all-apis')")
cmd.Flags().StringVar(&groupID, "group-id", "",
"Numeric Databricks group ID to assume during login (workspace-level logins only)")

cmd.PreRunE = profileHostConflictCheck

Expand Down Expand Up @@ -269,6 +272,7 @@
profileName: profileName,
timeout: loginTimeout,
scopes: scopes,
groupID: groupID,
existingProfile: existingProfile,
browserFunc: getBrowserFunc(cmd),
tokenStore: tokenStore,
Expand All @@ -294,6 +298,14 @@
scopesList = splitScopes(existingProfile.Scopes)
}

// The assumed group ID follows the same precedence as scopes: an
// explicit --group-id flag wins, otherwise re-login preserves the
// value from the existing profile.
assumeGroupID := groupID
if assumeGroupID == "" && existingProfile != nil {
assumeGroupID = existingProfile.AssumeGroupID
}

oauthArgument, err := authArguments.ToOAuthArgument()
if err != nil {
return err
Expand All @@ -306,6 +318,9 @@
if len(scopesList) > 0 {
persistentAuthOpts = append(persistentAuthOpts, u2m.WithScopes(scopesList))
}
if assumeGroupID != "" {
persistentAuthOpts = append(persistentAuthOpts, u2m.WithAssumeGroup(assumeGroupID))

Check failure on line 322 in cmd/auth/login.go

View workflow job for this annotation

GitHub Actions / lint

undefined: u2m.WithAssumeGroup
}
persistentAuth, err := u2m.NewPersistentAuth(ctx, persistentAuthOpts...)
if err != nil {
return err
Expand Down Expand Up @@ -394,6 +409,7 @@
ConfigFile: env.Get(ctx, "DATABRICKS_CONFIG_FILE"),
ServerlessComputeID: serverlessComputeID,
Scopes: scopesList,
AssumeGroupID: assumeGroupID,

Check failure on line 412 in cmd/auth/login.go

View workflow job for this annotation

GitHub Actions / lint

unknown field AssumeGroupID in struct literal of type "github.com/databricks/databricks-sdk-go/config".Config
}, clearKeys...)
if err != nil {
return err
Expand Down Expand Up @@ -635,6 +651,7 @@
profileName string
timeout time.Duration
scopes string
groupID string
existingProfile *profile.Profile
browserFunc func(string) error
tokenStore storage.Store
Expand All @@ -655,6 +672,11 @@
scopesList = splitScopes(in.existingProfile.Scopes)
}

assumeGroupID := in.groupID
if assumeGroupID == "" && in.existingProfile != nil {
assumeGroupID = in.existingProfile.AssumeGroupID
}

opts := []u2m.PersistentAuthOption{
u2m.WithOAuthArgument(arg),
u2m.WithBrowser(in.browserFunc),
Expand All @@ -664,6 +686,9 @@
if len(scopesList) > 0 {
opts = append(opts, u2m.WithScopes(scopesList))
}
if assumeGroupID != "" {
opts = append(opts, u2m.WithAssumeGroup(assumeGroupID))

Check failure on line 690 in cmd/auth/login.go

View workflow job for this annotation

GitHub Actions / lint

undefined: u2m.WithAssumeGroup
}
discoveryHost := env.Get(ctx, discoveryHostEnvVar)
if discoveryHost != "" {
opts = append(opts, u2m.WithDiscoveryHost(discoveryHost))
Expand Down Expand Up @@ -743,13 +768,14 @@
"serverless_compute_id",
)
err = databrickscfg.SaveToProfile(ctx, &config.Config{
Profile: in.profileName,
Host: discoveredHost,
AuthType: authTypeDatabricksCLI,
AccountID: accountID,
WorkspaceID: workspaceID,
Scopes: scopesList,
ConfigFile: configFile,
Profile: in.profileName,
Host: discoveredHost,
AuthType: authTypeDatabricksCLI,
AccountID: accountID,
WorkspaceID: workspaceID,
Scopes: scopesList,
AssumeGroupID: assumeGroupID,

Check failure on line 777 in cmd/auth/login.go

View workflow job for this annotation

GitHub Actions / lint

unknown field AssumeGroupID in struct literal of type "github.com/databricks/databricks-sdk-go/config".Config (typecheck)
ConfigFile: configFile,
}, clearKeys...)
if err != nil {
if configFile != "" {
Expand Down
90 changes: 90 additions & 0 deletions cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ type fakeDiscoveryClient struct {
// For assertions
introspectHost string
introspectToken string
capturedOpts []u2m.PersistentAuthOption
}

func (f *fakeDiscoveryClient) NewOAuthArgument(profileName string) (*u2m.BasicDiscoveryOAuthArgument, error) {
Expand All @@ -99,6 +100,7 @@ func (f *fakeDiscoveryClient) NewOAuthArgument(profileName string) (*u2m.BasicDi
}

func (f *fakeDiscoveryClient) NewPersistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (discoveryPersistentAuth, error) {
f.capturedOpts = opts
if f.persistentAuthErr != nil {
return nil, f.persistentAuthErr
}
Expand Down Expand Up @@ -1020,6 +1022,94 @@ func TestDiscoveryLogin_ExplicitScopesOverrideExistingProfile(t *testing.T) {
assert.Equal(t, "all-apis", savedProfile.Scopes)
}

func TestDiscoveryLogin_GroupIDFlagWiredAndSaved(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, ".databrickscfg")
err := os.WriteFile(configPath, []byte(""), 0o600)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configPath)

oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY")
require.NoError(t, err)
oauthArg.SetDiscoveredHost("https://workspace.example.com")

dc := &fakeDiscoveryClient{
oauthArg: oauthArg,
persistentAuth: &fakeDiscoveryPersistentAuth{
token: &oauth2.Token{AccessToken: "test-token"},
},
introspectionErr: errors.New("introspection failed"),
}

ctx, _ := cmdio.NewTestContextWithStdout(t.Context())
err = discoveryLogin(ctx, discoveryLoginInputs{
dc: dc,
profileName: "DISCOVERY",
timeout: time.Second,
groupID: "987654",
browserFunc: func(string) error { return nil },
tokenStore: newTestStore(),
})
require.NoError(t, err)

// The assume-group option must have been passed to the SDK. Applying the
// captured opts to a real PersistentAuth and asserting it constructs without
// error confirms WithAssumeGroup was included and is compatible with the
// discovery argument (workspace-level, not the account-target variant).
opts := append([]u2m.PersistentAuthOption{u2m.WithOAuthArgument(oauthArg), u2m.WithDiscoveryLogin()}, dc.capturedOpts...)
_, err = u2m.NewPersistentAuth(ctx, opts...)
require.NoError(t, err)

// The group ID must round-trip into the saved profile so re-login preserves it.
savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler)
require.NoError(t, err)
require.NotNil(t, savedProfile)
assert.Equal(t, "987654", savedProfile.AssumeGroupID)
}

func TestDiscoveryLogin_ReloginPreservesExistingProfileGroupID(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, ".databrickscfg")
err := os.WriteFile(configPath, []byte(""), 0o600)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configPath)

oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY")
require.NoError(t, err)
oauthArg.SetDiscoveredHost("https://workspace.example.com")

dc := &fakeDiscoveryClient{
oauthArg: oauthArg,
persistentAuth: &fakeDiscoveryPersistentAuth{
token: &oauth2.Token{AccessToken: "test-token"},
},
introspectionErr: errors.New("introspection failed"),
}

existingProfile := &profile.Profile{
Name: "DISCOVERY",
Host: "https://old-workspace.example.com",
AssumeGroupID: "111222",
}

// No --group-id flag: should fall back to the existing profile's group ID.
ctx, _ := cmdio.NewTestContextWithStdout(t.Context())
err = discoveryLogin(ctx, discoveryLoginInputs{
dc: dc,
profileName: "DISCOVERY",
timeout: time.Second,
existingProfile: existingProfile,
browserFunc: func(string) error { return nil },
tokenStore: newTestStore(),
})
require.NoError(t, err)

savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler)
require.NoError(t, err)
require.NotNil(t, savedProfile)
assert.Equal(t, "111222", savedProfile.AssumeGroupID)
}

func TestDiscoveryLogin_SPOGHostPopulatesAccountIDFromDiscovery(t *testing.T) {
// Start a mock server that returns SPOG discovery metadata.
server := newDiscoveryServer(t, map[string]any{
Expand Down
1 change: 1 addition & 0 deletions libs/databrickscfg/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ var envAlwaysSkipAttrs = map[string]bool{
"discovery_url": true,
"audience": true,
"cloud": true,
"group_id": true,
}

// envLoader reads config attributes from environment variables. It always skips
Expand Down
1 change: 1 addition & 0 deletions libs/databrickscfg/profile/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ func (f FileProfilerImpl) LoadProfiles(ctx context.Context, fn ProfileMatchFunct
ServerlessComputeID: all["serverless_compute_id"],
HasClientCredentials: all["client_id"] != "" && all["client_secret"] != "",
Scopes: all["scopes"],
AssumeGroupID: all["group_id"],
AuthType: all["auth_type"],
}
if fn(profile) {
Expand Down
1 change: 1 addition & 0 deletions libs/databrickscfg/profile/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Profile struct {
ServerlessComputeID string
HasClientCredentials bool
Scopes string
AssumeGroupID string
AuthType string
}

Expand Down
Loading