From 022211153eb840df8f169088c108d3721a46bd46 Mon Sep 17 00:00:00 2001 From: "nap.liu" Date: Wed, 8 Apr 2026 20:12:04 +0800 Subject: [PATCH] feat: add generic OAuth2 SSO login with configurable field mapping Add a generic OAuth2AuthProvider that works with any OAuth2-compliant identity provider (Google, Azure AD, Keycloak, Auth0, custom corporate OAuth2 servers, etc.). Backend: - New OAuth2AuthProvider class with configurable authorize_url, token_url, userinfo_url, client_id, client_secret, scope, and field_mapping - Token exchange uses application/x-www-form-urlencoded (RFC 6749) - Graceful handling of userinfo 401/empty/invalid responses - Configurable field_mapping maps provider fields to Clawith fields (provider_user_id, email, display_name, mobile, avatar_url) - Standard OIDC field fallbacks when no custom mapping is configured - Provider registered in auth_registry as "oauth2" - SSO callback route (GET /auth/oauth2/callback) with session handling - OAuth2 provider type added to SSO config endpoint Frontend: - OAuth2 configuration form with Token URL, UserInfo URL, Scope fields - Field Mapping section for custom provider field names - Save/update via dedicated OAuth2 API endpoints Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/api/auth.py | 20 +- backend/app/api/enterprise.py | 16 +- backend/app/api/sso.py | 91 ++++ backend/app/schemas/schemas.py | 1 + backend/app/services/auth_provider.py | 149 +++++ backend/app/services/auth_registry.py | 1 + backend/app/services/dingtalk_stream.py | 636 +++++++++++++++++++--- frontend/src/i18n/en.json | 25 +- frontend/src/i18n/zh.json | 25 +- frontend/src/pages/EnterpriseSettings.tsx | 62 ++- 10 files changed, 945 insertions(+), 81 deletions(-) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index ef76429b5..dc74c7f7c 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -910,14 +910,26 @@ async def oauth_callback( raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported") try: - # Exchange code for token - token_data = await auth_provider.exchange_code_for_token(data.code) + # Exchange code for token (pass redirect_uri for OAuth2 providers that require it) + if hasattr(auth_provider, 'exchange_code_for_token') and data.redirect_uri: + token_data = await auth_provider.exchange_code_for_token(data.code, redirect_uri=data.redirect_uri) + else: + token_data = await auth_provider.exchange_code_for_token(data.code) access_token = token_data.get("access_token") if not access_token: raise HTTPException(status_code=400, detail="Failed to get access token from provider") - # Get user info - user_info = await auth_provider.get_user_info(access_token) + # Get user info with fallback to token_data extraction + try: + user_info = await auth_provider.get_user_info(access_token) + except Exception: + if hasattr(auth_provider, 'get_user_info_from_token_data'): + user_info = await auth_provider.get_user_info_from_token_data(token_data) + else: + raise + if not user_info.provider_user_id and hasattr(auth_provider, 'get_user_info_from_token_data'): + # try token_data as last resort + user_info = await auth_provider.get_user_info_from_token_data(token_data) # Find or create user user, is_new = await auth_provider.find_or_create_user(db, user_info) diff --git a/backend/app/api/enterprise.py b/backend/app/api/enterprise.py index 25d98c802..05f29b1de 100644 --- a/backend/app/api/enterprise.py +++ b/backend/app/api/enterprise.py @@ -793,6 +793,7 @@ class OAuth2Config(BaseModel): token_url: str | None = None # OAuth2 token endpoint user_info_url: str | None = None # OAuth2 user info endpoint scope: str | None = "openid profile email" + field_mapping: dict | None = None # Custom field name mapping def to_config_dict(self) -> dict: """Convert to config dict with both naming conventions for compatibility.""" @@ -811,6 +812,8 @@ def to_config_dict(self) -> dict: config["user_info_url"] = self.user_info_url if self.scope: config["scope"] = self.scope + if self.field_mapping: + config["field_mapping"] = self.field_mapping return config @classmethod @@ -823,6 +826,7 @@ def from_config_dict(cls, config: dict) -> "OAuth2Config": token_url=config.get("token_url"), user_info_url=config.get("user_info_url"), scope=config.get("scope"), + field_mapping=config.get("field_mapping"), ) @@ -837,6 +841,7 @@ class IdentityProviderOAuth2Create(BaseModel): token_url: str user_info_url: str scope: str | None = "openid profile email" + field_mapping: dict | None = None tenant_id: uuid.UUID | None = None @@ -936,6 +941,7 @@ async def create_oauth2_provider( token_url=data.token_url, user_info_url=data.user_info_url, scope=data.scope, + field_mapping=data.field_mapping, ) config = oauth_config.to_config_dict() @@ -962,6 +968,7 @@ async def create_oauth2_provider( provider_type="oauth2", name=data.name, is_active=data.is_active, + sso_login_enabled=True, config=config, tenant_id=tid ) @@ -981,6 +988,7 @@ class OAuth2ConfigUpdate(BaseModel): token_url: str | None = None user_info_url: str | None = None scope: str | None = None + field_mapping: dict | None = None # Custom field name mapping @router.patch("/identity-providers/{provider_id}/oauth2", response_model=IdentityProviderOut) @@ -1009,7 +1017,7 @@ async def update_oauth2_provider( provider.is_active = data.is_active # Update config fields - if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope]): + if any([data.app_id, data.app_secret is not None, data.authorize_url, data.token_url, data.user_info_url, data.scope, data.field_mapping is not None]): current_config = provider.config.copy() if data.app_id is not None: @@ -1031,6 +1039,12 @@ async def update_oauth2_provider( current_config["user_info_url"] = data.user_info_url if data.scope is not None: current_config["scope"] = data.scope + if data.field_mapping is not None: + # Empty dict or explicit None clears the mapping + if data.field_mapping: + current_config["field_mapping"] = data.field_mapping + else: + current_config.pop("field_mapping", None) # Validate the updated config validate_provider_config("oauth2", current_config) diff --git a/backend/app/api/sso.py b/backend/app/api/sso.py index 1c5210247..bb7679c93 100644 --- a/backend/app/api/sso.py +++ b/backend/app/api/sso.py @@ -4,6 +4,7 @@ from urllib.parse import quote from fastapi import APIRouter, Depends, HTTPException, Request, status +from loguru import logger from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -141,5 +142,95 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De url = f"https://open.work.weixin.qq.com/wwopen/sso/qrConnect?appid={corp_id}&agentid={agent_id}&redirect_uri={quote(redir)}&state={sid}" auth_urls.append({"provider_type": "wecom", "name": p.name, "url": url}) + elif p.provider_type == "oauth2": + from app.services.auth_registry import auth_provider_registry + auth_provider = await auth_provider_registry.get_provider( + db, "oauth2", str(session.tenant_id) if session.tenant_id else None + ) + if auth_provider: + redir = f"{public_base}/api/auth/oauth2/callback" + url = await auth_provider.get_authorization_url(redir, str(sid)) + auth_urls.append({"provider_type": "oauth2", "name": p.name, "url": url}) + return auth_urls + +@router.get("/auth/oauth2/callback") +async def oauth2_callback( + code: str, + state: str = None, + db: AsyncSession = Depends(get_db), +): + """Handle OAuth2 SSO callback -- exchange code for user session.""" + from app.core.security import create_access_token + from fastapi.responses import HTMLResponse + from app.services.auth_registry import auth_provider_registry + + # 1. Resolve tenant context from state (= session ID) + tenant_id = None + sid = None + if state: + try: + sid = uuid.UUID(state) + s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid)) + session = s_res.scalar_one_or_none() + if session: + tenant_id = session.tenant_id + except (ValueError, AttributeError): + pass + + # 2. Get OAuth2 provider + auth_provider = await auth_provider_registry.get_provider( + db, "oauth2", str(tenant_id) if tenant_id else None + ) + if not auth_provider: + return HTMLResponse("Auth failed: OAuth2 provider not configured") + + # 3. Exchange code -> token -> user info -> find/create user + try: + token_data = await auth_provider.exchange_code_for_token(code) + access_token = token_data.get("access_token") + if not access_token: + logger.error("OAuth2 token exchange returned no access_token") + return HTMLResponse("Auth failed: token exchange error") + + user_info = await auth_provider.get_user_info(access_token) + if not user_info.provider_user_id: + logger.error("OAuth2 user info missing user ID") + return HTMLResponse("Auth failed: no user ID returned") + + user, is_new = await auth_provider.find_or_create_user( + db, user_info, tenant_id=str(tenant_id) if tenant_id else None + ) + if not user: + return HTMLResponse("Auth failed: user resolution failed") + + except Exception as e: + logger.error("OAuth2 login error: %s", e) + return HTMLResponse(f"Auth failed: {e!s}") + + # 4. Generate JWT, update SSO session + token = create_access_token(str(user.id), user.role) + + if sid: + try: + s_res = await db.execute(select(SSOScanSession).where(SSOScanSession.id == sid)) + session = s_res.scalar_one_or_none() + if session: + session.status = "authorized" + session.provider_type = "oauth2" + session.user_id = user.id + session.access_token = token + session.error_msg = None + await db.commit() + return HTMLResponse( + '' + '
SSO login successful. Redirecting...
' + f'' + '' + ) + except Exception as e: + logger.exception("Failed to update SSO session (oauth2): %s", e) + + return HTMLResponse("Logged in successfully.") + diff --git a/backend/app/schemas/schemas.py b/backend/app/schemas/schemas.py index 3870392b9..f0da50b62 100644 --- a/backend/app/schemas/schemas.py +++ b/backend/app/schemas/schemas.py @@ -180,6 +180,7 @@ class OAuthAuthorizeResponse(BaseModel): class OAuthCallbackRequest(BaseModel): code: str state: str + redirect_uri: str = "" class IdentityBindRequest(BaseModel): diff --git a/backend/app/services/auth_provider.py b/backend/app/services/auth_provider.py index d40cd583b..0d2543098 100644 --- a/backend/app/services/auth_provider.py +++ b/backend/app/services/auth_provider.py @@ -637,6 +637,154 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo: ) +class OAuth2AuthProvider(BaseAuthProvider): + """Generic OAuth2 provider (RFC 6749 Authorization Code flow). + + Works with any OAuth2-compliant identity provider (Google, Azure AD, + Keycloak, Auth0, custom corporate OAuth2 servers, etc.). + + Config keys: + client_id, client_secret, authorize_url, token_url, userinfo_url, + scope, field_mapping + """ + + provider_type = "oauth2" + + def __init__(self, provider=None, config=None): + super().__init__(provider, config) + self.client_id = self.config.get("client_id") or self.config.get("app_id", "") + self.client_secret = self.config.get("client_secret") or self.config.get("app_secret", "") + self.authorize_url = self.config.get("authorize_url", "") + self.token_url = self.config.get("token_url", "") + self.userinfo_url = self.config.get("userinfo_url") or self.config.get("user_info_url", "") + self.scope = self.config.get("scope", "openid profile email") + + # Derive token_url / userinfo_url from authorize_url if not provided + if self.authorize_url and not self.token_url: + base = self.authorize_url.rsplit("/", 1)[0] + self.token_url = f"{base}/token" + if self.authorize_url and not self.userinfo_url: + base = self.authorize_url.rsplit("/", 1)[0] + self.userinfo_url = f"{base}/userinfo" + + # Configurable field mapping: provider response field -> Clawith field + self.field_mapping = self.config.get("field_mapping") or {} + + async def get_authorization_url(self, redirect_uri: str, state: str) -> str: + from urllib.parse import quote, urlencode + params = urlencode({ + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": redirect_uri, + "scope": self.scope, + "state": state, + }) + return f"{self.authorize_url}?{params}" + + async def exchange_code_for_token(self, code: str, redirect_uri: str = "") -> dict: + """Exchange authorization code for access token. + + Uses application/x-www-form-urlencoded per RFC 6749 Section 4.1.3. + """ + async with httpx.AsyncClient() as client: + data = { + "grant_type": "authorization_code", + "code": code, + "client_id": self.client_id, + "client_secret": self.client_secret, + } + if redirect_uri: + data["redirect_uri"] = redirect_uri + resp = await client.post( + self.token_url, + data=data, + ) + if resp.status_code != 200: + logger.error( + "OAuth2 token exchange failed (HTTP %s): %s", + resp.status_code, + resp.text[:500], + ) + return {} + return resp.json() + + def _resolve_field(self, data: dict, clawith_field: str, default_keys: list[str]) -> str: + """Resolve a field using user-configured mapping first, then standard fallbacks.""" + custom_key = self.field_mapping.get(clawith_field) + if custom_key and data.get(custom_key): + return str(data[custom_key]) + for key in default_keys: + if data.get(key): + return str(data[key]) + return "" + + async def get_user_info(self, access_token: str) -> ExternalUserInfo: + async with httpx.AsyncClient() as client: + resp = await client.get( + self.userinfo_url, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + # Gracefully handle non-200 responses + if resp.status_code != 200: + logger.error("OAuth2 userinfo returned HTTP %s", resp.status_code) + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id="", + raw_data={"error": f"userinfo HTTP {resp.status_code}"}, + ) + + try: + resp_data = resp.json() + except Exception: + logger.error("OAuth2 userinfo response is not valid JSON") + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id="", + raw_data={"error": "userinfo JSON parse error"}, + ) + + # Some providers wrap payload in {"data": {...}}; standard OIDC is flat + if "data" in resp_data and isinstance(resp_data["data"], dict): + info = resp_data["data"] + else: + info = resp_data + + user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId"]) + name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname", "userName"]) + email = self._resolve_field(info, "email", ["email"]) + mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"]) + avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"]) + + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id=str(user_id), + name=name, + email=email, + mobile=mobile, + avatar_url=avatar, + raw_data=info, + ) + + async def get_user_info_from_token_data(self, token_data: dict) -> ExternalUserInfo: + """Fallback: extract user info directly from token response data.""" + info = token_data + user_id = self._resolve_field(info, "provider_user_id", ["sub", "id", "user_id", "userId", "openid"]) + name = self._resolve_field(info, "display_name", ["name", "preferred_username", "nickname"]) + email = self._resolve_field(info, "email", ["email"]) + mobile = self._resolve_field(info, "mobile", ["phone_number", "mobile", "phone"]) + avatar = self._resolve_field(info, "avatar_url", ["picture", "avatar_url", "avatar"]) + return ExternalUserInfo( + provider_type=self.provider_type, + provider_user_id=str(user_id), + name=name, + email=email, + mobile=mobile, + avatar_url=avatar, + raw_data=info, + ) + + class MicrosoftTeamsAuthProvider(BaseAuthProvider): """Microsoft Teams OAuth provider implementation.""" @@ -658,5 +806,6 @@ async def get_user_info(self, access_token: str) -> ExternalUserInfo: "feishu": FeishuAuthProvider, "dingtalk": DingTalkAuthProvider, "wecom": WeComAuthProvider, + "oauth2": OAuth2AuthProvider, "microsoft_teams": MicrosoftTeamsAuthProvider, } diff --git a/backend/app/services/auth_registry.py b/backend/app/services/auth_registry.py index c3e9b0035..3c655bade 100644 --- a/backend/app/services/auth_registry.py +++ b/backend/app/services/auth_registry.py @@ -15,6 +15,7 @@ DingTalkAuthProvider, FeishuAuthProvider, MicrosoftTeamsAuthProvider, + OAuth2AuthProvider, WeComAuthProvider, ) diff --git a/backend/app/services/dingtalk_stream.py b/backend/app/services/dingtalk_stream.py index 28a8ba8e2..73cf80b20 100644 --- a/backend/app/services/dingtalk_stream.py +++ b/backend/app/services/dingtalk_stream.py @@ -5,15 +5,401 @@ """ import asyncio +import base64 +import json import threading import uuid -from typing import Dict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import httpx from loguru import logger from sqlalchemy import select +from app.config import get_settings from app.database import async_session from app.models.channel_config import ChannelConfig +from app.services.dingtalk_token import dingtalk_token_manager + + +# ─── DingTalk Media Helpers ───────────────────────────── + + +async def _get_media_download_url( + access_token: str, download_code: str, robot_code: str +) -> Optional[str]: + """Get media file download URL from DingTalk API.""" + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.post( + "https://api.dingtalk.com/v1.0/robot/messageFiles/download", + headers={"x-acs-dingtalk-access-token": access_token}, + json={"downloadCode": download_code, "robotCode": robot_code}, + ) + data = resp.json() + url = data.get("downloadUrl") + if url: + return url + logger.error(f"[DingTalk] Failed to get download URL: {data}") + return None + except Exception as e: + logger.error(f"[DingTalk] Error getting download URL: {e}") + return None + + +async def _download_file(url: str) -> Optional[bytes]: + """Download a file from a URL and return its bytes.""" + try: + async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client: + resp = await client.get(url) + resp.raise_for_status() + return resp.content + except Exception as e: + logger.error(f"[DingTalk] Error downloading file: {e}") + return None + + +async def download_dingtalk_media( + app_key: str, app_secret: str, download_code: str +) -> Optional[bytes]: + """Download a media file from DingTalk using downloadCode. + + Steps: get access_token -> get download URL -> download file bytes. + """ + access_token = await dingtalk_token_manager.get_token(app_key, app_secret) + if not access_token: + return None + + download_url = await _get_media_download_url(access_token, download_code, app_key) + if not download_url: + return None + + return await _download_file(download_url) + + +def _resolve_upload_dir(agent_id: uuid.UUID) -> Path: + """Get the uploads directory for an agent, creating it if needed.""" + settings = get_settings() + upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" + upload_dir.mkdir(parents=True, exist_ok=True) + return upload_dir + + +async def _process_media_message( + msg_data: dict, + app_key: str, + app_secret: str, + agent_id: uuid.UUID, +) -> Tuple[str, Optional[List[str]], Optional[List[str]]]: + """Process a DingTalk message and extract text + media info. + + Returns: + (user_text, image_base64_list, saved_file_paths) + - user_text: text content for the LLM (may include markers) + - image_base64_list: list of base64-encoded image data URIs, or None + - saved_file_paths: list of saved file paths, or None + """ + msgtype = msg_data.get("msgtype", "text") + logger.info(f"[DingTalk] Processing message type: {msgtype}") + + image_base64_list: List[str] = [] + saved_file_paths: List[str] = [] + + if msgtype == "text": + text_content = msg_data.get("text", {}).get("content", "").strip() + return text_content, None, None + + elif msgtype == "picture": + download_code = msg_data.get("content", {}).get("downloadCode", "") + if not download_code: + download_code = msg_data.get("downloadCode", "") + if not download_code: + logger.warning("[DingTalk] Picture message without downloadCode") + return "[User sent an image, but it could not be downloaded]", None, None + + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if not file_bytes: + return "[User sent an image, but download failed]", None, None + + upload_dir = _resolve_upload_dir(agent_id) + filename = f"dingtalk_img_{uuid.uuid4().hex[:8]}.jpg" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved image to {save_path} ({len(file_bytes)} bytes)") + + b64_data = base64.b64encode(file_bytes).decode("ascii") + image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" + return ( + f"[User sent an image]\n{image_marker}", + [f"data:image/jpeg;base64,{b64_data}"], + [str(save_path)], + ) + + elif msgtype == "richText": + rich_text = msg_data.get("content", {}).get("richText", []) + text_parts: List[str] = [] + + for section in rich_text: + for item in section if isinstance(section, list) else [section]: + if "text" in item: + text_parts.append(item["text"]) + elif "downloadCode" in item: + file_bytes = await download_dingtalk_media( + app_key, app_secret, item["downloadCode"] + ) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + filename = f"dingtalk_richimg_{uuid.uuid4().hex[:8]}.jpg" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved rich text image to {save_path}") + + b64_data = base64.b64encode(file_bytes).decode("ascii") + image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" + text_parts.append(image_marker) + image_base64_list.append(f"data:image/jpeg;base64,{b64_data}") + saved_file_paths.append(str(save_path)) + + combined_text = "\n".join(text_parts).strip() + if not combined_text: + combined_text = "[User sent a rich text message]" + + return ( + combined_text, + image_base64_list if image_base64_list else None, + saved_file_paths if saved_file_paths else None, + ) + + elif msgtype == "audio": + content = msg_data.get("content", {}) + recognition = content.get("recognition", "") + if recognition: + logger.info(f"[DingTalk] Audio with recognition: {recognition[:80]}") + return f"[Voice message] {recognition}", None, None + + download_code = content.get("downloadCode", "") + if download_code: + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + duration = content.get("duration", "unknown") + filename = f"dingtalk_audio_{uuid.uuid4().hex[:8]}.amr" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved audio to {save_path} ({len(file_bytes)} bytes)") + return ( + f"[User sent a voice message, duration {duration}ms, saved to {filename}]", + None, + [str(save_path)], + ) + return "[User sent a voice message, but it could not be processed]", None, None + + elif msgtype == "video": + content = msg_data.get("content", {}) + download_code = content.get("downloadCode", "") + if download_code: + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + duration = content.get("duration", "unknown") + filename = f"dingtalk_video_{uuid.uuid4().hex[:8]}.mp4" + save_path = upload_dir / filename + save_path.write_bytes(file_bytes) + logger.info(f"[DingTalk] Saved video to {save_path} ({len(file_bytes)} bytes)") + return ( + f"[User sent a video, duration {duration}ms, saved to {filename}]", + None, + [str(save_path)], + ) + return "[User sent a video, but it could not be downloaded]", None, None + + elif msgtype == "file": + content = msg_data.get("content", {}) + download_code = content.get("downloadCode", "") + original_filename = content.get("fileName", "unknown_file") + if download_code: + file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) + if file_bytes: + upload_dir = _resolve_upload_dir(agent_id) + safe_name = f"dingtalk_{uuid.uuid4().hex[:8]}_{original_filename}" + save_path = upload_dir / safe_name + save_path.write_bytes(file_bytes) + logger.info( + f"[DingTalk] Saved file '{original_filename}' to {save_path} " + f"({len(file_bytes)} bytes)" + ) + return ( + f"[file:{original_filename}]", + None, + [str(save_path)], + ) + return f"[User sent file {original_filename}, but it could not be downloaded]", None, None + + else: + logger.warning(f"[DingTalk] Unsupported message type: {msgtype}") + return f"[User sent a {msgtype} message, which is not yet supported]", None, None + + +# ─── DingTalk Media Upload & Send ─────────────────────── + +async def _upload_dingtalk_media( + app_key: str, + app_secret: str, + file_path: str, + media_type: str = "file", +) -> Optional[str]: + """Upload a media file to DingTalk and return the mediaId. + + Args: + app_key: DingTalk app key (robotCode). + app_secret: DingTalk app secret. + file_path: Local file path to upload. + media_type: One of 'image', 'voice', 'video', 'file'. + + Returns: + mediaId string on success, None on failure. + """ + access_token = await dingtalk_token_manager.get_token(app_key, app_secret) + if not access_token: + return None + + file_p = Path(file_path) + if not file_p.exists(): + logger.error(f"[DingTalk] Upload failed: file not found: {file_path}") + return None + + try: + file_bytes = file_p.read_bytes() + async with httpx.AsyncClient(timeout=60) as client: + # Use the legacy oapi endpoint which is more reliable and widely supported. + upload_url = ( + f"https://oapi.dingtalk.com/media/upload" + f"?access_token={access_token}&type={media_type}" + ) + resp = await client.post( + upload_url, + files={"media": (file_p.name, file_bytes)}, + ) + data = resp.json() + # Legacy API returns media_id (snake_case), new API returns mediaId + media_id = data.get("media_id") or data.get("mediaId") + if media_id and data.get("errcode", 0) == 0: + logger.info( + f"[DingTalk] Uploaded {media_type} '{file_p.name}' -> mediaId={media_id[:20]}..." + ) + return media_id + logger.error(f"[DingTalk] Upload failed: {data}") + return None + except Exception as e: + logger.error(f"[DingTalk] Upload error: {e}") + return None + + +async def _send_dingtalk_media_message( + app_key: str, + app_secret: str, + target_id: str, + media_id: str, + media_type: str, + conversation_type: str, + filename: Optional[str] = None, +) -> bool: + """Send a media message via DingTalk proactive message API. + + Args: + app_key: DingTalk app key (robotCode). + app_secret: DingTalk app secret. + target_id: For P2P: sender_staff_id; For group: openConversationId. + media_id: The mediaId from upload. + media_type: One of 'image', 'voice', 'video', 'file'. + conversation_type: '1' for P2P, '2' for group. + filename: Original filename (used for file/video types). + + Returns: + True on success, False on failure. + """ + access_token = await dingtalk_token_manager.get_token(app_key, app_secret) + if not access_token: + return False + + headers = {"x-acs-dingtalk-access-token": access_token} + + # Build msgKey and msgParam based on media_type + if media_type == "image": + msg_key = "sampleImageMsg" + msg_param = json.dumps({"photoURL": media_id}) + elif media_type == "voice": + msg_key = "sampleAudio" + msg_param = json.dumps({"mediaId": media_id, "duration": "3000"}) + elif media_type == "video": + safe_name = filename or "video.mp4" + ext = Path(safe_name).suffix.lstrip(".") or "mp4" + msg_key = "sampleFile" + msg_param = json.dumps({ + "mediaId": media_id, + "fileName": safe_name, + "fileType": ext, + }) + else: + # file + safe_name = filename or "file" + ext = Path(safe_name).suffix.lstrip(".") or "bin" + msg_key = "sampleFile" + msg_param = json.dumps({ + "mediaId": media_id, + "fileName": safe_name, + "fileType": ext, + }) + + try: + async with httpx.AsyncClient(timeout=15) as client: + if conversation_type == "2": + # Group chat + resp = await client.post( + "https://api.dingtalk.com/v1.0/robot/groupMessages/send", + headers=headers, + json={ + "robotCode": app_key, + "openConversationId": target_id, + "msgKey": msg_key, + "msgParam": msg_param, + }, + ) + else: + # P2P chat + resp = await client.post( + "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend", + headers=headers, + json={ + "robotCode": app_key, + "userIds": [target_id], + "msgKey": msg_key, + "msgParam": msg_param, + }, + ) + + data = resp.json() + if resp.status_code >= 400 or data.get("errcode"): + logger.error(f"[DingTalk] Send media failed: {data}") + return False + + logger.info( + f"[DingTalk] Sent {media_type} message to {target_id[:16]}... " + f"(conv_type={conversation_type})" + ) + return True + except Exception as e: + logger.error(f"[DingTalk] Send media error: {e}") + return False + + +# ─── Stream Manager ───────────────────────────────────── + + +def _fire_and_forget(loop, coro): + """Schedule a coroutine on the main loop and log any unhandled exception.""" + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.add_done_callback(lambda f: f.exception() if not f.cancelled() else None) class DingTalkStreamManager: @@ -67,47 +453,72 @@ def _run_client_thread( app_secret: str, stop_event: threading.Event, ): - """Run the DingTalk Stream client in a blocking thread.""" + """Run the DingTalk Stream client with auto-reconnect.""" try: import dingtalk_stream + except ImportError: + logger.warning( + "[DingTalk Stream] dingtalk-stream package not installed. " + "Install with: pip install dingtalk-stream" + ) + self._threads.pop(agent_id, None) + self._stop_events.pop(agent_id, None) + return - # Reference to manager's main loop for async dispatch - main_loop = self._main_loop - - class ClawithChatbotHandler(dingtalk_stream.ChatbotHandler): - """Custom handler that dispatches messages to the Clawith LLM pipeline.""" - - async def process(self, callback: dingtalk_stream.CallbackMessage): - """Handle incoming bot message from DingTalk Stream. - - NOTE: The SDK invokes this method in the thread's own asyncio loop, - so we must dispatch to the main FastAPI loop for DB + LLM work. - """ - try: - # Parse the raw data into a ChatbotMessage via class method - incoming = dingtalk_stream.ChatbotMessage.from_dict(callback.data) - - # Extract text content + MAX_RETRIES = 5 + RETRY_DELAYS = [2, 5, 15, 30, 60] # exponential backoff, seconds + + # Reference to manager's main loop for async dispatch + main_loop = self._main_loop + retries = 0 + manager_self = self + + class ClawithChatbotHandler(dingtalk_stream.ChatbotHandler): + """Custom handler that dispatches messages to the Clawith LLM pipeline.""" + + async def process(self, callback: dingtalk_stream.CallbackMessage): + """Handle incoming bot message from DingTalk Stream. + + NOTE: The SDK invokes this method in the thread's own asyncio loop, + so we must dispatch to the main FastAPI loop for DB + LLM work. + """ + try: + # Parse the raw data + incoming = dingtalk_stream.ChatbotMessage.from_dict(callback.data) + msg_data = callback.data if isinstance(callback.data, dict) else json.loads(callback.data) + + msgtype = msg_data.get("msgtype", "text") + sender_staff_id = incoming.sender_staff_id or incoming.sender_id or "" + sender_nick = incoming.sender_nick or "" + message_id = incoming.message_id or "" + conversation_id = incoming.conversation_id or "" + conversation_type = incoming.conversation_type or "1" + session_webhook = incoming.session_webhook or "" + + logger.info( + f"[DingTalk Stream] Received {msgtype} message from {sender_staff_id}" + ) + + if msgtype == "text": + # Plain text: use existing logic text_list = incoming.get_text_list() user_text = " ".join(text_list).strip() if text_list else "" - if not user_text: return dingtalk_stream.AckMessage.STATUS_OK, "empty message" - sender_staff_id = incoming.sender_staff_id or incoming.sender_id or "" - conversation_id = incoming.conversation_id or "" - conversation_type = incoming.conversation_type or "1" - session_webhook = incoming.session_webhook or "" - logger.info( - f"[DingTalk Stream] Message from [{incoming.sender_nick}]{sender_staff_id}: {user_text[:80]}" + f"[DingTalk Stream] Text from {sender_staff_id}: {user_text[:80]}" ) - # Dispatch to the main FastAPI event loop for DB + LLM processing from app.api.dingtalk import process_dingtalk_message if main_loop and main_loop.is_running(): - future = asyncio.run_coroutine_threadsafe( + # Add thinking reaction immediately + from app.services.dingtalk_reaction import add_thinking_reaction + _fire_and_forget(main_loop, + add_thinking_reaction(app_key, app_secret, message_id, conversation_id)) + + _fire_and_forget(main_loop, process_dingtalk_message( agent_id=agent_id, sender_staff_id=sender_staff_id, @@ -115,50 +526,134 @@ async def process(self, callback: dingtalk_stream.CallbackMessage): conversation_id=conversation_id, conversation_type=conversation_type, session_webhook=session_webhook, - ), - main_loop, - ) - # Wait for result (with timeout) - try: - future.result(timeout=120) - except Exception as e: - logger.error(f"[DingTalk Stream] LLM processing error: {e}") - import traceback - traceback.print_exc() + sender_nick=sender_nick, + message_id=message_id, + )) else: - logger.warning("[DingTalk Stream] Main loop not available for dispatch") - - return dingtalk_stream.AckMessage.STATUS_OK, "ok" - except Exception as e: - logger.error(f"[DingTalk Stream] Error in message handler: {e}") - import traceback - traceback.print_exc() - return dingtalk_stream.AckMessage.STATUS_SYSTEM_EXCEPTION, str(e) - - credential = dingtalk_stream.Credential(client_id=app_key, client_secret=app_secret) - client = dingtalk_stream.DingTalkStreamClient(credential=credential) - client.register_callback_handler( - dingtalk_stream.chatbot.ChatbotMessage.TOPIC, - ClawithChatbotHandler(), - ) + logger.warning("[DingTalk Stream] Main loop not available") - logger.info(f"[DingTalk Stream] Connecting for agent {agent_id}...") - # start_forever() blocks until disconnected - client.start_forever() + else: + # Non-text message: process media in the main loop + if main_loop and main_loop.is_running(): + # Add thinking reaction immediately + from app.services.dingtalk_reaction import add_thinking_reaction + _fire_and_forget(main_loop, + add_thinking_reaction(app_key, app_secret, message_id, conversation_id)) + + _fire_and_forget(main_loop, + manager_self._handle_media_and_dispatch( + msg_data=msg_data, + app_key=app_key, + app_secret=app_secret, + agent_id=agent_id, + sender_staff_id=sender_staff_id, + conversation_id=conversation_id, + conversation_type=conversation_type, + session_webhook=session_webhook, + sender_nick=sender_nick, + message_id=message_id, + )) + else: + logger.warning("[DingTalk Stream] Main loop not available") + + return dingtalk_stream.AckMessage.STATUS_OK, "ok" + except Exception as e: + logger.error(f"[DingTalk Stream] Error in message handler: {e}") + import traceback + traceback.print_exc() + return dingtalk_stream.AckMessage.STATUS_SYSTEM_EXCEPTION, str(e) + + while not stop_event.is_set() and retries <= MAX_RETRIES: + try: + credential = dingtalk_stream.Credential(client_id=app_key, client_secret=app_secret) + client = dingtalk_stream.DingTalkStreamClient(credential=credential) + client.register_callback_handler( + dingtalk_stream.chatbot.ChatbotMessage.TOPIC, + ClawithChatbotHandler(), + ) - except ImportError: - logger.warning( - "[DingTalk Stream] dingtalk-stream package not installed. " - "Install with: pip install dingtalk-stream" + logger.info( + f"[DingTalk Stream] Connecting for agent {agent_id}... " + f"(attempt {retries + 1}/{MAX_RETRIES + 1})" + ) + # start_forever() blocks until disconnected + client.start_forever() + + # start_forever returned: connection dropped + if stop_event.is_set(): + break # intentional stop, no retry + + # Reset retries on successful connection (ran for a while then disconnected) + retries = 0 + retries += 1 + logger.warning( + f"[DingTalk Stream] Connection lost for agent {agent_id}, will retry..." + ) + + except Exception as e: + retries += 1 + logger.error( + f"[DingTalk Stream] Connection error for {agent_id} " + f"(attempt {retries}/{MAX_RETRIES + 1}): {e}" + ) + + if retries > MAX_RETRIES: + logger.error( + f"[DingTalk Stream] Agent {agent_id} exhausted all {MAX_RETRIES} retries, giving up" + ) + break + + delay = RETRY_DELAYS[min(retries - 1, len(RETRY_DELAYS) - 1)] + logger.info( + f"[DingTalk Stream] Retrying in {delay}s for agent {agent_id}..." ) - except Exception as e: - logger.error(f"[DingTalk Stream] Client error for {agent_id}: {e}") - import traceback - traceback.print_exc() - finally: - self._threads.pop(agent_id, None) - self._stop_events.pop(agent_id, None) - logger.info(f"[DingTalk Stream] Client stopped for agent {agent_id}") + # Use stop_event.wait so we exit immediately if stopped + if stop_event.wait(timeout=delay): + break # stop was requested during wait + + self._threads.pop(agent_id, None) + self._stop_events.pop(agent_id, None) + logger.info(f"[DingTalk Stream] Client stopped for agent {agent_id}") + + @staticmethod + async def _handle_media_and_dispatch( + msg_data: dict, + app_key: str, + app_secret: str, + agent_id: uuid.UUID, + sender_staff_id: str, + conversation_id: str, + conversation_type: str, + session_webhook: str, + sender_nick: str = "", + message_id: str = "", + ): + """Download media, then dispatch to process_dingtalk_message.""" + from app.api.dingtalk import process_dingtalk_message + + user_text, image_base64_list, saved_file_paths = await _process_media_message( + msg_data=msg_data, + app_key=app_key, + app_secret=app_secret, + agent_id=agent_id, + ) + + if not user_text: + logger.info("[DingTalk Stream] Empty content after media processing, skipping") + return + + await process_dingtalk_message( + agent_id=agent_id, + sender_staff_id=sender_staff_id, + user_text=user_text, + conversation_id=conversation_id, + conversation_type=conversation_type, + session_webhook=session_webhook, + image_base64_list=image_base64_list, + saved_file_paths=saved_file_paths, + sender_nick=sender_nick, + message_id=message_id, + ) async def stop_client(self, agent_id: uuid.UUID): """Stop a running Stream client for an agent.""" @@ -167,7 +662,10 @@ async def stop_client(self, agent_id: uuid.UUID): stop_event.set() thread = self._threads.pop(agent_id, None) if thread and thread.is_alive(): - logger.info(f"[DingTalk Stream] Stopping client for agent {agent_id}") + logger.info(f"[DingTalk Stream] Stopping client for agent {agent_id}, waiting for thread...") + thread.join(timeout=5) + if thread.is_alive(): + logger.warning(f"[DingTalk Stream] Thread for {agent_id} did not exit within 5s") async def start_all(self): """Start Stream clients for all configured DingTalk agents.""" diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index a5be3fd99..343203216 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -1359,7 +1359,30 @@ "messagingTitle": "3. Proactive 1-to-1 Messaging (AI-initiated)", "messagingDesc": "Sending messages to individual users requires a valid WeCom API access token, which can only be obtained from a server IP that has been whitelisted in the self-built app's settings. Unlike Feishu, WeCom mandates IP-level restrictions on all API calls — there is no token-only authentication option.", "footerText": "Due to the above constraints, WeCom integration currently cannot be easily set up by most users. We are actively exploring alternative approaches — including WeCom ISV (service provider) registration and lower-friction API options — or we may advocate for WeCom to relax these restrictions for SaaS platforms. Configuration will be re-enabled once a viable path is available." - } + }, + "avatarUrlField": "Avatar URL Field", + "avatarUrlFieldPlaceholder": "Default: picture", + "authorizeUrlPlaceholder": "https://sso.example.com/oauth2/authorize", + "clearAllMappings": "Clear All Field Mappings", + "clearField": "Clear this field", + "deleteConfirmProvider": "Are you sure you want to delete this configuration?", + "emailField": "Email Field", + "emailFieldPlaceholder": "Default: email", + "fieldMapping": "Field Mapping", + "fieldMappingHint": "Optional, leave empty to use standard OIDC fields", + "hasCustomMapping": "Custom field mapping configured", + "mobileField": "Mobile Field", + "mobileFieldPlaceholder": "Default: phone_number", + "nameField": "Name Field", + "nameFieldPlaceholder": "Default: name", + "oauth2": "OAuth2", + "oauth2Desc": "Generic OIDC Provider", + "scope": "Scope", + "scopePlaceholder": "openid profile email", + "tokenUrlPlaceholder": "Leave empty to auto-derive from Authorize URL", + "userIdField": "User ID Field", + "userIdFieldPlaceholder": "Default: sub", + "userInfoUrlPlaceholder": "Leave empty to auto-derive from Authorize URL" }, "dangerZone": "️ Danger Zone", "deleteCompanyDesc": "Permanently delete this company and all its data, including agents, models, tools, and skills. This action cannot be undone.", diff --git a/frontend/src/i18n/zh.json b/frontend/src/i18n/zh.json index ee020b0e1..2e7e358a4 100644 --- a/frontend/src/i18n/zh.json +++ b/frontend/src/i18n/zh.json @@ -1525,7 +1525,30 @@ "messagingTitle": "3. AI 主动发送一对一消息", "messagingDesc": "向企微成员主动发消息,需先获取有效的 API access_token,而获取 token 的请求 IP 必须在自建应用的「企业可信IP」白名单中。与飞书不同,企微对所有 API 调用均强制要求 IP 白名单,没有仅凭 token 认证的选项。", "footerText": "由于上述限制,目前大多数用户无法轻松完成企业微信集成配置。我们正在积极探索替代方案——包括申请企微服务商(ISV)资质及寻找限制更少的接入方式——同时也希望企微能够降低对 SaaS 平台的接入门槛。待可行路径明确后,配置入口将重新开启。" - } + }, + "avatarUrlField": "头像地址字段", + "avatarUrlFieldPlaceholder": "默认: picture", + "authorizeUrlPlaceholder": "https://sso.example.com/oauth2/authorize", + "clearAllMappings": "清空所有字段映射", + "clearField": "清空此字段", + "deleteConfirmProvider": "确定要删除此配置吗?", + "emailField": "邮箱字段", + "emailFieldPlaceholder": "默认:email", + "fieldMapping": "字段映射", + "fieldMappingHint": "可选,留空使用标准 OIDC 字段", + "hasCustomMapping": "当前已配置自定义字段映射", + "mobileField": "手机号字段", + "mobileFieldPlaceholder": "默认:phone_number", + "nameField": "姓名字段", + "nameFieldPlaceholder": "默认:name", + "oauth2": "OAuth2", + "oauth2Desc": "通用 OIDC 提供商", + "scope": "Scope", + "scopePlaceholder": "openid profile email", + "tokenUrlPlaceholder": "留空则从授权地址自动推导", + "userIdField": "用户 ID 字段", + "userIdFieldPlaceholder": "默认:sub", + "userInfoUrlPlaceholder": "留空则从授权地址自动推导" } }, "common": { diff --git a/frontend/src/pages/EnterpriseSettings.tsx b/frontend/src/pages/EnterpriseSettings.tsx index f9ade6332..7d5ec4956 100644 --- a/frontend/src/pages/EnterpriseSettings.tsx +++ b/frontend/src/pages/EnterpriseSettings.tsx @@ -419,7 +419,8 @@ function OrgTab({ tenant }: { tenant: any }) { authorize_url: '', token_url: '', user_info_url: '', - scope: 'openid profile email' + scope: 'openid profile email', + field_mapping: {} as Record, }); const currentTenantId = localStorage.getItem('current_tenant_id') || ''; @@ -522,7 +523,8 @@ function OrgTab({ tenant }: { tenant: any }) { authorize_url: config?.authorize_url || '', token_url: config?.token_url || '', user_info_url: config?.user_info_url || '', - scope: config?.scope || 'openid profile email' + scope: config?.scope || 'openid profile email', + field_mapping: config?.field_mapping || {}, }); const save = () => { @@ -565,7 +567,8 @@ function OrgTab({ tenant }: { tenant: any }) { name: nameMap[type] || type, config: defaults[type] || {}, app_id: '', app_secret: '', authorize_url: '', token_url: '', user_info_url: '', - scope: 'openid profile email' + scope: 'openid profile email', + field_mapping: {}, }); } setSelectedDept(null); @@ -650,8 +653,57 @@ function OrgTab({ tenant }: { tenant: any }) { setForm({ ...form, app_secret: e.target.value })} />
- - setForm({ ...form, authorize_url: e.target.value })} /> + + setForm({ ...form, authorize_url: e.target.value })} placeholder={t('enterprise.identity.authorizeUrlPlaceholder', 'https://sso.example.com/oauth2/authorize')} /> +
+
+ + setForm({ ...form, token_url: e.target.value })} placeholder={t('enterprise.identity.tokenUrlPlaceholder', 'Leave empty to auto-derive from Authorize URL')} /> +
+
+ + setForm({ ...form, user_info_url: e.target.value })} placeholder={t('enterprise.identity.userInfoUrlPlaceholder', 'Leave empty to auto-derive from Authorize URL')} /> +
+
+ + setForm({ ...form, scope: e.target.value })} placeholder={t('enterprise.identity.scopePlaceholder', 'openid profile email')} /> +
+
+
+ {t('enterprise.identity.fieldMapping', 'Field Mapping')} ({t('enterprise.identity.fieldMappingHint', 'Optional, leave empty to use standard OIDC fields')}) +
+
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, provider_user_id: e.target.value } })} + placeholder={t('enterprise.identity.userIdFieldPlaceholder', 'Default: sub')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, display_name: e.target.value } })} + placeholder={t('enterprise.identity.nameFieldPlaceholder', 'Default: name')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, email: e.target.value } })} + placeholder={t('enterprise.identity.emailFieldPlaceholder', 'Default: email')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, mobile: e.target.value } })} + placeholder={t('enterprise.identity.mobileFieldPlaceholder', 'Default: phone_number')} style={{ fontSize: '12px' }} /> +
+
+ + setForm({ ...form, field_mapping: { ...form.field_mapping, avatar_url: e.target.value } })} + placeholder={t('enterprise.identity.avatarUrlFieldPlaceholder', 'Default: picture')} style={{ fontSize: '12px' }} /> +
+
) : type === 'wecom' ? (