From ccca618e763517a4b159e350e7cdf5e127a36359 Mon Sep 17 00:00:00 2001 From: "langu.gjl" Date: Mon, 1 Jun 2026 15:21:21 +0800 Subject: [PATCH] [python] Support query auth (row filter & column masking) for REST catalog Adds query-auth support to the Python client so it honors the row-level filter and column masking rules returned by a REST catalog, matching the existing JVM client behavior. When the new option `query-auth.enabled` is set to true, the client calls `POST /v1/.../databases/{db}/tables/{tb}/auth` before producing a plan, receives `{ filter, columnMasking }`, and applies them on the read path: * `predicate_json_parser` parses Paimon predicate JSON into a PyArrow compute filter (EQ/NEQ/LT/LTEQ/GT/GTEQ/IS_NULL/IS_NOT_NULL/ IN/NOT_IN/STARTS_WITH/ENDS_WITH/CONTAINS/AND/OR/NOT). * `AuthFilterReader` / `AuthMaskingReader` / `ColumnProjectReader` perform row filtering, column masking transforms (NULL, FIELD_REF, CAST, UPPER, LOWER, CONCAT, CONCAT_WS) and final projection back to the user's requested columns. * `TableQueryAuth` / `TableQueryAuthResult` wrap the result and convert each split to a `QueryAuthSplit`. Behavior is gated by `CoreOptions.QUERY_AUTH_ENABLED` (default false), so existing users see no change. --- paimon-python/pypaimon/api/api_response.py | 12 + paimon-python/pypaimon/api/resource_paths.py | 6 + paimon-python/pypaimon/api/rest_api.py | 14 +- paimon-python/pypaimon/catalog/catalog.py | 3 + .../pypaimon/catalog/catalog_environment.py | 9 + .../pypaimon/catalog/rest/rest_catalog.py | 19 +- .../pypaimon/catalog/table_query_auth.py | 87 ++++ .../pypaimon/common/options/core_options.py | 11 + .../pypaimon/common/predicate_json_parser.py | 290 +++++++++++ .../pypaimon/read/query_auth_split.py | 55 +++ paimon-python/pypaimon/read/read_builder.py | 7 +- .../read/reader/auth_masking_reader.py | 158 ++++++ .../pypaimon/read/stream_read_builder.py | 3 +- paimon-python/pypaimon/read/table_read.py | 59 ++- paimon-python/pypaimon/read/table_scan.py | 18 +- .../pypaimon/table/file_store_table.py | 7 +- .../tests/auth_masking_reader_test.py | 301 ++++++++++++ .../tests/predicate_json_parser_test.py | 462 ++++++++++++++++++ .../pypaimon/tests/table_query_auth_test.py | 249 ++++++++++ 19 files changed, 1759 insertions(+), 11 deletions(-) create mode 100644 paimon-python/pypaimon/catalog/table_query_auth.py create mode 100644 paimon-python/pypaimon/common/predicate_json_parser.py create mode 100644 paimon-python/pypaimon/read/query_auth_split.py create mode 100644 paimon-python/pypaimon/read/reader/auth_masking_reader.py create mode 100644 paimon-python/pypaimon/tests/auth_masking_reader_test.py create mode 100644 paimon-python/pypaimon/tests/predicate_json_parser_test.py create mode 100644 paimon-python/pypaimon/tests/table_query_auth_test.py diff --git a/paimon-python/pypaimon/api/api_response.py b/paimon-python/pypaimon/api/api_response.py index 2df704b234e6..0d392b9eefd5 100644 --- a/paimon-python/pypaimon/api/api_response.py +++ b/paimon-python/pypaimon/api/api_response.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import Dict, Generic, List, Optional +from pypaimon.api.api_request import RESTRequest from pypaimon.common.identifier import Identifier from pypaimon.common.json_util import T, json_field from pypaimon.common.options import Options @@ -600,3 +601,14 @@ def to_dict(self) -> Dict: result["functions"] = None result["nextPageToken"] = self.next_page_token return result + + +@dataclass +class AuthTableQueryRequest(RESTRequest): + select: Optional[List[str]] = json_field("select", default=None) + + +@dataclass +class AuthTableQueryResponse(RESTResponse): + filter: Optional[List[str]] = json_field("filter", default=None) + column_masking: Optional[Dict[str, str]] = json_field("columnMasking", default=None) diff --git a/paimon-python/pypaimon/api/resource_paths.py b/paimon-python/pypaimon/api/resource_paths.py index fad221c3bf87..9207d4266ecb 100644 --- a/paimon-python/pypaimon/api/resource_paths.py +++ b/paimon-python/pypaimon/api/resource_paths.py @@ -133,3 +133,9 @@ def rename_branch(self, database_name: str, table_name: str, branch_name: str) - def forward_branch(self, database_name: str, table_name: str, branch_name: str) -> str: return "{}/{}".format( self.branch(database_name, table_name, branch_name), self.FORWARD) + + def auth_table(self, database_name: str, table_name: str) -> str: + return "{}/{}/{}/{}/{}/auth".format( + self.base_path, self.DATABASES, RESTUtil.encode_string(database_name), + self.TABLES, RESTUtil.encode_string(table_name) + ) diff --git a/paimon-python/pypaimon/api/rest_api.py b/paimon-python/pypaimon/api/rest_api.py index b6ed08860c5f..78afe0d9145e 100755 --- a/paimon-python/pypaimon/api/rest_api.py +++ b/paimon-python/pypaimon/api/rest_api.py @@ -40,7 +40,9 @@ ListTablesResponse, ListTagsResponse, PagedList, PagedResponse, GetTableSnapshotResponse, - Partition) + Partition, + AuthTableQueryRequest, + AuthTableQueryResponse) from pypaimon.api.auth import AuthProviderFactory, RESTAuthFunction from pypaimon.api.client import HttpClient from pypaimon.api.resource_paths import ResourcePaths @@ -688,6 +690,16 @@ def alter_function(self, identifier: Identifier, changes: List) -> None: self.rest_auth_function, ) + def auth_table_query(self, identifier: Identifier, select: Optional[List[str]]) -> AuthTableQueryResponse: + database_name, table_name = self.__validate_identifier(identifier) + request = AuthTableQueryRequest(select=select) + return self.client.post_with_response_type( + self.resource_paths.auth_table(database_name, table_name), + request, + AuthTableQueryResponse, + self.rest_auth_function, + ) + @staticmethod def __validate_identifier(identifier: Identifier): if not identifier: diff --git a/paimon-python/pypaimon/catalog/catalog.py b/paimon-python/pypaimon/catalog/catalog.py index 4a364b06aab5..c52b65b4ae48 100644 --- a/paimon-python/pypaimon/catalog/catalog.py +++ b/paimon-python/pypaimon/catalog/catalog.py @@ -401,3 +401,6 @@ def list_tags_paged( raise NotImplementedError( "list_tags_paged is not supported by this catalog." ) + + def auth_table_query(self, identifier: Identifier, select: Optional[List[str]]) -> 'TableQueryAuthResult': + raise NotImplementedError("auth_table_query not supported by this catalog") diff --git a/paimon-python/pypaimon/catalog/catalog_environment.py b/paimon-python/pypaimon/catalog/catalog_environment.py index a754d90ca42c..f1b3a2eef6fc 100644 --- a/paimon-python/pypaimon/catalog/catalog_environment.py +++ b/paimon-python/pypaimon/catalog/catalog_environment.py @@ -117,3 +117,12 @@ def empty() -> 'CatalogEnvironment': catalog_loader=None, supports_version_management=False ) + + def table_query_auth(self, options, identifier): + if not options.query_auth_enabled or self.catalog_loader is None: + return lambda select: None + + def auth(select): + catalog = self.catalog_loader.load() + return catalog.auth_table_query(identifier, select) + return auth diff --git a/paimon-python/pypaimon/catalog/rest/rest_catalog.py b/paimon-python/pypaimon/catalog/rest/rest_catalog.py index d6e89d50b9b1..fe60cb865b49 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_catalog.py +++ b/paimon-python/pypaimon/catalog/rest/rest_catalog.py @@ -21,7 +21,8 @@ from pypaimon.api.rest_api import RESTApi from pypaimon.catalog.catalog_exception import IllegalArgumentError from pypaimon.api.rest_exception import (NoSuchResourceException, AlreadyExistsException, - ForbiddenException, BadRequestException) + ForbiddenException, BadRequestException, + ServiceFailureException, NotImplementedException) from pypaimon.catalog.catalog import Catalog from pypaimon.catalog.catalog_context import CatalogContext from pypaimon.catalog.catalog_environment import CatalogEnvironment @@ -756,3 +757,19 @@ def create(file_io: FileIO, ) -> FileStoreTable: """Create FileStoreTable with dynamic options and catalog environment""" return FileStoreTable(file_io, catalog_environment.identifier, table_path, table_schema, catalog_environment) + + def auth_table_query(self, identifier, select=None): + from pypaimon.catalog.table_query_auth import TableQueryAuthResult, TableNoPermissionException + try: + response = self.rest_api.auth_table_query(identifier, select) + return TableQueryAuthResult(response.filter, response.column_masking) + except NoSuchResourceException as e: + raise TableNotExistException(identifier) from e + except ForbiddenException as e: + raise TableNoPermissionException(identifier, e) from e + except ServiceFailureException as e: + raise RuntimeError(e.args[0] if e.args else str(e)) from e + except NotImplementedException as e: + raise NotImplementedError(e.args[0] if e.args else str(e)) from e + except BadRequestException as e: + raise RuntimeError(str(e)) from e diff --git a/paimon-python/pypaimon/catalog/table_query_auth.py b/paimon-python/pypaimon/catalog/table_query_auth.py new file mode 100644 index 000000000000..921366f5932d --- /dev/null +++ b/paimon-python/pypaimon/catalog/table_query_auth.py @@ -0,0 +1,87 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Callable, Dict, List, Optional + +import pyarrow as pa +import pyarrow.compute as pc + +from pypaimon.common.predicate_json_parser import ( + extract_referenced_fields, + parse_predicate_to_batch_filter, +) +from pypaimon.schema.data_types import DataField + + +class TableNoPermissionException(Exception): + MSG = "Table %s has no permission. Cause by %s." + + def __init__(self, identifier, cause=None): + cause_msg = str(cause) if cause else "" + super().__init__(self.MSG % (identifier, cause_msg)) + self.identifier = identifier + self.__cause__ = cause + + +class TableQueryAuthResult: + + def __init__(self, filter: Optional[List[str]], column_masking: Optional[Dict[str, str]]): + self.filter = filter + self.column_masking = column_masking + + def convert_plan(self, plan): + from pypaimon.read.query_auth_split import QueryAuthSplit + from pypaimon.read.plan import Plan + + if not self.filter and not self.column_masking: + return plan + auth_splits = [QueryAuthSplit(split, self) for split in plan.splits()] + return Plan(auth_splits) + + def extract_row_filter(self) -> Optional[Callable[[pa.RecordBatch], pa.Array]]: + if not self.filter: + return None + filters = [parse_predicate_to_batch_filter(json_str) for json_str in self.filter] + if len(filters) == 1: + return filters[0] + + def combined(batch: pa.RecordBatch) -> pa.Array: + result = filters[0](batch) + for f in filters[1:]: + result = pc.and_(result, f(batch)) + return result + return combined + + def get_extra_fields_for_filter( + self, + read_fields: List[DataField], + table_fields: List[DataField], + ) -> List[DataField]: + if not self.filter: + return [] + read_field_names = {f.name for f in read_fields} + extra = [] + for json_str in self.filter: + referenced = extract_referenced_fields(json_str) + for name in referenced: + if name not in read_field_names: + field = next((f for f in table_fields if f.name == name), None) + if field: + extra.append(field) + read_field_names.add(name) + return extra diff --git a/paimon-python/pypaimon/common/options/core_options.py b/paimon-python/pypaimon/common/options/core_options.py index 874b888faf42..ced1ca651a8f 100644 --- a/paimon-python/pypaimon/common/options/core_options.py +++ b/paimon-python/pypaimon/common/options/core_options.py @@ -709,6 +709,13 @@ class CoreOptions: ) ) + QUERY_AUTH_ENABLED: ConfigOption[bool] = ( + ConfigOptions.key("query-auth.enabled") + .boolean_type() + .default_value(False) + .with_description("Whether to enable query auth.") + ) + PARTITION_DEFAULT_NAME: ConfigOption[str] = ( ConfigOptions.key("partition.default-name") .string_type() @@ -1080,3 +1087,7 @@ def add_column_before_partition(self) -> bool: def dynamic_partition_overwrite(self) -> bool: return self.options.get(CoreOptions.DYNAMIC_PARTITION_OVERWRITE) + + @property + def query_auth_enabled(self) -> bool: + return self.options.get(CoreOptions.QUERY_AUTH_ENABLED) diff --git a/paimon-python/pypaimon/common/predicate_json_parser.py b/paimon-python/pypaimon/common/predicate_json_parser.py new file mode 100644 index 000000000000..f306eb571f15 --- /dev/null +++ b/paimon-python/pypaimon/common/predicate_json_parser.py @@ -0,0 +1,290 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import json +import re +from typing import Callable + +import pyarrow as pa +import pyarrow.compute as pc + + +def parse_predicate_to_batch_filter(json_str: str) -> Callable[[pa.RecordBatch], pa.Array]: + data = json.loads(json_str) + return _build_filter(data) + + +def _build_filter(data: dict) -> Callable[[pa.RecordBatch], pa.Array]: + kind = data["kind"] + if kind == "LEAF": + return _build_leaf_filter(data) + elif kind == "COMPOUND": + return _build_compound_filter(data) + raise ValueError(f"Unknown predicate kind: {kind}") + + +def _build_leaf_filter(data: dict) -> Callable: + transform = data["transform"] + function = data["function"] + literals = data.get("literals", []) + + def filter_fn(batch: pa.RecordBatch) -> pa.Array: + value_array = _apply_predicate_transform(transform, batch) + return _apply_leaf_function(function, value_array, literals, len(batch)) + + return filter_fn + + +def _build_compound_filter(data: dict) -> Callable: + function = data["function"] + child_filters = [_build_filter(child) for child in data["children"]] + + def filter_fn(batch: pa.RecordBatch) -> pa.Array: + if function == "AND": + result = child_filters[0](batch) + for cf in child_filters[1:]: + result = pc.and_(result, cf(batch)) + return result + elif function == "OR": + result = child_filters[0](batch) + for cf in child_filters[1:]: + result = pc.or_(result, cf(batch)) + return result + raise ValueError(f"Unknown compound function: {function}") + + return filter_fn + + +def _apply_predicate_transform(transform: dict, batch: pa.RecordBatch) -> pa.Array: + name = transform["name"] + + if name == "FIELD_REF": + return batch.column(transform["fieldRef"]["name"]) + + elif name == "CAST": + col = batch.column(transform["fieldRef"]["name"]) + target_type = _paimon_type_to_arrow(transform["type"]) + return pc.cast(col, target_type) + + elif name == "UPPER": + input_col = _resolve_transform_input(transform["inputs"][0], batch) + return pc.utf8_upper(input_col) + + elif name == "LOWER": + input_col = _resolve_transform_input(transform["inputs"][0], batch) + return pc.utf8_lower(input_col) + + elif name == "CONCAT": + resolved = [_resolve_transform_input(inp, batch) for inp in transform["inputs"]] + if not resolved: + return pa.nulls(len(batch), type=pa.string()) + return pc.binary_join_element_wise(*resolved, "") + + elif name == "CONCAT_WS": + sep = _resolve_transform_input(transform["inputs"][0], batch) + values = [_resolve_transform_input(inp, batch) for inp in transform["inputs"][1:]] + if not values: + return pa.nulls(len(batch), type=pa.string()) + return pc.binary_join_element_wise(*values, sep, null_handling='skip') + + elif name == "NULL": + return pa.nulls(len(batch), type=pa.bool_()) + + raise ValueError(f"Unknown transform type in predicate: {name}") + + +def _resolve_transform_input(inp, batch: pa.RecordBatch) -> pa.Array: + if isinstance(inp, dict): + return batch.column(inp["name"]) + elif isinstance(inp, str): + return pa.array([inp] * len(batch), type=pa.string()) + elif inp is None: + return pa.nulls(len(batch), type=pa.string()) + return pa.array([str(inp)] * len(batch), type=pa.string()) + + +def _apply_leaf_function(function: str, value_array: pa.Array, literals: list, batch_len: int) -> pa.Array: + converted = [_convert_literal(lit, value_array.type) for lit in literals] + + if function == "EQUAL": + return pc.equal(value_array, converted[0]) + elif function == "NOT_EQUAL": + return pc.not_equal(value_array, converted[0]) + elif function == "LESS_THAN": + return pc.less(value_array, converted[0]) + elif function == "LESS_OR_EQUAL": + return pc.less_equal(value_array, converted[0]) + elif function == "GREATER_THAN": + return pc.greater(value_array, converted[0]) + elif function == "GREATER_OR_EQUAL": + return pc.greater_equal(value_array, converted[0]) + elif function == "IS_NULL": + return pc.is_null(value_array) + elif function == "IS_NOT_NULL": + return pc.is_valid(value_array) + elif function == "IN": + return pc.is_in(value_array, pa.array(converted, type=value_array.type)) + elif function == "NOT_IN": + return pc.invert(pc.is_in(value_array, pa.array(converted, type=value_array.type))) + elif function == "BETWEEN": + return pc.and_(pc.greater_equal(value_array, converted[0]), + pc.less_equal(value_array, converted[1])) + elif function == "NOT_BETWEEN": + return pc.or_(pc.less(value_array, converted[0]), + pc.greater(value_array, converted[1])) + elif function == "STARTS_WITH": + return pc.starts_with(value_array, converted[0]) + elif function == "ENDS_WITH": + return pc.ends_with(value_array, converted[0]) + elif function == "CONTAINS": + return pc.match_substring(value_array, converted[0]) + elif function == "LIKE": + raw = literals[0] + escaped = re.escape(raw) + pattern = escaped.replace("%", ".*").replace("_", ".") + return pc.match_substring_regex(value_array, f"^{pattern}$") + elif function == "TRUE": + return pa.array([True] * batch_len, type=pa.bool_()) + elif function == "FALSE": + return pa.array([False] * batch_len, type=pa.bool_()) + raise ValueError(f"Unknown leaf function: {function}") + + +def _convert_literal(literal, target_type: pa.DataType): + if literal is None: + return None + if pa.types.is_timestamp(target_type): + import datetime + if isinstance(literal, str): + dt = datetime.datetime.fromisoformat(literal.replace("Z", "+00:00")) + return pa.scalar(dt, type=target_type) + elif pa.types.is_date(target_type): + import datetime + if isinstance(literal, str): + return pa.scalar(datetime.date.fromisoformat(literal), type=target_type) + elif pa.types.is_time(target_type): + import datetime + if isinstance(literal, str): + t = datetime.time.fromisoformat(literal) + return pa.scalar(t, type=target_type) + elif pa.types.is_decimal(target_type): + import decimal + return pa.scalar(decimal.Decimal(str(literal)), type=target_type) + return literal + + +def _paimon_type_to_arrow(paimon_type: str) -> pa.DataType: + type_str = paimon_type.strip().upper() + + m = re.match(r"^([A-Z_ ]+?)(?:\((.+)\))?(?:\s+NOT\s+NULL)?$", type_str) + if not m: + raise ValueError(f"Cannot parse Paimon type: '{paimon_type}'") + base_type = m.group(1).strip() + params = m.group(2) + + simple_mapping = { + "INT": pa.int32(), + "BIGINT": pa.int64(), + "SMALLINT": pa.int16(), + "TINYINT": pa.int8(), + "FLOAT": pa.float32(), + "DOUBLE": pa.float64(), + "STRING": pa.string(), + "BOOLEAN": pa.bool_(), + "BYTES": pa.binary(), + "DATE": pa.date32(), + } + if base_type in simple_mapping: + return simple_mapping[base_type] + + if base_type in ("VARCHAR", "CHAR"): + return pa.string() + + if base_type in ("VARBINARY", "BINARY"): + return pa.binary() + + if base_type == "TIMESTAMP": + precision = int(params) if params else 6 + unit = _timestamp_precision_to_unit(precision) + return pa.timestamp(unit) + + if base_type in ("TIMESTAMP WITH LOCAL TIME ZONE", "TIMESTAMP_WITH_LOCAL_TIME_ZONE", "TIMESTAMP_LTZ"): + precision = int(params) if params else 6 + unit = _timestamp_precision_to_unit(precision) + return pa.timestamp(unit, tz="UTC") + + if base_type == "TIME": + precision = int(params) if params else 6 + unit = _timestamp_precision_to_unit(precision) + return pa.time64(unit) if unit in ("us", "ns") else pa.time32(unit) + + if base_type == "DECIMAL": + if params: + parts = [x.strip() for x in params.split(",")] + if len(parts) == 2: + return pa.decimal128(int(parts[0]), int(parts[1])) + raise ValueError(f"DECIMAL type requires (precision, scale): '{paimon_type}'") + + raise ValueError( + f"Unsupported Paimon type for PyArrow conversion: '{paimon_type}'. " + f"Supported: INT, BIGINT, SMALLINT, TINYINT, FLOAT, DOUBLE, STRING, VARCHAR, CHAR, " + f"BOOLEAN, BYTES, VARBINARY, DATE, TIME(p), TIMESTAMP(p), " + f"TIMESTAMP WITH LOCAL TIME ZONE(p), DECIMAL(p,s)." + ) + + +def _timestamp_precision_to_unit(precision: int) -> str: + if precision == 0: + return "s" + elif precision <= 3: + return "ms" + elif precision <= 6: + return "us" + else: + return "ns" + + +def extract_referenced_fields(json_str: str) -> set: + data = json.loads(json_str) + fields = set() + _collect_fields(data, fields) + return fields + + +def _collect_fields(data: dict, fields: set): + kind = data.get("kind") + if kind == "LEAF": + _collect_all_field_refs_from_transform(data["transform"], fields) + elif kind == "COMPOUND": + for child in data["children"]: + _collect_fields(child, fields) + + +def _collect_all_field_refs_from_transform(transform: dict, fields: set = None) -> set: + if fields is None: + fields = set() + name = transform.get("name") + if name == "FIELD_REF" and "fieldRef" in transform: + fields.add(transform["fieldRef"]["name"]) + elif name == "CAST" and "fieldRef" in transform: + fields.add(transform["fieldRef"]["name"]) + elif name in ("UPPER", "LOWER", "CONCAT", "CONCAT_WS"): + for inp in transform.get("inputs", []): + if isinstance(inp, dict) and "name" in inp: + fields.add(inp["name"]) + return fields diff --git a/paimon-python/pypaimon/read/query_auth_split.py b/paimon-python/pypaimon/read/query_auth_split.py new file mode 100644 index 000000000000..667ddb063a1a --- /dev/null +++ b/paimon-python/pypaimon/read/query_auth_split.py @@ -0,0 +1,55 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from pypaimon.read.split import Split + + +class QueryAuthSplit(Split): + + def __init__(self, split: Split, auth_result): + self._split = split + self._auth_result = auth_result + + @property + def split(self) -> Split: + return self._split + + @property + def auth_result(self): + return self._auth_result + + @property + def row_count(self) -> int: + return self._split.row_count + + @property + def files(self): + return self._split.files + + @property + def partition(self): + return self._split.partition + + @property + def bucket(self) -> int: + return self._split.bucket + + def merged_row_count(self): + if self._auth_result.filter: + return None + return self._split.merged_row_count() diff --git a/paimon-python/pypaimon/read/read_builder.py b/paimon-python/pypaimon/read/read_builder.py index 5537c97dd647..447406ad4709 100644 --- a/paimon-python/pypaimon/read/read_builder.py +++ b/paimon-python/pypaimon/read/read_builder.py @@ -33,7 +33,7 @@ class ReadBuilder: """Implementation of ReadBuilder for native Python reading.""" - def __init__(self, table): + def __init__(self, table, query_auth=None): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table @@ -45,6 +45,7 @@ def __init__(self, table): self._projection: Optional[List[str]] = None self._nested_paths: Optional[List[List[int]]] = None self._limit: Optional[int] = None + self._query_auth = query_auth def with_filter(self, predicate: Predicate) -> 'ReadBuilder': self._predicate = predicate @@ -77,7 +78,9 @@ def new_scan(self) -> TableScan: return TableScan( table=self.table, predicate=self._predicate, - limit=self._limit + limit=self._limit, + query_auth=self._query_auth, + read_type=self.read_type() ) def new_read(self) -> TableRead: diff --git a/paimon-python/pypaimon/read/reader/auth_masking_reader.py b/paimon-python/pypaimon/read/reader/auth_masking_reader.py new file mode 100644 index 000000000000..e1a842f969eb --- /dev/null +++ b/paimon-python/pypaimon/read/reader/auth_masking_reader.py @@ -0,0 +1,158 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import json +from typing import Callable, Dict, List, Optional + +import pyarrow as pa +import pyarrow.compute as pc + +from pypaimon.common.predicate_json_parser import ( + _collect_all_field_refs_from_transform, + _paimon_type_to_arrow, +) +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader + + +class AuthFilterReader(RecordBatchReader): + + def __init__(self, inner_reader: RecordBatchReader, filter_fn: Callable[[pa.RecordBatch], pa.Array]): + self._inner = inner_reader + self._filter_fn = filter_fn + + def read_arrow_batch(self) -> Optional[pa.RecordBatch]: + batch = self._inner.read_arrow_batch() + if batch is None: + return None + mask = self._filter_fn(batch) + return batch.filter(mask) + + def close(self): + self._inner.close() + + +class AuthMaskingReader(RecordBatchReader): + + def __init__(self, inner_reader: RecordBatchReader, masking_rules: Dict[str, str], read_fields: List): + self._inner = inner_reader + self._masking_rules = masking_rules + self._read_fields = read_fields + self._parsed_rules = {col: json.loads(tj) for col, tj in masking_rules.items()} + read_field_names = {f.name for f in read_fields} + for col_name, transform in self._parsed_rules.items(): + for ref_name in _collect_all_field_refs_from_transform(transform): + if ref_name not in read_field_names: + raise RuntimeError( + f"Column masking refers to field '{ref_name}' which is not present " + f"in output row type. Available fields: {read_field_names}" + ) + + def read_arrow_batch(self) -> Optional[pa.RecordBatch]: + batch = self._inner.read_arrow_batch() + if batch is None: + return None + original_batch = batch + masked_columns = {} + for col_name, transform in self._parsed_rules.items(): + if col_name in original_batch.schema.names: + col_idx = original_batch.schema.get_field_index(col_name) + target_col_type = original_batch.schema.field(col_idx).type + masked_columns[col_idx] = self._apply_transform(transform, original_batch, target_col_type) + for col_idx, masked_array in masked_columns.items(): + original_field = batch.schema.field(col_idx) + new_field = pa.field(original_field.name, masked_array.type, nullable=True) + batch = batch.set_column(col_idx, new_field, masked_array) + return batch + + def close(self): + self._inner.close() + + def _apply_transform( + self, + transform: dict, + original_batch: pa.RecordBatch, + target_col_type: pa.DataType, + ) -> pa.Array: + name = transform["name"] + + if name == "NULL": + return pa.nulls(len(original_batch), type=target_col_type) + + elif name == "FIELD_REF": + ref_name = transform["fieldRef"]["name"] + return original_batch.column(ref_name) + + elif name == "CAST": + ref_name = transform["fieldRef"]["name"] + source_col = original_batch.column(ref_name) + target_type = _paimon_type_to_arrow(transform["type"]) + return pc.cast(source_col, target_type) + + elif name == "UPPER": + input_col = self._resolve_input(transform["inputs"][0], original_batch) + return pc.utf8_upper(input_col) + + elif name == "LOWER": + input_col = self._resolve_input(transform["inputs"][0], original_batch) + return pc.utf8_lower(input_col) + + elif name == "CONCAT": + return self._apply_concat(transform["inputs"], original_batch) + + elif name == "CONCAT_WS": + return self._apply_concat_ws(transform["inputs"], original_batch) + + raise ValueError(f"Unknown transform type: {name}") + + def _resolve_input(self, inp, original_batch: pa.RecordBatch) -> pa.Array: + if isinstance(inp, dict): + return original_batch.column(inp["name"]) + elif isinstance(inp, str): + return pa.array([inp] * len(original_batch), type=pa.string()) + elif inp is None: + return pa.nulls(len(original_batch), type=pa.string()) + return pa.array([str(inp)] * len(original_batch), type=pa.string()) + + def _apply_concat(self, inputs: list, original_batch: pa.RecordBatch) -> pa.Array: + resolved = [self._resolve_input(inp, original_batch) for inp in inputs] + if not resolved: + return pa.nulls(len(original_batch), type=pa.string()) + return pc.binary_join_element_wise(*resolved, "") + + def _apply_concat_ws(self, inputs: list, original_batch: pa.RecordBatch) -> pa.Array: + if len(inputs) < 2: + return pa.nulls(len(original_batch), type=pa.string()) + sep = self._resolve_input(inputs[0], original_batch) + values = [self._resolve_input(inp, original_batch) for inp in inputs[1:]] + return pc.binary_join_element_wise(*values, sep, null_handling='skip') + + +class ColumnProjectReader(RecordBatchReader): + + def __init__(self, inner_reader: RecordBatchReader, columns: List[str]): + self._inner = inner_reader + self._columns = columns + + def read_arrow_batch(self) -> Optional[pa.RecordBatch]: + batch = self._inner.read_arrow_batch() + if batch is None: + return None + return batch.select(self._columns) + + def close(self): + self._inner.close() diff --git a/paimon-python/pypaimon/read/stream_read_builder.py b/paimon-python/pypaimon/read/stream_read_builder.py index bc4d6b88d522..5a919e11a92c 100644 --- a/paimon-python/pypaimon/read/stream_read_builder.py +++ b/paimon-python/pypaimon/read/stream_read_builder.py @@ -48,7 +48,7 @@ class StreamReadBuilder: process(arrow_table) """ - def __init__(self, table): + def __init__(self, table, query_auth=None): """Initialize the StreamReadBuilder.""" from pypaimon.table.file_store_table import FileStoreTable @@ -59,6 +59,7 @@ def __init__(self, table): self._include_row_kind: bool = False self._bucket_filter: Optional[Callable[[int], bool]] = None self._consumer_id: Optional[str] = None + self._query_auth = query_auth def with_filter(self, predicate: Predicate) -> 'StreamReadBuilder': """Set a filter predicate for the streaming read.""" diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 826b2b4024a1..10fa8b83c75c 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -107,7 +107,7 @@ def _record_generator(): for split in splits: if limit is not None and count >= limit: return - reader = self._create_split_read(split).create_reader() + reader = self._create_reader_for_split(split) try: for batch in iter(reader.read_batch, None): for row in iter(batch.next, None): @@ -198,7 +198,7 @@ def _arrow_batch_generator(self, splits: List[Split], schema: pyarrow.Schema) -> for split in splits: if remaining is not None and remaining <= 0: break - reader = self._create_split_read(split).create_reader() + reader = self._create_reader_for_split(split) try: if isinstance(reader, RecordBatchReader): for batch in iter(reader.read_arrow_batch, None): @@ -611,6 +611,61 @@ def _widen_to_top_level_for_merge(self) -> List[DataField]: widened.append(field) return widened + def _create_reader_for_split(self, split): + from pypaimon.read.query_auth_split import QueryAuthSplit + + auth_result = None + if isinstance(split, QueryAuthSplit): + auth_result = split.auth_result + split = split.split + + if auth_result is not None: + return self._authed_reader(split, auth_result) + else: + return self._create_split_read(split).create_reader() + + def _authed_reader(self, split, auth_result): + from pypaimon.read.reader.auth_masking_reader import ( + AuthFilterReader, AuthMaskingReader, ColumnProjectReader) + + table_fields = self.table.fields + read_fields = self.read_type + + extra_fields = auth_result.get_extra_fields_for_filter(read_fields, table_fields) + effective_read_type = read_fields + if extra_fields: + effective_read_type = read_fields + extra_fields + + reader = self._create_split_read_with_read_type(split, effective_read_type).create_reader() + + filter_fn = auth_result.extract_row_filter() + if filter_fn: + reader = AuthFilterReader(reader, filter_fn) + + if auth_result.column_masking: + reader = AuthMaskingReader(reader, auth_result.column_masking, effective_read_type) + + if extra_fields: + original_columns = [f.name for f in read_fields] + reader = ColumnProjectReader(reader, original_columns) + + return reader + + def _create_split_read_with_read_type(self, split, read_type): + if self.table.is_primary_key_table and not split.raw_convertible: + return MergeFileSplitRead( + table=self.table, predicate=self.predicate, + read_type=read_type, split=split, row_tracking_enabled=False) + elif self.table.options.data_evolution_enabled(): + return DataEvolutionSplitRead( + table=self.table, predicate=self.predicate, + read_type=read_type, split=split, row_tracking_enabled=True) + else: + return RawFileSplitRead( + table=self.table, predicate=self.predicate, + read_type=read_type, split=split, + row_tracking_enabled=self.table.options.row_tracking_enabled()) + @staticmethod def convert_rows_to_arrow_batch(row_tuples: List[tuple], schema: pyarrow.Schema) -> pyarrow.RecordBatch: columns_data = zip(*row_tuples) diff --git a/paimon-python/pypaimon/read/table_scan.py b/paimon-python/pypaimon/read/table_scan.py index 03a1c8b06297..f877c1a117bd 100755 --- a/paimon-python/pypaimon/read/table_scan.py +++ b/paimon-python/pypaimon/read/table_scan.py @@ -33,17 +33,31 @@ def __init__( self, table, predicate: Optional[Predicate], - limit: Optional[int] + limit: Optional[int], + query_auth=None, + read_type=None ): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table self.predicate = predicate self.limit = limit + self._query_auth = query_auth + self._read_type = read_type self.file_scanner = self._create_file_scanner() def plan(self) -> Plan: - return self.file_scanner.scan() + auth_result = self._auth_query() + plan = self.file_scanner.scan() + if auth_result is not None: + plan = auth_result.convert_plan(plan) + return plan + + def _auth_query(self): + if self._query_auth is None: + return None + select = [f.name for f in self._read_type] if self._read_type else None + return self._query_auth(select) def scan_with_stats(self) -> Tuple[Plan, ScanStats]: """Run :meth:`plan` while recording manifest / pruning counters. diff --git a/paimon-python/pypaimon/table/file_store_table.py b/paimon-python/pypaimon/table/file_store_table.py index 9fd61abed49b..6243bca753b8 100644 --- a/paimon-python/pypaimon/table/file_store_table.py +++ b/paimon-python/pypaimon/table/file_store_table.py @@ -63,6 +63,9 @@ def __init__(self, file_io: FileIO, identifier: Identifier, table_path: str, current_branch = self.options.branch() self.schema_manager = SchemaManager(file_io, table_path, branch=current_branch) + self._query_auth_fn = self.catalog_environment.table_query_auth( + self.options, self.identifier) + @classmethod def from_path(cls, table_path: str) -> 'FileStoreTable': """ @@ -408,10 +411,10 @@ def bucket_mode(self) -> BucketMode: return BucketMode.HASH_FIXED def new_read_builder(self) -> 'ReadBuilder': - return ReadBuilder(self) + return ReadBuilder(self, query_auth=self._query_auth_fn) def new_stream_read_builder(self) -> 'StreamReadBuilder': - return StreamReadBuilder(self) + return StreamReadBuilder(self, query_auth=self._query_auth_fn) def new_batch_write_builder(self) -> BatchWriteBuilder: return BatchWriteBuilder(self) diff --git a/paimon-python/pypaimon/tests/auth_masking_reader_test.py b/paimon-python/pypaimon/tests/auth_masking_reader_test.py new file mode 100644 index 000000000000..49e62c915def --- /dev/null +++ b/paimon-python/pypaimon/tests/auth_masking_reader_test.py @@ -0,0 +1,301 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import json +import unittest + +import pyarrow as pa + +from pypaimon.read.reader.auth_masking_reader import ( + AuthFilterReader, + AuthMaskingReader, + ColumnProjectReader, +) + + +class _FakeField: + def __init__(self, name): + self.name = name + + +class _FakeBatchReader: + def __init__(self, batches): + self._batches = iter(batches) + + def read_arrow_batch(self): + return next(self._batches, None) + + def close(self): + pass + + +class TestAuthMaskingReaderTransforms(unittest.TestCase): + + def setUp(self): + self.batch = pa.RecordBatch.from_pydict({ + "name": ["alice", "bob", "charlie"], + "email": ["A@x.com", "B@y.com", "C@z.com"], + "age": [25, 30, 35], + "dept": ["eng", "sales", "eng"], + }) + self.fields = [ + _FakeField("name"), + _FakeField("email"), + _FakeField("age"), + _FakeField("dept"), + ] + + def _apply_masking(self, masking_rules, batch=None, fields=None): + batch = batch or self.batch + fields = fields or self.fields + reader = AuthMaskingReader( + _FakeBatchReader([batch]), masking_rules, fields + ) + return reader.read_arrow_batch() + + def test_null_transform(self): + result = self._apply_masking( + {"name": json.dumps({"name": "NULL"})} + ) + self.assertEqual(result.column("name").to_pylist(), [None, None, None]) + self.assertEqual(result.schema.field("name").type, pa.string()) + + def test_upper_transform(self): + result = self._apply_masking({ + "email": json.dumps({ + "name": "UPPER", + "inputs": [{"index": 1, "name": "email", "type": "STRING"}], + }) + }) + self.assertEqual( + result.column("email").to_pylist(), + ["A@X.COM", "B@Y.COM", "C@Z.COM"], + ) + + def test_lower_transform(self): + result = self._apply_masking({ + "email": json.dumps({ + "name": "LOWER", + "inputs": [{"index": 1, "name": "email", "type": "STRING"}], + }) + }) + self.assertEqual( + result.column("email").to_pylist(), + ["a@x.com", "b@y.com", "c@z.com"], + ) + + def test_field_ref_transform(self): + result = self._apply_masking({ + "name": json.dumps({ + "name": "FIELD_REF", + "fieldRef": {"index": 3, "name": "dept", "type": "STRING"}, + }) + }) + self.assertEqual( + result.column("name").to_pylist(), ["eng", "sales", "eng"] + ) + + def test_cast_transform(self): + result = self._apply_masking({ + "age": json.dumps({ + "name": "CAST", + "fieldRef": {"index": 2, "name": "age", "type": "INT"}, + "type": "BIGINT", + }) + }) + self.assertEqual(result.column("age").type, pa.int64()) + self.assertEqual(result.column("age").to_pylist(), [25, 30, 35]) + + def test_concat_transform(self): + result = self._apply_masking({ + "name": json.dumps({ + "name": "CONCAT", + "inputs": [ + "***", + {"index": 1, "name": "email", "type": "STRING"}, + ], + }) + }) + self.assertEqual( + result.column("name").to_pylist(), + ["***A@x.com", "***B@y.com", "***C@z.com"], + ) + + def test_concat_null_emits_null(self): + batch = pa.RecordBatch.from_pydict({ + "name": ["alice", None, "charlie"], + "tag": ["x", "y", "z"], + }) + fields = [_FakeField("name"), _FakeField("tag")] + result = self._apply_masking( + { + "tag": json.dumps({ + "name": "CONCAT", + "inputs": [ + {"index": 0, "name": "name", "type": "STRING"}, + "@masked", + ], + }) + }, + batch=batch, + fields=fields, + ) + self.assertEqual( + result.column("tag").to_pylist(), + ["alice@masked", None, "charlie@masked"], + ) + + def test_concat_ws_transform(self): + batch = pa.RecordBatch.from_pydict({ + "name": ["alice", None, "charlie"], + "dept": ["eng", "sales", "eng"], + }) + fields = [_FakeField("name"), _FakeField("dept")] + result = self._apply_masking( + { + "name": json.dumps({ + "name": "CONCAT_WS", + "inputs": [ + "-", + {"index": 0, "name": "name", "type": "STRING"}, + {"index": 1, "name": "dept", "type": "STRING"}, + ], + }) + }, + batch=batch, + fields=fields, + ) + self.assertEqual( + result.column("name").to_pylist(), + ["alice-eng", "sales", "charlie-eng"], + ) + + def test_concat_ws_field_ref_separator(self): + batch = pa.RecordBatch.from_pydict({ + "sep": ["-", "|", ":"], + "a": ["x", "y", "z"], + "b": ["1", "2", "3"], + }) + fields = [_FakeField("sep"), _FakeField("a"), _FakeField("b")] + result = self._apply_masking( + { + "a": json.dumps({ + "name": "CONCAT_WS", + "inputs": [ + {"index": 0, "name": "sep", "type": "STRING"}, + {"index": 1, "name": "a", "type": "STRING"}, + {"index": 2, "name": "b", "type": "STRING"}, + ], + }) + }, + batch=batch, + fields=fields, + ) + self.assertEqual( + result.column("a").to_pylist(), ["x-1", "y|2", "z:3"] + ) + + +class TestMaskingOrderIndependence(unittest.TestCase): + + def test_cross_reference_uses_original_batch(self): + batch = pa.RecordBatch.from_pydict({"a": ["x", "y"], "b": ["p", "q"]}) + fields = [_FakeField("a"), _FakeField("b")] + masking = { + "a": json.dumps({ + "name": "FIELD_REF", + "fieldRef": {"index": 1, "name": "b", "type": "STRING"}, + }), + "b": json.dumps({ + "name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "a", "type": "STRING"}, + }), + } + reader = AuthMaskingReader(_FakeBatchReader([batch]), masking, fields) + result = reader.read_arrow_batch() + self.assertEqual(result.column("a").to_pylist(), ["p", "q"]) + self.assertEqual(result.column("b").to_pylist(), ["x", "y"]) + + +class TestMaskingFieldValidation(unittest.TestCase): + + def test_missing_field_raises(self): + batch = pa.RecordBatch.from_pydict({"name": ["alice"]}) + fields = [_FakeField("name")] + with self.assertRaises(RuntimeError) as ctx: + AuthMaskingReader( + _FakeBatchReader([batch]), + { + "name": json.dumps({ + "name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "nonexistent", "type": "STRING"}, + }) + }, + fields, + ) + self.assertIn("nonexistent", str(ctx.exception)) + + +class TestAuthFilterReader(unittest.TestCase): + + def test_filters_rows(self): + import pyarrow.compute as pc + + batch = pa.RecordBatch.from_pydict({ + "dept": ["eng", "sales", "eng", "hr"], + }) + + def filter_fn(b): + return pc.equal(b.column("dept"), "eng") + + reader = AuthFilterReader(_FakeBatchReader([batch]), filter_fn) + result = reader.read_arrow_batch() + self.assertEqual(result.num_rows, 2) + self.assertEqual(result.column("dept").to_pylist(), ["eng", "eng"]) + + def test_returns_none_at_end(self): + import pyarrow.compute as pc + + reader = AuthFilterReader( + _FakeBatchReader([]), + lambda b: pc.equal(b.column("x"), 1), + ) + self.assertIsNone(reader.read_arrow_batch()) + + +class TestColumnProjectReader(unittest.TestCase): + + def test_selects_columns(self): + batch = pa.RecordBatch.from_pydict({ + "a": [1, 2], + "b": ["x", "y"], + "c": [3.0, 4.0], + }) + reader = ColumnProjectReader(_FakeBatchReader([batch]), ["a", "c"]) + result = reader.read_arrow_batch() + self.assertEqual(result.schema.names, ["a", "c"]) + self.assertEqual(result.column("a").to_pylist(), [1, 2]) + self.assertEqual(result.column("c").to_pylist(), [3.0, 4.0]) + + def test_returns_none_at_end(self): + reader = ColumnProjectReader(_FakeBatchReader([]), ["a"]) + self.assertIsNone(reader.read_arrow_batch()) + + +if __name__ == "__main__": + unittest.main() diff --git a/paimon-python/pypaimon/tests/predicate_json_parser_test.py b/paimon-python/pypaimon/tests/predicate_json_parser_test.py new file mode 100644 index 000000000000..2e927aa2f38b --- /dev/null +++ b/paimon-python/pypaimon/tests/predicate_json_parser_test.py @@ -0,0 +1,462 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import json +import unittest + +import pyarrow as pa + +from pypaimon.common.predicate_json_parser import ( + _convert_literal, + _paimon_type_to_arrow, + extract_referenced_fields, + parse_predicate_to_batch_filter, +) + + +def _make_leaf(field_name, function, literals=None, field_type="INT"): + d = { + "kind": "LEAF", + "transform": { + "name": "FIELD_REF", + "fieldRef": {"index": 0, "name": field_name, "type": field_type}, + }, + "function": function, + } + if literals is not None: + d["literals"] = literals + return json.dumps(d) + + +class TestLeafFunctions(unittest.TestCase): + + def setUp(self): + self.batch = pa.RecordBatch.from_pydict({ + "id": [1, 2, 3, 4, 5], + "name": ["alice", "bob", "charlie", "alice", "eve"], + "score": [85.5, 92.0, 78.3, 95.1, 60.0], + }) + + def _filter(self, json_str): + return parse_predicate_to_batch_filter(json_str)(self.batch).to_pylist() + + def test_equal(self): + self.assertEqual( + self._filter(_make_leaf("id", "EQUAL", [3])), + [False, False, True, False, False], + ) + + def test_not_equal(self): + self.assertEqual( + self._filter(_make_leaf("id", "NOT_EQUAL", [3])), + [True, True, False, True, True], + ) + + def test_less_than(self): + self.assertEqual( + self._filter(_make_leaf("id", "LESS_THAN", [3])), + [True, True, False, False, False], + ) + + def test_less_or_equal(self): + self.assertEqual( + self._filter(_make_leaf("id", "LESS_OR_EQUAL", [3])), + [True, True, True, False, False], + ) + + def test_greater_than(self): + self.assertEqual( + self._filter(_make_leaf("id", "GREATER_THAN", [3])), + [False, False, False, True, True], + ) + + def test_greater_or_equal(self): + self.assertEqual( + self._filter(_make_leaf("id", "GREATER_OR_EQUAL", [3])), + [False, False, True, True, True], + ) + + def test_is_null(self): + batch = pa.RecordBatch.from_pydict({"val": [1, None, 3, None, 5]}) + f = parse_predicate_to_batch_filter(_make_leaf("val", "IS_NULL")) + self.assertEqual(f(batch).to_pylist(), [False, True, False, True, False]) + + def test_is_not_null(self): + batch = pa.RecordBatch.from_pydict({"val": [1, None, 3, None, 5]}) + f = parse_predicate_to_batch_filter(_make_leaf("val", "IS_NOT_NULL")) + self.assertEqual(f(batch).to_pylist(), [True, False, True, False, True]) + + def test_in(self): + self.assertEqual( + self._filter(_make_leaf("id", "IN", [1, 3, 5])), + [True, False, True, False, True], + ) + + def test_not_in(self): + self.assertEqual( + self._filter(_make_leaf("id", "NOT_IN", [1, 3, 5])), + [False, True, False, True, False], + ) + + def test_between(self): + self.assertEqual( + self._filter(_make_leaf("id", "BETWEEN", [2, 4])), + [False, True, True, True, False], + ) + + def test_not_between(self): + self.assertEqual( + self._filter(_make_leaf("id", "NOT_BETWEEN", [2, 4])), + [True, False, False, False, True], + ) + + def test_starts_with(self): + self.assertEqual( + self._filter(_make_leaf("name", "STARTS_WITH", ["al"], "STRING")), + [True, False, False, True, False], + ) + + def test_ends_with(self): + self.assertEqual( + self._filter(_make_leaf("name", "ENDS_WITH", ["e"], "STRING")), + [True, False, True, True, True], + ) + + def test_contains(self): + self.assertEqual( + self._filter(_make_leaf("name", "CONTAINS", ["li"], "STRING")), + [True, False, True, True, False], + ) + + def test_like(self): + like_json = json.dumps({ + "kind": "LEAF", + "transform": {"name": "FIELD_REF", "fieldRef": {"index": 0, "name": "name", "type": "STRING"}}, + "function": "LIKE", + "literals": ["%li%"], + }) + self.assertEqual( + parse_predicate_to_batch_filter(like_json)(self.batch).to_pylist(), + [True, False, True, True, False], + ) + + def test_true(self): + self.assertEqual( + self._filter(_make_leaf("id", "TRUE")), + [True, True, True, True, True], + ) + + def test_false(self): + self.assertEqual( + self._filter(_make_leaf("id", "FALSE")), + [False, False, False, False, False], + ) + + +class TestCompoundPredicates(unittest.TestCase): + + def setUp(self): + self.batch = pa.RecordBatch.from_pydict({ + "id": [1, 2, 3, 4, 5], + }) + + def test_and(self): + pred = json.dumps({ + "kind": "COMPOUND", + "function": "AND", + "children": [ + {"kind": "LEAF", + "transform": {"name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "id", "type": "INT"}}, + "function": "GREATER_THAN", "literals": [2]}, + {"kind": "LEAF", + "transform": {"name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "id", "type": "INT"}}, + "function": "LESS_THAN", "literals": [5]}, + ], + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(self.batch).to_pylist(), [False, False, True, True, False]) + + def test_or(self): + pred = json.dumps({ + "kind": "COMPOUND", + "function": "OR", + "children": [ + {"kind": "LEAF", + "transform": {"name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "id", "type": "INT"}}, + "function": "EQUAL", "literals": [1]}, + {"kind": "LEAF", + "transform": {"name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "id", "type": "INT"}}, + "function": "EQUAL", "literals": [5]}, + ], + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(self.batch).to_pylist(), [True, False, False, False, True]) + + +class TestPredicateTransforms(unittest.TestCase): + + def setUp(self): + self.batch = pa.RecordBatch.from_pydict({ + "id": [1, 2, 3, 4, 5], + "name": ["alice", "bob", "charlie", "alice", "eve"], + }) + + def test_upper_transform(self): + pred = json.dumps({ + "kind": "LEAF", + "transform": {"name": "UPPER", "inputs": [{"index": 1, "name": "name", "type": "STRING"}]}, + "function": "EQUAL", + "literals": ["ALICE"], + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(self.batch).to_pylist(), [True, False, False, True, False]) + + def test_lower_transform(self): + pred = json.dumps({ + "kind": "LEAF", + "transform": {"name": "LOWER", "inputs": [{"index": 1, "name": "name", "type": "STRING"}]}, + "function": "EQUAL", + "literals": ["bob"], + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(self.batch).to_pylist(), [False, True, False, False, False]) + + def test_cast_transform(self): + pred = json.dumps({ + "kind": "LEAF", + "transform": {"name": "CAST", "fieldRef": {"index": 0, "name": "id", "type": "INT"}, "type": "BIGINT"}, + "function": "GREATER_THAN", + "literals": [3], + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(self.batch).to_pylist(), [False, False, False, True, True]) + + def test_null_transform(self): + pred = json.dumps({ + "kind": "LEAF", + "transform": {"name": "NULL"}, + "function": "IS_NULL", + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(self.batch).to_pylist(), [True, True, True, True, True]) + + def test_concat_transform_in_predicate(self): + batch = pa.RecordBatch.from_pydict({ + "first": ["john", "jane"], + "last": ["doe", "smith"], + }) + pred = json.dumps({ + "kind": "LEAF", + "transform": { + "name": "CONCAT", + "inputs": [ + {"index": 0, "name": "first", "type": "STRING"}, + {"index": 1, "name": "last", "type": "STRING"}, + ], + }, + "function": "EQUAL", + "literals": ["johndoe"], + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(batch).to_pylist(), [True, False]) + + def test_concat_ws_transform_in_predicate(self): + batch = pa.RecordBatch.from_pydict({ + "first": ["john", "jane"], + "last": ["doe", "smith"], + }) + pred = json.dumps({ + "kind": "LEAF", + "transform": { + "name": "CONCAT_WS", + "inputs": [ + " ", + {"index": 0, "name": "first", "type": "STRING"}, + {"index": 1, "name": "last", "type": "STRING"}, + ], + }, + "function": "EQUAL", + "literals": ["john doe"], + }) + f = parse_predicate_to_batch_filter(pred) + self.assertEqual(f(batch).to_pylist(), [True, False]) + + +class TestPaimonTypeToArrow(unittest.TestCase): + + def test_simple_types(self): + self.assertEqual(_paimon_type_to_arrow("INT"), pa.int32()) + self.assertEqual(_paimon_type_to_arrow("BIGINT"), pa.int64()) + self.assertEqual(_paimon_type_to_arrow("SMALLINT"), pa.int16()) + self.assertEqual(_paimon_type_to_arrow("TINYINT"), pa.int8()) + self.assertEqual(_paimon_type_to_arrow("FLOAT"), pa.float32()) + self.assertEqual(_paimon_type_to_arrow("DOUBLE"), pa.float64()) + self.assertEqual(_paimon_type_to_arrow("STRING"), pa.string()) + self.assertEqual(_paimon_type_to_arrow("BOOLEAN"), pa.bool_()) + self.assertEqual(_paimon_type_to_arrow("BYTES"), pa.binary()) + self.assertEqual(_paimon_type_to_arrow("DATE"), pa.date32()) + + def test_varchar_char(self): + self.assertEqual(_paimon_type_to_arrow("VARCHAR"), pa.string()) + self.assertEqual(_paimon_type_to_arrow("VARCHAR(100)"), pa.string()) + self.assertEqual(_paimon_type_to_arrow("CHAR(10)"), pa.string()) + + def test_timestamp(self): + self.assertEqual(_paimon_type_to_arrow("TIMESTAMP(3)"), pa.timestamp("ms")) + self.assertEqual(_paimon_type_to_arrow("TIMESTAMP(6)"), pa.timestamp("us")) + self.assertEqual(_paimon_type_to_arrow("TIMESTAMP(9)"), pa.timestamp("ns")) + self.assertEqual(_paimon_type_to_arrow("TIMESTAMP(0)"), pa.timestamp("s")) + + def test_timestamp_ltz(self): + self.assertEqual( + _paimon_type_to_arrow("TIMESTAMP WITH LOCAL TIME ZONE(6)"), + pa.timestamp("us", tz="UTC"), + ) + + def test_decimal(self): + self.assertEqual( + _paimon_type_to_arrow("DECIMAL(10, 2)"), + pa.decimal128(10, 2), + ) + + def test_unsupported_type_raises(self): + with self.assertRaises(ValueError): + _paimon_type_to_arrow("ARRAY") + + +class TestConvertLiteral(unittest.TestCase): + + def test_none_literal(self): + self.assertIsNone(_convert_literal(None, pa.int32())) + + def test_int_passthrough(self): + self.assertEqual(_convert_literal(42, pa.int32()), 42) + + def test_string_passthrough(self): + self.assertEqual(_convert_literal("hello", pa.string()), "hello") + + def test_timestamp_literal(self): + result = _convert_literal("2024-01-15T10:30:00", pa.timestamp("us")) + self.assertIsInstance(result, pa.Scalar) + + def test_timestamp_z_suffix(self): + result = _convert_literal("2024-01-15T10:30:00Z", pa.timestamp("us", tz="UTC")) + self.assertIsInstance(result, pa.Scalar) + + def test_date_literal(self): + result = _convert_literal("2024-01-15", pa.date32()) + self.assertIsInstance(result, pa.Scalar) + + def test_decimal_literal(self): + result = _convert_literal(123.45, pa.decimal128(10, 2)) + self.assertIsInstance(result, pa.Scalar) + + +class TestExtractReferencedFields(unittest.TestCase): + + def test_leaf_field_ref(self): + refs = extract_referenced_fields(json.dumps({ + "kind": "LEAF", + "transform": {"name": "FIELD_REF", "fieldRef": {"index": 0, "name": "col1", "type": "INT"}}, + "function": "EQUAL", + "literals": [1], + })) + self.assertEqual(refs, {"col1"}) + + def test_leaf_cast(self): + refs = extract_referenced_fields(json.dumps({ + "kind": "LEAF", + "transform": {"name": "CAST", "fieldRef": {"index": 0, "name": "col1", "type": "INT"}, "type": "BIGINT"}, + "function": "EQUAL", + "literals": [1], + })) + self.assertEqual(refs, {"col1"}) + + def test_compound_collects_all(self): + refs = extract_referenced_fields(json.dumps({ + "kind": "COMPOUND", + "function": "AND", + "children": [ + {"kind": "LEAF", + "transform": {"name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "a", "type": "INT"}}, + "function": "EQUAL", "literals": [1]}, + {"kind": "LEAF", + "transform": {"name": "UPPER", + "inputs": [{"index": 1, "name": "b", "type": "STRING"}]}, + "function": "EQUAL", "literals": ["X"]}, + ], + })) + self.assertEqual(refs, {"a", "b"}) + + def test_concat_inputs(self): + refs = extract_referenced_fields(json.dumps({ + "kind": "LEAF", + "transform": { + "name": "CONCAT", + "inputs": [ + {"index": 0, "name": "first", "type": "STRING"}, + "literal", + {"index": 1, "name": "last", "type": "STRING"}, + ], + }, + "function": "EQUAL", + "literals": ["x"], + })) + self.assertEqual(refs, {"first", "last"}) + + +class TestLikeEdgeCases(unittest.TestCase): + + def test_dot_is_literal(self): + batch = pa.RecordBatch.from_pydict({"v": ["a.b", "axb", "a.bc"]}) + f = parse_predicate_to_batch_filter(json.dumps({ + "kind": "LEAF", + "transform": {"name": "FIELD_REF", "fieldRef": {"index": 0, "name": "v", "type": "STRING"}}, + "function": "LIKE", + "literals": ["a.b"], + })) + self.assertEqual(f(batch).to_pylist(), [True, False, False]) + + def test_underscore_matches_single_char(self): + batch = pa.RecordBatch.from_pydict({"v": ["abc", "axc", "ac", "abbc"]}) + f = parse_predicate_to_batch_filter(json.dumps({ + "kind": "LEAF", + "transform": {"name": "FIELD_REF", "fieldRef": {"index": 0, "name": "v", "type": "STRING"}}, + "function": "LIKE", + "literals": ["a_c"], + })) + self.assertEqual(f(batch).to_pylist(), [True, True, False, False]) + + def test_percent_matches_any(self): + batch = pa.RecordBatch.from_pydict({"v": ["abc", "ac", "axyzc", "def"]}) + f = parse_predicate_to_batch_filter(json.dumps({ + "kind": "LEAF", + "transform": {"name": "FIELD_REF", "fieldRef": {"index": 0, "name": "v", "type": "STRING"}}, + "function": "LIKE", + "literals": ["a%c"], + })) + self.assertEqual(f(batch).to_pylist(), [True, True, True, False]) + + +if __name__ == "__main__": + unittest.main() diff --git a/paimon-python/pypaimon/tests/table_query_auth_test.py b/paimon-python/pypaimon/tests/table_query_auth_test.py new file mode 100644 index 000000000000..4fd2af73a393 --- /dev/null +++ b/paimon-python/pypaimon/tests/table_query_auth_test.py @@ -0,0 +1,249 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import json +import unittest + +import pyarrow as pa + +from pypaimon.catalog.table_query_auth import ( + TableNoPermissionException, + TableQueryAuthResult, +) +from pypaimon.common.options import Options +from pypaimon.common.options.core_options import CoreOptions +from pypaimon.read.query_auth_split import QueryAuthSplit + + +class _FakeField: + def __init__(self, name): + self.name = name + + +class _FakeSplit: + def __init__(self, name="s1"): + self.name = name + + @property + def row_count(self): + return 100 + + @property + def files(self): + return [] + + @property + def partition(self): + return {} + + @property + def bucket(self): + return 0 + + def merged_row_count(self): + return 100 + + +class _FakePlan: + def __init__(self, splits): + self._splits = splits + + def splits(self): + return self._splits + + +def _simple_filter_json(field_name="dept", value="eng"): + return json.dumps({ + "kind": "LEAF", + "transform": { + "name": "FIELD_REF", + "fieldRef": {"index": 0, "name": field_name, "type": "STRING"}, + }, + "function": "EQUAL", + "literals": [value], + }) + + +class TestTableNoPermissionException(unittest.TestCase): + + def test_message_format(self): + exc = TableNoPermissionException("db.table", Exception("forbidden")) + self.assertIn("db.table", str(exc)) + self.assertIn("has no permission", str(exc)) + self.assertEqual(exc.identifier, "db.table") + + def test_without_cause(self): + exc = TableNoPermissionException("db.table") + self.assertIn("db.table", str(exc)) + + +class TestTableQueryAuthResultConvertPlan(unittest.TestCase): + + def test_no_auth_returns_original_plan(self): + result = TableQueryAuthResult(None, None) + plan = _FakePlan([_FakeSplit()]) + converted = result.convert_plan(plan) + self.assertIs(converted, plan) + + def test_empty_filter_and_masking_returns_original(self): + result = TableQueryAuthResult([], {}) + plan = _FakePlan([_FakeSplit()]) + converted = result.convert_plan(plan) + self.assertIs(converted, plan) + + def test_wraps_splits_with_filter(self): + result = TableQueryAuthResult([_simple_filter_json()], None) + plan = _FakePlan([_FakeSplit("s1"), _FakeSplit("s2")]) + converted = result.convert_plan(plan) + self.assertEqual(len(converted.splits()), 2) + for qs in converted.splits(): + self.assertIsInstance(qs, QueryAuthSplit) + + def test_wraps_splits_with_masking(self): + result = TableQueryAuthResult(None, {"col": '{"name":"NULL"}'}) + plan = _FakePlan([_FakeSplit()]) + converted = result.convert_plan(plan) + self.assertEqual(len(converted.splits()), 1) + self.assertIsInstance(converted.splits()[0], QueryAuthSplit) + + def test_inner_split_preserved(self): + result = TableQueryAuthResult([_simple_filter_json()], None) + plan = _FakePlan([_FakeSplit("original")]) + converted = result.convert_plan(plan) + self.assertEqual(converted.splits()[0].split.name, "original") + + +class TestTableQueryAuthResultExtractRowFilter(unittest.TestCase): + + def test_no_filter_returns_none(self): + result = TableQueryAuthResult(None, None) + self.assertIsNone(result.extract_row_filter()) + + def test_single_filter(self): + result = TableQueryAuthResult([_simple_filter_json("dept", "eng")], None) + fn = result.extract_row_filter() + self.assertIsNotNone(fn) + batch = pa.RecordBatch.from_pydict({"dept": ["eng", "sales", "eng"]}) + mask = fn(batch) + self.assertEqual(mask.to_pylist(), [True, False, True]) + + def test_multiple_filters_combined_with_and(self): + f1 = _simple_filter_json("dept", "eng") + f2 = json.dumps({ + "kind": "LEAF", + "transform": { + "name": "FIELD_REF", + "fieldRef": {"index": 0, "name": "dept", "type": "STRING"}, + }, + "function": "NOT_EQUAL", + "literals": ["eng"], + }) + result = TableQueryAuthResult([f1, f2], None) + fn = result.extract_row_filter() + batch = pa.RecordBatch.from_pydict({"dept": ["eng", "sales", "eng"]}) + mask = fn(batch) + self.assertEqual(mask.to_pylist(), [False, False, False]) + + +class TestTableQueryAuthResultExtraFields(unittest.TestCase): + + def test_detects_unprojected_field(self): + read_fields = [_FakeField("name"), _FakeField("age")] + table_fields = [ + _FakeField("name"), + _FakeField("age"), + _FakeField("dept"), + ] + result = TableQueryAuthResult([_simple_filter_json("dept")], None) + extra = result.get_extra_fields_for_filter(read_fields, table_fields) + self.assertEqual(len(extra), 1) + self.assertEqual(extra[0].name, "dept") + + def test_no_extra_when_already_projected(self): + read_fields = [_FakeField("name"), _FakeField("dept")] + table_fields = read_fields + [_FakeField("age")] + result = TableQueryAuthResult([_simple_filter_json("dept")], None) + extra = result.get_extra_fields_for_filter(read_fields, table_fields) + self.assertEqual(len(extra), 0) + + def test_no_extra_when_no_filter(self): + result = TableQueryAuthResult(None, None) + extra = result.get_extra_fields_for_filter( + [_FakeField("a")], [_FakeField("a"), _FakeField("b")] + ) + self.assertEqual(len(extra), 0) + + def test_deduplicates_extra_fields(self): + f1 = _simple_filter_json("dept", "eng") + f2 = _simple_filter_json("dept", "sales") + read_fields = [_FakeField("name")] + table_fields = [_FakeField("name"), _FakeField("dept")] + result = TableQueryAuthResult([f1, f2], None) + extra = result.get_extra_fields_for_filter(read_fields, table_fields) + self.assertEqual(len(extra), 1) + + +class TestQueryAuthSplit(unittest.TestCase): + + def test_delegates_properties(self): + auth = TableQueryAuthResult([_simple_filter_json()], None) + split = _FakeSplit() + qs = QueryAuthSplit(split, auth) + self.assertEqual(qs.row_count, 100) + self.assertEqual(qs.bucket, 0) + self.assertEqual(qs.files, []) + self.assertEqual(qs.partition, {}) + + def test_merged_row_count_none_with_filter(self): + auth = TableQueryAuthResult([_simple_filter_json()], None) + qs = QueryAuthSplit(_FakeSplit(), auth) + self.assertIsNone(qs.merged_row_count()) + + def test_merged_row_count_delegates_without_filter(self): + auth = TableQueryAuthResult(None, {"col": '{"name":"NULL"}'}) + qs = QueryAuthSplit(_FakeSplit(), auth) + self.assertEqual(qs.merged_row_count(), 100) + + def test_exposes_inner_split(self): + split = _FakeSplit("inner") + qs = QueryAuthSplit(split, TableQueryAuthResult(None, None)) + self.assertIs(qs.split, split) + + def test_exposes_auth_result(self): + auth = TableQueryAuthResult(None, None) + qs = QueryAuthSplit(_FakeSplit(), auth) + self.assertIs(qs.auth_result, auth) + + +class TestCoreOptionsQueryAuth(unittest.TestCase): + + def test_disabled_by_default(self): + opts = CoreOptions(Options({})) + self.assertFalse(opts.query_auth_enabled) + + def test_enabled_when_set(self): + opts = CoreOptions(Options({"query-auth.enabled": True})) + self.assertTrue(opts.query_auth_enabled) + + def test_disabled_when_false(self): + opts = CoreOptions(Options({"query-auth.enabled": False})) + self.assertFalse(opts.query_auth_enabled) + + +if __name__ == "__main__": + unittest.main()