Skip to content

Commit e9113dd

Browse files
timsaucerclaude
andcommitted
feat(spark): make pyspark-optional params optional
Match pyspark's optional-parameter surface in the spark namespace: - make_dt_interval, make_interval: all parts default to zero (int32 0 / lit 0.0) - str_to_map: pair_delim defaults to ',', key_value_delim defaults to ':' - round: scale defaults to 0 (HALF_UP rounding to nearest integer) - shuffle: accepts `seed` kwarg for pyspark parity; raises NotImplementedError for non-None values until the Rust binding supports it - like, ilike: accept `escapeChar` for pyspark parity; same NotImplementedError guard; first positional renamed `string` → `str` to match pyspark ceil/floor `scale=` deferred — the underlying Rust expr_fn is single-arg. Added a module-level `_ZERO_I32` literal to avoid rebuilding the pyarrow int32 zero scalar on every call. Tests: positional-compat coverage for aggregates (`spark.avg(col)` etc.), defaults-omitted cases for the optional-arg functions, and NotImplementedError cases for `shuffle(seed=)` and `like/ilike(escapeChar=)`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ea25686 commit e9113dd

2 files changed

Lines changed: 179 additions & 37 deletions

File tree

python/datafusion/functions/spark.py

Lines changed: 112 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
from typing import TYPE_CHECKING, Any
3434

35+
import pyarrow as pa
36+
3537
from datafusion._internal import functions as _functions
3638
from datafusion.expr import Expr, sort_list_to_raw_sort_list
3739

@@ -41,6 +43,9 @@
4143

4244
_f = _functions.spark
4345

46+
# Reused int32 literal so optional-arg defaults don't rebuild it per call.
47+
_ZERO_I32 = Expr.literal(pa.scalar(0, type=pa.int32()))
48+
4449

4550
def _filter_raw(filter: Expr | None) -> Any:
4651
return filter.expr if filter is not None else None
@@ -203,9 +208,12 @@ def array(*cols: Expr) -> Expr:
203208
return Expr(_f.array(*[c.expr for c in cols]))
204209

205210

206-
def shuffle(col: Expr) -> Expr:
211+
def shuffle(col: Expr, seed: int | None = None) -> Expr:
207212
"""Spark ``shuffle``: returns a random permutation of the input array.
208213
214+
``seed`` is accepted for pyspark parity but is not yet wired through the
215+
Rust binding; passing a non-``None`` value raises ``NotImplementedError``.
216+
209217
Examples:
210218
>>> ctx = dfn.SessionContext()
211219
>>> df = ctx.from_pydict({"x": [1]})
@@ -217,6 +225,9 @@ def shuffle(col: Expr) -> Expr:
217225
>>> sorted(r.collect_column("v")[0].as_py())
218226
[1, 2, 3]
219227
"""
228+
if seed is not None:
229+
msg = "shuffle(seed=...) is not yet supported by the Spark UDF binding"
230+
raise NotImplementedError(msg)
220231
return Expr(_f.shuffle(col.expr))
221232

222233

@@ -589,59 +600,78 @@ def last_day(col: Expr) -> Expr:
589600
return Expr(_f.last_day(col.expr))
590601

591602

592-
def make_dt_interval(days: Expr, hours: Expr, mins: Expr, secs: Expr) -> Expr:
603+
def make_dt_interval(
604+
days: Expr | None = None,
605+
hours: Expr | None = None,
606+
mins: Expr | None = None,
607+
secs: Expr | None = None,
608+
) -> Expr:
593609
"""Spark ``make_dt_interval``: day-time interval from components.
594610
611+
All parts are optional; omitted parts default to zero, matching pyspark.
612+
595613
Examples:
596-
>>> import pyarrow as pa
597614
>>> ctx = dfn.SessionContext()
598615
>>> df = ctx.from_pydict({"x": [1]})
616+
>>> r = df.select(dfn.functions.spark.make_dt_interval().alias("v"))
617+
>>> r.collect_column("v")[0].as_py()
618+
datetime.timedelta(0)
619+
620+
>>> import pyarrow as pa
599621
>>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32()))
600622
>>> r = df.select(
601623
... dfn.functions.spark.make_dt_interval(
602-
... i32(1), i32(2), i32(3), dfn.lit(4.5)
624+
... days=i32(1), hours=i32(2), mins=i32(3), secs=dfn.lit(4.5)
603625
... ).alias("v")
604626
... )
605627
>>> r.collect_column("v")[0].as_py()
606628
datetime.timedelta(days=1, seconds=7384, microseconds=500000)
607629
"""
608-
return Expr(_f.make_dt_interval(days.expr, hours.expr, mins.expr, secs.expr))
630+
return Expr(
631+
_f.make_dt_interval(
632+
(days if days is not None else _ZERO_I32).expr,
633+
(hours if hours is not None else _ZERO_I32).expr,
634+
(mins if mins is not None else _ZERO_I32).expr,
635+
(secs if secs is not None else Expr.literal(0.0)).expr,
636+
)
637+
)
609638

610639

611640
def make_interval(
612-
years: Expr,
613-
months: Expr,
614-
weeks: Expr,
615-
days: Expr,
616-
hours: Expr,
617-
mins: Expr,
618-
secs: Expr,
641+
years: Expr | None = None,
642+
months: Expr | None = None,
643+
weeks: Expr | None = None,
644+
days: Expr | None = None,
645+
hours: Expr | None = None,
646+
mins: Expr | None = None,
647+
secs: Expr | None = None,
619648
) -> Expr:
620649
"""Spark ``make_interval``: interval from year/month/week/day/hour/min/sec parts.
621650
651+
All parts are optional; omitted parts default to zero, matching pyspark.
652+
622653
Examples:
623-
>>> import pyarrow as pa
624654
>>> ctx = dfn.SessionContext()
625655
>>> df = ctx.from_pydict({"x": [1]})
656+
>>> r = df.select(dfn.functions.spark.make_interval().alias("v"))
657+
>>> r.collect_column("v")[0].as_py().months
658+
0
659+
660+
>>> import pyarrow as pa
626661
>>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32()))
627-
>>> r = df.select(
628-
... dfn.functions.spark.make_interval(
629-
... i32(1), i32(0), i32(0), i32(0),
630-
... i32(0), i32(0), dfn.lit(0.0)
631-
... ).alias("v")
632-
... )
662+
>>> r = df.select(dfn.functions.spark.make_interval(years=i32(1)).alias("v"))
633663
>>> r.collect_column("v")[0].as_py().months
634664
12
635665
"""
636666
return Expr(
637667
_f.make_interval(
638-
years.expr,
639-
months.expr,
640-
weeks.expr,
641-
days.expr,
642-
hours.expr,
643-
mins.expr,
644-
secs.expr,
668+
(years if years is not None else _ZERO_I32).expr,
669+
(months if months is not None else _ZERO_I32).expr,
670+
(weeks if weeks is not None else _ZERO_I32).expr,
671+
(days if days is not None else _ZERO_I32).expr,
672+
(hours if hours is not None else _ZERO_I32).expr,
673+
(mins if mins is not None else _ZERO_I32).expr,
674+
(secs if secs is not None else Expr.literal(0.0)).expr,
645675
)
646676
)
647677

@@ -984,21 +1014,36 @@ def map_from_entries(col: Expr) -> Expr:
9841014
return Expr(_f.map_from_entries(col.expr))
9851015

9861016

987-
def str_to_map(text: Expr, pair_delim: Expr, key_value_delim: Expr) -> Expr:
1017+
def str_to_map(
1018+
text: Expr,
1019+
pair_delim: Expr | None = None,
1020+
key_value_delim: Expr | None = None,
1021+
) -> Expr:
9881022
"""Spark ``str_to_map``: split text into key/value pairs using delimiters.
9891023
1024+
Delimiters default to ``","`` and ``":"`` when omitted, matching pyspark.
1025+
9901026
Examples:
9911027
>>> ctx = dfn.SessionContext()
9921028
>>> df = ctx.from_pydict({"x": [1]})
1029+
>>> r = df.select(
1030+
... dfn.functions.spark.str_to_map(dfn.lit("a:1,b:2")).alias("v"))
1031+
>>> r.collect_column("v")[0].as_py()
1032+
[('a', '1'), ('b', '2')]
1033+
9931034
>>> r = df.select(
9941035
... dfn.functions.spark.str_to_map(
995-
... dfn.lit("a:1,b:2"), dfn.lit(","), dfn.lit(":")
1036+
... dfn.lit("a=1;b=2"),
1037+
... pair_delim=dfn.lit(";"),
1038+
... key_value_delim=dfn.lit("="),
9961039
... ).alias("v")
9971040
... )
9981041
>>> r.collect_column("v")[0].as_py()
9991042
[('a', '1'), ('b', '2')]
10001043
"""
1001-
return Expr(_f.str_to_map(text.expr, pair_delim.expr, key_value_delim.expr))
1044+
pd = pair_delim if pair_delim is not None else Expr.literal(",")
1045+
kvd = key_value_delim if key_value_delim is not None else Expr.literal(":")
1046+
return Expr(_f.str_to_map(text.expr, pd.expr, kvd.expr))
10021047

10031048

10041049
# ---------------------------------------------------------------------------
@@ -1130,18 +1175,28 @@ def rint(col: Expr) -> Expr:
11301175
return Expr(_f.rint(col.expr))
11311176

11321177

1133-
def round(col: Expr, scale: Expr) -> Expr:
1178+
def round(col: Expr, scale: Expr | None = None) -> Expr:
11341179
"""Spark ``round``: round to ``scale`` decimal places, HALF_UP rounding.
11351180
1181+
``scale`` defaults to zero when omitted, matching pyspark.
1182+
11361183
Examples:
11371184
>>> ctx = dfn.SessionContext()
11381185
>>> df = ctx.from_pydict({"x": [1]})
1139-
>>> r = df.select(
1140-
... dfn.functions.spark.round(dfn.lit(2.5), dfn.lit(0)).alias("v"))
1186+
>>> r = df.select(dfn.functions.spark.round(dfn.lit(2.5)).alias("v"))
11411187
>>> r.collect_column("v")[0].as_py()
11421188
3.0
1189+
1190+
>>> r = df.select(
1191+
... dfn.functions.spark.round(
1192+
... dfn.lit(2.345), scale=dfn.lit(2)
1193+
... ).alias("v")
1194+
... )
1195+
>>> r.collect_column("v")[0].as_py()
1196+
2.35
11431197
"""
1144-
return Expr(_f.round(col.expr, scale.expr))
1198+
scale_expr = scale if scale is not None else _ZERO_I32
1199+
return Expr(_f.round(col.expr, scale_expr.expr))
11451200

11461201

11471202
def unhex(col: Expr) -> Expr:
@@ -1306,9 +1361,16 @@ def elt(*inputs: Expr) -> Expr:
13061361
return Expr(_f.elt(*[i.expr for i in inputs]))
13071362

13081363

1309-
def ilike(string: Expr, pattern: Expr) -> Expr:
1364+
def ilike(
1365+
str: Expr,
1366+
pattern: Expr,
1367+
escapeChar: str | None = None, # noqa: N803
1368+
) -> Expr:
13101369
"""Spark ``ilike``: case-insensitive pattern match.
13111370
1371+
``escapeChar`` is accepted for pyspark parity but is not yet wired through
1372+
the Rust binding; passing a non-``None`` value raises ``NotImplementedError``.
1373+
13121374
Examples:
13131375
>>> ctx = dfn.SessionContext()
13141376
>>> df = ctx.from_pydict({"x": [1]})
@@ -1317,7 +1379,10 @@ def ilike(string: Expr, pattern: Expr) -> Expr:
13171379
>>> r.collect_column("v")[0].as_py()
13181380
True
13191381
"""
1320-
return Expr(_f.ilike(string.expr, pattern.expr))
1382+
if escapeChar is not None:
1383+
msg = "ilike(escapeChar=...) is not yet supported by the Spark UDF binding"
1384+
raise NotImplementedError(msg)
1385+
return Expr(_f.ilike(str.expr, pattern.expr))
13211386

13221387

13231388
def length(col: Expr) -> Expr:
@@ -1333,9 +1398,16 @@ def length(col: Expr) -> Expr:
13331398
return Expr(_f.length(col.expr))
13341399

13351400

1336-
def like(string: Expr, pattern: Expr) -> Expr:
1401+
def like(
1402+
str: Expr,
1403+
pattern: Expr,
1404+
escapeChar: str | None = None, # noqa: N803
1405+
) -> Expr:
13371406
"""Spark ``like``: case-sensitive pattern match.
13381407
1408+
``escapeChar`` is accepted for pyspark parity but is not yet wired through
1409+
the Rust binding; passing a non-``None`` value raises ``NotImplementedError``.
1410+
13391411
Examples:
13401412
>>> ctx = dfn.SessionContext()
13411413
>>> df = ctx.from_pydict({"x": [1]})
@@ -1344,7 +1416,10 @@ def like(string: Expr, pattern: Expr) -> Expr:
13441416
>>> r.collect_column("v")[0].as_py()
13451417
True
13461418
"""
1347-
return Expr(_f.like(string.expr, pattern.expr))
1419+
if escapeChar is not None:
1420+
msg = "like(escapeChar=...) is not yet supported by the Spark UDF binding"
1421+
raise NotImplementedError(msg)
1422+
return Expr(_f.like(str.expr, pattern.expr))
13481423

13491424

13501425
def luhn_check(col: Expr) -> Expr:

python/tests/test_spark_functions.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,73 @@ def test_round_half_up():
229229
assert _val(df, spark.round(lit(2.5), lit(0))) == 3.0
230230

231231

232+
# ---------------------------------------------------------------------------
233+
# Optional parameter defaults / NotImplementedError
234+
# ---------------------------------------------------------------------------
235+
236+
237+
def test_round_scale_default():
238+
"""spark.round defaults scale to 0."""
239+
ctx = SessionContext()
240+
df = ctx.from_pydict({"x": [1]})
241+
assert _val(df, spark.round(lit(2.5))) == 3.0
242+
243+
244+
def test_make_dt_interval_defaults():
245+
"""spark.make_dt_interval with no args returns a zero day-time interval."""
246+
import datetime as dt
247+
248+
ctx = SessionContext()
249+
df = ctx.from_pydict({"x": [1]})
250+
assert _val(df, spark.make_dt_interval()) == dt.timedelta(0)
251+
252+
253+
def test_make_interval_defaults():
254+
"""spark.make_interval with no args returns a zero interval."""
255+
ctx = SessionContext()
256+
df = ctx.from_pydict({"x": [1]})
257+
assert _val(df, spark.make_interval()).months == 0
258+
259+
260+
def test_str_to_map_defaults():
261+
"""spark.str_to_map defaults delimiters to ',' and ':'."""
262+
ctx = SessionContext()
263+
df = ctx.from_pydict({"x": [1]})
264+
assert _val(df, spark.str_to_map(lit("a:1,b:2"))) == [("a", "1"), ("b", "2")]
265+
266+
267+
def test_shuffle_seed_raises():
268+
"""spark.shuffle(seed=...) raises NotImplementedError until Rust supports it."""
269+
with pytest.raises(NotImplementedError, match="seed"):
270+
spark.shuffle(spark.array(lit(1), lit(2)), seed=1)
271+
272+
273+
def test_like_escape_raises():
274+
"""spark.like/ilike escapeChar raises NotImplementedError until Rust supports."""
275+
with pytest.raises(NotImplementedError, match="escapeChar"):
276+
spark.like(lit("a"), lit("a"), escapeChar="\\")
277+
with pytest.raises(NotImplementedError, match="escapeChar"):
278+
spark.ilike(lit("a"), lit("a"), escapeChar="\\")
279+
280+
281+
def test_aggregate_positional_compat():
282+
"""Pyspark-style positional calls still work after the rename to ``col``."""
283+
ctx = SessionContext()
284+
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]})
285+
out = df.aggregate(
286+
[],
287+
[
288+
spark.avg(col("a")).alias("av"),
289+
spark.try_sum(col("a")).alias("ts"),
290+
spark.collect_list(col("a")).alias("cl"),
291+
spark.collect_set(col("a")).alias("cs"),
292+
],
293+
).collect()
294+
rec = pa.Table.from_batches(out)
295+
assert rec.column("av")[0].as_py() == 2.0
296+
assert rec.column("ts")[0].as_py() == 6.0
297+
298+
232299
# ---------------------------------------------------------------------------
233300
# SQL path via enable_spark_functions
234301
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)