Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions backend/biz/team/handler/http/v1/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package v1

import (
"log/slog"

"github.com/GoYoko/web"
"github.com/samber/do"

"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/middleware"
)

type TeamOIDCHandler struct {
logger *slog.Logger
usecase domain.TeamOIDCUsecase
}

func NewTeamOIDCHandler(i *do.Injector) (*TeamOIDCHandler, error) {
w := do.MustInvoke[*web.Web](i)
auth := do.MustInvoke[*middleware.AuthMiddleware](i)
logger := do.MustInvoke[*slog.Logger](i)

h := &TeamOIDCHandler{
logger: logger.With("module", "handler.team_oidc"),
usecase: do.MustInvoke[domain.TeamOIDCUsecase](i),
}

g := w.Group("/api/v1/teams/oidc")
g.GET("", web.BaseHandler(h.Get), auth.TeamAuth())
g.PUT("", web.BindHandler(h.Save), auth.TeamAuth())
g.POST("/test", web.BindHandler(h.Test), auth.TeamAuth())

return h, nil
}

// Get 获取团队 OIDC 配置
//
// @Summary 获取团队 OIDC 配置
// @Description 获取当前团队企业登录 OIDC 配置
// @Tags 【Team 管理员】企业登录
// @Accept json
// @Produce json
// @Security MonkeyCodeAITeamAuth
// @Success 200 {object} web.Resp{data=domain.TeamOIDCConfigResp}
// @Router /api/v1/teams/oidc [get]
func (h *TeamOIDCHandler) Get(c *web.Context) error {
resp, err := h.usecase.GetConfig(c.Request().Context(), middleware.GetTeamUser(c))
if err != nil {
return err
}
return c.Success(resp)
}

// Save 保存团队 OIDC 配置
//
// @Summary 保存团队 OIDC 配置
// @Description 新增或更新当前团队企业登录 OIDC 配置
// @Tags 【Team 管理员】企业登录
// @Accept json
// @Produce json
// @Security MonkeyCodeAITeamAuth
// @Param req body domain.SaveTeamOIDCConfigReq true "请求参数"
// @Success 200 {object} web.Resp{data=domain.TeamOIDCConfigResp}
// @Router /api/v1/teams/oidc [put]
func (h *TeamOIDCHandler) Save(c *web.Context, req domain.SaveTeamOIDCConfigReq) error {
resp, err := h.usecase.SaveConfig(c.Request().Context(), middleware.GetTeamUser(c), &req)
if err != nil {
return err
}
return c.Success(resp)
}

// Test 测试团队 OIDC 配置
//
// @Summary 测试团队 OIDC 配置
// @Description 拉取 OIDC discovery 文档验证配置可用性
// @Tags 【Team 管理员】企业登录
// @Accept json
// @Produce json
// @Security MonkeyCodeAITeamAuth
// @Param req body domain.SaveTeamOIDCConfigReq true "请求参数"
// @Success 200 {object} web.Resp{data=domain.TeamOIDCTestResp}
// @Router /api/v1/teams/oidc/test [post]
func (h *TeamOIDCHandler) Test(c *web.Context, req domain.SaveTeamOIDCConfigReq) error {
resp, err := h.usecase.TestConfig(c.Request().Context(), middleware.GetTeamUser(c), &req)
if err != nil {
return err
}
return c.Success(resp)
}
53 changes: 53 additions & 0 deletions backend/biz/team/handler/http/v1/oidc_route_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package v1

import (
"context"
"io"
"log/slog"
"testing"

"github.com/GoYoko/web"
"github.com/samber/do"

"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/middleware"
)

func TestNewTeamOIDCHandlerRegistersRoutes(t *testing.T) {
injector := do.New()
w := web.New()
do.ProvideValue(injector, w)
do.ProvideValue(injector, slog.New(slog.NewTextHandler(io.Discard, nil)))
do.ProvideValue[domain.TeamOIDCUsecase](injector, &teamOIDCUsecaseStub{})
do.ProvideValue(injector, &middleware.AuthMiddleware{})
do.ProvideValue(injector, middleware.NewTargetActiveMiddleware(slog.New(slog.NewTextHandler(io.Discard, nil)), nil))

if _, err := NewTeamOIDCHandler(injector); err != nil {
t.Fatal(err)
}

want := map[string]bool{
"GET /api/v1/teams/oidc": false,
"PUT /api/v1/teams/oidc": false,
"POST /api/v1/teams/oidc/test": false,
}
for _, route := range w.Routes() {
key := route.Method + " " + route.Path
if _, ok := want[key]; ok {
want[key] = true
}
}
for key, ok := range want {
if !ok {
t.Fatalf("route %s is not registered", key)
}
}
}

