Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ANTHROPIC_API_KEY=your-key-here
DATABRICKS_HOST=https://community.cloud.databricks.com
DATABRICKS_TOKEN=your-databricks-token-here
MLFLOW_TRACKING_URI=databricks
MLFLOW_EXPERIMENT_NAME=/Users/your-email/queryforge-eval
1 change: 1 addition & 0 deletions .github/workflows/claude-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
contents: read
pull-requests: write
issues: write
id-token: write
steps:
- name: Run Claude Code
uses: anthropics/claude-code-action@v1
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ node_modules/
chroma_data/
*.egg-info/
.DS_Store
mlruns/
spark-warehouse/
derby.log
metastore_db/
12 changes: 12 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

EXPOSE 8000

CMD ["uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "8000"]
49 changes: 39 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data.
+----------------+----------------+
| |
+----------v----------+ +----------v----------+
| Claude API | | PostgreSQL |
| Chain-of-thought | | Financial schemas: |
| Claude API | | Databricks CE |
| Chain-of-thought | | Delta Tables: |
| reasoning + SQL gen | | - accounts |
+----------+----------+ | - transactions |
| | - risk_metrics |
Expand All @@ -43,6 +43,7 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data.
|
+----------v----------+
| MLflow Tracking |
| (Databricks CE) |
+---------------------+
```

Expand All @@ -51,12 +52,17 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data.
- **MLflow Prompt Registry** — Version-controlled prompt templates (system prompt, schema context, few-shot examples)
- **MLflow Experiments** — Every evaluation run logged with SQL accuracy, execution rate, RAGAS scores
- **MLflow Model Registry** — `production` vs `staging` prompt versions, with promotion gates
- **Databricks CE Notebooks** — Schema introspection, Gold layer sample queries for few-shot context
- **Databricks CE Notebooks** — Schema introspection, Delta table setup, Gold layer sample queries for few-shot context

## Storage Layer (Delta Tables on Databricks CE)

- **Delta Lake** — ACID-compliant storage for all financial schemas
- **Four core tables** — `accounts`, `transactions`, `risk_metrics`, `model_inventory`
- **Setup notebook** — `notebooks/01_setup_delta_tables.py` creates and seeds all tables

## Inference + Evaluation Layer (Docker)

- **FastAPI** — `/generate-sql`, `/execute`, `/feedback` endpoints
- **PostgreSQL** — Sample financial database (accounts, transactions, risk_metrics, model_inventory schemas)
- **Claude API** — SQL generation with chain-of-thought reasoning before outputting SQL
- **RAGAS** — Automated evaluation: faithfulness, answer relevancy, context recall
- **GitHub Actions** — Evaluation pipeline runs on every PR, blocks merge if accuracy drops
Expand All @@ -67,6 +73,7 @@ Inspired by the text-to-SQL POC built at SMBC on real banking data.

- Docker and Docker Compose
- Anthropic API key
- Databricks Community Edition account (free)
- Python 3.11+

### Setup
Expand All @@ -78,9 +85,20 @@ cd QueryForge

# Set environment variables
cp .env.example .env
# Add your ANTHROPIC_API_KEY to .env
# Add your ANTHROPIC_API_KEY and DATABRICKS_TOKEN to .env
```

### Databricks CE Setup

# Start services
1. Sign up at [Databricks Community Edition](https://community.cloud.databricks.com)
2. Import `notebooks/01_setup_delta_tables.py` as a notebook
3. Attach to a cluster and run all cells — creates the `queryforge` database with Delta tables
4. Generate a personal access token: User Settings > Developer > Access Tokens

### Run Locally

```bash
# Start the API server
docker-compose up -d

# Run evaluation
Expand All @@ -91,11 +109,22 @@ python scripts/evaluate.py

```
QueryForge/
├── src/ # Application source code
├── src/
│ ├── api/ # FastAPI endpoints
│ │ └── main.py
│ ├── core/ # Business logic
│ │ ├── config.py # Settings and env vars
│ │ ├── sql_generator.py # Claude-powered NL2SQL
│ │ └── mlflow_tracker.py # MLflow experiment logging
│ └── eval/ # Evaluation framework
│ └── evaluator.py # RAGAS + accuracy benchmarks
├── notebooks/ # Databricks CE notebooks
│ └── 01_setup_delta_tables.py
├── tests/ # Test suite
├── scripts/ # Utility and evaluation scripts
├── data/ # Schema definitions and sample data
├── docker-compose.yml # Service orchestration
├── scripts/ # CLI utilities
│ └── evaluate.py
├── docker-compose.yml # API service orchestration
├── Dockerfile
├── requirements.txt # Python dependencies
└── .github/workflows/ # CI/CD pipelines
```
Expand Down
11 changes: 11 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
services:
queryforge-api:
build: .
ports:
- "8000:8000"
env_file:
- .env
volumes:
- ./src:/app/src
- ./data:/app/data
command: uvicorn src.api.main:app --host 0.0.0.0 --port 8000 --reload
140 changes: 140 additions & 0 deletions notebooks/01_setup_delta_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Databricks notebook source
# MAGIC %md
# MAGIC # QueryForge - Delta Table Setup
# MAGIC Run this notebook in Databricks Community Edition to create the financial data schemas as Delta tables.

# COMMAND ----------

# MAGIC %md
# MAGIC ## Create Database

# COMMAND ----------

spark.sql("CREATE DATABASE IF NOT EXISTS queryforge")
spark.sql("USE queryforge")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Accounts Table

# COMMAND ----------

from pyspark.sql.types import StructType, StructField, StringType, DecimalType, DateType

accounts_data = [
("ACC001", "Alice Chen", "checking", 125000.50, "2020-03-15", "NYC001", "active"),
("ACC002", "Bob Martinez", "savings", 89000.00, "2019-07-22", "LA002", "active"),
("ACC003", "Carol Williams", "loan", 250000.00, "2021-01-10", "CHI003", "active"),
("ACC004", "David Kim", "credit", 15000.75, "2018-11-05", "NYC001", "active"),
("ACC005", "Eva Patel", "checking", 340000.00, "2022-06-18", "SF004", "active"),
("ACC006", "Frank Johnson", "savings", 45000.00, "2017-02-28", "LA002", "closed"),
("ACC007", "Grace Lee", "checking", 178000.25, "2020-09-14", "CHI003", "active"),
("ACC008", "Henry Brown", "loan", 500000.00, "2023-03-01", "SF004", "active"),
("ACC009", "Irene Davis", "credit", 8500.00, "2021-08-20", "NYC001", "frozen"),
("ACC010", "Jack Wilson", "checking", 92000.00, "2019-12-03", "LA002", "active"),
]

accounts_df = spark.createDataFrame(accounts_data, ["account_id", "customer_name", "account_type", "balance", "open_date", "branch_code", "status"])
accounts_df = accounts_df.withColumn("balance", accounts_df.balance.cast(DecimalType(18, 2)))
accounts_df = accounts_df.withColumn("open_date", accounts_df.open_date.cast(DateType()))

accounts_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.accounts")
print(f"Accounts table created with {accounts_df.count()} rows")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Transactions Table

# COMMAND ----------

transactions_data = [
("TXN001", "ACC001", "2024-01-15 09:30:00", 5000.00, "credit", "salary", "Monthly salary deposit"),
("TXN002", "ACC001", "2024-01-16 14:22:00", 150.00, "debit", "utilities", "Electric bill payment"),
("TXN003", "ACC002", "2024-01-15 10:00:00", 2000.00, "credit", "transfer", "Transfer from checking"),
("TXN004", "ACC003", "2024-01-20 08:00:00", 3500.00, "debit", "loan_payment", "Monthly loan payment"),
("TXN005", "ACC004", "2024-01-18 16:45:00", 250.00, "debit", "shopping", "Online purchase"),
("TXN006", "ACC005", "2024-01-15 09:00:00", 12000.00, "credit", "salary", "Monthly salary deposit"),
("TXN007", "ACC005", "2024-01-22 11:30:00", 800.00, "debit", "dining", "Restaurant payment"),
("TXN008", "ACC007", "2024-01-15 09:15:00", 8500.00, "credit", "salary", "Monthly salary deposit"),
("TXN009", "ACC008", "2024-01-25 08:00:00", 5000.00, "debit", "loan_payment", "Monthly loan payment"),
("TXN010", "ACC010", "2024-01-15 09:45:00", 6000.00, "credit", "salary", "Monthly salary deposit"),
("TXN011", "ACC001", "2024-02-15 09:30:00", 5000.00, "credit", "salary", "Monthly salary deposit"),
("TXN012", "ACC001", "2024-02-20 13:10:00", 2200.00, "debit", "rent", "Monthly rent payment"),
("TXN013", "ACC005", "2024-02-15 09:00:00", 12000.00, "credit", "salary", "Monthly salary deposit"),
("TXN014", "ACC007", "2024-02-15 09:15:00", 8500.00, "credit", "salary", "Monthly salary deposit"),
("TXN015", "ACC010", "2024-02-15 09:45:00", 6000.00, "credit", "salary", "Monthly salary deposit"),
]

from pyspark.sql.types import TimestampType

transactions_df = spark.createDataFrame(transactions_data, ["txn_id", "account_id", "txn_date", "amount", "txn_type", "category", "description"])
transactions_df = transactions_df.withColumn("amount", transactions_df.amount.cast(DecimalType(18, 2)))
transactions_df = transactions_df.withColumn("txn_date", transactions_df.txn_date.cast(TimestampType()))

transactions_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.transactions")
print(f"Transactions table created with {transactions_df.count()} rows")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Risk Metrics Table

# COMMAND ----------

risk_data = [
("ACC001", "2024-01-31", 750, "low", 0.0120, 0.3500),
("ACC002", "2024-01-31", 680, "medium", 0.0450, 0.4000),
("ACC003", "2024-01-31", 620, "high", 0.0890, 0.5500),
("ACC004", "2024-01-31", 580, "high", 0.1200, 0.6000),
("ACC005", "2024-01-31", 800, "low", 0.0050, 0.2500),
("ACC007", "2024-01-31", 720, "low", 0.0180, 0.3200),
("ACC008", "2024-01-31", 550, "critical", 0.1800, 0.7500),
("ACC009", "2024-01-31", 490, "critical", 0.2500, 0.8500),
("ACC010", "2024-01-31", 700, "medium", 0.0350, 0.3800),
]

from pyspark.sql.types import IntegerType

risk_df = spark.createDataFrame(risk_data, ["account_id", "metric_date", "credit_score", "risk_rating", "probability_of_default", "loss_given_default"])
risk_df = risk_df.withColumn("metric_date", risk_df.metric_date.cast(DateType()))
risk_df = risk_df.withColumn("credit_score", risk_df.credit_score.cast(IntegerType()))
risk_df = risk_df.withColumn("probability_of_default", risk_df.probability_of_default.cast(DecimalType(5, 4)))
risk_df = risk_df.withColumn("loss_given_default", risk_df.loss_given_default.cast(DecimalType(5, 4)))

risk_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.risk_metrics")
print(f"Risk metrics table created with {risk_df.count()} rows")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Model Inventory Table

# COMMAND ----------

model_data = [
("MDL001", "Credit Scoring v3", "classification", "2023-06-15", "Risk Team", "active", "2024-01-15"),
("MDL002", "Fraud Detection v2", "anomaly_detection", "2023-09-01", "Fraud Team", "active", "2024-01-20"),
("MDL003", "Churn Predictor", "classification", "2022-11-10", "Marketing", "retired", "2023-06-10"),
("MDL004", "LGD Model v1", "regression", "2024-01-05", "Risk Team", "validation", "2024-01-05"),
("MDL005", "Transaction Classifier", "classification", "2023-03-20", "Operations", "active", "2023-12-15"),
]

model_df = spark.createDataFrame(model_data, ["model_id", "model_name", "model_type", "deployment_date", "owner", "status", "last_validation_date"])
model_df = model_df.withColumn("deployment_date", model_df.deployment_date.cast(DateType()))
model_df = model_df.withColumn("last_validation_date", model_df.last_validation_date.cast(DateType()))

model_df.write.format("delta").mode("overwrite").saveAsTable("queryforge.model_inventory")
print(f"Model inventory table created with {model_df.count()} rows")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Verify All Tables

# COMMAND ----------

for table in ["accounts", "transactions", "risk_metrics", "model_inventory"]:
count = spark.sql(f"SELECT COUNT(*) FROM queryforge.{table}").collect()[0][0]
print(f"queryforge.{table}: {count} rows")
14 changes: 14 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
anthropic>=0.42.0
fastapi>=0.115.0
uvicorn>=0.34.0
mlflow>=2.19.0
databricks-sdk>=0.38.0
delta-spark>=3.3.0
pyspark>=3.5.0
ragas>=0.2.0
langchain>=0.3.0
langchain-anthropic>=0.3.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-dotenv>=1.0.0
httpx>=0.28.0
32 changes: 32 additions & 0 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Run QueryForge evaluation suite from CLI."""
import sys
import json
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.eval.evaluator import run_evaluation


def main():
version = sys.argv[1] if len(sys.argv) > 1 else "v1"
print(f"Running QueryForge evaluation (prompt version: {version})...")

results = run_evaluation(prompt_version=version)

print(f"\n{'='*60}")
print(f"Evaluation Complete — Prompt Version: {results['prompt_version']}")
print(f"{'='*60}")
print(f"Total Questions: {results['metrics']['total_questions']}")
print(f"Total Tokens: {results['metrics']['total_tokens']}")
print(f"Avg Tokens/Query: {results['metrics']['avg_tokens_per_query']:.0f}")

for r in results["results"]:
print(f"\nQ: {r['question']}")
print(f" Generated: {r['generated_sql'][:80]}...")

print(f"\nFull results logged to MLflow experiment: {version}")


if __name__ == "__main__":
main()
Empty file added src/__init__.py
Empty file.
Empty file added src/api/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions src/api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from src.core.sql_generator import generate_sql
from src.core.mlflow_tracker import init_mlflow, log_generation

app = FastAPI(title="QueryForge", version="0.1.0")


class QueryRequest(BaseModel):
question: str


class QueryResponse(BaseModel):
question: str
sql: str
model: str
input_tokens: int
output_tokens: int


class FeedbackRequest(BaseModel):
question: str
generated_sql: str
is_correct: bool
corrected_sql: str | None = None
notes: str | None = None


@app.on_event("startup")
async def startup():
init_mlflow()


@app.get("/health")
async def health():
return {"status": "healthy"}


@app.post("/generate-sql", response_model=QueryResponse)
async def generate(request: QueryRequest):
try:
result = generate_sql(request.question)
log_generation(**result)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post("/feedback")
async def feedback(request: FeedbackRequest):
# TODO: Store feedback in Delta table for fine-tuning loop
return {"status": "recorded", "question": request.question}
Empty file added src/core/__init__.py
Empty file.
Loading