Skip to content

Commit 7d8a435

Browse files
timsaucerclaude
andcommitted
test: add unit tests for arrow_try_cast, arrow_field, cast_to_type, with_metadata
Mirrors the existing test_arrow_cast pattern. Covers: - arrow_try_cast: string-syntax, pa.DataType, and null-on-failure paths - arrow_field: full returned struct shape (name, data_type, nullable, metadata) - cast_to_type: type-from-expr happy path and try_cast=True null behavior - with_metadata: round-trip through arrow_metadata, empty-dict no-op, and empty-key ValueError Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 83dca2e commit 7d8a435

1 file changed

Lines changed: 105 additions & 0 deletions

File tree

python/tests/test_functions.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,111 @@ def test_arrow_cast_with_pyarrow_type(df):
13251325
assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string())
13261326

13271327

1328+
def test_arrow_try_cast(df):
1329+
df = df.select(
1330+
f.arrow_try_cast(column("b"), "Float64").alias("b_as_float"),
1331+
f.arrow_try_cast(column("b"), "Int32").alias("b_as_int"),
1332+
)
1333+
result = df.collect()[0]
1334+
1335+
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1336+
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1337+
1338+
1339+
def test_arrow_try_cast_with_pyarrow_type(df):
1340+
df = df.select(
1341+
f.arrow_try_cast(column("b"), pa.float64()).alias("b_as_float"),
1342+
f.arrow_try_cast(column("b"), pa.int32()).alias("b_as_int"),
1343+
)
1344+
result = df.collect()[0]
1345+
1346+
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1347+
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1348+
1349+
1350+
def test_arrow_try_cast_null_on_failure():
1351+
ctx = SessionContext()
1352+
batch = pa.RecordBatch.from_arrays([pa.array(["1.5", "oops", "3"])], names=["s"])
1353+
df = ctx.create_dataframe([[batch]])
1354+
1355+
result = df.select(
1356+
f.arrow_try_cast(column("s"), "Float64").alias("c"),
1357+
f.arrow_try_cast(column("s"), pa.float64()).alias("c_pa"),
1358+
).collect()[0]
1359+
1360+
assert result.column(0).to_pylist() == [1.5, None, 3.0]
1361+
assert result.column(1).to_pylist() == [1.5, None, 3.0]
1362+
1363+
1364+
def test_arrow_field():
1365+
ctx = SessionContext()
1366+
field = pa.field("val", pa.int64(), metadata={"k": "v"})
1367+
schema = pa.schema([field])
1368+
batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema)
1369+
df = ctx.create_dataframe([[batch]])
1370+
1371+
out = (
1372+
df.select(f.arrow_field(column("val")).alias("f"))
1373+
.collect_column("f")[0]
1374+
.as_py()
1375+
)
1376+
assert out == {
1377+
"name": "val",
1378+
"data_type": "Int64",
1379+
"nullable": True,
1380+
"metadata": [("k", "v")],
1381+
}
1382+
1383+
1384+
def test_cast_to_type():
1385+
ctx = SessionContext()
1386+
batch = pa.RecordBatch.from_arrays(
1387+
[pa.array([4, 5, 6]), pa.array([1.0, 2.0, 3.0])],
1388+
names=["b", "fl"],
1389+
)
1390+
df = ctx.create_dataframe([[batch]])
1391+
1392+
result = df.select(f.cast_to_type(column("b"), column("fl")).alias("c")).collect()[
1393+
0
1394+
]
1395+
1396+
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1397+
1398+
1399+
def test_cast_to_type_try_cast_null_on_failure():
1400+
ctx = SessionContext()
1401+
batch = pa.RecordBatch.from_arrays(
1402+
[pa.array(["oops", "2", "3"]), pa.array([1.0, 2.0, 3.0])],
1403+
names=["a", "fl"],
1404+
)
1405+
df = ctx.create_dataframe([[batch]])
1406+
1407+
result = df.select(
1408+
f.cast_to_type(column("a"), column("fl"), try_cast=True).alias("c")
1409+
).collect()[0]
1410+
1411+
assert result.column(0).to_pylist() == [None, 2.0, 3.0]
1412+
assert result.column(0).type == pa.float64()
1413+
1414+
1415+
def test_with_metadata_round_trip(df):
1416+
df = df.select(f.with_metadata(column("b"), {"unit": "ms"}).alias("b"))
1417+
result = df.select(f.arrow_metadata(column("b"), "unit").alias("u")).collect_column(
1418+
"u"
1419+
)
1420+
assert result[0].as_py() == "ms"
1421+
1422+
1423+
def test_with_metadata_empty_dict_noop(df):
1424+
out = df.select(f.with_metadata(column("b"), {}).alias("b")).collect()[0]
1425+
assert out.column(0) == pa.array([4, 5, 6])
1426+
1427+
1428+
def test_with_metadata_empty_key_raises(df):
1429+
with pytest.raises(ValueError, match="non-empty"):
1430+
f.with_metadata(column("b"), {"": "v"})
1431+
1432+
13281433
def test_case(df):
13291434
df = df.select(
13301435
f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)),

0 commit comments

Comments
 (0)