Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 136 additions & 42 deletions src/openlayer/lib/core/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import abc
import json
import time
import asyncio
import inspect
import argparse
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Tuple, Optional
from dataclasses import field, dataclass

import pandas as pd
Expand Down Expand Up @@ -36,6 +37,12 @@ class OpenlayerModel(abc.ABC):

It is more conventional to implement the `run` method.

``run`` may be defined as either ``def run`` (called sequentially per row)
or ``async def run``. When ``run`` is async, ``run_batch_from_df`` will drive
rows concurrently with ``asyncio.gather``; pass ``max_workers > 1`` to enable
concurrent execution. Use async-native I/O (``httpx``, ``openai-async``, etc.)
inside an async ``run`` to actually benefit from concurrency.

Refer to Openlayer's templates for examples of how to implement this class.
"""

Expand All @@ -59,6 +66,15 @@ def run_from_cli(self) -> None:
required=False,
help="Custom arguments in format 'key1=value1,key2=value2'",
)
parser.add_argument(
"--max-workers",
type=int,
default=None,
help=(
"Max concurrent rows when run() is async. "
"Defaults to 4 for async run, 1 for sync run."
),
)

# Parse the arguments
args = parser.parse_args()
Expand All @@ -76,9 +92,12 @@ def run_from_cli(self) -> None:
return self.batch(
dataset_path=args.dataset_path,
output_dir=args.output_dir,
max_workers=args.max_workers,
)

def batch(self, dataset_path: str, output_dir: str) -> None:
def batch(
self, dataset_path: str, output_dir: str, max_workers: Optional[int] = None
) -> None:
"""Reads the dataset from a file and runs the model on it."""
# Load the dataset into a pandas DataFrame
fmt = "csv"
Expand All @@ -91,50 +110,125 @@ def batch(self, dataset_path: str, output_dir: str) -> None:
raise ValueError(f"Unsupported dataset format: {dataset_path}")

# Call the model's run_batch method, passing in the DataFrame
output_df, config = self.run_batch_from_df(df)
output_df, config = self.run_batch_from_df(df, max_workers=max_workers)
self.write_output_to_directory(output_df, config, output_dir, fmt)

def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
"""Function that runs the model and returns the result."""
# Ensure the 'output' column exists
if "output" not in df.columns:
df["output"] = None
def run_batch_from_df(
self, df: pd.DataFrame, max_workers: Optional[int] = None
) -> Tuple[pd.DataFrame, dict]:
"""Function that runs the model and returns the result.

# Get the signature of the 'run' method
If ``run`` is defined as ``async def run(...)``, rows are dispatched
concurrently with ``asyncio.gather`` gated by ``asyncio.Semaphore(max_workers)``.
``max_workers`` defaults to 4 for an async ``run`` (writing `async def`
is the opt-in signal that interleaving is safe). For a synchronous
``run``, rows are processed sequentially and ``max_workers`` must be 1.

A row's exception propagates and aborts the batch. For the async path,
``asyncio.gather`` cancels in-flight siblings before re-raising.
"""
run_signature = inspect.signature(self.run)
valid_params = set(run_signature.parameters)
is_async = inspect.iscoroutinefunction(self.run)

if max_workers is None:
max_workers = 4 if is_async else 1
elif max_workers < 1:
raise ValueError("max_workers must be >= 1")

if max_workers > 1 and not is_async:
raise ValueError(
"max_workers > 1 requires an async `run` method. "
"Define `run` as `async def run(self, ...)` to enable "
"concurrent execution."
)

for col in ("output", "steps", "latency", "cost", "tokens", "context"):
if col not in df.columns:
df[col] = None

rows = [
(
idx,
{k: v for k, v in row.to_dict().items() if k in valid_params},
)
for idx, row in df.iterrows()
]

if is_async:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"run_batch_from_df was called from inside a running event "
"loop. Call `await self._run_rows_async(...)` directly "
"from async code."
)
results = asyncio.run(self._run_rows_async(rows, max_workers))
else:
results = [
(idx, self.run(**kwargs), tracer.get_current_trace())
for idx, kwargs in rows
]

for index, output, trace in results:
self._apply_row_result(df, index, output, trace)

for index, row in df.iterrows():
# Filter row_dict to only include keys that are valid parameters
# for the 'run' method
row_dict = row.to_dict()
filtered_kwargs = {
k: v for k, v in row_dict.items() if k in run_signature.parameters
}

# Call the run method with filtered kwargs
output = self.run(**filtered_kwargs)

df.at[index, "output"] = output.output

for k, v in output.other_fields.items():
if k not in df.columns:
df[k] = None
df.at[index, k] = v

trace = tracer.get_current_trace()
if trace:
processed_trace, _ = tracer.post_process_trace(trace_obj=trace)
df.at[index, "steps"] = trace.to_dict()
if "latency" in processed_trace:
df.at[index, "latency"] = processed_trace["latency"]
if "cost" in processed_trace:
df.at[index, "cost"] = processed_trace["cost"]
if "tokens" in processed_trace:
df.at[index, "tokens"] = processed_trace["tokens"]
if "context" in processed_trace:
df.at[index, "context"] = processed_trace["context"]

config = {
return df, self._build_config(run_signature, df)

async def _run_rows_async(
self,
rows: List[Tuple[Any, Dict[str, Any]]],
max_workers: int,
) -> List[Tuple[Any, RunReturn, Optional[Any]]]:
"""Drive an async ``run`` over all rows with bounded concurrency.

The first row to raise causes ``asyncio.gather`` to cancel in-flight
siblings and re-raise the original exception.
"""
sem = asyncio.Semaphore(max_workers)

async def _one(index: Any, kwargs: Dict[str, Any]):
async with sem:
output = await self.run(**kwargs)
return index, output, tracer.get_current_trace()

return await asyncio.gather(*(_one(i, k) for i, k in rows))

def _apply_row_result(
self,
df: pd.DataFrame,
index: Any,
output: RunReturn,
trace: Optional[Any],
) -> None:
"""Write a single row's output and trace fields into ``df`` in place."""
df.at[index, "output"] = output.output

for k, v in output.other_fields.items():
if k not in df.columns:
df[k] = None
df.at[index, k] = v

if trace:
processed_trace, _ = tracer.post_process_trace(trace_obj=trace)
df.at[index, "steps"] = trace.to_dict()
if "latency" in processed_trace:
df.at[index, "latency"] = processed_trace["latency"]
if "cost" in processed_trace:
df.at[index, "cost"] = processed_trace["cost"]
if "tokens" in processed_trace:
df.at[index, "tokens"] = processed_trace["tokens"]
if "context" in processed_trace:
df.at[index, "context"] = processed_trace["context"]

def _build_config(
self, run_signature: inspect.Signature, df: pd.DataFrame
) -> Dict[str, Any]:
"""Build the config dict returned alongside the output DataFrame."""
config: Dict[str, Any] = {
"outputColumnName": "output",
"inputVariableNames": list(run_signature.parameters.keys()),
"metadata": {
Expand All @@ -154,7 +248,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
for k, v in self.custom_args.items():
config["metadata"][k] = v

return df, config
return config

def write_output_to_directory(
self,
Expand Down