Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions packages/bigframes/bigframes/bigquery/_operations/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import bigframes.dataframe as dataframe
import bigframes.ml.base
import bigframes.session
import bigframes.core.col as col
from bigframes.bigquery._operations import utils


Expand All @@ -50,7 +51,9 @@ def create_model(
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
options: Optional[
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
] = None,
training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
session: Optional[bigframes.session.Session] = None,
Expand Down Expand Up @@ -78,7 +81,7 @@ def create_model(
The OUTPUT clause, which specifies the schema of the output data.
connection_name (str, optional):
The connection to use for the model.
options (Mapping[str, Union[str, int, float, bool, list]], optional):
options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.col.Expression]], optional):
The OPTIONS clause, which specifies the model options.
training_data (Union[bigframes.pandas.DataFrame, str], optional):
The query or DataFrame to use for training the model.
Expand Down
11 changes: 9 additions & 2 deletions packages/bigframes/bigframes/core/sql/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from typing import Any, Dict, List, Mapping, Optional, Union

import bigframes.core.col as col
from bigframes.core.compile.sqlglot import sql as sg_sql
from bigframes.core.compile.sqlglot.expression_compiler import expression_compiler


def create_model_ddl(
Expand All @@ -28,7 +30,9 @@ def create_model_ddl(
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
options: Optional[
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
] = None,
training_data: Optional[str] = None,
custom_holiday: Optional[str] = None,
) -> str:
Expand Down Expand Up @@ -70,7 +74,10 @@ def create_model_ddl(
if options:
rendered_options = []
for option_name, option_value in options.items():
if isinstance(option_value, (list, tuple)):
if isinstance(option_value, col.Expression):
sg_expr = expression_compiler.compile_expression(option_value._value)
rendered_val = sg_sql.to_sql(sg_expr)
elif isinstance(option_value, (list, tuple)):
# Handle list options like model_registry="vertex_ai"
# wait, usually options are key=value.
# if value is list, it is [val1, val2]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE MODEL `my_model`
OPTIONS(l2_reg = 0.1 * 10, booster_type = 'gbtree')
AS SELECT * FROM t
24 changes: 24 additions & 0 deletions packages/bigframes/tests/unit/core/sql/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

import pytest

import bigframes.core.col as col
import bigframes.core.expression as ex
import bigframes.core.sql.ml
import bigframes.dtypes as dtypes
import bigframes.operations.numeric_ops as numeric_ops

pytest.importorskip("pytest_snapshot")

Expand Down Expand Up @@ -97,6 +101,26 @@ def test_create_model_list_option(snapshot):
snapshot.assert_match(sql, "create_model_list_option.sql")


def test_create_model_expression_option(snapshot):
# An expression that calls a function on a literal value
# e.g. 0.1 * 10
literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE)
multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE)
math_expr = col.Expression(
ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr))
)

sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
options={
"l2_reg": math_expr,
"booster_type": "gbtree",
},
training_data="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_expression_option.sql")


def test_evaluate_model_basic(snapshot):
sql = bigframes.core.sql.ml.evaluate(
model_name="my_project.my_dataset.my_model",
Expand Down
Loading