diff --git a/src/drs/auth.py b/src/drs/auth.py index 051f365..68ea573 100644 --- a/src/drs/auth.py +++ b/src/drs/auth.py @@ -19,10 +19,10 @@ import os from pathlib import Path -from typing import Any +from typing import Any, Literal import yaml -from pydantic import BaseModel +from pydantic import BaseModel, model_validator DEFAULT_CONFIG_PATH = Path.home() / ".config" / "dremioai" / "config.yaml" DEFAULT_URI = "https://api.dremio.cloud" @@ -30,8 +30,16 @@ class DrsConfig(BaseModel): uri: str = DEFAULT_URI - pat: str + pat: str | None = None project_id: str + auth_method: Literal["pat", "oauth"] = "pat" + oauth_access_token: str | None = None + + @model_validator(mode="after") + def _require_credential(self) -> "DrsConfig": + if not self.pat and not self.oauth_access_token: + raise ValueError("Either 'pat' or 'oauth_access_token' must be provided.") + return self def load_config( @@ -58,6 +66,7 @@ def load_config( "uri": raw.get("uri", raw.get("endpoint")), "pat": raw.get("pat", raw.get("token")), "project_id": raw.get("project_id", raw.get("projectId")), + "auth_method": raw.get("auth_method"), } file_values = {k: v for k, v in file_values.items() if v is not None} @@ -83,4 +92,15 @@ def load_config( if cli_token: merged["pat"] = cli_token + # If no PAT is available, try loading OAuth tokens from the token store. + if "pat" not in merged or not merged["pat"]: + from drs import token_store + + uri = merged.get("uri", DEFAULT_URI) + tokens = token_store.load(uri) + if tokens is not None: + merged["auth_method"] = "oauth" + merged["oauth_access_token"] = tokens.access_token + merged.pop("pat", None) + return DrsConfig(**merged) diff --git a/src/drs/cli.py b/src/drs/cli.py index 8a86e4d..34ece06 100644 --- a/src/drs/cli.py +++ b/src/drs/cli.py @@ -34,6 +34,7 @@ folder, grant, job, + login, project, query, reflection, @@ -69,6 +70,8 @@ app.add_typer(project.app, name="project") app.add_typer(chat.app, name="chat") app.command("setup")(setup.setup_command) +app.command("login")(login.login_command) +app.command("logout")(login.logout_command) # Global state for config _config: DrsConfig | None = None @@ -137,8 +140,9 @@ def get_config() -> DrsConfig: Console(stderr=True).print( "\n[bold red]Configuration required[/bold red]\n\n" - "The Dremio CLI needs a Personal Access Token and Project ID.\n\n" - " [bold]Quick setup:[/] Run [bold cyan]dremio setup[/bold cyan]\n\n" + "The Dremio CLI needs authentication credentials and a Project ID.\n\n" + " [bold]Quick setup:[/] Run [bold cyan]dremio setup[/bold cyan]\n" + " [bold]OAuth login:[/] Run [bold cyan]dremio login[/bold cyan]\n\n" " [dim]Or provide credentials manually:[/dim]\n" " --token / DREMIO_TOKEN env var\n" " --project-id / DREMIO_PROJECT_ID env var\n" diff --git a/src/drs/client.py b/src/drs/client.py index e1ab86a..10d2a1d 100644 --- a/src/drs/client.py +++ b/src/drs/client.py @@ -44,9 +44,10 @@ class DremioClient: def __init__(self, config: DrsConfig) -> None: self.config = config + token = config.oauth_access_token if "oauth" == config.auth_method else config.pat self._client = httpx.AsyncClient( headers={ - "Authorization": f"Bearer {config.pat}", + "Authorization": f"Bearer {token}", "Content-Type": "application/json", }, timeout=120.0, @@ -70,12 +71,46 @@ def _v1(self, path: str) -> str: # -- HTTP helpers with retry -- + async def _refresh_oauth_token(self) -> None: + """Refresh the OAuth access token using the stored refresh token.""" + from drs import oauth, token_store + + tokens = token_store.load(self.config.uri) + if tokens is None or tokens.refresh_token is None: + raise RuntimeError("No refresh token available — please run 'dremio login' again.") + + metadata = oauth.discover(self.config.uri) + new_tokens = oauth.refresh_access_token( + metadata.token_endpoint, + tokens.client_id, + tokens.client_secret, + tokens.refresh_token, + ) + token_store.save(self.config.uri, new_tokens) + self._client.headers["authorization"] = f"Bearer {new_tokens.access_token}" + logger.info("OAuth access token refreshed successfully.") + async def _request_with_retry(self, method: str, url: str, **kwargs: Any) -> httpx.Response: """Execute an HTTP request with retry on transient errors.""" + auth_refreshed = False last_exc: Exception | None = None for attempt in range(_MAX_RETRIES): try: resp = await self._client.request(method, url, **kwargs) + # 401 with OAuth: attempt one token refresh per call + if ( + resp.status_code == 401 + and "oauth" == self.config.auth_method + and not auth_refreshed + ): + logger.info("Received 401 — attempting OAuth token refresh.") + try: + await self._refresh_oauth_token() + except Exception: + logger.warning("OAuth token refresh failed.", exc_info=True) + return resp + auth_refreshed = True + continue if resp.status_code in _RETRYABLE_STATUS_CODES and attempt < _MAX_RETRIES - 1: delay = _RETRY_BACKOFF[attempt] logger.warning( diff --git a/src/drs/commands/login.py b/src/drs/commands/login.py new file mode 100644 index 0000000..35dbacb --- /dev/null +++ b/src/drs/commands/login.py @@ -0,0 +1,216 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""``dremio login`` and ``dremio logout`` commands — OAuth browser flow.""" + +from __future__ import annotations + +from pathlib import Path +from urllib.parse import urlparse + +import httpx +import typer +import yaml +from rich.console import Console +from rich.panel import Panel + +from drs import oauth, token_store +from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI + +console = Console() +err_console = Console(stderr=True) + + +def _resolve_uri(ctx: typer.Context, explicit_uri: str | None = None) -> str: + """Resolve the Dremio API URI from explicit arg, CLI flags, config file, or default.""" + if explicit_uri: + return explicit_uri + + # Check if parent set cli_uri (global --uri flag) + from drs.cli import _cli_opts + + cli_uri = _cli_opts.get("cli_uri") + if cli_uri: + return cli_uri + + # Try config file + config_path_obj = ctx.obj.get("config_path") if ctx.obj else None + path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH + if path.exists(): + with path.open() as f: + raw = yaml.safe_load(f) or {} + uri = raw.get("uri", raw.get("endpoint")) + if uri: + return uri + + return DEFAULT_URI + + +def _api_url(uri: str) -> str: + """Derive the API URL from a Dremio URL (app.X -> api.X).""" + parsed = urlparse(uri) + host = parsed.hostname or "" + if host.startswith("app."): + host = "api." + host[4:] + return f"{parsed.scheme}://{host}" + + +_ACTIVE_STATES = {"ACTIVE", "HIBERNATED"} + + +def _fetch_projects(api_base: str, access_token: str) -> list[dict]: + """Fetch active/hibernated projects using the OAuth access token.""" + resp = httpx.get( + f"{api_base}/v0/projects", + headers={"Authorization": f"Bearer {access_token}"}, + timeout=30.0, + ) + resp.raise_for_status() + data = resp.json() + projects = data.get("data", data) if isinstance(data, dict) else data + return [p for p in projects if p.get("state", "").upper() in _ACTIVE_STATES] + + +def _format_date(raw: str | None) -> str: + """Format an ISO timestamp to a short date string.""" + if not raw: + return "" + return raw[:10] # YYYY-MM-DD + + +def _prompt_project_selection(projects: list[dict]) -> str: + """Display a numbered list of projects and let the user choose.""" + console.print() + lines = "[bold]Select a project:[/bold]\n" + for i, proj in enumerate(projects, 1): + name = proj.get("name", "unnamed") + pid = proj.get("id", "") + desc = proj.get("description", "") + state = proj.get("state", "") + created = _format_date(proj.get("createdAt")) + lines += f"\n [cyan]{i}[/cyan]) [bold]{name}[/bold] [dim]({pid})[/dim]" + if desc: + lines += f"\n {desc}" + details = [s for s in [state, f"created {created}" if created else ""] if s] + if details: + lines += f"\n [dim]{' · '.join(details)}[/dim]" + console.print(Panel(lines, title="Projects", border_style="blue")) + choice = typer.prompt(f"Enter 1-{len(projects)}").strip() + try: + idx = int(choice) - 1 + if 0 <= idx < len(projects): + selected = projects[idx] + console.print(f" -> [bold]{selected.get('name')}[/bold]") + return selected["id"] + except (ValueError, KeyError): + pass + err_console.print("[yellow]Invalid choice — please enter a project ID manually.[/yellow]") + return typer.prompt("Enter your Dremio Cloud Project ID").strip() + + +def _resolve_project_id(ctx: typer.Context, uri: str, access_token: str) -> str: + """Resolve project_id from CLI flags, config, or interactive project picker.""" + from drs.cli import _cli_opts + + cli_project_id = _cli_opts.get("cli_project_id") + if cli_project_id: + return cli_project_id + + config_path_obj = ctx.obj.get("config_path") if ctx.obj else None + path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH + if path.exists(): + with path.open() as f: + raw = yaml.safe_load(f) or {} + project_id = raw.get("project_id", raw.get("projectId")) + if project_id: + return project_id + + # Fetch projects and let the user pick + api_base = _api_url(uri) + try: + projects = _fetch_projects(api_base, access_token) + except Exception: + console.print("[yellow]Could not fetch project list.[/yellow]") + return typer.prompt("Enter your Dremio Cloud Project ID").strip() + + if not projects: + console.print("[yellow]No projects found in this organization.[/yellow]") + return typer.prompt("Enter your Dremio Cloud Project ID").strip() + + if len(projects) == 1: + proj = projects[0] + console.print(f" Auto-selected project: [bold]{proj.get('name')}[/bold] ({proj['id']})") + return proj["id"] + + return _prompt_project_selection(projects) + + +def login_command( + ctx: typer.Context, + uri: str = typer.Option(None, "--uri", "-u", help="Dremio API URL (e.g. https://app.dev.dremio.site)"), +) -> None: + """Log in to Dremio Cloud via OAuth (opens your browser).""" + uri = _resolve_uri(ctx, explicit_uri=uri) + console.print(f"\nLogging in to [bold]{uri}[/bold] ...") + + try: + tokens = oauth.run_login_flow(uri) + except Exception as exc: + err_console.print(f"\n[bold red]Login failed:[/bold red] {exc}") + raise typer.Exit(1) + + # Ensure we have a project_id to write into the config + project_id = _resolve_project_id(ctx, uri, tokens.access_token) + + token_store.save(uri, tokens) + + # Also persist auth_method + project_id in config file so subsequent + # commands pick up OAuth automatically. + config_path_obj = ctx.obj.get("config_path") if ctx.obj else None + config_path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH + _update_config_file(config_path, uri, project_id) + + console.print(f"\n[green]Logged in successfully.[/green] Tokens saved for {uri}") + + +def _update_config_file(config_path: Path, uri: str, project_id: str) -> None: + """Ensure the config file records auth_method=oauth and project_id.""" + data: dict = {} + if config_path.exists(): + with config_path.open() as f: + data = yaml.safe_load(f) or {} + + if uri != DEFAULT_URI: + data["uri"] = uri + data["auth_method"] = "oauth" + data["project_id"] = project_id + # Remove PAT if present — OAuth replaces it. + data.pop("pat", None) + data.pop("token", None) + + config_path.parent.mkdir(parents=True, exist_ok=True) + header = "# Dremio CLI config — generated by 'dremio login'\n" + config_path.write_text(header + yaml.dump(data, default_flow_style=False, sort_keys=False)) + config_path.chmod(0o600) + + +def logout_command( + ctx: typer.Context, + uri: str = typer.Option(None, "--uri", "-u", help="Dremio API URL to log out from"), +) -> None: + """Log out of Dremio Cloud (removes stored OAuth tokens).""" + uri = _resolve_uri(ctx, explicit_uri=uri) + token_store.clear(uri) + console.print(f"Logged out of [bold]{uri}[/bold]. OAuth tokens removed.") diff --git a/src/drs/commands/setup.py b/src/drs/commands/setup.py index fb77ea2..9775e78 100644 --- a/src/drs/commands/setup.py +++ b/src/drs/commands/setup.py @@ -32,6 +32,11 @@ from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI, DrsConfig from drs.client import DremioClient +AUTH_METHODS = { + "1": "oauth", + "2": "pat", +} + REGIONS = { "1": ("US", "https://api.dremio.cloud", "https://app.dremio.cloud"), "2": ("EU", "https://api.eu.dremio.cloud", "https://app.eu.dremio.cloud"), @@ -65,16 +70,25 @@ async def validate_credentials(uri: str, pat: str, project_id: str) -> tuple[boo await client.close() -def write_config(uri: str, pat: str, project_id: str, config_path: Path) -> None: +def write_config( + uri: str, + pat: str | None, + project_id: str, + config_path: Path, + auth_method: str = "pat", +) -> None: """Write YAML config file, creating parent directories as needed.""" data: dict[str, str] = {} if uri != DEFAULT_URI: data["uri"] = uri - data["pat"] = pat + if auth_method != "pat": + data["auth_method"] = auth_method + if pat: + data["pat"] = pat data["project_id"] = project_id config_path.parent.mkdir(parents=True, exist_ok=True) - header = "# Dremio CLI config — generated by 'dremio setup'\n# PAT is stored in plaintext. Keep this file private (mode 600).\n" + header = "# Dremio CLI config — generated by 'dremio setup'\n# Keep this file private (mode 600).\n" config_path.write_text(header + yaml.dump(data, default_flow_style=False, sort_keys=False)) config_path.chmod(0o600) @@ -100,12 +114,33 @@ def _prompt_region() -> tuple[str, str]: return api_uri, app_url +def _prompt_auth_method() -> str: + """Prompt for authentication method. Returns 'oauth' or 'pat'.""" + console.print() + console.print( + Panel( + "[bold]Step 2: Choose authentication method[/bold]\n\n" + " [cyan]1[/cyan]) OAuth (browser login) — recommended\n" + " [cyan]2[/cyan]) PAT (manual personal access token)", + title="Authentication", + border_style="blue", + ) + ) + choice = typer.prompt("Enter 1 or 2", default="1").strip() + if choice not in AUTH_METHODS: + console.print("[yellow]Invalid choice, defaulting to OAuth.[/yellow]") + choice = "1" + method = AUTH_METHODS[choice] + console.print(f" -> Auth method: [bold]{method.upper()}[/bold]") + return method + + def _prompt_pat(app_url: str) -> str: """Prompt for Personal Access Token with step-by-step instructions.""" console.print() console.print( Panel( - "[bold]Step 2: Create a Personal Access Token (PAT)[/bold]\n\n" + "[bold]Step 3: Create a Personal Access Token (PAT)[/bold]\n\n" f" 1. Open [link={app_url}]{app_url}[/link] and sign in\n" " 2. Click your profile icon (bottom-left) → [bold]Account Settings[/bold]\n" " 3. Go to [bold]Personal Access Tokens[/bold]\n" @@ -127,7 +162,7 @@ def _prompt_project_id(app_url: str) -> str: console.print() console.print( Panel( - "[bold]Step 3: Find your Project ID[/bold]\n\n" + "[bold]Find your Project ID[/bold]\n\n" f" 1. Open [link={app_url}]{app_url}[/link]\n" " 2. Select your project from the top-left dropdown\n" " 3. Go to [bold]Project Settings[/bold] → [bold]General[/bold]\n" @@ -174,7 +209,7 @@ def setup_command( "This wizard will help you connect the Dremio CLI to your Dremio Cloud account.\n\n" "You'll need:\n" " • A [bold]Dremio Cloud account[/bold] (sign up at [link=https://app.dremio.cloud]app.dremio.cloud[/link])\n" - " • A [bold]Personal Access Token[/bold] (we'll walk you through creating one)\n" + " • A [bold]Personal Access Token[/bold] or [bold]OAuth browser login[/bold]\n" " • A [bold]Project ID[/bold] (we'll show you where to find it)", title="[bold]Dremio CLI Setup[/bold]", border_style="cyan", @@ -191,53 +226,79 @@ def setup_command( # Step 1: Region api_uri, app_url = _prompt_region() - # Step 2: PAT (with retry loop) - pat = _prompt_pat(app_url) + # Step 2: Auth method + auth_method = _prompt_auth_method() - # Step 3: Project ID (with retry loop) - project_id = _prompt_project_id(app_url) + pat: str | None = None + project_name: str = "" - # Validate - console.print() - with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): - ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) + if auth_method == "oauth": + # OAuth flow — browser login, then project ID + from drs import oauth, token_store - while not ok: - console.print(f"\n[red]✗ {message}[/red]") - if not typer.confirm("Would you like to try again?", default=True): - console.print("Setup cancelled.") + console.print("\n[bold]Starting OAuth browser login...[/bold]") + try: + tokens = oauth.run_login_flow(api_uri) + except Exception as exc: + err_console.print(f"\n[bold red]OAuth login failed:[/bold red] {exc}") raise typer.Exit(1) - if "Authentication" in message: - console.print("[dim]Let's try the PAT again.[/dim]") - pat = _prompt_pat(app_url) - elif "Access denied" in message: - console.print( - "\n [cyan]1[/cyan]) Re-enter PAT (token may lack permissions)\n [cyan]2[/cyan]) Re-enter Project ID" - ) - choice = typer.prompt("Which would you like to fix?", default="1").strip() - if choice == "2": - project_id = _prompt_project_id(app_url) - else: - pat = _prompt_pat(app_url) - elif "Project" in message: - console.print("[dim]Let's try the Project ID again.[/dim]") - project_id = _prompt_project_id(app_url) - else: - console.print("[dim]Let's try the region again.[/dim]") - api_uri, app_url = _prompt_region() - pat = _prompt_pat(app_url) - project_id = _prompt_project_id(app_url) + token_store.save(api_uri, tokens) + console.print("[green]OAuth login successful.[/green]") + + # Step 3: Project ID + project_id = _prompt_project_id(app_url) + project_name = project_id + write_config(api_uri, None, project_id, config_path, auth_method="oauth") + else: + # PAT flow (existing behavior) + pat = _prompt_pat(app_url) + + # Step 3: Project ID (with retry loop) + project_id = _prompt_project_id(app_url) + + # Validate console.print() with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) - # Success — write config - project_name = project_data.get("name", project_id) if project_data else project_id - console.print(f"\n[green]✓ {message}[/green]") + while not ok: + console.print(f"\n[red]x {message}[/red]") + if not typer.confirm("Would you like to try again?", default=True): + console.print("Setup cancelled.") + raise typer.Exit(1) + + if "Authentication" in message: + console.print("[dim]Let's try the PAT again.[/dim]") + pat = _prompt_pat(app_url) + elif "Access denied" in message: + console.print( + "\n [cyan]1[/cyan]) Re-enter PAT (token may lack permissions)\n [cyan]2[/cyan]) Re-enter Project ID" + ) + choice = typer.prompt("Which would you like to fix?", default="1").strip() + if choice == "2": + project_id = _prompt_project_id(app_url) + else: + pat = _prompt_pat(app_url) + elif "Project" in message: + console.print("[dim]Let's try the Project ID again.[/dim]") + project_id = _prompt_project_id(app_url) + else: + console.print("[dim]Let's try the region again.[/dim]") + api_uri, app_url = _prompt_region() + pat = _prompt_pat(app_url) + project_id = _prompt_project_id(app_url) + + console.print() + with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): + ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) + + # Success + project_name = project_data.get("name", project_id) if project_data else project_id + console.print(f"\n[green]v {message}[/green]") - write_config(api_uri, pat, project_id, config_path) + write_config(api_uri, pat, project_id, config_path) console.print() success = Text() diff --git a/src/drs/oauth.py b/src/drs/oauth.py new file mode 100644 index 0000000..c78ecf9 --- /dev/null +++ b/src/drs/oauth.py @@ -0,0 +1,292 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""OAuth 2.0 PKCE browser-login flow for Dremio Cloud.""" + +from __future__ import annotations + +import base64 +import hashlib +import logging +import secrets +import socket +import threading +import time +import webbrowser +from concurrent.futures import Future +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, HTTPServer +from urllib.parse import parse_qs, urlencode, urlparse + +import httpx + +from drs.token_store import OAuthTokens + +logger = logging.getLogger(__name__) + +_CLIENT_ID = "https://connectors.dremio.app/claude" + +_SUCCESS_HTML = """\ + +Dremio CLI + +

Login successful

+

You can close this tab and return to the terminal.

+ +""" + + +@dataclass +class OAuthServerMetadata: + authorization_endpoint: str + token_endpoint: str + registration_endpoint: str | None = None + + +def _login_url(dremio_url: str) -> str: + """Derive the OAuth login host from a Dremio URL (app.X -> login.X).""" + parsed = urlparse(dremio_url) + host = parsed.hostname or "" + if host.startswith("app."): + host = "login." + host[4:] + return f"{parsed.scheme}://{host}" + + +def discover(dremio_url: str) -> OAuthServerMetadata: + """Fetch OAuth Authorization Server Metadata from *dremio_url*.""" + base = _login_url(dremio_url) + url = f"{base}/.well-known/oauth-authorization-server" + resp = httpx.get(url, timeout=30.0) + resp.raise_for_status() + data = resp.json() + return OAuthServerMetadata( + authorization_endpoint=data["authorization_endpoint"], + token_endpoint=data["token_endpoint"], + registration_endpoint=data.get("registration_endpoint"), + ) + + +def register_client(registration_endpoint: str, redirect_uri: str) -> tuple[str, str | None]: + """Dynamic Client Registration. Returns ``(client_id, client_secret)``. + + Returns ``None`` if the server does not support DCR (403/400). + """ + body = { + "client_name": _CLIENT_ID, + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + } + resp = httpx.post(registration_endpoint, json=body, timeout=30.0) + if resp.status_code in (400, 403): + logger.info("DCR not available (%s) — using well-known client_id.", resp.status_code) + return None, None + resp.raise_for_status() + data = resp.json() + return data["client_id"], data.get("client_secret") + + +def generate_pkce() -> tuple[str, str]: + """Generate PKCE ``(code_verifier, code_challenge)`` pair.""" + code_verifier = secrets.token_urlsafe(32) + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return code_verifier, code_challenge + + +def build_authorization_url( + auth_endpoint: str, + client_id: str, + redirect_uri: str, + code_challenge: str, + state: str, +) -> str: + """Construct the authorization URL the user's browser should open.""" + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "scope": "dremio.all offline_access", + "state": state, + } + return f"{auth_endpoint}?{urlencode(params)}" + + +def find_free_port() -> int: + """Bind to port 0 to obtain a free ephemeral port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def start_callback_server(port: int) -> tuple[HTTPServer, Future[tuple[str, str]]]: + """Start a localhost HTTP server that captures the OAuth callback. + + Returns ``(server, future)`` where the future resolves to ``(code, state)``. + """ + future: Future[tuple[str, str]] = Future() + + class _Handler(BaseHTTPRequestHandler): + def do_GET(self) -> None: # noqa: N802 + qs = parse_qs(urlparse(self.path).query) + code = qs.get("code", [None])[0] + state = qs.get("state", [None])[0] + error = qs.get("error", [None])[0] + + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + + if code and state: + self.wfile.write(_SUCCESS_HTML.encode()) + future.set_result((code, state)) + else: + error_msg = error or "unknown" + self.wfile.write(f"

Login failed: {error_msg}

".encode()) + if not future.done(): + future.set_exception(RuntimeError(f"OAuth callback error: {error_msg}")) + + def log_message(self, format: str, *args: object) -> None: # noqa: A002 + logger.debug(format, *args) + + server = HTTPServer(("127.0.0.1", port), _Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, future + + +def exchange_code( + token_endpoint: str, + code: str, + redirect_uri: str, + client_id: str, + client_secret: str | None, + code_verifier: str, +) -> OAuthTokens: + """Exchange an authorization code for tokens.""" + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "code_verifier": code_verifier, + } + if client_secret: + data["client_secret"] = client_secret + resp = httpx.post(token_endpoint, data=data, timeout=30.0) + resp.raise_for_status() + body = resp.json() + expires_at: float | None = None + if "expires_in" in body: + expires_at = time.time() + body["expires_in"] + return OAuthTokens( + access_token=body["access_token"], + refresh_token=body.get("refresh_token"), + expires_at=expires_at, + client_id=client_id, + client_secret=client_secret, + ) + + +def refresh_access_token( + token_endpoint: str, + client_id: str, + client_secret: str | None, + refresh_token: str, +) -> OAuthTokens: + """Use a refresh token to obtain a new access token.""" + data = { + "grant_type": "refresh_token", + "client_id": client_id, + "refresh_token": refresh_token, + } + if client_secret: + data["client_secret"] = client_secret + resp = httpx.post(token_endpoint, data=data, timeout=30.0) + resp.raise_for_status() + body = resp.json() + expires_at: float | None = None + if "expires_in" in body: + expires_at = time.time() + body["expires_in"] + return OAuthTokens( + access_token=body["access_token"], + refresh_token=body.get("refresh_token", refresh_token), + expires_at=expires_at, + client_id=client_id, + client_secret=client_secret, + ) + + +def run_login_flow(dremio_url: str) -> OAuthTokens: + """Orchestrate the full browser-based OAuth login flow. + + 1. Discover OAuth endpoints + 2. Find a free port & register the client (DCR) + 3. Generate PKCE codes + 4. Start localhost callback server + 5. Open the browser (fall back to printing URL) + 6. Wait for the callback + 7. Exchange the code for tokens + """ + metadata = discover(dremio_url) + + port = find_free_port() + redirect_uri = f"http://localhost:{port}/Callback" + + client_id, client_secret = None, None + if metadata.registration_endpoint: + client_id, client_secret = register_client(metadata.registration_endpoint, redirect_uri) + if not client_id: + client_id, client_secret = _CLIENT_ID, None + + code_verifier, code_challenge = generate_pkce() + state = secrets.token_urlsafe(16) + + auth_url = build_authorization_url( + metadata.authorization_endpoint, + client_id, + redirect_uri, + code_challenge, + state, + ) + + server, future = start_callback_server(port) + try: + try: + opened = webbrowser.open(auth_url) + except Exception: + opened = False + if opened: + logger.info("Opened browser for OAuth login.") + else: + # Headless / no-browser fallback + print(f"\nOpen this URL in your browser to log in:\n\n {auth_url}\n") + + code, returned_state = future.result(timeout=300) + if returned_state != state: + raise RuntimeError("OAuth state mismatch — possible CSRF attack.") + finally: + server.shutdown() + + return exchange_code( + metadata.token_endpoint, + code, + redirect_uri, + client_id, + client_secret, + code_verifier, + ) diff --git a/src/drs/token_store.py b/src/drs/token_store.py new file mode 100644 index 0000000..8d8638a --- /dev/null +++ b/src/drs/token_store.py @@ -0,0 +1,70 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Persistent storage for OAuth tokens, keyed by Dremio URL.""" + +from __future__ import annotations + +from pathlib import Path + +import yaml +from pydantic import BaseModel + +STORE_PATH = Path.home() / ".config" / "dremioai" / "oauth_tokens.yaml" + + +class OAuthTokens(BaseModel): + access_token: str + refresh_token: str | None = None + expires_at: float | None = None + client_id: str + client_secret: str | None = None + + +def load(dremio_url: str, store_path: Path = STORE_PATH) -> OAuthTokens | None: + """Read stored OAuth tokens for *dremio_url*. Returns ``None`` if absent.""" + if not store_path.exists(): + return None + with store_path.open() as f: + data = yaml.safe_load(f) or {} + entry = data.get(dremio_url) + if entry is None: + return None + return OAuthTokens(**entry) + + +def save(dremio_url: str, tokens: OAuthTokens, store_path: Path = STORE_PATH) -> None: + """Persist *tokens* under *dremio_url*. Creates dirs and sets mode 600.""" + store_path.parent.mkdir(parents=True, exist_ok=True) + + data: dict = {} + if store_path.exists(): + with store_path.open() as f: + data = yaml.safe_load(f) or {} + + data[dremio_url] = tokens.model_dump(exclude_none=True) + store_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False)) + store_path.chmod(0o600) + + +def clear(dremio_url: str, store_path: Path = STORE_PATH) -> None: + """Remove stored tokens for *dremio_url*.""" + if not store_path.exists(): + return + with store_path.open() as f: + data = yaml.safe_load(f) or {} + data.pop(dremio_url, None) + store_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False)) + store_path.chmod(0o600) diff --git a/tests/test_auth.py b/tests/test_auth.py index 9187f81..348de25 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -25,7 +25,8 @@ import yaml from pydantic import ValidationError -from drs.auth import load_config +from drs.auth import DrsConfig, load_config +from drs.token_store import OAuthTokens def test_config_from_env_vars(tmp_path: Path) -> None: @@ -137,6 +138,55 @@ def test_dremio_token_overrides_dremio_pat(tmp_path: Path) -> None: assert config.pat == "new-token" +def test_config_oauth_from_token_store(tmp_path: Path) -> None: + """When no PAT is available but OAuth tokens exist, auth_method should be 'oauth'.""" + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"project_id": "proj-1"})) + + fake_tokens = OAuthTokens(access_token="oauth-at", refresh_token="oauth-rt", client_id="cid") + + with ( + patch.dict(os.environ, {}, clear=False), + patch("drs.token_store.load", return_value=fake_tokens), + ): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.auth_method == "oauth" + assert config.oauth_access_token == "oauth-at" + assert config.pat is None + + +def test_config_pat_still_works(tmp_path: Path) -> None: + """PAT auth should continue to work exactly as before.""" + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"pat": "my-pat", "project_id": "proj-1"})) + + with patch.dict(os.environ, {}, clear=False): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.auth_method == "pat" + assert config.pat == "my-pat" + + +def test_config_neither_pat_nor_oauth_fails(tmp_path: Path) -> None: + """Missing both PAT and OAuth tokens should raise ValidationError.""" + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"project_id": "proj-1"})) + + with ( + patch.dict(os.environ, {}, clear=False), + patch("drs.token_store.load", return_value=None), + pytest.raises(ValidationError), + ): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + load_config(config_file) + + def test_cli_args_override_env(tmp_path: Path) -> None: """CLI args should override env vars and file values.""" config_file = tmp_path / "config.yaml" diff --git a/tests/test_client_retry.py b/tests/test_client_retry.py index 485229b..3a29222 100644 --- a/tests/test_client_retry.py +++ b/tests/test_client_retry.py @@ -22,6 +22,7 @@ import httpx import pytest +from drs.auth import DrsConfig from drs.client import DremioClient @@ -154,3 +155,87 @@ async def test_retry_backoff_delays(config) -> None: assert mock_sleep.call_count == 2 mock_sleep.assert_any_call(1.0) mock_sleep.assert_any_call(2.0) + + +@pytest.mark.asyncio +async def test_401_triggers_oauth_refresh() -> None: + """401 with OAuth auth should attempt token refresh and retry.""" + oauth_config = DrsConfig( + uri="https://api.dremio.cloud", + oauth_access_token="old-token", + project_id="proj", + auth_method="oauth", + ) + client = DremioClient(oauth_config) + + unauthorized = httpx.Response(401, request=httpx.Request("GET", "https://example.com")) + ok_response = httpx.Response(200, json={"ok": True}, request=httpx.Request("GET", "https://example.com")) + + client._client.request = AsyncMock(side_effect=[unauthorized, ok_response]) + + with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock) as mock_refresh: + result = await client._get("https://example.com/test") + + assert result == {"ok": True} + mock_refresh.assert_awaited_once() + assert client._client.request.call_count == 2 + + +@pytest.mark.asyncio +async def test_401_no_refresh_for_pat(config) -> None: + """401 with PAT auth should NOT attempt token refresh.""" + client = DremioClient(config) + + unauthorized = httpx.Response(401, json={"error": "unauthorized"}, request=httpx.Request("GET", "https://example.com")) + client._client.request = AsyncMock(return_value=unauthorized) + + with pytest.raises(httpx.HTTPStatusError): + await client._get("https://example.com/test") + + assert client._client.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_401_refresh_failure_returns_original_response() -> None: + """When refresh fails, the original 401 response should be returned.""" + oauth_config = DrsConfig( + uri="https://api.dremio.cloud", + oauth_access_token="old-token", + project_id="proj", + auth_method="oauth", + ) + client = DremioClient(oauth_config) + + unauthorized = httpx.Response(401, json={"error": "bad token"}, request=httpx.Request("GET", "https://example.com")) + client._client.request = AsyncMock(return_value=unauthorized) + + with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock, side_effect=RuntimeError("no refresh")): + with pytest.raises(httpx.HTTPStatusError): + await client._get("https://example.com/test") + + # Should only have tried once (no retry after failed refresh) + assert client._client.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_401_only_one_refresh_per_request() -> None: + """Only one refresh attempt per request — second 401 should not retry.""" + oauth_config = DrsConfig( + uri="https://api.dremio.cloud", + oauth_access_token="old-token", + project_id="proj", + auth_method="oauth", + ) + client = DremioClient(oauth_config) + + unauthorized = httpx.Response(401, request=httpx.Request("GET", "https://example.com")) + + # Both attempts return 401 + client._client.request = AsyncMock(return_value=unauthorized) + + with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock): + with pytest.raises(httpx.HTTPStatusError): + await client._get("https://example.com/test") + + # 1st call -> 401 -> refresh -> 2nd call -> 401 -> return (no more refresh) + assert client._client.request.call_count == 2 diff --git a/tests/test_commands/test_login.py b/tests/test_commands/test_login.py new file mode 100644 index 0000000..bf20f2b --- /dev/null +++ b/tests/test_commands/test_login.py @@ -0,0 +1,132 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for dremio login / dremio logout commands.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from typer.testing import CliRunner + +from drs.cli import app +from drs.token_store import OAuthTokens, load, save + +runner = CliRunner() + + +def test_login_saves_tokens(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text("project_id: proj-1\n") + store_path = tmp_path / "oauth_tokens.yaml" + + fake_tokens = OAuthTokens( + access_token="at-new", + refresh_token="rt-new", + client_id="cid", + ) + + with ( + patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.commands.login.token_store.save") as mock_save, + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + patch("drs.commands.login._update_config_file"), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"]) + + assert result.exit_code == 0 + assert "successfully" in result.output.lower() or "logged in" in result.output.lower() + mock_save.assert_called_once() + saved_tokens = mock_save.call_args[0][1] + assert saved_tokens.access_token == "at-new" + + +def test_login_picks_project_from_list(tmp_path: Path) -> None: + """When project_id is not in config, login should show project list.""" + config_path = tmp_path / "config.yaml" + config_path.write_text("uri: https://api.dremio.cloud\n") + + fake_tokens = OAuthTokens(access_token="at", client_id="cid") + fake_projects = [ + {"id": "proj-aaa", "name": "Alpha"}, + {"id": "proj-bbb", "name": "Beta"}, + ] + + with ( + patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.commands.login.token_store.save"), + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + patch("drs.commands.login._update_config_file") as mock_update, + patch("drs.commands.login._fetch_projects", return_value=fake_projects), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"], input="2\n") + + assert result.exit_code == 0 + assert "Beta" in result.output + mock_update.assert_called_once() + assert mock_update.call_args[0][2] == "proj-bbb" + + +def test_login_auto_selects_single_project(tmp_path: Path) -> None: + """When only one project exists, auto-select it.""" + config_path = tmp_path / "config.yaml" + config_path.write_text("uri: https://api.dremio.cloud\n") + + fake_tokens = OAuthTokens(access_token="at", client_id="cid") + fake_projects = [{"id": "proj-only", "name": "OnlyProject"}] + + with ( + patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.commands.login.token_store.save"), + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + patch("drs.commands.login._update_config_file") as mock_update, + patch("drs.commands.login._fetch_projects", return_value=fake_projects), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"]) + + assert result.exit_code == 0 + assert "OnlyProject" in result.output + mock_update.assert_called_once() + assert mock_update.call_args[0][2] == "proj-only" + + +def test_logout_clears_tokens(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text("uri: https://api.dremio.cloud\nproject_id: proj-1\n") + + with ( + patch("drs.commands.login.token_store.clear") as mock_clear, + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + ): + result = runner.invoke(app, ["--config", str(config_path), "logout"]) + + assert result.exit_code == 0 + assert "logged out" in result.output.lower() or "removed" in result.output.lower() + mock_clear.assert_called_once_with("https://api.dremio.cloud") + + +def test_login_failure_exits_1(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text("project_id: proj-1\n") + + with ( + patch("drs.commands.login.oauth.run_login_flow", side_effect=RuntimeError("network error")), + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"]) + + assert result.exit_code == 1 + assert "failed" in result.output.lower() diff --git a/tests/test_commands/test_setup.py b/tests/test_commands/test_setup.py index 2bab050..697eeb7 100644 --- a/tests/test_commands/test_setup.py +++ b/tests/test_commands/test_setup.py @@ -27,6 +27,7 @@ from drs.auth import DEFAULT_URI from drs.cli import app from drs.commands.setup import validate_credentials, write_config +from drs.token_store import OAuthTokens runner = CliRunner() @@ -147,7 +148,7 @@ def test_write_config_includes_header(tmp_path) -> None: raw = config_path.read_text() assert raw.startswith("# Dremio CLI config") - assert "plaintext" in raw + assert "mode 600" in raw def test_setup_non_interactive(tmp_path) -> None: @@ -174,8 +175,8 @@ def test_setup_happy_path(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: region=1, PAT=test-pat, project_id=test-proj - result = runner.invoke(app, ["setup"], input="1\ntest-pat\ntest-proj\n") + # Input: region=1, auth=2(PAT), PAT=test-pat, project_id=test-proj + result = runner.invoke(app, ["setup"], input="1\n2\ntest-pat\ntest-proj\n") assert result.exit_code == 0 assert "Setup complete" in result.output @@ -219,8 +220,8 @@ def test_setup_existing_config_overwrite(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: overwrite=y, region=1, PAT=new-pat, project_id=new-proj - result = runner.invoke(app, ["setup"], input="y\n1\nnew-pat\nnew-proj\n") + # Input: overwrite=y, region=1, auth=2(PAT), PAT=new-pat, project_id=new-proj + result = runner.invoke(app, ["setup"], input="y\n1\n2\nnew-pat\nnew-proj\n") assert result.exit_code == 0 data = yaml.safe_load(config_path.read_text()) @@ -244,8 +245,8 @@ def test_setup_retry_then_abort(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: region=1, PAT=bad, project_id=p1, then decline retry - result = runner.invoke(app, ["setup"], input="1\nbad-pat\np1\nn\n") + # Input: region=1, auth=2(PAT), PAT=bad, project_id=p1, then decline retry + result = runner.invoke(app, ["setup"], input="1\n2\nbad-pat\np1\nn\n") assert result.exit_code == 1 assert "cancelled" in result.output.lower() @@ -265,9 +266,34 @@ def test_setup_global_config_passthrough(tmp_path) -> None: patch("drs.commands.setup.DremioClient", return_value=mock_client), ): mock_sys.stdin.isatty.return_value = True - result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\nmy-pat\nmy-proj\n") + result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\n2\nmy-pat\nmy-proj\n") assert result.exit_code == 0 assert config_path.exists() data = yaml.safe_load(config_path.read_text()) assert data["pat"] == "my-pat" + + +def test_setup_oauth_path(tmp_path) -> None: + """OAuth path through setup wizard: region, auth=1(OAuth), project_id.""" + config_path = tmp_path / "config.yaml" + + fake_tokens = OAuthTokens(access_token="at-oauth", client_id="cid") + + with ( + patch("drs.commands.setup.sys") as mock_sys, + patch("drs.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.token_store.save") as mock_save, + patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), + ): + mock_sys.stdin.isatty.return_value = True + # Input: region=1, auth=1(OAuth), project_id=my-proj + result = runner.invoke(app, ["setup"], input="1\n1\nmy-proj\n") + + assert result.exit_code == 0 + assert config_path.exists() + data = yaml.safe_load(config_path.read_text()) + assert data.get("auth_method") == "oauth" + assert "pat" not in data + assert data["project_id"] == "my-proj" + mock_save.assert_called_once() diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..0339ffc --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,282 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for drs.oauth — OAuth 2.0 PKCE flow mechanics.""" + +from __future__ import annotations + +import base64 +import hashlib +import urllib.request +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest + +from drs.oauth import ( + build_authorization_url, + discover, + exchange_code, + find_free_port, + generate_pkce, + refresh_access_token, + run_login_flow, + start_callback_server, +) + + +class TestPKCE: + def test_generate_pkce_format(self) -> None: + verifier, challenge = generate_pkce() + assert len(verifier) > 20 + assert len(challenge) > 20 + # No padding characters + assert "=" not in challenge + + def test_pkce_challenge_matches_verifier(self) -> None: + verifier, challenge = generate_pkce() + expected = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()).rstrip(b"=").decode("ascii") + assert challenge == expected + + def test_pkce_uniqueness(self) -> None: + v1, _ = generate_pkce() + v2, _ = generate_pkce() + assert v1 != v2 + + +class TestBuildAuthorizationURL: + def test_url_construction(self) -> None: + url = build_authorization_url( + auth_endpoint="https://auth.example.com/authorize", + client_id="my-client", + redirect_uri="http://localhost:8080/Callback", + code_challenge="abc123", + state="state-xyz", + ) + parsed = urlparse(url) + params = parse_qs(parsed.query) + + assert parsed.scheme == "https" + assert parsed.netloc == "auth.example.com" + assert parsed.path == "/authorize" + assert params["response_type"] == ["code"] + assert params["client_id"] == ["my-client"] + assert params["redirect_uri"] == ["http://localhost:8080/Callback"] + assert params["code_challenge"] == ["abc123"] + assert params["code_challenge_method"] == ["S256"] + assert params["scope"] == ["dremio.all offline_access"] + assert params["state"] == ["state-xyz"] + + +class TestLoginUrl: + def test_app_to_login(self) -> None: + from drs.oauth import _login_url + + assert _login_url("https://app.dremio.cloud") == "https://login.dremio.cloud" + assert _login_url("https://app.dev.dremio.site") == "https://login.dev.dremio.site" + assert _login_url("https://app.eu.dremio.cloud") == "https://login.eu.dremio.cloud" + + def test_non_app_host_unchanged(self) -> None: + from drs.oauth import _login_url + + assert _login_url("https://login.dremio.cloud") == "https://login.dremio.cloud" + assert _login_url("https://custom.example.com") == "https://custom.example.com" + + +class TestDiscover: + def test_discover_parses_metadata(self) -> None: + metadata_json = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + } + + mock_response = MagicMock() + mock_response.json.return_value = metadata_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.get", return_value=mock_response) as mock_get: + result = discover("https://app.dremio.cloud") + + mock_get.assert_called_once_with( + "https://login.dremio.cloud/.well-known/oauth-authorization-server", + timeout=30.0, + ) + assert result.authorization_endpoint == "https://auth.example.com/authorize" + assert result.token_endpoint == "https://auth.example.com/token" + assert result.registration_endpoint == "https://auth.example.com/register" + + def test_discover_optional_registration(self) -> None: + metadata_json = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + } + + mock_response = MagicMock() + mock_response.json.return_value = metadata_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.get", return_value=mock_response): + result = discover("https://app.dremio.cloud") + + assert result.registration_endpoint is None + + +class TestExchangeCode: + def test_exchange_code_success(self) -> None: + token_json = { + "access_token": "at-new", + "refresh_token": "rt-new", + "expires_in": 3600, + } + mock_response = MagicMock() + mock_response.json.return_value = token_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.post", return_value=mock_response): + tokens = exchange_code( + "https://auth.example.com/token", + "auth-code-123", + "http://localhost:8080/Callback", + "my-client", + "my-secret", + "my-verifier", + ) + + assert tokens.access_token == "at-new" + assert tokens.refresh_token == "rt-new" + assert tokens.expires_at is not None + assert tokens.client_id == "my-client" + assert tokens.client_secret == "my-secret" + + +class TestRefreshAccessToken: + def test_refresh_success(self) -> None: + token_json = { + "access_token": "at-refreshed", + "refresh_token": "rt-new", + "expires_in": 3600, + } + mock_response = MagicMock() + mock_response.json.return_value = token_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.post", return_value=mock_response): + tokens = refresh_access_token( + "https://auth.example.com/token", + "my-client", + "my-secret", + "rt-old", + ) + + assert tokens.access_token == "at-refreshed" + assert tokens.refresh_token == "rt-new" + + def test_refresh_preserves_old_refresh_token(self) -> None: + """When server omits refresh_token in response, keep the old one.""" + token_json = { + "access_token": "at-refreshed", + "expires_in": 3600, + } + mock_response = MagicMock() + mock_response.json.return_value = token_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.post", return_value=mock_response): + tokens = refresh_access_token( + "https://auth.example.com/token", + "my-client", + None, + "rt-old", + ) + + assert tokens.refresh_token == "rt-old" + + +class TestCallbackServer: + def test_callback_captures_code_and_state(self) -> None: + port = find_free_port() + server, future = start_callback_server(port) + try: + url = f"http://localhost:{port}/Callback?code=test-code&state=test-state" + urllib.request.urlopen(url, timeout=5) + code, state = future.result(timeout=5) + assert code == "test-code" + assert state == "test-state" + finally: + server.shutdown() + + def test_find_free_port_returns_positive(self) -> None: + port = find_free_port() + assert port > 0 + + +class TestRunLoginFlow: + @pytest.mark.parametrize("browser_behavior", ["raises", "returns_false"]) + def test_headless_fallback_prints_url(self, browser_behavior: str) -> None: + """When webbrowser.open raises or returns False, the URL should be printed.""" + metadata_json = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + } + dcr_json = {"client_id": "cid", "client_secret": "cs"} + token_json = {"access_token": "at", "refresh_token": "rt", "expires_in": 3600} + + mock_get = MagicMock() + mock_get.json.return_value = metadata_json + mock_get.raise_for_status = MagicMock() + + post_responses = [] + # DCR call + dcr_resp = MagicMock() + dcr_resp.json.return_value = dcr_json + dcr_resp.raise_for_status = MagicMock() + post_responses.append(dcr_resp) + # Token exchange + token_resp = MagicMock() + token_resp.json.return_value = token_json + token_resp.raise_for_status = MagicMock() + post_responses.append(token_resp) + + if browser_behavior == "raises": + browser_mock = MagicMock(side_effect=RuntimeError("no browser")) + else: + browser_mock = MagicMock(return_value=False) + + with ( + patch("drs.oauth.httpx.get", return_value=mock_get), + patch("drs.oauth.httpx.post", side_effect=post_responses), + patch("drs.oauth.webbrowser.open", browser_mock), + patch("drs.oauth.start_callback_server") as mock_server, + patch("builtins.print") as mock_print, + ): + mock_future = MagicMock() + mock_future.result.return_value = ("code-123", None) + mock_srv = MagicMock() + mock_server.return_value = (mock_srv, mock_future) + + # The state won't match because we bypass the real flow, + # so we need to patch around the state check too. + try: + run_login_flow("https://api.dremio.cloud") + except RuntimeError: + pass # state mismatch expected in mocked test + + # Verify fallback print was called + mock_print.assert_called_once() + printed_text = mock_print.call_args[0][0] + assert "browser" in printed_text.lower() or "url" in printed_text.lower() or "http" in printed_text.lower() diff --git a/tests/test_token_store.py b/tests/test_token_store.py new file mode 100644 index 0000000..45c71eb --- /dev/null +++ b/tests/test_token_store.py @@ -0,0 +1,97 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for drs.token_store — OAuth token persistence.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from drs.token_store import OAuthTokens, clear, load, save + + +@pytest.fixture +def store_path(tmp_path: Path) -> Path: + return tmp_path / "oauth_tokens.yaml" + + +@pytest.fixture +def sample_tokens() -> OAuthTokens: + return OAuthTokens( + access_token="at-123", + refresh_token="rt-456", + expires_at=1700000000.0, + client_id="test-client", + client_secret="test-secret", + ) + + +def test_round_trip(store_path: Path, sample_tokens: OAuthTokens) -> None: + url = "https://api.dremio.cloud" + save(url, sample_tokens, store_path=store_path) + loaded = load(url, store_path=store_path) + assert loaded is not None + assert loaded == sample_tokens + + +def test_load_missing_url(store_path: Path, sample_tokens: OAuthTokens) -> None: + save("https://api.dremio.cloud", sample_tokens, store_path=store_path) + assert load("https://api.eu.dremio.cloud", store_path=store_path) is None + + +def test_load_nonexistent_file(store_path: Path) -> None: + assert load("https://api.dremio.cloud", store_path=store_path) is None + + +def test_clear_removes_entry(store_path: Path, sample_tokens: OAuthTokens) -> None: + url = "https://api.dremio.cloud" + save(url, sample_tokens, store_path=store_path) + clear(url, store_path=store_path) + assert load(url, store_path=store_path) is None + + +def test_clear_nonexistent_is_noop(store_path: Path) -> None: + clear("https://api.dremio.cloud", store_path=store_path) # should not raise + + +def test_file_mode_600(store_path: Path, sample_tokens: OAuthTokens) -> None: + save("https://api.dremio.cloud", sample_tokens, store_path=store_path) + assert oct(store_path.stat().st_mode & 0o777) == "0o600" + + +def test_multi_instance_keying(store_path: Path) -> None: + url_us = "https://api.dremio.cloud" + url_eu = "https://api.eu.dremio.cloud" + tokens_us = OAuthTokens(access_token="us-token", client_id="cid") + tokens_eu = OAuthTokens(access_token="eu-token", client_id="cid") + + save(url_us, tokens_us, store_path=store_path) + save(url_eu, tokens_eu, store_path=store_path) + + loaded_us = load(url_us, store_path=store_path) + loaded_eu = load(url_eu, store_path=store_path) + assert loaded_us is not None + assert loaded_eu is not None + assert loaded_us.access_token == "us-token" + assert loaded_eu.access_token == "eu-token" + + +def test_save_creates_parent_dirs(tmp_path: Path) -> None: + deep_path = tmp_path / "a" / "b" / "tokens.yaml" + tokens = OAuthTokens(access_token="at", client_id="cid") + save("https://api.dremio.cloud", tokens, store_path=deep_path) + assert deep_path.exists()