Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions src/dve/core_engine/backends/implementations/duckdb/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
duckdb_read_parquet,
duckdb_record_index,
duckdb_write_parquet,
get_duckdb_cast_statement_from_annotation,
get_duckdb_type_from_annotation,
relation_is_empty,
)
Expand Down Expand Up @@ -101,18 +102,7 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument
_lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable
return self._connection.sql("select * from _lazy_df")

@staticmethod
def generate_ddb_cast_statement(
column_name: str, dtype: DuckDBPyType, null_flag: bool = False
) -> str:
"""Helper method to generate sql statements for casting datatypes (permissively).
Current duckdb python API doesn't play well with this currently.
"""
if not null_flag:
return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"'
return f'cast(NULL AS {dtype}) AS "{column_name}"'

# pylint: disable=R0914
# pylint: disable=R0914,R0915
def apply_data_contract(
self,
working_dir: URI,
Expand Down Expand Up @@ -180,12 +170,16 @@ def apply_data_contract(

casting_statements = [
(
self.generate_ddb_cast_statement(column, dtype)
get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation)
+ f""" AS "{column}" """
if column in relation.columns
else self.generate_ddb_cast_statement(column, dtype, null_flag=True)
else f"CAST(NULL AS {ddb_schema[column]}) AS {column}"
)
for column, dtype in ddb_schema.items()
for column, mdl_fld in entity_fields.items()
]
casting_statements.append(
f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301
)
try:
relation = relation.project(", ".join(casting_statements))
except Exception as err: # pylint: disable=broad-except
Expand Down
105 changes: 105 additions & 0 deletions src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,108 @@ def duckdb_record_index(cls):
setattr(cls, "add_record_index", _add_duckdb_record_index)
setattr(cls, "drop_record_index", _drop_duckdb_record_index)
return cls


def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str:
"""Cast to Duck DB type"""
return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"""


def _ddb_safely_quote_name(field_name: str) -> str:
"""Quote field names in case reserved"""
try:
sep_idx = field_name.index(".")
return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:]
except ValueError:
return f'"{field_name}"'


# pylint: disable=R0801,R0911,R0912
def get_duckdb_cast_statement_from_annotation(
element_name: str,
type_annotation: Any,
parent_element: bool = True,
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
time_regex: str = r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$",
) -> str:
"""Generate casting statements for duckdb relations from type annotations"""
type_origin = get_origin(type_annotation)

quoted_name = _ddb_safely_quote_name(element_name)

# An `Optional` or `Union` type, check to ensure non-heterogenity.
if type_origin is Union:
python_type = _get_non_heterogenous_type(get_args(type_annotation))
return get_duckdb_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
)

# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
element_type = _get_non_heterogenous_type(get_args(type_annotation))
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)

if type_origin is Annotated:
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
return get_duckdb_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
) # add other expected params here
# Ensure that we have a concrete type at this point.
if not isinstance(type_annotation, type):
raise ValueError(f"Unsupported type annotation {type_annotation!r}")

if (
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
(issubclass(type_annotation, dict) and type_annotation is not dict)
# Type hint is a dataclass.
or is_dataclass(type_annotation)
# Type hint is a `pydantic` model.
or (type_origin is None and issubclass(type_annotation, BaseModel))
):
fields: dict[str, str] = {}
for field_name, field_annotation in get_type_hints(type_annotation).items():
# Technically non-string keys are disallowed, but people are bad.
if not isinstance(field_name, str):
raise ValueError(
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
) # pragma: no cover
if get_origin(field_annotation) is ClassVar:
continue

fields[field_name] = get_duckdb_cast_statement_from_annotation(
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
)

if not fields:
raise ValueError(
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
)
cast_exprs = ",".join([f'"{nme}":= {stmt}' for nme, stmt in fields.items()])
stmt = f"struct_pack({cast_exprs})"
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)

if type_annotation is list:
raise ValueError(
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
)
if type_annotation is dict or type_origin is dict:
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")

for type_ in type_annotation.mro():
# datetime is subclass of date, so needs to be handled first
if issubclass(type_, datetime):
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
return stmt
if issubclass(type_, date):
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
return stmt
if issubclass(type_, time):
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
return stmt
duck_type = get_duckdb_type_from_annotation(type_)
if duck_type:
stmt = f"trim({quoted_name})"
return _cast_as_ddb_type(stmt, type_) if parent_element else stmt
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")
100 changes: 100 additions & 0 deletions src/dve/core_engine/backends/implementations/spark/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,103 @@ def spark_record_index(cls):
setattr(cls, "add_record_index", _add_spark_record_index)
setattr(cls, "drop_record_index", _drop_spark_record_index)
return cls


def _cast_as_spark_type(field_expr: str, field_type: Any) -> Column:
"""Cast to spark type"""
return sf.expr(field_expr).cast(get_type_from_annotation(field_type))


def _spark_safely_quote_name(field_name: str) -> str:
"""Quote field names in case reserved"""
try:
sep_idx = field_name.index(".")
return f"`{field_name[: sep_idx]}`" + field_name[sep_idx:]
except ValueError:
return f"`{field_name}`"


# pylint: disable=R0801
def get_spark_cast_statement_from_annotation(
element_name: str,
type_annotation: Any,
parent_element: bool = True,
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
):
"""Generate casting statements for spark dataframes based on type annotations"""
type_origin = get_origin(type_annotation)

quoted_name = _spark_safely_quote_name(element_name)

# An `Optional` or `Union` type, check to ensure non-heterogenity.
if type_origin is Union:
python_type = _get_non_heterogenous_type(get_args(type_annotation))
return get_spark_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
)

# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
element_type = _get_non_heterogenous_type(get_args(type_annotation))
stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)

if type_origin is Annotated:
python_type, *_ = get_args(type_annotation) # pylint: disable=unused-variable
return get_spark_cast_statement_from_annotation(
element_name, python_type, parent_element, date_regex, timestamp_regex
) # add other expected params here
# Ensure that we have a concrete type at this point.
if not isinstance(type_annotation, type):
raise ValueError(f"Unsupported type annotation {type_annotation!r}")

if (
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
(issubclass(type_annotation, dict) and type_annotation is not dict)
# Type hint is a dataclass.
or is_dataclass(type_annotation)
# Type hint is a `pydantic` model.
or (type_origin is None and issubclass(type_annotation, BaseModel))
):
fields: dict[str, str] = {}
for field_name, field_annotation in get_type_hints(type_annotation).items():
# Technically non-string keys are disallowed, but people are bad.
if not isinstance(field_name, str):
raise ValueError(
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
) # pragma: no cover
if get_origin(field_annotation) is ClassVar:
continue

fields[field_name] = get_spark_cast_statement_from_annotation(
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
)

if not fields:
raise ValueError(
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
)
cast_exprs = ",".join([f"{stmt} AS `{nme}`" for nme, stmt in fields.items()])
stmt = f"struct({cast_exprs})"
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)
if type_annotation is list:
raise ValueError(
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
)
if type_annotation is dict or type_origin is dict:
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")

for type_ in type_annotation.mro():
# datetime is subclass of date, so needs to be handled first
if issubclass(type_, dt.datetime):
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
if issubclass(type_, dt.date):
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
spark_type = get_type_from_annotation(type_)
if spark_type:
stmt = f"trim({quoted_name})"
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
raise ValueError(f"No equivalent Spark type for {type_annotation!r}")
3 changes: 2 additions & 1 deletion src/dve/metadata_parser/domain_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ def validate(cls, value: Union[dt.time, dt.datetime, str]) -> dt.time | None:
raise ValueError("Provided time has timezone, but this is forbidden for this field")
if cls.TIMEZONE_TREATMENT == "require" and not new_time.tzinfo:
raise ValueError("Provided time missing timezone, but this is required for this field")

if isinstance(value, str) and cls.TIME_FORMAT and value != str(new_time):
raise ValueError("Provided time is not matching expected time format supplied.")
return new_time

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/dve/pipeline/foundry_ddb_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def persist_audit_records(self, submission_info: SubmissionInfo) -> URI:
write_to.parent.mkdir(parents=True, exist_ok=True)
write_to = write_to.as_posix()
self.write_parquet( # type: ignore # pylint: disable=E1101
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
f"submission_id = '{submission_info.submission_id}'"
),
fh.joinuri(write_to, "processing_status.parquet"),
)
self.write_parquet( # type: ignore # pylint: disable=E1101
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
f"submission_id = '{submission_info.submission_id}'"
),
fh.joinuri(write_to, "submission_statistics.parquet"),
Expand Down
Loading
Loading