type teamOIDCUsecaseStub struct {
domain.TeamOIDCUsecase
}

func (s *teamOIDCUsecaseStub) GetConfig(ctx context.Context, teamUser *domain.TeamUser) (*domain.TeamOIDCConfigResp, error) {
return &domain.TeamOIDCConfigResp{}, nil
}
5 changes: 5 additions & 0 deletions backend/biz/team/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ func ProvideTeam(i *do.Injector) {
do.Provide(i, repo.NewTeamPolicyRepo)
do.Provide(i, usecase.NewTeamPolicyUsecase)
do.Provide(i, v1.NewTeamPolicyHandler)
do.Provide(i, repo.NewTeamOIDCRepo)
do.Provide(i, usecase.NewTeamOIDCUsecase)
do.Provide(i, usecase.NewTeamOIDCLoginUsecase)
do.Provide(i, v1.NewTeamOIDCHandler)
do.Provide(i, v1.NewTeamGroupUserHandler)
}

Expand All @@ -45,4 +49,5 @@ func InvokeTeam(i *do.Injector) {
do.MustInvoke[*v1.TeamImageHandler](i)
do.MustInvoke[*v1.TeamHostHandler](i)
do.MustInvoke[*v1.TeamPolicyHandler](i)
do.MustInvoke[*v1.TeamOIDCHandler](i)
}
208 changes: 208 additions & 0 deletions backend/biz/team/repo/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package repo

import (
"context"
"strings"

"entgo.io/ent/dialect/sql"
"github.com/google/uuid"
"github.com/samber/do"

"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/db/team"
"github.com/chaitin/MonkeyCode/backend/db/teamgroupmember"
"github.com/chaitin/MonkeyCode/backend/db/teammember"
"github.com/chaitin/MonkeyCode/backend/db/teamoidcconfig"
"github.com/chaitin/MonkeyCode/backend/db/user"
"github.com/chaitin/MonkeyCode/backend/db/useridentity"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/errcode"
"github.com/chaitin/MonkeyCode/backend/pkg/entx"
"github.com/chaitin/MonkeyCode/backend/pkg/oidc"
)

type TeamOIDCRepo struct {
db *db.Client
}

func NewTeamOIDCRepo(i *do.Injector) (domain.TeamOIDCRepo, error) {
return &TeamOIDCRepo{db: do.MustInvoke[*db.Client](i)}, nil
}

func (r *TeamOIDCRepo) GetConfig(ctx context.Context, teamID uuid.UUID) (*db.TeamOIDCConfig, error) {
return r.db.TeamOIDCConfig.Query().Where(teamoidcconfig.TeamIDEQ(teamID)).First(ctx)
}

func (r *TeamOIDCRepo) GetDefaultEnabledConfig(ctx context.Context) (*db.TeamOIDCConfig, error) {
return r.db.TeamOIDCConfig.Query().
Where(teamoidcconfig.EnabledEQ(true)).
WithTeam().
Modify(func(s *sql.Selector) {
t := sql.Table(team.Table)
s.Join(t).On(s.C(teamoidcconfig.FieldTeamID), t.C(team.FieldID))
s.OrderBy(t.C(team.FieldCreatedAt), s.C(teamoidcconfig.FieldCreatedAt))
}).
First(ctx)
}

func (r *TeamOIDCRepo) UpsertConfig(ctx context.Context, teamID uuid.UUID, req *domain.SaveTeamOIDCConfigReq) (*db.TeamOIDCConfig, error) {
issuer := oidc.CleanIssuer(req.Issuer)
scopes := strings.TrimSpace(req.Scopes)
if scopes == "" {
scopes = "openid email profile"
}
displayName := strings.TrimSpace(req.DisplayName)
if displayName == "" {
displayName = "企业登录"
}
create := r.db.TeamOIDCConfig.Create().
SetID(uuid.New()).
SetTeamID(teamID).
SetEnabled(req.Enabled).
SetDisplayName(displayName).
SetIssuer(issuer).
SetClientID(strings.TrimSpace(req.ClientID)).
SetScopes(scopes).
SetEmailDomain(strings.TrimSpace(strings.ToLower(req.EmailDomain))).
SetAutoCreateMember(req.AutoCreateMember).
SetAllowPasswordLogin(req.AllowPasswordLogin)
if req.ClientSecret != "" {
create.SetClientSecretCiphertext(req.ClientSecret)
}
id, err := create.
OnConflictColumns(teamoidcconfig.FieldTeamID).
Update(func(upsert *db.TeamOIDCConfigUpsert) {
upsert.SetEnabled(req.Enabled)
upsert.SetDisplayName(displayName)
upsert.SetIssuer(issuer)
upsert.SetClientID(strings.TrimSpace(req.ClientID))
upsert.SetScopes(scopes)
upsert.SetEmailDomain(strings.TrimSpace(strings.ToLower(req.EmailDomain)))
upsert.SetAutoCreateMember(req.AutoCreateMember)
upsert.SetAllowPasswordLogin(req.AllowPasswordLogin)
if req.ClientSecret != "" {
upsert.SetClientSecretCiphertext(req.ClientSecret)
}
}).
ID(ctx)
if err != nil {
return nil, err
}
return r.db.TeamOIDCConfig.Get(ctx, id)
}

