From 9dd60d49e464dd57d5700ad473a4c28716f5dad3 Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Tue, 14 Apr 2026 19:22:44 +0000 Subject: [PATCH 1/8] feat: add browser-based OAuth login flow (DX-118868) Add OAuth Authorization Code + PKCE flow as an alternative to PAT authentication. Users can now run `dremio login` to authenticate via browser, with tokens stored in ~/.config/dremioai/oauth_tokens.yaml. - OAuth discovery via .well-known/oauth-authorization-server - Dynamic Client Registration (DCR) with PKCE - Localhost callback server for auth code capture - Automatic token refresh on 401 (one retry per request) - Headless fallback: prints URL when browser unavailable - `dremio login` / `dremio logout` commands - Setup wizard offers OAuth alongside PAT Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/auth.py | 26 ++- src/drs/cli.py | 8 +- src/drs/client.py | 37 +++- src/drs/commands/login.py | 127 ++++++++++++++ src/drs/commands/setup.py | 145 +++++++++++----- src/drs/oauth.py | 275 ++++++++++++++++++++++++++++++ src/drs/token_store.py | 70 ++++++++ tests/test_auth.py | 52 +++++- tests/test_client_retry.py | 85 +++++++++ tests/test_commands/test_login.py | 102 +++++++++++ tests/test_commands/test_setup.py | 42 ++++- tests/test_oauth.py | 267 +++++++++++++++++++++++++++++ tests/test_token_store.py | 97 +++++++++++ 13 files changed, 1276 insertions(+), 57 deletions(-) create mode 100644 src/drs/commands/login.py create mode 100644 src/drs/oauth.py create mode 100644 src/drs/token_store.py create mode 100644 tests/test_commands/test_login.py create mode 100644 tests/test_oauth.py create mode 100644 tests/test_token_store.py diff --git a/src/drs/auth.py b/src/drs/auth.py index 051f365..68ea573 100644 --- a/src/drs/auth.py +++ b/src/drs/auth.py @@ -19,10 +19,10 @@ import os from pathlib import Path -from typing import Any +from typing import Any, Literal import yaml -from pydantic import BaseModel +from pydantic import BaseModel, model_validator DEFAULT_CONFIG_PATH = Path.home() / ".config" / "dremioai" / "config.yaml" DEFAULT_URI = "https://api.dremio.cloud" @@ -30,8 +30,16 @@ class DrsConfig(BaseModel): uri: str = DEFAULT_URI - pat: str + pat: str | None = None project_id: str + auth_method: Literal["pat", "oauth"] = "pat" + oauth_access_token: str | None = None + + @model_validator(mode="after") + def _require_credential(self) -> "DrsConfig": + if not self.pat and not self.oauth_access_token: + raise ValueError("Either 'pat' or 'oauth_access_token' must be provided.") + return self def load_config( @@ -58,6 +66,7 @@ def load_config( "uri": raw.get("uri", raw.get("endpoint")), "pat": raw.get("pat", raw.get("token")), "project_id": raw.get("project_id", raw.get("projectId")), + "auth_method": raw.get("auth_method"), } file_values = {k: v for k, v in file_values.items() if v is not None} @@ -83,4 +92,15 @@ def load_config( if cli_token: merged["pat"] = cli_token + # If no PAT is available, try loading OAuth tokens from the token store. + if "pat" not in merged or not merged["pat"]: + from drs import token_store + + uri = merged.get("uri", DEFAULT_URI) + tokens = token_store.load(uri) + if tokens is not None: + merged["auth_method"] = "oauth" + merged["oauth_access_token"] = tokens.access_token + merged.pop("pat", None) + return DrsConfig(**merged) diff --git a/src/drs/cli.py b/src/drs/cli.py index 8a86e4d..34ece06 100644 --- a/src/drs/cli.py +++ b/src/drs/cli.py @@ -34,6 +34,7 @@ folder, grant, job, + login, project, query, reflection, @@ -69,6 +70,8 @@ app.add_typer(project.app, name="project") app.add_typer(chat.app, name="chat") app.command("setup")(setup.setup_command) +app.command("login")(login.login_command) +app.command("logout")(login.logout_command) # Global state for config _config: DrsConfig | None = None @@ -137,8 +140,9 @@ def get_config() -> DrsConfig: Console(stderr=True).print( "\n[bold red]Configuration required[/bold red]\n\n" - "The Dremio CLI needs a Personal Access Token and Project ID.\n\n" - " [bold]Quick setup:[/] Run [bold cyan]dremio setup[/bold cyan]\n\n" + "The Dremio CLI needs authentication credentials and a Project ID.\n\n" + " [bold]Quick setup:[/] Run [bold cyan]dremio setup[/bold cyan]\n" + " [bold]OAuth login:[/] Run [bold cyan]dremio login[/bold cyan]\n\n" " [dim]Or provide credentials manually:[/dim]\n" " --token / DREMIO_TOKEN env var\n" " --project-id / DREMIO_PROJECT_ID env var\n" diff --git a/src/drs/client.py b/src/drs/client.py index e1ab86a..10d2a1d 100644 --- a/src/drs/client.py +++ b/src/drs/client.py @@ -44,9 +44,10 @@ class DremioClient: def __init__(self, config: DrsConfig) -> None: self.config = config + token = config.oauth_access_token if "oauth" == config.auth_method else config.pat self._client = httpx.AsyncClient( headers={ - "Authorization": f"Bearer {config.pat}", + "Authorization": f"Bearer {token}", "Content-Type": "application/json", }, timeout=120.0, @@ -70,12 +71,46 @@ def _v1(self, path: str) -> str: # -- HTTP helpers with retry -- + async def _refresh_oauth_token(self) -> None: + """Refresh the OAuth access token using the stored refresh token.""" + from drs import oauth, token_store + + tokens = token_store.load(self.config.uri) + if tokens is None or tokens.refresh_token is None: + raise RuntimeError("No refresh token available — please run 'dremio login' again.") + + metadata = oauth.discover(self.config.uri) + new_tokens = oauth.refresh_access_token( + metadata.token_endpoint, + tokens.client_id, + tokens.client_secret, + tokens.refresh_token, + ) + token_store.save(self.config.uri, new_tokens) + self._client.headers["authorization"] = f"Bearer {new_tokens.access_token}" + logger.info("OAuth access token refreshed successfully.") + async def _request_with_retry(self, method: str, url: str, **kwargs: Any) -> httpx.Response: """Execute an HTTP request with retry on transient errors.""" + auth_refreshed = False last_exc: Exception | None = None for attempt in range(_MAX_RETRIES): try: resp = await self._client.request(method, url, **kwargs) + # 401 with OAuth: attempt one token refresh per call + if ( + resp.status_code == 401 + and "oauth" == self.config.auth_method + and not auth_refreshed + ): + logger.info("Received 401 — attempting OAuth token refresh.") + try: + await self._refresh_oauth_token() + except Exception: + logger.warning("OAuth token refresh failed.", exc_info=True) + return resp + auth_refreshed = True + continue if resp.status_code in _RETRYABLE_STATUS_CODES and attempt < _MAX_RETRIES - 1: delay = _RETRY_BACKOFF[attempt] logger.warning( diff --git a/src/drs/commands/login.py b/src/drs/commands/login.py new file mode 100644 index 0000000..bb9979f --- /dev/null +++ b/src/drs/commands/login.py @@ -0,0 +1,127 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""``dremio login`` and ``dremio logout`` commands — OAuth browser flow.""" + +from __future__ import annotations + +from pathlib import Path + +import typer +import yaml +from rich.console import Console + +from drs import oauth, token_store +from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI + +console = Console() +err_console = Console(stderr=True) + + +def _resolve_uri(ctx: typer.Context) -> str: + """Resolve the Dremio API URI from CLI flags, config file, or default.""" + # Check CLI --uri + config_path_obj = ctx.obj.get("config_path") if ctx.obj else None + # Check if parent set cli_uri + from drs.cli import _cli_opts + + cli_uri = _cli_opts.get("cli_uri") + if cli_uri: + return cli_uri + + # Try config file + path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH + if path.exists(): + with path.open() as f: + raw = yaml.safe_load(f) or {} + uri = raw.get("uri", raw.get("endpoint")) + if uri: + return uri + + return DEFAULT_URI + + +def _resolve_project_id(ctx: typer.Context) -> str: + """Resolve project_id from CLI flags, config file, or prompt the user.""" + from drs.cli import _cli_opts + + cli_project_id = _cli_opts.get("cli_project_id") + if cli_project_id: + return cli_project_id + + config_path_obj = ctx.obj.get("config_path") if ctx.obj else None + path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH + if path.exists(): + with path.open() as f: + raw = yaml.safe_load(f) or {} + project_id = raw.get("project_id", raw.get("projectId")) + if project_id: + return project_id + + # Prompt user + return typer.prompt("Enter your Dremio Cloud Project ID").strip() + + +def login_command(ctx: typer.Context) -> None: + """Log in to Dremio Cloud via OAuth (opens your browser).""" + uri = _resolve_uri(ctx) + console.print(f"\nLogging in to [bold]{uri}[/bold] ...") + + try: + tokens = oauth.run_login_flow(uri) + except Exception as exc: + err_console.print(f"\n[bold red]Login failed:[/bold red] {exc}") + raise typer.Exit(1) + + # Ensure we have a project_id to write into the config + project_id = _resolve_project_id(ctx) + + token_store.save(uri, tokens) + + # Also persist auth_method + project_id in config file so subsequent + # commands pick up OAuth automatically. + config_path_obj = ctx.obj.get("config_path") if ctx.obj else None + config_path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH + _update_config_file(config_path, uri, project_id) + + console.print(f"\n[green]Logged in successfully.[/green] Tokens saved for {uri}") + + +def _update_config_file(config_path: Path, uri: str, project_id: str) -> None: + """Ensure the config file records auth_method=oauth and project_id.""" + data: dict = {} + if config_path.exists(): + with config_path.open() as f: + data = yaml.safe_load(f) or {} + + if uri != DEFAULT_URI: + data["uri"] = uri + data["auth_method"] = "oauth" + data["project_id"] = project_id + # Remove PAT if present — OAuth replaces it. + data.pop("pat", None) + data.pop("token", None) + + config_path.parent.mkdir(parents=True, exist_ok=True) + header = "# Dremio CLI config — generated by 'dremio login'\n" + config_path.write_text(header + yaml.dump(data, default_flow_style=False, sort_keys=False)) + config_path.chmod(0o600) + + +def logout_command(ctx: typer.Context) -> None: + """Log out of Dremio Cloud (removes stored OAuth tokens).""" + uri = _resolve_uri(ctx) + token_store.clear(uri) + console.print(f"Logged out of [bold]{uri}[/bold]. OAuth tokens removed.") diff --git a/src/drs/commands/setup.py b/src/drs/commands/setup.py index fb77ea2..9775e78 100644 --- a/src/drs/commands/setup.py +++ b/src/drs/commands/setup.py @@ -32,6 +32,11 @@ from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI, DrsConfig from drs.client import DremioClient +AUTH_METHODS = { + "1": "oauth", + "2": "pat", +} + REGIONS = { "1": ("US", "https://api.dremio.cloud", "https://app.dremio.cloud"), "2": ("EU", "https://api.eu.dremio.cloud", "https://app.eu.dremio.cloud"), @@ -65,16 +70,25 @@ async def validate_credentials(uri: str, pat: str, project_id: str) -> tuple[boo await client.close() -def write_config(uri: str, pat: str, project_id: str, config_path: Path) -> None: +def write_config( + uri: str, + pat: str | None, + project_id: str, + config_path: Path, + auth_method: str = "pat", +) -> None: """Write YAML config file, creating parent directories as needed.""" data: dict[str, str] = {} if uri != DEFAULT_URI: data["uri"] = uri - data["pat"] = pat + if auth_method != "pat": + data["auth_method"] = auth_method + if pat: + data["pat"] = pat data["project_id"] = project_id config_path.parent.mkdir(parents=True, exist_ok=True) - header = "# Dremio CLI config — generated by 'dremio setup'\n# PAT is stored in plaintext. Keep this file private (mode 600).\n" + header = "# Dremio CLI config — generated by 'dremio setup'\n# Keep this file private (mode 600).\n" config_path.write_text(header + yaml.dump(data, default_flow_style=False, sort_keys=False)) config_path.chmod(0o600) @@ -100,12 +114,33 @@ def _prompt_region() -> tuple[str, str]: return api_uri, app_url +def _prompt_auth_method() -> str: + """Prompt for authentication method. Returns 'oauth' or 'pat'.""" + console.print() + console.print( + Panel( + "[bold]Step 2: Choose authentication method[/bold]\n\n" + " [cyan]1[/cyan]) OAuth (browser login) — recommended\n" + " [cyan]2[/cyan]) PAT (manual personal access token)", + title="Authentication", + border_style="blue", + ) + ) + choice = typer.prompt("Enter 1 or 2", default="1").strip() + if choice not in AUTH_METHODS: + console.print("[yellow]Invalid choice, defaulting to OAuth.[/yellow]") + choice = "1" + method = AUTH_METHODS[choice] + console.print(f" -> Auth method: [bold]{method.upper()}[/bold]") + return method + + def _prompt_pat(app_url: str) -> str: """Prompt for Personal Access Token with step-by-step instructions.""" console.print() console.print( Panel( - "[bold]Step 2: Create a Personal Access Token (PAT)[/bold]\n\n" + "[bold]Step 3: Create a Personal Access Token (PAT)[/bold]\n\n" f" 1. Open [link={app_url}]{app_url}[/link] and sign in\n" " 2. Click your profile icon (bottom-left) → [bold]Account Settings[/bold]\n" " 3. Go to [bold]Personal Access Tokens[/bold]\n" @@ -127,7 +162,7 @@ def _prompt_project_id(app_url: str) -> str: console.print() console.print( Panel( - "[bold]Step 3: Find your Project ID[/bold]\n\n" + "[bold]Find your Project ID[/bold]\n\n" f" 1. Open [link={app_url}]{app_url}[/link]\n" " 2. Select your project from the top-left dropdown\n" " 3. Go to [bold]Project Settings[/bold] → [bold]General[/bold]\n" @@ -174,7 +209,7 @@ def setup_command( "This wizard will help you connect the Dremio CLI to your Dremio Cloud account.\n\n" "You'll need:\n" " • A [bold]Dremio Cloud account[/bold] (sign up at [link=https://app.dremio.cloud]app.dremio.cloud[/link])\n" - " • A [bold]Personal Access Token[/bold] (we'll walk you through creating one)\n" + " • A [bold]Personal Access Token[/bold] or [bold]OAuth browser login[/bold]\n" " • A [bold]Project ID[/bold] (we'll show you where to find it)", title="[bold]Dremio CLI Setup[/bold]", border_style="cyan", @@ -191,53 +226,79 @@ def setup_command( # Step 1: Region api_uri, app_url = _prompt_region() - # Step 2: PAT (with retry loop) - pat = _prompt_pat(app_url) + # Step 2: Auth method + auth_method = _prompt_auth_method() - # Step 3: Project ID (with retry loop) - project_id = _prompt_project_id(app_url) + pat: str | None = None + project_name: str = "" - # Validate - console.print() - with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): - ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) + if auth_method == "oauth": + # OAuth flow — browser login, then project ID + from drs import oauth, token_store - while not ok: - console.print(f"\n[red]✗ {message}[/red]") - if not typer.confirm("Would you like to try again?", default=True): - console.print("Setup cancelled.") + console.print("\n[bold]Starting OAuth browser login...[/bold]") + try: + tokens = oauth.run_login_flow(api_uri) + except Exception as exc: + err_console.print(f"\n[bold red]OAuth login failed:[/bold red] {exc}") raise typer.Exit(1) - if "Authentication" in message: - console.print("[dim]Let's try the PAT again.[/dim]") - pat = _prompt_pat(app_url) - elif "Access denied" in message: - console.print( - "\n [cyan]1[/cyan]) Re-enter PAT (token may lack permissions)\n [cyan]2[/cyan]) Re-enter Project ID" - ) - choice = typer.prompt("Which would you like to fix?", default="1").strip() - if choice == "2": - project_id = _prompt_project_id(app_url) - else: - pat = _prompt_pat(app_url) - elif "Project" in message: - console.print("[dim]Let's try the Project ID again.[/dim]") - project_id = _prompt_project_id(app_url) - else: - console.print("[dim]Let's try the region again.[/dim]") - api_uri, app_url = _prompt_region() - pat = _prompt_pat(app_url) - project_id = _prompt_project_id(app_url) + token_store.save(api_uri, tokens) + console.print("[green]OAuth login successful.[/green]") + + # Step 3: Project ID + project_id = _prompt_project_id(app_url) + project_name = project_id + write_config(api_uri, None, project_id, config_path, auth_method="oauth") + else: + # PAT flow (existing behavior) + pat = _prompt_pat(app_url) + + # Step 3: Project ID (with retry loop) + project_id = _prompt_project_id(app_url) + + # Validate console.print() with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) - # Success — write config - project_name = project_data.get("name", project_id) if project_data else project_id - console.print(f"\n[green]✓ {message}[/green]") + while not ok: + console.print(f"\n[red]x {message}[/red]") + if not typer.confirm("Would you like to try again?", default=True): + console.print("Setup cancelled.") + raise typer.Exit(1) + + if "Authentication" in message: + console.print("[dim]Let's try the PAT again.[/dim]") + pat = _prompt_pat(app_url) + elif "Access denied" in message: + console.print( + "\n [cyan]1[/cyan]) Re-enter PAT (token may lack permissions)\n [cyan]2[/cyan]) Re-enter Project ID" + ) + choice = typer.prompt("Which would you like to fix?", default="1").strip() + if choice == "2": + project_id = _prompt_project_id(app_url) + else: + pat = _prompt_pat(app_url) + elif "Project" in message: + console.print("[dim]Let's try the Project ID again.[/dim]") + project_id = _prompt_project_id(app_url) + else: + console.print("[dim]Let's try the region again.[/dim]") + api_uri, app_url = _prompt_region() + pat = _prompt_pat(app_url) + project_id = _prompt_project_id(app_url) + + console.print() + with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): + ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) + + # Success + project_name = project_data.get("name", project_id) if project_data else project_id + console.print(f"\n[green]v {message}[/green]") - write_config(api_uri, pat, project_id, config_path) + write_config(api_uri, pat, project_id, config_path) console.print() success = Text() diff --git a/src/drs/oauth.py b/src/drs/oauth.py new file mode 100644 index 0000000..67f22d2 --- /dev/null +++ b/src/drs/oauth.py @@ -0,0 +1,275 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""OAuth 2.0 PKCE browser-login flow for Dremio Cloud.""" + +from __future__ import annotations + +import base64 +import hashlib +import logging +import secrets +import socket +import threading +import time +import webbrowser +from concurrent.futures import Future +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, HTTPServer +from urllib.parse import parse_qs, urlencode, urlparse + +import httpx + +from drs.token_store import OAuthTokens + +logger = logging.getLogger(__name__) + +_CLIENT_ID = "https://connectors.dremio.app/claude" + +_SUCCESS_HTML = """\ + +Dremio CLI + +

Login successful

+

You can close this tab and return to the terminal.

+ +""" + + +@dataclass +class OAuthServerMetadata: + authorization_endpoint: str + token_endpoint: str + registration_endpoint: str | None = None + + +def discover(dremio_url: str) -> OAuthServerMetadata: + """Fetch OAuth Authorization Server Metadata from *dremio_url*.""" + url = f"{dremio_url}/.well-known/oauth-authorization-server" + resp = httpx.get(url, timeout=30.0) + resp.raise_for_status() + data = resp.json() + return OAuthServerMetadata( + authorization_endpoint=data["authorization_endpoint"], + token_endpoint=data["token_endpoint"], + registration_endpoint=data.get("registration_endpoint"), + ) + + +def register_client(registration_endpoint: str, redirect_uri: str) -> tuple[str, str | None]: + """Dynamic Client Registration. Returns ``(client_id, client_secret)``.""" + body = { + "client_id": _CLIENT_ID, + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + } + resp = httpx.post(registration_endpoint, json=body, timeout=30.0) + resp.raise_for_status() + data = resp.json() + return data["client_id"], data.get("client_secret") + + +def generate_pkce() -> tuple[str, str]: + """Generate PKCE ``(code_verifier, code_challenge)`` pair.""" + code_verifier = secrets.token_urlsafe(32) + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return code_verifier, code_challenge + + +def build_authorization_url( + auth_endpoint: str, + client_id: str, + redirect_uri: str, + code_challenge: str, + state: str, +) -> str: + """Construct the authorization URL the user's browser should open.""" + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "scope": "openid offline_access", + "state": state, + } + return f"{auth_endpoint}?{urlencode(params)}" + + +def find_free_port() -> int: + """Bind to port 0 to obtain a free ephemeral port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def start_callback_server(port: int) -> tuple[HTTPServer, Future[tuple[str, str]]]: + """Start a localhost HTTP server that captures the OAuth callback. + + Returns ``(server, future)`` where the future resolves to ``(code, state)``. + """ + future: Future[tuple[str, str]] = Future() + + class _Handler(BaseHTTPRequestHandler): + def do_GET(self) -> None: # noqa: N802 + qs = parse_qs(urlparse(self.path).query) + code = qs.get("code", [None])[0] + state = qs.get("state", [None])[0] + error = qs.get("error", [None])[0] + + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + + if code and state: + self.wfile.write(_SUCCESS_HTML.encode()) + future.set_result((code, state)) + else: + error_msg = error or "unknown" + self.wfile.write(f"

Login failed: {error_msg}

".encode()) + if not future.done(): + future.set_exception(RuntimeError(f"OAuth callback error: {error_msg}")) + + def log_message(self, format: str, *args: object) -> None: # noqa: A002 + logger.debug(format, *args) + + server = HTTPServer(("127.0.0.1", port), _Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, future + + +def exchange_code( + token_endpoint: str, + code: str, + redirect_uri: str, + client_id: str, + client_secret: str | None, + code_verifier: str, +) -> OAuthTokens: + """Exchange an authorization code for tokens.""" + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "code_verifier": code_verifier, + } + if client_secret: + data["client_secret"] = client_secret + resp = httpx.post(token_endpoint, data=data, timeout=30.0) + resp.raise_for_status() + body = resp.json() + expires_at: float | None = None + if "expires_in" in body: + expires_at = time.time() + body["expires_in"] + return OAuthTokens( + access_token=body["access_token"], + refresh_token=body.get("refresh_token"), + expires_at=expires_at, + client_id=client_id, + client_secret=client_secret, + ) + + +def refresh_access_token( + token_endpoint: str, + client_id: str, + client_secret: str | None, + refresh_token: str, +) -> OAuthTokens: + """Use a refresh token to obtain a new access token.""" + data = { + "grant_type": "refresh_token", + "client_id": client_id, + "refresh_token": refresh_token, + } + if client_secret: + data["client_secret"] = client_secret + resp = httpx.post(token_endpoint, data=data, timeout=30.0) + resp.raise_for_status() + body = resp.json() + expires_at: float | None = None + if "expires_in" in body: + expires_at = time.time() + body["expires_in"] + return OAuthTokens( + access_token=body["access_token"], + refresh_token=body.get("refresh_token", refresh_token), + expires_at=expires_at, + client_id=client_id, + client_secret=client_secret, + ) + + +def run_login_flow(dremio_url: str) -> OAuthTokens: + """Orchestrate the full browser-based OAuth login flow. + + 1. Discover OAuth endpoints + 2. Find a free port & register the client (DCR) + 3. Generate PKCE codes + 4. Start localhost callback server + 5. Open the browser (fall back to printing URL) + 6. Wait for the callback + 7. Exchange the code for tokens + """ + metadata = discover(dremio_url) + + port = find_free_port() + redirect_uri = f"http://localhost:{port}/callback" + + if metadata.registration_endpoint: + client_id, client_secret = register_client(metadata.registration_endpoint, redirect_uri) + else: + client_id, client_secret = _CLIENT_ID, None + + code_verifier, code_challenge = generate_pkce() + state = secrets.token_urlsafe(16) + + auth_url = build_authorization_url( + metadata.authorization_endpoint, + client_id, + redirect_uri, + code_challenge, + state, + ) + + server, future = start_callback_server(port) + try: + try: + opened = webbrowser.open(auth_url) + except Exception: + opened = False + if opened: + logger.info("Opened browser for OAuth login.") + else: + # Headless / no-browser fallback + print(f"\nOpen this URL in your browser to log in:\n\n {auth_url}\n") + + code, returned_state = future.result(timeout=300) + if returned_state != state: + raise RuntimeError("OAuth state mismatch — possible CSRF attack.") + finally: + server.shutdown() + + return exchange_code( + metadata.token_endpoint, + code, + redirect_uri, + client_id, + client_secret, + code_verifier, + ) diff --git a/src/drs/token_store.py b/src/drs/token_store.py new file mode 100644 index 0000000..8d8638a --- /dev/null +++ b/src/drs/token_store.py @@ -0,0 +1,70 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Persistent storage for OAuth tokens, keyed by Dremio URL.""" + +from __future__ import annotations + +from pathlib import Path + +import yaml +from pydantic import BaseModel + +STORE_PATH = Path.home() / ".config" / "dremioai" / "oauth_tokens.yaml" + + +class OAuthTokens(BaseModel): + access_token: str + refresh_token: str | None = None + expires_at: float | None = None + client_id: str + client_secret: str | None = None + + +def load(dremio_url: str, store_path: Path = STORE_PATH) -> OAuthTokens | None: + """Read stored OAuth tokens for *dremio_url*. Returns ``None`` if absent.""" + if not store_path.exists(): + return None + with store_path.open() as f: + data = yaml.safe_load(f) or {} + entry = data.get(dremio_url) + if entry is None: + return None + return OAuthTokens(**entry) + + +def save(dremio_url: str, tokens: OAuthTokens, store_path: Path = STORE_PATH) -> None: + """Persist *tokens* under *dremio_url*. Creates dirs and sets mode 600.""" + store_path.parent.mkdir(parents=True, exist_ok=True) + + data: dict = {} + if store_path.exists(): + with store_path.open() as f: + data = yaml.safe_load(f) or {} + + data[dremio_url] = tokens.model_dump(exclude_none=True) + store_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False)) + store_path.chmod(0o600) + + +def clear(dremio_url: str, store_path: Path = STORE_PATH) -> None: + """Remove stored tokens for *dremio_url*.""" + if not store_path.exists(): + return + with store_path.open() as f: + data = yaml.safe_load(f) or {} + data.pop(dremio_url, None) + store_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False)) + store_path.chmod(0o600) diff --git a/tests/test_auth.py b/tests/test_auth.py index 9187f81..348de25 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -25,7 +25,8 @@ import yaml from pydantic import ValidationError -from drs.auth import load_config +from drs.auth import DrsConfig, load_config +from drs.token_store import OAuthTokens def test_config_from_env_vars(tmp_path: Path) -> None: @@ -137,6 +138,55 @@ def test_dremio_token_overrides_dremio_pat(tmp_path: Path) -> None: assert config.pat == "new-token" +def test_config_oauth_from_token_store(tmp_path: Path) -> None: + """When no PAT is available but OAuth tokens exist, auth_method should be 'oauth'.""" + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"project_id": "proj-1"})) + + fake_tokens = OAuthTokens(access_token="oauth-at", refresh_token="oauth-rt", client_id="cid") + + with ( + patch.dict(os.environ, {}, clear=False), + patch("drs.token_store.load", return_value=fake_tokens), + ): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.auth_method == "oauth" + assert config.oauth_access_token == "oauth-at" + assert config.pat is None + + +def test_config_pat_still_works(tmp_path: Path) -> None: + """PAT auth should continue to work exactly as before.""" + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"pat": "my-pat", "project_id": "proj-1"})) + + with patch.dict(os.environ, {}, clear=False): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.auth_method == "pat" + assert config.pat == "my-pat" + + +def test_config_neither_pat_nor_oauth_fails(tmp_path: Path) -> None: + """Missing both PAT and OAuth tokens should raise ValidationError.""" + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"project_id": "proj-1"})) + + with ( + patch.dict(os.environ, {}, clear=False), + patch("drs.token_store.load", return_value=None), + pytest.raises(ValidationError), + ): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + load_config(config_file) + + def test_cli_args_override_env(tmp_path: Path) -> None: """CLI args should override env vars and file values.""" config_file = tmp_path / "config.yaml" diff --git a/tests/test_client_retry.py b/tests/test_client_retry.py index 485229b..3a29222 100644 --- a/tests/test_client_retry.py +++ b/tests/test_client_retry.py @@ -22,6 +22,7 @@ import httpx import pytest +from drs.auth import DrsConfig from drs.client import DremioClient @@ -154,3 +155,87 @@ async def test_retry_backoff_delays(config) -> None: assert mock_sleep.call_count == 2 mock_sleep.assert_any_call(1.0) mock_sleep.assert_any_call(2.0) + + +@pytest.mark.asyncio +async def test_401_triggers_oauth_refresh() -> None: + """401 with OAuth auth should attempt token refresh and retry.""" + oauth_config = DrsConfig( + uri="https://api.dremio.cloud", + oauth_access_token="old-token", + project_id="proj", + auth_method="oauth", + ) + client = DremioClient(oauth_config) + + unauthorized = httpx.Response(401, request=httpx.Request("GET", "https://example.com")) + ok_response = httpx.Response(200, json={"ok": True}, request=httpx.Request("GET", "https://example.com")) + + client._client.request = AsyncMock(side_effect=[unauthorized, ok_response]) + + with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock) as mock_refresh: + result = await client._get("https://example.com/test") + + assert result == {"ok": True} + mock_refresh.assert_awaited_once() + assert client._client.request.call_count == 2 + + +@pytest.mark.asyncio +async def test_401_no_refresh_for_pat(config) -> None: + """401 with PAT auth should NOT attempt token refresh.""" + client = DremioClient(config) + + unauthorized = httpx.Response(401, json={"error": "unauthorized"}, request=httpx.Request("GET", "https://example.com")) + client._client.request = AsyncMock(return_value=unauthorized) + + with pytest.raises(httpx.HTTPStatusError): + await client._get("https://example.com/test") + + assert client._client.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_401_refresh_failure_returns_original_response() -> None: + """When refresh fails, the original 401 response should be returned.""" + oauth_config = DrsConfig( + uri="https://api.dremio.cloud", + oauth_access_token="old-token", + project_id="proj", + auth_method="oauth", + ) + client = DremioClient(oauth_config) + + unauthorized = httpx.Response(401, json={"error": "bad token"}, request=httpx.Request("GET", "https://example.com")) + client._client.request = AsyncMock(return_value=unauthorized) + + with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock, side_effect=RuntimeError("no refresh")): + with pytest.raises(httpx.HTTPStatusError): + await client._get("https://example.com/test") + + # Should only have tried once (no retry after failed refresh) + assert client._client.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_401_only_one_refresh_per_request() -> None: + """Only one refresh attempt per request — second 401 should not retry.""" + oauth_config = DrsConfig( + uri="https://api.dremio.cloud", + oauth_access_token="old-token", + project_id="proj", + auth_method="oauth", + ) + client = DremioClient(oauth_config) + + unauthorized = httpx.Response(401, request=httpx.Request("GET", "https://example.com")) + + # Both attempts return 401 + client._client.request = AsyncMock(return_value=unauthorized) + + with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock): + with pytest.raises(httpx.HTTPStatusError): + await client._get("https://example.com/test") + + # 1st call -> 401 -> refresh -> 2nd call -> 401 -> return (no more refresh) + assert client._client.request.call_count == 2 diff --git a/tests/test_commands/test_login.py b/tests/test_commands/test_login.py new file mode 100644 index 0000000..eb7fe18 --- /dev/null +++ b/tests/test_commands/test_login.py @@ -0,0 +1,102 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for dremio login / dremio logout commands.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from typer.testing import CliRunner + +from drs.cli import app +from drs.token_store import OAuthTokens, load, save + +runner = CliRunner() + + +def test_login_saves_tokens(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text("project_id: proj-1\n") + store_path = tmp_path / "oauth_tokens.yaml" + + fake_tokens = OAuthTokens( + access_token="at-new", + refresh_token="rt-new", + client_id="cid", + ) + + with ( + patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.commands.login.token_store.save") as mock_save, + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + patch("drs.commands.login._update_config_file"), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"]) + + assert result.exit_code == 0 + assert "successfully" in result.output.lower() or "logged in" in result.output.lower() + mock_save.assert_called_once() + saved_tokens = mock_save.call_args[0][1] + assert saved_tokens.access_token == "at-new" + + +def test_login_prompts_for_project_id(tmp_path: Path) -> None: + """When project_id is not in config, login should prompt for it.""" + config_path = tmp_path / "config.yaml" + # No project_id in config + config_path.write_text("uri: https://api.dremio.cloud\n") + + fake_tokens = OAuthTokens(access_token="at", client_id="cid") + + with ( + patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.commands.login.token_store.save"), + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + patch("drs.commands.login._update_config_file"), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"], input="my-project-id\n") + + assert result.exit_code == 0 + + +def test_logout_clears_tokens(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text("uri: https://api.dremio.cloud\nproject_id: proj-1\n") + + with ( + patch("drs.commands.login.token_store.clear") as mock_clear, + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + ): + result = runner.invoke(app, ["--config", str(config_path), "logout"]) + + assert result.exit_code == 0 + assert "logged out" in result.output.lower() or "removed" in result.output.lower() + mock_clear.assert_called_once_with("https://api.dremio.cloud") + + +def test_login_failure_exits_1(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text("project_id: proj-1\n") + + with ( + patch("drs.commands.login.oauth.run_login_flow", side_effect=RuntimeError("network error")), + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"]) + + assert result.exit_code == 1 + assert "failed" in result.output.lower() diff --git a/tests/test_commands/test_setup.py b/tests/test_commands/test_setup.py index 2bab050..697eeb7 100644 --- a/tests/test_commands/test_setup.py +++ b/tests/test_commands/test_setup.py @@ -27,6 +27,7 @@ from drs.auth import DEFAULT_URI from drs.cli import app from drs.commands.setup import validate_credentials, write_config +from drs.token_store import OAuthTokens runner = CliRunner() @@ -147,7 +148,7 @@ def test_write_config_includes_header(tmp_path) -> None: raw = config_path.read_text() assert raw.startswith("# Dremio CLI config") - assert "plaintext" in raw + assert "mode 600" in raw def test_setup_non_interactive(tmp_path) -> None: @@ -174,8 +175,8 @@ def test_setup_happy_path(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: region=1, PAT=test-pat, project_id=test-proj - result = runner.invoke(app, ["setup"], input="1\ntest-pat\ntest-proj\n") + # Input: region=1, auth=2(PAT), PAT=test-pat, project_id=test-proj + result = runner.invoke(app, ["setup"], input="1\n2\ntest-pat\ntest-proj\n") assert result.exit_code == 0 assert "Setup complete" in result.output @@ -219,8 +220,8 @@ def test_setup_existing_config_overwrite(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: overwrite=y, region=1, PAT=new-pat, project_id=new-proj - result = runner.invoke(app, ["setup"], input="y\n1\nnew-pat\nnew-proj\n") + # Input: overwrite=y, region=1, auth=2(PAT), PAT=new-pat, project_id=new-proj + result = runner.invoke(app, ["setup"], input="y\n1\n2\nnew-pat\nnew-proj\n") assert result.exit_code == 0 data = yaml.safe_load(config_path.read_text()) @@ -244,8 +245,8 @@ def test_setup_retry_then_abort(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: region=1, PAT=bad, project_id=p1, then decline retry - result = runner.invoke(app, ["setup"], input="1\nbad-pat\np1\nn\n") + # Input: region=1, auth=2(PAT), PAT=bad, project_id=p1, then decline retry + result = runner.invoke(app, ["setup"], input="1\n2\nbad-pat\np1\nn\n") assert result.exit_code == 1 assert "cancelled" in result.output.lower() @@ -265,9 +266,34 @@ def test_setup_global_config_passthrough(tmp_path) -> None: patch("drs.commands.setup.DremioClient", return_value=mock_client), ): mock_sys.stdin.isatty.return_value = True - result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\nmy-pat\nmy-proj\n") + result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\n2\nmy-pat\nmy-proj\n") assert result.exit_code == 0 assert config_path.exists() data = yaml.safe_load(config_path.read_text()) assert data["pat"] == "my-pat" + + +def test_setup_oauth_path(tmp_path) -> None: + """OAuth path through setup wizard: region, auth=1(OAuth), project_id.""" + config_path = tmp_path / "config.yaml" + + fake_tokens = OAuthTokens(access_token="at-oauth", client_id="cid") + + with ( + patch("drs.commands.setup.sys") as mock_sys, + patch("drs.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.token_store.save") as mock_save, + patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), + ): + mock_sys.stdin.isatty.return_value = True + # Input: region=1, auth=1(OAuth), project_id=my-proj + result = runner.invoke(app, ["setup"], input="1\n1\nmy-proj\n") + + assert result.exit_code == 0 + assert config_path.exists() + data = yaml.safe_load(config_path.read_text()) + assert data.get("auth_method") == "oauth" + assert "pat" not in data + assert data["project_id"] == "my-proj" + mock_save.assert_called_once() diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..0799ef9 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,267 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for drs.oauth — OAuth 2.0 PKCE flow mechanics.""" + +from __future__ import annotations + +import base64 +import hashlib +import urllib.request +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest + +from drs.oauth import ( + build_authorization_url, + discover, + exchange_code, + find_free_port, + generate_pkce, + refresh_access_token, + run_login_flow, + start_callback_server, +) + + +class TestPKCE: + def test_generate_pkce_format(self) -> None: + verifier, challenge = generate_pkce() + assert len(verifier) > 20 + assert len(challenge) > 20 + # No padding characters + assert "=" not in challenge + + def test_pkce_challenge_matches_verifier(self) -> None: + verifier, challenge = generate_pkce() + expected = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()).rstrip(b"=").decode("ascii") + assert challenge == expected + + def test_pkce_uniqueness(self) -> None: + v1, _ = generate_pkce() + v2, _ = generate_pkce() + assert v1 != v2 + + +class TestBuildAuthorizationURL: + def test_url_construction(self) -> None: + url = build_authorization_url( + auth_endpoint="https://auth.example.com/authorize", + client_id="my-client", + redirect_uri="http://localhost:8080/callback", + code_challenge="abc123", + state="state-xyz", + ) + parsed = urlparse(url) + params = parse_qs(parsed.query) + + assert parsed.scheme == "https" + assert parsed.netloc == "auth.example.com" + assert parsed.path == "/authorize" + assert params["response_type"] == ["code"] + assert params["client_id"] == ["my-client"] + assert params["redirect_uri"] == ["http://localhost:8080/callback"] + assert params["code_challenge"] == ["abc123"] + assert params["code_challenge_method"] == ["S256"] + assert params["scope"] == ["openid offline_access"] + assert params["state"] == ["state-xyz"] + + +class TestDiscover: + def test_discover_parses_metadata(self) -> None: + metadata_json = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + } + + mock_response = MagicMock() + mock_response.json.return_value = metadata_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.get", return_value=mock_response) as mock_get: + result = discover("https://api.dremio.cloud") + + mock_get.assert_called_once_with( + "https://api.dremio.cloud/.well-known/oauth-authorization-server", + timeout=30.0, + ) + assert result.authorization_endpoint == "https://auth.example.com/authorize" + assert result.token_endpoint == "https://auth.example.com/token" + assert result.registration_endpoint == "https://auth.example.com/register" + + def test_discover_optional_registration(self) -> None: + metadata_json = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + } + + mock_response = MagicMock() + mock_response.json.return_value = metadata_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.get", return_value=mock_response): + result = discover("https://api.dremio.cloud") + + assert result.registration_endpoint is None + + +class TestExchangeCode: + def test_exchange_code_success(self) -> None: + token_json = { + "access_token": "at-new", + "refresh_token": "rt-new", + "expires_in": 3600, + } + mock_response = MagicMock() + mock_response.json.return_value = token_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.post", return_value=mock_response): + tokens = exchange_code( + "https://auth.example.com/token", + "auth-code-123", + "http://localhost:8080/callback", + "my-client", + "my-secret", + "my-verifier", + ) + + assert tokens.access_token == "at-new" + assert tokens.refresh_token == "rt-new" + assert tokens.expires_at is not None + assert tokens.client_id == "my-client" + assert tokens.client_secret == "my-secret" + + +class TestRefreshAccessToken: + def test_refresh_success(self) -> None: + token_json = { + "access_token": "at-refreshed", + "refresh_token": "rt-new", + "expires_in": 3600, + } + mock_response = MagicMock() + mock_response.json.return_value = token_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.post", return_value=mock_response): + tokens = refresh_access_token( + "https://auth.example.com/token", + "my-client", + "my-secret", + "rt-old", + ) + + assert tokens.access_token == "at-refreshed" + assert tokens.refresh_token == "rt-new" + + def test_refresh_preserves_old_refresh_token(self) -> None: + """When server omits refresh_token in response, keep the old one.""" + token_json = { + "access_token": "at-refreshed", + "expires_in": 3600, + } + mock_response = MagicMock() + mock_response.json.return_value = token_json + mock_response.raise_for_status = MagicMock() + + with patch("drs.oauth.httpx.post", return_value=mock_response): + tokens = refresh_access_token( + "https://auth.example.com/token", + "my-client", + None, + "rt-old", + ) + + assert tokens.refresh_token == "rt-old" + + +class TestCallbackServer: + def test_callback_captures_code_and_state(self) -> None: + port = find_free_port() + server, future = start_callback_server(port) + try: + url = f"http://localhost:{port}/callback?code=test-code&state=test-state" + urllib.request.urlopen(url, timeout=5) + code, state = future.result(timeout=5) + assert code == "test-code" + assert state == "test-state" + finally: + server.shutdown() + + def test_find_free_port_returns_positive(self) -> None: + port = find_free_port() + assert port > 0 + + +class TestRunLoginFlow: + @pytest.mark.parametrize("browser_behavior", ["raises", "returns_false"]) + def test_headless_fallback_prints_url(self, browser_behavior: str) -> None: + """When webbrowser.open raises or returns False, the URL should be printed.""" + metadata_json = { + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + } + dcr_json = {"client_id": "cid", "client_secret": "cs"} + token_json = {"access_token": "at", "refresh_token": "rt", "expires_in": 3600} + + mock_get = MagicMock() + mock_get.json.return_value = metadata_json + mock_get.raise_for_status = MagicMock() + + post_responses = [] + # DCR call + dcr_resp = MagicMock() + dcr_resp.json.return_value = dcr_json + dcr_resp.raise_for_status = MagicMock() + post_responses.append(dcr_resp) + # Token exchange + token_resp = MagicMock() + token_resp.json.return_value = token_json + token_resp.raise_for_status = MagicMock() + post_responses.append(token_resp) + + if browser_behavior == "raises": + browser_mock = MagicMock(side_effect=RuntimeError("no browser")) + else: + browser_mock = MagicMock(return_value=False) + + with ( + patch("drs.oauth.httpx.get", return_value=mock_get), + patch("drs.oauth.httpx.post", side_effect=post_responses), + patch("drs.oauth.webbrowser.open", browser_mock), + patch("drs.oauth.start_callback_server") as mock_server, + patch("builtins.print") as mock_print, + ): + mock_future = MagicMock() + mock_future.result.return_value = ("code-123", None) + mock_srv = MagicMock() + mock_server.return_value = (mock_srv, mock_future) + + # The state won't match because we bypass the real flow, + # so we need to patch around the state check too. + try: + run_login_flow("https://api.dremio.cloud") + except RuntimeError: + pass # state mismatch expected in mocked test + + # Verify fallback print was called + mock_print.assert_called_once() + printed_text = mock_print.call_args[0][0] + assert "browser" in printed_text.lower() or "url" in printed_text.lower() or "http" in printed_text.lower() diff --git a/tests/test_token_store.py b/tests/test_token_store.py new file mode 100644 index 0000000..45c71eb --- /dev/null +++ b/tests/test_token_store.py @@ -0,0 +1,97 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for drs.token_store — OAuth token persistence.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from drs.token_store import OAuthTokens, clear, load, save + + +@pytest.fixture +def store_path(tmp_path: Path) -> Path: + return tmp_path / "oauth_tokens.yaml" + + +@pytest.fixture +def sample_tokens() -> OAuthTokens: + return OAuthTokens( + access_token="at-123", + refresh_token="rt-456", + expires_at=1700000000.0, + client_id="test-client", + client_secret="test-secret", + ) + + +def test_round_trip(store_path: Path, sample_tokens: OAuthTokens) -> None: + url = "https://api.dremio.cloud" + save(url, sample_tokens, store_path=store_path) + loaded = load(url, store_path=store_path) + assert loaded is not None + assert loaded == sample_tokens + + +def test_load_missing_url(store_path: Path, sample_tokens: OAuthTokens) -> None: + save("https://api.dremio.cloud", sample_tokens, store_path=store_path) + assert load("https://api.eu.dremio.cloud", store_path=store_path) is None + + +def test_load_nonexistent_file(store_path: Path) -> None: + assert load("https://api.dremio.cloud", store_path=store_path) is None + + +def test_clear_removes_entry(store_path: Path, sample_tokens: OAuthTokens) -> None: + url = "https://api.dremio.cloud" + save(url, sample_tokens, store_path=store_path) + clear(url, store_path=store_path) + assert load(url, store_path=store_path) is None + + +def test_clear_nonexistent_is_noop(store_path: Path) -> None: + clear("https://api.dremio.cloud", store_path=store_path) # should not raise + + +def test_file_mode_600(store_path: Path, sample_tokens: OAuthTokens) -> None: + save("https://api.dremio.cloud", sample_tokens, store_path=store_path) + assert oct(store_path.stat().st_mode & 0o777) == "0o600" + + +def test_multi_instance_keying(store_path: Path) -> None: + url_us = "https://api.dremio.cloud" + url_eu = "https://api.eu.dremio.cloud" + tokens_us = OAuthTokens(access_token="us-token", client_id="cid") + tokens_eu = OAuthTokens(access_token="eu-token", client_id="cid") + + save(url_us, tokens_us, store_path=store_path) + save(url_eu, tokens_eu, store_path=store_path) + + loaded_us = load(url_us, store_path=store_path) + loaded_eu = load(url_eu, store_path=store_path) + assert loaded_us is not None + assert loaded_eu is not None + assert loaded_us.access_token == "us-token" + assert loaded_eu.access_token == "eu-token" + + +def test_save_creates_parent_dirs(tmp_path: Path) -> None: + deep_path = tmp_path / "a" / "b" / "tokens.yaml" + tokens = OAuthTokens(access_token="at", client_id="cid") + save("https://api.dremio.cloud", tokens, store_path=deep_path) + assert deep_path.exists() From 96e2682f471ead284a1737a6100930fa1422301e Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 15 Apr 2026 02:02:28 +0000 Subject: [PATCH 2/8] fix: add --uri option to login/logout commands (DX-118868) Allow targeting a custom Dremio URL (e.g. app.dev.dremio.site) instead of defaulting to app.dremio.cloud. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/commands/login.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/drs/commands/login.py b/src/drs/commands/login.py index bb9979f..9f7179c 100644 --- a/src/drs/commands/login.py +++ b/src/drs/commands/login.py @@ -30,11 +30,12 @@ err_console = Console(stderr=True) -def _resolve_uri(ctx: typer.Context) -> str: - """Resolve the Dremio API URI from CLI flags, config file, or default.""" - # Check CLI --uri - config_path_obj = ctx.obj.get("config_path") if ctx.obj else None - # Check if parent set cli_uri +def _resolve_uri(ctx: typer.Context, explicit_uri: str | None = None) -> str: + """Resolve the Dremio API URI from explicit arg, CLI flags, config file, or default.""" + if explicit_uri: + return explicit_uri + + # Check if parent set cli_uri (global --uri flag) from drs.cli import _cli_opts cli_uri = _cli_opts.get("cli_uri") @@ -42,6 +43,7 @@ def _resolve_uri(ctx: typer.Context) -> str: return cli_uri # Try config file + config_path_obj = ctx.obj.get("config_path") if ctx.obj else None path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH if path.exists(): with path.open() as f: @@ -74,9 +76,12 @@ def _resolve_project_id(ctx: typer.Context) -> str: return typer.prompt("Enter your Dremio Cloud Project ID").strip() -def login_command(ctx: typer.Context) -> None: +def login_command( + ctx: typer.Context, + uri: str = typer.Option(None, "--uri", "-u", help="Dremio API URL (e.g. https://app.dev.dremio.site)"), +) -> None: """Log in to Dremio Cloud via OAuth (opens your browser).""" - uri = _resolve_uri(ctx) + uri = _resolve_uri(ctx, explicit_uri=uri) console.print(f"\nLogging in to [bold]{uri}[/bold] ...") try: @@ -120,8 +125,11 @@ def _update_config_file(config_path: Path, uri: str, project_id: str) -> None: config_path.chmod(0o600) -def logout_command(ctx: typer.Context) -> None: +def logout_command( + ctx: typer.Context, + uri: str = typer.Option(None, "--uri", "-u", help="Dremio API URL to log out from"), +) -> None: """Log out of Dremio Cloud (removes stored OAuth tokens).""" - uri = _resolve_uri(ctx) + uri = _resolve_uri(ctx, explicit_uri=uri) token_store.clear(uri) console.print(f"Logged out of [bold]{uri}[/bold]. OAuth tokens removed.") From 9c30dac64abddf32a064da67dc8eed64c791634c Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 15 Apr 2026 02:04:33 +0000 Subject: [PATCH 3/8] fix: derive login host from app URL for OAuth discovery (DX-118868) The .well-known/oauth-authorization-server endpoint lives on the login subdomain (login.X), not the app subdomain (app.X). Rewrite app.* to login.* when constructing the discovery URL. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/oauth.py | 12 +++++++++++- tests/test_oauth.py | 21 ++++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/drs/oauth.py b/src/drs/oauth.py index 67f22d2..1784267 100644 --- a/src/drs/oauth.py +++ b/src/drs/oauth.py @@ -55,9 +55,19 @@ class OAuthServerMetadata: registration_endpoint: str | None = None +def _login_url(dremio_url: str) -> str: + """Derive the OAuth login host from a Dremio URL (app.X -> login.X).""" + parsed = urlparse(dremio_url) + host = parsed.hostname or "" + if host.startswith("app."): + host = "login." + host[4:] + return f"{parsed.scheme}://{host}" + + def discover(dremio_url: str) -> OAuthServerMetadata: """Fetch OAuth Authorization Server Metadata from *dremio_url*.""" - url = f"{dremio_url}/.well-known/oauth-authorization-server" + base = _login_url(dremio_url) + url = f"{base}/.well-known/oauth-authorization-server" resp = httpx.get(url, timeout=30.0) resp.raise_for_status() data = resp.json() diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 0799ef9..bb3d2fe 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -81,6 +81,21 @@ def test_url_construction(self) -> None: assert params["state"] == ["state-xyz"] +class TestLoginUrl: + def test_app_to_login(self) -> None: + from drs.oauth import _login_url + + assert _login_url("https://app.dremio.cloud") == "https://login.dremio.cloud" + assert _login_url("https://app.dev.dremio.site") == "https://login.dev.dremio.site" + assert _login_url("https://app.eu.dremio.cloud") == "https://login.eu.dremio.cloud" + + def test_non_app_host_unchanged(self) -> None: + from drs.oauth import _login_url + + assert _login_url("https://login.dremio.cloud") == "https://login.dremio.cloud" + assert _login_url("https://custom.example.com") == "https://custom.example.com" + + class TestDiscover: def test_discover_parses_metadata(self) -> None: metadata_json = { @@ -94,10 +109,10 @@ def test_discover_parses_metadata(self) -> None: mock_response.raise_for_status = MagicMock() with patch("drs.oauth.httpx.get", return_value=mock_response) as mock_get: - result = discover("https://api.dremio.cloud") + result = discover("https://app.dremio.cloud") mock_get.assert_called_once_with( - "https://api.dremio.cloud/.well-known/oauth-authorization-server", + "https://login.dremio.cloud/.well-known/oauth-authorization-server", timeout=30.0, ) assert result.authorization_endpoint == "https://auth.example.com/authorize" @@ -115,7 +130,7 @@ def test_discover_optional_registration(self) -> None: mock_response.raise_for_status = MagicMock() with patch("drs.oauth.httpx.get", return_value=mock_response): - result = discover("https://api.dremio.cloud") + result = discover("https://app.dremio.cloud") assert result.registration_endpoint is None From 088551974d015bb74910e8eff9db98b3b1072d0c Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 15 Apr 2026 02:06:24 +0000 Subject: [PATCH 4/8] fix: use client_name for DCR and fall back when DCR is rejected (DX-118868) - Send `client_name` (not `client_id`) in DCR body per RFC 7591 - Gracefully fall back to the well-known client_id when the server rejects DCR (400/403), as Dremio servers may not allow open registration Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/oauth.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/drs/oauth.py b/src/drs/oauth.py index 1784267..fb6f153 100644 --- a/src/drs/oauth.py +++ b/src/drs/oauth.py @@ -79,14 +79,20 @@ def discover(dremio_url: str) -> OAuthServerMetadata: def register_client(registration_endpoint: str, redirect_uri: str) -> tuple[str, str | None]: - """Dynamic Client Registration. Returns ``(client_id, client_secret)``.""" + """Dynamic Client Registration. Returns ``(client_id, client_secret)``. + + Returns ``None`` if the server does not support DCR (403/400). + """ body = { - "client_id": _CLIENT_ID, + "client_name": _CLIENT_ID, "redirect_uris": [redirect_uri], "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], } resp = httpx.post(registration_endpoint, json=body, timeout=30.0) + if resp.status_code in (400, 403): + logger.info("DCR not available (%s) — using well-known client_id.", resp.status_code) + return None, None resp.raise_for_status() data = resp.json() return data["client_id"], data.get("client_secret") @@ -241,9 +247,10 @@ def run_login_flow(dremio_url: str) -> OAuthTokens: port = find_free_port() redirect_uri = f"http://localhost:{port}/callback" + client_id, client_secret = None, None if metadata.registration_endpoint: client_id, client_secret = register_client(metadata.registration_endpoint, redirect_uri) - else: + if not client_id: client_id, client_secret = _CLIENT_ID, None code_verifier, code_challenge = generate_pkce() From 93e6c481386b0f2bc72a510371f11f3dfd4aaf46 Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 15 Apr 2026 02:07:36 +0000 Subject: [PATCH 5/8] fix: use dremio.all scope instead of openid for OAuth (DX-118868) Dremio's OAuth server supports `dremio.all offline_access`, not `openid offline_access`. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/oauth.py | 2 +- tests/test_oauth.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/drs/oauth.py b/src/drs/oauth.py index fb6f153..07bc08a 100644 --- a/src/drs/oauth.py +++ b/src/drs/oauth.py @@ -120,7 +120,7 @@ def build_authorization_url( "redirect_uri": redirect_uri, "code_challenge": code_challenge, "code_challenge_method": "S256", - "scope": "openid offline_access", + "scope": "dremio.all offline_access", "state": state, } return f"{auth_endpoint}?{urlencode(params)}" diff --git a/tests/test_oauth.py b/tests/test_oauth.py index bb3d2fe..d88176f 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -77,7 +77,7 @@ def test_url_construction(self) -> None: assert params["redirect_uri"] == ["http://localhost:8080/callback"] assert params["code_challenge"] == ["abc123"] assert params["code_challenge_method"] == ["S256"] - assert params["scope"] == ["openid offline_access"] + assert params["scope"] == ["dremio.all offline_access"] assert params["state"] == ["state-xyz"] From da7baa17b7fadfb86d9ba89fbbe6fdd8cac2f507 Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 15 Apr 2026 02:08:57 +0000 Subject: [PATCH 6/8] fix: use /Callback (capital C) for OAuth redirect URI (DX-118868) The Dremio OAuth server's allow-listed redirect path is case-sensitive: /Callback, not /callback. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/oauth.py | 2 +- tests/test_oauth.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/drs/oauth.py b/src/drs/oauth.py index 07bc08a..c78ecf9 100644 --- a/src/drs/oauth.py +++ b/src/drs/oauth.py @@ -245,7 +245,7 @@ def run_login_flow(dremio_url: str) -> OAuthTokens: metadata = discover(dremio_url) port = find_free_port() - redirect_uri = f"http://localhost:{port}/callback" + redirect_uri = f"http://localhost:{port}/Callback" client_id, client_secret = None, None if metadata.registration_endpoint: diff --git a/tests/test_oauth.py b/tests/test_oauth.py index d88176f..0339ffc 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -62,7 +62,7 @@ def test_url_construction(self) -> None: url = build_authorization_url( auth_endpoint="https://auth.example.com/authorize", client_id="my-client", - redirect_uri="http://localhost:8080/callback", + redirect_uri="http://localhost:8080/Callback", code_challenge="abc123", state="state-xyz", ) @@ -74,7 +74,7 @@ def test_url_construction(self) -> None: assert parsed.path == "/authorize" assert params["response_type"] == ["code"] assert params["client_id"] == ["my-client"] - assert params["redirect_uri"] == ["http://localhost:8080/callback"] + assert params["redirect_uri"] == ["http://localhost:8080/Callback"] assert params["code_challenge"] == ["abc123"] assert params["code_challenge_method"] == ["S256"] assert params["scope"] == ["dremio.all offline_access"] @@ -150,7 +150,7 @@ def test_exchange_code_success(self) -> None: tokens = exchange_code( "https://auth.example.com/token", "auth-code-123", - "http://localhost:8080/callback", + "http://localhost:8080/Callback", "my-client", "my-secret", "my-verifier", @@ -211,7 +211,7 @@ def test_callback_captures_code_and_state(self) -> None: port = find_free_port() server, future = start_callback_server(port) try: - url = f"http://localhost:{port}/callback?code=test-code&state=test-state" + url = f"http://localhost:{port}/Callback?code=test-code&state=test-state" urllib.request.urlopen(url, timeout=5) code, state = future.result(timeout=5) assert code == "test-code" From b0331845662eef1a2e0e76baf0d63967305cfd09 Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 15 Apr 2026 02:11:36 +0000 Subject: [PATCH 7/8] feat: show interactive project picker during dremio login (DX-118868) After OAuth login succeeds, fetch the project list and let the user pick from a numbered menu instead of typing a project ID manually. Auto-selects when only one project exists. Falls back to manual prompt if the project list fetch fails. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/commands/login.py | 72 ++++++++++++++++++++++++++++--- tests/test_commands/test_login.py | 40 ++++++++++++++--- 2 files changed, 102 insertions(+), 10 deletions(-) diff --git a/src/drs/commands/login.py b/src/drs/commands/login.py index 9f7179c..982964e 100644 --- a/src/drs/commands/login.py +++ b/src/drs/commands/login.py @@ -18,10 +18,13 @@ from __future__ import annotations from pathlib import Path +from urllib.parse import urlparse +import httpx import typer import yaml from rich.console import Console +from rich.panel import Panel from drs import oauth, token_store from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI @@ -55,8 +58,51 @@ def _resolve_uri(ctx: typer.Context, explicit_uri: str | None = None) -> str: return DEFAULT_URI -def _resolve_project_id(ctx: typer.Context) -> str: - """Resolve project_id from CLI flags, config file, or prompt the user.""" +def _api_url(uri: str) -> str: + """Derive the API URL from a Dremio URL (app.X -> api.X).""" + parsed = urlparse(uri) + host = parsed.hostname or "" + if host.startswith("app."): + host = "api." + host[4:] + return f"{parsed.scheme}://{host}" + + +def _fetch_projects(api_base: str, access_token: str) -> list[dict]: + """Fetch the list of projects using the OAuth access token.""" + resp = httpx.get( + f"{api_base}/v0/projects", + headers={"Authorization": f"Bearer {access_token}"}, + timeout=30.0, + ) + resp.raise_for_status() + data = resp.json() + return data.get("data", data) if isinstance(data, dict) else data + + +def _prompt_project_selection(projects: list[dict]) -> str: + """Display a numbered list of projects and let the user choose.""" + console.print() + lines = "[bold]Select a project:[/bold]\n" + for i, proj in enumerate(projects, 1): + name = proj.get("name", "unnamed") + pid = proj.get("id", "") + lines += f"\n [cyan]{i}[/cyan]) {name} [dim]({pid})[/dim]" + console.print(Panel(lines, title="Projects", border_style="blue")) + choice = typer.prompt(f"Enter 1-{len(projects)}").strip() + try: + idx = int(choice) - 1 + if 0 <= idx < len(projects): + selected = projects[idx] + console.print(f" -> [bold]{selected.get('name')}[/bold]") + return selected["id"] + except (ValueError, KeyError): + pass + err_console.print("[yellow]Invalid choice — please enter a project ID manually.[/yellow]") + return typer.prompt("Enter your Dremio Cloud Project ID").strip() + + +def _resolve_project_id(ctx: typer.Context, uri: str, access_token: str) -> str: + """Resolve project_id from CLI flags, config, or interactive project picker.""" from drs.cli import _cli_opts cli_project_id = _cli_opts.get("cli_project_id") @@ -72,8 +118,24 @@ def _resolve_project_id(ctx: typer.Context) -> str: if project_id: return project_id - # Prompt user - return typer.prompt("Enter your Dremio Cloud Project ID").strip() + # Fetch projects and let the user pick + api_base = _api_url(uri) + try: + projects = _fetch_projects(api_base, access_token) + except Exception: + console.print("[yellow]Could not fetch project list.[/yellow]") + return typer.prompt("Enter your Dremio Cloud Project ID").strip() + + if not projects: + console.print("[yellow]No projects found in this organization.[/yellow]") + return typer.prompt("Enter your Dremio Cloud Project ID").strip() + + if len(projects) == 1: + proj = projects[0] + console.print(f" Auto-selected project: [bold]{proj.get('name')}[/bold] ({proj['id']})") + return proj["id"] + + return _prompt_project_selection(projects) def login_command( @@ -91,7 +153,7 @@ def login_command( raise typer.Exit(1) # Ensure we have a project_id to write into the config - project_id = _resolve_project_id(ctx) + project_id = _resolve_project_id(ctx, uri, tokens.access_token) token_store.save(uri, tokens) diff --git a/tests/test_commands/test_login.py b/tests/test_commands/test_login.py index eb7fe18..bf20f2b 100644 --- a/tests/test_commands/test_login.py +++ b/tests/test_commands/test_login.py @@ -54,23 +54,53 @@ def test_login_saves_tokens(tmp_path: Path) -> None: assert saved_tokens.access_token == "at-new" -def test_login_prompts_for_project_id(tmp_path: Path) -> None: - """When project_id is not in config, login should prompt for it.""" +def test_login_picks_project_from_list(tmp_path: Path) -> None: + """When project_id is not in config, login should show project list.""" config_path = tmp_path / "config.yaml" - # No project_id in config config_path.write_text("uri: https://api.dremio.cloud\n") fake_tokens = OAuthTokens(access_token="at", client_id="cid") + fake_projects = [ + {"id": "proj-aaa", "name": "Alpha"}, + {"id": "proj-bbb", "name": "Beta"}, + ] with ( patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens), patch("drs.commands.login.token_store.save"), patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), - patch("drs.commands.login._update_config_file"), + patch("drs.commands.login._update_config_file") as mock_update, + patch("drs.commands.login._fetch_projects", return_value=fake_projects), + ): + result = runner.invoke(app, ["--config", str(config_path), "login"], input="2\n") + + assert result.exit_code == 0 + assert "Beta" in result.output + mock_update.assert_called_once() + assert mock_update.call_args[0][2] == "proj-bbb" + + +def test_login_auto_selects_single_project(tmp_path: Path) -> None: + """When only one project exists, auto-select it.""" + config_path = tmp_path / "config.yaml" + config_path.write_text("uri: https://api.dremio.cloud\n") + + fake_tokens = OAuthTokens(access_token="at", client_id="cid") + fake_projects = [{"id": "proj-only", "name": "OnlyProject"}] + + with ( + patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens), + patch("drs.commands.login.token_store.save"), + patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path), + patch("drs.commands.login._update_config_file") as mock_update, + patch("drs.commands.login._fetch_projects", return_value=fake_projects), ): - result = runner.invoke(app, ["--config", str(config_path), "login"], input="my-project-id\n") + result = runner.invoke(app, ["--config", str(config_path), "login"]) assert result.exit_code == 0 + assert "OnlyProject" in result.output + mock_update.assert_called_once() + assert mock_update.call_args[0][2] == "proj-only" def test_logout_clears_tokens(tmp_path: Path) -> None: From 840295d1dd39f00d1ffc04a64493362be2742870 Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 15 Apr 2026 02:15:59 +0000 Subject: [PATCH 8/8] feat: filter active/hibernated projects and show details in picker (DX-118868) - Only show ACTIVE and HIBERNATED projects (skip DELETED etc.) - Display project name, description, state, and creation date - Richer formatting in the project selection panel Co-Authored-By: Claude Opus 4.6 (1M context) --- src/drs/commands/login.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/drs/commands/login.py b/src/drs/commands/login.py index 982964e..35dbacb 100644 --- a/src/drs/commands/login.py +++ b/src/drs/commands/login.py @@ -67,8 +67,11 @@ def _api_url(uri: str) -> str: return f"{parsed.scheme}://{host}" +_ACTIVE_STATES = {"ACTIVE", "HIBERNATED"} + + def _fetch_projects(api_base: str, access_token: str) -> list[dict]: - """Fetch the list of projects using the OAuth access token.""" + """Fetch active/hibernated projects using the OAuth access token.""" resp = httpx.get( f"{api_base}/v0/projects", headers={"Authorization": f"Bearer {access_token}"}, @@ -76,7 +79,15 @@ def _fetch_projects(api_base: str, access_token: str) -> list[dict]: ) resp.raise_for_status() data = resp.json() - return data.get("data", data) if isinstance(data, dict) else data + projects = data.get("data", data) if isinstance(data, dict) else data + return [p for p in projects if p.get("state", "").upper() in _ACTIVE_STATES] + + +def _format_date(raw: str | None) -> str: + """Format an ISO timestamp to a short date string.""" + if not raw: + return "" + return raw[:10] # YYYY-MM-DD def _prompt_project_selection(projects: list[dict]) -> str: @@ -86,7 +97,15 @@ def _prompt_project_selection(projects: list[dict]) -> str: for i, proj in enumerate(projects, 1): name = proj.get("name", "unnamed") pid = proj.get("id", "") - lines += f"\n [cyan]{i}[/cyan]) {name} [dim]({pid})[/dim]" + desc = proj.get("description", "") + state = proj.get("state", "") + created = _format_date(proj.get("createdAt")) + lines += f"\n [cyan]{i}[/cyan]) [bold]{name}[/bold] [dim]({pid})[/dim]" + if desc: + lines += f"\n {desc}" + details = [s for s in [state, f"created {created}" if created else ""] if s] + if details: + lines += f"\n [dim]{' · '.join(details)}[/dim]" console.print(Panel(lines, title="Projects", border_style="blue")) choice = typer.prompt(f"Enter 1-{len(projects)}").strip() try: