3232
3333from typing import TYPE_CHECKING , Any
3434
35+ import pyarrow as pa
36+
3537from datafusion ._internal import functions as _functions
3638from datafusion .expr import Expr , sort_list_to_raw_sort_list
3739
4143
4244_f = _functions .spark
4345
46+ # Reused int32 literal so optional-arg defaults don't rebuild it per call.
47+ _ZERO_I32 = Expr .literal (pa .scalar (0 , type = pa .int32 ()))
48+
4449
4550def _filter_raw (filter : Expr | None ) -> Any :
4651 return filter .expr if filter is not None else None
@@ -203,9 +208,12 @@ def array(*cols: Expr) -> Expr:
203208 return Expr (_f .array (* [c .expr for c in cols ]))
204209
205210
206- def shuffle (col : Expr ) -> Expr :
211+ def shuffle (col : Expr , seed : int | None = None ) -> Expr :
207212 """Spark ``shuffle``: returns a random permutation of the input array.
208213
214+ ``seed`` is accepted for pyspark parity but is not yet wired through the
215+ Rust binding; passing a non-``None`` value raises ``NotImplementedError``.
216+
209217 Examples:
210218 >>> ctx = dfn.SessionContext()
211219 >>> df = ctx.from_pydict({"x": [1]})
@@ -217,6 +225,9 @@ def shuffle(col: Expr) -> Expr:
217225 >>> sorted(r.collect_column("v")[0].as_py())
218226 [1, 2, 3]
219227 """
228+ if seed is not None :
229+ msg = "shuffle(seed=...) is not yet supported by the Spark UDF binding"
230+ raise NotImplementedError (msg )
220231 return Expr (_f .shuffle (col .expr ))
221232
222233
@@ -589,59 +600,78 @@ def last_day(col: Expr) -> Expr:
589600 return Expr (_f .last_day (col .expr ))
590601
591602
592- def make_dt_interval (days : Expr , hours : Expr , mins : Expr , secs : Expr ) -> Expr :
603+ def make_dt_interval (
604+ days : Expr | None = None ,
605+ hours : Expr | None = None ,
606+ mins : Expr | None = None ,
607+ secs : Expr | None = None ,
608+ ) -> Expr :
593609 """Spark ``make_dt_interval``: day-time interval from components.
594610
611+ All parts are optional; omitted parts default to zero, matching pyspark.
612+
595613 Examples:
596- >>> import pyarrow as pa
597614 >>> ctx = dfn.SessionContext()
598615 >>> df = ctx.from_pydict({"x": [1]})
616+ >>> r = df.select(dfn.functions.spark.make_dt_interval().alias("v"))
617+ >>> r.collect_column("v")[0].as_py()
618+ datetime.timedelta(0)
619+
620+ >>> import pyarrow as pa
599621 >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32()))
600622 >>> r = df.select(
601623 ... dfn.functions.spark.make_dt_interval(
602- ... i32(1), i32(2), i32(3), dfn.lit(4.5)
624+ ... days= i32(1), hours= i32(2), mins= i32(3), secs= dfn.lit(4.5)
603625 ... ).alias("v")
604626 ... )
605627 >>> r.collect_column("v")[0].as_py()
606628 datetime.timedelta(days=1, seconds=7384, microseconds=500000)
607629 """
608- return Expr (_f .make_dt_interval (days .expr , hours .expr , mins .expr , secs .expr ))
630+ return Expr (
631+ _f .make_dt_interval (
632+ (days if days is not None else _ZERO_I32 ).expr ,
633+ (hours if hours is not None else _ZERO_I32 ).expr ,
634+ (mins if mins is not None else _ZERO_I32 ).expr ,
635+ (secs if secs is not None else Expr .literal (0.0 )).expr ,
636+ )
637+ )
609638
610639
611640def make_interval (
612- years : Expr ,
613- months : Expr ,
614- weeks : Expr ,
615- days : Expr ,
616- hours : Expr ,
617- mins : Expr ,
618- secs : Expr ,
641+ years : Expr | None = None ,
642+ months : Expr | None = None ,
643+ weeks : Expr | None = None ,
644+ days : Expr | None = None ,
645+ hours : Expr | None = None ,
646+ mins : Expr | None = None ,
647+ secs : Expr | None = None ,
619648) -> Expr :
620649 """Spark ``make_interval``: interval from year/month/week/day/hour/min/sec parts.
621650
651+ All parts are optional; omitted parts default to zero, matching pyspark.
652+
622653 Examples:
623- >>> import pyarrow as pa
624654 >>> ctx = dfn.SessionContext()
625655 >>> df = ctx.from_pydict({"x": [1]})
656+ >>> r = df.select(dfn.functions.spark.make_interval().alias("v"))
657+ >>> r.collect_column("v")[0].as_py().months
658+ 0
659+
660+ >>> import pyarrow as pa
626661 >>> i32 = lambda n: dfn.lit(pa.scalar(n, type=pa.int32()))
627- >>> r = df.select(
628- ... dfn.functions.spark.make_interval(
629- ... i32(1), i32(0), i32(0), i32(0),
630- ... i32(0), i32(0), dfn.lit(0.0)
631- ... ).alias("v")
632- ... )
662+ >>> r = df.select(dfn.functions.spark.make_interval(years=i32(1)).alias("v"))
633663 >>> r.collect_column("v")[0].as_py().months
634664 12
635665 """
636666 return Expr (
637667 _f .make_interval (
638- years .expr ,
639- months .expr ,
640- weeks .expr ,
641- days .expr ,
642- hours .expr ,
643- mins .expr ,
644- secs .expr ,
668+ ( years if years is not None else _ZERO_I32 ) .expr ,
669+ ( months if months is not None else _ZERO_I32 ) .expr ,
670+ ( weeks if weeks is not None else _ZERO_I32 ) .expr ,
671+ ( days if days is not None else _ZERO_I32 ) .expr ,
672+ ( hours if hours is not None else _ZERO_I32 ) .expr ,
673+ ( mins if mins is not None else _ZERO_I32 ) .expr ,
674+ ( secs if secs is not None else Expr . literal ( 0.0 )) .expr ,
645675 )
646676 )
647677
@@ -984,21 +1014,36 @@ def map_from_entries(col: Expr) -> Expr:
9841014 return Expr (_f .map_from_entries (col .expr ))
9851015
9861016
987- def str_to_map (text : Expr , pair_delim : Expr , key_value_delim : Expr ) -> Expr :
1017+ def str_to_map (
1018+ text : Expr ,
1019+ pair_delim : Expr | None = None ,
1020+ key_value_delim : Expr | None = None ,
1021+ ) -> Expr :
9881022 """Spark ``str_to_map``: split text into key/value pairs using delimiters.
9891023
1024+ Delimiters default to ``","`` and ``":"`` when omitted, matching pyspark.
1025+
9901026 Examples:
9911027 >>> ctx = dfn.SessionContext()
9921028 >>> df = ctx.from_pydict({"x": [1]})
1029+ >>> r = df.select(
1030+ ... dfn.functions.spark.str_to_map(dfn.lit("a:1,b:2")).alias("v"))
1031+ >>> r.collect_column("v")[0].as_py()
1032+ [('a', '1'), ('b', '2')]
1033+
9931034 >>> r = df.select(
9941035 ... dfn.functions.spark.str_to_map(
995- ... dfn.lit("a:1,b:2"), dfn.lit(","), dfn.lit(":")
1036+ ... dfn.lit("a=1;b=2"),
1037+ ... pair_delim=dfn.lit(";"),
1038+ ... key_value_delim=dfn.lit("="),
9961039 ... ).alias("v")
9971040 ... )
9981041 >>> r.collect_column("v")[0].as_py()
9991042 [('a', '1'), ('b', '2')]
10001043 """
1001- return Expr (_f .str_to_map (text .expr , pair_delim .expr , key_value_delim .expr ))
1044+ pd = pair_delim if pair_delim is not None else Expr .literal ("," )
1045+ kvd = key_value_delim if key_value_delim is not None else Expr .literal (":" )
1046+ return Expr (_f .str_to_map (text .expr , pd .expr , kvd .expr ))
10021047
10031048
10041049# ---------------------------------------------------------------------------
@@ -1130,18 +1175,28 @@ def rint(col: Expr) -> Expr:
11301175 return Expr (_f .rint (col .expr ))
11311176
11321177
1133- def round (col : Expr , scale : Expr ) -> Expr :
1178+ def round (col : Expr , scale : Expr | None = None ) -> Expr :
11341179 """Spark ``round``: round to ``scale`` decimal places, HALF_UP rounding.
11351180
1181+ ``scale`` defaults to zero when omitted, matching pyspark.
1182+
11361183 Examples:
11371184 >>> ctx = dfn.SessionContext()
11381185 >>> df = ctx.from_pydict({"x": [1]})
1139- >>> r = df.select(
1140- ... dfn.functions.spark.round(dfn.lit(2.5), dfn.lit(0)).alias("v"))
1186+ >>> r = df.select(dfn.functions.spark.round(dfn.lit(2.5)).alias("v"))
11411187 >>> r.collect_column("v")[0].as_py()
11421188 3.0
1189+
1190+ >>> r = df.select(
1191+ ... dfn.functions.spark.round(
1192+ ... dfn.lit(2.345), scale=dfn.lit(2)
1193+ ... ).alias("v")
1194+ ... )
1195+ >>> r.collect_column("v")[0].as_py()
1196+ 2.35
11431197 """
1144- return Expr (_f .round (col .expr , scale .expr ))
1198+ scale_expr = scale if scale is not None else _ZERO_I32
1199+ return Expr (_f .round (col .expr , scale_expr .expr ))
11451200
11461201
11471202def unhex (col : Expr ) -> Expr :
@@ -1306,9 +1361,16 @@ def elt(*inputs: Expr) -> Expr:
13061361 return Expr (_f .elt (* [i .expr for i in inputs ]))
13071362
13081363
1309- def ilike (string : Expr , pattern : Expr ) -> Expr :
1364+ def ilike (
1365+ str : Expr ,
1366+ pattern : Expr ,
1367+ escapeChar : str | None = None , # noqa: N803
1368+ ) -> Expr :
13101369 """Spark ``ilike``: case-insensitive pattern match.
13111370
1371+ ``escapeChar`` is accepted for pyspark parity but is not yet wired through
1372+ the Rust binding; passing a non-``None`` value raises ``NotImplementedError``.
1373+
13121374 Examples:
13131375 >>> ctx = dfn.SessionContext()
13141376 >>> df = ctx.from_pydict({"x": [1]})
@@ -1317,7 +1379,10 @@ def ilike(string: Expr, pattern: Expr) -> Expr:
13171379 >>> r.collect_column("v")[0].as_py()
13181380 True
13191381 """
1320- return Expr (_f .ilike (string .expr , pattern .expr ))
1382+ if escapeChar is not None :
1383+ msg = "ilike(escapeChar=...) is not yet supported by the Spark UDF binding"
1384+ raise NotImplementedError (msg )
1385+ return Expr (_f .ilike (str .expr , pattern .expr ))
13211386
13221387
13231388def length (col : Expr ) -> Expr :
@@ -1333,9 +1398,16 @@ def length(col: Expr) -> Expr:
13331398 return Expr (_f .length (col .expr ))
13341399
13351400
1336- def like (string : Expr , pattern : Expr ) -> Expr :
1401+ def like (
1402+ str : Expr ,
1403+ pattern : Expr ,
1404+ escapeChar : str | None = None , # noqa: N803
1405+ ) -> Expr :
13371406 """Spark ``like``: case-sensitive pattern match.
13381407
1408+ ``escapeChar`` is accepted for pyspark parity but is not yet wired through
1409+ the Rust binding; passing a non-``None`` value raises ``NotImplementedError``.
1410+
13391411 Examples:
13401412 >>> ctx = dfn.SessionContext()
13411413 >>> df = ctx.from_pydict({"x": [1]})
@@ -1344,7 +1416,10 @@ def like(string: Expr, pattern: Expr) -> Expr:
13441416 >>> r.collect_column("v")[0].as_py()
13451417 True
13461418 """
1347- return Expr (_f .like (string .expr , pattern .expr ))
1419+ if escapeChar is not None :
1420+ msg = "like(escapeChar=...) is not yet supported by the Spark UDF binding"
1421+ raise NotImplementedError (msg )
1422+ return Expr (_f .like (str .expr , pattern .expr ))
13481423
13491424
13501425def luhn_check (col : Expr ) -> Expr :
0 commit comments