@@ -1325,6 +1325,111 @@ def test_arrow_cast_with_pyarrow_type(df):
13251325 assert result .column (2 ) == pa .array (["4" , "5" , "6" ], type = pa .string ())
13261326
13271327
1328+ def test_arrow_try_cast (df ):
1329+ df = df .select (
1330+ f .arrow_try_cast (column ("b" ), "Float64" ).alias ("b_as_float" ),
1331+ f .arrow_try_cast (column ("b" ), "Int32" ).alias ("b_as_int" ),
1332+ )
1333+ result = df .collect ()[0 ]
1334+
1335+ assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1336+ assert result .column (1 ) == pa .array ([4 , 5 , 6 ], type = pa .int32 ())
1337+
1338+
1339+ def test_arrow_try_cast_with_pyarrow_type (df ):
1340+ df = df .select (
1341+ f .arrow_try_cast (column ("b" ), pa .float64 ()).alias ("b_as_float" ),
1342+ f .arrow_try_cast (column ("b" ), pa .int32 ()).alias ("b_as_int" ),
1343+ )
1344+ result = df .collect ()[0 ]
1345+
1346+ assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1347+ assert result .column (1 ) == pa .array ([4 , 5 , 6 ], type = pa .int32 ())
1348+
1349+
1350+ def test_arrow_try_cast_null_on_failure ():
1351+ ctx = SessionContext ()
1352+ batch = pa .RecordBatch .from_arrays ([pa .array (["1.5" , "oops" , "3" ])], names = ["s" ])
1353+ df = ctx .create_dataframe ([[batch ]])
1354+
1355+ result = df .select (
1356+ f .arrow_try_cast (column ("s" ), "Float64" ).alias ("c" ),
1357+ f .arrow_try_cast (column ("s" ), pa .float64 ()).alias ("c_pa" ),
1358+ ).collect ()[0 ]
1359+
1360+ assert result .column (0 ).to_pylist () == [1.5 , None , 3.0 ]
1361+ assert result .column (1 ).to_pylist () == [1.5 , None , 3.0 ]
1362+
1363+
1364+ def test_arrow_field ():
1365+ ctx = SessionContext ()
1366+ field = pa .field ("val" , pa .int64 (), metadata = {"k" : "v" })
1367+ schema = pa .schema ([field ])
1368+ batch = pa .RecordBatch .from_arrays ([pa .array ([1 ])], schema = schema )
1369+ df = ctx .create_dataframe ([[batch ]])
1370+
1371+ out = (
1372+ df .select (f .arrow_field (column ("val" )).alias ("f" ))
1373+ .collect_column ("f" )[0 ]
1374+ .as_py ()
1375+ )
1376+ assert out == {
1377+ "name" : "val" ,
1378+ "data_type" : "Int64" ,
1379+ "nullable" : True ,
1380+ "metadata" : [("k" , "v" )],
1381+ }
1382+
1383+
1384+ def test_cast_to_type ():
1385+ ctx = SessionContext ()
1386+ batch = pa .RecordBatch .from_arrays (
1387+ [pa .array ([4 , 5 , 6 ]), pa .array ([1.0 , 2.0 , 3.0 ])],
1388+ names = ["b" , "fl" ],
1389+ )
1390+ df = ctx .create_dataframe ([[batch ]])
1391+
1392+ result = df .select (f .cast_to_type (column ("b" ), column ("fl" )).alias ("c" )).collect ()[
1393+ 0
1394+ ]
1395+
1396+ assert result .column (0 ) == pa .array ([4.0 , 5.0 , 6.0 ], type = pa .float64 ())
1397+
1398+
1399+ def test_cast_to_type_try_cast_null_on_failure ():
1400+ ctx = SessionContext ()
1401+ batch = pa .RecordBatch .from_arrays (
1402+ [pa .array (["oops" , "2" , "3" ]), pa .array ([1.0 , 2.0 , 3.0 ])],
1403+ names = ["a" , "fl" ],
1404+ )
1405+ df = ctx .create_dataframe ([[batch ]])
1406+
1407+ result = df .select (
1408+ f .cast_to_type (column ("a" ), column ("fl" ), try_cast = True ).alias ("c" )
1409+ ).collect ()[0 ]
1410+
1411+ assert result .column (0 ).to_pylist () == [None , 2.0 , 3.0 ]
1412+ assert result .column (0 ).type == pa .float64 ()
1413+
1414+
1415+ def test_with_metadata_round_trip (df ):
1416+ df = df .select (f .with_metadata (column ("b" ), {"unit" : "ms" }).alias ("b" ))
1417+ result = df .select (f .arrow_metadata (column ("b" ), "unit" ).alias ("u" )).collect_column (
1418+ "u"
1419+ )
1420+ assert result [0 ].as_py () == "ms"
1421+
1422+
1423+ def test_with_metadata_empty_dict_noop (df ):
1424+ out = df .select (f .with_metadata (column ("b" ), {}).alias ("b" )).collect ()[0 ]
1425+ assert out .column (0 ) == pa .array ([4 , 5 , 6 ])
1426+
1427+
1428+ def test_with_metadata_empty_key_raises (df ):
1429+ with pytest .raises (ValueError , match = "non-empty" ):
1430+ f .with_metadata (column ("b" ), {"" : "v" })
1431+
1432+
13281433def test_case (df ):
13291434 df = df .select (
13301435 f .case (column ("b" )).when (literal (4 ), literal (10 )).otherwise (literal (8 )),
0 commit comments