Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions src/dve/core_engine/backends/implementations/duckdb/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
duckdb_read_parquet,
duckdb_record_index,
duckdb_write_parquet,
get_duckdb_cast_statement_from_annotation,
get_duckdb_type_from_annotation,
relation_is_empty,
)
Expand Down Expand Up @@ -101,18 +102,7 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument
_lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable
return self._connection.sql("select * from _lazy_df")

@staticmethod
def generate_ddb_cast_statement(
column_name: str, dtype: DuckDBPyType, null_flag: bool = False
) -> str:
"""Helper method to generate sql statements for casting datatypes (permissively).
Current duckdb python API doesn't play well with this currently.
"""
if not null_flag:
return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"'
return f'cast(NULL AS {dtype}) AS "{column_name}"'

# pylint: disable=R0914
# pylint: disable=R0914,R0915
def apply_data_contract(
self,
working_dir: URI,
Expand Down Expand Up @@ -180,12 +170,16 @@ def apply_data_contract(

casting_statements = [
(
self.generate_ddb_cast_statement(column, dtype)
get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation)
+ f""" AS "{column}" """
if column in relation.columns
else self.generate_ddb_cast_statement(column, dtype, null_flag=True)
else f"CAST(NULL AS {ddb_schema[column]}) AS {column}"
)
for column, dtype in ddb_schema.items()
for column, mdl_fld in entity_fields.items()
]
casting_statements.append(
f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301
)
try:
relation = relation.project(", ".join(casting_statements))
except Exception as err: # pylint: disable=broad-except
Expand Down
105 changes: 105 additions & 0 deletions src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,108 @@ def duckdb_record_index(cls):
setattr(cls, "add_record_index", _add_duckdb_record_index)
setattr(cls, "drop_record_index", _drop_duckdb_record_index)
return cls


def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str:
"""Cast to Duck DB type"""
return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"""


def _ddb_safely_quote_name(field_name: str) -> str:
"""Quote field names in case reserved"""
try:
sep_idx = field_name.index(".")
return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:]
except ValueError:
return f'"{field_name}"'


# pylint: disable=R0801,R0911,R0912
def get_duckdb_cast_statement_from_annotation(
element_name: str,
type_annotation: Any,
parent_element: bool = True,
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
time_regex: str = r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$",
) -> str:
"""Generate casting statements for duckdb relations from type annotations"""
type_origin = get_origin(type_annotation)

quoted_name = _ddb_safely_quote_name(element_name)

# An `Optional` or `Union` type, check to ensure non-heterogenity.
if type_origin is Union:
python_type = _get_non_heterogenous_type(get_args(type_annotation))
return get_duckdb_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
)

# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
element_type = _get_non_heterogenous_type(get_args(type_annotation))
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)

if type_origin is Annotated:
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
return get_duckdb_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
) # add other expected params here
# Ensure that we have a concrete type at this point.
if not isinstance(type_annotation, type):
raise ValueError(f"Unsupported type annotation {type_annotation!r}")

if (
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
(issubclass(type_annotation, dict) and type_annotation is not dict)
# Type hint is a dataclass.
or is_dataclass(type_annotation)
# Type hint is a `pydantic` model.
or (type_origin is None and issubclass(type_annotation, BaseModel))
):
fields: dict[str, str] = {}
for field_name, field_annotation in get_type_hints(type_annotation).items():
# Technically non-string keys are disallowed, but people are bad.
if not isinstance(field_name, str):
raise ValueError(
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
) # pragma: no cover
if get_origin(field_annotation) is ClassVar:
continue

fields[field_name] = get_duckdb_cast_statement_from_annotation(
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
)

if not fields:
raise ValueError(
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
)
cast_exprs = ",".join([f'"{nme}":= {stmt}' for nme, stmt in fields.items()])
stmt = f"struct_pack({cast_exprs})"
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)

if type_annotation is list:
raise ValueError(
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
)
if type_annotation is dict or type_origin is dict:
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")

for type_ in type_annotation.mro():
# datetime is subclass of date, so needs to be handled first
if issubclass(type_, datetime):
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
return stmt
if issubclass(type_, date):
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
return stmt
if issubclass(type_, time):
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
return stmt
duck_type = get_duckdb_type_from_annotation(type_)
if duck_type:
stmt = f"trim({quoted_name})"
return _cast_as_ddb_type(stmt, type_) if parent_element else stmt
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")
100 changes: 100 additions & 0 deletions src/dve/core_engine/backends/implementations/spark/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,103 @@ def spark_record_index(cls):
setattr(cls, "add_record_index", _add_spark_record_index)
setattr(cls, "drop_record_index", _drop_spark_record_index)
return cls


def _cast_as_spark_type(field_expr: str, field_type: Any) -> Column:
"""Cast to spark type"""
return sf.expr(field_expr).cast(get_type_from_annotation(field_type))


def _spark_safely_quote_name(field_name: str) -> str:
"""Quote field names in case reserved"""
try:
sep_idx = field_name.index(".")
return f"`{field_name[: sep_idx]}`" + field_name[sep_idx:]
except ValueError:
return f"`{field_name}`"


# pylint: disable=R0801
def get_spark_cast_statement_from_annotation(
element_name: str,
type_annotation: Any,
parent_element: bool = True,
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
):
"""Generate casting statements for spark dataframes based on type annotations"""
type_origin = get_origin(type_annotation)

quoted_name = _spark_safely_quote_name(element_name)

# An `Optional` or `Union` type, check to ensure non-heterogenity.
if type_origin is Union:
python_type = _get_non_heterogenous_type(get_args(type_annotation))
return get_spark_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
)

# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
element_type = _get_non_heterogenous_type(get_args(type_annotation))
stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)

if type_origin is Annotated:
python_type, *_ = get_args(type_annotation) # pylint: disable=unused-variable
return get_spark_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
) # add other expected params here
# Ensure that we have a concrete type at this point.
if not isinstance(type_annotation, type):
raise ValueError(f"Unsupported type annotation {type_annotation!r}")

if (
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
(issubclass(type_annotation, dict) and type_annotation is not dict)
# Type hint is a dataclass.
or is_dataclass(type_annotation)
# Type hint is a `pydantic` model.
or (type_origin is None and issubclass(type_annotation, BaseModel))
):
fields: dict[str, str] = {}
for field_name, field_annotation in get_type_hints(type_annotation).items():
# Technically non-string keys are disallowed, but people are bad.
if not isinstance(field_name, str):
raise ValueError(
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
) # pragma: no cover
if get_origin(field_annotation) is ClassVar:
continue

fields[field_name] = get_spark_cast_statement_from_annotation(
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
)

if not fields:
raise ValueError(
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
)
cast_exprs = ",".join([f"{stmt} AS `{nme}`" for nme, stmt in fields.items()])
stmt = f"struct({cast_exprs})"
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)
if type_annotation is list:
raise ValueError(
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
)
if type_annotation is dict or type_origin is dict:
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")

for type_ in type_annotation.mro():
# datetime is subclass of date, so needs to be handled first
if issubclass(type_, dt.datetime):
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
if issubclass(type_, dt.date):
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
spark_type = get_type_from_annotation(type_)
if spark_type:
stmt = f"trim({quoted_name})"
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
raise ValueError(f"No equivalent Spark type for {type_annotation!r}")
4 changes: 2 additions & 2 deletions src/dve/pipeline/foundry_ddb_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def persist_audit_records(self, submission_info: SubmissionInfo) -> URI:
write_to.parent.mkdir(parents=True, exist_ok=True)
write_to = write_to.as_posix()
self.write_parquet( # type: ignore # pylint: disable=E1101
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
f"submission_id = '{submission_info.submission_id}'"
),
fh.joinuri(write_to, "processing_status.parquet"),
)
self.write_parquet( # type: ignore # pylint: disable=E1101
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
f"submission_id = '{submission_info.submission_id}'"
),
fh.joinuri(write_to, "submission_statistics.parquet"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,73 @@
import datetime
import tempfile
from pathlib import Path
from typing import Any
from typing import Any, List

import pytest
import pyspark.sql.types as pst
from duckdb import DuckDBPyRelation, DuckDBPyConnection
from pydantic import BaseModel
from pyspark.sql import Row, SparkSession

from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
_ddb_read_parquet,
duckdb_rel_to_dictionaries)
duckdb_rel_to_dictionaries,
get_duckdb_cast_statement_from_annotation,
get_duckdb_type_from_annotation)

@pytest.fixture
def casting_test_table(temp_ddb_conn):
_, conn = temp_ddb_conn
conn.sql("""CREATE TABLE test_casting (
str_test VARCHAR,
int_test VARCHAR,
date_test VARCHAR,
timestamp_test VARCHAR,
list_int_field VARCHAR[],
basic_model STRUCT(str_field VARCHAR, date_field VARCHAR),
another_model STRUCT(unique_id VARCHAR, basic_models STRUCT(str_field VARCHAR, date_field VARCHAR)[]))""")

conn.sql("""INSERT INTO test_casting
VALUES(
'good_one',
'1',
'2024-11-13',
'2024-04-15 12:25:36',
['1', '2', '3'],
{'str_field': 'test', 'date_field': '2024-12-11'},
{'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}),
(
'dodgy_dates',
'2',
'24-11-13',
'2024-4-15 12:25:36',
['4', '5', '6'],
{'str_field': 'test', 'date_field': '202-1-11'},
{'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-5'}]})""")


yield temp_ddb_conn

conn.sql("DROP TABLE IF EXISTS test_casting")



class BasicModel(BaseModel):
str_field: str
date_field: datetime.date

class AnotherModel(BaseModel):
unique_id: int
basic_models: List[BasicModel]

class CastingRecord(BaseModel):
str_test: str
int_test: int
date_test: datetime.date
timestamp_test: datetime.datetime
list_int_field: list[int]
basic_model: BasicModel
another_model: AnotherModel

class TempConnection:
"""
Expand All @@ -25,6 +81,7 @@ def __init__(self, connection: DuckDBPyConnection) -> None:
self._connection = connection



@pytest.mark.parametrize(
"outpath",
[
Expand Down Expand Up @@ -94,4 +151,29 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection,
res.append(chunk)

assert res == data

# add decimal check
@pytest.mark.parametrize("field_name,field_type,cast_statement",
[("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"),
("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"),
("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"),
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"),
("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"),
("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])
def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement):
assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement


def test_use_cast_statements(casting_test_table):
_, conn = casting_test_table
test_rel = conn.sql("SELECT * from test_casting")
casting_statements = [ f"{get_duckdb_cast_statement_from_annotation(fld.name, fld.annotation)} as {fld.name}" for fld in CastingRecord.__fields__.values()]
test_rel = test_rel.project(",".join(casting_statements))
assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()}
dodgy_date_rec = test_rel.pl()[1].to_dicts()[0]
assert (not dodgy_date_rec.get("date_test") and
not dodgy_date_rec.get("basic_model",{}).get("date_field")
and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[]))
)

Loading
Loading