Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions backend/app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,14 +910,26 @@ async def oauth_callback(
raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported")

try:
# Exchange code for token
token_data = await auth_provider.exchange_code_for_token(data.code)
# Exchange code for token (pass redirect_uri for OAuth2 providers that require it)
if hasattr(auth_provider, 'exchange_code_for_token') and data.redirect_uri:
token_data = await auth_provider.exchange_code_for_token(data.code, redirect_uri=data.redirect_uri)
else:
token_data = await auth_provider.exchange_code_for_token(data.code)
access_token = token_data.get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="Failed to get access token from provider")

# Get user info
user_info = await auth_provider.get_user_info(access_token)
# Get user info with fallback to token_data extraction
try:
user_info = await auth_provider.get_user_info(access_token)
except Exception:
if hasattr(auth_provider, 'get_user_info_from_token_data'):
user_info = await auth_provider.get_user_info_from_token_data(token_data)
else:
raise
if not user_info.provider_user_id and hasattr(auth_provider, 'get_user_info_from_token_data'):
# try token_data as last resort
user_info = await auth_provider.get_user_info_from_token_data(token_data)

# Find or create user
user, is_new = await auth_provider.find_or_create_user(db, user_info)
Expand Down
16 changes: 15 additions & 1 deletion backend/app/api/enterprise.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ class OAuth2Config(BaseModel):
token_url: str | None = None # OAuth2 token endpoint
user_info_url: str | None = None # OAuth2 user info endpoint
scope: str | None = "openid profile email"
field_mapping: dict | None = None # Custom field name mapping

def to_config_dict(self) -> dict:
"""Convert to config dict with both naming conventions for compatibility."""
Expand All @@ -811,6 +812,8 @@ def to_config_dict(self) -> dict:
config["user_info_url"] = self.user_info_url
if self.scope:
config["scope"] = self.scope
if self.field_mapping:
config["field_mapping"] = self.field_mapping
return config

@classmethod
Expand All @@ -823,6 +826,7 @@ def from_config_dict(cls, config: dict) -> "OAuth2Config":
token_url=config.get("token_url"),
user_info_url=config.get("user_info_url"),
scope=config.get("scope"),
field_mapping=config.get("field_mapping"),
)


Expand All @@ -837,6 +841,7 @@ class IdentityProviderOAuth2Create(BaseModel):
token_url: str
user_info_url: str
scope: str | None = "openid profile email"
field_mapping: dict | None = None
tenant_id: uuid.UUID | None = None


Expand Down Expand Up @@ -936,6 +941,7 @@ async def create_oauth2_provider(
token_url=data.token_url,
user_info_url=data.user_info_url,
scope=data.scope,
field_mapping=data.field_mapping,
)
config = oauth_config.to_config_dict()

Expand All @@ -962,6 +968,7 @@ async def create_oauth2_provider(
provider_type="oauth2",
name=data.name,
is_active=data.is_active,
sso_login_enabled=True,
config=config,
tenant_id=tid
)
Expand All @@ -981,6 +988,7 @@ class OAuth2ConfigUpdate(BaseModel):
token_url: str | None = None
user_info_url: str | None = None
scope: str | None = None
field_mapping: dict | None = None # Custom field name mapping


@router.patch("/identity-providers/{provider_id}/oauth2", response_model=IdentityProviderOut)
Expand Down Expand Up @@ -1009,7 +1017,7 @@ async def update_oauth2_provider(
provider.is_active = data.is_active

# Update config fields
if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope]):
if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope, data.field_mapping is not None]):
current_config = provider.config.copy()

