diff --git a/src/drs/auth.py b/src/drs/auth.py
index 051f365..68ea573 100644
--- a/src/drs/auth.py
+++ b/src/drs/auth.py
@@ -19,10 +19,10 @@
import os
from pathlib import Path
-from typing import Any
+from typing import Any, Literal
import yaml
-from pydantic import BaseModel
+from pydantic import BaseModel, model_validator
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "dremioai" / "config.yaml"
DEFAULT_URI = "https://api.dremio.cloud"
@@ -30,8 +30,16 @@
class DrsConfig(BaseModel):
uri: str = DEFAULT_URI
- pat: str
+ pat: str | None = None
project_id: str
+ auth_method: Literal["pat", "oauth"] = "pat"
+ oauth_access_token: str | None = None
+
+ @model_validator(mode="after")
+ def _require_credential(self) -> "DrsConfig":
+ if not self.pat and not self.oauth_access_token:
+ raise ValueError("Either 'pat' or 'oauth_access_token' must be provided.")
+ return self
def load_config(
@@ -58,6 +66,7 @@ def load_config(
"uri": raw.get("uri", raw.get("endpoint")),
"pat": raw.get("pat", raw.get("token")),
"project_id": raw.get("project_id", raw.get("projectId")),
+ "auth_method": raw.get("auth_method"),
}
file_values = {k: v for k, v in file_values.items() if v is not None}
@@ -83,4 +92,15 @@ def load_config(
if cli_token:
merged["pat"] = cli_token
+ # If no PAT is available, try loading OAuth tokens from the token store.
+ if "pat" not in merged or not merged["pat"]:
+ from drs import token_store
+
+ uri = merged.get("uri", DEFAULT_URI)
+ tokens = token_store.load(uri)
+ if tokens is not None:
+ merged["auth_method"] = "oauth"
+ merged["oauth_access_token"] = tokens.access_token
+ merged.pop("pat", None)
+
return DrsConfig(**merged)
diff --git a/src/drs/cli.py b/src/drs/cli.py
index 8a86e4d..34ece06 100644
--- a/src/drs/cli.py
+++ b/src/drs/cli.py
@@ -34,6 +34,7 @@
folder,
grant,
job,
+ login,
project,
query,
reflection,
@@ -69,6 +70,8 @@
app.add_typer(project.app, name="project")
app.add_typer(chat.app, name="chat")
app.command("setup")(setup.setup_command)
+app.command("login")(login.login_command)
+app.command("logout")(login.logout_command)
# Global state for config
_config: DrsConfig | None = None
@@ -137,8 +140,9 @@ def get_config() -> DrsConfig:
Console(stderr=True).print(
"\n[bold red]Configuration required[/bold red]\n\n"
- "The Dremio CLI needs a Personal Access Token and Project ID.\n\n"
- " [bold]Quick setup:[/] Run [bold cyan]dremio setup[/bold cyan]\n\n"
+ "The Dremio CLI needs authentication credentials and a Project ID.\n\n"
+ " [bold]Quick setup:[/] Run [bold cyan]dremio setup[/bold cyan]\n"
+ " [bold]OAuth login:[/] Run [bold cyan]dremio login[/bold cyan]\n\n"
" [dim]Or provide credentials manually:[/dim]\n"
" --token / DREMIO_TOKEN env var\n"
" --project-id / DREMIO_PROJECT_ID env var\n"
diff --git a/src/drs/client.py b/src/drs/client.py
index e1ab86a..10d2a1d 100644
--- a/src/drs/client.py
+++ b/src/drs/client.py
@@ -44,9 +44,10 @@ class DremioClient:
def __init__(self, config: DrsConfig) -> None:
self.config = config
+ token = config.oauth_access_token if "oauth" == config.auth_method else config.pat
self._client = httpx.AsyncClient(
headers={
- "Authorization": f"Bearer {config.pat}",
+ "Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
timeout=120.0,
@@ -70,12 +71,46 @@ def _v1(self, path: str) -> str:
# -- HTTP helpers with retry --
+ async def _refresh_oauth_token(self) -> None:
+ """Refresh the OAuth access token using the stored refresh token."""
+ from drs import oauth, token_store
+
+ tokens = token_store.load(self.config.uri)
+ if tokens is None or tokens.refresh_token is None:
+ raise RuntimeError("No refresh token available — please run 'dremio login' again.")
+
+ metadata = oauth.discover(self.config.uri)
+ new_tokens = oauth.refresh_access_token(
+ metadata.token_endpoint,
+ tokens.client_id,
+ tokens.client_secret,
+ tokens.refresh_token,
+ )
+ token_store.save(self.config.uri, new_tokens)
+ self._client.headers["authorization"] = f"Bearer {new_tokens.access_token}"
+ logger.info("OAuth access token refreshed successfully.")
+
async def _request_with_retry(self, method: str, url: str, **kwargs: Any) -> httpx.Response:
"""Execute an HTTP request with retry on transient errors."""
+ auth_refreshed = False
last_exc: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
resp = await self._client.request(method, url, **kwargs)
+ # 401 with OAuth: attempt one token refresh per call
+ if (
+ resp.status_code == 401
+ and "oauth" == self.config.auth_method
+ and not auth_refreshed
+ ):
+ logger.info("Received 401 — attempting OAuth token refresh.")
+ try:
+ await self._refresh_oauth_token()
+ except Exception:
+ logger.warning("OAuth token refresh failed.", exc_info=True)
+ return resp
+ auth_refreshed = True
+ continue
if resp.status_code in _RETRYABLE_STATUS_CODES and attempt < _MAX_RETRIES - 1:
delay = _RETRY_BACKOFF[attempt]
logger.warning(
diff --git a/src/drs/commands/login.py b/src/drs/commands/login.py
new file mode 100644
index 0000000..35dbacb
--- /dev/null
+++ b/src/drs/commands/login.py
@@ -0,0 +1,216 @@
+#
+# Copyright (C) 2017-2026 Dremio Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""``dremio login`` and ``dremio logout`` commands — OAuth browser flow."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from urllib.parse import urlparse
+
+import httpx
+import typer
+import yaml
+from rich.console import Console
+from rich.panel import Panel
+
+from drs import oauth, token_store
+from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI
+
+console = Console()
+err_console = Console(stderr=True)
+
+
+def _resolve_uri(ctx: typer.Context, explicit_uri: str | None = None) -> str:
+ """Resolve the Dremio API URI from explicit arg, CLI flags, config file, or default."""
+ if explicit_uri:
+ return explicit_uri
+
+ # Check if parent set cli_uri (global --uri flag)
+ from drs.cli import _cli_opts
+
+ cli_uri = _cli_opts.get("cli_uri")
+ if cli_uri:
+ return cli_uri
+
+ # Try config file
+ config_path_obj = ctx.obj.get("config_path") if ctx.obj else None
+ path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH
+ if path.exists():
+ with path.open() as f:
+ raw = yaml.safe_load(f) or {}
+ uri = raw.get("uri", raw.get("endpoint"))
+ if uri:
+ return uri
+
+ return DEFAULT_URI
+
+
+def _api_url(uri: str) -> str:
+ """Derive the API URL from a Dremio URL (app.X -> api.X)."""
+ parsed = urlparse(uri)
+ host = parsed.hostname or ""
+ if host.startswith("app."):
+ host = "api." + host[4:]
+ return f"{parsed.scheme}://{host}"
+
+
+_ACTIVE_STATES = {"ACTIVE", "HIBERNATED"}
+
+
+def _fetch_projects(api_base: str, access_token: str) -> list[dict]:
+ """Fetch active/hibernated projects using the OAuth access token."""
+ resp = httpx.get(
+ f"{api_base}/v0/projects",
+ headers={"Authorization": f"Bearer {access_token}"},
+ timeout=30.0,
+ )
+ resp.raise_for_status()
+ data = resp.json()
+ projects = data.get("data", data) if isinstance(data, dict) else data
+ return [p for p in projects if p.get("state", "").upper() in _ACTIVE_STATES]
+
+
+def _format_date(raw: str | None) -> str:
+ """Format an ISO timestamp to a short date string."""
+ if not raw:
+ return ""
+ return raw[:10] # YYYY-MM-DD
+
+
+def _prompt_project_selection(projects: list[dict]) -> str:
+ """Display a numbered list of projects and let the user choose."""
+ console.print()
+ lines = "[bold]Select a project:[/bold]\n"
+ for i, proj in enumerate(projects, 1):
+ name = proj.get("name", "unnamed")
+ pid = proj.get("id", "")
+ desc = proj.get("description", "")
+ state = proj.get("state", "")
+ created = _format_date(proj.get("createdAt"))
+ lines += f"\n [cyan]{i}[/cyan]) [bold]{name}[/bold] [dim]({pid})[/dim]"
+ if desc:
+ lines += f"\n {desc}"
+ details = [s for s in [state, f"created {created}" if created else ""] if s]
+ if details:
+ lines += f"\n [dim]{' · '.join(details)}[/dim]"
+ console.print(Panel(lines, title="Projects", border_style="blue"))
+ choice = typer.prompt(f"Enter 1-{len(projects)}").strip()
+ try:
+ idx = int(choice) - 1
+ if 0 <= idx < len(projects):
+ selected = projects[idx]
+ console.print(f" -> [bold]{selected.get('name')}[/bold]")
+ return selected["id"]
+ except (ValueError, KeyError):
+ pass
+ err_console.print("[yellow]Invalid choice — please enter a project ID manually.[/yellow]")
+ return typer.prompt("Enter your Dremio Cloud Project ID").strip()
+
+
+def _resolve_project_id(ctx: typer.Context, uri: str, access_token: str) -> str:
+ """Resolve project_id from CLI flags, config, or interactive project picker."""
+ from drs.cli import _cli_opts
+
+ cli_project_id = _cli_opts.get("cli_project_id")
+ if cli_project_id:
+ return cli_project_id
+
+ config_path_obj = ctx.obj.get("config_path") if ctx.obj else None
+ path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH
+ if path.exists():
+ with path.open() as f:
+ raw = yaml.safe_load(f) or {}
+ project_id = raw.get("project_id", raw.get("projectId"))
+ if project_id:
+ return project_id
+
+ # Fetch projects and let the user pick
+ api_base = _api_url(uri)
+ try:
+ projects = _fetch_projects(api_base, access_token)
+ except Exception:
+ console.print("[yellow]Could not fetch project list.[/yellow]")
+ return typer.prompt("Enter your Dremio Cloud Project ID").strip()
+
+ if not projects:
+ console.print("[yellow]No projects found in this organization.[/yellow]")
+ return typer.prompt("Enter your Dremio Cloud Project ID").strip()
+
+ if len(projects) == 1:
+ proj = projects[0]
+ console.print(f" Auto-selected project: [bold]{proj.get('name')}[/bold] ({proj['id']})")
+ return proj["id"]
+
+ return _prompt_project_selection(projects)
+
+
+def login_command(
+ ctx: typer.Context,
+ uri: str = typer.Option(None, "--uri", "-u", help="Dremio API URL (e.g. https://app.dev.dremio.site)"),
+) -> None:
+ """Log in to Dremio Cloud via OAuth (opens your browser)."""
+ uri = _resolve_uri(ctx, explicit_uri=uri)
+ console.print(f"\nLogging in to [bold]{uri}[/bold] ...")
+
+ try:
+ tokens = oauth.run_login_flow(uri)
+ except Exception as exc:
+ err_console.print(f"\n[bold red]Login failed:[/bold red] {exc}")
+ raise typer.Exit(1)
+
+ # Ensure we have a project_id to write into the config
+ project_id = _resolve_project_id(ctx, uri, tokens.access_token)
+
+ token_store.save(uri, tokens)
+
+ # Also persist auth_method + project_id in config file so subsequent
+ # commands pick up OAuth automatically.
+ config_path_obj = ctx.obj.get("config_path") if ctx.obj else None
+ config_path: Path = config_path_obj if config_path_obj else DEFAULT_CONFIG_PATH
+ _update_config_file(config_path, uri, project_id)
+
+ console.print(f"\n[green]Logged in successfully.[/green] Tokens saved for {uri}")
+
+
+def _update_config_file(config_path: Path, uri: str, project_id: str) -> None:
+ """Ensure the config file records auth_method=oauth and project_id."""
+ data: dict = {}
+ if config_path.exists():
+ with config_path.open() as f:
+ data = yaml.safe_load(f) or {}
+
+ if uri != DEFAULT_URI:
+ data["uri"] = uri
+ data["auth_method"] = "oauth"
+ data["project_id"] = project_id
+ # Remove PAT if present — OAuth replaces it.
+ data.pop("pat", None)
+ data.pop("token", None)
+
+ config_path.parent.mkdir(parents=True, exist_ok=True)
+ header = "# Dremio CLI config — generated by 'dremio login'\n"
+ config_path.write_text(header + yaml.dump(data, default_flow_style=False, sort_keys=False))
+ config_path.chmod(0o600)
+
+
+def logout_command(
+ ctx: typer.Context,
+ uri: str = typer.Option(None, "--uri", "-u", help="Dremio API URL to log out from"),
+) -> None:
+ """Log out of Dremio Cloud (removes stored OAuth tokens)."""
+ uri = _resolve_uri(ctx, explicit_uri=uri)
+ token_store.clear(uri)
+ console.print(f"Logged out of [bold]{uri}[/bold]. OAuth tokens removed.")
diff --git a/src/drs/commands/setup.py b/src/drs/commands/setup.py
index fb77ea2..9775e78 100644
--- a/src/drs/commands/setup.py
+++ b/src/drs/commands/setup.py
@@ -32,6 +32,11 @@
from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI, DrsConfig
from drs.client import DremioClient
+AUTH_METHODS = {
+ "1": "oauth",
+ "2": "pat",
+}
+
REGIONS = {
"1": ("US", "https://api.dremio.cloud", "https://app.dremio.cloud"),
"2": ("EU", "https://api.eu.dremio.cloud", "https://app.eu.dremio.cloud"),
@@ -65,16 +70,25 @@ async def validate_credentials(uri: str, pat: str, project_id: str) -> tuple[boo
await client.close()
-def write_config(uri: str, pat: str, project_id: str, config_path: Path) -> None:
+def write_config(
+ uri: str,
+ pat: str | None,
+ project_id: str,
+ config_path: Path,
+ auth_method: str = "pat",
+) -> None:
"""Write YAML config file, creating parent directories as needed."""
data: dict[str, str] = {}
if uri != DEFAULT_URI:
data["uri"] = uri
- data["pat"] = pat
+ if auth_method != "pat":
+ data["auth_method"] = auth_method
+ if pat:
+ data["pat"] = pat
data["project_id"] = project_id
config_path.parent.mkdir(parents=True, exist_ok=True)
- header = "# Dremio CLI config — generated by 'dremio setup'\n# PAT is stored in plaintext. Keep this file private (mode 600).\n"
+ header = "# Dremio CLI config — generated by 'dremio setup'\n# Keep this file private (mode 600).\n"
config_path.write_text(header + yaml.dump(data, default_flow_style=False, sort_keys=False))
config_path.chmod(0o600)
@@ -100,12 +114,33 @@ def _prompt_region() -> tuple[str, str]:
return api_uri, app_url
+def _prompt_auth_method() -> str:
+ """Prompt for authentication method. Returns 'oauth' or 'pat'."""
+ console.print()
+ console.print(
+ Panel(
+ "[bold]Step 2: Choose authentication method[/bold]\n\n"
+ " [cyan]1[/cyan]) OAuth (browser login) — recommended\n"
+ " [cyan]2[/cyan]) PAT (manual personal access token)",
+ title="Authentication",
+ border_style="blue",
+ )
+ )
+ choice = typer.prompt("Enter 1 or 2", default="1").strip()
+ if choice not in AUTH_METHODS:
+ console.print("[yellow]Invalid choice, defaulting to OAuth.[/yellow]")
+ choice = "1"
+ method = AUTH_METHODS[choice]
+ console.print(f" -> Auth method: [bold]{method.upper()}[/bold]")
+ return method
+
+
def _prompt_pat(app_url: str) -> str:
"""Prompt for Personal Access Token with step-by-step instructions."""
console.print()
console.print(
Panel(
- "[bold]Step 2: Create a Personal Access Token (PAT)[/bold]\n\n"
+ "[bold]Step 3: Create a Personal Access Token (PAT)[/bold]\n\n"
f" 1. Open [link={app_url}]{app_url}[/link] and sign in\n"
" 2. Click your profile icon (bottom-left) → [bold]Account Settings[/bold]\n"
" 3. Go to [bold]Personal Access Tokens[/bold]\n"
@@ -127,7 +162,7 @@ def _prompt_project_id(app_url: str) -> str:
console.print()
console.print(
Panel(
- "[bold]Step 3: Find your Project ID[/bold]\n\n"
+ "[bold]Find your Project ID[/bold]\n\n"
f" 1. Open [link={app_url}]{app_url}[/link]\n"
" 2. Select your project from the top-left dropdown\n"
" 3. Go to [bold]Project Settings[/bold] → [bold]General[/bold]\n"
@@ -174,7 +209,7 @@ def setup_command(
"This wizard will help you connect the Dremio CLI to your Dremio Cloud account.\n\n"
"You'll need:\n"
" • A [bold]Dremio Cloud account[/bold] (sign up at [link=https://app.dremio.cloud]app.dremio.cloud[/link])\n"
- " • A [bold]Personal Access Token[/bold] (we'll walk you through creating one)\n"
+ " • A [bold]Personal Access Token[/bold] or [bold]OAuth browser login[/bold]\n"
" • A [bold]Project ID[/bold] (we'll show you where to find it)",
title="[bold]Dremio CLI Setup[/bold]",
border_style="cyan",
@@ -191,53 +226,79 @@ def setup_command(
# Step 1: Region
api_uri, app_url = _prompt_region()
- # Step 2: PAT (with retry loop)
- pat = _prompt_pat(app_url)
+ # Step 2: Auth method
+ auth_method = _prompt_auth_method()
- # Step 3: Project ID (with retry loop)
- project_id = _prompt_project_id(app_url)
+ pat: str | None = None
+ project_name: str = ""
- # Validate
- console.print()
- with console.status("[bold]Validating credentials...[/bold]", spinner="dots"):
- ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id))
+ if auth_method == "oauth":
+ # OAuth flow — browser login, then project ID
+ from drs import oauth, token_store
- while not ok:
- console.print(f"\n[red]✗ {message}[/red]")
- if not typer.confirm("Would you like to try again?", default=True):
- console.print("Setup cancelled.")
+ console.print("\n[bold]Starting OAuth browser login...[/bold]")
+ try:
+ tokens = oauth.run_login_flow(api_uri)
+ except Exception as exc:
+ err_console.print(f"\n[bold red]OAuth login failed:[/bold red] {exc}")
raise typer.Exit(1)
- if "Authentication" in message:
- console.print("[dim]Let's try the PAT again.[/dim]")
- pat = _prompt_pat(app_url)
- elif "Access denied" in message:
- console.print(
- "\n [cyan]1[/cyan]) Re-enter PAT (token may lack permissions)\n [cyan]2[/cyan]) Re-enter Project ID"
- )
- choice = typer.prompt("Which would you like to fix?", default="1").strip()
- if choice == "2":
- project_id = _prompt_project_id(app_url)
- else:
- pat = _prompt_pat(app_url)
- elif "Project" in message:
- console.print("[dim]Let's try the Project ID again.[/dim]")
- project_id = _prompt_project_id(app_url)
- else:
- console.print("[dim]Let's try the region again.[/dim]")
- api_uri, app_url = _prompt_region()
- pat = _prompt_pat(app_url)
- project_id = _prompt_project_id(app_url)
+ token_store.save(api_uri, tokens)
+ console.print("[green]OAuth login successful.[/green]")
+
+ # Step 3: Project ID
+ project_id = _prompt_project_id(app_url)
+ project_name = project_id
+ write_config(api_uri, None, project_id, config_path, auth_method="oauth")
+ else:
+ # PAT flow (existing behavior)
+ pat = _prompt_pat(app_url)
+
+ # Step 3: Project ID (with retry loop)
+ project_id = _prompt_project_id(app_url)
+
+ # Validate
console.print()
with console.status("[bold]Validating credentials...[/bold]", spinner="dots"):
ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id))
- # Success — write config
- project_name = project_data.get("name", project_id) if project_data else project_id
- console.print(f"\n[green]✓ {message}[/green]")
+ while not ok:
+ console.print(f"\n[red]x {message}[/red]")
+ if not typer.confirm("Would you like to try again?", default=True):
+ console.print("Setup cancelled.")
+ raise typer.Exit(1)
+
+ if "Authentication" in message:
+ console.print("[dim]Let's try the PAT again.[/dim]")
+ pat = _prompt_pat(app_url)
+ elif "Access denied" in message:
+ console.print(
+ "\n [cyan]1[/cyan]) Re-enter PAT (token may lack permissions)\n [cyan]2[/cyan]) Re-enter Project ID"
+ )
+ choice = typer.prompt("Which would you like to fix?", default="1").strip()
+ if choice == "2":
+ project_id = _prompt_project_id(app_url)
+ else:
+ pat = _prompt_pat(app_url)
+ elif "Project" in message:
+ console.print("[dim]Let's try the Project ID again.[/dim]")
+ project_id = _prompt_project_id(app_url)
+ else:
+ console.print("[dim]Let's try the region again.[/dim]")
+ api_uri, app_url = _prompt_region()
+ pat = _prompt_pat(app_url)
+ project_id = _prompt_project_id(app_url)
+
+ console.print()
+ with console.status("[bold]Validating credentials...[/bold]", spinner="dots"):
+ ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id))
+
+ # Success
+ project_name = project_data.get("name", project_id) if project_data else project_id
+ console.print(f"\n[green]v {message}[/green]")
- write_config(api_uri, pat, project_id, config_path)
+ write_config(api_uri, pat, project_id, config_path)
console.print()
success = Text()
diff --git a/src/drs/oauth.py b/src/drs/oauth.py
new file mode 100644
index 0000000..c78ecf9
--- /dev/null
+++ b/src/drs/oauth.py
@@ -0,0 +1,292 @@
+#
+# Copyright (C) 2017-2026 Dremio Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""OAuth 2.0 PKCE browser-login flow for Dremio Cloud."""
+
+from __future__ import annotations
+
+import base64
+import hashlib
+import logging
+import secrets
+import socket
+import threading
+import time
+import webbrowser
+from concurrent.futures import Future
+from dataclasses import dataclass
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from urllib.parse import parse_qs, urlencode, urlparse
+
+import httpx
+
+from drs.token_store import OAuthTokens
+
+logger = logging.getLogger(__name__)
+
+_CLIENT_ID = "https://connectors.dremio.app/claude"
+
+_SUCCESS_HTML = """\
+
+
Dremio CLI
+
+Login successful
+You can close this tab and return to the terminal.
+
+"""
+
+
+@dataclass
+class OAuthServerMetadata:
+ authorization_endpoint: str
+ token_endpoint: str
+ registration_endpoint: str | None = None
+
+
+def _login_url(dremio_url: str) -> str:
+ """Derive the OAuth login host from a Dremio URL (app.X -> login.X)."""
+ parsed = urlparse(dremio_url)
+ host = parsed.hostname or ""
+ if host.startswith("app."):
+ host = "login." + host[4:]
+ return f"{parsed.scheme}://{host}"
+
+
+def discover(dremio_url: str) -> OAuthServerMetadata:
+ """Fetch OAuth Authorization Server Metadata from *dremio_url*."""
+ base = _login_url(dremio_url)
+ url = f"{base}/.well-known/oauth-authorization-server"
+ resp = httpx.get(url, timeout=30.0)
+ resp.raise_for_status()
+ data = resp.json()
+ return OAuthServerMetadata(
+ authorization_endpoint=data["authorization_endpoint"],
+ token_endpoint=data["token_endpoint"],
+ registration_endpoint=data.get("registration_endpoint"),
+ )
+
+
+def register_client(registration_endpoint: str, redirect_uri: str) -> tuple[str, str | None]:
+ """Dynamic Client Registration. Returns ``(client_id, client_secret)``.
+
+ Returns ``None`` if the server does not support DCR (403/400).
+ """
+ body = {
+ "client_name": _CLIENT_ID,
+ "redirect_uris": [redirect_uri],
+ "grant_types": ["authorization_code", "refresh_token"],
+ "response_types": ["code"],
+ }
+ resp = httpx.post(registration_endpoint, json=body, timeout=30.0)
+ if resp.status_code in (400, 403):
+ logger.info("DCR not available (%s) — using well-known client_id.", resp.status_code)
+ return None, None
+ resp.raise_for_status()
+ data = resp.json()
+ return data["client_id"], data.get("client_secret")
+
+
+def generate_pkce() -> tuple[str, str]:
+ """Generate PKCE ``(code_verifier, code_challenge)`` pair."""
+ code_verifier = secrets.token_urlsafe(32)
+ digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+ code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+ return code_verifier, code_challenge
+
+
+def build_authorization_url(
+ auth_endpoint: str,
+ client_id: str,
+ redirect_uri: str,
+ code_challenge: str,
+ state: str,
+) -> str:
+ """Construct the authorization URL the user's browser should open."""
+ params = {
+ "response_type": "code",
+ "client_id": client_id,
+ "redirect_uri": redirect_uri,
+ "code_challenge": code_challenge,
+ "code_challenge_method": "S256",
+ "scope": "dremio.all offline_access",
+ "state": state,
+ }
+ return f"{auth_endpoint}?{urlencode(params)}"
+
+
+def find_free_port() -> int:
+ """Bind to port 0 to obtain a free ephemeral port."""
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("127.0.0.1", 0))
+ return s.getsockname()[1]
+
+
+def start_callback_server(port: int) -> tuple[HTTPServer, Future[tuple[str, str]]]:
+ """Start a localhost HTTP server that captures the OAuth callback.
+
+ Returns ``(server, future)`` where the future resolves to ``(code, state)``.
+ """
+ future: Future[tuple[str, str]] = Future()
+
+ class _Handler(BaseHTTPRequestHandler):
+ def do_GET(self) -> None: # noqa: N802
+ qs = parse_qs(urlparse(self.path).query)
+ code = qs.get("code", [None])[0]
+ state = qs.get("state", [None])[0]
+ error = qs.get("error", [None])[0]
+
+ self.send_response(200)
+ self.send_header("Content-Type", "text/html")
+ self.end_headers()
+
+ if code and state:
+ self.wfile.write(_SUCCESS_HTML.encode())
+ future.set_result((code, state))
+ else:
+ error_msg = error or "unknown"
+ self.wfile.write(f"Login failed: {error_msg}
".encode())
+ if not future.done():
+ future.set_exception(RuntimeError(f"OAuth callback error: {error_msg}"))
+
+ def log_message(self, format: str, *args: object) -> None: # noqa: A002
+ logger.debug(format, *args)
+
+ server = HTTPServer(("127.0.0.1", port), _Handler)
+ thread = threading.Thread(target=server.serve_forever, daemon=True)
+ thread.start()
+ return server, future
+
+
+def exchange_code(
+ token_endpoint: str,
+ code: str,
+ redirect_uri: str,
+ client_id: str,
+ client_secret: str | None,
+ code_verifier: str,
+) -> OAuthTokens:
+ """Exchange an authorization code for tokens."""
+ data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": redirect_uri,
+ "client_id": client_id,
+ "code_verifier": code_verifier,
+ }
+ if client_secret:
+ data["client_secret"] = client_secret
+ resp = httpx.post(token_endpoint, data=data, timeout=30.0)
+ resp.raise_for_status()
+ body = resp.json()
+ expires_at: float | None = None
+ if "expires_in" in body:
+ expires_at = time.time() + body["expires_in"]
+ return OAuthTokens(
+ access_token=body["access_token"],
+ refresh_token=body.get("refresh_token"),
+ expires_at=expires_at,
+ client_id=client_id,
+ client_secret=client_secret,
+ )
+
+
+def refresh_access_token(
+ token_endpoint: str,
+ client_id: str,
+ client_secret: str | None,
+ refresh_token: str,
+) -> OAuthTokens:
+ """Use a refresh token to obtain a new access token."""
+ data = {
+ "grant_type": "refresh_token",
+ "client_id": client_id,
+ "refresh_token": refresh_token,
+ }
+ if client_secret:
+ data["client_secret"] = client_secret
+ resp = httpx.post(token_endpoint, data=data, timeout=30.0)
+ resp.raise_for_status()
+ body = resp.json()
+ expires_at: float | None = None
+ if "expires_in" in body:
+ expires_at = time.time() + body["expires_in"]
+ return OAuthTokens(
+ access_token=body["access_token"],
+ refresh_token=body.get("refresh_token", refresh_token),
+ expires_at=expires_at,
+ client_id=client_id,
+ client_secret=client_secret,
+ )
+
+
+def run_login_flow(dremio_url: str) -> OAuthTokens:
+ """Orchestrate the full browser-based OAuth login flow.
+
+ 1. Discover OAuth endpoints
+ 2. Find a free port & register the client (DCR)
+ 3. Generate PKCE codes
+ 4. Start localhost callback server
+ 5. Open the browser (fall back to printing URL)
+ 6. Wait for the callback
+ 7. Exchange the code for tokens
+ """
+ metadata = discover(dremio_url)
+
+ port = find_free_port()
+ redirect_uri = f"http://localhost:{port}/Callback"
+
+ client_id, client_secret = None, None
+ if metadata.registration_endpoint:
+ client_id, client_secret = register_client(metadata.registration_endpoint, redirect_uri)
+ if not client_id:
+ client_id, client_secret = _CLIENT_ID, None
+
+ code_verifier, code_challenge = generate_pkce()
+ state = secrets.token_urlsafe(16)
+
+ auth_url = build_authorization_url(
+ metadata.authorization_endpoint,
+ client_id,
+ redirect_uri,
+ code_challenge,
+ state,
+ )
+
+ server, future = start_callback_server(port)
+ try:
+ try:
+ opened = webbrowser.open(auth_url)
+ except Exception:
+ opened = False
+ if opened:
+ logger.info("Opened browser for OAuth login.")
+ else:
+ # Headless / no-browser fallback
+ print(f"\nOpen this URL in your browser to log in:\n\n {auth_url}\n")
+
+ code, returned_state = future.result(timeout=300)
+ if returned_state != state:
+ raise RuntimeError("OAuth state mismatch — possible CSRF attack.")
+ finally:
+ server.shutdown()
+
+ return exchange_code(
+ metadata.token_endpoint,
+ code,
+ redirect_uri,
+ client_id,
+ client_secret,
+ code_verifier,
+ )
diff --git a/src/drs/token_store.py b/src/drs/token_store.py
new file mode 100644
index 0000000..8d8638a
--- /dev/null
+++ b/src/drs/token_store.py
@@ -0,0 +1,70 @@
+#
+# Copyright (C) 2017-2026 Dremio Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Persistent storage for OAuth tokens, keyed by Dremio URL."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import yaml
+from pydantic import BaseModel
+
+STORE_PATH = Path.home() / ".config" / "dremioai" / "oauth_tokens.yaml"
+
+
+class OAuthTokens(BaseModel):
+ access_token: str
+ refresh_token: str | None = None
+ expires_at: float | None = None
+ client_id: str
+ client_secret: str | None = None
+
+
+def load(dremio_url: str, store_path: Path = STORE_PATH) -> OAuthTokens | None:
+ """Read stored OAuth tokens for *dremio_url*. Returns ``None`` if absent."""
+ if not store_path.exists():
+ return None
+ with store_path.open() as f:
+ data = yaml.safe_load(f) or {}
+ entry = data.get(dremio_url)
+ if entry is None:
+ return None
+ return OAuthTokens(**entry)
+
+
+def save(dremio_url: str, tokens: OAuthTokens, store_path: Path = STORE_PATH) -> None:
+ """Persist *tokens* under *dremio_url*. Creates dirs and sets mode 600."""
+ store_path.parent.mkdir(parents=True, exist_ok=True)
+
+ data: dict = {}
+ if store_path.exists():
+ with store_path.open() as f:
+ data = yaml.safe_load(f) or {}
+
+ data[dremio_url] = tokens.model_dump(exclude_none=True)
+ store_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False))
+ store_path.chmod(0o600)
+
+
+def clear(dremio_url: str, store_path: Path = STORE_PATH) -> None:
+ """Remove stored tokens for *dremio_url*."""
+ if not store_path.exists():
+ return
+ with store_path.open() as f:
+ data = yaml.safe_load(f) or {}
+ data.pop(dremio_url, None)
+ store_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False))
+ store_path.chmod(0o600)
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 9187f81..348de25 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -25,7 +25,8 @@
import yaml
from pydantic import ValidationError
-from drs.auth import load_config
+from drs.auth import DrsConfig, load_config
+from drs.token_store import OAuthTokens
def test_config_from_env_vars(tmp_path: Path) -> None:
@@ -137,6 +138,55 @@ def test_dremio_token_overrides_dremio_pat(tmp_path: Path) -> None:
assert config.pat == "new-token"
+def test_config_oauth_from_token_store(tmp_path: Path) -> None:
+ """When no PAT is available but OAuth tokens exist, auth_method should be 'oauth'."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text(yaml.dump({"project_id": "proj-1"}))
+
+ fake_tokens = OAuthTokens(access_token="oauth-at", refresh_token="oauth-rt", client_id="cid")
+
+ with (
+ patch.dict(os.environ, {}, clear=False),
+ patch("drs.token_store.load", return_value=fake_tokens),
+ ):
+ for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]:
+ os.environ.pop(k, None)
+ config = load_config(config_file)
+
+ assert config.auth_method == "oauth"
+ assert config.oauth_access_token == "oauth-at"
+ assert config.pat is None
+
+
+def test_config_pat_still_works(tmp_path: Path) -> None:
+ """PAT auth should continue to work exactly as before."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text(yaml.dump({"pat": "my-pat", "project_id": "proj-1"}))
+
+ with patch.dict(os.environ, {}, clear=False):
+ for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]:
+ os.environ.pop(k, None)
+ config = load_config(config_file)
+
+ assert config.auth_method == "pat"
+ assert config.pat == "my-pat"
+
+
+def test_config_neither_pat_nor_oauth_fails(tmp_path: Path) -> None:
+ """Missing both PAT and OAuth tokens should raise ValidationError."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text(yaml.dump({"project_id": "proj-1"}))
+
+ with (
+ patch.dict(os.environ, {}, clear=False),
+ patch("drs.token_store.load", return_value=None),
+ pytest.raises(ValidationError),
+ ):
+ for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]:
+ os.environ.pop(k, None)
+ load_config(config_file)
+
+
def test_cli_args_override_env(tmp_path: Path) -> None:
"""CLI args should override env vars and file values."""
config_file = tmp_path / "config.yaml"
diff --git a/tests/test_client_retry.py b/tests/test_client_retry.py
index 485229b..3a29222 100644
--- a/tests/test_client_retry.py
+++ b/tests/test_client_retry.py
@@ -22,6 +22,7 @@
import httpx
import pytest
+from drs.auth import DrsConfig
from drs.client import DremioClient
@@ -154,3 +155,87 @@ async def test_retry_backoff_delays(config) -> None:
assert mock_sleep.call_count == 2
mock_sleep.assert_any_call(1.0)
mock_sleep.assert_any_call(2.0)
+
+
+@pytest.mark.asyncio
+async def test_401_triggers_oauth_refresh() -> None:
+ """401 with OAuth auth should attempt token refresh and retry."""
+ oauth_config = DrsConfig(
+ uri="https://api.dremio.cloud",
+ oauth_access_token="old-token",
+ project_id="proj",
+ auth_method="oauth",
+ )
+ client = DremioClient(oauth_config)
+
+ unauthorized = httpx.Response(401, request=httpx.Request("GET", "https://example.com"))
+ ok_response = httpx.Response(200, json={"ok": True}, request=httpx.Request("GET", "https://example.com"))
+
+ client._client.request = AsyncMock(side_effect=[unauthorized, ok_response])
+
+ with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock) as mock_refresh:
+ result = await client._get("https://example.com/test")
+
+ assert result == {"ok": True}
+ mock_refresh.assert_awaited_once()
+ assert client._client.request.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_401_no_refresh_for_pat(config) -> None:
+ """401 with PAT auth should NOT attempt token refresh."""
+ client = DremioClient(config)
+
+ unauthorized = httpx.Response(401, json={"error": "unauthorized"}, request=httpx.Request("GET", "https://example.com"))
+ client._client.request = AsyncMock(return_value=unauthorized)
+
+ with pytest.raises(httpx.HTTPStatusError):
+ await client._get("https://example.com/test")
+
+ assert client._client.request.call_count == 1
+
+
+@pytest.mark.asyncio
+async def test_401_refresh_failure_returns_original_response() -> None:
+ """When refresh fails, the original 401 response should be returned."""
+ oauth_config = DrsConfig(
+ uri="https://api.dremio.cloud",
+ oauth_access_token="old-token",
+ project_id="proj",
+ auth_method="oauth",
+ )
+ client = DremioClient(oauth_config)
+
+ unauthorized = httpx.Response(401, json={"error": "bad token"}, request=httpx.Request("GET", "https://example.com"))
+ client._client.request = AsyncMock(return_value=unauthorized)
+
+ with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock, side_effect=RuntimeError("no refresh")):
+ with pytest.raises(httpx.HTTPStatusError):
+ await client._get("https://example.com/test")
+
+ # Should only have tried once (no retry after failed refresh)
+ assert client._client.request.call_count == 1
+
+
+@pytest.mark.asyncio
+async def test_401_only_one_refresh_per_request() -> None:
+ """Only one refresh attempt per request — second 401 should not retry."""
+ oauth_config = DrsConfig(
+ uri="https://api.dremio.cloud",
+ oauth_access_token="old-token",
+ project_id="proj",
+ auth_method="oauth",
+ )
+ client = DremioClient(oauth_config)
+
+ unauthorized = httpx.Response(401, request=httpx.Request("GET", "https://example.com"))
+
+ # Both attempts return 401
+ client._client.request = AsyncMock(return_value=unauthorized)
+
+ with patch.object(client, "_refresh_oauth_token", new_callable=AsyncMock):
+ with pytest.raises(httpx.HTTPStatusError):
+ await client._get("https://example.com/test")
+
+ # 1st call -> 401 -> refresh -> 2nd call -> 401 -> return (no more refresh)
+ assert client._client.request.call_count == 2
diff --git a/tests/test_commands/test_login.py b/tests/test_commands/test_login.py
new file mode 100644
index 0000000..bf20f2b
--- /dev/null
+++ b/tests/test_commands/test_login.py
@@ -0,0 +1,132 @@
+#
+# Copyright (C) 2017-2026 Dremio Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Tests for dremio login / dremio logout commands."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+from typer.testing import CliRunner
+
+from drs.cli import app
+from drs.token_store import OAuthTokens, load, save
+
+runner = CliRunner()
+
+
+def test_login_saves_tokens(tmp_path: Path) -> None:
+ config_path = tmp_path / "config.yaml"
+ config_path.write_text("project_id: proj-1\n")
+ store_path = tmp_path / "oauth_tokens.yaml"
+
+ fake_tokens = OAuthTokens(
+ access_token="at-new",
+ refresh_token="rt-new",
+ client_id="cid",
+ )
+
+ with (
+ patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens),
+ patch("drs.commands.login.token_store.save") as mock_save,
+ patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path),
+ patch("drs.commands.login._update_config_file"),
+ ):
+ result = runner.invoke(app, ["--config", str(config_path), "login"])
+
+ assert result.exit_code == 0
+ assert "successfully" in result.output.lower() or "logged in" in result.output.lower()
+ mock_save.assert_called_once()
+ saved_tokens = mock_save.call_args[0][1]
+ assert saved_tokens.access_token == "at-new"
+
+
+def test_login_picks_project_from_list(tmp_path: Path) -> None:
+ """When project_id is not in config, login should show project list."""
+ config_path = tmp_path / "config.yaml"
+ config_path.write_text("uri: https://api.dremio.cloud\n")
+
+ fake_tokens = OAuthTokens(access_token="at", client_id="cid")
+ fake_projects = [
+ {"id": "proj-aaa", "name": "Alpha"},
+ {"id": "proj-bbb", "name": "Beta"},
+ ]
+
+ with (
+ patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens),
+ patch("drs.commands.login.token_store.save"),
+ patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path),
+ patch("drs.commands.login._update_config_file") as mock_update,
+ patch("drs.commands.login._fetch_projects", return_value=fake_projects),
+ ):
+ result = runner.invoke(app, ["--config", str(config_path), "login"], input="2\n")
+
+ assert result.exit_code == 0
+ assert "Beta" in result.output
+ mock_update.assert_called_once()
+ assert mock_update.call_args[0][2] == "proj-bbb"
+
+
+def test_login_auto_selects_single_project(tmp_path: Path) -> None:
+ """When only one project exists, auto-select it."""
+ config_path = tmp_path / "config.yaml"
+ config_path.write_text("uri: https://api.dremio.cloud\n")
+
+ fake_tokens = OAuthTokens(access_token="at", client_id="cid")
+ fake_projects = [{"id": "proj-only", "name": "OnlyProject"}]
+
+ with (
+ patch("drs.commands.login.oauth.run_login_flow", return_value=fake_tokens),
+ patch("drs.commands.login.token_store.save"),
+ patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path),
+ patch("drs.commands.login._update_config_file") as mock_update,
+ patch("drs.commands.login._fetch_projects", return_value=fake_projects),
+ ):
+ result = runner.invoke(app, ["--config", str(config_path), "login"])
+
+ assert result.exit_code == 0
+ assert "OnlyProject" in result.output
+ mock_update.assert_called_once()
+ assert mock_update.call_args[0][2] == "proj-only"
+
+
+def test_logout_clears_tokens(tmp_path: Path) -> None:
+ config_path = tmp_path / "config.yaml"
+ config_path.write_text("uri: https://api.dremio.cloud\nproject_id: proj-1\n")
+
+ with (
+ patch("drs.commands.login.token_store.clear") as mock_clear,
+ patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path),
+ ):
+ result = runner.invoke(app, ["--config", str(config_path), "logout"])
+
+ assert result.exit_code == 0
+ assert "logged out" in result.output.lower() or "removed" in result.output.lower()
+ mock_clear.assert_called_once_with("https://api.dremio.cloud")
+
+
+def test_login_failure_exits_1(tmp_path: Path) -> None:
+ config_path = tmp_path / "config.yaml"
+ config_path.write_text("project_id: proj-1\n")
+
+ with (
+ patch("drs.commands.login.oauth.run_login_flow", side_effect=RuntimeError("network error")),
+ patch("drs.commands.login.DEFAULT_CONFIG_PATH", config_path),
+ ):
+ result = runner.invoke(app, ["--config", str(config_path), "login"])
+
+ assert result.exit_code == 1
+ assert "failed" in result.output.lower()
diff --git a/tests/test_commands/test_setup.py b/tests/test_commands/test_setup.py
index 2bab050..697eeb7 100644
--- a/tests/test_commands/test_setup.py
+++ b/tests/test_commands/test_setup.py
@@ -27,6 +27,7 @@
from drs.auth import DEFAULT_URI
from drs.cli import app
from drs.commands.setup import validate_credentials, write_config
+from drs.token_store import OAuthTokens
runner = CliRunner()
@@ -147,7 +148,7 @@ def test_write_config_includes_header(tmp_path) -> None:
raw = config_path.read_text()
assert raw.startswith("# Dremio CLI config")
- assert "plaintext" in raw
+ assert "mode 600" in raw
def test_setup_non_interactive(tmp_path) -> None:
@@ -174,8 +175,8 @@ def test_setup_happy_path(tmp_path) -> None:
patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path),
):
mock_sys.stdin.isatty.return_value = True
- # Input: region=1, PAT=test-pat, project_id=test-proj
- result = runner.invoke(app, ["setup"], input="1\ntest-pat\ntest-proj\n")
+ # Input: region=1, auth=2(PAT), PAT=test-pat, project_id=test-proj
+ result = runner.invoke(app, ["setup"], input="1\n2\ntest-pat\ntest-proj\n")
assert result.exit_code == 0
assert "Setup complete" in result.output
@@ -219,8 +220,8 @@ def test_setup_existing_config_overwrite(tmp_path) -> None:
patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path),
):
mock_sys.stdin.isatty.return_value = True
- # Input: overwrite=y, region=1, PAT=new-pat, project_id=new-proj
- result = runner.invoke(app, ["setup"], input="y\n1\nnew-pat\nnew-proj\n")
+ # Input: overwrite=y, region=1, auth=2(PAT), PAT=new-pat, project_id=new-proj
+ result = runner.invoke(app, ["setup"], input="y\n1\n2\nnew-pat\nnew-proj\n")
assert result.exit_code == 0
data = yaml.safe_load(config_path.read_text())
@@ -244,8 +245,8 @@ def test_setup_retry_then_abort(tmp_path) -> None:
patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path),
):
mock_sys.stdin.isatty.return_value = True
- # Input: region=1, PAT=bad, project_id=p1, then decline retry
- result = runner.invoke(app, ["setup"], input="1\nbad-pat\np1\nn\n")
+ # Input: region=1, auth=2(PAT), PAT=bad, project_id=p1, then decline retry
+ result = runner.invoke(app, ["setup"], input="1\n2\nbad-pat\np1\nn\n")
assert result.exit_code == 1
assert "cancelled" in result.output.lower()
@@ -265,9 +266,34 @@ def test_setup_global_config_passthrough(tmp_path) -> None:
patch("drs.commands.setup.DremioClient", return_value=mock_client),
):
mock_sys.stdin.isatty.return_value = True
- result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\nmy-pat\nmy-proj\n")
+ result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\n2\nmy-pat\nmy-proj\n")
assert result.exit_code == 0
assert config_path.exists()
data = yaml.safe_load(config_path.read_text())
assert data["pat"] == "my-pat"
+
+
+def test_setup_oauth_path(tmp_path) -> None:
+ """OAuth path through setup wizard: region, auth=1(OAuth), project_id."""
+ config_path = tmp_path / "config.yaml"
+
+ fake_tokens = OAuthTokens(access_token="at-oauth", client_id="cid")
+
+ with (
+ patch("drs.commands.setup.sys") as mock_sys,
+ patch("drs.oauth.run_login_flow", return_value=fake_tokens),
+ patch("drs.token_store.save") as mock_save,
+ patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path),
+ ):
+ mock_sys.stdin.isatty.return_value = True
+ # Input: region=1, auth=1(OAuth), project_id=my-proj
+ result = runner.invoke(app, ["setup"], input="1\n1\nmy-proj\n")
+
+ assert result.exit_code == 0
+ assert config_path.exists()
+ data = yaml.safe_load(config_path.read_text())
+ assert data.get("auth_method") == "oauth"
+ assert "pat" not in data
+ assert data["project_id"] == "my-proj"
+ mock_save.assert_called_once()
diff --git a/tests/test_oauth.py b/tests/test_oauth.py
new file mode 100644
index 0000000..0339ffc
--- /dev/null
+++ b/tests/test_oauth.py
@@ -0,0 +1,282 @@
+#
+# Copyright (C) 2017-2026 Dremio Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Tests for drs.oauth — OAuth 2.0 PKCE flow mechanics."""
+
+from __future__ import annotations
+
+import base64
+import hashlib
+import urllib.request
+from unittest.mock import MagicMock, patch
+from urllib.parse import parse_qs, urlparse
+
+import httpx
+import pytest
+
+from drs.oauth import (
+ build_authorization_url,
+ discover,
+ exchange_code,
+ find_free_port,
+ generate_pkce,
+ refresh_access_token,
+ run_login_flow,
+ start_callback_server,
+)
+
+
+class TestPKCE:
+ def test_generate_pkce_format(self) -> None:
+ verifier, challenge = generate_pkce()
+ assert len(verifier) > 20
+ assert len(challenge) > 20
+ # No padding characters
+ assert "=" not in challenge
+
+ def test_pkce_challenge_matches_verifier(self) -> None:
+ verifier, challenge = generate_pkce()
+ expected = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()).rstrip(b"=").decode("ascii")
+ assert challenge == expected
+
+ def test_pkce_uniqueness(self) -> None:
+ v1, _ = generate_pkce()
+ v2, _ = generate_pkce()
+ assert v1 != v2
+
+
+class TestBuildAuthorizationURL:
+ def test_url_construction(self) -> None:
+ url = build_authorization_url(
+ auth_endpoint="https://auth.example.com/authorize",
+ client_id="my-client",
+ redirect_uri="http://localhost:8080/Callback",
+ code_challenge="abc123",
+ state="state-xyz",
+ )
+ parsed = urlparse(url)
+ params = parse_qs(parsed.query)
+
+ assert parsed.scheme == "https"
+ assert parsed.netloc == "auth.example.com"
+ assert parsed.path == "/authorize"
+ assert params["response_type"] == ["code"]
+ assert params["client_id"] == ["my-client"]
+ assert params["redirect_uri"] == ["http://localhost:8080/Callback"]
+ assert params["code_challenge"] == ["abc123"]
+ assert params["code_challenge_method"] == ["S256"]
+ assert params["scope"] == ["dremio.all offline_access"]
+ assert params["state"] == ["state-xyz"]
+
+
+class TestLoginUrl:
+ def test_app_to_login(self) -> None:
+ from drs.oauth import _login_url
+
+ assert _login_url("https://app.dremio.cloud") == "https://login.dremio.cloud"
+ assert _login_url("https://app.dev.dremio.site") == "https://login.dev.dremio.site"
+ assert _login_url("https://app.eu.dremio.cloud") == "https://login.eu.dremio.cloud"
+
+ def test_non_app_host_unchanged(self) -> None:
+ from drs.oauth import _login_url
+
+ assert _login_url("https://login.dremio.cloud") == "https://login.dremio.cloud"
+ assert _login_url("https://custom.example.com") == "https://custom.example.com"
+
+
+class TestDiscover:
+ def test_discover_parses_metadata(self) -> None:
+ metadata_json = {
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ }
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = metadata_json
+ mock_response.raise_for_status = MagicMock()
+
+ with patch("drs.oauth.httpx.get", return_value=mock_response) as mock_get:
+ result = discover("https://app.dremio.cloud")
+
+ mock_get.assert_called_once_with(
+ "https://login.dremio.cloud/.well-known/oauth-authorization-server",
+ timeout=30.0,
+ )
+ assert result.authorization_endpoint == "https://auth.example.com/authorize"
+ assert result.token_endpoint == "https://auth.example.com/token"
+ assert result.registration_endpoint == "https://auth.example.com/register"
+
+ def test_discover_optional_registration(self) -> None:
+ metadata_json = {
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ }
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = metadata_json
+ mock_response.raise_for_status = MagicMock()
+
+ with patch("drs.oauth.httpx.get", return_value=mock_response):
+ result = discover("https://app.dremio.cloud")
+
+ assert result.registration_endpoint is None
+
+
+class TestExchangeCode:
+ def test_exchange_code_success(self) -> None:
+ token_json = {
+ "access_token": "at-new",
+ "refresh_token": "rt-new",
+ "expires_in": 3600,
+ }
+ mock_response = MagicMock()
+ mock_response.json.return_value = token_json
+ mock_response.raise_for_status = MagicMock()
+
+ with patch("drs.oauth.httpx.post", return_value=mock_response):
+ tokens = exchange_code(
+ "https://auth.example.com/token",
+ "auth-code-123",
+ "http://localhost:8080/Callback",
+ "my-client",
+ "my-secret",
+ "my-verifier",
+ )
+
+ assert tokens.access_token == "at-new"
+ assert tokens.refresh_token == "rt-new"
+ assert tokens.expires_at is not None
+ assert tokens.client_id == "my-client"
+ assert tokens.client_secret == "my-secret"
+
+
+class TestRefreshAccessToken:
+ def test_refresh_success(self) -> None:
+ token_json = {
+ "access_token": "at-refreshed",
+ "refresh_token": "rt-new",
+ "expires_in": 3600,
+ }
+ mock_response = MagicMock()
+ mock_response.json.return_value = token_json
+ mock_response.raise_for_status = MagicMock()
+
+ with patch("drs.oauth.httpx.post", return_value=mock_response):
+ tokens = refresh_access_token(
+ "https://auth.example.com/token",
+ "my-client",
+ "my-secret",
+ "rt-old",
+ )
+
+ assert tokens.access_token == "at-refreshed"
+ assert tokens.refresh_token == "rt-new"
+
+ def test_refresh_preserves_old_refresh_token(self) -> None:
+ """When server omits refresh_token in response, keep the old one."""
+ token_json = {
+ "access_token": "at-refreshed",
+ "expires_in": 3600,
+ }
+ mock_response = MagicMock()
+ mock_response.json.return_value = token_json
+ mock_response.raise_for_status = MagicMock()
+
+ with patch("drs.oauth.httpx.post", return_value=mock_response):
+ tokens = refresh_access_token(
+ "https://auth.example.com/token",
+ "my-client",
+ None,
+ "rt-old",
+ )
+
+ assert tokens.refresh_token == "rt-old"
+
+
+class TestCallbackServer:
+ def test_callback_captures_code_and_state(self) -> None:
+ port = find_free_port()
+ server, future = start_callback_server(port)
+ try:
+ url = f"http://localhost:{port}/Callback?code=test-code&state=test-state"
+ urllib.request.urlopen(url, timeout=5)
+ code, state = future.result(timeout=5)
+ assert code == "test-code"
+ assert state == "test-state"
+ finally:
+ server.shutdown()
+
+ def test_find_free_port_returns_positive(self) -> None:
+ port = find_free_port()
+ assert port > 0
+
+
+class TestRunLoginFlow:
+ @pytest.mark.parametrize("browser_behavior", ["raises", "returns_false"])
+ def test_headless_fallback_prints_url(self, browser_behavior: str) -> None:
+ """When webbrowser.open raises or returns False, the URL should be printed."""
+ metadata_json = {
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ }
+ dcr_json = {"client_id": "cid", "client_secret": "cs"}
+ token_json = {"access_token": "at", "refresh_token": "rt", "expires_in": 3600}
+
+ mock_get = MagicMock()
+ mock_get.json.return_value = metadata_json
+ mock_get.raise_for_status = MagicMock()
+
+ post_responses = []
+ # DCR call
+ dcr_resp = MagicMock()
+ dcr_resp.json.return_value = dcr_json
+ dcr_resp.raise_for_status = MagicMock()
+ post_responses.append(dcr_resp)
+ # Token exchange
+ token_resp = MagicMock()
+ token_resp.json.return_value = token_json
+ token_resp.raise_for_status = MagicMock()
+ post_responses.append(token_resp)
+
+ if browser_behavior == "raises":
+ browser_mock = MagicMock(side_effect=RuntimeError("no browser"))
+ else:
+ browser_mock = MagicMock(return_value=False)
+
+ with (
+ patch("drs.oauth.httpx.get", return_value=mock_get),
+ patch("drs.oauth.httpx.post", side_effect=post_responses),
+ patch("drs.oauth.webbrowser.open", browser_mock),
+ patch("drs.oauth.start_callback_server") as mock_server,
+ patch("builtins.print") as mock_print,
+ ):
+ mock_future = MagicMock()
+ mock_future.result.return_value = ("code-123", None)
+ mock_srv = MagicMock()
+ mock_server.return_value = (mock_srv, mock_future)
+
+ # The state won't match because we bypass the real flow,
+ # so we need to patch around the state check too.
+ try:
+ run_login_flow("https://api.dremio.cloud")
+ except RuntimeError:
+ pass # state mismatch expected in mocked test
+
+ # Verify fallback print was called
+ mock_print.assert_called_once()
+ printed_text = mock_print.call_args[0][0]
+ assert "browser" in printed_text.lower() or "url" in printed_text.lower() or "http" in printed_text.lower()
diff --git a/tests/test_token_store.py b/tests/test_token_store.py
new file mode 100644
index 0000000..45c71eb
--- /dev/null
+++ b/tests/test_token_store.py
@@ -0,0 +1,97 @@
+#
+# Copyright (C) 2017-2026 Dremio Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Tests for drs.token_store — OAuth token persistence."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import pytest
+
+from drs.token_store import OAuthTokens, clear, load, save
+
+
+@pytest.fixture
+def store_path(tmp_path: Path) -> Path:
+ return tmp_path / "oauth_tokens.yaml"
+
+
+@pytest.fixture
+def sample_tokens() -> OAuthTokens:
+ return OAuthTokens(
+ access_token="at-123",
+ refresh_token="rt-456",
+ expires_at=1700000000.0,
+ client_id="test-client",
+ client_secret="test-secret",
+ )
+
+
+def test_round_trip(store_path: Path, sample_tokens: OAuthTokens) -> None:
+ url = "https://api.dremio.cloud"
+ save(url, sample_tokens, store_path=store_path)
+ loaded = load(url, store_path=store_path)
+ assert loaded is not None
+ assert loaded == sample_tokens
+
+
+def test_load_missing_url(store_path: Path, sample_tokens: OAuthTokens) -> None:
+ save("https://api.dremio.cloud", sample_tokens, store_path=store_path)
+ assert load("https://api.eu.dremio.cloud", store_path=store_path) is None
+
+
+def test_load_nonexistent_file(store_path: Path) -> None:
+ assert load("https://api.dremio.cloud", store_path=store_path) is None
+
+
+def test_clear_removes_entry(store_path: Path, sample_tokens: OAuthTokens) -> None:
+ url = "https://api.dremio.cloud"
+ save(url, sample_tokens, store_path=store_path)
+ clear(url, store_path=store_path)
+ assert load(url, store_path=store_path) is None
+
+
+def test_clear_nonexistent_is_noop(store_path: Path) -> None:
+ clear("https://api.dremio.cloud", store_path=store_path) # should not raise
+
+
+def test_file_mode_600(store_path: Path, sample_tokens: OAuthTokens) -> None:
+ save("https://api.dremio.cloud", sample_tokens, store_path=store_path)
+ assert oct(store_path.stat().st_mode & 0o777) == "0o600"
+
+
+def test_multi_instance_keying(store_path: Path) -> None:
+ url_us = "https://api.dremio.cloud"
+ url_eu = "https://api.eu.dremio.cloud"
+ tokens_us = OAuthTokens(access_token="us-token", client_id="cid")
+ tokens_eu = OAuthTokens(access_token="eu-token", client_id="cid")
+
+ save(url_us, tokens_us, store_path=store_path)
+ save(url_eu, tokens_eu, store_path=store_path)
+
+ loaded_us = load(url_us, store_path=store_path)
+ loaded_eu = load(url_eu, store_path=store_path)
+ assert loaded_us is not None
+ assert loaded_eu is not None
+ assert loaded_us.access_token == "us-token"
+ assert loaded_eu.access_token == "eu-token"
+
+
+def test_save_creates_parent_dirs(tmp_path: Path) -> None:
+ deep_path = tmp_path / "a" / "b" / "tokens.yaml"
+ tokens = OAuthTokens(access_token="at", client_id="cid")
+ save("https://api.dremio.cloud", tokens, store_path=deep_path)
+ assert deep_path.exists()