From 3587a027fa0922e8748ee19b014b34a7cc89f569 Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:18:44 +0100 Subject: [PATCH 1/5] Feature/ndit 1146 improve contract casting statements (#90) * fix: enhance duckdb casting to be less permissive of poorly formatted dates and trim whitespace * feat: integrated duckdb casting into data contract and added initial spark casting * refactor: added further spark cast work with tests, small fixes to duckdb casting * style: address linting issues * refactor: add in time type for duckdb casting in contract and fix spark test involving regexp for spark casting in contract --- .../implementations/duckdb/contract.py | 24 ++-- .../implementations/duckdb/duckdb_helpers.py | 105 ++++++++++++++++++ .../implementations/spark/spark_helpers.py | 100 +++++++++++++++++ src/dve/pipeline/foundry_ddb_pipeline.py | 4 +- .../test_duckdb/test_duckdb_helpers.py | 86 +++++++++++++- .../test_spark}/test_spark_helpers.py | 66 ++++++++++- 6 files changed, 364 insertions(+), 21 deletions(-) rename tests/test_core_engine/{ => test_backends/test_implementations/test_spark}/test_spark_helpers.py (54%) diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 25fb8a7..3595716 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -31,6 +31,7 @@ duckdb_read_parquet, duckdb_record_index, duckdb_write_parquet, + get_duckdb_cast_statement_from_annotation, get_duckdb_type_from_annotation, relation_is_empty, ) @@ -101,18 +102,7 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument _lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable return self._connection.sql("select * from _lazy_df") - @staticmethod - def generate_ddb_cast_statement( - column_name: str, dtype: DuckDBPyType, null_flag: bool = False - ) -> str: - """Helper method to generate sql statements for casting datatypes (permissively). - Current duckdb python API doesn't play well with this currently. - """ - if not null_flag: - return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"' - return f'cast(NULL AS {dtype}) AS "{column_name}"' - - # pylint: disable=R0914 + # pylint: disable=R0914,R0915 def apply_data_contract( self, working_dir: URI, @@ -180,12 +170,16 @@ def apply_data_contract( casting_statements = [ ( - self.generate_ddb_cast_statement(column, dtype) + get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation) + + f""" AS "{column}" """ if column in relation.columns - else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + else f"CAST(NULL AS {ddb_schema[column]}) AS {column}" ) - for column, dtype in ddb_schema.items() + for column, mdl_fld in entity_fields.items() ] + casting_statements.append( + f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301 + ) try: relation = relation.project(", ".join(casting_statements)) except Exception as err: # pylint: disable=broad-except diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index f5b0fe9..394cd01 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -313,3 +313,108 @@ def duckdb_record_index(cls): setattr(cls, "add_record_index", _add_duckdb_record_index) setattr(cls, "drop_record_index", _drop_duckdb_record_index) return cls + + +def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str: + """Cast to Duck DB type""" + return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})""" + + +def _ddb_safely_quote_name(field_name: str) -> str: + """Quote field names in case reserved""" + try: + sep_idx = field_name.index(".") + return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:] + except ValueError: + return f'"{field_name}"' + + +# pylint: disable=R0801,R0911,R0912 +def get_duckdb_cast_statement_from_annotation( + element_name: str, + type_annotation: Any, + parent_element: bool = True, + date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 + time_regex: str = r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$", +) -> str: + """Generate casting statements for duckdb relations from type annotations""" + type_origin = get_origin(type_annotation) + + quoted_name = _ddb_safely_quote_name(element_name) + + # An `Optional` or `Union` type, check to ensure non-heterogenity. + if type_origin is Union: + python_type = _get_non_heterogenous_type(get_args(type_annotation)) + return get_duckdb_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) + + # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. + if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): + element_type = _get_non_heterogenous_type(get_args(type_annotation)) + stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 + return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) + + if type_origin is Annotated: + python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable + return get_duckdb_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) # add other expected params here + # Ensure that we have a concrete type at this point. + if not isinstance(type_annotation, type): + raise ValueError(f"Unsupported type annotation {type_annotation!r}") + + if ( + # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. + (issubclass(type_annotation, dict) and type_annotation is not dict) + # Type hint is a dataclass. + or is_dataclass(type_annotation) + # Type hint is a `pydantic` model. + or (type_origin is None and issubclass(type_annotation, BaseModel)) + ): + fields: dict[str, str] = {} + for field_name, field_annotation in get_type_hints(type_annotation).items(): + # Technically non-string keys are disallowed, but people are bad. + if not isinstance(field_name, str): + raise ValueError( + f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}" + ) # pragma: no cover + if get_origin(field_annotation) is ClassVar: + continue + + fields[field_name] = get_duckdb_cast_statement_from_annotation( + f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex + ) + + if not fields: + raise ValueError( + f"No type annotations in dict/dataclass type (got {type_annotation!r})" + ) + cast_exprs = ",".join([f'"{nme}":= {stmt}' for nme, stmt in fields.items()]) + stmt = f"struct_pack({cast_exprs})" + return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) + + if type_annotation is list: + raise ValueError( + f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" + ) + if type_annotation is dict or type_origin is dict: + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + + for type_ in type_annotation.mro(): + # datetime is subclass of date, so needs to be handled first + if issubclass(type_, datetime): + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301 + return stmt + if issubclass(type_, date): + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 + return stmt + if issubclass(type_, time): + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301 + return stmt + duck_type = get_duckdb_type_from_annotation(type_) + if duck_type: + stmt = f"trim({quoted_name})" + return _cast_as_ddb_type(stmt, type_) if parent_element else stmt + raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index 07a4a04..ced985a 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -439,3 +439,103 @@ def spark_record_index(cls): setattr(cls, "add_record_index", _add_spark_record_index) setattr(cls, "drop_record_index", _drop_spark_record_index) return cls + + +def _cast_as_spark_type(field_expr: str, field_type: Any) -> Column: + """Cast to spark type""" + return sf.expr(field_expr).cast(get_type_from_annotation(field_type)) + + +def _spark_safely_quote_name(field_name: str) -> str: + """Quote field names in case reserved""" + try: + sep_idx = field_name.index(".") + return f"`{field_name[: sep_idx]}`" + field_name[sep_idx:] + except ValueError: + return f"`{field_name}`" + + +# pylint: disable=R0801 +def get_spark_cast_statement_from_annotation( + element_name: str, + type_annotation: Any, + parent_element: bool = True, + date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 +): + """Generate casting statements for spark dataframes based on type annotations""" + type_origin = get_origin(type_annotation) + + quoted_name = _spark_safely_quote_name(element_name) + + # An `Optional` or `Union` type, check to ensure non-heterogenity. + if type_origin is Union: + python_type = _get_non_heterogenous_type(get_args(type_annotation)) + return get_spark_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) + + # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. + if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): + element_type = _get_non_heterogenous_type(get_args(type_annotation)) + stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 + return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation) + + if type_origin is Annotated: + python_type, *_ = get_args(type_annotation) # pylint: disable=unused-variable + return get_spark_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) # add other expected params here + # Ensure that we have a concrete type at this point. + if not isinstance(type_annotation, type): + raise ValueError(f"Unsupported type annotation {type_annotation!r}") + + if ( + # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. + (issubclass(type_annotation, dict) and type_annotation is not dict) + # Type hint is a dataclass. + or is_dataclass(type_annotation) + # Type hint is a `pydantic` model. + or (type_origin is None and issubclass(type_annotation, BaseModel)) + ): + fields: dict[str, str] = {} + for field_name, field_annotation in get_type_hints(type_annotation).items(): + # Technically non-string keys are disallowed, but people are bad. + if not isinstance(field_name, str): + raise ValueError( + f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}" + ) # pragma: no cover + if get_origin(field_annotation) is ClassVar: + continue + + fields[field_name] = get_spark_cast_statement_from_annotation( + f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex + ) + + if not fields: + raise ValueError( + f"No type annotations in dict/dataclass type (got {type_annotation!r})" + ) + cast_exprs = ",".join([f"{stmt} AS `{nme}`" for nme, stmt in fields.items()]) + stmt = f"struct({cast_exprs})" + return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation) + if type_annotation is list: + raise ValueError( + f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" + ) + if type_annotation is dict or type_origin is dict: + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + + for type_ in type_annotation.mro(): + # datetime is subclass of date, so needs to be handled first + if issubclass(type_, dt.datetime): + stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + if issubclass(type_, dt.date): + stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + spark_type = get_type_from_annotation(type_) + if spark_type: + stmt = f"trim({quoted_name})" + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + raise ValueError(f"No equivalent Spark type for {type_annotation!r}") diff --git a/src/dve/pipeline/foundry_ddb_pipeline.py b/src/dve/pipeline/foundry_ddb_pipeline.py index 3b0e55f..21cac56 100644 --- a/src/dve/pipeline/foundry_ddb_pipeline.py +++ b/src/dve/pipeline/foundry_ddb_pipeline.py @@ -42,13 +42,13 @@ def persist_audit_records(self, submission_info: SubmissionInfo) -> URI: write_to.parent.mkdir(parents=True, exist_ok=True) write_to = write_to.as_posix() self.write_parquet( # type: ignore # pylint: disable=E1101 - self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212 + self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212 f"submission_id = '{submission_info.submission_id}'" ), fh.joinuri(write_to, "processing_status.parquet"), ) self.write_parquet( # type: ignore # pylint: disable=E1101 - self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212 + self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212 f"submission_id = '{submission_info.submission_id}'" ), fh.joinuri(write_to, "submission_statistics.parquet"), diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py index 5c39e36..19e96e2 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py @@ -3,17 +3,73 @@ import datetime import tempfile from pathlib import Path -from typing import Any +from typing import Any, List import pytest import pyspark.sql.types as pst from duckdb import DuckDBPyRelation, DuckDBPyConnection +from pydantic import BaseModel from pyspark.sql import Row, SparkSession from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( _ddb_read_parquet, - duckdb_rel_to_dictionaries) + duckdb_rel_to_dictionaries, + get_duckdb_cast_statement_from_annotation, + get_duckdb_type_from_annotation) +@pytest.fixture +def casting_test_table(temp_ddb_conn): + _, conn = temp_ddb_conn + conn.sql("""CREATE TABLE test_casting ( + str_test VARCHAR, + int_test VARCHAR, + date_test VARCHAR, + timestamp_test VARCHAR, + list_int_field VARCHAR[], + basic_model STRUCT(str_field VARCHAR, date_field VARCHAR), + another_model STRUCT(unique_id VARCHAR, basic_models STRUCT(str_field VARCHAR, date_field VARCHAR)[]))""") + + conn.sql("""INSERT INTO test_casting + VALUES( + 'good_one', + '1', + '2024-11-13', + '2024-04-15 12:25:36', + ['1', '2', '3'], + {'str_field': 'test', 'date_field': '2024-12-11'}, + {'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}), + ( + 'dodgy_dates', + '2', + '24-11-13', + '2024-4-15 12:25:36', + ['4', '5', '6'], + {'str_field': 'test', 'date_field': '202-1-11'}, + {'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-5'}]})""") + + + yield temp_ddb_conn + + conn.sql("DROP TABLE IF EXISTS test_casting") + + + +class BasicModel(BaseModel): + str_field: str + date_field: datetime.date + +class AnotherModel(BaseModel): + unique_id: int + basic_models: List[BasicModel] + +class CastingRecord(BaseModel): + str_test: str + int_test: int + date_test: datetime.date + timestamp_test: datetime.datetime + list_int_field: list[int] + basic_model: BasicModel + another_model: AnotherModel class TempConnection: """ @@ -25,6 +81,7 @@ def __init__(self, connection: DuckDBPyConnection) -> None: self._connection = connection + @pytest.mark.parametrize( "outpath", [ @@ -94,4 +151,29 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection, res.append(chunk) assert res == data + +# add decimal check +@pytest.mark.parametrize("field_name,field_type,cast_statement", + [("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"), + ("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"), + ("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"), + ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"), + ("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"), + ("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), + ("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) +def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement): + assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement + + +def test_use_cast_statements(casting_test_table): + _, conn = casting_test_table + test_rel = conn.sql("SELECT * from test_casting") + casting_statements = [ f"{get_duckdb_cast_statement_from_annotation(fld.name, fld.annotation)} as {fld.name}" for fld in CastingRecord.__fields__.values()] + test_rel = test_rel.project(",".join(casting_statements)) + assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()} + dodgy_date_rec = test_rel.pl()[1].to_dicts()[0] + assert (not dodgy_date_rec.get("date_test") and + not dodgy_date_rec.get("basic_model",{}).get("date_field") + and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) + ) diff --git a/tests/test_core_engine/test_spark_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py similarity index 54% rename from tests/test_core_engine/test_spark_helpers.py rename to tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py index a3f167d..7502673 100644 --- a/tests/test_core_engine/test_spark_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py @@ -12,17 +12,56 @@ from pydantic.types import condecimal from pyspark.sql import DataFrame, SparkSession from pyspark.sql import types as st -from pyspark.sql.functions import col +from pyspark.sql.functions import col, expr +from pyspark.sql.types import ArrayType, DateType, LongType, StringType, StructField, StructType, TimestampType from typing_extensions import Annotated, TypedDict from dve.core_engine.backends.implementations.spark.spark_helpers import ( DecimalConfig, create_udf, + get_spark_cast_statement_from_annotation, get_type_from_annotation, object_to_spark_literal, ) -from ..fixtures import spark # pylint: disable=unused-import +from .....fixtures import spark # pylint: disable=unused-import + +@pytest.fixture +def casting_dataframe(spark): + data = [{"str_test": "good_one", "int_test": "1", "date_test": "2024-11-13", "timestamp_test": "2024-04-15 12:25:36", + "list_int_field":['1', '2', '3'], "basic_model": {'str_field': 'test', 'date_field': '2024-12-11'}, + "another_model": {'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}}, + {"str_test": "dodgy_dates", "int_test": "2", "date_test": "24-11-13", "timestamp_test": "2024-4-15 12:25:36", + "list_int_field":['4', '5', '6'], "basic_model": {'str_field': 'test', 'date_field': '202-12-11'}, + "another_model": {'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-05'}]}}] + + bm_schema = StructType([StructField("str_field", StringType()), StructField("date_field", StringType())]) + + schema = StructType([StructField("str_test", StringType()), StructField("int_test", StringType()), StructField("date_test", StringType()), + StructField("timestamp_test", StringType()), StructField("list_int_field", ArrayType(StringType())), + StructField("basic_model", bm_schema), + StructField("another_model", StructType([StructField("unique_id", StringType()), StructField("basic_models", ArrayType(bm_schema))]))]) + yield spark.createDataFrame(data, schema=schema) + + + + +class BasicModel(BaseModel): + str_field: str + date_field: dt.date + +class AnotherModel(BaseModel): + unique_id: int + basic_models: List[BasicModel] + +class CastingRecord(BaseModel): + str_test: str + int_test: int + date_test: dt.date + timestamp_test: dt.datetime + list_int_field: list[int] + basic_model: BasicModel + another_model: AnotherModel EXPECTED_STRUCT = st.StructType( [ @@ -203,3 +242,26 @@ def test_object_to_spark_literal_blocks_some_footguns(obj: Any): """ with pytest.raises(ValueError): object_to_spark_literal(obj) + +@pytest.mark.parametrize("field_name,field_type,expression,spark_type", + [("str_test", str, "trim(`str_test`)", StringType()), + ("int_test", int, "trim(`int_test`)", LongType()), + ("date_test", dt.date, "CASE WHEN REGEXP(TRIM(`date_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`date_test`) ELSE NULL END", DateType()), + ("timestamp_test", dt.datetime, r"CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()), + ("list_int_field", list[int], "transform(`list_int_field`, x -> trim(`x`))", ArrayType(LongType(), True)), + ("basic_model", BasicModel, "struct(trim(`basic_model`.str_field) as str_field, CASE WHEN REGEXP(TRIM(`basic_model`.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`basic_model`.date_field) ELSE NULL END as date_field)", StructType([StructField("str_field", StringType(), True), StructField("date_field", DateType(), True)])), + ("another_model", AnotherModel, "struct(trim(`another_model`.unique_id) as unique_id, transform(`another_model`.basic_models, x -> struct(trim(x.str_field) as str_field, CASE WHEN REGEXP(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(x.date_field) ELSE NULL END as date_field)) as basic_models)", StructType([StructField("unique_id", LongType(), True), StructField("basic_models", ArrayType(StructType([StructField("str_field", StringType()), StructField("date_field", DateType(), True)])))]))]) +def test_get_spark_cast_statement_from_annotation(field_name, field_type, expression, spark_type): + assert str(get_spark_cast_statement_from_annotation(field_name, field_type)) == str(expr(expression).cast(spark_type)) + + +def test_use_cast_statements(spark, casting_dataframe): + casting_statements = [ get_spark_cast_statement_from_annotation(fld.name, fld.annotation).alias(fld.name) for fld in CastingRecord.__fields__.values()] + cast_df = casting_dataframe.select(*casting_statements) + assert {fld.name: fld.dataType for fld in cast_df.schema} == {fld.name: get_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()} + dodgy_date_rec = [rw.asDict(True) for rw in cast_df.collect()][1] + assert (not dodgy_date_rec.get("date_test") and + not dodgy_date_rec.get("basic_model",{}).get("date_field") + and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) + ) + assert cast_df \ No newline at end of file From 11930b81d8ed60e9957ec3018edb4c9989be304f Mon Sep 17 00:00:00 2001 From: George Robertson <50412379+georgeRobertson@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:14:28 +0100 Subject: [PATCH 2/5] fix: make time format more strict to stop invalid time date flowing (#92) * fix: make time format more strict to stop invalid time date flowing * style: address sonarcube l523 comment --- src/dve/metadata_parser/domain_types.py | 3 ++- tests/test_model_generation/test_domain_types.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/dve/metadata_parser/domain_types.py b/src/dve/metadata_parser/domain_types.py index 545429f..6e102ea 100644 --- a/src/dve/metadata_parser/domain_types.py +++ b/src/dve/metadata_parser/domain_types.py @@ -519,7 +519,8 @@ def validate(cls, value: Union[dt.time, dt.datetime, str]) -> dt.time | None: raise ValueError("Provided time has timezone, but this is forbidden for this field") if cls.TIMEZONE_TREATMENT == "require" and not new_time.tzinfo: raise ValueError("Provided time missing timezone, but this is required for this field") - + if isinstance(value, str) and cls.TIME_FORMAT and value != str(new_time): + raise ValueError("Provided time is not matching expected time format supplied.") return new_time @classmethod diff --git a/tests/test_model_generation/test_domain_types.py b/tests/test_model_generation/test_domain_types.py index a494494..56cf3f9 100644 --- a/tests/test_model_generation/test_domain_types.py +++ b/tests/test_model_generation/test_domain_types.py @@ -335,9 +335,6 @@ def test_reportingperiod_raises(field, value): ["23:00:00Z", None, "require", dt.time(23, 0, 0, tzinfo=UTC)], ["12:00:00Zam", None, "permit", dt.time(0, 0, 0, tzinfo=UTC)], ["12:00:00pm", None, "forbid", dt.time(12, 0, 0)], - ["1970-01-01", "%Y-%m-%d", "forbid", dt.time(0, 0)], - # not great that it effectively returns incorrect time object here. However, this would be - # down to user error in setting up the dischema. [dt.datetime(2025, 12, 1, 13, 0, 5), "%H:%M:%S", "forbid", dt.time(13, 0, 5)], [dt.datetime(2025, 12, 1, 13, 0, 5, tzinfo=UTC), "%H:%M:%S", "require", dt.time(13, 0, 5, tzinfo=UTC)], [dt.time(13, 0, 0), "%H:%M:%S", "forbid", dt.time(13, 0, 0)], @@ -364,6 +361,9 @@ def test_formattedtime( ["23:00:00", "%I:%M:%S", "permit",], ["23:00:00", "%H:%M:%S", "require",], ["23:00:00Z", "%I:%M:%S", "forbid",], + ["2:10:13", "%H:%M:%S", "forbid",], + ["20:0:13", "%H:%M:%S", "forbid",], + ["20:10:1", "%H:%M:%S", "forbid",], [dt.datetime(2025, 12, 1, 13, 0, 5, tzinfo=UTC), "%H:%M:%S", "forbid",], [dt.time(13, 0, 5, tzinfo=UTC), "%H:%M:%S", "forbid",], ["12:00", "%H:%M:%S", "forbid",], From 76ecd7e230aec1f1506e63367d5ed06a04e9a385 Mon Sep 17 00:00:00 2001 From: "george.robertson1" <50412379+georgeRobertson@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:26:48 +0100 Subject: [PATCH 3/5] build: update poetry.lock --- poetry.lock | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index 297787b..141c41b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand. [[package]] name = "argcomplete" @@ -58,7 +58,7 @@ tomli = {version = ">=1.1.0", markers = "python_version >= \"3.0\" and python_ve [package.extras] develop = ["PyHamcrest (<2.0) ; python_version < \"3.0\"", "PyHamcrest (>=2.0.2) ; python_version >= \"3.0\"", "build (>=0.5.1)", "coverage (>=5.0)", "invoke (>=1.7.0) ; python_version >= \"3.6\"", "invoke (>=1.7.0,<2.0) ; python_version < \"3.6\"", "mock (<4.0) ; python_version < \"3.6\"", "mock (>=4.0) ; python_version >= \"3.6\"", "modernize (>=0.5)", "path (>=13.1.0) ; python_version >= \"3.5\"", "path.py (>=11.5.0) ; python_version < \"3.5\"", "pycmd", "pylint", "pytest (>=4.2,<5.0) ; python_version < \"3.0\"", "pytest (>=5.0) ; python_version >= \"3.0\"", "pytest-cov", "pytest-html (>=1.19.0,<2.0) ; python_version < \"3.0\"", "pytest-html (>=2.0) ; python_version >= \"3.0\"", "ruff ; python_version >= \"3.7\"", "tox (>=3.28.0,<4.0)", "twine (>=1.13.0)", "virtualenv (<20.22.0) ; python_version < \"3.7\"", "virtualenv (>=20.26.6) ; python_version >= \"3.7\""] -docs = ["furo (>=2024.04.27) ; python_version >= \"3.8\"", "sphinx (>=1.6,<4.4) ; python_version < \"3.7\"", "sphinx (>=7.4.0) ; python_version >= \"3.7\"", "sphinx-copybutton (>=0.5.2) ; python_version >= \"3.7\"", "sphinxcontrib-applehelp (>=1.0.8) ; python_version >= \"3.7\"", "sphinxcontrib-htmlhelp (>=2.0.5) ; python_version >= \"3.7\""] +docs = ["furo (>=2024.4.27) ; python_version >= \"3.8\"", "sphinx (>=1.6,<4.4) ; python_version < \"3.7\"", "sphinx (>=7.4.0) ; python_version >= \"3.7\"", "sphinx-copybutton (>=0.5.2) ; python_version >= \"3.7\"", "sphinxcontrib-applehelp (>=1.0.8) ; python_version >= \"3.7\"", "sphinxcontrib-htmlhelp (>=2.0.5) ; python_version >= \"3.7\""] formatters = ["behave-html-formatter (>=0.9.10) ; python_version >= \"3.6\"", "behave-html-pretty-formatter (>=1.9.1) ; python_version >= \"3.6\""] testing = ["PyHamcrest (<2.0) ; python_version < \"3.0\"", "PyHamcrest (>=2.0.2) ; python_version >= \"3.0\"", "assertpy (>=1.1)", "chardet", "freezegun (>=1.5.1) ; python_version > \"3.7\"", "mock (<4.0) ; python_version < \"3.6\"", "mock (>=4.0) ; python_version >= \"3.6\"", "path (>=13.1.0) ; python_version >= \"3.5\"", "path.py (>=11.5.0,<13.0) ; python_version < \"3.5\"", "pytest (<5.0) ; python_version < \"3.0\"", "pytest (>=5.0) ; python_version >= \"3.0\"", "pytest-html (>=1.19.0,<2.0) ; python_version < \"3.0\"", "pytest-html (>=2.0) ; python_version >= \"3.0\""] @@ -842,7 +842,7 @@ deprecated = ">=1.2.13,<2" jinja2 = ">=2.10.3" packaging = ">=19" prompt_toolkit = "!=3.0.52" -pyyaml = ">=3.08" +pyyaml = ">=3.8" questionary = ">=2.0,<3.0" termcolor = ">=1.1.0,<4.0.0" tomlkit = ">=0.5.3,<1.0.0" @@ -2520,7 +2520,7 @@ files = [ ] [package.dependencies] -astroid = ">=2.14.2,<=2.16.0-dev0" +astroid = ">=2.14.2,<=2.16.0.dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, @@ -2831,10 +2831,10 @@ files = [ ] [package.dependencies] -botocore = ">=1.33.2,<2.0a.0" +botocore = ">=1.33.2,<2.0a0" [package.extras] -crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] +crt = ["botocore[crt] (>=1.33.2,<2.0a0)"] [[package]] name = "six" From 865040cccffaeaa8987151a286700be29f8e2802 Mon Sep 17 00:00:00 2001 From: "george.robertson1" <50412379+georgeRobertson@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:38:55 +0100 Subject: [PATCH 4/5] =?UTF-8?q?bump:=20version=200.7.2=20=E2=86=92=200.7.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 6 ++++++ pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6682d5..47eb833 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## v0.7.3 (2026-04-16) + +### Fix + +- make time format more strict to stop invalid time date flowing (#92) + ## v0.7.2 (2026-04-02) ### Fix diff --git a/pyproject.toml b/pyproject.toml index 4b27f76..26071f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ Issues = "https://github.com/NHSDigital/data-validation-engine/issues" Changelog = "https://github.com/NHSDigital/data-validation-engine/blob/main/CHANGELOG.md" [tool.poetry] -version = "0.7.2" +version = "0.7.3" packages = [ { include = "dve", from = "src" }, ] From 3f057dd70d8b29ac9314a40c7361e944f21ee8f3 Mon Sep 17 00:00:00 2001 From: "george.robertson1" <50412379+georgeRobertson@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:43:28 +0100 Subject: [PATCH 5/5] docs: manually update changelog due to missing changes --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47eb833..dbb56c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ ### Fix - make time format more strict to stop invalid time date flowing (#92) +- enhance duckdb casting to be less permissive of poorly formatted dates and trim whitespace + +### Refactor + +- integrated duckdb casting into data contract and added initial spark casting ## v0.7.2 (2026-04-02)