if data.app_id is not None:
Expand All @@ -1031,6 +1039,12 @@ async def update_oauth2_provider(
current_config["user_info_url"] = data.user_info_url
if data.scope is not None:
current_config["scope"] = data.scope
if data.field_mapping is not None:
# Empty dict or explicit None clears the mapping
if data.field_mapping:
current_config["field_mapping"] = data.field_mapping
else:
current_config.pop("field_mapping", None)

# Validate the updated config
validate_provider_config("oauth2", current_config)
Expand Down
91 changes: 91 additions & 0 deletions backend/app/api/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from urllib.parse import quote

from fastapi import APIRouter, Depends, HTTPException, Request, status
from loguru import logger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

Expand Down Expand Up @@ -141,5 +142,95 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
url = f"https://open.work.weixin.qq.com/wwopen/sso/qrConnect?appid={corp_id}&agentid={agent_id}&redirect_uri={quote(redir)}&state={sid}"
auth_urls.append({"provider_type": "wecom", "name": p.name, "url": url})

elif p.provider_type == "oauth2":
from app.services.auth_registry import auth_provider_registry
auth_provider = await auth_provider_registry.get_provider(
db, "oauth2", str(session.tenant_id) if session.tenant_id else None
)
if auth_provider:
redir = f"{public_base}/api/auth/oauth2/callback"
url = await auth_provider.get_authorization_url(redir, str(sid))
auth_urls.append({"provider_type": "oauth2", "name": p.name, "url": url})

return auth_urls


@router.get("/auth/oauth2/callback")
async def oauth2_callback(
code: str,
state: str = None,
db: AsyncSession = Depends(get_db),
):
"""Handle OAuth2 SSO callback -- exchange code for user session."""
from app.core.security import create_access_token
from fastapi.responses import HTMLResponse
from app.services.auth_registry import auth_provider_registry

# 1. Resolve tenant context from state (= session ID)
tenant_id = None
sid = None
if state:
try:
sid = uuid.UUID(state)
s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid))
session = s_res.scalar_one_or_none()
if session:
tenant_id = session.tenant_id
except (ValueError, AttributeError):
pass

# 2. Get OAuth2 provider
auth_provider = await auth_provider_registry.get_provider(
db, "oauth2", str(tenant_id) if tenant_id else None
)
if not auth_provider:
return HTMLResponse("Auth failed: OAuth2 provider not configured")

# 3. Exchange code -> token -> user info -> find/create user
try:
token_data = await auth_provider.exchange_code_for_token(code)
access_token = token_data.get("access_token")
if not access_token:
logger.error("OAuth2 token exchange returned no access_token")
return HTMLResponse("Auth failed: token exchange error")

user_info = await auth_provider.get_user_info(access_token)
if not user_info.provider_user_id:
logger.error("OAuth2 user info missing user ID")
return HTMLResponse("Auth failed: no user ID returned")

user, is_new = await auth_provider.find_or_create_user(
db, user_info, tenant_id=str(tenant_id) if tenant_id else None
)
if not user:
return HTMLResponse("Auth failed: user resolution failed")

except Exception as e:
logger.error("OAuth2 login error: %s", e)
return HTMLResponse(f"Auth failed: {e!s}")

# 4. Generate JWT, update SSO session
token = create_access_token(str(user.id), user.role)

if sid:
try:
s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid))
session = s_res.scalar_one_or_none()
if session:
session.status = "authorized"
session.provider_type = "oauth2"
session.user_id = user.id
session.access_token = token
session.error_msg = None
await db.commit()
return HTMLResponse(
'<html><head><meta charset="utf-8" /></head>'
'<body><div>SSO login successful. Redirecting...</div>'
f'<script>window.location.href = "/sso/entry?sid={sid}&complete=1";</script>'
'</body></html>'
)
except Exception as e:
logger.exception("Failed to update SSO session (oauth2): %s", e)

return HTMLResponse("Logged in successfully.")

1 change: 1 addition & 0 deletions backend/app/schemas/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class OAuthAuthorizeResponse(BaseModel):
class OAuthCallbackRequest(BaseModel):
code: str
state: str
redirect_uri: str = ""


class IdentityBindRequest(BaseModel):
Expand Down
149 changes: 149 additions & 0 deletions backend/app/services/auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,154 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo:
)


class OAuth2AuthProvider(BaseAuthProvider):
"""Generic OAuth2 provider (RFC 6749 Authorization Code flow).

Works with any OAuth2-compliant identity provider (Google, Azure AD,
Keycloak, Auth0, custom corporate OAuth2 servers, etc.).

Config keys:
client_id, client_secret, authorize_url, token_url, userinfo_url,
scope, field_mapping
"""

provider_type = "oauth2"

def __init__(self, provider=None, config=None):
super().__init__(provider, config)
self.client_id = self.config.get("client_id") or self.config.get("app_id", "")
self.client_secret = self.config.get("client_secret") or self.config.get("app_secret", "")
self.authorize_url = self.config.get("authorize_url", "")
self.token_url = self.config.get("token_url", "")
self.userinfo_url = self.config.get("userinfo_url") or self.config.get("user_info_url", "")
self.scope = self.config.get("scope", "openid profile email")

