From 34b3566c20a3d4c17b3c88a10e4f795e7c158afe Mon Sep 17 00:00:00 2001 From: Ruby Gunna Date: Wed, 15 Apr 2026 12:17:37 -0400 Subject: [PATCH 1/2] feat: databricks CE foundation with Delta tables, MLflow, and Claude API Replace PostgreSQL with Delta tables on Databricks Community Edition. Add FastAPI endpoints, Claude-powered SQL generation, MLflow tracking, evaluation framework, and Databricks notebook for Delta table setup. - Storage: Delta Lake tables (accounts, transactions, risk_metrics, model_inventory) - Inference: Claude API with chain-of-thought SQL generation - Tracking: MLflow on Databricks CE - Eval: 5-question benchmark suite with token tracking - API: FastAPI with /generate-sql, /feedback, /health endpoints Co-Authored-By: Claude Opus 4.6 (1M context) --- .env.example | 5 ++ .gitignore | 4 + Dockerfile | 12 +++ README.md | 49 +++++++--- docker-compose.yml | 11 +++ notebooks/01_setup_delta_tables.py | 140 +++++++++++++++++++++++++++++ requirements.txt | 14 +++ scripts/evaluate.py | 32 +++++++ src/__init__.py | 0 src/api/__init__.py | 0 src/api/main.py | 53 +++++++++++ src/core/__init__.py | 0 src/core/config.py | 17 ++++ src/core/mlflow_tracker.py | 28 ++++++ src/core/sql_generator.py | 82 +++++++++++++++++ src/eval/__init__.py | 0 src/eval/evaluator.py | 59 ++++++++++++ 17 files changed, 496 insertions(+), 10 deletions(-) create mode 100644 .env.example create mode 100644 Dockerfile create mode 100644 notebooks/01_setup_delta_tables.py create mode 100644 scripts/evaluate.py create mode 100644 src/__init__.py create mode 100644 src/api/__init__.py create mode 100644 src/api/main.py create mode 100644 src/core/__init__.py create mode 100644 src/core/config.py create mode 100644 src/core/mlflow_tracker.py create mode 100644 src/core/sql_generator.py create mode 100644 src/eval/__init__.py create mode 100644 src/eval/evaluator.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..1714250 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +ANTHROPIC_API_KEY=your-key-here +DATABRICKS_HOST=https://community.cloud.databricks.com +DATABRICKS_TOKEN=your-databricks-token-here +MLFLOW_TRACKING_URI=databricks +MLFLOW_EXPERIMENT_NAME=/Users/your-email/queryforge-eval diff --git a/.gitignore b/.gitignore index 46083bf..15d80c7 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,7 @@ node_modules/ chroma_data/ *.egg-info/ .DS_Store +mlruns/ +spark-warehouse/ +derby.log +metastore_db/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..3e1b89b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.11-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +EXPOSE 8000 + +CMD ["uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md index cd1e226..c1d3ff0 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,8 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data. +----------------+----------------+ | | +----------v----------+ +----------v----------+ - | Claude API | | PostgreSQL | - | Chain-of-thought | | Financial schemas: | + | Claude API | | Databricks CE | + | Chain-of-thought | | Delta Tables: | | reasoning + SQL gen | | - accounts | +----------+----------+ | - transactions | | | - risk_metrics | @@ -43,6 +43,7 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data. | +----------v----------+ | MLflow Tracking | + | (Databricks CE) | +---------------------+ ``` @@ -51,12 +52,17 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data. - **MLflow Prompt Registry** — Version-controlled prompt templates (system prompt, schema context, few-shot examples) - **MLflow Experiments** — Every evaluation run logged with SQL accuracy, execution rate, RAGAS scores - **MLflow Model Registry** — `production` vs `staging` prompt versions, with promotion gates -- **Databricks CE Notebooks** — Schema introspection, Gold layer sample queries for few-shot context +- **Databricks CE Notebooks** — Schema introspection, Delta table setup, Gold layer sample queries for few-shot context + +## Storage Layer (Delta Tables on Databricks CE) + +- **Delta Lake** — ACID-compliant storage for all financial schemas +- **Four core tables** — `accounts`, `transactions`, `risk_metrics`, `model_inventory` +- **Setup notebook** — `notebooks/01_setup_delta_tables.py` creates and seeds all tables ## Inference + Evaluation Layer (Docker) - **FastAPI** — `/generate-sql`, `/execute`, `/feedback` endpoints -- **PostgreSQL** — Sample financial database (accounts, transactions, risk_metrics, model_inventory schemas) - **Claude API** — SQL generation with chain-of-thought reasoning before outputting SQL - **RAGAS** — Automated evaluation: faithfulness, answer relevancy, context recall - **GitHub Actions** — Evaluation pipeline runs on every PR, blocks merge if accuracy drops @@ -67,6 +73,7 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data. - Docker and Docker Compose - Anthropic API key +- Databricks Community Edition account (free) - Python 3.11+ ### Setup @@ -78,9 +85,20 @@ cd QueryForge # Set environment variables cp .env.example .env -# Add your ANTHROPIC_API_KEY to .env +# Add your ANTHROPIC_API_KEY and DATABRICKS_TOKEN to .env +``` + +### Databricks CE Setup -# Start services +1. Sign up at [Databricks Community Edition](https://community.cloud.databricks.com) +2. Import `notebooks/01_setup_delta_tables.py` as a notebook +3. Attach to a cluster and run all cells — creates the `queryforge` database with Delta tables +4. Generate a personal access token: User Settings > Developer > Access Tokens + +### Run Locally + +```bash +# Start the API server docker-compose up -d # Run evaluation @@ -91,11 +109,22 @@ python scripts/evaluate.py ``` QueryForge/ -├── src/ # Application source code +├── src/ +│ ├── api/ # FastAPI endpoints +│ │ └── main.py +│ ├── core/ # Business logic +│ │ ├── config.py # Settings and env vars +│ │ ├── sql_generator.py # Claude-powered NL2SQL +│ │ └── mlflow_tracker.py # MLflow experiment logging +│ └── eval/ # Evaluation framework +│ └── evaluator.py # RAGAS + accuracy benchmarks +├── notebooks/ # Databricks CE notebooks +│ └── 01_setup_delta_tables.py ├── tests/ # Test suite -├── scripts/ # Utility and evaluation scripts -├── data/ # Schema definitions and sample data -├── docker-compose.yml # Service orchestration +├── scripts/ # CLI utilities +│ └── evaluate.py +├── docker-compose.yml # API service orchestration +├── Dockerfile ├── requirements.txt # Python dependencies └── .github/workflows/ # CI/CD pipelines ``` diff --git a/docker-compose.yml b/docker-compose.yml index e69de29..167cdcb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -0,0 +1,11 @@ +services: + queryforge-api: + build: . + ports: + - "8000:8000" + env_file: + - .env + volumes: + - ./src:/app/src + - ./data:/app/data + command: uvicorn src.api.main:app --host 0.0.0.0 --port 8000 --reload diff --git a/notebooks/01_setup_delta_tables.py b/notebooks/01_setup_delta_tables.py new file mode 100644 index 0000000..a3c3449 --- /dev/null +++ b/notebooks/01_setup_delta_tables.py @@ -0,0 +1,140 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # QueryForge - Delta Table Setup +# MAGIC Run this notebook in Databricks Community Edition to create the financial data schemas as Delta tables. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Create Database + +# COMMAND ---------- + +spark.sql("CREATE DATABASE IF NOT EXISTS queryforge") +spark.sql("USE queryforge") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Accounts Table + +# COMMAND ---------- + +from pyspark.sql.types import StructType, StructField, StringType, DecimalType, DateType + +accounts_data = [ + ("ACC001", "Alice Chen", "checking", 125000.50, "2020-03-15", "NYC001", "active"), + ("ACC002", "Bob Martinez", "savings", 89000.00, "2019-07-22", "LA002", "active"), + ("ACC003", "Carol Williams", "loan", 250000.00, "2021-01-10", "CHI003", "active"), + ("ACC004", "David Kim", "credit", 15000.75, "2018-11-05", "NYC001", "active"), + ("ACC005", "Eva Patel", "checking", 340000.00, "2022-06-18", "SF004", "active"), + ("ACC006", "Frank Johnson", "savings", 45000.00, "2017-02-28", "LA002", "closed"), + ("ACC007", "Grace Lee", "checking", 178000.25, "2020-09-14", "CHI003", "active"), + ("ACC008", "Henry Brown", "loan", 500000.00, "2023-03-01", "SF004", "active"), + ("ACC009", "Irene Davis", "credit", 8500.00, "2021-08-20", "NYC001", "frozen"), + ("ACC010", "Jack Wilson", "checking", 92000.00, "2019-12-03", "LA002", "active"), +] + +accounts_df = spark.createDataFrame(accounts_data, ["account_id", "customer_name", "account_type", "balance", "open_date", "branch_code", "status"]) +accounts_df = accounts_df.withColumn("balance", accounts_df.balance.cast(DecimalType(18, 2))) +accounts_df = accounts_df.withColumn("open_date", accounts_df.open_date.cast(DateType())) + +accounts_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.accounts") +print(f"Accounts table created with {accounts_df.count()} rows") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Transactions Table + +# COMMAND ---------- + +transactions_data = [ + ("TXN001", "ACC001", "2024-01-15 09:30:00", 5000.00, "credit", "salary", "Monthly salary deposit"), + ("TXN002", "ACC001", "2024-01-16 14:22:00", 150.00, "debit", "utilities", "Electric bill payment"), + ("TXN003", "ACC002", "2024-01-15 10:00:00", 2000.00, "credit", "transfer", "Transfer from checking"), + ("TXN004", "ACC003", "2024-01-20 08:00:00", 3500.00, "debit", "loan_payment", "Monthly loan payment"), + ("TXN005", "ACC004", "2024-01-18 16:45:00", 250.00, "debit", "shopping", "Online purchase"), + ("TXN006", "ACC005", "2024-01-15 09:00:00", 12000.00, "credit", "salary", "Monthly salary deposit"), + ("TXN007", "ACC005", "2024-01-22 11:30:00", 800.00, "debit", "dining", "Restaurant payment"), + ("TXN008", "ACC007", "2024-01-15 09:15:00", 8500.00, "credit", "salary", "Monthly salary deposit"), + ("TXN009", "ACC008", "2024-01-25 08:00:00", 5000.00, "debit", "loan_payment", "Monthly loan payment"), + ("TXN010", "ACC010", "2024-01-15 09:45:00", 6000.00, "credit", "salary", "Monthly salary deposit"), + ("TXN011", "ACC001", "2024-02-15 09:30:00", 5000.00, "credit", "salary", "Monthly salary deposit"), + ("TXN012", "ACC001", "2024-02-20 13:10:00", 2200.00, "debit", "rent", "Monthly rent payment"), + ("TXN013", "ACC005", "2024-02-15 09:00:00", 12000.00, "credit", "salary", "Monthly salary deposit"), + ("TXN014", "ACC007", "2024-02-15 09:15:00", 8500.00, "credit", "salary", "Monthly salary deposit"), + ("TXN015", "ACC010", "2024-02-15 09:45:00", 6000.00, "credit", "salary", "Monthly salary deposit"), +] + +from pyspark.sql.types import TimestampType + +transactions_df = spark.createDataFrame(transactions_data, ["txn_id", "account_id", "txn_date", "amount", "txn_type", "category", "description"]) +transactions_df = transactions_df.withColumn("amount", transactions_df.amount.cast(DecimalType(18, 2))) +transactions_df = transactions_df.withColumn("txn_date", transactions_df.txn_date.cast(TimestampType())) + +transactions_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.transactions") +print(f"Transactions table created with {transactions_df.count()} rows") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Risk Metrics Table + +# COMMAND ---------- + +risk_data = [ + ("ACC001", "2024-01-31", 750, "low", 0.0120, 0.3500), + ("ACC002", "2024-01-31", 680, "medium", 0.0450, 0.4000), + ("ACC003", "2024-01-31", 620, "high", 0.0890, 0.5500), + ("ACC004", "2024-01-31", 580, "high", 0.1200, 0.6000), + ("ACC005", "2024-01-31", 800, "low", 0.0050, 0.2500), + ("ACC007", "2024-01-31", 720, "low", 0.0180, 0.3200), + ("ACC008", "2024-01-31", 550, "critical", 0.1800, 0.7500), + ("ACC009", "2024-01-31", 490, "critical", 0.2500, 0.8500), + ("ACC010", "2024-01-31", 700, "medium", 0.0350, 0.3800), +] + +from pyspark.sql.types import IntegerType + +risk_df = spark.createDataFrame(risk_data, ["account_id", "metric_date", "credit_score", "risk_rating", "probability_of_default", "loss_given_default"]) +risk_df = risk_df.withColumn("metric_date", risk_df.metric_date.cast(DateType())) +risk_df = risk_df.withColumn("credit_score", risk_df.credit_score.cast(IntegerType())) +risk_df = risk_df.withColumn("probability_of_default", risk_df.probability_of_default.cast(DecimalType(5, 4))) +risk_df = risk_df.withColumn("loss_given_default", risk_df.loss_given_default.cast(DecimalType(5, 4))) + +risk_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.risk_metrics") +print(f"Risk metrics table created with {risk_df.count()} rows") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Model Inventory Table + +# COMMAND ---------- + +model_data = [ + ("MDL001", "Credit Scoring v3", "classification", "2023-06-15", "Risk Team", "active", "2024-01-15"), + ("MDL002", "Fraud Detection v2", "anomaly_detection", "2023-09-01", "Fraud Team", "active", "2024-01-20"), + ("MDL003", "Churn Predictor", "classification", "2022-11-10", "Marketing", "retired", "2023-06-10"), + ("MDL004", "LGD Model v1", "regression", "2024-01-05", "Risk Team", "validation", "2024-01-05"), + ("MDL005", "Transaction Classifier", "classification", "2023-03-20", "Operations", "active", "2023-12-15"), +] + +model_df = spark.createDataFrame(model_data, ["model_id", "model_name", "model_type", "deployment_date", "owner", "status", "last_validation_date"]) +model_df = model_df.withColumn("deployment_date", model_df.deployment_date.cast(DateType())) +model_df = model_df.withColumn("last_validation_date", model_df.last_validation_date.cast(DateType())) + +model_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.model_inventory") +print(f"Model inventory table created with {model_df.count()} rows") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Verify All Tables + +# COMMAND ---------- + +for table in ["accounts", "transactions", "risk_metrics", "model_inventory"]: + count = spark.sql(f"SELECT COUNT(*) FROM queryforge.{table}").collect()[0][0] + print(f"queryforge.{table}: {count} rows") diff --git a/requirements.txt b/requirements.txt index e69de29..03b0c21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,14 @@ +anthropic>=0.42.0 +fastapi>=0.115.0 +uvicorn>=0.34.0 +mlflow>=2.19.0 +databricks-sdk>=0.38.0 +delta-spark>=3.3.0 +pyspark>=3.5.0 +ragas>=0.2.0 +langchain>=0.3.0 +langchain-anthropic>=0.3.0 +pydantic>=2.10.0 +pydantic-settings>=2.7.0 +python-dotenv>=1.0.0 +httpx>=0.28.0 diff --git a/scripts/evaluate.py b/scripts/evaluate.py new file mode 100644 index 0000000..74ec404 --- /dev/null +++ b/scripts/evaluate.py @@ -0,0 +1,32 @@ +"""Run QueryForge evaluation suite from CLI.""" +import sys +import json +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.eval.evaluator import run_evaluation + + +def main(): + version = sys.argv[1] if len(sys.argv) > 1 else "v1" + print(f"Running QueryForge evaluation (prompt version: {version})...") + + results = run_evaluation(prompt_version=version) + + print(f"\n{'='*60}") + print(f"Evaluation Complete — Prompt Version: {results['prompt_version']}") + print(f"{'='*60}") + print(f"Total Questions: {results['metrics']['total_questions']}") + print(f"Total Tokens: {results['metrics']['total_tokens']}") + print(f"Avg Tokens/Query: {results['metrics']['avg_tokens_per_query']:.0f}") + + for r in results["results"]: + print(f"\nQ: {r['question']}") + print(f" Generated: {r['generated_sql'][:80]}...") + + print(f"\nFull results logged to MLflow experiment: {version}") + + +if __name__ == "__main__": + main() diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/api/main.py b/src/api/main.py new file mode 100644 index 0000000..7b44fc1 --- /dev/null +++ b/src/api/main.py @@ -0,0 +1,53 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from src.core.sql_generator import generate_sql +from src.core.mlflow_tracker import init_mlflow, log_generation + +app = FastAPI(title="QueryForge", version="0.1.0") + + +class QueryRequest(BaseModel): + question: str + + +class QueryResponse(BaseModel): + question: str + sql: str + model: str + input_tokens: int + output_tokens: int + + +class FeedbackRequest(BaseModel): + question: str + generated_sql: str + is_correct: bool + corrected_sql: str | None = None + notes: str | None = None + + +@app.on_event("startup") +async def startup(): + init_mlflow() + + +@app.get("/health") +async def health(): + return {"status": "healthy"} + + +@app.post("/generate-sql", response_model=QueryResponse) +async def generate(request: QueryRequest): + try: + result = generate_sql(request.question) + log_generation(**result) + return result + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/feedback") +async def feedback(request: FeedbackRequest): + # TODO: Store feedback in Delta table for fine-tuning loop + return {"status": "recorded", "question": request.question} diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000..1d4fc80 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,17 @@ +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + anthropic_api_key: str = "" + databricks_host: str = "https://community.cloud.databricks.com" + databricks_token: str = "" + mlflow_tracking_uri: str = "databricks" + mlflow_experiment_name: str = "/queryforge-eval" + model_name: str = "claude-sonnet-4-20250514" + max_tokens: int = 1024 + + class Config: + env_file = ".env" + + +settings = Settings() diff --git a/src/core/mlflow_tracker.py b/src/core/mlflow_tracker.py new file mode 100644 index 0000000..48bcfcf --- /dev/null +++ b/src/core/mlflow_tracker.py @@ -0,0 +1,28 @@ +import mlflow +from src.core.config import settings + + +def init_mlflow(): + mlflow.set_tracking_uri(settings.mlflow_tracking_uri) + mlflow.set_experiment(settings.mlflow_experiment_name) + + +def log_generation(question: str, sql: str, model: str, input_tokens: int, output_tokens: int): + with mlflow.start_run(nested=True): + mlflow.log_params({ + "model": model, + "question_length": len(question), + }) + mlflow.log_metrics({ + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + }) + mlflow.log_text(question, "question.txt") + mlflow.log_text(sql, "generated_sql.sql") + + +def log_evaluation(metrics: dict, prompt_version: str): + with mlflow.start_run(nested=True): + mlflow.log_param("prompt_version", prompt_version) + mlflow.log_metrics(metrics) diff --git a/src/core/sql_generator.py b/src/core/sql_generator.py new file mode 100644 index 0000000..0805a62 --- /dev/null +++ b/src/core/sql_generator.py @@ -0,0 +1,82 @@ +import anthropic +from src.core.config import settings + +SYSTEM_PROMPT = """You are a SQL expert for financial databases. Given a natural language question +and schema context, generate valid SQL. + +Rules: +- Output ONLY the SQL query, no explanation +- Use standard SQL compatible with Spark SQL / Delta Lake +- Never use DELETE, DROP, TRUNCATE, or any DDL/DML that modifies data +- Always qualify column names with table aliases when joining + +Available schemas: + +TABLE accounts ( + account_id STRING, + customer_name STRING, + account_type STRING, -- 'checking', 'savings', 'loan', 'credit' + balance DECIMAL(18,2), + open_date DATE, + branch_code STRING, + status STRING -- 'active', 'closed', 'frozen' +) + +TABLE transactions ( + txn_id STRING, + account_id STRING, + txn_date TIMESTAMP, + amount DECIMAL(18,2), + txn_type STRING, -- 'credit', 'debit' + category STRING, + description STRING +) + +TABLE risk_metrics ( + account_id STRING, + metric_date DATE, + credit_score INT, + risk_rating STRING, -- 'low', 'medium', 'high', 'critical' + probability_of_default DECIMAL(5,4), + loss_given_default DECIMAL(5,4) +) + +TABLE model_inventory ( + model_id STRING, + model_name STRING, + model_type STRING, + deployment_date DATE, + owner STRING, + status STRING, -- 'active', 'retired', 'validation' + last_validation_date DATE +) +""" + +client = anthropic.Anthropic(api_key=settings.anthropic_api_key) + + +def generate_sql(question: str) -> dict: + response = client.messages.create( + model=settings.model_name, + max_tokens=settings.max_tokens, + system=SYSTEM_PROMPT, + messages=[ + { + "role": "user", + "content": f"Generate SQL for: {question}", + } + ], + ) + sql = response.content[0].text.strip() + # Strip markdown code fences if present + if sql.startswith("```"): + lines = sql.split("\n") + sql = "\n".join(lines[1:-1]).strip() + + return { + "question": question, + "sql": sql, + "model": settings.model_name, + "input_tokens": response.usage.input_tokens, + "output_tokens": response.usage.output_tokens, + } diff --git a/src/eval/__init__.py b/src/eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/eval/evaluator.py b/src/eval/evaluator.py new file mode 100644 index 0000000..ef68709 --- /dev/null +++ b/src/eval/evaluator.py @@ -0,0 +1,59 @@ +import mlflow +from src.core.sql_generator import generate_sql +from src.core.mlflow_tracker import init_mlflow, log_evaluation + +EVAL_QUESTIONS = [ + { + "question": "What is the total balance across all active checking accounts?", + "expected_sql": "SELECT SUM(balance) AS total_balance FROM accounts WHERE account_type = 'checking' AND status = 'active'", + }, + { + "question": "Show me the top 10 customers by transaction volume in the last 30 days", + "expected_sql": "SELECT a.customer_name, COUNT(t.txn_id) AS txn_count FROM accounts a JOIN transactions t ON a.account_id = t.account_id WHERE t.txn_date >= CURRENT_DATE - INTERVAL 30 DAY GROUP BY a.customer_name ORDER BY txn_count DESC LIMIT 10", + }, + { + "question": "List all accounts with a high or critical risk rating and balance over 100000", + "expected_sql": "SELECT a.account_id, a.customer_name, a.balance, r.risk_rating FROM accounts a JOIN risk_metrics r ON a.account_id = r.account_id WHERE r.risk_rating IN ('high', 'critical') AND a.balance > 100000", + }, + { + "question": "How many models are currently in validation status?", + "expected_sql": "SELECT COUNT(*) AS model_count FROM model_inventory WHERE status = 'validation'", + }, + { + "question": "What is the average probability of default for each risk rating category?", + "expected_sql": "SELECT risk_rating, AVG(probability_of_default) AS avg_pd FROM risk_metrics GROUP BY risk_rating ORDER BY avg_pd DESC", + }, +] + + +def run_evaluation(prompt_version: str = "v1") -> dict: + init_mlflow() + + results = [] + with mlflow.start_run(run_name=f"eval-{prompt_version}"): + for item in EVAL_QUESTIONS: + result = generate_sql(item["question"]) + results.append({ + "question": item["question"], + "expected_sql": item["expected_sql"], + "generated_sql": result["sql"], + "tokens_used": result["input_tokens"] + result["output_tokens"], + }) + + total = len(results) + metrics = { + "total_questions": total, + "total_tokens": sum(r["tokens_used"] for r in results), + "avg_tokens_per_query": sum(r["tokens_used"] for r in results) / total, + } + log_evaluation(metrics, prompt_version) + + return {"prompt_version": prompt_version, "metrics": metrics, "results": results} + + +if __name__ == "__main__": + output = run_evaluation() + for r in output["results"]: + print(f"\nQ: {r['question']}") + print(f"Generated: {r['generated_sql']}") + print(f"Expected: {r['expected_sql']}") From aca73d8da47675a4d123d2bdaff3f477f8a22557 Mon Sep 17 00:00:00 2001 From: Ruby Gunna Date: Wed, 15 Apr 2026 12:24:10 -0400 Subject: [PATCH 2/2] fix(ci): add id-token write permission for Claude OIDC auth Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/claude-code.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/claude-code.yml b/.github/workflows/claude-code.yml index d6ab94f..00293c2 100644 --- a/.github/workflows/claude-code.yml +++ b/.github/workflows/claude-code.yml @@ -22,6 +22,7 @@ jobs: contents: read pull-requests: write issues: write + id-token: write steps: - name: Run Claude Code uses: anthropics/claude-code-action@v1