func (r *TeamOIDCRepo) FindUserByOIDCIdentity(ctx context.Context, identityID string) (*db.User, error) {
identity, err := r.db.UserIdentity.Query().
Where(useridentity.PlatformEQ(consts.UserPlatformOIDC), useridentity.IdentityIDEQ(identityID)).
WithUser().
First(ctx)
if err != nil {
return nil, err
}
return identity.Edges.User, nil
}

func (r *TeamOIDCRepo) FindTeamMemberByEmail(ctx context.Context, teamID uuid.UUID, email string) (*db.TeamMember, error) {
return r.db.TeamMember.Query().
Where(
teammember.TeamIDEQ(teamID),
teammember.HasUserWith(user.EmailEQ(normalizeEmail(email))),
).
WithUser().
First(ctx)
}

func (r *TeamOIDCRepo) BindOIDCIdentity(ctx context.Context, userID uuid.UUID, external *domain.OIDCExternalUser) error {
name := external.Name
if name == "" {
name = external.Username
}
if name == "" {
name = external.Email
}
return r.db.UserIdentity.Create().
SetID(uuid.New()).
SetUserID(userID).
SetPlatform(consts.UserPlatformOIDC).
SetIdentityID(oidc.IdentityID(external.Issuer, external.Subject)).
SetUsername(name).
SetEmail(external.Email).
SetAvatarURL(external.AvatarURL).
OnConflict(
sql.ConflictColumns(useridentity.FieldPlatform, useridentity.FieldIdentityID),
sql.ResolveWithIgnore(),
).
Exec(ctx)
}

func (r *TeamOIDCRepo) AutoCreateMember(ctx context.Context, teamID uuid.UUID, external *domain.OIDCExternalUser) (*db.User, error) {
var created *db.User
err := entx.WithTx2(ctx, r.db, func(tx *db.Tx) error {
tm, err := tx.Team.Get(ctx, teamID)
if err != nil {
return err
}
count, err := tx.TeamMember.Query().Where(teammember.TeamIDEQ(teamID)).Count(ctx)
if err != nil {
return err
}
if count >= tm.MemberLimit {
return errcode.ErrTeamMemberLimitExceeded
}
email := normalizeEmail(external.Email)
u, err := tx.User.Query().Where(user.EmailEQ(email), user.RoleEQ(consts.UserRoleSubAccount)).First(ctx)
if err != nil {
if !db.IsNotFound(err) {
return err
}
name := external.Name
if name == "" {
name = external.Username
}
if name == "" {
name = email
}
u, err = tx.User.Create().
SetID(uuid.New()).
SetName(name).
SetEmail(email).
SetAvatarURL(external.AvatarURL).
SetRole(consts.UserRoleSubAccount).
SetStatus(consts.UserStatusActive).
Save(ctx)
if err != nil {
return err
}
}
exists, err := tx.TeamMember.Query().Where(teammember.TeamIDEQ(teamID), teammember.UserIDEQ(u.ID)).Exist(ctx)
if err != nil {
return err
}
if !exists {
if _, err := tx.TeamMember.Create().SetID(uuid.New()).SetTeamID(teamID).SetUserID(u.ID).SetRole(consts.TeamMemberRoleUser).Save(ctx); err != nil {
return err
}
}
group, err := ensureDefaultTeamGroupTx(ctx, tx, teamID)
if err != nil {
return err
}
exists, err = tx.TeamGroupMember.Query().Where(teamgroupmember.GroupIDEQ(group.ID), teamgroupmember.UserIDEQ(u.ID)).Exist(ctx)
if err != nil {
return err
}
if !exists {
if err := tx.TeamGroupMember.Create().SetID(uuid.New()).SetGroupID(group.ID).SetUserID(u.ID).Exec(ctx); err != nil {
return err
}
}
created = u
return nil
})
return created, err
}

func normalizeEmail(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
Loading