# Derive token_url / userinfo_url from authorize_url if not provided
if self.authorize_url and not self.token_url:
base = self.authorize_url.rsplit("/", 1)[0]
self.token_url = f"{base}/token"
if self.authorize_url and not self.userinfo_url:
base = self.authorize_url.rsplit("/", 1)[0]
self.userinfo_url = f"{base}/userinfo"

# Configurable field mapping: provider response field -> Clawith field
self.field_mapping = self.config.get("field_mapping") or {}

async def get_authorization_url(self, redirect_uri: str, state: str) -> str:
from urllib.parse import quote, urlencode
params = urlencode({
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": redirect_uri,
"scope": self.scope,
"state": state,
})
return f"{self.authorize_url}?{params}"

async def exchange_code_for_token(self, code: str, redirect_uri: str = "") -> dict:
"""Exchange authorization code for access token.

Uses application/x-www-form-urlencoded per RFC 6749 Section 4.1.3.
"""
async with httpx.AsyncClient() as client:
data = {
"grant_type": "authorization_code",
"code": code,
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if redirect_uri:
data["redirect_uri"] = redirect_uri
resp = await client.post(
self.token_url,
data=data,
)
if resp.status_code != 200:
logger.error(
"OAuth2 token exchange failed (HTTP %s): %s",
resp.status_code,
resp.text[:500],
)
return {}
return resp.json()

def _resolve_field(self, data: dict, clawith_field: str, default_keys: list[str]) -> str:
"""Resolve a field using user-configured mapping first, then standard fallbacks."""
custom_key = self.field_mapping.get(clawith_field)
if custom_key and data.get(custom_key):
return str(data[custom_key])
for key in default_keys:
if data.get(key):
return str(data[key])
return ""

async def get_user_info(self, access_token: str) -> ExternalUserInfo:
async with httpx.AsyncClient() as client:
resp = await client.get(
self.userinfo_url,
headers={"Authorization": f"Bearer {access_token}"},
)

# Gracefully handle non-200 responses
if resp.status_code != 200:
logger.error("OAuth2 userinfo returned HTTP %s", resp.status_code)
return ExternalUserInfo(
provider_type=self.provider_type,
provider_user_id="",
raw_data={"error": f"userinfo HTTP {resp.status_code}"},
)

try:
resp_data = resp.json()
except Exception:
logger.error("OAuth2 userinfo response is not valid JSON")
return ExternalUserInfo(
provider_type=self.provider_type,
provider_user_id="",
raw_data={"error": "userinfo JSON parse error"},
)

# Some providers wrap payload in {"data": {...}}; standard OIDC is flat
if "data" in resp_data and isinstance(resp_data["data"], dict):
info = resp_data["data"]
else:
info = resp_data

user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId"])
name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname", "userName"])
email = self._resolve_field(info, "email", ["email"])
mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"])
avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"])

return ExternalUserInfo(
provider_type=self.provider_type,
provider_user_id=str(user_id),
name=name,
email=email,
mobile=mobile,
avatar_url=avatar,
raw_data=info,
)

async def get_user_info_from_token_data(self, token_data: dict) -> ExternalUserInfo:
"""Fallback: extract user info directly from token response data."""
info = token_data
user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId", "openid"])
name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname"])
email = self._resolve_field(info, "email", ["email"])
mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"])
avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"])
return ExternalUserInfo(
provider_type=self.provider_type,
provider_user_id=str(user_id),
name=name,
email=email,
mobile=mobile,
avatar_url=avatar,
raw_data=info,
)


class MicrosoftTeamsAuthProvider(BaseAuthProvider):
"""Microsoft Teams OAuth provider implementation."""

Expand All @@ -658,5 +806,6 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo:
"feishu": FeishuAuthProvider,
"dingtalk": DingTalkAuthProvider,
"wecom": WeComAuthProvider,
"oauth2": OAuth2AuthProvider,
"microsoft_teams": MicrosoftTeamsAuthProvider,
}
1 change: 1 addition & 0 deletions backend/app/services/auth_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DingTalkAuthProvider,
FeishuAuthProvider,
MicrosoftTeamsAuthProvider,
OAuth2AuthProvider,
WeComAuthProvider,
)

Expand Down
Loading