diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 1aec73ae..e72c09fd 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -208,7 +208,12 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { name = user.Name } else { controller.log.App.Debug().Msg("No name from OAuth provider, generating from email") - name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) + parts := strings.SplitN(user.Email, "@", 2) + if len(parts) == 2 { + name = fmt.Sprintf("%s (%s)", utils.Capitalize(parts[0]), parts[1]) + } else { + name = utils.Capitalize(user.Email) + } } var username string diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 142f0b40..e46e7e82 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -146,7 +146,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { client, ok := controller.oidc.GetClient(req.ClientID) if !ok { - controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "") + controller.authorizeError(c, fmt.Errorf("client not found: %s", req.ClientID), "Client not found", "The client ID is invalid", "", "", "") return } @@ -288,7 +288,7 @@ func (controller *OIDCController) Token(c *gin.Context) { entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) if err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { - controller.log.App.Error().Err(err).Msg("Failed to delete code") + controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code") } if errors.Is(err, service.ErrCodeNotFound) { controller.log.App.Warn().Msg("Code not found") diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index a721aa2b..925c2951 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -773,46 +773,49 @@ func (auth *AuthService) ensureOAuthSessionLimit() { auth.oauthMutex.Lock() defer auth.oauthMutex.Unlock() - if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions { - - cleanupIds := make([]string, 0, OAuthCleanupCount) + if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions { + return + } - for range OAuthCleanupCount { - oldestId := "" - oldestTime := int64(0) + type entry struct { + id string + expiresAt int64 + } - for id, session := range auth.oauthPendingSessions { - if oldestTime == 0 { - oldestId = id - oldestTime = session.ExpiresAt.Unix() - continue - } - if slices.Contains(cleanupIds, id) { - continue - } - if session.ExpiresAt.Unix() < oldestTime { - oldestId = id - oldestTime = session.ExpiresAt.Unix() - } - } + entries := make([]entry, 0, len(auth.oauthPendingSessions)) + for id, session := range auth.oauthPendingSessions { + entries = append(entries, entry{id, session.ExpiresAt.Unix()}) + } - cleanupIds = append(cleanupIds, oldestId) + slices.SortFunc(entries, func(a, b entry) int { + if a.expiresAt < b.expiresAt { + return -1 } - - for _, id := range cleanupIds { - delete(auth.oauthPendingSessions, id) + if a.expiresAt > b.expiresAt { + return 1 } + return 0 + }) + + for _, e := range entries[:OAuthCleanupCount] { + delete(auth.oauthPendingSessions, e.id) } } func (auth *AuthService) lockdownMode() { ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - auth.lockdownCtx = ctx - auth.lockdownCancelFunc = cancel auth.loginMutex.Lock() + if auth.lockdown != nil && auth.lockdown.Active { + auth.loginMutex.Unlock() + cancel() + return + } + + auth.lockdownCtx = ctx + auth.lockdownCancelFunc = cancel + auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.lockdown = &Lockdown{ @@ -825,10 +828,12 @@ func (auth *AuthService) lockdownMode() { auth.loginAttempts = make(map[string]*LoginAttempt) timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil)) - defer timer.Stop() auth.loginMutex.Unlock() + defer cancel() + defer timer.Stop() + select { case <-timer.C: // Timer expired, end lockdown diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 0def3143..dc0b7c08 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -26,6 +26,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: config.Insecure, + MinVersion: tls.VersionTLS12, }, }, }