diff --git a/core/group/service.go b/core/group/service.go index 98e149675..d2bc189e5 100644 --- a/core/group/service.go +++ b/core/group/service.go @@ -74,7 +74,12 @@ func (s Service) Create(ctx context.Context, grp Group) (Group, error) { return Group{}, err } - if err = s.membershipService.OnGroupCreated(ctx, newGroup.ID, newGroup.OrganizationID, principal.ID, principal.Type); err != nil { + // PAT → resolve to underlying user so ownership is on the user, not the token + subjectID, subjectType := principal.ResolveSubject() + if err = s.membershipService.OnGroupCreated(ctx, newGroup.ID, newGroup.OrganizationID, subjectID, subjectType); err != nil { + if cleanupErr := s.repository.Delete(ctx, newGroup.ID); cleanupErr != nil { + return Group{}, errors.Join(err, fmt.Errorf("rollback group create: %w", cleanupErr)) + } return Group{}, err } diff --git a/core/group/service_test.go b/core/group/service_test.go index 1a8c0efd7..f5a7da59f 100644 --- a/core/group/service_test.go +++ b/core/group/service_test.go @@ -68,7 +68,63 @@ func TestService_Create(t *testing.T) { assert.Equal(t, strings.Contains(err.Error(), authenticate.ErrInvalidID.Error()), true) }) - t.Run("should propagate error from membership.OnGroupCreated", func(t *testing.T) { + t.Run("PAT creator resolves to underlying user before OnGroupCreated", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockRelationSvc := mocks.NewRelationService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) + + userID := uuid.New().String() + patID := uuid.New().String() + mockAuthnSvc.On("GetPrincipal", mock.Anything).Return(authenticate.Principal{ + ID: patID, + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: patID, UserID: userID}, + }, nil) + + groupParam := group.Group{Name: "g", OrganizationID: uuid.New().String()} + groupInRepo := groupParam + groupInRepo.ID = uuid.New().String() + mockRepo.On("Create", mock.Anything, groupParam).Return(groupInRepo, nil) + mockMembershipSvc.EXPECT().OnGroupCreated(mock.Anything, groupInRepo.ID, groupInRepo.OrganizationID, userID, schema.UserPrincipal).Return(nil) + + _, err := svc.Create(context.Background(), groupParam) + assert.NoError(t, err) + }) + + t.Run("OnGroupCreated failure rolls back the Postgres group row", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockRelationSvc := mocks.NewRelationService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) + + mockUserID := uuid.New().String() + mockAuthnSvc.On("GetPrincipal", mock.Anything).Return(authenticate.Principal{ + ID: mockUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: mockUserID}, + }, nil) + + groupParam := group.Group{Name: "g", OrganizationID: uuid.New().String()} + groupInRepo := groupParam + groupInRepo.ID = uuid.New().String() + mockRepo.On("Create", mock.Anything, groupParam).Return(groupInRepo, nil) + mockMembershipSvc.EXPECT().OnGroupCreated(mock.Anything, groupInRepo.ID, groupInRepo.OrganizationID, mockUserID, schema.UserPrincipal).Return(errors.New("spicedb down")) + mockRepo.On("Delete", mock.Anything, groupInRepo.ID).Return(nil).Once() + + _, err := svc.Create(context.Background(), groupParam) + assert.ErrorContains(t, err, "spicedb down") + }) + + t.Run("rollback failure surfaces both errors", func(t *testing.T) { mockRepo := mocks.NewRepository(t) mockAuthnSvc := mocks.NewAuthnService(t) mockRelationSvc := mocks.NewRelationService(t) @@ -90,9 +146,12 @@ func TestService_Create(t *testing.T) { groupInRepo.ID = uuid.New().String() mockRepo.On("Create", mock.Anything, groupParam).Return(groupInRepo, nil) mockMembershipSvc.EXPECT().OnGroupCreated(mock.Anything, groupInRepo.ID, groupInRepo.OrganizationID, mockUserID, schema.UserPrincipal).Return(errors.New("spicedb down")) + mockRepo.On("Delete", mock.Anything, groupInRepo.ID).Return(errors.New("pg gone")).Once() _, err := svc.Create(context.Background(), groupParam) assert.ErrorContains(t, err, "spicedb down") + assert.ErrorContains(t, err, "rollback group create") + assert.ErrorContains(t, err, "pg gone") }) }