From ed308ca833590190f8c9f5d80294fab8c052f50f Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 3 Jul 2026 14:00:47 +0200 Subject: [PATCH] Add --group-id flag to assume a group during OAuth login Wire the SDK's u2m.WithAssumeGroup option into `databricks auth login` so users can assume a Databricks group during the U2M OAuth flow. When set, the numeric group ID is sent as the `assume_group` query parameter on the authorize request and the minted token is scoped to that group. Changes: - Add a `--group-id` flag to `auth login`. The value follows the same precedence as `--scopes`: an explicit flag wins, otherwise re-login preserves the group ID from the existing profile. - Wire WithAssumeGroup into both the standard and discovery login flows and persist the group ID to the profile as `group_id`. - Add `AssumeGroupID` to the profile struct and its file parsing so re-login can read the previously configured group ID back. - Skip `group_id` in the env loader's always-skip list so it comes from the selected profile only, consistent with other auth-steering fields. Co-authored-by: Isaac Signed-off-by: Martin Grund --- cmd/auth/login.go | 40 +++++++++--- cmd/auth/login_test.go | 90 +++++++++++++++++++++++++++ libs/databrickscfg/loader.go | 1 + libs/databrickscfg/profile/file.go | 1 + libs/databrickscfg/profile/profile.go | 1 + 5 files changed, 126 insertions(+), 7 deletions(-) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index f9f2531ac74..6c0dd4516a8 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -133,6 +133,7 @@ a new profile is created. var configureServerless bool var skipWorkspace bool var scopes string + var groupID string cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout, "Timeout for completing login challenge in the browser") cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false, @@ -143,6 +144,8 @@ a new profile is created. "Skip workspace selection for account-level access") cmd.Flags().StringVar(&scopes, "scopes", "", "Comma-separated list of OAuth scopes to request (defaults to 'all-apis')") + cmd.Flags().StringVar(&groupID, "group-id", "", + "Numeric Databricks group ID to assume during login (workspace-level logins only)") cmd.PreRunE = profileHostConflictCheck @@ -269,6 +272,7 @@ a new profile is created. profileName: profileName, timeout: loginTimeout, scopes: scopes, + groupID: groupID, existingProfile: existingProfile, browserFunc: getBrowserFunc(cmd), tokenStore: tokenStore, @@ -294,6 +298,14 @@ a new profile is created. scopesList = splitScopes(existingProfile.Scopes) } + // The assumed group ID follows the same precedence as scopes: an + // explicit --group-id flag wins, otherwise re-login preserves the + // value from the existing profile. + assumeGroupID := groupID + if assumeGroupID == "" && existingProfile != nil { + assumeGroupID = existingProfile.AssumeGroupID + } + oauthArgument, err := authArguments.ToOAuthArgument() if err != nil { return err @@ -306,6 +318,9 @@ a new profile is created. if len(scopesList) > 0 { persistentAuthOpts = append(persistentAuthOpts, u2m.WithScopes(scopesList)) } + if assumeGroupID != "" { + persistentAuthOpts = append(persistentAuthOpts, u2m.WithAssumeGroup(assumeGroupID)) + } persistentAuth, err := u2m.NewPersistentAuth(ctx, persistentAuthOpts...) if err != nil { return err @@ -394,6 +409,7 @@ a new profile is created. ConfigFile: env.Get(ctx, "DATABRICKS_CONFIG_FILE"), ServerlessComputeID: serverlessComputeID, Scopes: scopesList, + AssumeGroupID: assumeGroupID, }, clearKeys...) if err != nil { return err @@ -635,6 +651,7 @@ type discoveryLoginInputs struct { profileName string timeout time.Duration scopes string + groupID string existingProfile *profile.Profile browserFunc func(string) error tokenStore storage.Store @@ -655,6 +672,11 @@ func discoveryLogin(ctx context.Context, in discoveryLoginInputs) error { scopesList = splitScopes(in.existingProfile.Scopes) } + assumeGroupID := in.groupID + if assumeGroupID == "" && in.existingProfile != nil { + assumeGroupID = in.existingProfile.AssumeGroupID + } + opts := []u2m.PersistentAuthOption{ u2m.WithOAuthArgument(arg), u2m.WithBrowser(in.browserFunc), @@ -664,6 +686,9 @@ func discoveryLogin(ctx context.Context, in discoveryLoginInputs) error { if len(scopesList) > 0 { opts = append(opts, u2m.WithScopes(scopesList)) } + if assumeGroupID != "" { + opts = append(opts, u2m.WithAssumeGroup(assumeGroupID)) + } discoveryHost := env.Get(ctx, discoveryHostEnvVar) if discoveryHost != "" { opts = append(opts, u2m.WithDiscoveryHost(discoveryHost)) @@ -743,13 +768,14 @@ func discoveryLogin(ctx context.Context, in discoveryLoginInputs) error { "serverless_compute_id", ) err = databrickscfg.SaveToProfile(ctx, &config.Config{ - Profile: in.profileName, - Host: discoveredHost, - AuthType: authTypeDatabricksCLI, - AccountID: accountID, - WorkspaceID: workspaceID, - Scopes: scopesList, - ConfigFile: configFile, + Profile: in.profileName, + Host: discoveredHost, + AuthType: authTypeDatabricksCLI, + AccountID: accountID, + WorkspaceID: workspaceID, + Scopes: scopesList, + AssumeGroupID: assumeGroupID, + ConfigFile: configFile, }, clearKeys...) if err != nil { if configFile != "" { diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index a8eafb4be43..8b5c2918d24 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -89,6 +89,7 @@ type fakeDiscoveryClient struct { // For assertions introspectHost string introspectToken string + capturedOpts []u2m.PersistentAuthOption } func (f *fakeDiscoveryClient) NewOAuthArgument(profileName string) (*u2m.BasicDiscoveryOAuthArgument, error) { @@ -99,6 +100,7 @@ func (f *fakeDiscoveryClient) NewOAuthArgument(profileName string) (*u2m.BasicDi } func (f *fakeDiscoveryClient) NewPersistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (discoveryPersistentAuth, error) { + f.capturedOpts = opts if f.persistentAuthErr != nil { return nil, f.persistentAuthErr } @@ -1020,6 +1022,94 @@ func TestDiscoveryLogin_ExplicitScopesOverrideExistingProfile(t *testing.T) { assert.Equal(t, "all-apis", savedProfile.Scopes) } +func TestDiscoveryLogin_GroupIDFlagWiredAndSaved(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + err := os.WriteFile(configPath, []byte(""), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://workspace.example.com") + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspectionErr: errors.New("introspection failed"), + } + + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, discoveryLoginInputs{ + dc: dc, + profileName: "DISCOVERY", + timeout: time.Second, + groupID: "987654", + browserFunc: func(string) error { return nil }, + tokenStore: newTestStore(), + }) + require.NoError(t, err) + + // The assume-group option must have been passed to the SDK. Applying the + // captured opts to a real PersistentAuth and asserting it constructs without + // error confirms WithAssumeGroup was included and is compatible with the + // discovery argument (workspace-level, not the account-target variant). + opts := append([]u2m.PersistentAuthOption{u2m.WithOAuthArgument(oauthArg), u2m.WithDiscoveryLogin()}, dc.capturedOpts...) + _, err = u2m.NewPersistentAuth(ctx, opts...) + require.NoError(t, err) + + // The group ID must round-trip into the saved profile so re-login preserves it. + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "987654", savedProfile.AssumeGroupID) +} + +func TestDiscoveryLogin_ReloginPreservesExistingProfileGroupID(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + err := os.WriteFile(configPath, []byte(""), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://workspace.example.com") + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspectionErr: errors.New("introspection failed"), + } + + existingProfile := &profile.Profile{ + Name: "DISCOVERY", + Host: "https://old-workspace.example.com", + AssumeGroupID: "111222", + } + + // No --group-id flag: should fall back to the existing profile's group ID. + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, discoveryLoginInputs{ + dc: dc, + profileName: "DISCOVERY", + timeout: time.Second, + existingProfile: existingProfile, + browserFunc: func(string) error { return nil }, + tokenStore: newTestStore(), + }) + require.NoError(t, err) + + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "111222", savedProfile.AssumeGroupID) +} + func TestDiscoveryLogin_SPOGHostPopulatesAccountIDFromDiscovery(t *testing.T) { // Start a mock server that returns SPOG discovery metadata. server := newDiscoveryServer(t, map[string]any{ diff --git a/libs/databrickscfg/loader.go b/libs/databrickscfg/loader.go index 732dcf9024f..5c17ef2761d 100644 --- a/libs/databrickscfg/loader.go +++ b/libs/databrickscfg/loader.go @@ -110,6 +110,7 @@ var envAlwaysSkipAttrs = map[string]bool{ "discovery_url": true, "audience": true, "cloud": true, + "group_id": true, } // envLoader reads config attributes from environment variables. It always skips diff --git a/libs/databrickscfg/profile/file.go b/libs/databrickscfg/profile/file.go index b7f6074c811..6723b2be33b 100644 --- a/libs/databrickscfg/profile/file.go +++ b/libs/databrickscfg/profile/file.go @@ -87,6 +87,7 @@ func (f FileProfilerImpl) LoadProfiles(ctx context.Context, fn ProfileMatchFunct ServerlessComputeID: all["serverless_compute_id"], HasClientCredentials: all["client_id"] != "" && all["client_secret"] != "", Scopes: all["scopes"], + AssumeGroupID: all["group_id"], AuthType: all["auth_type"], } if fn(profile) { diff --git a/libs/databrickscfg/profile/profile.go b/libs/databrickscfg/profile/profile.go index efd358cd4e5..52758fa0c71 100644 --- a/libs/databrickscfg/profile/profile.go +++ b/libs/databrickscfg/profile/profile.go @@ -18,6 +18,7 @@ type Profile struct { ServerlessComputeID string HasClientCredentials bool Scopes string + AssumeGroupID string AuthType string }