From 202acc62762eedd0319bb9e4249728d675a408df Mon Sep 17 00:00:00 2001 From: jkukreja Date: Wed, 27 May 2026 14:46:46 -0400 Subject: [PATCH 1/4] [flink] Implement FLIP-314 LineageVertexProvider for non-table APIs --- .../paimon/flink/lineage/LineageUtils.java | 6 +- .../flink/lineage/PaimonLineageDataset.java | 51 +++++++++++ .../sink/FlinkFormatTableDataStreamSink.java | 10 ++- .../apache/paimon/flink/sink/FlinkSink.java | 3 +- .../flink/sink/PaimonDiscardingSink.java | 46 ++++++++++ .../flink/source/FlinkSourceBuilder.java | 5 +- .../flink/source/PaimonDataStreamSource.java | 90 +++++++++++++++++++ .../flink/source/operator/MonitorSource.java | 44 +++++++-- .../flink/lineage/LineageUtilsTest.java | 67 ++++++++++++-- .../FlinkFormatTableDataStreamSinkTest.java | 73 +++++++++++++++ .../sink/FlinkSinkBuilderLineageTest.java | 84 +++++++++++++++++ .../flink/source/FlinkSourceBuilderTest.java | 68 ++++++++++++++ 12 files changed, 528 insertions(+), 19 deletions(-) create mode 100644 paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java create mode 100644 paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java create mode 100644 paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSinkTest.java create mode 100644 paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkSinkBuilderLineageTest.java diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java index 110365c76ee4..3d41de6f9d55 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java @@ -78,7 +78,8 @@ public static String getNamespace(Table table) { public static SourceLineageVertex sourceLineageVertex( String name, boolean isBounded, Table table) { LineageDataset dataset = - new PaimonLineageDataset(name, getNamespace(table), buildConfigMap(table)); + new PaimonLineageDataset( + name, getNamespace(table), buildConfigMap(table), table.rowType()); Boundedness boundedness = isBounded ? Boundedness.BOUNDED : Boundedness.CONTINUOUS_UNBOUNDED; return new PaimonSourceLineageVertex(boundedness, Collections.singletonList(dataset)); @@ -92,7 +93,8 @@ public static SourceLineageVertex sourceLineageVertex( */ public static LineageVertex sinkLineageVertex(String name, Table table) { LineageDataset dataset = - new PaimonLineageDataset(name, getNamespace(table), buildConfigMap(table)); + new PaimonLineageDataset( + name, getNamespace(table), buildConfigMap(table), table.rowType()); return new PaimonSinkLineageVertex(Collections.singletonList(dataset)); } } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/PaimonLineageDataset.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/PaimonLineageDataset.java index 5e99df0b2d4d..c122d14cc83e 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/PaimonLineageDataset.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/PaimonLineageDataset.java @@ -18,11 +18,19 @@ package org.apache.paimon.flink.lineage; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.RowType; + import org.apache.flink.streaming.api.lineage.DatasetConfigFacet; +import org.apache.flink.streaming.api.lineage.DatasetSchemaFacet; +import org.apache.flink.streaming.api.lineage.DatasetSchemaField; import org.apache.flink.streaming.api.lineage.LineageDataset; import org.apache.flink.streaming.api.lineage.LineageDatasetFacet; +import javax.annotation.Nullable; + import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.Map; /** @@ -34,11 +42,21 @@ public class PaimonLineageDataset implements LineageDataset { private final String name; private final String namespace; private final Map tableOptions; + @Nullable private final RowType rowType; public PaimonLineageDataset(String name, String namespace, Map tableOptions) { + this(name, namespace, tableOptions, null); + } + + public PaimonLineageDataset( + String name, + String namespace, + Map tableOptions, + @Nullable RowType rowType) { this.name = name; this.namespace = namespace; this.tableOptions = tableOptions; + this.rowType = rowType; } @Override @@ -67,6 +85,39 @@ public Map config() { return tableOptions; } }); + if (rowType != null) { + facets.put( + "schema", + new DatasetSchemaFacet() { + @Override + public String name() { + return "schema"; + } + + @Override + public Map> fields() { + Map> result = new LinkedHashMap<>(); + for (DataField field : rowType.getFields()) { + String fieldName = field.name(); + String fieldType = field.type().asSQLString(); + result.put( + fieldName, + new DatasetSchemaField() { + @Override + public String name() { + return fieldName; + } + + @Override + public String type() { + return fieldType; + } + }); + } + return result; + } + }); + } return facets; } } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java index 3d133a9ceace..7305076cba57 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java @@ -20,6 +20,7 @@ import org.apache.paimon.data.InternalRow; import org.apache.paimon.flink.FlinkRowWrapper; +import org.apache.paimon.flink.lineage.LineageUtils; import org.apache.paimon.table.FormatTable; import org.apache.paimon.table.format.FormatTableWrite; import org.apache.paimon.table.sink.BatchTableCommit; @@ -32,6 +33,8 @@ import org.apache.flink.api.connector.sink2.WriterInitContext; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSink; +import org.apache.flink.streaming.api.lineage.LineageVertex; +import org.apache.flink.streaming.api.lineage.LineageVertexProvider; import org.apache.flink.table.data.RowData; import java.util.List; @@ -55,7 +58,7 @@ public DataStreamSink sinkFrom(DataStream dataStream) { return dataStream.sinkTo(new FormatTableSink(table, overwrite, staticPartitions)); } - private static class FormatTableSink implements Sink { + private static class FormatTableSink implements Sink, LineageVertexProvider { private final FormatTable table; private final boolean overwrite; @@ -68,6 +71,11 @@ public FormatTableSink( this.staticPartitions = staticPartitions; } + @Override + public LineageVertex getLineageVertex() { + return LineageUtils.sinkLineageVertex(table.fullName(), table); + } + /** * Do not annotate with @override here to maintain compatibility with Flink * 2.0+. diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java index 959132ad58e0..9948863c547b 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java @@ -40,7 +40,6 @@ import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.environment.CheckpointConfig; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink; import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory; import javax.annotation.Nullable; @@ -240,7 +239,7 @@ public DataStreamSink doCommit(DataStream written, String commit } configureSlotSharingGroup( committed, options.get(SINK_COMMITTER_CPU), options.get(SINK_COMMITTER_MEMORY)); - return committed.sinkTo(new DiscardingSink<>()).name("end").setParallelism(1); + return committed.sinkTo(new PaimonDiscardingSink<>(table)).name("end").setParallelism(1); } public static void configureSlotSharingGroup( diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java new file mode 100644 index 000000000000..c5949dfa07ac --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.sink; + +import org.apache.paimon.flink.lineage.LineageUtils; +import org.apache.paimon.table.FileStoreTable; + +import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink; +import org.apache.flink.streaming.api.lineage.LineageVertex; +import org.apache.flink.streaming.api.lineage.LineageVertexProvider; + +/** + * A {@link DiscardingSink} that implements {@link LineageVertexProvider} so Flink's lineage graph + * discovers the Paimon sink table when using the DataStream API. + */ +public class PaimonDiscardingSink extends DiscardingSink implements LineageVertexProvider { + + private static final long serialVersionUID = 1L; + + private final FileStoreTable table; + + public PaimonDiscardingSink(FileStoreTable table) { + this.table = table; + } + + @Override + public LineageVertex getLineageVertex() { + return LineageUtils.sinkLineageVertex(table.fullName(), table); + } +} diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java index 3e96dec1ea50..37346573a262 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java @@ -240,7 +240,7 @@ private DataStream buildAlignedContinuousFileSource() { private DataStream toDataStream(Source source) { DataStreamSource dataStream = env.fromSource( - source, + new PaimonDataStreamSource<>(source, table), watermarkStrategy == null ? WatermarkStrategy.noWatermarks() : watermarkStrategy, @@ -354,7 +354,8 @@ private DataStream buildDedicatedSplitGenSource(boolean isBounded) { unordered, outerProject(), isBounded, - limit); + limit, + table); if (parallelism != null) { dataStream.getTransformation().setParallelism(parallelism); } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java new file mode 100644 index 000000000000..95999ab39e3f --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.source; + +import org.apache.paimon.flink.lineage.LineageUtils; +import org.apache.paimon.table.Table; + +import org.apache.flink.api.connector.source.Boundedness; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.connector.source.SourceReader; +import org.apache.flink.api.connector.source.SourceReaderContext; +import org.apache.flink.api.connector.source.SourceSplit; +import org.apache.flink.api.connector.source.SplitEnumerator; +import org.apache.flink.api.connector.source.SplitEnumeratorContext; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.streaming.api.lineage.LineageVertex; +import org.apache.flink.streaming.api.lineage.LineageVertexProvider; + +/** + * A {@link Source} wrapper that preserves the wrapped source behavior and exposes Paimon lineage + * for sources built through {@link FlinkSourceBuilder}. + */ +public class PaimonDataStreamSource + implements Source, LineageVertexProvider { + + private static final long serialVersionUID = 1L; + + private final Source source; + private final Table table; + + public PaimonDataStreamSource(Source source, Table table) { + this.source = source; + this.table = table; + } + + @Override + public Boundedness getBoundedness() { + return source.getBoundedness(); + } + + @Override + public SourceReader createReader(SourceReaderContext readerContext) + throws Exception { + return source.createReader(readerContext); + } + + @Override + public SplitEnumerator createEnumerator( + SplitEnumeratorContext enumContext) throws Exception { + return source.createEnumerator(enumContext); + } + + @Override + public SplitEnumerator restoreEnumerator( + SplitEnumeratorContext enumContext, CheckpointT checkpoint) throws Exception { + return source.restoreEnumerator(enumContext, checkpoint); + } + + @Override + public SimpleVersionedSerializer getSplitSerializer() { + return source.getSplitSerializer(); + } + + @Override + public SimpleVersionedSerializer getEnumeratorCheckpointSerializer() { + return source.getEnumeratorCheckpointSerializer(); + } + + @Override + public LineageVertex getLineageVertex() { + return LineageUtils.sourceLineageVertex( + table.fullName(), getBoundedness() == Boundedness.BOUNDED, table); + } +} diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/MonitorSource.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/MonitorSource.java index d9b7d054cfcb..b6de64472b5b 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/MonitorSource.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/MonitorSource.java @@ -21,9 +21,12 @@ import org.apache.paimon.flink.NestedProjectedRowData; import org.apache.paimon.flink.source.AbstractNonCoordinatedSource; import org.apache.paimon.flink.source.AbstractNonCoordinatedSourceReader; +import org.apache.paimon.flink.source.NoOpEnumState; +import org.apache.paimon.flink.source.PaimonDataStreamSource; import org.apache.paimon.flink.source.SimpleSourceSplit; import org.apache.paimon.flink.source.SplitListState; import org.apache.paimon.flink.utils.JavaTypeInfo; +import org.apache.paimon.table.Table; import org.apache.paimon.table.sink.ChannelComputer; import org.apache.paimon.table.source.DataSplit; import org.apache.paimon.table.source.EndOfScanException; @@ -37,6 +40,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.connector.source.Boundedness; import org.apache.flink.api.connector.source.ReaderOutput; +import org.apache.flink.api.connector.source.Source; import org.apache.flink.api.connector.source.SourceReader; import org.apache.flink.api.connector.source.SourceReaderContext; import org.apache.flink.api.java.tuple.Tuple2; @@ -242,13 +246,43 @@ public static DataStream buildSource( NestedProjectedRowData nestedProjectedRowData, boolean isBounded, @Nullable Long limit) { + return buildSource( + env, + name, + typeInfo, + readBuilder, + monitorInterval, + emitSnapshotWatermark, + shuffleBucketWithPartition, + unordered, + nestedProjectedRowData, + isBounded, + limit, + null); + } + + public static DataStream buildSource( + StreamExecutionEnvironment env, + String name, + TypeInformation typeInfo, + ReadBuilder readBuilder, + long monitorInterval, + boolean emitSnapshotWatermark, + boolean shuffleBucketWithPartition, + boolean unordered, + NestedProjectedRowData nestedProjectedRowData, + boolean isBounded, + @Nullable Long limit, + @Nullable Table table) { + MonitorSource monitorSource = + new MonitorSource(readBuilder, monitorInterval, emitSnapshotWatermark, isBounded); + Source source = monitorSource; + if (table != null) { + source = new PaimonDataStreamSource<>(monitorSource, table); + } SingleOutputStreamOperator operator = env.fromSource( - new MonitorSource( - readBuilder, - monitorInterval, - emitSnapshotWatermark, - isBounded), + source, WatermarkStrategy.noWatermarks(), name + "-Monitor", new JavaTypeInfo<>(Split.class)) diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java index 62d601ec1b23..cea640ab8f10 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java @@ -20,7 +20,10 @@ import org.apache.paimon.CoreOptions; import org.apache.paimon.flink.PaimonDataStreamScanProvider; -import org.apache.paimon.flink.PaimonDataStreamSinkProvider; +import org.apache.paimon.flink.sink.PaimonDiscardingSink; +import org.apache.paimon.flink.source.ContinuousFileStoreSource; +import org.apache.paimon.flink.source.PaimonDataStreamSource; +import org.apache.paimon.flink.source.operator.MonitorSource; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.local.LocalFileIO; import org.apache.paimon.schema.Schema; @@ -33,6 +36,7 @@ import org.apache.flink.api.connector.source.Boundedness; import org.apache.flink.streaming.api.lineage.DatasetConfigFacet; +import org.apache.flink.streaming.api.lineage.DatasetSchemaFacet; import org.apache.flink.streaming.api.lineage.LineageDataset; import org.apache.flink.streaming.api.lineage.LineageDatasetFacet; import org.apache.flink.streaming.api.lineage.LineageVertex; @@ -177,6 +181,24 @@ void testConfigFacetWithEmptyKeys() throws Exception { assertThat(config).containsEntry("primary-keys", ""); } + @Test + void testSchemaFacetContainsPaimonFields() throws Exception { + FileStoreTable table = + createTable(new HashMap<>(), Collections.emptyList(), Arrays.asList("f0")); + + LineageVertex vertex = LineageUtils.sinkLineageVertex("paimon.db.t", table); + LineageDataset dataset = vertex.datasets().get(0); + + Map facets = dataset.facets(); + assertThat(facets).containsKey("schema"); + + DatasetSchemaFacet schemaFacet = (DatasetSchemaFacet) facets.get("schema"); + assertThat(schemaFacet.fields()).containsOnlyKeys("f0", "f1", "f2"); + assertThat(schemaFacet.fields().get("f0").type()).isEqualTo("INT NOT NULL"); + assertThat(schemaFacet.fields().get("f1").type()).isEqualTo("VARCHAR(100)"); + assertThat(schemaFacet.fields().get("f2").type()).isEqualTo("INT"); + } + @Test void testScanProviderImplementsLineageVertexProvider() throws Exception { FileStoreTable table = @@ -193,16 +215,47 @@ void testScanProviderImplementsLineageVertexProvider() throws Exception { } @Test - void testSinkProviderImplementsLineageVertexProvider() throws Exception { + void testSinkLineageViaPaimonDiscardingSink() throws Exception { FileStoreTable table = createTable(new HashMap<>(), Collections.emptyList(), Arrays.asList("f0")); - PaimonDataStreamSinkProvider provider = - new PaimonDataStreamSinkProvider(dataStream -> null, "paimon.db.sink", table); + PaimonDiscardingSink sink = new PaimonDiscardingSink<>(table); - assertThat(provider).isInstanceOf(LineageVertexProvider.class); - LineageVertex vertex = provider.getLineageVertex(); + assertThat(sink).isInstanceOf(LineageVertexProvider.class); + LineageVertex vertex = sink.getLineageVertex(); + assertThat(vertex.datasets()).hasSize(1); + } + + @Test + void testPaimonDataStreamSourceWrapsMonitorSourceLineageVertex() throws Exception { + FileStoreTable table = + createTable(new HashMap<>(), Collections.emptyList(), Arrays.asList("f0")); + + PaimonDataStreamSource source = + new PaimonDataStreamSource<>( + new MonitorSource(table.newReadBuilder(), 10, false, true), table); + + assertThat(source).isInstanceOf(LineageVertexProvider.class); + SourceLineageVertex vertex = (SourceLineageVertex) source.getLineageVertex(); + assertThat(vertex.boundedness()).isEqualTo(Boundedness.BOUNDED); + assertThat(vertex.datasets()).hasSize(1); + assertThat(vertex.datasets().get(0).name()).isEqualTo(table.fullName()); + } + + @Test + void testPaimonDataStreamSourceWrapsFlinkSourceLineageVertex() throws Exception { + FileStoreTable table = + createTable(new HashMap<>(), Collections.emptyList(), Arrays.asList("f0")); + + PaimonDataStreamSource source = + new PaimonDataStreamSource<>( + new ContinuousFileStoreSource( + table.newReadBuilder(), table.options(), null), + table); + + SourceLineageVertex vertex = (SourceLineageVertex) source.getLineageVertex(); + assertThat(vertex.boundedness()).isEqualTo(Boundedness.CONTINUOUS_UNBOUNDED); assertThat(vertex.datasets()).hasSize(1); - assertThat(vertex.datasets().get(0).name()).isEqualTo("paimon.db.sink"); + assertThat(vertex.datasets().get(0).name()).isEqualTo(table.fullName()); } } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSinkTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSinkTest.java new file mode 100644 index 000000000000..b97792e82425 --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSinkTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.sink; + +import org.apache.paimon.catalog.CatalogContext; +import org.apache.paimon.catalog.Identifier; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.options.Options; +import org.apache.paimon.table.FormatTable; +import org.apache.paimon.types.IntType; +import org.apache.paimon.types.RowType; + +import org.apache.flink.streaming.api.lineage.LineageVertex; +import org.apache.flink.streaming.api.lineage.LineageVertexProvider; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.lang.reflect.Constructor; +import java.util.Collections; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link FlinkFormatTableDataStreamSink}. */ +class FlinkFormatTableDataStreamSinkTest { + + @TempDir java.nio.file.Path temp; + + @Test + void testFormatTableSinkLineageVertex() throws Exception { + FormatTable table = + FormatTable.builder() + .fileIO(LocalFileIO.create()) + .identifier(Identifier.create("test_db", "test_table")) + .rowType(RowType.of(new IntType())) + .partitionKeys(Collections.emptyList()) + .location(new Path(temp.toUri().toString()).toString()) + .format(FormatTable.Format.PARQUET) + .options(Collections.singletonMap("path", temp.toUri().toString())) + .catalogContext(CatalogContext.create(new Options())) + .build(); + + Class sinkClass = + Class.forName( + "org.apache.paimon.flink.sink.FlinkFormatTableDataStreamSink$FormatTableSink"); + Constructor constructor = + sinkClass.getDeclaredConstructor(FormatTable.class, boolean.class, Map.class); + constructor.setAccessible(true); + Object sink = constructor.newInstance(table, false, Collections.emptyMap()); + + assertThat(sink).isInstanceOf(LineageVertexProvider.class); + LineageVertex vertex = ((LineageVertexProvider) sink).getLineageVertex(); + assertThat(vertex.datasets()).hasSize(1); + assertThat(vertex.datasets().get(0).name()).isEqualTo(table.fullName()); + } +} diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkSinkBuilderLineageTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkSinkBuilderLineageTest.java new file mode 100644 index 000000000000..a7ba7f687985 --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/sink/FlinkSinkBuilderLineageTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.sink; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.schema.Schema; +import org.apache.paimon.schema.SchemaManager; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.FileStoreTableFactory; +import org.apache.paimon.types.IntType; +import org.apache.paimon.types.RowType; + +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.DataStreamSink; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.transformations.SinkTransformation; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.paimon.flink.LogicalTypeConversion.toLogicalType; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for sink lineage in {@link FlinkSinkBuilder}. */ +class FlinkSinkBuilderLineageTest { + + @TempDir java.nio.file.Path temp; + + @Test + void testFlinkSinkBuilderUsesPaimonDiscardingSinkForLineage() throws Exception { + FileStoreTable table = createTable(); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStream input = + env.fromCollection( + Collections.singletonList((RowData) GenericRowData.of(1)), + InternalTypeInfo.of(toLogicalType(table.rowType()))); + + DataStreamSink sink = new FlinkSinkBuilder(table).forRowData(input).build(); + + assertThat(sink.getTransformation()).isInstanceOf(SinkTransformation.class); + SinkTransformation transformation = + (SinkTransformation) sink.getTransformation(); + assertThat(transformation.getSink()).isInstanceOf(PaimonDiscardingSink.class); + } + + private FileStoreTable createTable() throws Exception { + Path tablePath = new Path(temp.toUri().toString()); + Map options = new HashMap<>(); + options.put(CoreOptions.BUCKET.key(), "-1"); + new SchemaManager(LocalFileIO.create(), tablePath) + .createTable( + new Schema( + RowType.of(new IntType()).getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + "")); + return FileStoreTableFactory.create(LocalFileIO.create(), tablePath); + } +} diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkSourceBuilderTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkSourceBuilderTest.java index bc2ccb0fed13..9f6c46c2a793 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkSourceBuilderTest.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkSourceBuilderTest.java @@ -22,16 +22,25 @@ import org.apache.paimon.catalog.CatalogContext; import org.apache.paimon.catalog.CatalogFactory; import org.apache.paimon.catalog.Identifier; +import org.apache.paimon.flink.source.operator.MonitorSource; import org.apache.paimon.schema.Schema; import org.apache.paimon.table.Table; import org.apache.paimon.types.DataTypes; +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.transformations.SourceTransformation; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import java.nio.file.Path; +import static org.apache.paimon.flink.LogicalTypeConversion.toLogicalType; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -114,4 +123,63 @@ public void testUnawareBucket() throws Exception { builder = new FlinkSourceBuilder(table); assertTrue(builder.isUnordered()); } + + @Test + public void testBuildWrapsStaticSourceWithPaimonDataStreamSource() throws Exception { + Table table = createTable("static_source", false, -1, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream dataStream = + new FlinkSourceBuilder(table).env(env).sourceBounded(true).build(); + + assertThat(dataStream.getTransformation()).isInstanceOf(SourceTransformation.class); + SourceTransformation transformation = + (SourceTransformation) dataStream.getTransformation(); + assertThat(transformation.getSource()).isInstanceOf(PaimonDataStreamSource.class); + } + + @Test + public void testBuildWrapsContinuousSourceWithPaimonDataStreamSource() throws Exception { + Table table = createTable("continuous_source", false, -1, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream dataStream = + new FlinkSourceBuilder(table).env(env).sourceBounded(false).build(); + + assertThat(dataStream.getTransformation()).isInstanceOf(SourceTransformation.class); + SourceTransformation transformation = + (SourceTransformation) dataStream.getTransformation(); + assertThat(transformation.getSource()).isInstanceOf(PaimonDataStreamSource.class); + } + + @Test + public void testMonitorSourceBuildSourceWrapsWithPaimonDataStreamSource() throws Exception { + Table table = createTable("monitor_source", false, -1, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + DataStream dataStream = + MonitorSource.buildSource( + env, + "source", + InternalTypeInfo.of(toLogicalType(table.rowType())), + table.newReadBuilder(), + 10, + false, + false, + false, + null, + true, + null, + table); + + assertThat(dataStream.getTransformation().getTransitivePredecessors()) + .filteredOn(Transformation.class::isInstance) + .filteredOn(transformation -> transformation instanceof SourceTransformation) + .anySatisfy( + transformation -> + assertThat( + ((SourceTransformation) transformation) + .getSource()) + .isInstanceOf(PaimonDataStreamSource.class)); + } } From fa3f15bc239586bf599343643c8615ca58e5e6ef Mon Sep 17 00:00:00 2001 From: jkukreja Date: Sun, 31 May 2026 09:58:25 -0400 Subject: [PATCH 2/4] Trigger CI From d88cdc0d31729bc9836e007d1be9ac1ed1e736bd Mon Sep 17 00:00:00 2001 From: jkukreja Date: Fri, 5 Jun 2026 00:24:43 -0400 Subject: [PATCH 3/4] Merge branch 'master' into add-lineage-for-non-table-apis --- .github/workflows/paimon-python-checks.yml | 26 +- ...ase-rust-native.yml => utcase-tantivy.yml} | 42 +- SECURITY.md | 391 +++++ docs/docs/append-table/blob.mdx | 800 ---------- docs/docs/append-table/bucketed.mdx | 2 +- docs/docs/append-table/index.mdx | 2 +- docs/docs/concepts/data-types.md | 2 +- docs/docs/concepts/spec/fileformat.md | 60 +- docs/docs/concepts/system-tables.mdx | 34 + docs/docs/flink/procedures.md | 38 +- docs/docs/index.mdx | 33 +- docs/docs/learn-paimon/scenario-guide.mdx | 30 +- docs/docs/maintenance/metrics.md | 10 + docs/docs/migration/migration-from-hive.mdx | 2 +- docs/docs/multimodal-table/blob.mdx | 605 +++++++ .../data-evolution.md | 0 .../global-index.mdx | 59 + docs/docs/multimodal-table/index.mdx | 47 + .../vector.mdx | 172 +- docs/docs/primary-key-table/chain-table.md | 64 + docs/docs/primary-key-table/index.md | 2 +- .../merge-engine/aggregation.mdx | 2 + .../primary-key-table/sequence-rowkind.mdx | 47 + docs/docs/project/security.md | 73 + docs/docs/pypaimon/blob.md | 4 +- docs/docs/pypaimon/cli.md | 5 +- docs/docs/pypaimon/data-evolution.md | 2 +- docs/docs/pypaimon/global-index.md | 10 +- docs/docs/pypaimon/index.md | 15 + docs/docs/pypaimon/python-api.mdx | 42 +- docs/docs/pypaimon/ray-data.md | 197 ++- docs/docs/spark/procedures.md | 13 +- docs/docs/spark/sql-write.md | 357 +++-- docs/generated/core_configuration.html | 42 +- .../generated/hive_catalog_configuration.html | 6 + .../spark_connector_configuration.html | 10 +- docs/redirects.js | 24 +- docs/sidebars.js | 34 +- docs/src/css/custom.css | 15 + .../java/org/apache/paimon/CoreOptions.java | 138 +- .../java/org/apache/paimon/rest/RESTApi.java | 22 - .../org/apache/paimon/rest/ResourcePaths.java | 13 - .../rest/requests/RenameBranchRequest.java | 46 - .../arrow/ArrowFieldTypeConversion.java | 6 +- .../arrow/vector/ArrowFormatWriterTest.java | 21 + .../apache/paimon/data/BlobViewResolver.java | 4 + .../apache/paimon/data/BlobViewStruct.java | 6 + .../ApplyBitmapIndexFileRecordIterator.java | 13 +- .../bitmap/ApplyBitmapIndexRecordReader.java | 9 +- .../fs/cache/CachingSeekableInputStream.java | 50 +- .../globalindex/GlobalIndexEvaluator.java | 29 +- .../paimon/globalindex/GlobalIndexReader.java | 19 +- .../paimon/globalindex/GlobalIndexResult.java | 49 +- .../GlobalIndexResultSerializer.java | 4 +- .../paimon/globalindex/GlobalIndexer.java | 7 +- .../globalindex/OffsetGlobalIndexReader.java | 85 +- .../globalindex/ScoredGlobalIndexResult.java | 17 +- .../globalindex/UnionGlobalIndexReader.java | 156 +- .../globalindex/bitmap/BitmapGlobalIndex.java | 78 - .../bitmap/BitmapGlobalIndexerFactory.java | 42 - .../globalindex/btree/BTreeGlobalIndexer.java | 9 +- .../globalindex/btree/BTreeIndexReader.java | 242 +-- .../globalindex/btree/KeySerializer.java | 24 +- .../btree/LazyFilteredBTreeReader.java | 312 ++-- .../wrap/FileIndexReaderWrapper.java | 123 -- .../wrap/FileIndexWriterWrapper.java | 66 - .../paimon/predicate/FullTextSearch.java | 23 +- .../apache/paimon/predicate/VectorSearch.java | 7 - .../paimon/reader/DataEvolutionArray.java | 30 + .../org/apache/paimon/sst/BlockCache.java | 15 +- .../paimon/utils/FileBasedBloomFilter.java | 11 +- .../org/apache/paimon/utils/FileType.java | 2 +- .../org/apache/paimon/utils/LazyField.java | 22 +- ...he.paimon.globalindex.GlobalIndexerFactory | 3 +- .../paimon/data/BlobViewStructTest.java | 11 + .../ApplyBitmapIndexRecordReaderTest.java | 255 +++ .../globalindex/GlobalIndexEvaluatorTest.java | 286 ++-- .../GlobalIndexSerDeUtilsTest.java | 8 +- .../bitmapindex/BitmapGlobalIndexTest.java | 311 ---- .../btree/AbstractIndexReaderTest.java | 30 +- .../btree/BTreeIndexReaderTest.java | 7 +- .../btree/BTreeThreadSafetyTest.java | 385 +++++ .../LazyFilteredBTreeIndexReaderTest.java | 291 +++- .../TestFullTextGlobalIndexReader.java | 79 +- .../TestFullTextGlobalIndexer.java | 5 +- .../TestVectorGlobalIndexReader.java | 77 +- .../testvector/TestVectorGlobalIndexer.java | 5 +- .../org/apache/paimon/AbstractFileStore.java | 84 +- .../java/org/apache/paimon/FileStore.java | 6 +- .../main/java/org/apache/paimon/KeyValue.java | 5 + .../DedicatedFormatRollingFileWriter.java | 58 +- .../DataEvolutionCompactCoordinator.java | 29 +- .../DataEvolutionCompactTask.java | 56 +- .../GlobalIndexReadThreadPool.java | 47 + .../globalindex/GlobalIndexScanner.java | 69 +- .../btree/BTreeGlobalIndexBuilder.java | 13 +- .../apache/paimon/index/IndexFileHandler.java | 49 + .../org/apache/paimon/io/DataFileMeta.java | 7 + .../io/KeyValueDataFileRecordReader.java | 32 +- .../paimon/io/KeyValueFileReaderFactory.java | 25 +- .../IndexManifestEntrySerializer.java | 15 + .../compact/PartialUpdateMergeFunction.java | 19 +- .../aggregate/FieldNestedUpdateAgg.java | 69 +- .../factory/FieldNestedUpdateAggFactory.java | 13 +- .../ChainTableCommitPreCallback.java | 98 +- .../operation/ChainTablePartitionExpire.java | 446 ++++++ .../operation/DataEvolutionFileStoreScan.java | 119 +- .../paimon/operation/FileStoreCommitImpl.java | 66 +- .../operation/ManifestAdjacentSortedRun.java | 102 ++ .../paimon/operation/ManifestFileMerger.java | 71 +- .../paimon/operation/ManifestFileSorter.java | 1149 +++++++++++++ .../operation/ManifestPickStrategy.java | 149 ++ .../operation/NormalPartitionExpire.java | 228 +++ .../paimon/operation/PartitionExpire.java | 222 +-- .../operation/metrics/CommitMetrics.java | 4 + .../paimon/operation/metrics/CommitStats.java | 10 +- .../paimon/privilege/PrivilegedFileStore.java | 8 +- .../org/apache/paimon/rest/RESTCatalog.java | 10 +- .../paimon/schema/ColumnDirectiveUtils.java | 397 +++++ .../apache/paimon/schema/SchemaManager.java | 77 +- .../paimon/schema/SchemaMergingUtils.java | 199 ++- .../paimon/schema/SchemaValidation.java | 75 +- .../paimon/sort/BinaryExternalSortBuffer.java | 6 +- .../paimon/table/AbstractFileStoreTable.java | 9 +- .../apache/paimon/table/KnownSplitsTable.java | 9 + .../paimon/table/format/FormatTableScan.java | 27 +- .../table/source/BlobViewResolvingRow.java | 12 + .../paimon/table/source/DvAwareStats.java | 41 + .../paimon/table/source/FullTextReadImpl.java | 76 +- .../table/source/FullTextSearchBuilder.java | 3 + .../source/FullTextSearchBuilderImpl.java | 9 +- .../paimon/table/source/PushDownUtils.java | 18 + .../paimon/table/source/VectorReadImpl.java | 68 +- .../source/snapshot/SnapshotReaderImpl.java | 60 +- .../table/system/CompactBucketsTable.java | 2 - .../apache/paimon/utils/BlobViewLookup.java | 164 +- .../org/apache/paimon/utils/DVMetaCache.java | 86 +- .../java/org/apache/paimon/JavaPyE2ETest.java | 62 + .../apache/paimon/append/BlobTableTest.java | 124 +- .../paimon/append/MultipleBlobTableTest.java | 3 +- .../DataEvolutionCompactCoordinatorTest.java | 18 + .../btree/BTreeGlobalIndexBuilderTest.java | 70 + .../paimon/index/IndexFileHandlerTest.java | 62 + .../paimon/manifest/ManifestFileMetaTest.java | 701 +++++++- .../NoPartitionManifestFileMetaTest.java | 20 +- .../PartialUpdateMergeFunctionTest.java | 63 + .../SortMergeSnapshotOrderingTest.java | 134 ++ .../aggregate/FieldAggregatorTest.java | 255 +++ .../ChainTablePartitionExpireTest.java | 786 +++++++++ .../paimon/operation/FileStoreCommitTest.java | 2 +- .../paimon/operation/PartitionExpireTest.java | 18 +- .../operation/metrics/CommitMetricsTest.java | 13 +- .../operation/metrics/CommitStatsTest.java | 12 +- .../apache/paimon/rest/RESTCatalogServer.java | 27 +- .../apache/paimon/rest/RESTCatalogTest.java | 20 +- .../schema/ColumnDirectiveUtilsTest.java | 373 +++++ .../paimon/schema/SchemaMergingUtilsTest.java | 452 ++++-- .../paimon/schema/SchemaValidationTest.java | 176 ++ .../table/AppendOnlySimpleTableTest.java | 101 +- .../table/BitmapGlobalIndexTableTest.java | 255 --- .../table/BtreeGlobalIndexTableTest.java | 37 +- .../paimon/table/DataEvolutionTableTest.java | 376 +++++ .../table/PrimaryKeySimpleTableTest.java | 597 ++++++- .../paimon/table/SchemaEvolutionTest.java | 348 ++++ .../paimon/table/SimpleTableTestBase.java | 11 + .../table/format/FormatTableScanTest.java | 37 +- .../paimon/table/source/DvAwareStatsTest.java | 76 + .../table/source/PushDownUtilsTest.java | 116 ++ .../apache/paimon/utils/DVMetaCacheTest.java | 92 +- paimon-filesystems/paimon-hadoop-uber/pom.xml | 1 + .../TableAwareFileStoreSourceSplit.java | 6 + .../source/reader/CDCSourceSplitReader.java | 60 +- .../mysql/MySqlSyncDatabaseActionITCase.java | 2 +- ...areFileStoreSourceSplitSerializerTest.java | 50 +- .../reader/CDCSourceSplitReaderTest.java | 166 +- .../flink/action/ExpirePartitionsAction.java | 6 +- .../paimon/flink/lineage/LineageUtils.java | 52 + .../procedure/ExpirePartitionsProcedure.java | 2 - .../sink/FlinkFormatTableDataStreamSink.java | 2 +- .../flink/sink/PaimonDiscardingSink.java | 2 +- .../flink/source/PaimonDataStreamSource.java | 3 +- .../metrics/FileStoreSourceReaderMetrics.java | 1 - .../flink/source/operator/ReadOperator.java | 15 +- .../paimon/flink/BatchFileStoreITCase.java | 32 + .../paimon/flink/PartialUpdateITCase.java | 2 +- .../paimon/flink/SchemaChangeITCase.java | 21 + .../flink/action/ConsumerActionITCase.java | 86 +- .../flink/lineage/LineageUtilsTest.java | 54 + .../DropGlobalIndexProcedureITCase.java | 60 +- .../operator/DedicatedSplitReadLimitTest.java | 163 ++ .../source/operator/OperatorSourceTest.java | 8 +- .../paimon/format/row/RowFormatWriter.java | 1 - .../paimon/hive/HiveAlterTableUtils.java | 12 +- .../org/apache/paimon/hive/HiveCatalog.java | 8 +- .../paimon/hive/HiveCatalogOptions.java | 8 + .../paimon/hive/migrate/HiveMigrator.java | 14 +- .../paimon/hive/HiveTableStatsTest.java | 11 +- .../org/apache/paimon/hive/HiveTypeUtils.java | 6 + .../hive/mapred/PaimonOutputFormat.java | 124 +- .../hive/mapred/PartitionedRecordWriter.java | 92 ++ .../PaimonObjectInspectorFactory.java | 5 + .../apache/paimon/hive/HiveWriteITCase.java | 69 + .../hive/mapred/PaimonOutputFormatTest.java | 169 ++ .../PaimonBlobObjectInspectorTest.java | 80 + paimon-lumina/pom.xml | 2 +- .../index/LuminaVectorGlobalIndexReader.java | 102 +- .../index/LuminaVectorGlobalIndexer.java | 7 +- .../lumina/index/LuminaVectorBenchmark.java | 7 +- .../index/LuminaVectorGlobalIndexTest.java | 78 +- paimon-mosaic/pom.xml | 86 + .../format/mosaic/MosaicFileFormat.java | 238 +++ .../mosaic/MosaicFileFormatFactory.java | 38 + .../format/mosaic/MosaicInputFileAdapter.java | 79 + .../paimon/format/mosaic/MosaicObjects.java | 103 ++ .../format/mosaic/MosaicReaderFactory.java | 60 + .../format/mosaic/MosaicRecordsReader.java | 214 +++ .../format/mosaic/MosaicRecordsWriter.java | 199 +++ .../mosaic/MosaicSimpleStatsExtractor.java | 180 +++ .../format/mosaic/MosaicWriterFactory.java | 75 + .../format/mosaic/MosaicWriterMetadata.java | 53 + ...org.apache.paimon.format.FileFormatFactory | 16 + .../format/mosaic/MosaicFileFormatTest.java | 124 ++ .../mosaic/MosaicFormatReadWriteTest.java | 137 ++ .../format/mosaic/MosaicObjectsTest.java | 209 +++ .../format/mosaic/MosaicReaderWriterTest.java | 361 +++++ .../MosaicSimpleStatsExtractorTest.java | 212 +++ .../mosaic/MosaicWriterMetadataTest.java | 386 +++++ paimon-python/README.md | 79 + paimon-python/dev/requirements-dev.txt | 7 +- paimon-python/dev/requirements.txt | 5 +- paimon-python/dev/run_mixed_tests.sh | 65 +- .../pypaimon/benchmark/hdfs_io_bench.py | 168 ++ .../pypaimon/catalog/catalog_factory.py | 8 +- .../pypaimon/catalog/filesystem_catalog.py | 6 +- .../pypaimon/catalog/jdbc_catalog.py | 622 ++++++++ .../pypaimon/catalog/jdbc_catalog_loader.py | 32 + .../catalog/rest/rest_token_file_io.py | 3 + .../pypaimon/common/external_path_provider.py | 192 ++- paimon-python/pypaimon/common/file_io.py | 65 +- .../pypaimon/common/merge_engine_dispatch.py | 167 ++ .../pypaimon/common/options/config.py | 42 + .../pypaimon/common/options/core_options.py | 266 ++- paimon-python/pypaimon/common/predicate.py | 2 +- paimon-python/pypaimon/daft/__init__.py | 4 +- paimon-python/pypaimon/daft/daft_catalog.py | 49 +- paimon-python/pypaimon/daft/daft_datasink.py | 128 +- .../pypaimon/daft/daft_datasource.py | 469 +++++- paimon-python/pypaimon/daft/daft_explain.py | 160 ++ paimon-python/pypaimon/daft/daft_paimon.py | 144 +- .../pypaimon/daft/daft_predicate_visitor.py | 74 +- .../pypaimon/filesystem/_kerberos.py | 45 + .../pypaimon/filesystem/caching_file_io.py | 51 +- .../filesystem/hdfs_native_file_io.py | 698 ++++++++ .../pypaimon/filesystem/local_file_io.py | 31 + paimon-python/pypaimon/filesystem/pvfs.py | 13 +- .../pypaimon/filesystem/pyarrow_file_io.py | 45 +- .../pypaimon/globalindex/btree/__init__.py | 10 +- .../btree/btree_file_meta_selector.py | 132 ++ .../globalindex/btree/btree_index_reader.py | 397 ++--- .../btree/lazy_filtered_btree_reader.py | 210 +++ .../globalindex/btree/sst_file_reader.py | 30 +- .../pypaimon/globalindex/full_text_search.py | 17 +- .../globalindex/global_index_evaluator.py | 53 +- .../globalindex/global_index_reader.py | 112 +- .../globalindex/global_index_result.py | 43 +- .../globalindex/global_index_scanner.py | 45 +- .../lumina_vector_global_index_reader.py | 61 +- .../globalindex/offset_global_index_reader.py | 98 +- .../tantivy_full_text_global_index_reader.py | 458 +++++- .../globalindex/union_global_index_reader.py | 132 +- .../pypaimon/globalindex/vector_search.py | 3 +- .../globalindex/vector_search_result.py | 25 +- .../pypaimon/manifest/index_manifest_file.py | 119 +- paimon-python/pypaimon/ray/__init__.py | 21 +- .../pypaimon/ray/data_evolution_merge_into.py | 574 +++++++ .../pypaimon/ray/data_evolution_merge_join.py | 388 +++++ .../ray/data_evolution_merge_transform.py | 150 ++ paimon-python/pypaimon/ray/merge_condition.py | 104 ++ paimon-python/pypaimon/ray/ray_paimon.py | 82 +- paimon-python/pypaimon/ray/shuffle.py | 86 +- .../pypaimon/read/merge_engine_support.py | 216 ++- .../read/reader/aggregate/__init__.py | 84 + .../read/reader/aggregate/aggregators.py | 283 ++++ .../read/reader/aggregate/field_aggregator.py | 81 + .../read/reader/aggregation_merge_function.py | 203 +++ .../reader/blob_descriptor_convert_reader.py | 179 ++- .../read/reader/concat_batch_reader.py | 130 +- .../read/reader/data_file_batch_reader.py | 108 +- .../read/reader/deduplicate_merge_function.py | 50 + .../pypaimon/read/reader/field_bunch.py | 52 + .../read/reader/first_row_merge_function.py | 55 + .../read/reader/format_blob_reader.py | 156 +- .../read/reader/format_mosaic_reader.py | 140 ++ .../pypaimon/read/reader/format_row_reader.py | 469 ++++++ .../read/reader/format_vortex_reader.py | 28 +- .../read/reader/limited_record_reader.py | 38 + .../reader/partial_update_merge_function.py | 83 +- .../pypaimon/read/reader/sort_merge_reader.py | 195 ++- .../pypaimon/read/scanner/file_scanner.py | 17 +- paimon-python/pypaimon/read/split_read.py | 247 ++- paimon-python/pypaimon/read/table_read.py | 28 + paimon-python/pypaimon/read/table_scan.py | 103 ++ .../pypaimon/schema/column_directive_utils.py | 236 +++ paimon-python/pypaimon/schema/schema.py | 34 - .../pypaimon/schema/schema_manager.py | 126 +- .../pypaimon/table/file_store_table.py | 18 +- paimon-python/pypaimon/table/row/blob.py | 160 +- paimon-python/pypaimon/table/row/key_value.py | 14 + .../pypaimon/table/source/full_text_read.py | 38 +- .../table/source/full_text_search_builder.py | 13 +- .../table/source/vector_search_read.py | 29 +- .../pypaimon/table/special_fields.py | 19 + .../pypaimon/table/system/buckets_table.py | 165 ++ .../table/system/system_table_loader.py | 4 +- .../pypaimon/tests/blob_table_test.py | 959 ++++++++++- paimon-python/pypaimon/tests/blob_test.py | 121 +- .../tests/btree_thread_safety_test.py | 228 +++ .../tests/column_directive_utils_test.py | 228 +++ .../pypaimon/tests/daft/daft_catalog_test.py | 17 + .../pypaimon/tests/daft/daft_data_test.py | 350 +++- .../pypaimon/tests/daft/daft_explain_test.py | 420 +++++ .../tests/daft/daft_integration_test.py | 236 +++ .../pypaimon/tests/daft/daft_sink_test.py | 22 + .../tests/data_evolution_formats_test.py | 1063 ++++++++++++ .../pypaimon/tests/e2e/hdfs/README.md | 75 + .../pypaimon/tests/e2e/hdfs/__init__.py | 16 + .../tests/e2e/hdfs/docker-compose.yml | 55 + .../tests/e2e/hdfs/hdfs_native_e2e_test.py | 107 ++ .../tests/e2e/java_py_read_write_test.py | 216 ++- .../pypaimon/tests/external_paths_test.py | 292 +++- .../tests/external_storage_blob_test.py | 2 +- .../tests/global_index_evaluator_test.py | 154 +- .../pypaimon/tests/global_index_test.py | 84 + .../pypaimon/tests/hdfs_native_test.py | 917 +++++++++++ .../tests/index_manifest_write_test.py | 121 ++ .../pypaimon/tests/jdbc_catalog_test.py | 223 +++ paimon-python/pypaimon/tests/kerberos_test.py | 2 +- .../tests/lumina_vector_index_test.py | 4 +- .../tests/nested_type_read_write_test.py | 168 ++ .../pypaimon/tests/predicates_test.py | 8 + paimon-python/pypaimon/tests/pvfs_test.py | 44 + .../tests/py36/rest_ao_read_write_test.py | 7 +- .../ray_data_evolution_merge_into_test.py | 1420 +++++++++++++++++ .../pypaimon/tests/ray_integration_test.py | 118 +- .../pypaimon/tests/ray_repartition_test.py | 212 ++- paimon-python/pypaimon/tests/ray_sink_test.py | 138 ++ .../pypaimon/tests/reader_base_test.py | 7 +- .../pypaimon/tests/reader_primary_key_test.py | 13 +- .../tests/schema_evolution_read_test.py | 151 ++ .../tests/system/buckets_table_test.py | 136 ++ .../tests/system/system_table_loader_test.py | 2 +- .../pypaimon/tests/table_scan_mode_test.py | 94 ++ .../pypaimon/tests/table_update_test.py | 83 +- .../tests/table_upsert_by_key_test.py | 42 + .../pypaimon/tests/test_aggregation_e2e.py | 375 +++++ .../tests/test_aggregation_merge_function.py | 300 ++++ .../tests/test_field_aggregator_registry.py | 103 ++ .../pypaimon/tests/test_field_aggregators.py | 274 ++++ .../pypaimon/tests/test_first_row_e2e.py | 200 +++ .../tests/test_first_row_merge_function.py | 146 ++ .../tests/test_format_mosaic_reader_writer.py | 304 ++++ .../tests/test_format_mosaic_table.py | 193 +++ .../tests/test_format_row_reader_writer.py | 534 +++++++ .../pypaimon/tests/test_format_row_table.py | 503 ++++++ .../pypaimon/tests/test_limit_pushdown.py | 52 +- .../tests/test_merge_engine_dispatch.py | 135 ++ .../pypaimon/tests/test_partial_update_e2e.py | 173 +- .../test_partial_update_merge_function.py | 44 +- .../pypaimon/tests/test_ray_shuffle_helper.py | 115 +- .../tests/test_sequence_field_read.py | 545 +++++++ .../pypaimon/tests/test_write_merge_buffer.py | 366 +++++ .../tests/vector_search_filter_test.py | 800 +++++++++- .../pypaimon/utils/blob_view_lookup.py | 274 ++++ .../pypaimon/utils/file_store_path_factory.py | 11 +- paimon-python/pypaimon/utils/file_type.py | 2 +- .../pypaimon/write/blob_format_writer.py | 23 +- .../pypaimon/write/commit_message.py | 10 +- .../pypaimon/write/file_store_commit.py | 56 +- .../pypaimon/write/file_store_write.py | 118 +- .../write/global_index_update_checker.py | 82 + paimon-python/pypaimon/write/ray_datasink.py | 45 +- paimon-python/pypaimon/write/table_update.py | 12 +- .../pypaimon/write/table_update_by_row_id.py | 235 ++- paimon-python/pypaimon/write/table_write.py | 42 +- paimon-python/pypaimon/write/write_builder.py | 4 +- .../pypaimon/write/writer/blob_file_writer.py | 32 +- .../pypaimon/write/writer/blob_writer.py | 23 +- .../write/writer/data_vector_writer.py | 4 + .../pypaimon/write/writer/data_writer.py | 4 + ...b_writer.py => dedicated_format_writer.py} | 226 ++- .../write/writer/format_row_writer.py | 408 +++++ .../write/writer/key_value_data_writer.py | 223 ++- paimon-python/setup.py | 6 + .../paimon/spark/PaimonScanBuilder.scala | 18 +- .../paimon/spark/PaimonScanBuilderTest.scala | 44 + .../paimon/spark/sql/CopyIntoTest.scala | 2 +- .../paimon/spark/DataFrameWriteTest.scala | 2 - .../paimon/spark/sql/CopyIntoTest.scala | 2 +- .../apache/paimon/spark/PaimonSinkTest.scala | 1 + .../paimon/spark/sql/CopyIntoTest.scala | 2 +- .../paimon/spark/sql/CopyIntoTest.scala | 2 +- .../paimon/spark/sql/RowTrackingTest.scala | 19 +- .../PaimonSpark4SqlExtensionsParser.scala | 30 + .../MergeIntoPaimonDataEvolutionTable.scala | 148 +- .../spark/commands/MergeIntoPaimonTable.scala | 2 + ...stractPaimonSparkSqlExtensionsParser.scala | 468 ------ .../paimon/spark/sql/CopyIntoTest.scala | 2 +- .../paimon/spark/sql/CopyIntoTest.scala | 2 +- .../paimon/spark/sql/RowTrackingTest.scala | 109 +- .../PaimonSqlExtensions.g4 | 10 +- .../spark/AbstractSparkInternalRow.java | 9 +- .../apache/paimon/spark/DataConverter.java | 14 + .../org/apache/paimon/spark/SparkCatalog.java | 33 +- .../paimon/spark/SparkConnectorOptions.java | 18 +- .../paimon/spark/SparkFilterConverter.java | 25 +- .../paimon/spark/SparkInternalRowWrapper.java | 24 +- .../org/apache/paimon/spark/SparkRow.java | 96 +- .../schema/PaimonMetadataColumnBase.java | 40 + .../PaimonV2MetadataAwareDataWriter.java | 62 + .../spark/PaimonRecordReaderIterator.scala | 27 +- .../paimon/spark/PaimonSparkTableBase.scala | 3 + .../org/apache/paimon/spark/SparkTable.scala | 29 +- .../apache/paimon/spark/SparkTypeUtils.java | 31 +- .../aggregate/AggregatePushDownUtils.scala | 8 +- .../analysis/MergeSchemaEvolutionHelper.scala | 38 +- .../catalyst/analysis/PaimonAnalysis.scala | 21 +- .../analysis/PaimonAssignmentUtils.scala | 10 +- .../catalyst/analysis/PaimonMergeInto.scala | 3 +- .../analysis/PaimonOutputResolver.scala | 12 +- .../catalyst/analysis/RowLevelHelper.scala | 18 +- .../plans/logical/CopyIntoTableCommand.scala | 7 +- .../catalyst/plans/logical/CopyOptions.scala | 155 +- .../commands/DataEvolutionPaimonWriter.scala | 42 +- .../MergeIntoPaimonDataEvolutionTable.scala | 147 +- .../spark/commands/MergeIntoPaimonTable.scala | 2 + ...imonDynamicPartitionOverwriteCommand.scala | 4 +- .../commands/SchemaEvolutionHelper.scala | 200 +++ .../paimon/spark/commands/SchemaHelper.scala | 159 -- .../spark/commands/WriteIntoPaimonTable.scala | 6 +- .../copyinto/CopyIntoResultBuilder.scala | 145 ++ .../execution/CopyIntoCastValidator.scala | 196 +++ .../execution/CopyIntoDataFrameBuilder.scala | 187 +++ .../execution/CopyIntoErrorHandler.scala | 276 ++++ .../spark/execution/CopyIntoHelper.scala | 138 ++ .../execution/CopyIntoLocationExec.scala | 11 +- .../spark/execution/CopyIntoTableExec.scala | 525 +++--- .../spark/execution/CopyIntoUtils.scala | 15 + .../spark/execution/PaimonStrategy.scala | 3 +- .../paimon/spark/read/BinPackingSplits.scala | 43 +- .../PaimonSparkCopyOnWriteOperation.scala | 9 +- .../spark/schema/PaimonMetadataColumn.scala | 26 +- .../paimon/spark/sources/PaimonSink.scala | 4 +- .../paimon/spark/util/OptionUtils.scala | 41 +- .../write/DataEvolutionTableDataWrite.scala | 6 +- .../spark/write/PaimonBatchWriteBase.scala | 44 +- .../spark/write/PaimonV2DataWriter.scala | 68 +- .../paimon/spark/write/PaimonV2Write.scala | 7 +- .../org/apache/spark/sql/PaimonUtils.scala | 6 +- ...stractPaimonSparkSqlExtensionsParser.scala | 28 +- .../PaimonSqlExtensionsAstBuilder.scala | 10 +- .../paimon/spark/SparkChainTableITCase.java | 64 + .../spark/SparkDataEvolutionITCase.java | 306 ++++ .../spark/SparkFilterConverterTest.java | 34 + .../paimon/spark/SparkMultimodalITCase.java | 143 ++ .../spark/SparkSchemaEvolutionITCase.java | 34 + .../apache/paimon/spark/SparkTypeTest.java | 19 + .../paimon/spark/BinPackingSplitsTest.scala | 188 ++- .../apache/paimon/spark/PaimonSinkTest.scala | 1 + ...DynamicPartitionOverwriteCommandTest.scala | 113 ++ .../execution/CopyIntoCastValidatorTest.scala | 204 +++ .../CopyIntoDataFrameBuilderTest.scala | 231 +++ .../spark/execution/CopyIntoHelperTest.scala | 133 ++ .../CreateGlobalIndexProcedureTest.scala | 155 -- .../DropGlobalIndexProcedureTest.scala | 44 +- .../procedure/MigrateTableProcedureTest.scala | 15 + .../CatalogQualifiedCreateTableLikeTest.scala | 10 + .../spark/sql/CopyIntoOnErrorTest.scala | 527 ++++++ .../paimon/spark/sql/CopyIntoTestBase.scala | 768 ++++++++- .../apache/paimon/spark/sql/DDLTestBase.scala | 2 +- .../spark/sql/DataFrameWriteTestBase.scala | 96 +- .../spark/sql/PushDownAggregatesTest.scala | 19 + .../spark/sql/RowTrackingTestBase.scala | 193 ++- .../spark/sql/V2WriteMergeSchemaTest.scala | 1 + .../spark/sql/WriteMergeSchemaTest.scala | 457 +++++- .../PaimonSpark3SqlExtensionsParser.scala | 3 +- .../PaimonSpark4SqlExtensionsParser.scala | 3 +- .../analysis/PureAppendOnlyScope.scala | 50 +- .../Spark41DeleteMetadataRestore.scala | 10 +- .../analysis/Spark41MergeIntoRewrite.scala | 15 +- .../analysis/Spark41UpdateTableRewrite.scala | 7 +- paimon-tantivy/paimon-tantivy-index/README.md | 49 + .../TantivyFullTextGlobalIndexReader.java | 126 +- .../TantivyFullTextGlobalIndexWriter.java | 13 +- .../index/TantivyFullTextGlobalIndexer.java | 19 +- .../TantivyFullTextGlobalIndexerFactory.java | 18 +- .../index/TantivyFullTextIndexOptions.java | 420 +++++ .../tantivy/index/JavaPyTantivyE2ETest.java | 107 +- .../index/TantivyFullTextGlobalIndexTest.java | 97 +- .../TantivyFullTextIndexOptionsTest.java | 160 ++ paimon-tantivy/paimon-tantivy-jni/README.md | 51 + .../paimon-tantivy-jni/rust/Cargo.toml | 2 + .../paimon-tantivy-jni/rust/src/lib.rs | 668 ++++++-- .../paimon/tantivy/TantivyIndexWriter.java | 29 + .../paimon/tantivy/TantivySearcher.java | 92 +- .../apache/paimon/tantivy/TantivyJniTest.java | 74 +- .../format/vortex/VortexFileFormat.java | 10 +- .../vortex/VortexPredicateConverter.java | 96 +- .../format/vortex/VortexRecordsReader.java | 118 +- .../format/vortex/VortexRecordsWriter.java | 137 +- .../paimon/format/vortex/VortexTypeUtils.java | 198 --- .../format/vortex/VortexWriterFactory.java | 17 +- .../vortex/VortexFileFormatReadWriteTest.java | 27 - .../vortex/VortexPredicateConverterTest.java | 346 ++-- .../format/vortex/VortexReaderWriterTest.java | 66 +- .../format/vortex/VortexTypeUtilsTest.java | 132 -- paimon-vortex/paimon-vortex-jni/README.md | 71 +- paimon-vortex/paimon-vortex-jni/pom.xml | 55 +- .../src/main/java/dev/vortex/api/Array.java | 67 - .../src/main/java/dev/vortex/api/DType.java | 270 ---- .../main/java/dev/vortex/api/DataSource.java | 100 ++ .../main/java/dev/vortex/api/Expression.java | 204 ++- .../src/main/java/dev/vortex/api/Files.java | 49 - .../main/java/dev/vortex/api/Partition.java | 70 + .../src/main/java/dev/vortex/api/Scan.java | 76 + .../main/java/dev/vortex/api/ScanOptions.java | 46 +- .../src/main/java/dev/vortex/api/Session.java | 51 + .../java/dev/vortex/api/VortexWriter.java | 80 +- .../dev/vortex/api/expressions/Binary.java | 234 --- .../dev/vortex/api/expressions/GetItem.java | 95 -- .../dev/vortex/api/expressions/IsNotNull.java | 89 -- .../dev/vortex/api/expressions/IsNull.java | 88 - .../dev/vortex/api/expressions/Literal.java | 781 --------- .../java/dev/vortex/api/expressions/Not.java | 88 - .../java/dev/vortex/api/expressions/Root.java | 64 - .../dev/vortex/api/expressions/Unknown.java | 51 - .../java/dev/vortex/api/proto/DTypes.java | 201 --- .../dev/vortex/api/proto/EndianUtils.java | 74 - .../dev/vortex/api/proto/Expressions.java | 70 - .../java/dev/vortex/api/proto/Scalars.java | 378 ----- .../vortex/api/proto/TemporalMetadatas.java | 98 -- .../main/java/dev/vortex/jni/JNIArray.java | 167 -- .../java/dev/vortex/jni/JNIArrayIterator.java | 75 - .../main/java/dev/vortex/jni/JNIDType.java | 129 -- .../src/main/java/dev/vortex/jni/JNIFile.java | 71 - .../main/java/dev/vortex/jni/JNIWriter.java | 74 - .../dev/vortex/jni/NativeArrayMethods.java | 74 - .../dev/vortex/jni/NativeDTypeMethods.java | 93 -- ...FileMethods.java => NativeDataSource.java} | 24 +- .../java/dev/vortex/jni/NativeExpression.java | 83 + .../NativeFiles.java} | 23 +- .../java/dev/vortex/jni/NativeLoader.java | 9 + .../java/dev/vortex/jni/NativeLogging.java | 14 +- .../java/dev/vortex/jni/NativePartition.java | 37 + ...teratorMethods.java => NativeRuntime.java} | 13 +- .../main/java/dev/vortex/jni/NativeScan.java | 49 + .../{api/File.java => jni/NativeSession.java} | 18 +- ...veWriterMethods.java => NativeWriter.java} | 14 +- .../src/main/proto/dtype.proto | 101 -- .../src/main/proto/expr.proto | 97 -- .../src/main/proto/scalar.proto | 39 - pom.xml | 1 + 561 files changed, 50935 insertions(+), 12063 deletions(-) rename .github/workflows/{utitcase-rust-native.yml => utcase-tantivy.yml} (64%) create mode 100644 SECURITY.md delete mode 100644 docs/docs/append-table/blob.mdx create mode 100644 docs/docs/multimodal-table/blob.mdx rename docs/docs/{append-table => multimodal-table}/data-evolution.md (100%) rename docs/docs/{append-table => multimodal-table}/global-index.mdx (72%) create mode 100644 docs/docs/multimodal-table/index.mdx rename docs/docs/{append-table => multimodal-table}/vector.mdx (67%) create mode 100644 docs/docs/project/security.md delete mode 100644 paimon-api/src/main/java/org/apache/paimon/rest/requests/RenameBranchRequest.java delete mode 100644 paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndex.java delete mode 100644 paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndexerFactory.java delete mode 100644 paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexReaderWrapper.java delete mode 100644 paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexWriterWrapper.java create mode 100644 paimon-common/src/test/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReaderTest.java delete mode 100644 paimon-common/src/test/java/org/apache/paimon/globalindex/bitmapindex/BitmapGlobalIndexTest.java create mode 100644 paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeThreadSafetyTest.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexReadThreadPool.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/operation/ChainTablePartitionExpire.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/operation/ManifestAdjacentSortedRun.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileSorter.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/operation/ManifestPickStrategy.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/operation/NormalPartitionExpire.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/schema/ColumnDirectiveUtils.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/table/source/DvAwareStats.java create mode 100644 paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeSnapshotOrderingTest.java create mode 100644 paimon-core/src/test/java/org/apache/paimon/operation/ChainTablePartitionExpireTest.java create mode 100644 paimon-core/src/test/java/org/apache/paimon/schema/ColumnDirectiveUtilsTest.java delete mode 100644 paimon-core/src/test/java/org/apache/paimon/table/BitmapGlobalIndexTableTest.java create mode 100644 paimon-core/src/test/java/org/apache/paimon/table/source/DvAwareStatsTest.java create mode 100644 paimon-core/src/test/java/org/apache/paimon/table/source/PushDownUtilsTest.java create mode 100644 paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/DedicatedSplitReadLimitTest.java create mode 100644 paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PartitionedRecordWriter.java create mode 100644 paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/mapred/PaimonOutputFormatTest.java create mode 100644 paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/objectinspector/PaimonBlobObjectInspectorTest.java create mode 100644 paimon-mosaic/pom.xml create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormat.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormatFactory.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicInputFileAdapter.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicObjects.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicReaderFactory.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsReader.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsWriter.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractor.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterFactory.java create mode 100644 paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterMetadata.java create mode 100644 paimon-mosaic/src/main/resources/META-INF/services/org.apache.paimon.format.FileFormatFactory create mode 100644 paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFileFormatTest.java create mode 100644 paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFormatReadWriteTest.java create mode 100644 paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicObjectsTest.java create mode 100644 paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicReaderWriterTest.java create mode 100644 paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractorTest.java create mode 100644 paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicWriterMetadataTest.java create mode 100644 paimon-python/pypaimon/benchmark/hdfs_io_bench.py create mode 100644 paimon-python/pypaimon/catalog/jdbc_catalog.py create mode 100644 paimon-python/pypaimon/catalog/jdbc_catalog_loader.py create mode 100644 paimon-python/pypaimon/common/merge_engine_dispatch.py create mode 100644 paimon-python/pypaimon/daft/daft_explain.py create mode 100644 paimon-python/pypaimon/filesystem/_kerberos.py create mode 100644 paimon-python/pypaimon/filesystem/hdfs_native_file_io.py create mode 100644 paimon-python/pypaimon/globalindex/btree/btree_file_meta_selector.py create mode 100644 paimon-python/pypaimon/globalindex/btree/lazy_filtered_btree_reader.py create mode 100644 paimon-python/pypaimon/ray/data_evolution_merge_into.py create mode 100644 paimon-python/pypaimon/ray/data_evolution_merge_join.py create mode 100644 paimon-python/pypaimon/ray/data_evolution_merge_transform.py create mode 100644 paimon-python/pypaimon/ray/merge_condition.py create mode 100644 paimon-python/pypaimon/read/reader/aggregate/__init__.py create mode 100644 paimon-python/pypaimon/read/reader/aggregate/aggregators.py create mode 100644 paimon-python/pypaimon/read/reader/aggregate/field_aggregator.py create mode 100644 paimon-python/pypaimon/read/reader/aggregation_merge_function.py create mode 100644 paimon-python/pypaimon/read/reader/deduplicate_merge_function.py create mode 100644 paimon-python/pypaimon/read/reader/first_row_merge_function.py create mode 100644 paimon-python/pypaimon/read/reader/format_mosaic_reader.py create mode 100644 paimon-python/pypaimon/read/reader/format_row_reader.py create mode 100644 paimon-python/pypaimon/schema/column_directive_utils.py create mode 100644 paimon-python/pypaimon/table/system/buckets_table.py create mode 100644 paimon-python/pypaimon/tests/btree_thread_safety_test.py create mode 100644 paimon-python/pypaimon/tests/column_directive_utils_test.py create mode 100644 paimon-python/pypaimon/tests/daft/daft_explain_test.py create mode 100644 paimon-python/pypaimon/tests/daft/daft_integration_test.py create mode 100644 paimon-python/pypaimon/tests/data_evolution_formats_test.py create mode 100644 paimon-python/pypaimon/tests/e2e/hdfs/README.md create mode 100644 paimon-python/pypaimon/tests/e2e/hdfs/__init__.py create mode 100644 paimon-python/pypaimon/tests/e2e/hdfs/docker-compose.yml create mode 100644 paimon-python/pypaimon/tests/e2e/hdfs/hdfs_native_e2e_test.py create mode 100644 paimon-python/pypaimon/tests/hdfs_native_test.py create mode 100644 paimon-python/pypaimon/tests/index_manifest_write_test.py create mode 100644 paimon-python/pypaimon/tests/jdbc_catalog_test.py create mode 100644 paimon-python/pypaimon/tests/nested_type_read_write_test.py create mode 100644 paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py create mode 100644 paimon-python/pypaimon/tests/system/buckets_table_test.py create mode 100644 paimon-python/pypaimon/tests/table_scan_mode_test.py create mode 100644 paimon-python/pypaimon/tests/test_aggregation_e2e.py create mode 100644 paimon-python/pypaimon/tests/test_aggregation_merge_function.py create mode 100644 paimon-python/pypaimon/tests/test_field_aggregator_registry.py create mode 100644 paimon-python/pypaimon/tests/test_field_aggregators.py create mode 100644 paimon-python/pypaimon/tests/test_first_row_e2e.py create mode 100644 paimon-python/pypaimon/tests/test_first_row_merge_function.py create mode 100644 paimon-python/pypaimon/tests/test_format_mosaic_reader_writer.py create mode 100644 paimon-python/pypaimon/tests/test_format_mosaic_table.py create mode 100644 paimon-python/pypaimon/tests/test_format_row_reader_writer.py create mode 100644 paimon-python/pypaimon/tests/test_format_row_table.py create mode 100644 paimon-python/pypaimon/tests/test_merge_engine_dispatch.py create mode 100644 paimon-python/pypaimon/tests/test_sequence_field_read.py create mode 100644 paimon-python/pypaimon/tests/test_write_merge_buffer.py create mode 100644 paimon-python/pypaimon/utils/blob_view_lookup.py create mode 100644 paimon-python/pypaimon/write/global_index_update_checker.py rename paimon-python/pypaimon/write/writer/{data_blob_writer.py => dedicated_format_writer.py} (64%) create mode 100644 paimon-python/pypaimon/write/writer/format_row_writer.py create mode 100755 paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/PaimonScanBuilderTest.scala create mode 100644 paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala delete mode 100644 paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala create mode 100644 paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/schema/PaimonMetadataColumnBase.java create mode 100644 paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/write/PaimonV2MetadataAwareDataWriter.java create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaEvolutionHelper.scala delete mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/copyinto/CopyIntoResultBuilder.scala create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoCastValidator.scala create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilder.scala create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoErrorHandler.scala create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoHelper.scala create mode 100644 paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkDataEvolutionITCase.java create mode 100644 paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommandTest.scala create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoCastValidatorTest.scala create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilderTest.scala create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoHelperTest.scala create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoOnErrorTest.scala create mode 100644 paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptionsTest.java delete mode 100644 paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexTypeUtils.java delete mode 100644 paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexTypeUtilsTest.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Array.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DType.java create mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DataSource.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Files.java create mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Partition.java create mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Scan.java create mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Session.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Binary.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/GetItem.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNotNull.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNull.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Literal.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Not.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Root.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Unknown.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/DTypes.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/EndianUtils.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Expressions.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Scalars.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/TemporalMetadatas.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArray.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArrayIterator.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIDType.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIFile.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIWriter.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeArrayMethods.java delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeDTypeMethods.java rename paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/{NativeFileMethods.java => NativeDataSource.java} (56%) create mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java rename paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/{api/ArrayIterator.java => jni/NativeFiles.java} (62%) create mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativePartition.java rename paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/{NativeArrayIteratorMethods.java => NativeRuntime.java} (74%) create mode 100644 paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeScan.java rename paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/{api/File.java => jni/NativeSession.java} (73%) rename paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/{NativeWriterMethods.java => NativeWriter.java} (70%) delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/proto/dtype.proto delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/proto/expr.proto delete mode 100644 paimon-vortex/paimon-vortex-jni/src/main/proto/scalar.proto diff --git a/.github/workflows/paimon-python-checks.yml b/.github/workflows/paimon-python-checks.yml index de00337be359..f27219e6facb 100755 --- a/.github/workflows/paimon-python-checks.yml +++ b/.github/workflows/paimon-python-checks.yml @@ -99,18 +99,6 @@ jobs: mkdir -p ${RESOURCE_DIR} cp paimon-tantivy/paimon-tantivy-jni/rust/target/release/libtantivy_jni.so ${RESOURCE_DIR}/ - - name: Clone and build Vortex native library - run: | - git clone --depth 1 -b 0.69.0 https://github.com/spiraldb/vortex.git ${RUNNER_TEMP}/vortex - cd ${RUNNER_TEMP}/vortex - cargo build --package vortex-jni --release - - - name: Copy Vortex native library to resources - run: | - RESOURCE_DIR=paimon-vortex/paimon-vortex-jni/src/main/resources/native/linux-amd64 - mkdir -p ${RESOURCE_DIR} - cp ${RUNNER_TEMP}/vortex/target/release/libvortex_jni.so ${RESOURCE_DIR}/ - - name: Verify Python version run: python --version @@ -145,11 +133,12 @@ jobs: else python -m pip install --upgrade pip pip install torch --index-url https://download.pytorch.org/whl/cpu - python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' + python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0 'datafusion>=52' python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION }}' -i https://pypi.org/simple/ if python -c "import sys; sys.exit(0 if sys.version_info >= (3, 11) else 1)"; then python -m pip install vortex-data==0.70.0 fi + python -m pip install 'paimon-mosaic>=0.1.0' fi df -h @@ -163,15 +152,6 @@ jobs: maturin build --release pip install target/wheels/tantivy-*.whl - - name: Build and install pypaimon-rust from source - if: matrix.python-version != '3.6.15' - shell: bash - run: | - git clone https://github.com/apache/paimon-rust.git /tmp/paimon-rust - cd /tmp/paimon-rust/bindings/python - maturin build --release -o dist - pip install dist/pypaimon_rust-*.whl - pip install 'datafusion>=52' - name: Run lint-python.sh shell: bash @@ -205,7 +185,7 @@ jobs: run: | python -m pip install --upgrade pip pip install torch --index-url https://download.pytorch.org/whl/cpu - python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 + python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION }}' -i https://pypi.org/simple/ - name: Run lint-python.sh shell: bash diff --git a/.github/workflows/utitcase-rust-native.yml b/.github/workflows/utcase-tantivy.yml similarity index 64% rename from .github/workflows/utitcase-rust-native.yml rename to .github/workflows/utcase-tantivy.yml index 32f53787e36c..f2ebb097db76 100644 --- a/.github/workflows/utitcase-rust-native.yml +++ b/.github/workflows/utcase-tantivy.yml @@ -16,16 +16,14 @@ # limitations under the License. ################################################################################ -name: UTCase Rust Native +name: UTCase Tantivy on: push: paths: - - 'paimon-vortex/**' - 'paimon-tantivy/**' pull_request: paths: - - 'paimon-vortex/**' - 'paimon-tantivy/**' env: @@ -37,44 +35,6 @@ concurrency: cancel-in-progress: true jobs: - vortex_test: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Set up JDK ${{ env.JDK_VERSION }} - uses: actions/setup-java@v5 - with: - java-version: ${{ env.JDK_VERSION }} - distribution: 'temurin' - - - name: Install Rust toolchain - run: | - curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable --profile minimal - echo "$HOME/.cargo/bin" >> $GITHUB_PATH - - - name: Clone and build Vortex native library - run: | - git clone --depth 1 -b 0.69.0 https://github.com/spiraldb/vortex.git ${RUNNER_TEMP}/vortex - cd ${RUNNER_TEMP}/vortex - cargo build --package vortex-jni --release - - - name: Copy native library to resources - run: | - RESOURCE_DIR=paimon-vortex/paimon-vortex-jni/src/main/resources/native/linux-amd64 - mkdir -p ${RESOURCE_DIR} - cp ${RUNNER_TEMP}/vortex/target/release/libvortex_jni.so ${RESOURCE_DIR}/ - - - name: Build and test Vortex modules - timeout-minutes: 30 - run: | - mvn -T 2C -B -ntp clean install -DskipTests - mvn -B -ntp verify -pl paimon-vortex/paimon-vortex-jni,paimon-vortex/paimon-vortex-format -Dcheckstyle.skip=true -Dspotless.check.skip=true - env: - MAVEN_OPTS: -Xmx4096m - tantivy_test: runs-on: ubuntu-latest diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000000..00de7a150e3a --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,391 @@ + + +# Apache Paimon Security Threat Model + +This document describes Apache Paimon's detailed security threat model for +maintainers and automated security triage. + +It complements the shorter public-facing security model in +[`docs/docs/project/security.md`](docs/docs/project/security.md) (published at the project website) by making +Paimon's trust assumptions, security boundaries, and recurring non-security +bug classes more explicit. + +## Purpose + +Apache Paimon is a streaming data lake platform that is often deployed as a +library and integration layer inside larger systems (Flink, Spark, Hive, and +other query engines) that provide their own authentication, authorization, and +credential management. Because of that deployment model, many bug classes that +look security-relevant in the abstract are not actually security +vulnerabilities in Paimon itself. + +This model is intended to answer: + +- what Paimon generally treats as a security vulnerability +- what Paimon generally treats as correctness, hardening, or deployment work +- which boundaries are primarily owned by Paimon versus the surrounding + catalog, engine, or service +- which issue classes should be downgraded by default by scanners + +## Scope + +This model is scoped to the Apache Paimon project itself: + +- the table format implementation (paimon-core) +- client libraries (paimon-api, paimon-common) +- the REST Catalog client and protocol (paimon-api, paimon-core) +- engine integrations (Flink, Spark, Hive connectors) +- the Python client (pypaimon) + +It is not a general threat model for every deployment that embeds Paimon. + +In particular, it does not attempt to define the complete security model for: + +- query engines or applications that embed Paimon +- storage-level authorization enforced outside Paimon +- REST Catalog server implementations (Paimon defines the client and protocol, + not the server) + +## Security Goals + +Paimon should: + +- avoid exposing secrets or delegated credentials to principals that were not + already trusted with them +- avoid creating new unauthorized capabilities in Paimon-owned components or + integrations +- avoid violating trust boundaries that Paimon itself owns, such as leaking + auth, signer, or credential-bearing state across catalog or session + boundaries in the same process +- avoid leaking delegated storage tokens (data tokens) across table or + principal boundaries + +Paimon does not aim to be the primary enforcement point for: + +- user-to-user authorization inside a query engine +- storage-level authorization (e.g., object store IAM policies) +- service-side authorization performed by a REST Catalog server +- row-level or column-level access control (Paimon relays server-provided + filters and column masking rules, but enforcement is in the server) + +## Roles + +### Operator + +The operator deploys and configures the catalog, REST Catalog server, engine, +and storage integration around Paimon. This role is trusted to choose +endpoints, warehouses, and storage integrations, configure credentials, and +decide which users may create, read, or modify tables. + +### Catalog Control Plane + +The catalog control plane is responsible for resolving tables and supplying +metadata, locations, configuration, and delegated credentials to Paimon. +This role may be implemented by: + +- a REST Catalog server +- a Hive Metastore +- a JDBC-backed catalog +- a filesystem-based catalog + +Regardless of implementation, it should not expose secrets to unintended +principals or leak credential-bearing state across unintended boundaries. + +Paimon assumes a trusted catalog or metastore, which is outside its primary +security boundary. + +### REST Catalog Server + +In REST deployments, part of the catalog control plane is implemented by a +server that returns metadata, configuration, delegated storage credentials +(data tokens), and query-level authorization (row filters and column masking) +to the client. This server is generally treated as a trusted control-plane +component. + +The REST Catalog server is responsible for: + +- authenticating clients +- authorizing catalog operations (create/drop/alter databases, tables, views, + functions) +- issuing scoped, time-limited data tokens for storage access +- providing row-level filters and column masking rules via the auth table + query API +- returning server-side configuration to merge with client configuration + +### REST Catalog Client + +In REST deployments, the client-side catalog (`RESTCatalog`, `RESTApi`) +consumes server-provided metadata, configuration, and credentials. Where the +client and server are meaningfully distinct, client-side bugs in token +handling, caching, or reuse may still be security-relevant. This is especially +true when the Paimon-owned client implementation leaks credential-bearing +state across catalog, session, or principal boundaries it is expected to +preserve. + +The REST Catalog client is responsible for: + +- sending authenticated requests using a configured `AuthProvider` +- refreshing tokens before expiration (with a configurable safe time margin) +- caching `FileIO` instances keyed by data token (via `RESTTokenFileIO`) + and evicting them when tokens expire +- not mixing data tokens or auth state across different catalog instances or + tables in the same process + +### Engine or Embedding Application + +Query engines (Flink, Spark, Hive, Trino, StarRocks, etc.) and applications +may expose only a subset of Paimon capabilities to users. They are responsible +for their own user-facing authorization boundaries unless Paimon explicitly +documents otherwise. + +### Table Writer or Maintainer + +This role may already have legitimate power to write or replace table +metadata, write or delete data files, manage snapshots, create or delete +branches and tags, and invoke destructive maintenance operations (compaction, +expiration, rollback). If a report only shows a new way to achieve the same +effect this role can already cause legitimately, it is usually not a security +issue in Paimon. + +## Trust Boundaries + +### Boundary 1: Operator-Trusted Configuration + +The following are generally treated as trusted operator or deployment inputs: + +- catalog properties (including `uri`, `warehouse`, `token.provider`) +- REST Catalog server endpoint configuration +- warehouse and storage roots +- authentication credentials +- Kerberos keytab paths and principal names + (`security.kerberos.login.keytab`, `security.kerberos.login.principal`) +- metastore wiring (Hive Metastore URI, JDBC connection strings) +- custom HTTP headers (`header.*`) + +If a report depends on the attacker controlling those values directly, it is +usually not a vulnerability in Paimon itself. + +### Boundary 2: Catalog-Supplied Metadata + +Paimon often accepts metadata locations, table properties, database +properties, schema definitions, and related control-plane information from a +catalog or metastore. By default, Paimon treats those sources as trusted. + +This means a malicious catalog supplying incorrect or malicious metadata is +usually not a Paimon vulnerability by itself. + +### Boundary 3: REST Catalog Server-Supplied Configuration and Delegated Storage Access + +In REST deployments, Paimon accepts the following from the REST Catalog server: + +- **Server configuration**: merged into client options via the `/v1/config` + endpoint, including catalog prefix and additional headers +- **Data tokens**: time-limited storage credentials returned by the + `/v1/{prefix}/databases/{database}/tables/{table}/token` endpoint, used by + `RESTTokenFileIO` to access the underlying object store +- **Auth table query responses**: row-level filters and column masking rules + returned by the `/v1/{prefix}/databases/{database}/tables/{table}/auth` + endpoint + +By default, these are treated as trusted control-plane inputs unless Paimon +explicitly documents a stronger guarantee. + +This means a malicious REST Catalog server sending dangerous configuration or +overly broad data tokens is usually not a Paimon vulnerability by itself. It +also means many client-side token-selection bugs are often correctness or +specification issues rather than security boundary failures. + +The major exception is **secret exposure**. If Paimon surfaces credentials or +secrets to a new audience that was not already trusted with them, that is +security-relevant. In particular: + +- Data tokens for one table leaking to operations on a different table +- Auth state from one catalog instance leaking into another +- Credentials appearing in logs, error messages, or serialized state + +### Boundary 4: Storage-Level Authorization + +Object store permissions (e.g., OSS, S3, HDFS ACLs) are enforced by the +storage provider and the credentials the surrounding deployment chooses to +hand to Paimon. Paimon is not the root authority for bucket- or object-level +authorization. + +Reports that depend primarily on over-broad IAM policies or permissive +storage ACLs are usually deployment-sensitive rather than product-security +issues in Paimon. + +### Boundary 5: Engine-Level User Authorization + +Paimon integrations may surface data and operations through a query engine or +application, but Paimon is not a complete user-authorization framework for +those systems. + +Paimon does provide a mechanism for the REST Catalog server to supply +row-level filters and column masking rules via `authTableQuery`, but +enforcement of those rules is a shared responsibility between the engine +integration and the catalog server. Paimon relays the rules; the engine +must apply them. + +## In-Scope Security Vulnerabilities + +The following categories are generally security-relevant in Paimon when the +report is credible and reproducible. + +### 1. Secret or Credential Disclosure to a New Audience + +Examples include: + +- catalog credentials exposed through a user-visible engine surface + (e.g., query results, EXPLAIN output, table properties) +- one catalog's credentials or auth state leaking into another catalog or + session within the same process +- data tokens for table A being used for (or exposed to) table B +- credentials or tokens logged at INFO or lower levels without redaction +- credentials surviving in serialized `RESTTokenFileIO` or `RESTApi` state + beyond their intended scope + +### 2. Paimon-Owned Trust-Boundary Violations + +Security issues exist when Paimon itself is expected to separate catalogs, +principals, or sessions and fails to do so. + +Examples include: + +- process-global auth provider or signer state crossing catalog instances + (e.g., the `FILE_IO_CACHE` in `RESTTokenFileIO` returning a `FileIO` + belonging to a different principal) +- a data token obtained for one table being reused for a different table's + data access +- auth header state from one `RESTApi` instance leaking into another + +### 3. Row-Level and Column-Level Access Control Bypass + +If Paimon's client-side handling of `authTableQuery` responses (row filters +or column masking rules) allows a caller to bypass filters that the server +intended to enforce, that is security-relevant when the bypass occurs within +Paimon-owned code rather than in the engine integration. + +## Usually Out of Scope or Non-Security by Default + +These categories may still be real bugs worth fixing, but they are not usually +security vulnerabilities in Paimon itself. + +### 1. Correctness Bugs + +Examples: + +- wrong byte offsets or stale decoded values in file formats +- incorrect merge-tree compaction producing wrong query results +- race conditions or logic bugs that do not create a new trust-boundary + violation +- snapshot or schema version conflicts that produce incorrect metadata + +### 2. Parser Hardening and Malformed-Input Robustness + +Malformed-input crashes, raw runtime exceptions from invalid JSON or Avro +data, and memory amplification from oversized manifests or schemas are usually +treated as robustness or hardening work rather than security issues in Paimon +itself. + +### 3. Malicious Catalog, Metastore, or External Service Scenarios + +Reports that require a malicious catalog, metastore, REST Catalog server, or +other external service are usually outside Paimon's primary security boundary. + +Examples: + +- a REST Catalog server returning a data token with overly broad storage + permissions +- a Hive Metastore returning a table location pointing to a sensitive path +- a REST Catalog server returning malicious row filters designed to extract + data through side channels + +### 4. Equivalent-Harm Reports + +If the actor already has a legitimate capability that can cause the same harm, +the new path is usually not a security issue. This often applies to writers or +maintainers who already control metadata layout, file layout, or destructive +maintenance operations (snapshot expiration, orphan file cleanup, branch +deletion). + +### 5. Denial of Service Through Normal Operations + +Resource exhaustion caused by legitimate but expensive operations (e.g., large +compaction, scanning many partitions, listing all snapshots) is usually +treated as an operational concern rather than a security vulnerability. + +## REST Catalog Specific Security Considerations + +### Authentication + +Paimon's REST Catalog client supports pluggable authentication through the +`AuthProvider` interface. + +Authentication providers are created via the `AuthProviderFactory` SPI, loaded +using Java's `ServiceLoader` mechanism based on the `token.provider` +configuration. The authentication provider is process-level per catalog +instance and must not share mutable state across instances. + +### Data Token Lifecycle + +When `data-token.enabled` is `true`, `RESTTokenFileIO` manages delegated +storage credentials: + +1. The client calls the table token endpoint to obtain a time-limited data + token +2. The token is cached and used to construct a `FileIO` instance for storage + access +3. Tokens are refreshed before expiration (1 hour safe time margin by default) +4. `FileIO` instances are cached in a process-global cache + (`FILE_IO_CACHE`) keyed by `RESTToken`, with a maximum size of 1000 + entries and 10-hour expiry + +Security-relevant invariants: + +- Data tokens must be scoped to specific tables by the server +- The `FILE_IO_CACHE` keys on the full `RESTToken` (token content + + expiration), so different tokens produce different `FileIO` instances +- Token refresh creates a new `RESTApi` instance from the catalog context if + the original instance is unavailable (e.g., after deserialization) + +### Kerberos + +Paimon supports Kerberos authentication for Hadoop-based deployments through +`SecurityContext` and `SecurityConfiguration`. Keytab paths and principals +are treated as trusted operator configuration. + +## Scanner Calibration Rules + +A scanner targeting Paimon should treat a finding as higher-confidence only if +it plausibly shows one of the following: + +- exposure of a secret or delegated credential to a new audience +- creation of a new unauthorized capability in a Paimon-owned component +- violation of a Paimon-owned trust boundary (e.g., cross-catalog credential + leak, cross-table data token reuse) + +A finding should be downgraded or rejected by default if it instead depends +primarily on: + +- malformed-input robustness or denial-of-service behavior +- a malicious catalog, metastore, REST Catalog server, or external service +- a principal that already has equivalent power through legitimate write or + maintenance capabilities +- operator misconfiguration (overly broad credentials, missing TLS, etc.) diff --git a/docs/docs/append-table/blob.mdx b/docs/docs/append-table/blob.mdx deleted file mode 100644 index dfe78709b2c9..000000000000 --- a/docs/docs/append-table/blob.mdx +++ /dev/null @@ -1,800 +0,0 @@ ---- -title: "Blob Storage" -sidebar_position: 7 ---- - -import Tabs from '@theme/Tabs'; -import TabItem from '@theme/TabItem'; - - - -# Blob Storage - -## Overview - -The `BLOB` (Binary Large Object) type is a data type designed for storing multimodal data such as images, videos, audio files, and other large binary objects in Paimon tables. Unlike traditional `BYTES` type which stores binary data inline with other columns, `BLOB` type stores large binary data in separate files and maintains references to them, providing better performance for large objects. - -The Blob Storage is based on Data Evolution mode. - -The Blob type is ideal for: - -- **Image Storage**: Store product images, user avatars, medical imaging data -- **Video Content**: Store video clips, surveillance footage, multimedia content -- **Audio Files**: Store voice recordings, music files, podcast episodes -- **Document Storage**: Store PDF documents, office files, large text files -- **Machine Learning**: Store embeddings, model weights, feature vectors -- **Any Large Binary Data**: Any data that is too large to store efficiently inline - -## Storage Layout - -When you define a table with a Blob column, Paimon automatically separates the storage: - -1. **Normal Data Files** (e.g., `.parquet`, `.orc`): Store regular columns (INT, STRING, etc.) -2. **Blob Data Files** (`.blob`): Store the actual blob data - -For example, given a table with schema `(id INT, name STRING, picture BLOB)`: - -``` -table/ -├── bucket-0/ -│ ├── data-uuid-0.parquet # Contains id, name columns -│ ├── data-uuid-1.blob # Contains picture blob data -│ ├── data-uuid-2.blob # Contains more picture blob data -│ └── ... -├── manifest/ -├── schema/ -└── snapshot/ -``` - -This separation provides several benefits: -- Efficient column projection (reading non-blob columns doesn't load blob data) -- Optimized file rolling based on blob size -- Better compression for regular columnar data - -For details about the blob file format structure, see [File Format - BLOB](../concepts/spec/fileformat#blob). - -## Storage Modes - -Paimon supports four storage modes for BLOB fields: - -1. **Default blob storage** - Blob bytes are written to Paimon-managed `.blob` files under the table path. - -2. **Descriptor-only storage** - Fields configured in `blob-descriptor-field` store only serialized `BlobDescriptor` bytes inline in data files. Paimon does not write `.blob` files for these fields, and writes must provide descriptor-based input. - -3. **External-storage descriptor mode** - Fields configured in `blob-external-storage-field` are a subset of `blob-descriptor-field`. At write time, Paimon writes the raw blob data to the configured `blob-external-storage-path` and stores only serialized `BlobDescriptor` bytes inline in data files. - -4. **Blob view storage** - Fields configured in `blob-view-field` store serialized `BlobViewStruct` bytes inline in data files. The struct points to a BLOB value in an upstream table by table identifier, BLOB field, and row id. The actual blob bytes are resolved from the upstream table at read time. - -This allows one table to mix raw-data BLOB fields, descriptor-only BLOB fields, descriptor-based BLOB fields backed by external storage, and view fields that reference upstream BLOB values. - -## Table Options - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
OptionRequiredDefaultTypeDescription
blob-field
No(none)StringSpecifies column names that should be stored as blob type. This is used when you want to treat a BYTES column as a BLOB. Fields listed in blob-descriptor-field or blob-view-field are also treated as BLOB fields.
blob-as-descriptor
NofalseBooleanControls read output format for blob fields. When set to true, queries return serialized BlobDescriptor bytes; when false, queries return actual blob bytes. This option is dynamic and can be changed with ALTER TABLE ... SET.
blob-descriptor-field
No(none)String - Comma-separated field names treated as BLOB fields and stored as serialized BlobDescriptor bytes inline in normal data files. - By default, all blob fields store blob bytes in separate .blob files. - If configured, one table can mix: - some BLOB fields in .blob files and some as descriptor references. -
blob-write-null-on-missing-file
NofalseBoolean - When enabled for Flink writes, if a descriptor BLOB value references a file that does not exist, Paimon writes NULL for that value and logs a warning instead of failing when reading the descriptor. -
blob-view-field
No(none)String - Comma-separated field names treated as BLOB fields and stored as serialized BlobViewStruct bytes inline in normal data files. - The field values reference BLOB values in upstream tables and are resolved at read time. - This option must not overlap with blob-descriptor-field. -
blob-view.resolve.enabled
NotrueBoolean - Controls whether blob-view-field values are resolved to the upstream BLOB - content at read time. Set this dynamic option to false when forwarding blob view - references from one view table to another view table and you want the target table to keep - referencing the original upstream BLOB. -
blob-external-storage-field
No(none)String - Comma-separated BLOB field names whose raw data should be written to external storage at write time. - This option must be a subset of blob-descriptor-field. - For these fields, Paimon stores serialized BlobDescriptor bytes inline in data files. -
blob-external-storage-path
No(none)String - External storage path used by fields configured in blob-external-storage-field. - Orphan file cleanup is not applied to this path. -
blob.target-file-size
No(same as target-file-size)MemorySizeTarget size for blob files. When a blob file reaches this size, a new file is created. If not specified, uses the same value as target-file-size.
row-tracking.enabled
Yes*falseBooleanMust be enabled for blob tables to support row-level operations.
data-evolution.enabled
Yes*falseBooleanMust be enabled for blob tables to support schema evolution.
- -*Required for blob functionality to work correctly. - -Specifically, if the storage system of the input BlobDescriptor differs from that used by Paimon, -you can specify the storage configuration for the input blob descriptor using the prefix -`blob-descriptor.`. For example, if the source data is stored in a different OSS endpoint, -you can configure it as below (using flink sql as an example): -```sql -CREATE TABLE image_table ( - id INT, - name STRING, - image BYTES -) WITH ( - 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-field' = 'image', - 'fs.oss.endpoint' = 'aaa', -- This is for Paimon's own config - 'blob-descriptor.fs.oss.endpoint' = 'bbb' -- This is for input blob descriptors' config -); -``` - -## SQL Usage - -### Creating a Table - - - - - -```sql --- Create a table with a blob field --- Note: In Flink SQL, use BYTES type and specify blob-field option -CREATE TABLE image_table ( - id INT, - name STRING, - image BYTES -) WITH ( - 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-field' = 'image' -); -``` - - - - - -```sql --- Create a table with a blob field --- Note: In Spark SQL, use BINARY type and specify blob-field option -CREATE TABLE image_table ( - id INT, - name STRING, - image BINARY -) TBLPROPERTIES ( - 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-field' = 'image' -); -``` - - - - - -### Inserting Blob Data - - - - - -```sql --- Insert data with inline blob bytes -INSERT INTO image_table VALUES (1, 'sample', X'89504E470D0A1A0A'); - --- Insert from another table -INSERT INTO image_table -SELECT id, name, content FROM source_table; -``` - - - - - -```sql --- Insert data with inline blob bytes -INSERT INTO image_table VALUES (1, 'sample', X'89504E470D0A1A0A'); -``` - - - - - -### Querying Blob Data - -```sql --- Select all columns including blob -SELECT * FROM image_table; - --- Select only non-blob columns (efficient - doesn't load blob data) -SELECT id, name FROM image_table; - --- Select specific rows with blob -SELECT * FROM image_table WHERE id = 1; -``` - -### Blob Read Output Mode (`blob-as-descriptor`) - -`blob-as-descriptor` only controls how blob values are returned when reading. - -```sql --- Return descriptor bytes -ALTER TABLE blob_table SET ('blob-as-descriptor' = 'true'); -SELECT image FROM blob_table; - --- Return actual blob bytes -ALTER TABLE blob_table SET ('blob-as-descriptor' = 'false'); -SELECT image FROM blob_table; -``` - -### Blob View - -Blob view is useful when a downstream table should reference BLOB values already stored in an upstream table, without copying the bytes or creating new `.blob` files. A blob view field stores only a small `BlobViewStruct` inline. When the field is read, Paimon resolves the referenced BLOB from the upstream table. - -Blob view requires: - -- the upstream table to have row tracking enabled, so each row has a stable `_ROW_ID` -- the downstream field to be listed in `blob-view-field` -- writes to provide a serialized `BlobViewStruct`; in Flink SQL, use the built-in `sys.blob_view` function - -The Flink SQL function signature is: - -```sql -sys.blob_view(table_name, field_name, row_id) -``` - -Arguments: - -- `table_name`: the upstream table name. It must be fully qualified as `database.table` or `catalog.database.table`. Unqualified table names are rejected. -- `field_name`: the upstream BLOB field name. -- `row_id`: the `_ROW_ID` value from the upstream row-tracking table. - -The following example writes a downstream table whose `image_ref` field views the `image` field in `image_table`: - -```sql -CREATE TABLE image_table ( - id INT, - name STRING, - image BYTES -) WITH ( - 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-field' = 'image' -); - -CREATE TABLE image_view_table ( - id INT, - label STRING, - image_ref BYTES -) WITH ( - 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-view-field' = 'image_ref' -); - -INSERT INTO image_view_table -SELECT - id, - name AS label, - sys.blob_view('default.image_table', 'image', _ROW_ID) -FROM `image_table$row_tracking`; -``` - -If the current Paimon catalog name is included in the table name, the function also accepts `catalog.database.table`: - -```sql -SELECT sys.blob_view('my_catalog.default.image_table', 'image', _ROW_ID) -FROM `image_table$row_tracking`; -``` - -Reads from `image_view_table.image_ref` return the referenced BLOB bytes in the same way as normal blob fields. The referenced upstream table and row must remain available for the view to be resolved. - -#### Forward Blob View References - -By default, reading a blob view field resolves the `BlobViewStruct` and returns the upstream BLOB -content. If you want to import data from one blob view table into another blob view table without -copying the BLOB bytes, read the source table with `blob-view.resolve.enabled=false` and write the -result into a target field configured by `blob-view-field`. - -With this option disabled, Paimon preserves the serialized `BlobViewStruct` during reads. When the -preserved value is written to another blob view field, the target table stores the same upstream -reference instead of creating a chained view reference. - -For example, if table `T1` contains blob view references to BLOBs in table `T0`, importing `T1` into -`T2` with `blob-view.resolve.enabled=false` makes `T2` keep referencing `T0` directly. - -```sql -CREATE TABLE t2 ( - id INT, - image_ref BYTES -) WITH ( - 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-view-field' = 'image_ref' -); - --- Flink SQL example: the source table is read with blob view resolution disabled. -INSERT INTO t2 -SELECT id, image_ref -FROM t1 /*+ OPTIONS('blob-view.resolve.enabled'='false') */; -``` - -### MERGE INTO Support - -For Data Evolution writes in Flink and Spark: - -- raw-data BLOB columns are still rejected in partial-column `MERGE INTO` updates -- descriptor-based BLOB columns are allowed - -## Java API Usage - -### Creating a Table - -The following example demonstrates how to create a table with a blob column, write blob data, and read it back using Paimon's Java API. - -```java -import org.apache.paimon.catalog.Catalog; -import org.apache.paimon.catalog.CatalogContext; -import org.apache.paimon.catalog.CatalogFactory; -import org.apache.paimon.catalog.Identifier; -import org.apache.paimon.CoreOptions; -import org.apache.paimon.data.BinaryString; -import org.apache.paimon.data.Blob; -import org.apache.paimon.data.BlobData; -import org.apache.paimon.data.GenericRow; -import org.apache.paimon.data.InternalRow; -import org.apache.paimon.fs.Path; -import org.apache.paimon.fs.SeekableInputStream; -import org.apache.paimon.reader.RecordReader; -import org.apache.paimon.schema.Schema; -import org.apache.paimon.table.Table; -import org.apache.paimon.table.sink.BatchTableCommit; -import org.apache.paimon.table.sink.BatchTableWrite; -import org.apache.paimon.table.sink.BatchWriteBuilder; -import org.apache.paimon.table.source.ReadBuilder; -import org.apache.paimon.types.DataTypes; - -import java.io.ByteArrayInputStream; -import java.nio.file.Files; - -public class BlobTableExample { - - public static void main(String[] args) throws Exception { - // 1. Create catalog - Path warehouse = new Path("/tmp/paimon-warehouse"); - Catalog catalog = CatalogFactory.createCatalog(CatalogContext.create(warehouse)); - catalog.createDatabase("my_db", true); - - // 2. Define schema with BLOB column - Schema schema = Schema.newBuilder() - .column("id", DataTypes.INT()) - .column("name", DataTypes.STRING()) - .column("image", DataTypes.BLOB()) // Blob column for storing images - .option(CoreOptions.ROW_TRACKING_ENABLED.key(), "true") - .option(CoreOptions.DATA_EVOLUTION_ENABLED.key(), "true") - .build(); - - // 3. Create table - Identifier tableId = Identifier.create("my_db", "image_table"); - catalog.createTable(tableId, schema, true); - Table table = catalog.getTable(tableId); - - // 4. Write blob data - writeBlobData(table); - - // 5. Read blob data back - readBlobData(table); - } - - private static void writeBlobData(Table table) throws Exception { - BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); - - try (BatchTableWrite write = writeBuilder.newWrite(); - BatchTableCommit commit = writeBuilder.newCommit()) { - - // Method 1: Create blob from byte array - byte[] imageBytes = loadImageBytes("/path/to/image1.png"); - GenericRow row1 = GenericRow.of( - 1, - BinaryString.fromString("image1"), - new BlobData(imageBytes) - ); - write.write(row1); - - // Method 2: Create blob from local file - GenericRow row2 = GenericRow.of( - 2, - BinaryString.fromString("image2"), - Blob.fromLocal("/path/to/image2.png") - ); - write.write(row2); - - // Method 3: Create blob from InputStream (useful for streaming data) - byte[] streamData = loadImageBytes("/path/to/image3.png"); - ByteArrayInputStream inputStream = new ByteArrayInputStream(streamData); - GenericRow row3 = GenericRow.of( - 3, - BinaryString.fromString("image3"), - Blob.fromInputStream(() -> SeekableInputStream.wrap(inputStream)) - ); - write.write(row3); - - // Method 4: Create blob from HTTP URL - GenericRow row4 = GenericRow.of( - 4, - BinaryString.fromString("remote_image"), - Blob.fromHttp("https://example.com/image.png") - ); - write.write(row4); - - // Commit all writes - commit.commit(write.prepareCommit()); - } - - System.out.println("Successfully wrote 4 rows with blob data"); - } - - private static void readBlobData(Table table) throws Exception { - ReadBuilder readBuilder = table.newReadBuilder(); - RecordReader reader = - readBuilder.newRead().createReader(readBuilder.newScan().plan()); - - reader.forEachRemaining(row -> { - int id = row.getInt(0); - String name = row.getString(1).toString(); - Blob blob = row.getBlob(2); - - // Method 1: Read blob as byte array (loads entire blob into memory) - byte[] data = blob.toData(); - System.out.println("Row " + id + ": " + name + ", blob size: " + data.length); - - // Method 2: Read blob as stream (better for large blobs) - try (SeekableInputStream in = blob.newInputStream()) { - // Process stream without loading entire blob into memory - byte[] buffer = new byte[1024]; - int bytesRead; - long totalSize = 0; - while ((bytesRead = in.read(buffer)) != -1) { - totalSize += bytesRead; - // Process buffer... - } - System.out.println("Streamed " + totalSize + " bytes"); - } catch (Exception e) { - e.printStackTrace(); - } - }); - } - - private static byte[] loadImageBytes(String path) throws Exception { - return Files.readAllBytes(java.nio.file.Path.of(path)); - } -} -``` - -### Construct blob from different sources - -```java -// From byte array (data already in memory) -Blob blob = Blob.fromData(imageBytes); - -// From local file system -Blob blob = Blob.fromLocal("/path/to/image.png"); - -// From any FileIO (supports HDFS, S3, OSS, etc.) -FileIO fileIO = FileIO.get(new Path("s3://bucket"), catalogContext); -Blob blob = Blob.fromFile(fileIO, "s3://bucket/path/to/image.png"); - -// From FileIO with offset and length (read partial file) -Blob blob = Blob.fromFile(fileIO, "s3://bucket/large-file.bin", 1024, 2048); - -// From HTTP/HTTPS URL -Blob blob = Blob.fromHttp("https://example.com/image.png"); - -// From InputStream supplier (lazy loading) -Blob blob = Blob.fromInputStream(() -> new FileInputStream("/path/to/image.png")); - -// From BlobDescriptor (reconstruct blob reference from descriptor) -BlobDescriptor descriptor = new BlobDescriptor("s3://bucket/path/to/image.png", 0, 1024); -UriReader uriReader = UriReader.fromFile(fileIO); -Blob blob = Blob.fromDescriptor(uriReader, descriptor); -``` - -### Querying Blob Data - -```java -// Get blob from row (column index 2 in this example) -Blob blob = row.getBlob(2); - -// Read as byte array (simple but loads entire blob into memory) -byte[] data = blob.toData(); - -// Read as stream (recommended for large blobs) -try (SeekableInputStream in = blob.newInputStream()) { - // SeekableInputStream supports random access - in.seek(100); // Jump to position 100 - byte[] buffer = new byte[1024]; - int bytesRead = in.read(buffer); -} - -// Get blob descriptor (for reference-based blobs) -// Note: Only works for BlobRef, not BlobData -BlobDescriptor descriptor = blob.toDescriptor(); -String uri = descriptor.uri(); // e.g., "s3://bucket/path/to/blob" -long offset = descriptor.offset(); // Starting position in the file -long length = descriptor.length(); // Length of the blob data -``` - -### Descriptor-Aware Write Behavior - -Paimon write path is descriptor-aware automatically: - -1. For blob fields stored in `.blob` files, input can be either blob bytes or a `BlobDescriptor`. -2. For fields configured in `blob-descriptor-field`, Paimon stores descriptor bytes inline in data files (no `.blob` files for those fields), and input must be a descriptor. -3. For fields configured in `blob-external-storage-field`, Paimon writes the blob data to `blob-external-storage-path` and stores descriptor bytes inline in data files. -4. This behavior does not depend on `blob-as-descriptor`. - -```java -import org.apache.paimon.catalog.Catalog; -import org.apache.paimon.catalog.CatalogContext; -import org.apache.paimon.catalog.CatalogFactory; -import org.apache.paimon.catalog.Identifier; -import org.apache.paimon.CoreOptions; -import org.apache.paimon.data.BinaryString; -import org.apache.paimon.data.Blob; -import org.apache.paimon.data.BlobData; -import org.apache.paimon.data.BlobDescriptor; -import org.apache.paimon.data.GenericRow; -import org.apache.paimon.data.InternalRow; -import org.apache.paimon.fs.Path; -import org.apache.paimon.reader.RecordReader; -import org.apache.paimon.schema.Schema; -import org.apache.paimon.table.Table; -import org.apache.paimon.table.sink.BatchTableCommit; -import org.apache.paimon.table.sink.BatchTableWrite; -import org.apache.paimon.table.sink.BatchWriteBuilder; -import org.apache.paimon.table.source.ReadBuilder; -import org.apache.paimon.types.DataTypes; - -public class BlobDescriptorExample { - - public static void main(String[] args) throws Exception { - Path warehouse = new Path("s3://my-bucket/paimon-warehouse"); - CatalogContext catalogContext = CatalogContext.create(warehouse); - Catalog catalog = CatalogFactory.createCatalog(catalogContext); - catalog.createDatabase("my_db", true); - - // Create table: store "video" as descriptor bytes inline - Schema schema = Schema.newBuilder() - .column("id", DataTypes.INT()) - .column("name", DataTypes.STRING()) - .column("video", DataTypes.BLOB()) - .option(CoreOptions.ROW_TRACKING_ENABLED.key(), "true") - .option(CoreOptions.DATA_EVOLUTION_ENABLED.key(), "true") - .option(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), "video") - .build(); - - Identifier tableId = Identifier.create("my_db", "video_table"); - catalog.createTable(tableId, schema, true); - Table table = catalog.getTable(tableId); - - // Write blob using descriptor reference - writeLargeBlobWithDescriptor(table); - - // Read blob data - readBlobData(table); - } - - private static void writeLargeBlobWithDescriptor(Table table) throws Exception { - BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); - - try (BatchTableWrite write = writeBuilder.newWrite(); - BatchTableCommit commit = writeBuilder.newCommit()) { - - // For a very large file (e.g., 2GB video), instead of loading into memory: - // byte[] hugeVideo = Files.readAllBytes(...); // This would cause OutOfMemoryError! - // - // Create a descriptor reference to external blob - String externalUri = "s3://my-bucket/videos/large_video.mp4"; - long fileSize = 2L * 1024 * 1024 * 1024; // 2GB - - BlobDescriptor descriptor = new BlobDescriptor(externalUri, 0, fileSize); - // file io should be accessible to externalUri - FileIO fileIO = Table.fileIO(); - UriReader uriReader = UriReader.fromFile(fileIO); - Blob blob = Blob.fromDescriptor(uriReader, descriptor); - - GenericRow row = GenericRow.of( - 1, - BinaryString.fromString("large_video"), - blob); - write.write(row); - - commit.commit(write.prepareCommit()); - } - - System.out.println("Successfully wrote large blob using descriptor reference"); - } - - private static void readBlobData(Table table) throws Exception { - ReadBuilder readBuilder = table.newReadBuilder(); - RecordReader reader = - readBuilder.newRead().createReader(readBuilder.newScan().plan()); - - reader.forEachRemaining(row -> { - int id = row.getInt(0); - String name = row.getString(1).toString(); - Blob blob = row.getBlob(2); - - // Field is configured in blob-descriptor-field, so descriptor is stored inline - BlobDescriptor descriptor = blob.toDescriptor(); - System.out.println("Row " + id + ": " + name); - System.out.println(" Blob URI: " + descriptor.uri()); - System.out.println(" Length: " + descriptor.length()); - }); - } -} -``` - -**Reading blob data with different output modes:** - -The `blob-as-descriptor` option affects only read output: - -```sql --- When blob-as-descriptor = true: Returns BlobDescriptor bytes (reference to Paimon blob file) -ALTER TABLE video_table SET ('blob-as-descriptor' = 'true'); -SELECT * FROM video_table; -- Returns serialized BlobDescriptor - --- When blob-as-descriptor = false: Returns actual blob bytes -ALTER TABLE video_table SET ('blob-as-descriptor' = 'false'); -SELECT * FROM video_table; -- Returns actual blob bytes from Paimon storage -``` - -### Blob storage mode: DESCRIPTOR ONLY - -If you want downstream tables to **reuse** upstream blob files (no copying and no new .blob files), configure the target blob field(s): - -```sql -'blob-descriptor-field' = 'image' -``` - -For these configured fields, Paimon stores only serialized BlobDescriptor bytes in normal data files. Reading the blob follows the descriptor URI to access bytes, and writing requires descriptor input for those fields. - -### Blob storage mode: EXTERNAL STORAGE - -If you want Paimon to write raw blob data to a separate external location while keeping only descriptor bytes inline, configure the target blob field(s): - -```sql -'blob-descriptor-field' = 'image', -'blob-external-storage-field' = 'image', -'blob-external-storage-path' = 'oss://bucket/path/' -``` - -For these configured fields: - -- raw blob data is written to the configured external storage path -- normal data files keep only serialized BlobDescriptor bytes -- writes can still start from raw BLOB input -- the field is treated as descriptor-based for operations such as `MERGE INTO` - -For the Python equivalent, see [Blob Storage in pypaimon](../pypaimon/blob). - -## Limitations - -1. **Append Table Only**: Blob type is designed for append-only tables. Primary key tables are not supported. -2. **No Predicate Pushdown**: Blob columns cannot be used in filter predicates. -3. **No Statistics**: Statistics collection is not supported for blob columns. -4. **Required Options**: `row-tracking.enabled` and `data-evolution.enabled` must be set to `true`. -5. **External Storage Cleanup**: Files written through `blob-external-storage-path` are outside Paimon's orphan file cleanup scope. -6. **Blob View Dependency**: Blob view fields depend on the referenced upstream table and row. If the upstream data is removed or no longer readable, the view cannot be resolved. - -## Best Practices - -1. **Use Column Projection**: Always select only the columns you need. Avoid `SELECT *` if you don't need blob data. - -2. **Set Appropriate Target File Size**: Configure `blob.target-file-size` based on your blob sizes. Larger values mean fewer files but larger individual files. - -3. **Use Descriptor Fields When Reusing External Blob Files**: Configure `blob-descriptor-field` for fields that should keep descriptor references instead of writing new `.blob` files. - -4. **Use External-Storage Fields When Accepting Raw Input But Storing Descriptors**: Configure `blob-external-storage-field` together with `blob-external-storage-path` when upstream writes raw blob bytes but you want descriptor-based storage. - -5. **Manage External Storage Lifecycle Separately**: Files written to `blob-external-storage-path` are not cleaned up by Paimon, so retention and deletion should be managed externally. - -6. **Use Blob View to Avoid Copying BLOB Data**: Configure `blob-view-field` when a downstream table only needs to reference BLOB values from an upstream table. - -7. **Use Partitioning**: Partition your blob tables by date or other dimensions to improve query performance and data management. diff --git a/docs/docs/append-table/bucketed.mdx b/docs/docs/append-table/bucketed.mdx index aac4f02c1967..eb77931e3724 100644 --- a/docs/docs/append-table/bucketed.mdx +++ b/docs/docs/append-table/bucketed.mdx @@ -1,5 +1,5 @@ --- -title: "Bucketed" +title: "Bucketed Append" sidebar_position: 3 --- diff --git a/docs/docs/append-table/index.mdx b/docs/docs/append-table/index.mdx index 2014eb336ef8..152cb9c648a0 100644 --- a/docs/docs/append-table/index.mdx +++ b/docs/docs/append-table/index.mdx @@ -1,5 +1,5 @@ --- -title: "Table w/o PK" +title: "Append Table" sidebar_position: 3 --- diff --git a/docs/docs/concepts/data-types.md b/docs/docs/concepts/data-types.md index 39e367b46a4a..d8149e0c078b 100644 --- a/docs/docs/concepts/data-types.md +++ b/docs/docs/concepts/data-types.md @@ -184,7 +184,7 @@ All data types supported by Paimon are as follows: BLOB Data type of a binary large object.

Designed for storing large binary data such as images, videos, audio files, and other multimodal data. Unlike BYTES type which stores data inline, BLOB stores large binary data in separate files and maintains references to them, providing better performance for large objects.

- Note: Requires 'row-tracking.enabled' and 'data-evolution.enabled' to be set to true. See Blob Type for details. + Note: Requires 'row-tracking.enabled' and 'data-evolution.enabled' to be set to true. See Blob Type for details. diff --git a/docs/docs/concepts/spec/fileformat.md b/docs/docs/concepts/spec/fileformat.md index a38a667cbb83..39092054a0e9 100644 --- a/docs/docs/concepts/spec/fileformat.md +++ b/docs/docs/concepts/spec/fileformat.md @@ -24,9 +24,11 @@ under the License. # File Format -Currently, supports Parquet, Avro, ORC, CSV, JSON, Lance, and Row file formats. +Currently, supports Parquet, Avro, ORC, CSV, JSON, Lance, Vortex, Mosaic, and Row file formats. - Recommended column format is Parquet, which has a high compression rate and fast column projection queries. - Recommended row based format is Avro, which has good performance on reading and writing full row (all columns). +- Recommended format for wide tables is [Mosaic](https://paimon.apache.org/docs/mosaic/), a columnar-bucket hybrid format with column bucketing for parallel I/O. +- Recommended columnar format for point lookups is [Vortex](https://github.com/spiraldb/vortex), which uses adaptive encoding for excellent point-query performance and efficient vector data compression. - Recommended format for row-number based O(1) lookups is Row, which stores data in row-oriented blocks with ZSTD compression and supports fast random access by row number. - Recommended testing format is CSV, which has better readability but the worst read-write performance. - Recommended format for ML workloads is Lance, which is optimized for vector search and machine learning use cases. @@ -755,6 +757,60 @@ Limitations: 1. Lance file format does not support `MAP` type. 2. Lance file format does not support `TIMESTAMP_LOCAL_ZONE` type. +## VORTEX + +[Vortex](https://github.com/spiraldb/vortex) is a columnar file format that uses adaptive, data-dependent encodings to achieve high compression ratios while maintaining fast scan performance. It supports native predicate pushdown and efficient column projection. + +Key features: +- **Adaptive Encoding**: Automatically selects the best encoding per column based on data distribution +- **Native Predicate Pushdown**: Supports filter expressions pushed down to the scan layer +- **Column Projection**: Only reads requested columns from disk + +Limitations: +1. Vortex does not support `MAP` or `MULTISET` types. + +## MOSAIC + +[Mosaic](https://paimon.apache.org/docs/mosaic/) is a columnar-bucket hybrid format optimized for wide tables. It groups columns into buckets and compresses each bucket independently with ZSTD, enabling efficient column projection that only reads the buckets containing requested columns. + +Key features: +- **Column Bucketing**: Columns are grouped into configurable buckets for parallel I/O, significantly reducing read amplification on wide tables +- **Row Group Statistics**: Per-row-group min/max/null_count statistics enable row group skipping during scan +- **ZSTD Compression**: All data is compressed with ZSTD (configurable level) +- **Arrow-native**: Uses Apache Arrow as the in-memory representation for zero-copy integration + +Format Options: + + + + + + + + + + + + + + + + + + + + + + + + +
OptionDefaultTypeDescription
mosaic.num-buckets
autoIntegerNumber of column buckets for parallel I/O. When set to 0 or not specified, the format auto-determines the bucket count.
mosaic.stats-columns
(empty)StringComma-separated column names to collect min/max statistics for filter pushdown. Empty means no statistics are collected.
+ +Limitations: +1. Mosaic does not support complex types: ARRAY, MAP, MULTISET, ROW, VARIANT, BLOB, VECTOR. + +For more details, see the [Mosaic documentation](https://paimon.apache.org/docs/mosaic/). + ## ROW The Row format is a row-oriented storage format designed for O(1) random access by row number. Data is organized in blocks with ZSTD Level 1 compression. Each block contains complete rows serialized in a compact binary format with an offset array for direct row positioning. @@ -803,4 +859,4 @@ Limitations: 2. BLOB format does not support predicate pushdown. 3. Statistics collection is not supported for BLOB columns. -For usage details, configuration options, and examples, see [Blob Type](../../append-table/blob). +For usage details, configuration options, and examples, see [Blob Type](../../multimodal-table/blob). diff --git a/docs/docs/concepts/system-tables.mdx b/docs/docs/concepts/system-tables.mdx index 59ca5742440d..164d68a61cca 100644 --- a/docs/docs/concepts/system-tables.mdx +++ b/docs/docs/concepts/system-tables.mdx @@ -501,6 +501,40 @@ SELECT * FROM my_table$table_indexes; */ ``` +### Row Tracking Table + +If you need to query the unique row id assigned to each row in an append table, you can use the `row_tracking` system table. +The `row_tracking` table appends `_ROW_ID` and `_SEQUENCE_NUMBER` metadata columns to the original table schema. + +:::note + +The table must have `'row-tracking.enabled' = 'true'` set. This feature is only supported for append tables. + +::: + +```sql +SELECT * FROM my_table$row_tracking; + +/* ++----------+-----------+---------+------------------+ +| id | data | _ROW_ID | _SEQUENCE_NUMBER | ++----------+-----------+---------+------------------+ +| 11 | a | 0 | 1 | +| 22 | b | 1 | 1 | ++----------+-----------+---------+------------------+ +2 rows in set +*/ +``` + +- `_ROW_ID`: A globally unique row identifier within the table, assigned during write. +- `_SEQUENCE_NUMBER`: The sequence number (snapshot id) when the row was written. + +You can also select these columns directly from the original table (without using the system table) when row tracking is enabled: + +```sql +SELECT *, _ROW_ID, _SEQUENCE_NUMBER FROM my_table; +``` + ## Global System Table Global system tables contain the statistical information of all the tables exists in paimon. For convenient of searching, we create a reference system database called `sys`. diff --git a/docs/docs/flink/procedures.md b/docs/docs/flink/procedures.md index c1286f2be6c7..31601354f1a8 100644 --- a/docs/docs/flink/procedures.md +++ b/docs/docs/flink/procedures.md @@ -342,7 +342,7 @@ All available procedures are listed below. matched_update_set => 'matchedUpdateSet',
sink_parallelism => sinkParallelism)

- To perform "MERGE INTO" syntax specially implemented for data-evolution tables. Please see data evolution for more information. + To perform "MERGE INTO" syntax specially implemented for data-evolution tables. Please see data evolution for more information. -- for Flink 1.18
CALL [catalog].sys.data_evolution_merge_into('default.T', '', '', 'S', 'T.id=S.id', 'name=S.name', 2)

@@ -1004,22 +1004,40 @@ All available procedures are listed below. To create a global index on a table for accelerating queries. Arguments:
  • table(required): the target table identifier.
  • index_column(required): the column name to build index on.
  • -
  • index_type(required): the type of global index, supported types include 'bitmap', 'btree', 'lumina', 'tantivy-fulltext'.
  • +
  • index_type(required): the type of global index, supported types include 'btree', 'lumina', 'tantivy-fulltext'.
  • partitions(optional): partition filter for selective index creation.
  • options(optional): additional dynamic options for index creation.
  • - -- Create bitmap index
    + -- Create btree index
    CALL sys.create_global_index(
    `table` => 'default.T',
    `index_column` => 'name',
    - `index_type` => 'bitmap')

    + `index_type` => 'btree')

    -- Create index for specific partitions
    CALL sys.create_global_index(
    `table` => 'default.T',
    `index_column` => 'name',
    - `index_type` => 'bitmap',
    - `partitions` => 'pt=p1;pt=p2') + `index_type` => 'btree',
    + `partitions` => 'pt=p1;pt=p2')

    + -- Create Tantivy full-text index with ngram tokenizer
    + CALL sys.create_global_index(
    + `table` => 'default.T',
    + `index_column` => 'content',
    + `index_type` => 'tantivy-fulltext',
    + `options` => 'tantivy.tokenizer=ngram,tantivy.ngram.min-gram=2,tantivy.ngram.max-gram=2')

    + -- Create Tantivy full-text index with jieba tokenizer
    + CALL sys.create_global_index(
    + `table` => 'default.T',
    + `index_column` => 'content',
    + `index_type` => 'tantivy-fulltext',
    + `options` => 'tantivy.tokenizer=jieba')

    + -- Create Tantivy full-text index with a custom analyzer
    + CALL sys.create_global_index(
    + `table` => 'default.T',
    + `index_column` => 'content',
    + `index_type` => 'tantivy-fulltext',
    + `options` => 'tantivy.tokenizer=simple,tantivy.stem=true,tantivy.remove-stop-words=true') @@ -1035,20 +1053,20 @@ All available procedures are listed below. To drop global index files from a table. Arguments:
  • table(required): the target table identifier.
  • index_column(required): the column name for which to drop the index.
  • -
  • index_type(required): the type of global index to drop, e.g., 'bitmap', 'btree'.
  • +
  • index_type(required): the type of global index to drop, e.g., 'btree'.
  • partitions(optional): partition specification for selective index deletion.
  • - -- Drop all bitmap indexes for column 'name'
    + -- Drop all btree indexes for column 'name'
    CALL sys.drop_global_index(
    `table` => 'default.T',
    `index_column` => 'name',
    - `index_type` => 'bitmap')

    + `index_type` => 'btree')

    -- Drop indexes only for specific partitions
    CALL sys.drop_global_index(
    `table` => 'default.T',
    `index_column` => 'name',
    - `index_type` => 'bitmap',
    + `index_type` => 'btree',
    `partitions` => 'pt=p1;pt=p2') diff --git a/docs/docs/index.mdx b/docs/docs/index.mdx index 0eb7428d24bc..1561cfa3f369 100644 --- a/docs/docs/index.mdx +++ b/docs/docs/index.mdx @@ -31,6 +31,9 @@ under the License.

    Data Lake Platform — unified batch, streaming, and multimodal AI in a single lake format.

    +

    +Concepts · Configurations · Download +

    @@ -41,27 +44,27 @@ Data Lake Platform — unified batch, streaming, and multimodal AI in a single l
    Petabyte-scale tables with time travel, fast scan planning, schema evolution, and incremental clustering.
    - -
    📖
    +
    +
    📋
    -

    Concepts

    -

    Core architecture, data types, catalogs, and system tables

    +

    Append Table

    +

    Append-only tables, incremental clustering, and streaming append

    - -
    📋
    +
    +
    -

    Table w/o PK

    -

    Append-only tables, incremental clustering, blob & vector support

    +

    Spark

    +

    Quick start, SQL operations, DataFrames, structured streaming

    - +
    ⚙️

    Maintenance

    -

    Configuration, snapshots, tags, metrics, and performance tuning

    +

    Snapshots, tags, metrics, compaction, and performance tuning

    @@ -76,7 +79,7 @@ Data Lake Platform — unified batch, streaming, and multimodal AI in a single l
    🔑
    -

    Table with PK

    +

    PrimaryKey Table

    Merge engines, changelog producers, compaction, and streaming updates

    @@ -105,11 +108,11 @@ Data Lake Platform — unified batch, streaming, and multimodal AI in a single l
    Vector search, full-text search, blob tables, and native Python SDK for ML pipelines.
    - -
    +
    +
    🧩
    -

    Spark

    -

    Quick start, SQL operations, DataFrames, structured streaming

    +

    Multimodal Table

    +

    Data evolution, blob storage, vector storage, and global index

    diff --git a/docs/docs/learn-paimon/scenario-guide.mdx b/docs/docs/learn-paimon/scenario-guide.mdx index 877c055f9d8c..333c91689a2e 100644 --- a/docs/docs/learn-paimon/scenario-guide.mdx +++ b/docs/docs/learn-paimon/scenario-guide.mdx @@ -43,7 +43,7 @@ configurations that are suited for different scenarios. | High-frequency point queries on key | Append Table | `bucket = N, bucket-key = col` | | Queue-like ordered streaming | Append Table | `bucket = N, bucket-key = col` | | Large-scale OLAP with ad-hoc queries | Append Table | Incremental Clustering | -| Store images / videos / documents | Append Table (Blob) | `blob-field`, Data Evolution enabled | +| Store images / videos / documents | Append Table (Blob) | `__BLOB_FIELD` comment, Data Evolution enabled | | AI vector search / RAG | Append Table (Vector) | `VECTOR` type, Global Index (DiskANN) | | AI feature engineering & column evolution | Append Table | `data-evolution.enabled = true` | | Python AI pipeline (Ray / PyTorch) | Append Table | PyPaimon SDK | @@ -338,7 +338,7 @@ See [Bucketed Streaming](../append-table/bucketed#bucketed-streaming). Paimon is a multimodal lakehouse for AI. You can keep multimodal data, metadata, and embeddings in the same table and query them via vector search, full-text search, or SQL. All multimodal features are built on top of Append Tables with -[Data Evolution](../append-table/data-evolution) mode enabled. +[Data Evolution](../multimodal-table/data-evolution) mode enabled. ### Scenario 8: Storing Multimodal Data (Blob Table) @@ -356,11 +356,10 @@ CREATE TABLE image_table ( id INT, name STRING, label STRING, - image BYTES + image BYTES COMMENT '__BLOB_FIELD' ) WITH ( 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-field' = 'image' + 'data-evolution.enabled' = 'true' ); ``` @@ -373,11 +372,10 @@ CREATE TABLE image_table ( id INT, name STRING, label STRING, - image BINARY + image BINARY COMMENT '__BLOB_FIELD' ) TBLPROPERTIES ( 'row-tracking.enabled' = 'true', - 'data-evolution.enabled' = 'true', - 'blob-field' = 'image' + 'data-evolution.enabled' = 'true' ); ``` @@ -385,7 +383,7 @@ CREATE TABLE image_table ( -**Why:** The [Blob Storage](../append-table/blob) separates large binary data into dedicated `.blob` files +**Why:** The [Blob Storage](../multimodal-table/blob) separates large binary data into dedicated `.blob` files while metadata stays in standard columnar files (Parquet/ORC). This means: - `SELECT id, name, label FROM image_table` does **not** load any blob data — very fast. @@ -398,12 +396,10 @@ while metadata stays in standard columnar files (Parquet/ORC). This means: CREATE TABLE video_table ( id INT, title STRING, - video BYTES + video BYTES COMMENT '__BLOB_EXTERNAL_STORAGE_FIELD' ) WITH ( 'row-tracking.enabled' = 'true', 'data-evolution.enabled' = 'true', - 'blob-descriptor-field' = 'video', - 'blob-external-storage-field' = 'video', 'blob-external-storage-path' = 's3://my-bucket/paimon-blobs/' ); ``` @@ -426,13 +422,11 @@ CREATE TABLE doc_embeddings ( doc_id INT, title STRING, content STRING, - embedding ARRAY + embedding ARRAY COMMENT '__VECTOR_FIELD;768' ) TBLPROPERTIES ( 'row-tracking.enabled' = 'true', 'data-evolution.enabled' = 'true', 'global-index.enabled' = 'true', - 'vector-field' = 'embedding', - 'field.embedding.vector-dim' = '768', 'vector.file.format' = 'lance' ); ``` @@ -476,7 +470,7 @@ SELECT * FROM vector_search('doc_embeddings', 'embedding', array(0.1f, 0.2f, ... The legacy index type `lumina-vector-ann` is still accepted for existing tables and SQL compatibility. -**Why:** The [Global Index](../append-table/global-index) with DiskANN provides high-performance ANN search. +**Why:** The [Global Index](../multimodal-table/global-index) with DiskANN provides high-performance ANN search. Vector data is stored in dedicated `.vector.lance` files optimized for dense vectors, while scalar columns stay in Parquet. You can also build a **BTree Index** on scalar columns for efficient filtering: @@ -524,7 +518,7 @@ ON t.user_id = s.user_id WHEN MATCHED THEN UPDATE SET t.embedding = s.embedding; ``` -**Why:** [Data Evolution](../append-table/data-evolution) mode writes only the updated columns to new files +**Why:** [Data Evolution](../multimodal-table/data-evolution) mode writes only the updated columns to new files and merges them at read time. This is ideal for: - Adding new feature columns and backfilling data without rewriting the entire table. @@ -610,7 +604,7 @@ Do you need upsert / update / delete? │ → bucket = N, bucket-key = col (Bucketed Append) │ └── AI / Multimodal scenarios? → Enable Data Evolution - ├── Store images / videos / docs? → Blob Table (blob-field) + ├── Store images / videos / docs? → Blob Table (__BLOB_FIELD comment) ├── Vector search / RAG? → VECTOR type + Global Index (DiskANN) ├── Feature engineering? → Data Evolution (MERGE INTO partial columns) └── Python pipeline? → PyPaimon (Ray / PyTorch / Pandas) diff --git a/docs/docs/maintenance/metrics.md b/docs/docs/maintenance/metrics.md index f6e1576dead2..ef764c13e148 100644 --- a/docs/docs/maintenance/metrics.md +++ b/docs/docs/maintenance/metrics.md @@ -60,6 +60,11 @@ Below is lists of Paimon built-in metrics. They are summarized into types of sca Histogram Distributions of the time taken by the last few scans. + + lastScannedSnapshotId + Gauge + The snapshot ID scanned in the last scan. 0 if no scan has occurred. + lastScannedManifests Gauge @@ -179,6 +184,11 @@ Below is lists of Paimon built-in metrics. They are summarized into types of sca Gauge Total size of the output files for the last compaction. + + lastCommittedSnapshotId + Gauge + The snapshot ID created by the last commit. -1 if no commit has occurred. + diff --git a/docs/docs/migration/migration-from-hive.mdx b/docs/docs/migration/migration-from-hive.mdx index 6e8485593704..92fa84c19350 100644 --- a/docs/docs/migration/migration-from-hive.mdx +++ b/docs/docs/migration/migration-from-hive.mdx @@ -37,7 +37,7 @@ At the same time, you can use paimon hive catalog with Migrate Database Procedur * Migrate Table Procedure: Paimon table does not exist, use the procedure upgrade hive table to paimon table. Hive table will disappear after action done. * Migrate Database Procedure: Paimon table does not exist, use the procedure upgrade all hive tables in database to paimon table. All hive tables will disappear after action done. -These three actions now support file format of hive "orc" and "parquet" and "avro". +These two actions now support file format of hive "orc" and "parquet" and "avro". **We highly recommend to back up hive table data before migrating, because migrating action is not atomic. If been interrupted while migrating, you may lose your data.** diff --git a/docs/docs/multimodal-table/blob.mdx b/docs/docs/multimodal-table/blob.mdx new file mode 100644 index 000000000000..55fd0621f2a0 --- /dev/null +++ b/docs/docs/multimodal-table/blob.mdx @@ -0,0 +1,605 @@ +--- +title: "Blob Storage" +sidebar_position: 7 +--- + +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + + +# Blob Storage + +## Overview + +The `BLOB` (Binary Large Object) type is a data type designed for storing multimodal data such as images, videos, audio files, and other large binary objects in Paimon tables. Unlike traditional `BYTES` type which stores binary data inline with other columns, `BLOB` type stores large binary data in separate files and maintains references to them, providing better performance for large objects. + +The Blob Storage is based on Data Evolution mode. + +The Blob type is ideal for: + +- **Image Storage**: Store product images, user avatars, medical imaging data +- **Video Content**: Store video clips, surveillance footage, multimedia content +- **Audio Files**: Store voice recordings, music files, podcast episodes +- **Document Storage**: Store PDF documents, office files, large text files +- **Machine Learning**: Store embeddings, model weights, feature vectors +- **Any Large Binary Data**: Any data that is too large to store efficiently inline + +## Storage Layout + +When you define a table with a Blob column, Paimon automatically separates the storage: + +1. **Normal Data Files** (e.g., `.parquet`, `.orc`): Store regular columns (INT, STRING, etc.) +2. **Blob Data Files** (`.blob`): Store the actual blob data + +For example, given a table with schema `(id INT, name STRING, picture BLOB)`: + +``` +table/ +├── bucket-0/ +│ ├── data-uuid-0.parquet # Contains id, name columns +│ ├── data-uuid-1.blob # Contains picture blob data +│ ├── data-uuid-2.blob # Contains more picture blob data +│ └── ... +├── manifest/ +├── schema/ +└── snapshot/ +``` + +This separation provides several benefits: +- Efficient column projection (reading non-blob columns doesn't load blob data) +- Optimized file rolling based on blob size +- Better compression for regular columnar data + +For details about the blob file format structure, see [File Format - BLOB](../concepts/spec/fileformat#blob). + +## Storage Modes + +Paimon supports four storage modes for BLOB fields, selected via **comment directives** on the column: + +1. **Default blob storage** (`__BLOB_FIELD`) + Blob bytes are written to Paimon-managed `.blob` files under the table path. + +2. **Descriptor-only storage** (`__BLOB_DESCRIPTOR_FIELD`) + Only serialized `BlobDescriptor` bytes are stored inline in data files. Paimon does not write `.blob` files for these fields, and writes must provide descriptor-based input. + +3. **External-storage descriptor mode** (`__BLOB_EXTERNAL_STORAGE_FIELD`) + At write time, Paimon writes the raw blob data to the configured `blob-external-storage-path` and stores only serialized `BlobDescriptor` bytes inline in data files. + +4. **Blob view storage** (`__BLOB_VIEW_FIELD`) + Serialized `BlobViewStruct` bytes are stored inline. The struct points to a BLOB value in an upstream table by table identifier, BLOB field, and row id. The actual blob bytes are resolved from the upstream table at read time. + +This allows one table to mix different storage modes for different BLOB columns. + +## Table Options + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    OptionRequiredDefaultTypeDescription
    blob-as-descriptor
    NofalseBooleanControls read output format for blob fields. When set to true, queries return serialized BlobDescriptor bytes; when false, queries return actual blob bytes. This option is dynamic and can be changed with ALTER TABLE ... SET.
    blob-write-null-on-missing-file
    NofalseBoolean + When enabled for Flink writes, if a descriptor BLOB value references a file that does not exist, Paimon writes NULL for that value and logs a warning instead of failing when reading the descriptor. +
    blob-view.resolve.enabled
    NotrueBoolean + Controls whether blob view fields are resolved to the upstream BLOB + content at read time. Set to false when forwarding blob view + references from one view table to another. +
    blob-external-storage-path
    No(none)String + External storage path for fields declared with __BLOB_EXTERNAL_STORAGE_FIELD. + Orphan file cleanup is not applied to this path. +
    blob.target-file-size
    No(same as target-file-size)MemorySizeTarget size for blob files. When a blob file reaches this size, a new file is created. If not specified, uses the same value as target-file-size.
    row-tracking.enabled
    Yes*falseBooleanMust be enabled for blob tables to support row-level operations.
    data-evolution.enabled
    Yes*falseBooleanMust be enabled for blob tables to support schema evolution.
    + +*Required for blob functionality to work correctly. + +Specifically, if the storage system of the input BlobDescriptor differs from that used by Paimon, +you can specify the storage configuration for the input blob descriptor using the prefix +`blob-descriptor.`. For example, if the source data is stored in a different OSS endpoint, +you can configure it as below (using flink sql as an example): +```sql +CREATE TABLE image_table ( + id INT, + name STRING, + image BYTES COMMENT '__BLOB_FIELD' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true', + 'fs.oss.endpoint' = 'aaa', -- This is for Paimon's own config + 'blob-descriptor.fs.oss.endpoint' = 'bbb' -- This is for input blob descriptors' config +); +``` + +## Creating a Table + +The recommended way to create a blob table in SQL is to use the **comment directive** `__BLOB_FIELD`, `__BLOB_DESCRIPTOR_FIELD`, or `__BLOB_VIEW_FIELD` on the column. Paimon automatically converts the column type to `BLOB` and registers it in the corresponding option. + + + + + +```sql +CREATE TABLE image_table ( + id INT, + name STRING, + image BYTES COMMENT '__BLOB_FIELD; product image' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); + +-- Multiple blob columns with different storage modes +CREATE TABLE media_table ( + id INT, + photo BYTES COMMENT '__BLOB_FIELD; original photo', + thumbnail BYTES COMMENT '__BLOB_DESCRIPTOR_FIELD; thumbnail descriptor', + preview BYTES COMMENT '__BLOB_VIEW_FIELD; preview from upstream' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); +``` + + + + + +```sql +CREATE TABLE image_table ( + id INT, + name STRING, + image BINARY COMMENT '__BLOB_FIELD; product image' +) TBLPROPERTIES ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); +``` + + + + + +```java +// Java API uses BlobType directly, with comment directive for auto option registration +Schema schema = Schema.newBuilder() + .column("id", DataTypes.INT()) + .column("name", DataTypes.STRING()) + .column("image", DataTypes.BYTES(), "__BLOB_FIELD; product image") + .option("row-tracking.enabled", "true") + .option("data-evolution.enabled", "true") + .build(); +``` + + + + + +```python +import pyarrow as pa +from pypaimon import Schema + +# pa.large_binary() is automatically recognized as BLOB type +pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('image', pa.large_binary()), +]) + +schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } +) +``` + + + + + +The comment directive format is `__DIRECTIVE; optional user comment`. Paimon converts the `BYTES`/`BINARY` type to `BLOB`, registers the field in the corresponding option, and stores the text after `;` as the column's real comment. + +Supported directives: + +| Directive | Storage mode | Option | +|-----------|-------------|--------| +| `__BLOB_FIELD` | Raw bytes in `.blob` files | `blob-field` | +| `__BLOB_DESCRIPTOR_FIELD` | Descriptor bytes inline | `blob-descriptor-field` | +| `__BLOB_VIEW_FIELD` | View reference inline | `blob-view-field` | +| `__BLOB_EXTERNAL_STORAGE_FIELD` | Raw data to external path, descriptor inline | `blob-external-storage-field` + `blob-descriptor-field` | + +## Adding a Blob Column + +The same comment directive works with `ALTER TABLE ADD COLUMN`: + + + + + +```sql +ALTER TABLE image_table ADD picture BYTES COMMENT '__BLOB_FIELD'; + +ALTER TABLE image_table + ADD video BYTES COMMENT '__BLOB_DESCRIPTOR_FIELD; promotional video'; +``` + + + + + +```sql +ALTER TABLE image_table ADD COLUMN picture BINARY COMMENT '__BLOB_FIELD'; + +ALTER TABLE image_table + ADD COLUMN video BINARY COMMENT '__BLOB_DESCRIPTOR_FIELD; promotional video'; +``` + + + + + +```java +// Java API: add column with BlobType directly +schemaManager.commitChanges( + SchemaChange.addColumn("picture", DataTypes.BLOB())); + +// Or use comment directive like SQL +schemaManager.commitChanges( + SchemaChange.addColumn("video", DataTypes.BYTES(), + "__BLOB_DESCRIPTOR_FIELD; promotional video", null)); +``` + + + + + +```python +# pa.large_binary() is automatically recognized as BLOB type +catalog.alter_table( + 'default.image_table', + [('add', 'picture', pa.large_binary())] +) +``` + + + + + +## Inserting Blob Data + + + + + +```sql +INSERT INTO image_table VALUES (1, 'sample', X'89504E470D0A1A0A'); + +INSERT INTO image_table +SELECT id, name, content FROM source_table; +``` + + + + + +```sql +INSERT INTO image_table VALUES (1, 'sample', X'89504E470D0A1A0A'); +``` + + + + + +```java +GenericRow row = GenericRow.of( + 1, + BinaryString.fromString("sample"), + new BlobData(imageBytes) +); +write.write(row); +``` + + + + + +```python +data = pa.table({ + 'id': pa.array([1], type=pa.int32()), + 'name': pa.array(['sample']), + 'image': pa.array([b'\x89PNG\r\n\x1a\n'], type=pa.large_binary()), +}) +writer.write_arrow(data) +``` + + + + + +## Querying Blob Data + + + + + +```sql +-- Select all columns including blob +SELECT * FROM image_table; + +-- Select only non-blob columns (efficient - doesn't load blob data) +SELECT id, name FROM image_table; + +-- Return descriptor bytes instead of actual blob bytes +ALTER TABLE image_table SET ('blob-as-descriptor' = 'true'); +SELECT image FROM image_table; +``` + + + + + +```java +Blob blob = row.getBlob(2); + +// Load into memory +byte[] data = blob.toData(); + +// Stream (recommended for large blobs) +try (SeekableInputStream in = blob.newInputStream()) { + in.seek(100); // random access + byte[] buffer = new byte[1024]; + int bytesRead = in.read(buffer); +} + +// Get descriptor reference (for descriptor-based blobs) +BlobDescriptor descriptor = blob.toDescriptor(); +``` + + + + + +## Blob Construct Sources (Java API) + +```java +Blob blob = Blob.fromData(imageBytes); // byte array +Blob blob = Blob.fromLocal("/path/to/image.png"); // local file +Blob blob = Blob.fromFile(fileIO, "s3://bucket/path/to/image.png"); // any FileIO +Blob blob = Blob.fromFile(fileIO, "s3://bucket/large-file.bin", 1024, 2048); // partial file +Blob blob = Blob.fromHttp("https://example.com/image.png"); // HTTP URL +Blob blob = Blob.fromInputStream(() -> new FileInputStream("...")); // InputStream +Blob blob = Blob.fromDescriptor(uriReader, descriptor); // BlobDescriptor +``` + +## Descriptor-Only Storage + +If you want downstream tables to **reuse** upstream blob files (no copying and no new `.blob` files), use `__BLOB_DESCRIPTOR_FIELD`: + +```sql +CREATE TABLE descriptor_table ( + id INT, + image BYTES COMMENT '__BLOB_DESCRIPTOR_FIELD; reused image' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); +``` + +Paimon stores only serialized `BlobDescriptor` bytes in normal data files. Reading the blob follows the descriptor URI to access bytes, and writing requires descriptor input for those fields. + +## External Storage + +If you want Paimon to write raw blob data to a separate external location while keeping only descriptor bytes inline, use `__BLOB_EXTERNAL_STORAGE_FIELD`: + +```sql +CREATE TABLE external_table ( + id INT, + image BYTES COMMENT '__BLOB_EXTERNAL_STORAGE_FIELD' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true', + 'blob-external-storage-path' = 'oss://bucket/path/' +); +``` + +For these fields: + +- raw blob data is written to the configured external storage path +- normal data files keep only serialized `BlobDescriptor` bytes +- writes can still start from raw BLOB input +- the field is treated as descriptor-based for operations such as `MERGE INTO` + +## Blob View + +Blob view is useful when a downstream table should reference BLOB values already stored in an upstream table, without copying the bytes or creating new `.blob` files. A blob view field stores only a small `BlobViewStruct` inline. When the field is read, Paimon resolves the referenced BLOB from the upstream table. + +Blob view requires: + +- the upstream table to have row tracking enabled, so each row has a stable `_ROW_ID` +- the downstream field to be declared with `__BLOB_VIEW_FIELD` comment directive +- writes to provide a serialized `BlobViewStruct`; in Flink SQL, use the built-in `sys.blob_view` function + +The Flink SQL function signature is: + +```sql +sys.blob_view(table_name, field_name, row_id) +``` + +Arguments: + +- `table_name`: the upstream table name. It must be fully qualified as `database.table` or `catalog.database.table`. Unqualified table names are rejected. +- `field_name`: the upstream BLOB field name. +- `row_id`: the `_ROW_ID` value from the upstream row-tracking table. + +The following example writes a downstream table whose `image_ref` field views the `image` field in `image_table`: + +```sql +CREATE TABLE image_table ( + id INT, + name STRING, + image BYTES COMMENT '__BLOB_FIELD' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); + +CREATE TABLE image_view_table ( + id INT, + label STRING, + image_ref BYTES COMMENT '__BLOB_VIEW_FIELD' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); + +INSERT INTO image_view_table +SELECT + id, + name AS label, + sys.blob_view('default.image_table', 'image', _ROW_ID) +FROM `image_table$row_tracking`; +``` + +If the current Paimon catalog name is included in the table name, the function also accepts `catalog.database.table`: + +```sql +SELECT sys.blob_view('my_catalog.default.image_table', 'image', _ROW_ID) +FROM `image_table$row_tracking`; +``` + +Reads from `image_view_table.image_ref` return the referenced BLOB bytes in the same way as normal blob fields. The referenced upstream table and row must remain available for the view to be resolved. + +**Forward Blob View References** + +By default, reading a blob view field resolves the `BlobViewStruct` and returns the upstream BLOB +content. If you want to import data from one blob view table into another blob view table without +copying the BLOB bytes, read the source table with `blob-view.resolve.enabled=false` and write the +result into a target field declared with `__BLOB_VIEW_FIELD`. + +With this option disabled, Paimon preserves the serialized `BlobViewStruct` during reads. When the +preserved value is written to another blob view field, the target table stores the same upstream +reference instead of creating a chained view reference. + +For example, if table `T1` contains blob view references to BLOBs in table `T0`, importing `T1` into +`T2` with `blob-view.resolve.enabled=false` makes `T2` keep referencing `T0` directly. + +```sql +CREATE TABLE t2 ( + id INT, + image_ref BYTES COMMENT '__BLOB_VIEW_FIELD' +) WITH ( + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); + +-- Flink SQL example: the source table is read with blob view resolution disabled. +INSERT INTO t2 +SELECT id, image_ref +FROM t1 /*+ OPTIONS('blob-view.resolve.enabled'='false') */; +``` + +## MERGE INTO Support + +For Data Evolution writes in Flink and Spark: + +- raw-data BLOB columns are still rejected in partial-column `MERGE INTO` updates +- descriptor-based BLOB columns are allowed + +For the Python equivalent, see [Blob Storage in pypaimon](../pypaimon/blob). + +## Limitations + +1. **Append Table Only**: Blob type is designed for append-only tables. Primary key tables are not supported. +2. **No Predicate Pushdown**: Blob columns cannot be used in filter predicates. +3. **No Statistics**: Statistics collection is not supported for blob columns. +4. **Required Options**: `row-tracking.enabled` and `data-evolution.enabled` must be set to `true`. +5. **External Storage Cleanup**: Files written through `blob-external-storage-path` are outside Paimon's orphan file cleanup scope. +6. **Blob View Dependency**: Blob view fields depend on the referenced upstream table and row. If the upstream data is removed or no longer readable, the view cannot be resolved. + +## Best Practices + +1. **Use Column Projection**: Always select only the columns you need. Avoid `SELECT *` if you don't need blob data. + +2. **Set Appropriate Target File Size**: Configure `blob.target-file-size` based on your blob sizes. Larger values mean fewer files but larger individual files. + +3. **Use Descriptor Fields When Reusing External Blob Files**: Use `__BLOB_DESCRIPTOR_FIELD` for fields that should keep descriptor references instead of writing new `.blob` files. + +4. **Use External-Storage Fields When Accepting Raw Input But Storing Descriptors**: Use `__BLOB_EXTERNAL_STORAGE_FIELD` with `blob-external-storage-path` when upstream writes raw blob bytes but you want descriptor-based storage. + +5. **Manage External Storage Lifecycle Separately**: Files written to `blob-external-storage-path` are not cleaned up by Paimon, so retention and deletion should be managed externally. + +6. **Use Blob View to Avoid Copying BLOB Data**: Use `__BLOB_VIEW_FIELD` when a downstream table only needs to reference BLOB values from an upstream table. + +7. **Use Partitioning**: Partition your blob tables by date or other dimensions to improve query performance and data management. diff --git a/docs/docs/append-table/data-evolution.md b/docs/docs/multimodal-table/data-evolution.md similarity index 100% rename from docs/docs/append-table/data-evolution.md rename to docs/docs/multimodal-table/data-evolution.md diff --git a/docs/docs/append-table/global-index.mdx b/docs/docs/multimodal-table/global-index.mdx similarity index 72% rename from docs/docs/append-table/global-index.mdx rename to docs/docs/multimodal-table/global-index.mdx index 4a8cc0877855..8edcfe87e2ad 100644 --- a/docs/docs/append-table/global-index.mdx +++ b/docs/docs/multimodal-table/global-index.mdx @@ -187,6 +187,63 @@ CALL sys.create_global_index( ); ``` +For other content where users often search by short character fragments, build the +index with Tantivy's `ngram` tokenizer: + +```sql +CALL sys.create_global_index( + table => 'db.my_table', + index_column => 'content', + index_type => 'tantivy-fulltext', + options => 'tantivy.tokenizer=ngram,tantivy.ngram.min-gram=2,tantivy.ngram.max-gram=2' +); +``` + +For Chinese word segmentation, build the index with the `jieba` tokenizer: + +```sql +CALL sys.create_global_index( + table => 'db.my_table', + index_column => 'content', + index_type => 'tantivy-fulltext', + options => 'tantivy.tokenizer=jieba' +); +``` + +For LanceDB-style analyzer customization, choose a base tokenizer and compose token filters: + +```sql +CALL sys.create_global_index( + table => 'db.my_table', + index_column => 'content', + index_type => 'tantivy-fulltext', + options => 'tantivy.tokenizer=simple,tantivy.stem=true,tantivy.remove-stop-words=true' +); +``` + +Supported tokenizer options: + +| Option | Default | Description | +|---|---|---| +| `tantivy.tokenizer` | `default` | Tokenizer used by the full-text index. Supported values: `default`, `simple`, `whitespace`, `raw`, `ngram`, `jieba`. | +| `tantivy.ngram.min-gram` | `2` | Minimum gram length for the `ngram` tokenizer. | +| `tantivy.ngram.max-gram` | `2` | Maximum gram length for the `ngram` tokenizer. | +| `tantivy.ngram.prefix-only` | `false` | Whether the `ngram` tokenizer only emits prefix ngrams. | +| `tantivy.lower-case` | `true` | Whether configurable tokenizers lowercase emitted tokens. | +| `tantivy.max-token-length` | `40` | Maximum token length kept by configurable tokenizers. | +| `tantivy.ascii-folding` | `false` | Whether to normalize non-ASCII Latin characters to ASCII. | +| `tantivy.stem` | `false` | Whether to apply stemming to emitted tokens. | +| `tantivy.language` | `english` | Language used by stemming and built-in stop word filters. | +| `tantivy.remove-stop-words` | `false` | Whether to remove built-in stop words for the configured language. | +| `tantivy.stop-words` | ` ` | Semicolon-separated custom stop words to remove. | +| `tantivy.with-position` | `true` | Whether to store term positions for phrase queries. Disable it to reduce index size when phrase queries are not needed. | + +Tokenizer settings are persisted in global index metadata. Existing index files keep using the +tokenizer they were built with, even if later index builds use different options. +Paimon does not load arbitrary Rust tokenizer plugins from configuration; custom analysis is +provided by composing the supported tokenizer and filter options above. PyPaimon can query +`jieba` indexes when the Python `jieba` package is installed. + **Full-Text Search** @@ -208,6 +265,7 @@ Table table = catalog.getTable(identifier); // Step 1: Build full-text search GlobalIndexResult result = table.newFullTextSearchBuilder() .withQueryText("paimon lake format") + .withQueryOperator("and") .withLimit(10) .withTextColumn("content") .executeLocal(); @@ -233,6 +291,7 @@ table = catalog.get_table('db.my_table') builder = table.new_full_text_search_builder() builder.with_text_column('content') builder.with_query_text('paimon lake format') +builder.with_query_operator('and') builder.with_limit(10) result = builder.execute_local() diff --git a/docs/docs/multimodal-table/index.mdx b/docs/docs/multimodal-table/index.mdx new file mode 100644 index 000000000000..4e0f42a38712 --- /dev/null +++ b/docs/docs/multimodal-table/index.mdx @@ -0,0 +1,47 @@ +--- +title: "Multimodal Table" +sidebar_position: 4 +--- + +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + + +# Overview + +Multimodal Table extends the Append Table with capabilities for storing and querying multimodal data — images, videos, +audio, vectors, and full-text content — all within a single table. It is built on top of the +[Data Evolution](./data-evolution) mode, which enables efficient partial column updates and schema evolution without +rewriting entire data files. + +Key capabilities: + +- **[Data Evolution](./data-evolution)**: Update partial columns without rewriting entire files, enabling efficient schema evolution. +- **[Blob Storage](./blob)**: Store large binary objects (images, videos, audio) in dedicated `.blob` files with efficient column projection. +- **[Vector Storage](./vector)**: Store and manage vector embeddings in dedicated Vortex-format files optimized for vector workloads. +- **[Global Index](./global-index)**: Build BTree, vector (DiskANN), and full-text (Tantivy) indexes for efficient lookups and similarity search. + +All multimodal features require the following table properties: + +```sql +'row-tracking.enabled' = 'true', +'data-evolution.enabled' = 'true' +``` diff --git a/docs/docs/append-table/vector.mdx b/docs/docs/multimodal-table/vector.mdx similarity index 67% rename from docs/docs/append-table/vector.mdx rename to docs/docs/multimodal-table/vector.mdx index c957d64b5f55..8cbc905e17e4 100644 --- a/docs/docs/append-table/vector.mdx +++ b/docs/docs/multimodal-table/vector.mdx @@ -65,8 +65,6 @@ table/ └── snapshot/ ``` -### Configuration - | Option | Description | |--------|-------------| | `vector.file.format` | File format for dedicated vector files. Set to `vortex` to enable dedicated vector storage. | @@ -74,42 +72,48 @@ table/ | `row-tracking.enabled` | Must be `true` for dedicated vector storage. | | `data-evolution.enabled` | Must be `true` for dedicated vector storage. | -## Usage Examples +## Create Table -### Create Table +The recommended way to create a vector table in SQL is to use the **comment directive** `__VECTOR_FIELD;dim` on the column. Paimon automatically converts the `ARRAY` type to `VECTOR` and registers the field in the `vector-field` option. ```sql -CREATE TABLE IF NOT EXISTS vector_table ( +-- Comment directive: __VECTOR_FIELD;{dim}; optional comment +CREATE TABLE vector_table ( id BIGINT, - embed ARRAY + embed ARRAY COMMENT '__VECTOR_FIELD;128; product embedding' +) WITH ( + 'vector.file.format' = 'vortex', + 'row-tracking.enabled' = 'true', + 'data-evolution.enabled' = 'true' +); + +-- Multiple vector columns +CREATE TABLE multi_vector_table ( + id BIGINT, + embed1 ARRAY COMMENT '__VECTOR_FIELD;128', + embed2 ARRAY COMMENT '__VECTOR_FIELD;768' ) WITH ( 'vector.file.format' = 'vortex', - 'vector-field' = 'embed', - 'field.embed.vector-dim' = '128', 'row-tracking.enabled' = 'true', 'data-evolution.enabled' = 'true' ); ``` -Since engine layers typically don't have dedicated vector types, Paimon provides configuration to convert the engine's `ARRAY` type to Paimon's `VECTOR` type: - - **`vector-field`**: Declare columns as `VECTOR` type, multiple columns separated by commas (`,`); - - **`field.{field-name}.vector-dim`**: Declare the dimension of the vector column. -Multiple vector columns example: + + + + ```sql -CREATE TABLE IF NOT EXISTS multi_vector_table ( +CREATE TABLE vector_table ( id BIGINT, - embed1 ARRAY, - embed2 ARRAY -) WITH ( + embed ARRAY COMMENT '__VECTOR_FIELD;128; product embedding' +) TBLPROPERTIES ( 'vector.file.format' = 'vortex', - 'vector-field' = 'embed1,embed2', - 'field.embed1.vector-dim' = '128', - 'field.embed2.vector-dim' = '768', 'row-tracking.enabled' = 'true', 'data-evolution.enabled' = 'true' ); @@ -117,18 +121,10 @@ CREATE TABLE IF NOT EXISTS multi_vector_table ( - + ```java -import org.apache.paimon.catalog.*; -import org.apache.paimon.data.*; -import org.apache.paimon.schema.Schema; -import org.apache.paimon.table.FileStoreTable; -import org.apache.paimon.table.sink.*; -import org.apache.paimon.table.source.*; -import org.apache.paimon.types.DataTypes; - -// Build schema with VECTOR column +// Java API uses VectorType directly — no comment directive needed Schema schema = Schema.newBuilder() .column("id", DataTypes.BIGINT()) .column("embed", DataTypes.VECTOR(128, DataTypes.FLOAT())) @@ -137,29 +133,20 @@ Schema schema = Schema.newBuilder() .option("data-evolution.enabled", "true") .option("bucket", "-1") .build(); - -// Create table -Identifier identifier = Identifier.create("default", "vector_table"); -catalog.createTable(identifier, schema, false); -FileStoreTable table = (FileStoreTable) catalog.getTable(identifier); ``` - + ```python import pyarrow as pa -from pypaimon import CatalogFactory, Schema - -# Create catalog -catalog = CatalogFactory.create({'warehouse': '/path/to/warehouse'}) -catalog.create_database('default', True) +from pypaimon import Schema -# Define schema with vector column (fixed-size list in PyArrow) +# Fixed-size list is automatically recognized as VECTOR type pa_schema = pa.schema([ ('id', pa.int64()), - ('embed', pa.list_(pa.float32(), 128)), # VECTOR(128, FLOAT) + ('embed', pa.list_(pa.float32(), 128)), ]) schema = Schema.from_pyarrow_schema( @@ -171,25 +158,88 @@ schema = Schema.from_pyarrow_schema( 'bucket': '-1', } ) +``` + + + + + +## Adding a Vector Column + + + + + +```sql +ALTER TABLE vector_table + ADD embed2 ARRAY COMMENT '__VECTOR_FIELD;768; text embedding'; +``` + + + + + +```sql +ALTER TABLE vector_table + ADD COLUMN embed2 ARRAY COMMENT '__VECTOR_FIELD;768; text embedding'; +``` -catalog.create_table('default.vector_table', schema, False) + + + + +```java +// Java API: add column with VectorType directly +schemaManager.commitChanges( + SchemaChange.addColumn("embed2", DataTypes.VECTOR(768, DataTypes.FLOAT()))); + +// Or use comment directive like SQL +schemaManager.commitChanges( + SchemaChange.addColumn("embed2", DataTypes.ARRAY(DataTypes.FLOAT()), + "__VECTOR_FIELD;768; text embedding", null)); +``` + + + + + +```python +catalog.alter_table( + 'default.vector_table', + [('add', 'embed2', pa.list_(pa.float32(), 768))] +) ``` -### Write Data +## Write Data - + + +```sql +INSERT INTO vector_table VALUES (1, ARRAY[1.0, 2.0, 3.0, ...]); +``` + + + + + +```sql +INSERT INTO vector_table VALUES (1, ARRAY(1.0, 2.0, 3.0, ...)); +``` + + + + ```java BatchWriteBuilder builder = table.newBatchWriteBuilder(); try (BatchTableWrite write = builder.newWrite(); BatchTableCommit commit = builder.newCommit()) { - // Create a vector using BinaryVector InternalVector vector = BinaryVector.fromPrimitiveArray(new float[] {1.0f, 2.0f, 3.0f}); write.write(GenericRow.of(1L, vector)); commit.commit(write.prepareCommit()); @@ -198,14 +248,11 @@ try (BatchTableWrite write = builder.newWrite(); - + ```python import pyarrow as pa -table = catalog.get_table('default.vector_table') - -# Prepare vector data as PyArrow FixedSizeListArray data = pa.table({ 'id': pa.array([1, 2, 3], type=pa.int64()), 'embed': pa.FixedSizeListArray.from_arrays( @@ -226,11 +273,27 @@ writer.close() -### Read Data +## Read Data - + + +```sql +SELECT id, embed FROM vector_table; +``` + + + + + +```sql +SELECT id, embed FROM vector_table; +``` + + + + ```java ReadBuilder readBuilder = table.newReadBuilder(); @@ -246,16 +309,13 @@ try (RecordReader reader = readBuilder.newRead().createReader(plan) - + ```python -table = catalog.get_table('default.vector_table') - read_builder = table.new_read_builder() splits = read_builder.new_scan().plan().splits() result = read_builder.new_read().to_arrow(splits) -# Result contains vector columns as PyArrow FixedSizeListArray print(result.column('embed').to_pylist()) # [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] ``` diff --git a/docs/docs/primary-key-table/chain-table.md b/docs/docs/primary-key-table/chain-table.md index b5890d428b5d..b061efd36a1c 100644 --- a/docs/docs/primary-key-table/chain-table.md +++ b/docs/docs/primary-key-table/chain-table.md @@ -204,3 +204,67 @@ partition keys: This treats `(dt, hour)` as the composite chain dimension and everything before it (e.g., `region`) as the group dimension. + +## Partition Expiration + +Chain tables support automatic partition expiration via the standard `partition.expiration-time` option. +However, the expiration algorithm differs from normal tables to preserve chain integrity. + +### How It Works + +In a normal table, every partition older than the cutoff (`now - partition.expiration-time`) is dropped +independently. Chain tables cannot do this because a delta partition depends on its nearest earlier +snapshot partition as an anchor for merge-on-read. Dropping the anchor would break the chain. + +Chain table expiration works in **segments**. A segment consists of one snapshot partition and all the +delta partitions whose time falls between that snapshot and the next snapshot in sorted order. The +segment is the atomic unit of expiration: either the entire segment is expired, or nothing in it is. + +The algorithm per group: +1. List all snapshot branch partitions sorted by chain partition time. +2. Filter to those before the cutoff (`now - partition.expiration-time`). +3. If fewer than 2 snapshots are before the cutoff, nothing can be expired — the only one must be kept + as the anchor. +4. The most recent snapshot before the cutoff is the **anchor** (kept). All earlier snapshots and their + associated delta partitions form expirable segments. +5. Delta partitions are dropped before snapshot partitions so that the commit pre-check always passes. + +For tables with group partitions, each group is processed independently. A group with many expired +snapshots can have segments expired while another group with only one snapshot before the cutoff retains +all of its data. + +### Example + +```sql +ALTER TABLE default.t SET TBLPROPERTIES ( + 'partition.expiration-time' = '30 d', + 'partition.expiration-check-interval' = '1 d' +); +ALTER TABLE `default`.`t$branch_snapshot` SET TBLPROPERTIES ( + 'partition.expiration-time' = '30 d', + 'partition.expiration-check-interval' = '1 d' +); +ALTER TABLE `default`.`t$branch_delta` SET TBLPROPERTIES ( + 'partition.expiration-time' = '30 d', + 'partition.expiration-check-interval' = '1 d' +); +``` + +Suppose the snapshot branch has partitions `S(0101)`, `S(0201)`, `S(0301)` and the delta branch has +`D(0110)`, `D(0210)`, `D(0315)`. On `2025-03-31` with a 30-day retention the cutoff is `2025-03-01`: + +- Snapshots before cutoff: `S(0101)`, `S(0201)`. Anchor = `S(0201)` (kept). +- Segment 1 expired: `S(0101)` + `D(0110)` (delta between `S(0101)` and `S(0201)`). +- Remaining: `S(0201)`, `S(0301)`, `D(0210)`, `D(0315)`. + +### Important Notes + +- **Delta-only groups are not expired.** If a group has delta partitions but no snapshot partition, its + deltas are the only copy of that group's data. Partition expiration will not touch them. They will + start to be expired once at least two snapshot partitions exist for the group and fall before the + cutoff. +- **Conflict detection is anchor-aware.** When `partition.expiration-strategy` is `values-time`, the + conflict detection during writes correctly recognizes that anchor partitions are retained and does not + reject writes to them. +- The `partition.expiration-time` and `partition.expiration-check-interval` options should be set + consistently across the main table and both branches. diff --git a/docs/docs/primary-key-table/index.md b/docs/docs/primary-key-table/index.md index a912bc6ac73f..ec48c1277ef5 100644 --- a/docs/docs/primary-key-table/index.md +++ b/docs/docs/primary-key-table/index.md @@ -1,5 +1,5 @@ --- -title: "Table with PK" +title: "PrimaryKey Table" sidebar_position: 2 --- diff --git a/docs/docs/primary-key-table/merge-engine/aggregation.mdx b/docs/docs/primary-key-table/merge-engine/aggregation.mdx index cb8829720584..96e6e4f3f076 100644 --- a/docs/docs/primary-key-table/merge-engine/aggregation.mdx +++ b/docs/docs/primary-key-table/merge-engine/aggregation.mdx @@ -311,6 +311,8 @@ public static class BitmapContainsUDF extends ScalarFunction { Use `fields..nested-key=pk0,pk1,...` to specify the primary keys of the nested table. If no keys, row will be appended to array\. + Use `fields..nested-sequence-field=seq0,seq1,...` to control the update sequence of a nested table, you must configure `fields..nested-key` when using it. + Use `fields..count-limit=` to specify the maximum number of rows in the nested table. When no nested-key, it will select data sequentially up to limit; but if nested-key is specified, it cannot guarantee the correctness of the aggregation result. This option can be used to avoid abnormal input. diff --git a/docs/docs/primary-key-table/sequence-rowkind.mdx b/docs/docs/primary-key-table/sequence-rowkind.mdx index 96fcc19963cb..8a0f49fb7307 100644 --- a/docs/docs/primary-key-table/sequence-rowkind.mdx +++ b/docs/docs/primary-key-table/sequence-rowkind.mdx @@ -71,3 +71,50 @@ By default, the primary key table determines the row kind according to the input `'rowkind.field'` to use a field to extract row kind. The valid row kind string should be `'+I'`, `'-U'`, `'+U'` or `'-D'`. + +## Snapshot Ordering + +In multi-writer scenarios where wall-clock sequence numbers cannot be globally ordered across writers, +you can enable `'sequence.snapshot-ordering'` to use the commit snapshot id as the ordering key when +merging records with the same primary key. Records from later snapshots are considered newer, +regardless of their per-record sequence number. + + + + +```sql +CREATE TABLE my_table ( + pk BIGINT PRIMARY KEY NOT ENFORCED, + v1 DOUBLE, + v2 BIGINT +) WITH ( + 'sequence.snapshot-ordering' = 'true', + 'write-only' = 'true' +); +``` + + + + +:::warning +This option requires `'write-only' = 'true'`. Compaction must be performed by a separate dedicated +compact job. This ensures that compaction correctly preserves the snapshot id of each record. +::: + +:::info +`'sequence.snapshot-ordering'` is mutually exclusive with `'sequence.field'`. You cannot enable both +on the same table. +::: + +:::info +The ordering key is the commit snapshot id only; the order of records **within the same snapshot** is +not guaranteed, and this is by design. Under the default configuration it is harmless: a writer +buffers a commit's writes (`'write-buffer-spillable' = 'true'`) and runs them through the merge +function before flushing, so at most one record per primary key is written per snapshot — the common +case is fully covered. We therefore deliberately do not handle the case where the same key is spread +across multiple files of one snapshot. That case only arises with `'write-buffer-spillable' = 'false'`, +or when the spilled data exceeds `'write-buffer-spill-disk-size'`, where the buffer may be flushed +mid-commit; the same key can then land in multiple files of the same snapshot with equal sequence +numbers and their relative order becomes undefined. This affects only intra-snapshot order, never the +cross-snapshot ordering this feature provides. +::: diff --git a/docs/docs/project/security.md b/docs/docs/project/security.md new file mode 100644 index 000000000000..134569d19df9 --- /dev/null +++ b/docs/docs/project/security.md @@ -0,0 +1,73 @@ +--- +title: "Security" +sidebar_position: 4 +--- + + + +# Security + +## Reporting Security Issues + +The Apache Paimon Project uses the standard process outlined by the +[Apache Security Team](https://www.apache.org/security/) for reporting +vulnerabilities. + +Note that vulnerabilities should not be publicly disclosed until the project +has responded. + +To report a possible security vulnerability, please email +**[security@apache.org](mailto:security@apache.org)**. + +## Security Model + +Apache Paimon is a data lake platform and a set of libraries and integrations +used inside larger systems such as catalogs, query engines, and services. + +In most deployments, the primary trust and authorization boundaries are +enforced by the surrounding catalog, engine, service, operator configuration, +and storage-level authorization rather than by Paimon alone. + +Paimon security issues generally include: + +- Secret or credential disclosure to a newly reachable audience (e.g., bearer + tokens, access keys, or delegated storage tokens leaking across catalog, + session, or table boundaries) +- Other cases where Paimon itself creates a new unauthorized capability + rather than merely reflecting the trust decisions of a catalog, engine, or + operator + +Many other issues may still be valid bugs, but are not normally considered +security vulnerabilities in Paimon. This includes: + +- Robustness issues such as malformed-input crashes or memory exhaustion +- Issues that require a malicious catalog, metastore, REST Catalog server, or + other external service +- Issues that depend on operator misconfiguration (e.g., overly broad IAM + policies, missing TLS) + +Potential vulnerabilities that fall within this security model should be +reported privately using the process above. Other bugs and hardening issues +should be reported through the +[public issue tracker](https://github.com/apache/paimon/issues). + +For a more detailed threat model used for maintainer triage and scanner +calibration, see the +[Apache Paimon Security Threat Model](https://github.com/apache/paimon/blob/master/SECURITY.md). diff --git a/docs/docs/pypaimon/blob.md b/docs/docs/pypaimon/blob.md index dee0b9d6243b..9147a96873ef 100644 --- a/docs/docs/pypaimon/blob.md +++ b/docs/docs/pypaimon/blob.md @@ -24,7 +24,7 @@ under the License. # Blob Storage in pypaimon For Paimon's Blob storage concepts (storage modes, table options, SQL usage, -Java API), see [Blob Storage](../append-table/blob). +Java API), see [Blob Storage](../multimodal-table/blob). This page covers the Python API for reading and writing BLOB columns. @@ -148,7 +148,7 @@ header). This mirrors Java's `Blob.fromBytes(...)`. ## See Also -- [Blob Storage](../append-table/blob) — concept, storage modes, +- [Blob Storage](../multimodal-table/blob) — concept, storage modes, SQL/Java API - [Data Evolution](./data-evolution) — required for blob tables diff --git a/docs/docs/pypaimon/cli.md b/docs/docs/pypaimon/cli.md index f7afc2972cb3..94c0fdd6881b 100644 --- a/docs/docs/pypaimon/cli.md +++ b/docs/docs/pypaimon/cli.md @@ -461,7 +461,10 @@ Output: 4 Data lake platforms like Paimon handle large-... ``` -**Note:** The table must have a Tantivy full-text index built on the target column. See [Global Index](../append-table/global-index) for how to create full-text indexes. +**Note:** The table must have a Tantivy full-text index built on the target column. PyPaimon uses +the tokenizer settings stored in the index metadata; ngram full-text indexes require a `tantivy-py` +package with custom tokenizer support, and jieba full-text indexes require the Python `jieba` +package. See [Global Index](../multimodal-table/global-index) for how to create full-text indexes. ### Table Drop diff --git a/docs/docs/pypaimon/data-evolution.md b/docs/docs/pypaimon/data-evolution.md index c20f438800bb..2ce89d23144c 100644 --- a/docs/docs/pypaimon/data-evolution.md +++ b/docs/docs/pypaimon/data-evolution.md @@ -25,7 +25,7 @@ under the License. # Data Evolution -PyPaimon for Data Evolution mode. See [Data Evolution](../append-table/data-evolution). +PyPaimon for Data Evolution mode. See [Data Evolution](../multimodal-table/data-evolution). ## Prerequisites diff --git a/docs/docs/pypaimon/global-index.md b/docs/docs/pypaimon/global-index.md index 37e5f8032be4..5795303a8692 100644 --- a/docs/docs/pypaimon/global-index.md +++ b/docs/docs/pypaimon/global-index.md @@ -30,7 +30,7 @@ PyPaimon supports querying global indexes built on Data Evolution (append) table - **Vector Index (Lumina)**: Approximate nearest neighbor (ANN) index for vector similarity search. - **Full-Text Index (Tantivy)**: Full-text search index for text retrieval with relevance scoring. -> Global indexes must be built beforehand (e.g., via Spark or Flink). See [Global Index](../append-table/global-index) for how to create indexes. +> Global indexes must be built beforehand (e.g., via Spark or Flink). See [Global Index](../multimodal-table/global-index) for how to create indexes. ## BTree Index @@ -109,6 +109,13 @@ data = read.to_arrow(scan.plan().splits) ## Full-Text Index (Tantivy) Use `FullTextSearchBuilder` to perform full-text search on a text column, then read the matched rows. +PyPaimon reads the Tantivy tokenizer settings stored in the global index metadata. Indexes built +with `tantivy.tokenizer=ngram` can be queried from Python when the installed `tantivy-py` package +provides custom tokenizer support, including `Tokenizer.ngram`, `Tokenizer.simple`, +`Tokenizer.whitespace`, `Tokenizer.raw`, `TextAnalyzerBuilder`, tokenizer filters, and +`Index.register_tokenizer`. Indexes built with `tantivy.tokenizer=jieba` can be queried from Python +when the `jieba` package is installed. Query terms use OR semantics by default; use +`with_query_operator("and")` to require all terms. ```python table = catalog.get_table("db.my_table") @@ -119,6 +126,7 @@ index_result = ( builder .with_text_column("content") .with_query_text("search keywords") + .with_query_operator("and") .with_limit(20) .execute_local() ) diff --git a/docs/docs/pypaimon/index.md b/docs/docs/pypaimon/index.md index 18989aaf8289..e35f8b8bf762 100644 --- a/docs/docs/pypaimon/index.md +++ b/docs/docs/pypaimon/index.md @@ -50,3 +50,18 @@ pip3 install dist/*.tar.gz ``` The command will install the package and core dependencies to your local Python environment. + +## Optional Dependencies + +PyPaimon supports multiple file formats via optional extras: + +```shell +# Mosaic format (columnar-bucket hybrid, optimized for wide tables) +pip install pypaimon[mosaic] + +# Lance format (optimized for ML / vector search) +pip install pypaimon[lance] + +# Vortex format (requires Python >= 3.11) +pip install pypaimon[vortex] +``` diff --git a/docs/docs/pypaimon/python-api.mdx b/docs/docs/pypaimon/python-api.mdx index 52afbc2676b7..c1887c537c75 100644 --- a/docs/docs/pypaimon/python-api.mdx +++ b/docs/docs/pypaimon/python-api.mdx @@ -103,6 +103,46 @@ catalog_options = { + + +PyPaimon keeps the catalog type `jdbc` for compatibility with Paimon catalog options, +but connects with native Python DB-API drivers instead of JVM JDBC drivers. + +```python +from pypaimon import CatalogFactory + +# Note that keys and values are all string +catalog_options = { + 'metastore': 'jdbc', + 'warehouse': 'file:///path/to/warehouse', + 'uri': 'jdbc:sqlite:/path/to/catalog.db', + # Optional. Defaults to 'jdbc'. + 'catalog-key': 'jdbc', +} +catalog = CatalogFactory.create(catalog_options) +``` + +For MySQL or PostgreSQL, install the corresponding Python DB-API driver and use the same Paimon +JDBC catalog options: + +```python +catalog_options = { + 'metastore': 'jdbc', + 'warehouse': 's3://bucket/path/to/warehouse', + 'uri': 'jdbc:mysql://:/', + 'jdbc.user': '...', + 'jdbc.password': '...', + 'catalog-key': 'jdbc', +} +``` + +Unlike Flink or Spark, PyPaimon does not use JVM JDBC drivers or load JDBC connector jars. +It keeps the `metastore='jdbc'` and `jdbc:` URI format for compatibility with Paimon's +JDBC catalog configuration, but the database connection is created through native Python DB-API +drivers such as `pymysql`, `mysql-connector-python`, `psycopg2`, or `psycopg`. + + + The sample code is as follows. The detailed meaning of option can be found in [REST](../concepts/rest/). @@ -124,7 +164,7 @@ catalog = CatalogFactory.create(catalog_options) -Currently, PyPaimon only support filesystem catalog and rest catalog. See [Catalog](../concepts/catalog). +Currently, PyPaimon supports filesystem catalog, jdbc catalog and rest catalog. See [Catalog](../concepts/catalog). You can use the catalog to create table for writing data. diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md index 3ee4db328979..d160b4302f65 100644 --- a/docs/docs/pypaimon/ray-data.md +++ b/docs/docs/pypaimon/ray-data.md @@ -205,19 +205,49 @@ write_paimon( ) ``` -**Automatic (partition, bucket) clustering for HASH_FIXED tables:** +**HASH_FIXED pre-clustering:** -For HASH_FIXED tables, `write_paimon` automatically clusters rows by -`(partition_keys..., bucket)` before writing so each (partition, -bucket) lands in a single Ray task — one writer, one file group. This -avoids the small-file storm that Ray's default round-robin -distribution would otherwise produce (`partitions × buckets × -ray_tasks` files instead of `partitions × buckets`). +HASH_FIXED rows are always assigned to the correct Paimon bucket by +the writer. Pre-clustering is only a file-count optimization. -Bucket assignment uses the same hash routine the writer uses, so the -bucket seen by the groupby is byte-equivalent to the one the writer -would compute. No user configuration is required. For non-HASH_FIXED -tables the dataset is written as-is. +By default, `write_paimon` writes append-only HASH_FIXED tables +without pre-clustering. This avoids Ray `groupby().map_groups()` +materializing an entire `(partition_keys..., bucket)` group on one Ray +node. + +HASH_FIXED primary-key tables reject the default/off mode. Direct Ray +writes can send the same bucket to multiple writer tasks, and those +writers can allocate overlapping sequence numbers. Use the explicit +`map_groups` mode until a bounded pre-clustering strategy preserves +per-bucket sequence ordering. + +If every `(partition_keys..., bucket)` group fits in memory on a +single Ray node, you can opt in to the legacy small-file optimization: + +```python +write_paimon( + ray_dataset, + "database_name.table_name", + catalog_options={"warehouse": "/path/to/warehouse"}, + hash_fixed_precluster="map_groups", +) +``` + +`hash_fixed_precluster="map_groups"` groups rows by +`(partition_keys..., bucket)` before writing so each group lands in a +single Ray task. This can reduce file count and keeps HASH_FIXED +primary-key sequence generation per bucket in one writer task, but it +inherits Ray's `map_groups()` memory bound. Large append-only buckets +or hot append-only partitions should use the default mode or +`hash_fixed_precluster="off"`. + +For non-HASH_FIXED append-only tables, the dataset is written as-is. +Postpone-bucket primary-key tables (`bucket = -2`) are also written +as-is to the `bucket-postpone` directory. HASH_DYNAMIC and +CROSS_PARTITION primary-key Ray writes are not supported and fail fast, +including the default dynamic-bucket primary-key table (`bucket = -1`). +Ray write tasks create independent Paimon writers, which can assign +overlapping buckets or sequence numbers for those modes. **Parameters:** - `dataset`: the Ray Dataset to write. @@ -227,53 +257,156 @@ tables the dataset is written as-is. - `concurrency`: optional max number of Ray write tasks to run concurrently. - `ray_remote_args`: optional kwargs passed to `ray.remote()` in write tasks (e.g. `{"num_cpus": 2}`). +- `hash_fixed_precluster`: HASH_FIXED pre-clustering mode. `"auto"` and + `"off"` write append-only HASH_FIXED tables directly and reject + HASH_FIXED primary-key tables. `"map_groups"` enables the legacy + small-file optimization for HASH_FIXED primary-key tables and requires + each `(partition_keys..., bucket)` group to fit in memory on one Ray + node. This option does not enable Ray writes for HASH_DYNAMIC or + CROSS_PARTITION primary-key tables. ### `TableWrite.write_ray()` (lower-level) If you have already constructed a `table_write` from a write builder, you can -hand a Ray Dataset directly to it. `write_ray()` commits through the Ray -Datasink API, so there is no `prepare_commit` / `commit` step to run for the -Ray write itself — just close the writer when you are done with it: +hand a Ray Dataset directly to it. `write_ray()` uses the same HASH_FIXED +pre-clustering modes and safety checks as the top-level `write_paimon()` API. +It commits through the Ray Datasink API, so there is no `prepare_commit` / +`commit` step to run for the Ray write itself — just close the writer when you +are done with it: ```python import ray table = catalog.get_table('database_name.table_name') -# 1. Create table write and commit (commit is only needed for non-Ray writes -# on the same table_write instance — see below). -write_builder = table.new_batch_write_builder() -table_write = write_builder.new_write() -table_commit = write_builder.new_commit() +# 1. Create table write. +table_write = table.new_batch_write_builder().new_write() # 2. Write Ray Dataset ray_dataset = ray.data.read_json("/path/to/data.jsonl") -table_write.write_ray(ray_dataset, overwrite=False, concurrency=2) +table_write.write_ray( + ray_dataset, + overwrite=False, + concurrency=2, + hash_fixed_precluster="auto", + static_partition=None, +) # Parameters: # - dataset: Ray Dataset to write # - overwrite: Whether to overwrite existing data (default: False) # - concurrency: Optional max number of concurrent Ray tasks # - ray_remote_args: Optional kwargs passed to ray.remote() (e.g., {"num_cpus": 2}) +# - hash_fixed_precluster: Same HASH_FIXED modes and primary-key safety +# checks as write_paimon() +# - static_partition: Optional partition spec to overwrite. When set, +# write_ray() runs in overwrite mode for this partition. -# 3. Commit data (required for write_pandas/write_arrow/write_arrow_batch only) -commit_messages = table_write.prepare_commit() -table_commit.commit(commit_messages) - -# 4. Close resources +# 3. Close resources table_write.close() -table_commit.close() ``` -### Overwrite at builder level +### Overwrite -The recommended way to overwrite via `write_paimon` is the `overwrite=True` -flag above. When using the lower-level builder API, you can also configure -overwrite mode on the write builder itself: +The top-level `write_paimon()` API supports whole-table overwrite with the +`overwrite=True` flag above. With the lower-level `write_ray()` API, you can +use `overwrite=True` for whole-table overwrite and `static_partition={...}` for +partition overwrite: + +```python +table_write.write_ray(ray_dataset, overwrite=True) +table_write.write_ray(ray_dataset, static_partition={'dt': '2024-01-01'}) +``` + +When using the lower-level builder API, you can also configure overwrite mode +on the write builder itself. The resulting `table_write` carries the overwrite +partition into `write_ray()`. A `static_partition` argument passed directly to +`write_ray()` overrides the builder-level partition: ```python # overwrite whole table -write_builder = table.new_batch_write_builder().overwrite() +table_write = table.new_batch_write_builder().overwrite().new_write() +table_write.write_ray(ray_dataset) # overwrite partition 'dt=2024-01-01' -write_builder = table.new_batch_write_builder().overwrite({'dt': '2024-01-01'}) +table_write = ( + table.new_batch_write_builder() + .overwrite({'dt': '2024-01-01'}) + .new_write() +) +table_write.write_ray(ray_dataset) ``` + +## Merge Into + +`merge_into` updates (and optionally inserts) rows of a **data-evolution** table +from a source, like SQL `MERGE INTO`. Matched rows are updated in place by +`_ROW_ID`; only the touched columns are rewritten. Requires `ray >= 2.50` and a +target table with `'data-evolution.enabled'` and `'row-tracking.enabled'` set. + +```python +from pypaimon.ray import merge_into, WhenMatched, WhenNotMatched + +metrics = merge_into( + target="database_name.table_name", + source=ray_dataset, # ray.data.Dataset / pa.Table / pandas / table-name str + catalog_options={"warehouse": "/path/to/warehouse"}, + on=["id"], # or {"target_col": "source_col"} for renamed keys + when_matched=[WhenMatched(update="*")], + when_not_matched=[WhenNotMatched(insert="*")], # optional +) +print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0} +``` + +Conditional clauses filter which matched/unmatched rows are acted on: + +```python +merge_into( + target="db.table", + source=source_ds, + catalog_options=catalog_options, + on=["id"], + when_matched=[WhenMatched(update="*", condition="s.age > t.age")], + when_not_matched=[WhenNotMatched(insert="*", condition="s.age > 18")], +) +``` + +Conditions use SQL-style expressions with `s.` (source) and `t.` (target) +column prefixes. `WhenNotMatched` conditions may only reference source +columns (`s.*`). Requires the `datafusion` package: `pip install pypaimon[sql]`. + +- `update` / `insert`: `"*"` updates/inserts all non-blob columns from source. + A mapping selects specific columns: + ```python + from pypaimon.ray import source_col, target_col, lit + + WhenMatched(update={"age": source_col("age"), "name": target_col("name")}) + WhenNotMatched(insert={"id": source_col("id"), "status": lit("new")}) + ``` + `"s."` / `"t."` shorthands also work (`t.*` only in update). + Use `lit()` for literals starting with `s.` or `t.`. +- `condition`: an optional SQL-style boolean expression. Use `s.` and + `t.` to reference source and target columns. + +**Parameters:** +- `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a + Paimon table identifier string. When a string is passed, it reads the table + from the same `catalog_options` at the latest snapshot. +- `on`: key columns, or `{target_col: source_col}` for renamed keys. +- `num_partitions`: shuffle parallelism for the join and the write; defaults to + `max(1, cluster_cpus * 2)`. Raise it for large merges on big clusters. +- `ray_remote_args`: Ray remote options applied to the merge's map/group + tasks (update transform, group write, insert transform). +- `concurrency`: scheduling for the insert sink. + +**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. `num_matched` +counts the rows actually updated (after condition filtering). `num_unchanged` +is `0` in the current implementation. + +**Notes:** +- Partition key columns cannot be updated by matched clauses. If the target + table is partitioned, `merge_into` raises an error when `when_matched` is + specified, because cross-partition row movement is not implemented. + Not-matched inserts into partitioned tables work normally. +- Blob columns are not written by `merge_into`: update leaves the existing + `.blob` files untouched, and insert fills blob columns with `NULL`. The + source data does not need to (and should not) carry blob columns. diff --git a/docs/docs/spark/procedures.md b/docs/docs/spark/procedures.md index 008c892a377a..d693346e00a8 100644 --- a/docs/docs/spark/procedures.md +++ b/docs/docs/spark/procedures.md @@ -517,14 +517,16 @@ This section introduce all available spark procedures about paimon. To create global index files for a given column. The table must have row-tracking.enabled=true. Arguments:
  • table: the target table identifier. Cannot be empty.
  • index_column: the name of the column to index. Cannot be empty.
  • -
  • index_type: type of the index to build, e.g. 'btree' or 'bitmap'. Cannot be empty.
  • +
  • index_type: type of the index to build, e.g. 'btree'. Cannot be empty.
  • partitions: partition filter to limit the partitions on which to build the index. The comma (",") represents "AND", the semicolon (";") represents "OR". Left empty for all partitions.
  • options: additional dynamic options of the table. It prioritizes higher than original `tableProp` and lower than `procedureArg`.
  • - CALL sys.create_global_index(table => 'default.T', index_column => 'name', index_type => 'bitmap')

    CALL sys.create_global_index(table => 'default.T', index_column => 'name', index_type => 'btree')

    - CALL sys.create_global_index(table => 'default.T', index_column => 'name', index_type => 'btree', partitions => 'pt=p1;pt=p2') + CALL sys.create_global_index(table => 'default.T', index_column => 'name', index_type => 'btree', partitions => 'pt=p1;pt=p2')

    + CALL sys.create_global_index(table => 'default.T', index_column => 'content', index_type => 'tantivy-fulltext', options => 'tantivy.tokenizer=ngram,tantivy.ngram.min-gram=2,tantivy.ngram.max-gram=2')

    + CALL sys.create_global_index(table => 'default.T', index_column => 'content', index_type => 'tantivy-fulltext', options => 'tantivy.tokenizer=jieba')

    + CALL sys.create_global_index(table => 'default.T', index_column => 'content', index_type => 'tantivy-fulltext', options => 'tantivy.tokenizer=simple,tantivy.stem=true,tantivy.remove-stop-words=true') @@ -533,12 +535,11 @@ This section introduce all available spark procedures about paimon. To drop global index files for a given column. Arguments:
  • table: the target table identifier. Cannot be empty.
  • index_column: the name of the indexed column. Cannot be empty.
  • -
  • index_type: type of the index to drop, e.g. 'btree' or 'bitmap'. Cannot be empty.
  • +
  • index_type: type of the index to drop, e.g. 'btree'. Cannot be empty.
  • partitions: partition filter to limit the partitions from which to drop the index. The comma (",") represents "AND", the semicolon (";") represents "OR". Left empty for all partitions.
  • - CALL sys.drop_global_index(table => 'default.T', index_column => 'name', index_type => 'bitmap')

    - CALL sys.drop_global_index(table => 'default.T', index_column => 'name', index_type => 'bitmap', partitions => 'pt=p1') + CALL sys.drop_global_index(table => 'default.T', index_column => 'name', index_type => 'btree', partitions => 'pt=p1') diff --git a/docs/docs/spark/sql-write.md b/docs/docs/spark/sql-write.md index 518bf0e5b40c..703976857a15 100644 --- a/docs/docs/spark/sql-write.md +++ b/docs/docs/spark/sql-write.md @@ -175,132 +175,108 @@ DELETE FROM my_table WHERE id = 1; Merges a set of updates, insertions and deletions based on a source table into a target table. -Note: - :::info -In update clause, to update primary key columns is not supported when the target table is a primary key table. +Updating primary key columns is not supported when the target table is a primary key table. ::: -**Example: One** - -This is a simple demo that, if a row exists in the target table update it, else insert it. +### Syntax ```sql --- Here both source and target tables have the same schema: (a INT, b INT, c STRING), and a is a primary key. - MERGE INTO target USING source -ON target.a = source.a -WHEN MATCHED THEN -UPDATE SET * -WHEN NOT MATCHED -THEN INSERT * +ON +WHEN MATCHED [AND ] THEN { UPDATE SET ... | DELETE } +WHEN NOT MATCHED [AND ] THEN INSERT ... ``` -**Example: Two** +Each `WHEN` clause can be repeated; clauses are evaluated in order, and the first matching one wins for a given row. -This is a demo with multiple, conditional clauses. +### Examples + +The examples below assume both source and target have schema `(a INT, b INT, c STRING)`, with `a` as the primary key. + +Simple upsert — update existing rows, insert new ones: ```sql --- Here both source and target tables have the same schema: (a INT, b INT, c STRING), and a is a primary key. +MERGE INTO target +USING source +ON target.a = source.a +WHEN MATCHED THEN UPDATE SET * +WHEN NOT MATCHED THEN INSERT * +``` + +Multiple conditional clauses: +```sql MERGE INTO target USING source ON target.a = source.a -WHEN MATCHED AND target.a = 5 THEN - UPDATE SET b = source.b + target.b -- when matched and meet the condition 1, then update b; -WHEN MATCHED AND source.c > 'c2' THEN - UPDATE SET * -- when matched and meet the condition 2, then update all the columns; -WHEN MATCHED THEN - DELETE -- when matched, delete this row in target table; -WHEN NOT MATCHED AND c > 'c9' THEN - INSERT (a, b, c) VALUES (a, b * 1.1, c) -- when not matched but meet the condition 3, then transform and insert this row; -WHEN NOT MATCHED THEN -INSERT * -- when not matched, insert this row without any transformation; +WHEN MATCHED AND target.a = 5 THEN UPDATE SET b = source.b + target.b +WHEN MATCHED AND source.c > 'c2' THEN UPDATE SET * +WHEN MATCHED THEN DELETE +WHEN NOT MATCHED AND c > 'c9' THEN INSERT (a, b, c) VALUES (a, b * 1.1, c) +WHEN NOT MATCHED THEN INSERT * ``` ### Column Alignment Assignments are aligned to the target table by **column name**. -For explicit clauses (`UPDATE SET col = expr` / `INSERT (col list) VALUES ...`), only the mentioned columns are written. Unmentioned target columns preserve their current value for `UPDATE`, or are filled with NULL / `CURRENT_DEFAULT` for `INSERT`. - -For star clauses (`UPDATE SET *` / `INSERT *`), `*` expands against the **target** columns. The behavior when source and target columns don't match exactly depends on `spark.paimon.write.merge-schema` (see [Write Merge Schema](#write-merge-schema)): - -| Scenario | `merge-schema=false` (default) | `merge-schema=true` | -|----------|-------------------------------|---------------------| -| Top-level source-extra columns | Silently dropped (`*` only covers target columns) | Evolved into the target schema | -| Top-level target columns missing from source | Throws | `UPDATE *` preserves current value; `INSERT *` fills NULL | -| Nested struct source-extra fields | Throws | Evolved into the target schema | -| Nested struct target-missing fields | Throws | `UPDATE *` preserves current value; `INSERT *` fills NULL | - -The key difference between top-level and nested: under strict mode (`merge-schema=false`), top-level source-extras are silently dropped because `*` never references them, while nested source-extras inside a struct value throw an error to avoid silent data loss. +- **Explicit clauses** (`UPDATE SET col = expr` / `INSERT (col list) VALUES ...`) — only the mentioned columns are written. Unmentioned target columns preserve their current value for `UPDATE`, or get NULL / `CURRENT_DEFAULT` for `INSERT`. +- **Star clauses** (`UPDATE SET *` / `INSERT *`) — `*` expands against the **target** columns. When source and target columns don't match exactly, the behavior depends on `spark.paimon.write.merge-schema`; see [Column Alignment by Write Path](#column-alignment-by-write-path) under Write Merge Schema for the full table covering both `MERGE INTO *` and byName `INSERT` paths. ## Write Merge Schema +When `write.merge-schema` is enabled, Paimon automatically evolves the table schema during write to accommodate new columns in the incoming data, while preserving data integrity. + :::info Since the table schema may be updated during writing, catalog caching needs to be disabled to use this feature. Configure `spark.sql.catalog..cache-enabled` to `false`. ::: -Write merge schema is a feature that allows users to easily modify the current schema of a table to adapt to existing data, or new data that changes over time, while maintaining data integrity and consistency. - -Paimon supports automatic schema merging of source data and current table data while data is being written, and uses the merged schema as the latest schema of the table, and it only requires configuring `write.merge-schema`. - -```scala -data.write - .format("paimon") - .mode("append") - .option("write.merge-schema", "true") - .save(location) -``` - -When enable `write.merge-schema`, Paimon can allow users to perform the following actions on table schema by default: -- Adding columns -- Up-casting the type of column(e.g. Int -> Long) - -Paimon also supports explicit type conversions between certain types (e.g. String -> Date, Long -> Int), it requires an explicit configuration `write.merge-schema.explicit-cast`. - -Write merge schema can be used in streaming mode at the same time. - -```scala -val inputData = MemoryStream[(Int, String)] -inputData - .toDS() - .toDF("col1", "col2") - .writeStream - .format("paimon") - .option("checkpointLocation", "/path/to/checkpoint") - .option("write.merge-schema", "true") - .option("write.merge-schema.explicit-cast", "true") - .start(location) -``` +### How It Evolves the Schema -Here list the configurations. +Three options control how aggressively the schema evolves; each only takes effect when the previous one is enabled: - - + + - + + + + + - +
    Scan ModeDescriptionOptionDescription
    write.merge-schema
    If true, merge the data schema and the table schema automatically before write data.If true, evolve the table schema to accept new columns from the incoming data. Existing column types are preserved and incoming values are cast to them; to also widen existing types, enable write.merge-schema.type-widening.
    write.merge-schema.type-widening
    Only effective when write.merge-schema is true. If true, widen an existing column type when the incoming data has a wider compatible type (e.g. INT -> BIGINT, DECIMAL precision increase). Lossy changes are still rejected unless write.merge-schema.explicit-cast is also true.
    write.merge-schema.explicit-cast
    If true, allow to merge data types if the two types meet the rules for explicit casting.Only effective when write.merge-schema.type-widening is true. If true, also allow lossy type changes between compatible types (e.g. BIGINT -> INT, STRING -> DATE).
    -This mode also supports Spark SQL. Here is an example: +### Examples + +DataFrame batch write: + +```scala +data.write + .format("paimon") + .mode("append") + .option("write.merge-schema", "true") + .saveAsTable("t") +``` + +Spark SQL (requires Spark 3.5+ for `BY NAME`): ```sql SET `spark.paimon.write.merge-schema` = true; @@ -308,15 +284,52 @@ SET `spark.paimon.write.merge-schema` = true; CREATE TABLE t (a INT, b STRING); INSERT INTO t VALUES (1, '1'), (2, '2'); --- Need using `BY NAME` statement (requires Spark 3.5+) INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, 3 AS c; ``` +Streaming write: + +```scala +val inputData = MemoryStream[(Int, String)] +inputData + .toDS() + .toDF("col1", "col2") + .writeStream + .format("paimon") + .option("checkpointLocation", "/path/to/checkpoint") + .option("write.merge-schema", "true") + .toTable("t") +``` + +### Column Alignment by Write Path + +When the source schema doesn't match the target schema exactly, the behavior depends on both `write.merge-schema` and the write path. For nested struct fields, all byName paths behave the same; at the top level, `MERGE INTO *` differs from regular byName `INSERT` because `*` expansion only references target columns. + +| Write path | Scenario | `merge-schema=false` (default) | `merge-schema=true` | +|------------|----------|-------------------------------|---------------------| +| **byName `INSERT`** (`INSERT INTO ... BY NAME` / `saveAsTable` / `writeTo`) | Top-level source-extra columns | Throws | Evolved into the target schema | +| | Top-level target columns missing from source | NULL-filled | NULL-filled | +| | Nested struct source-extra fields | Throws | Evolved into the target schema | +| | Nested struct target-missing fields | Throws | NULL-filled | +| **`MERGE INTO *`** (`UPDATE *` / `INSERT *`) | Top-level source-extra columns | Silently dropped (`*` only covers target columns) | Evolved into the target schema | +| | Top-level target columns missing from source | Throws | `UPDATE *` preserves current value; `INSERT *` fills `CURRENT_DEFAULT` (or NULL when no default) | +| | Nested struct source-extra fields | Throws | Evolved into the target schema | +| | Nested struct target-missing fields | Throws | `UPDATE *` preserves current value; `INSERT *` fills `CURRENT_DEFAULT` (or NULL when no default) | + +Notes: +- Position-based writes (e.g. `INSERT INTO t VALUES (...)` without `BY NAME`) require an exact column count match and don't engage schema evolution; only byName writes are covered above. +- Top-level target-missing under `merge-schema=false` for byName `INSERT` mirrors Spark's `INSERT FILL` semantics — only nested missing fields throw. +- Under strict mode (`merge-schema=false`), nested source-extra fields throw to avoid silent data loss; for `MERGE INTO *` at the top level, source-extras are silently dropped because `*` never references them. + ## COPY INTO -`COPY INTO` provides a SQL command for bulk loading CSV files into Paimon tables and writing table data to CSV files. +`COPY INTO` provides a SQL command for bulk loading data files into Paimon tables and exporting table data to files. Supported formats: **CSV**, **JSON**, and **Parquet**. -### CSV Import +:::info +**SQL dialect:** Paimon's `COPY INTO` is a Snowflake-style extension (`FILE_FORMAT = (TYPE = ...)`, `PATTERN`, `FORCE`, `ON_ERROR`), not the Databricks `COPY INTO` form (`FILEFORMAT` + `FORMAT_OPTIONS (...)` / `COPY_OPTIONS (...)`). It implements only a subset of the Snowflake syntax. In particular, `ON_ERROR` supports `ABORT_STATEMENT` (default), `CONTINUE`, and `SKIP_FILE`; the Snowflake variants `SKIP_FILE_` and `SKIP_FILE_%` are **not** supported. +::: + +#### CSV Import ```sql COPY INTO table_name [(col1, col2, ...)] @@ -324,7 +337,7 @@ FROM 'source_path' FILE_FORMAT = (TYPE = CSV [, option = value, ...]) [PATTERN = 'regex'] [FORCE = TRUE|FALSE] -[ON_ERROR = ABORT_STATEMENT] +[ON_ERROR = { ABORT_STATEMENT | CONTINUE | SKIP_FILE }] ``` **Basic import:** @@ -354,7 +367,67 @@ PATTERN = '.*\.csv' FORCE = FALSE; ``` -### Write CSV Files +#### JSON Import + +```sql +COPY INTO table_name [(col1, col2, ...)] +FROM 'source_path' +FILE_FORMAT = (TYPE = JSON [, option = value, ...]) +[PATTERN = 'regex'] +[FORCE = TRUE|FALSE] +[ON_ERROR = { ABORT_STATEMENT | CONTINUE | SKIP_FILE }] +``` + +**Basic import:** + +```sql +COPY INTO my_db.my_table +FROM '/data/json_files/' +FILE_FORMAT = (TYPE = JSON); +``` + +**Import multi-line JSON array:** + +```sql +COPY INTO my_db.events +FROM '/data/events/' +FILE_FORMAT = (TYPE = JSON, MULTI_LINE = TRUE); +``` + +JSON columns are matched **by column name** (not by position), so source field order does not matter. + +#### Parquet Import + +```sql +COPY INTO table_name [(col1, col2, ...)] +FROM 'source_path' +FILE_FORMAT = (TYPE = PARQUET [, option = value, ...]) +[PATTERN = 'regex'] +[FORCE = TRUE|FALSE] +[ON_ERROR = { ABORT_STATEMENT | CONTINUE | SKIP_FILE }] +``` + +**Basic import:** + +```sql +COPY INTO my_db.my_table +FROM '/data/parquet_files/' +FILE_FORMAT = (TYPE = PARQUET); +``` + +**Import with PATTERN:** + +```sql +COPY INTO my_db.events +FROM '/data/lake/' +FILE_FORMAT = (TYPE = PARQUET) +PATTERN = '.*\.parquet' +FORCE = FALSE; +``` + +Parquet columns are matched **by column name** (not by position). Extra columns in the source files are ignored; missing columns become NULL. + +#### Write CSV Files ```sql COPY INTO 'target_path' @@ -372,15 +445,60 @@ FILE_FORMAT = (TYPE = CSV, HEADER = TRUE, FIELD_DELIMITER = ',') OVERWRITE = TRUE; ``` -### FILE_FORMAT Options +#### Write JSON Files + +```sql +COPY INTO 'target_path' +FROM table_name +FILE_FORMAT = (TYPE = JSON [, option = value, ...]) +[OVERWRITE = TRUE|FALSE] +``` + +**Basic JSON export:** + +```sql +COPY INTO '/export/events_backup/' +FROM my_db.events +FILE_FORMAT = (TYPE = JSON) +OVERWRITE = TRUE; +``` + +#### Write Parquet Files + +```sql +COPY INTO 'target_path' +FROM table_name +FILE_FORMAT = (TYPE = PARQUET [, option = value, ...]) +[OVERWRITE = TRUE|FALSE] +``` + +**Basic Parquet export:** + +```sql +COPY INTO '/export/data_backup/' +FROM my_db.events +FILE_FORMAT = (TYPE = PARQUET) +OVERWRITE = TRUE; +``` + +**Export with compression:** + +```sql +COPY INTO '/export/data_compressed/' +FROM my_db.events +FILE_FORMAT = (TYPE = PARQUET, COMPRESSION = GZIP) +OVERWRITE = TRUE; +``` -`FILE_FORMAT` is required and must include `TYPE = CSV`. +#### FILE_FORMAT Options -**Import options:** +`FILE_FORMAT` is required and must include `TYPE = CSV`, `TYPE = JSON`, or `TYPE = PARQUET`. + +**CSV import options:** | Option | Description | Default | |--------|-------------|---------| -| TYPE | File format type. Must be `CSV`. | (required) | +| TYPE | File format type. `CSV`, `JSON`, or `PARQUET`. | (required) | | FIELD_DELIMITER | Column delimiter character. | `,` | | SKIP_HEADER | Skip the first line as header. Only `0` or `1`. | `0` | | QUOTE | Quote character for enclosing fields. | `"` | @@ -389,46 +507,81 @@ OVERWRITE = TRUE; | EMPTY_FIELD_AS_NULL | Treat empty fields as NULL. `TRUE` or `FALSE`. | `FALSE` | | COMPRESSION | Compression codec (e.g. `GZIP`). | `NONE` | -**Write options:** +**JSON import options:** + +| Option | Description | Default | +|--------|-------------|---------| +| TYPE | File format type. `CSV`, `JSON`, or `PARQUET`. | (required) | +| MULTI_LINE | Parse multi-line JSON (e.g. JSON arrays or pretty-printed objects). | `FALSE` | +| NULL_IF | List of string values to interpret as NULL. | (none) | +| EMPTY_FIELD_AS_NULL | Treat empty string values as NULL. | `FALSE` | +| COMPRESSION | Compression codec (e.g. `GZIP`). | `NONE` | + +**Parquet import options:** + +| Option | Description | Default | +|--------|-------------|---------| +| TYPE | File format type. `CSV`, `JSON`, or `PARQUET`. | (required) | +| COMPRESSION | Compression codec. Usually auto-detected; rarely needed for import. | (auto) | + +**CSV write options:** | Option | Description | Default | |--------|-------------|---------| -| TYPE | File format type. Must be `CSV`. | (required) | +| TYPE | File format type. `CSV`, `JSON`, or `PARQUET`. | (required) | | FIELD_DELIMITER | Column delimiter character. | `,` | | HEADER | Write column names as the first line. `TRUE` or `FALSE`. | `FALSE` | | QUOTE | Quote character for enclosing fields. | `"` | | ESCAPE | Escape character within quoted fields. | `\` | | COMPRESSION | Compression codec (e.g. `GZIP`). | `NONE` | -### Import Options +**JSON write options:** + +| Option | Description | Default | +|--------|-------------|---------| +| TYPE | File format type. `CSV`, `JSON`, or `PARQUET`. | (required) | +| DATE_FORMAT | Custom date format pattern. | Spark default | +| TIMESTAMP_FORMAT | Custom timestamp format pattern. | Spark default | +| COMPRESSION | Compression codec (e.g. `GZIP`). | `NONE` | + +**Parquet write options:** + +| Option | Description | Default | +|--------|-------------|---------| +| TYPE | File format type. `CSV`, `JSON`, or `PARQUET`. | (required) | +| COMPRESSION | Compression codec (`SNAPPY`, `GZIP`, `NONE`, etc.). | `SNAPPY` | + +#### Import Options | Option | Description | Default | |--------|-------------|---------| | PATTERN | Regex to filter source files by base file name. Only matching files are loaded. | (all files) | | FORCE | `FALSE`: skip files already loaded (idempotent). `TRUE`: reload all files. | `FALSE` | -| ON_ERROR | Error handling strategy. Only `ABORT_STATEMENT` is supported. | `ABORT_STATEMENT` | +| ON_ERROR | Error handling strategy. `ABORT_STATEMENT`: abort on any error. `CONTINUE`: skip bad rows and continue loading. `SKIP_FILE`: skip files that contain errors. | `ABORT_STATEMENT` | -### File Write Options +#### File Write Options | Option | Description | Default | |--------|-------------|---------| | OVERWRITE | `FALSE`: fail if target path exists. `TRUE`: overwrite existing files. | `FALSE` | -### Column Mapping +#### Column Mapping When an explicit column list is provided (e.g., `COPY INTO t (col1, col2) FROM ...`): -- CSV columns are mapped **positionally** to the specified column list. -- The number of CSV columns must match the column list length. +- **CSV**: Columns are mapped **positionally** to the specified column list. +- **JSON**: Columns are matched **by name** to the specified column list. +- **Parquet**: Columns are matched **by name** to the specified column list. +- The number of source columns must match the column list length (CSV). For JSON and Parquet, missing fields in the source become NULL. - Columns not in the list are filled with their **DEFAULT value** (if defined in the table schema) or **NULL**. - Non-nullable columns without a default value that are not in the list will cause an error. When no column list is provided: -- CSV columns are mapped positionally to all writable columns in the target table. -- The number of CSV columns must match the number of writable columns. +- **CSV**: Columns are mapped positionally to all writable columns in the target table. The number of CSV columns must match the number of writable columns. +- **JSON**: Columns are matched by name to the writable columns. Missing fields in JSON become NULL. -### Repeated Imports +#### Repeated Imports By default (`FORCE = FALSE`), COPY INTO tracks which files have been successfully loaded. A file is identified by its path, size, and last-modified timestamp. @@ -436,16 +589,18 @@ By default (`FORCE = FALSE`), COPY INTO tracks which files have been successfull - If a source file is modified (size or timestamp changes), it becomes eligible for re-loading. - `FORCE = TRUE` bypasses load history and always re-imports all matching files. -### Result Output +#### Result Output **Import** returns one row per source file: | Column | Type | Description | |--------|------|-------------| | file_name | STRING | Source file name | -| status | STRING | `LOADED` or `SKIPPED` | +| status | STRING | `LOADED`, `PARTIALLY_LOADED`, `LOAD_FAILED`, or `SKIPPED` | | rows_loaded | BIGINT | Number of rows written | | rows_parsed | BIGINT | Number of rows parsed from the file | +| errors_seen | BIGINT | Number of error rows (parse or cast failures) | +| first_error | STRING | First error message encountered (NULL if no errors) | **File write** returns a single row: @@ -455,11 +610,11 @@ By default (`FORCE = FALSE`), COPY INTO tracks which files have been successfull | file_count | INT | Number of files written | | rows_written | BIGINT | Total rows written | -### Limitations +#### Limitations -- Only **CSV** format is supported. +- **CSV column-count mismatch**: Rows with fewer or more columns than the target schema are treated as malformed records. With `ON_ERROR = CONTINUE`, these rows are skipped and counted as errors. +- Only **CSV**, **JSON**, and **Parquet** formats are supported. - Writing files only supports `FROM table_name`; `FROM (SELECT ...)` is not supported. -- `ON_ERROR = CONTINUE` is not supported; any parse or cast error aborts the entire command. - `SINGLE = TRUE` (single-file output) is not supported. - File format options must be specified inline in `FILE_FORMAT = (...)`. - File listing is **non-recursive**: only direct files under the source path are processed. Subdirectories are ignored. diff --git a/docs/generated/core_configuration.html b/docs/generated/core_configuration.html index 7bc02cf87bc6..3094d2e9d0d3 100644 --- a/docs/generated/core_configuration.html +++ b/docs/generated/core_configuration.html @@ -470,6 +470,18 @@ Boolean Whether enable data evolution for row tracking table. + +
    data-evolution.merge-into.file-pruning
    + true + Boolean + If true, enables the file-level pruning step for MergeInto partial column update on data-evolution tables. Set this to false when most files in the target partition are expected to be updated, so that the overhead of collecting touched file IDs outweighs the benefit of pruning untouched files. + + +
    data-evolution.merge-into.source-persist
    + false + Boolean + Whether to persist source when process merge into action on data evolution table. +
    data-file.external-paths
    (none) @@ -736,9 +748,9 @@
    global-index.thread-num
    - (none) + 32 Integer - The maximum number of concurrent scanner for global index.By default is the number of processors available to the Java virtual machine. + The maximum number of concurrent threads for global index I/O.
    ignore-delete
    @@ -919,7 +931,25 @@
    manifest.merge-min-count
    30 Integer - To avoid frequent manifest merges, this parameter specifies the minimum number of ManifestFileMeta to merge. + To avoid frequent manifest merges, this parameter specifies the minimum number of ManifestFileMeta to merge.
    Note: when 'manifest-sort.enabled' is true, this minimum-count gate is only applied to the trailing sub-segment of a section that exceeds 'manifest-sort.max-rewrite-size'. Small under-budget sections are sorted and rewritten directly, so two small manifest files may be merged into one even when their count is below this threshold and full compaction is not triggered. + + +
    manifest-sort.enabled
    + false + Boolean + Whether to invoke manifest sort rewrite during commit.
    Note: enabling this changes the semantics of 'manifest.merge-min-count'. In the sort rewrite path, small manifest files within the rewrite budget are sorted and merged directly, so the minimum-count gate no longer prevents merging a small number of under-budget manifest files when full compaction is not triggered. + + +
    manifest-sort.partition-field
    + (none) + String + Partition field name to sort manifest entries by. Validated by schema validation, if not configured, defaults to the first partition field. + + +
    manifest-sort.max-rewrite-size
    + 256 mb + MemorySize + Maximum total size of manifest files to rewrite in a single sort rewrite pass. Sections exceeding this limit are skipped. Set to a larger value to allow more aggressive sort rewriting. The cap only limits the sorted rewrite portion and full/minor cleanup may still happen beyond it.
    manifest.target-file-size
    @@ -1337,6 +1367,12 @@

    Enum

    Specify the order of sequence.field.

    Possible values:
    • "ascending": specifies sequence.field sort order is ascending.
    • "descending": specifies sequence.field sort order is descending.
    + +
    sequence.snapshot-ordering
    + false + Boolean + When enabled, merge uses the commit snapshot id as the ordering key for primary-key conflicts: records from later snapshots always win. Designed for multi-writer scenarios on the same primary-key table where wall-clock sequence numbers cannot be globally ordered. The order of records within the same snapshot is not guaranteed. Mutually exclusive with sequence.field. Requires a primary-key table with write-only=true. Inline compaction is not allowed because snapshot ids are assigned only after commit. To compact such tables, run a dedicated compaction job/action with write-only=false. +
    sink.process-time-zone
    (none) diff --git a/docs/generated/hive_catalog_configuration.html b/docs/generated/hive_catalog_configuration.html index 48adc1114fb3..63eaab9327e4 100644 --- a/docs/generated/hive_catalog_configuration.html +++ b/docs/generated/hive_catalog_configuration.html @@ -56,6 +56,12 @@ If not configured, try to load from 'HIVE_CONF_DIR' env. + +
    hive.skip-update-stats
    + false + Boolean + If true, sets DO_NOT_UPDATE_STATS in the Hive EnvironmentContext when altering tables, preventing Hive from updating table statistics. +
    location-in-properties
    false diff --git a/docs/generated/spark_connector_configuration.html b/docs/generated/spark_connector_configuration.html index a6c8278f4c44..0dbc8b830a77 100644 --- a/docs/generated/spark_connector_configuration.html +++ b/docs/generated/spark_connector_configuration.html @@ -84,13 +84,19 @@
    write.merge-schema
    false Boolean - If true, merge the data schema and the table schema automatically before write data. + If true, evolve the table schema to accept new columns from the incoming data. Existing column types are preserved and incoming values are cast to them; to also widen existing types, enable 'write.merge-schema.type-widening'.
    write.merge-schema.explicit-cast
    false Boolean - If true, allow to merge data types if the two types meet the rules for explicit casting. + Only effective when 'write.merge-schema.type-widening' is true. If true, also allow lossy type changes between compatible types (e.g. BIGINT -> INT, STRING -> DATE). + + +
    write.merge-schema.type-widening
    + false + Boolean + Only effective when 'write.merge-schema' is true. If true, widen an existing column type when the incoming data has a wider compatible type (e.g. INT -> BIGINT, DECIMAL precision increase). Lossy changes are still rejected unless 'write.merge-schema.explicit-cast' is also true.
    write.use-v2-write
    diff --git a/docs/redirects.js b/docs/redirects.js index 3e37353c9fe6..d3facf9b7b49 100644 --- a/docs/redirects.js +++ b/docs/redirects.js @@ -2,7 +2,11 @@ module.exports = [ { "from": "/append-table/blob-storage.html", - "to": "/append-table/blob" + "to": "/multimodal-table/blob" + }, + { + "from": "/append-table/blob", + "to": "/multimodal-table/blob" }, { "from": "/append-table/bucketed.html", @@ -10,11 +14,19 @@ module.exports = [ }, { "from": "/append-table/data-evolution.html", - "to": "/append-table/data-evolution" + "to": "/multimodal-table/data-evolution" + }, + { + "from": "/append-table/data-evolution", + "to": "/multimodal-table/data-evolution" }, { "from": "/append-table/global-index.html", - "to": "/append-table/global-index" + "to": "/multimodal-table/global-index" + }, + { + "from": "/append-table/global-index", + "to": "/multimodal-table/global-index" }, { "from": "/append-table/incremental-clustering.html", @@ -26,7 +38,11 @@ module.exports = [ }, { "from": "/append-table/vector-storage.html", - "to": "/append-table/vector" + "to": "/multimodal-table/vector" + }, + { + "from": "/append-table/vector", + "to": "/multimodal-table/vector" }, { "from": "/cdc-ingestion/flink-cdc.html", diff --git a/docs/sidebars.js b/docs/sidebars.js index d6f35bec727f..6f6fa15afa2c 100644 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -57,7 +57,21 @@ const sidebars = { }, { type: "category", - "label": "Table with PK", + "label": "Append Table", + "collapsed": true, + "link": { + type: "doc", + "id": "append-table/index" + }, + "items": [ + "append-table/incremental-clustering", + "append-table/bucketed", + "append-table/row-tracking" + ] + }, + { + type: "category", + "label": "PrimaryKey Table", "collapsed": true, "link": { type: "doc", @@ -90,20 +104,17 @@ const sidebars = { }, { type: "category", - "label": "Table w/o PK", + "label": "Multimodal Table", "collapsed": true, "link": { type: "doc", - "id": "append-table/index" + "id": "multimodal-table/index" }, "items": [ - "append-table/incremental-clustering", - "append-table/bucketed", - "append-table/row-tracking", - "append-table/data-evolution", - "append-table/blob", - "append-table/vector", - "append-table/global-index" + "multimodal-table/data-evolution", + "multimodal-table/blob", + "multimodal-table/vector", + "multimodal-table/global-index" ] }, { @@ -290,7 +301,8 @@ const sidebars = { "items": [ "project/download", "project/contributing", - "project/committer" + "project/committer", + "project/security" ] }, { diff --git a/docs/src/css/custom.css b/docs/src/css/custom.css index 618ae52f3976..32f34dd984ed 100644 --- a/docs/src/css/custom.css +++ b/docs/src/css/custom.css @@ -50,6 +50,21 @@ line-height: 1.5; } +.hero-links { + margin: 0.6rem auto 0; + font-size: 0.95rem; +} + +.hero-links a { + color: var(--ifm-color-primary); + text-decoration: none; + font-weight: 500; +} + +.hero-links a:hover { + text-decoration: underline; +} + /* ===== Feature Columns ===== */ .feature-columns { display: grid; diff --git a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java index 6b0ef75ff823..80f0a8cec0c4 100644 --- a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java +++ b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java @@ -80,6 +80,8 @@ public class CoreOptions implements Serializable { public static final String NESTED_KEY = "nested-key"; + public static final String NESTED_SEQUENCE_FIELD = "nested-sequence-field"; + public static final String COUNT_LIMIT = "count-limit"; public static final String DISTINCT = "distinct"; @@ -466,8 +468,61 @@ public InlineElement getDescription() { .intType() .defaultValue(30) .withDescription( - "To avoid frequent manifest merges, this parameter specifies the minimum number " - + "of ManifestFileMeta to merge."); + Description.builder() + .text( + "To avoid frequent manifest merges, this parameter specifies the minimum number " + + "of ManifestFileMeta to merge.") + .linebreak() + .text( + "Note: when '" + + "manifest-sort.enabled" + + "' is true, this minimum-count gate is only " + + "applied to the trailing sub-segment of a " + + "section that exceeds '" + + "manifest-sort.max-rewrite-size" + + "'. Small under-budget sections are sorted " + + "and rewritten directly, so two small manifest " + + "files may be merged into one even when their " + + "count is below this threshold and full " + + "compaction is not triggered.") + .build()); + + public static final ConfigOption MANIFEST_SORT_ENABLED = + key("manifest-sort.enabled") + .booleanType() + .defaultValue(false) + .withDescription( + Description.builder() + .text("Whether to invoke manifest sort rewrite during commit.") + .linebreak() + .text( + "Note: enabling this changes the semantics of '" + + "manifest.merge-min-count" + + "'. In the sort rewrite path, small manifest " + + "files within the rewrite budget are sorted " + + "and merged directly, so the minimum-count " + + "gate no longer prevents merging a small " + + "number of under-budget manifest files when " + + "full compaction is not triggered.") + .build()); + + public static final ConfigOption MANIFEST_SORT_PARTITION_FIELD = + key("manifest-sort.partition-field") + .stringType() + .noDefaultValue() + .withDescription( + "Partition field name to sort manifest entries by. Validated by" + + " schema validation, if not configured, defaults to the first partition field."); + + public static final ConfigOption MANIFEST_SORT_MAX_REWRITE_SIZE = + key("manifest-sort.max-rewrite-size") + .memoryType() + .defaultValue(MemorySize.ofMebiBytes(256)) + .withDescription( + "Maximum total size of manifest files to rewrite in a single" + + " sort rewrite pass. Sections exceeding this limit are" + + " skipped. Set to a larger value to allow more aggressive" + + " sort rewriting. The cap only limits the sorted rewrite portion and full/minor cleanup may still happen beyond it."); public static final ConfigOption UPSERT_KEY = key("upsert-key") @@ -965,6 +1020,24 @@ public InlineElement getDescription() { .defaultValue(SortOrder.ASCENDING) .withDescription("Specify the order of sequence.field."); + @Immutable + public static final ConfigOption SEQUENCE_SNAPSHOT_ORDERING = + key("sequence.snapshot-ordering") + .booleanType() + .defaultValue(false) + .withDescription( + "When enabled, merge uses the commit snapshot id as the ordering key " + + "for primary-key conflicts: records from later snapshots " + + "always win. Designed for multi-writer scenarios on the same " + + "primary-key table where wall-clock sequence numbers cannot " + + "be globally ordered. The order of records within the same " + + "snapshot is not guaranteed. Mutually exclusive with " + + "sequence.field. Requires a primary-key table with " + + "write-only=true. Inline compaction is not allowed because " + + "snapshot ids are assigned only after commit. To compact such " + + "tables, run a dedicated compaction job/action with " + + "write-only=false."); + @Immutable public static final ConfigOption AGGREGATION_REMOVE_RECORD_ON_DELETE = key("aggregation.remove-record-on-delete") @@ -2236,6 +2309,24 @@ public InlineElement getDescription() { .defaultValue(false) .withDescription("Whether enable data evolution for row tracking table."); + public static final ConfigOption DATA_EVOLUTION_MERGE_INTO_FILE_PRUNING = + key("data-evolution.merge-into.file-pruning") + .booleanType() + .defaultValue(true) + .withDescription( + "If true, enables the file-level pruning step for MergeInto partial column " + + "update on data-evolution tables. " + + "Set this to false when most files in the target partition are expected " + + "to be updated, so that the overhead of collecting touched file IDs " + + "outweighs the benefit of pruning untouched files."); + + public static final ConfigOption DATA_EVOLUTION_MERGE_INTO_SOURCE_PERSIST = + key("data-evolution.merge-into.source-persist") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether to persist source when process merge into action on data evolution table."); + public static final ConfigOption BLOB_COMPACTION_ENABLED = key("blob-compaction.enabled") .booleanType() @@ -2308,6 +2399,7 @@ public InlineElement getDescription() { .noDefaultValue() .withDescription("Format table commit hive sync uri."); + @Immutable public static final ConfigOption BLOB_FIELD = key("blob-field") .stringType() @@ -2445,10 +2537,9 @@ public InlineElement getDescription() { public static final ConfigOption GLOBAL_INDEX_THREAD_NUM = key("global-index.thread-num") .intType() - .noDefaultValue() + .defaultValue(32) .withDescription( - "The maximum number of concurrent scanner for global index." - + "By default is the number of processors available to the Java virtual machine."); + "The maximum number of concurrent threads for global index I/O."); public static final ConfigOption OVERWRITE_UPGRADE = key("overwrite-upgrade") @@ -2601,6 +2692,19 @@ public MemorySize manifestFullCompactionThresholdSize() { return options.get(MANIFEST_FULL_COMPACTION_FILE_SIZE); } + public boolean manifestSortEnabled() { + return options.get(MANIFEST_SORT_ENABLED); + } + + @Nullable + public String manifestSortPartitionField() { + return options.get(MANIFEST_SORT_PARTITION_FIELD); + } + + public long manifestSortMaxRewriteSize() { + return options.get(MANIFEST_SORT_MAX_REWRITE_SIZE).getBytes(); + } + public String partitionDefaultName() { return options.get(PARTITION_DEFAULT_NAME); } @@ -2740,6 +2844,18 @@ public List fieldNestedUpdateAggNestedKey(String fieldName) { return Arrays.stream(keyString.split(",")).map(String::trim).collect(Collectors.toList()); } + public List fieldNestedUpdateAggNestedSequenceField(String fieldName) { + String keyString = + options.get( + key(FIELDS_PREFIX + "." + fieldName + "." + NESTED_SEQUENCE_FIELD) + .stringType() + .noDefaultValue()); + if (keyString == null) { + return Collections.emptyList(); + } + return Arrays.stream(keyString.split(",")).map(String::trim).collect(Collectors.toList()); + } + public int fieldNestedUpdateAggCountLimit(String fieldName) { return options.get( key(FIELDS_PREFIX + "." + fieldName + "." + COUNT_LIMIT) @@ -3401,6 +3517,10 @@ public boolean sequenceFieldSortOrderIsAscending() { return options.get(SEQUENCE_FIELD_SORT_ORDER) == SortOrder.ASCENDING; } + public boolean snapshotSequenceOrdering() { + return options.get(SEQUENCE_SNAPSHOT_ORDERING); + } + public Optional rowkindField() { return options.getOptional(ROWKIND_FIELD); } @@ -3730,6 +3850,14 @@ public boolean dataEvolutionEnabled() { return options.get(DATA_EVOLUTION_ENABLED); } + public boolean dataEvolutionMergeIntoFilePruning() { + return options.get(DATA_EVOLUTION_MERGE_INTO_FILE_PRUNING); + } + + public boolean dataEvolutionMergeIntoSourcePersist() { + return options.get(DATA_EVOLUTION_MERGE_INTO_SOURCE_PERSIST); + } + public boolean blobCompactionEnabled() { return options.get(BLOB_COMPACTION_ENABLED); } diff --git a/paimon-api/src/main/java/org/apache/paimon/rest/RESTApi.java b/paimon-api/src/main/java/org/apache/paimon/rest/RESTApi.java index 7433a30388f8..263c4e2c0640 100644 --- a/paimon-api/src/main/java/org/apache/paimon/rest/RESTApi.java +++ b/paimon-api/src/main/java/org/apache/paimon/rest/RESTApi.java @@ -49,7 +49,6 @@ import org.apache.paimon.rest.requests.ListPartitionsByNamesRequest; import org.apache.paimon.rest.requests.MarkDonePartitionsRequest; import org.apache.paimon.rest.requests.RegisterTableRequest; -import org.apache.paimon.rest.requests.RenameBranchRequest; import org.apache.paimon.rest.requests.RenameTableRequest; import org.apache.paimon.rest.requests.ReplaceTableRequest; import org.apache.paimon.rest.requests.ResetConsumerRequest; @@ -985,27 +984,6 @@ public void dropBranch(Identifier identifier, String branch) { restAuthFunction); } - /** - * Rename branch for table. - * - * @param identifier database name and table name. - * @param fromBranch source branch name - * @param toBranch target branch name - * @throws NoSuchResourceException Exception thrown on HTTP 404 means the branch not exists - * @throws AlreadyExistsException Exception thrown on HTTP 409 means the target branch already - * exists - * @throws ForbiddenException Exception thrown on HTTP 403 means don't have the permission for - * this table - */ - public void renameBranch(Identifier identifier, String fromBranch, String toBranch) { - RenameBranchRequest request = new RenameBranchRequest(toBranch); - client.post( - resourcePaths.renameBranch( - identifier.getDatabaseName(), identifier.getObjectName(), fromBranch), - request, - restAuthFunction); - } - /** * Forward branch for table. * diff --git a/paimon-api/src/main/java/org/apache/paimon/rest/ResourcePaths.java b/paimon-api/src/main/java/org/apache/paimon/rest/ResourcePaths.java index 66a1653232b8..28f79d040995 100644 --- a/paimon-api/src/main/java/org/apache/paimon/rest/ResourcePaths.java +++ b/paimon-api/src/main/java/org/apache/paimon/rest/ResourcePaths.java @@ -273,19 +273,6 @@ public String forwardBranch(String databaseName, String tableName, String branch "forward"); } - public String renameBranch(String databaseName, String tableName, String branch) { - return SLASH.join( - V1, - prefix, - DATABASES, - encodeString(databaseName), - TABLES, - encodeString(tableName), - BRANCHES, - encodeString(branch), - "rename"); - } - public String tags(String databaseName, String objectName) { return SLASH.join( V1, diff --git a/paimon-api/src/main/java/org/apache/paimon/rest/requests/RenameBranchRequest.java b/paimon-api/src/main/java/org/apache/paimon/rest/requests/RenameBranchRequest.java deleted file mode 100644 index 63cf0011fac0..000000000000 --- a/paimon-api/src/main/java/org/apache/paimon/rest/requests/RenameBranchRequest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.rest.requests; - -import org.apache.paimon.rest.RESTRequest; - -import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonCreator; -import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonGetter; -import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonProperty; - -/** Request for renaming branch. */ -@JsonIgnoreProperties(ignoreUnknown = true) -public class RenameBranchRequest implements RESTRequest { - - private static final String FIELD_TO_BRANCH = "toBranch"; - - @JsonProperty(FIELD_TO_BRANCH) - private final String toBranch; - - @JsonCreator - public RenameBranchRequest(@JsonProperty(FIELD_TO_BRANCH) String toBranch) { - this.toBranch = toBranch; - } - - @JsonGetter(FIELD_TO_BRANCH) - public String toBranch() { - return toBranch; - } -} diff --git a/paimon-arrow/src/main/java/org/apache/paimon/arrow/ArrowFieldTypeConversion.java b/paimon-arrow/src/main/java/org/apache/paimon/arrow/ArrowFieldTypeConversion.java index 0a72e89304cd..80d9208053b8 100644 --- a/paimon-arrow/src/main/java/org/apache/paimon/arrow/ArrowFieldTypeConversion.java +++ b/paimon-arrow/src/main/java/org/apache/paimon/arrow/ArrowFieldTypeConversion.java @@ -49,8 +49,6 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; -import java.time.ZoneId; - /** Utils for conversion between Paimon {@link DataType} and Arrow {@link FieldType}. */ public class ArrowFieldTypeConversion { @@ -148,9 +146,7 @@ public FieldType visit(TimestampType timestampType) { public FieldType visit(LocalZonedTimestampType localZonedTimestampType) { int precision = localZonedTimestampType.getPrecision(); TimeUnit timeUnit = getTimeUnit(precision); - ArrowType arrowType = - new ArrowType.Timestamp( - timeUnit, ZoneId.systemDefault().normalized().toString()); + ArrowType arrowType = new ArrowType.Timestamp(timeUnit, "UTC"); return new FieldType(localZonedTimestampType.isNullable(), arrowType, null); } diff --git a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java index 9b0333f37647..5e7dc9539548 100644 --- a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java +++ b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java @@ -51,6 +51,8 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -693,4 +695,23 @@ private byte[] randomBytes(int minLength, int maxLength) { } return bytes; } + + @Test + public void testTimestampArrowFieldTypeTimezone() { + for (int precision : new int[] {0, 3, 6, 9}) { + // TIMESTAMP_LTZ should use UTC + FieldType ltzFieldType = + DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(precision) + .accept(ArrowFieldTypeConversion.ARROW_FIELD_TYPE_VISITOR); + ArrowType.Timestamp ltzType = (ArrowType.Timestamp) ltzFieldType.getType(); + assertThat(ltzType.getTimezone()).isEqualTo("UTC"); + + // TIMESTAMP should have no timezone + FieldType tsFieldType = + DataTypes.TIMESTAMP(precision) + .accept(ArrowFieldTypeConversion.ARROW_FIELD_TYPE_VISITOR); + ArrowType.Timestamp tsType = (ArrowType.Timestamp) tsFieldType.getType(); + assertThat(tsType.getTimezone()).isNull(); + } + } } diff --git a/paimon-common/src/main/java/org/apache/paimon/data/BlobViewResolver.java b/paimon-common/src/main/java/org/apache/paimon/data/BlobViewResolver.java index c2ea172e70b3..71a6da110170 100644 --- a/paimon-common/src/main/java/org/apache/paimon/data/BlobViewResolver.java +++ b/paimon-common/src/main/java/org/apache/paimon/data/BlobViewResolver.java @@ -24,4 +24,8 @@ public interface BlobViewResolver extends Serializable { void resolve(BlobView blobView); + + default boolean resolvesToNull(BlobView blobView) { + return false; + } } diff --git a/paimon-common/src/main/java/org/apache/paimon/data/BlobViewStruct.java b/paimon-common/src/main/java/org/apache/paimon/data/BlobViewStruct.java index f16daa921039..b5a98468a611 100644 --- a/paimon-common/src/main/java/org/apache/paimon/data/BlobViewStruct.java +++ b/paimon-common/src/main/java/org/apache/paimon/data/BlobViewStruct.java @@ -26,6 +26,7 @@ import java.util.Objects; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.apache.paimon.catalog.Identifier.UNKNOWN_DATABASE; /** * Serialized metadata for a BLOB view field. @@ -64,6 +65,11 @@ public long rowId() { } public byte[] serialize() { + if (UNKNOWN_DATABASE.equals(identifier.getDatabaseName())) { + throw new IllegalArgumentException( + "Blob view upstream table identifier must include database name: " + + identifier.getFullName()); + } byte[] identifierBytes = identifier.getFullName().getBytes(UTF_8); int totalSize = 1 + 8 + 4 + identifierBytes.length + 4 + 8; diff --git a/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexFileRecordIterator.java b/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexFileRecordIterator.java index eec931d3e98f..240d92fb5fb3 100644 --- a/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexFileRecordIterator.java +++ b/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexFileRecordIterator.java @@ -26,6 +26,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; /** * A {@link FileRecordIterator} wraps a {@link FileRecordIterator} and {@link BitmapIndexResult}. @@ -35,12 +36,16 @@ public class ApplyBitmapIndexFileRecordIterator implements FileRecordIterator iterator; private final RoaringBitmap32 bitmap; private final int last; + private final AtomicBoolean exhausted; - public ApplyBitmapIndexFileRecordIterator( - FileRecordIterator iterator, BitmapIndexResult fileIndexResult) { + ApplyBitmapIndexFileRecordIterator( + FileRecordIterator iterator, + BitmapIndexResult fileIndexResult, + AtomicBoolean exhausted) { this.iterator = iterator; this.bitmap = fileIndexResult.get(); this.last = bitmap.last(); + this.exhausted = exhausted; } @Override @@ -63,9 +68,13 @@ public InternalRow next() throws IOException { } int position = (int) returnedPosition(); if (position > last) { + exhausted.set(true); return null; } if (bitmap.contains(position)) { + if (position >= last) { + exhausted.set(true); + } return next; } } diff --git a/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReader.java b/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReader.java index 3b1207c8bd6e..9b17bcc3cc1e 100644 --- a/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReader.java +++ b/paimon-common/src/main/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReader.java @@ -26,6 +26,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; /** A {@link RecordReader} which apply {@link BitmapIndexResult} to filter record. */ public class ApplyBitmapIndexRecordReader implements FileRecordReader { @@ -34,6 +35,8 @@ public class ApplyBitmapIndexRecordReader implements FileRecordReader reader, BitmapIndexResult fileIndexResult) { this.reader = reader; @@ -43,12 +46,16 @@ public ApplyBitmapIndexRecordReader( @Nullable @Override public FileRecordIterator readBatch() throws IOException { + if (exhausted.get()) { + return null; + } + FileRecordIterator batch = reader.readBatch(); if (batch == null) { return null; } - return new ApplyBitmapIndexFileRecordIterator(batch, fileIndexResult); + return new ApplyBitmapIndexFileRecordIterator(batch, fileIndexResult, exhausted); } @Override diff --git a/paimon-common/src/main/java/org/apache/paimon/fs/cache/CachingSeekableInputStream.java b/paimon-common/src/main/java/org/apache/paimon/fs/cache/CachingSeekableInputStream.java index b81e98b5722a..024bf1e40982 100644 --- a/paimon-common/src/main/java/org/apache/paimon/fs/cache/CachingSeekableInputStream.java +++ b/paimon-common/src/main/java/org/apache/paimon/fs/cache/CachingSeekableInputStream.java @@ -21,13 +21,14 @@ import org.apache.paimon.fs.FileIO; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.SeekableInputStream; +import org.apache.paimon.fs.VectoredReadable; import javax.annotation.Nullable; import java.io.IOException; /** A {@link SeekableInputStream} that caches reads at block granularity on local disk. */ -public class CachingSeekableInputStream extends SeekableInputStream { +public class CachingSeekableInputStream extends SeekableInputStream implements VectoredReadable { private final FileIO fileIO; private final Path path; @@ -110,6 +111,36 @@ public int read(byte[] b, int off, int len) throws IOException { return totalRead; } + @Override + public int pread(long position, byte[] buffer, int offset, int length) throws IOException { + if (length == 0) { + return 0; + } + long end = Math.min(position + length, fileSize()); + if (position >= end) { + return -1; + } + + int blockSize = cache.blockSize(); + int totalRead = 0; + + while (position < end) { + int blockIndex = (int) (position / blockSize); + byte[] blockData = readBlock(blockIndex); + + long blockStart = (long) blockIndex * blockSize; + int startInBlock = (int) (position - blockStart); + int endInBlock = (int) Math.min(end - blockStart, blockData.length); + int bytesToCopy = endInBlock - startInBlock; + + System.arraycopy(blockData, startInBlock, buffer, offset + totalRead, bytesToCopy); + totalRead += bytesToCopy; + position += bytesToCopy; + } + + return totalRead; + } + private byte[] readBlock(int blockIndex) throws IOException { byte[] cached = cache.getBlock(path.toString(), blockIndex); if (cached != null) { @@ -120,14 +151,25 @@ private byte[] readBlock(int blockIndex) throws IOException { long offset = (long) blockIndex * blockSize; int readSize = (int) Math.min(blockSize, fileSize() - offset); - SeekableInputStream stream = getRemoteStream(); - stream.seek(offset); - byte[] data = readFully(stream, readSize); + byte[] data = readRemote(offset, readSize); cache.putBlock(path.toString(), blockIndex, data); return data; } + private byte[] readRemote(long offset, int size) throws IOException { + SeekableInputStream stream = getRemoteStream(); + if (stream instanceof VectoredReadable) { + byte[] buf = new byte[size]; + ((VectoredReadable) stream).preadFully(offset, buf, 0, size); + return buf; + } + synchronized (stream) { + stream.seek(offset); + return readFully(stream, size); + } + } + private SeekableInputStream getRemoteStream() throws IOException { if (remoteStream == null) { remoteStream = fileIO.newInputStream(path); diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexEvaluator.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexEvaluator.java index 2f07d97fca89..5343df611771 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexEvaluator.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexEvaluator.java @@ -39,28 +39,20 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; import java.util.function.IntFunction; import java.util.stream.Collectors; -import static org.apache.paimon.shade.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; - /** Predicate for filtering data using global indexes. */ public class GlobalIndexEvaluator implements Closeable { private final RowType rowType; private final IntFunction> readersFunction; private final Map> indexReadersCache; - private final ExecutorService executorService; public GlobalIndexEvaluator( - RowType rowType, - IntFunction> readersFunction, - @Nullable ExecutorService executorService) { + RowType rowType, IntFunction> readersFunction) { this.rowType = rowType; this.readersFunction = readersFunction; - this.executorService = - executorService == null ? newDirectExecutorService() : executorService; this.indexReadersCache = new ConcurrentHashMap<>(); } @@ -104,19 +96,7 @@ private CompletableFuture> visitLeafAsync(LeafPredic List>> readerFutures = new ArrayList<>(readers.size()); for (GlobalIndexReader reader : readers) { - readerFutures.add( - CompletableFuture.supplyAsync( - () -> { - synchronized (reader) { - Optional result = - predicate - .function() - .visit(reader, fieldRef, predicate.literals()); - result.ifPresent(GlobalIndexResult::results); - return result; - } - }, - executorService)); + readerFutures.add(predicate.function().visit(reader, fieldRef, predicate.literals())); } return CompletableFuture.allOf(readerFutures.toArray(new CompletableFuture[0])) @@ -147,7 +127,10 @@ private CompletableFuture> visitCompoundAsync( CompoundPredicate predicate) { List children = flattenChildren(predicate); List>> childFutures = - children.stream().map(this::visitAsync).collect(Collectors.toList()); + new ArrayList<>(children.size()); + for (Predicate child : children) { + childFutures.add(visitAsync(child)); + } return CompletableFuture.allOf(childFutures.toArray(new CompletableFuture[0])) .thenApply( diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexReader.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexReader.java index f639273575f3..b16ce888af50 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexReader.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexReader.java @@ -26,30 +26,37 @@ import java.io.Closeable; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; /** Index reader for global index, return {@link GlobalIndexResult}. */ -public interface GlobalIndexReader extends FunctionVisitor>, Closeable { +public interface GlobalIndexReader + extends FunctionVisitor>>, Closeable { @Override - default Optional visitAnd(List> children) { + default CompletableFuture> visitAnd( + List>> children) { throw new UnsupportedOperationException(); } @Override - default Optional visitOr(List> children) { + default CompletableFuture> visitOr( + List>> children) { throw new UnsupportedOperationException(); } @Override - default Optional visitNonFieldLeaf(LeafPredicate predicate) { + default CompletableFuture> visitNonFieldLeaf( + LeafPredicate predicate) { throw new UnsupportedOperationException(); } - default Optional visitVectorSearch(VectorSearch vectorSearch) { + default CompletableFuture> visitVectorSearch( + VectorSearch vectorSearch) { throw new UnsupportedOperationException(); } - default Optional visitFullTextSearch(FullTextSearch fullTextSearch) { + default CompletableFuture> visitFullTextSearch( + FullTextSearch fullTextSearch) { throw new UnsupportedOperationException(); } } diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResult.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResult.java index d9dd61139d31..b92990a839a0 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResult.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResult.java @@ -18,12 +18,10 @@ package org.apache.paimon.globalindex; -import org.apache.paimon.utils.LazyField; import org.apache.paimon.utils.Range; import org.apache.paimon.utils.RoaringNavigableMap64; import java.util.List; -import java.util.function.Supplier; /** Global index result represents row ids as a compressed bitmap. */ public interface GlobalIndexResult { @@ -41,56 +39,39 @@ default GlobalIndexResult offset(long startOffset) { for (long rowId : roaringNavigableMap64) { roaringNavigableMap64Offset.add(rowId + startOffset); } - return create(() -> roaringNavigableMap64Offset); + return create(roaringNavigableMap64Offset); } - /** - * Returns the intersection of this result and the other result. - * - *

    Uses native bitmap AND operation for optimal performance. - */ default GlobalIndexResult and(GlobalIndexResult other) { - return create(() -> RoaringNavigableMap64.and(this.results(), other.results())); + return create(RoaringNavigableMap64.and(this.results(), other.results())); } - /** - * Returns the union of this result and the other result. - * - *

    Uses native bitmap OR operation for optimal performance. - */ default GlobalIndexResult or(GlobalIndexResult other) { - return create(() -> RoaringNavigableMap64.or(this.results(), other.results())); + return create(RoaringNavigableMap64.or(this.results(), other.results())); } /** Returns an empty {@link GlobalIndexResult}. */ static GlobalIndexResult createEmpty() { - return create(RoaringNavigableMap64::new); + return create(new RoaringNavigableMap64()); } - /** Returns a new {@link GlobalIndexResult} from supplier. */ - static GlobalIndexResult create(Supplier supplier) { - LazyField lazyField = new LazyField<>(supplier); - return lazyField::get; + /** Returns a new {@link GlobalIndexResult} from bitmap. */ + static GlobalIndexResult create(RoaringNavigableMap64 bitmap) { + return () -> bitmap; } /** Returns a new {@link GlobalIndexResult} from {@link Range}. */ static GlobalIndexResult fromRange(Range range) { - return create( - () -> { - RoaringNavigableMap64 result64 = new RoaringNavigableMap64(); - result64.addRange(range); - return result64; - }); + RoaringNavigableMap64 result64 = new RoaringNavigableMap64(); + result64.addRange(range); + return create(result64); } static GlobalIndexResult fromRanges(List ranges) { - return create( - () -> { - RoaringNavigableMap64 result64 = new RoaringNavigableMap64(); - for (Range range : ranges) { - result64.addRange(range); - } - return result64; - }); + RoaringNavigableMap64 result64 = new RoaringNavigableMap64(); + for (Range range : ranges) { + result64.addRange(range); + } + return create(result64); } } diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java index cf3fe6fb61e0..66a43b082d4a 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java @@ -94,7 +94,7 @@ public GlobalIndexResult deserialize(DataInputView dataInput) throws IOException int scoreSize = dataInput.readInt(); if (scoreSize == 0) { - return GlobalIndexResult.create(() -> roaringNavigableMap64); + return GlobalIndexResult.create(roaringNavigableMap64); } checkArgument( scoreSize == roaringNavigableMap64.getIntCardinality(), @@ -114,6 +114,6 @@ public GlobalIndexResult deserialize(DataInputView dataInput) throws IOException scoreMap.put(rowId, scores[i++]); } - return ScoredGlobalIndexResult.create(() -> roaringNavigableMap64, scoreMap::get); + return ScoredGlobalIndexResult.create(roaringNavigableMap64, scoreMap::get); } } diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexer.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexer.java index 33ea405a8f9f..74d223a60467 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexer.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexer.java @@ -25,14 +25,17 @@ import java.io.IOException; import java.util.List; +import java.util.concurrent.ExecutorService; /** Abstract base class for global indexers. */ public interface GlobalIndexer { GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) throws IOException; - GlobalIndexReader createReader(GlobalIndexFileReader fileReader, List files) - throws IOException; + GlobalIndexReader createReader( + GlobalIndexFileReader fileReader, + List files, + ExecutorService executor); static GlobalIndexer create(String type, DataField dataField, Options options) { GlobalIndexerFactory globalIndexerFactory = GlobalIndexerFactoryUtils.load(type); diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/OffsetGlobalIndexReader.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/OffsetGlobalIndexReader.java index 0122060dfbe8..e2a03bca7684 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/OffsetGlobalIndexReader.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/OffsetGlobalIndexReader.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; /** * A {@link GlobalIndexReader} that wraps another reader and applies an offset to all row IDs in the @@ -43,89 +44,105 @@ public OffsetGlobalIndexReader(GlobalIndexReader wrapped, long offset, long to) } @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return applyOffset(wrapped.visitIsNotNull(fieldRef)); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return wrapped.visitIsNotNull(fieldRef).thenApply(this::applyOffset); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - return applyOffset(wrapped.visitIsNull(fieldRef)); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return wrapped.visitIsNull(fieldRef).thenApply(this::applyOffset); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitStartsWith(fieldRef, literal)); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return wrapped.visitStartsWith(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitEndsWith(fieldRef, literal)); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return wrapped.visitEndsWith(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitContains(fieldRef, literal)); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return wrapped.visitContains(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitLike(fieldRef, literal)); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return wrapped.visitLike(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitLessThan(fieldRef, literal)); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return wrapped.visitLessThan(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitGreaterOrEqual(fieldRef, literal)); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return wrapped.visitGreaterOrEqual(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitNotEqual(fieldRef, literal)); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return wrapped.visitNotEqual(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitLessOrEqual(fieldRef, literal)); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return wrapped.visitLessOrEqual(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitEqual(fieldRef, literal)); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return wrapped.visitEqual(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return applyOffset(wrapped.visitGreaterThan(fieldRef, literal)); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return wrapped.visitGreaterThan(fieldRef, literal).thenApply(this::applyOffset); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return applyOffset(wrapped.visitIn(fieldRef, literals)); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return wrapped.visitIn(fieldRef, literals).thenApply(this::applyOffset); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return applyOffset(wrapped.visitNotIn(fieldRef, literals)); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return wrapped.visitNotIn(fieldRef, literals).thenApply(this::applyOffset); } @Override - public Optional visitBetween(FieldRef fieldRef, Object from, Object to) { - return applyOffset(wrapped.visitBetween(fieldRef, from, to)); + public CompletableFuture> visitBetween( + FieldRef fieldRef, Object from, Object to) { + return wrapped.visitBetween(fieldRef, from, to).thenApply(this::applyOffset); } @Override - public Optional visitVectorSearch(VectorSearch vectorSearch) { + public CompletableFuture> visitVectorSearch( + VectorSearch vectorSearch) { return wrapped.visitVectorSearch(vectorSearch.offsetRange(this.offset, this.to)) - .map(r -> r.offset(offset)); + .thenApply(opt -> opt.map(r -> r.offset(offset))); } @Override - public Optional visitFullTextSearch(FullTextSearch fullTextSearch) { - return wrapped.visitFullTextSearch(fullTextSearch).map(r -> r.offset(offset)); + public CompletableFuture> visitFullTextSearch( + FullTextSearch fullTextSearch) { + return wrapped.visitFullTextSearch(fullTextSearch) + .thenApply(opt -> opt.map(r -> r.offset(offset))); } private Optional applyOffset(Optional result) { diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/ScoredGlobalIndexResult.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/ScoredGlobalIndexResult.java index 0219155a1368..cc75b207a5bf 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/ScoredGlobalIndexResult.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/ScoredGlobalIndexResult.java @@ -18,12 +18,10 @@ package org.apache.paimon.globalindex; -import org.apache.paimon.utils.LazyField; import org.apache.paimon.utils.RoaringNavigableMap64; import java.util.Comparator; import java.util.PriorityQueue; -import java.util.function.Supplier; /** Vector search global index result for scored index. */ public interface ScoredGlobalIndexResult extends GlobalIndexResult { @@ -46,8 +44,7 @@ default ScoredGlobalIndexResult offset(long offset) { roaringNavigableMap64Offset.add(rowId + offset); } - return create( - () -> roaringNavigableMap64Offset, rowId -> thisScoreGetter.score(rowId - offset)); + return create(roaringNavigableMap64Offset, rowId -> thisScoreGetter.score(rowId - offset)); } @Override @@ -109,18 +106,16 @@ default ScoredGlobalIndexResult topK(int k) { topKRowIds.add(entry[0]); } - return ScoredGlobalIndexResult.create(() -> topKRowIds, scoreGetter); + return ScoredGlobalIndexResult.create(topKRowIds, scoreGetter); } /** Returns an empty {@link ScoredGlobalIndexResult}. */ static ScoredGlobalIndexResult createEmpty() { - return create(RoaringNavigableMap64::new, rowId -> 0); + return create(new RoaringNavigableMap64(), rowId -> 0); } - /** Returns a new {@link ScoredGlobalIndexResult} from supplier. */ - static ScoredGlobalIndexResult create( - Supplier supplier, ScoreGetter scoreGetter) { - LazyField lazyField = new LazyField<>(supplier); + /** Returns a new {@link ScoredGlobalIndexResult} from bitmap. */ + static ScoredGlobalIndexResult create(RoaringNavigableMap64 bitmap, ScoreGetter scoreGetter) { return new ScoredGlobalIndexResult() { @Override public ScoreGetter scoreGetter() { @@ -129,7 +124,7 @@ public ScoreGetter scoreGetter() { @Override public RoaringNavigableMap64 results() { - return lazyField.get(); + return bitmap; } }; } diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/UnionGlobalIndexReader.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/UnionGlobalIndexReader.java index 23580b41716c..eca23818bd41 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/UnionGlobalIndexReader.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/UnionGlobalIndexReader.java @@ -22,10 +22,11 @@ import org.apache.paimon.predicate.VectorSearch; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import java.util.stream.Collectors; /** * A {@link GlobalIndexReader} that combines results from multiple readers by performing a union @@ -40,115 +41,144 @@ public UnionGlobalIndexReader(List readers) { } @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return union(reader -> reader.visitIsNotNull(fieldRef)); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return unionAsync(reader -> reader.visitIsNotNull(fieldRef)); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - return union(reader -> reader.visitIsNull(fieldRef)); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return unionAsync(reader -> reader.visitIsNull(fieldRef)); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitStartsWith(fieldRef, literal)); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitStartsWith(fieldRef, literal)); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitEndsWith(fieldRef, literal)); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitEndsWith(fieldRef, literal)); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitContains(fieldRef, literal)); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitContains(fieldRef, literal)); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitLike(fieldRef, literal)); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitLike(fieldRef, literal)); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitLessThan(fieldRef, literal)); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitLessThan(fieldRef, literal)); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitGreaterOrEqual(fieldRef, literal)); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitGreaterOrEqual(fieldRef, literal)); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitNotEqual(fieldRef, literal)); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitNotEqual(fieldRef, literal)); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitLessOrEqual(fieldRef, literal)); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitLessOrEqual(fieldRef, literal)); } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitEqual(fieldRef, literal)); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitEqual(fieldRef, literal)); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return union(reader -> reader.visitGreaterThan(fieldRef, literal)); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return unionAsync(reader -> reader.visitGreaterThan(fieldRef, literal)); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return union(reader -> reader.visitIn(fieldRef, literals)); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return unionAsync(reader -> reader.visitIn(fieldRef, literals)); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return union(reader -> reader.visitNotIn(fieldRef, literals)); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return unionAsync(reader -> reader.visitNotIn(fieldRef, literals)); } @Override - public Optional visitBetween(FieldRef fieldRef, Object from, Object to) { - return union(reader -> reader.visitBetween(fieldRef, from, to)); + public CompletableFuture> visitBetween( + FieldRef fieldRef, Object from, Object to) { + return unionAsync(reader -> reader.visitBetween(fieldRef, from, to)); } @Override - public Optional visitVectorSearch(VectorSearch vectorSearch) { - Optional result = Optional.empty(); - List> results = - executeAllReaders(reader -> reader.visitVectorSearch(vectorSearch)); - for (Optional current : results) { - if (!current.isPresent()) { - continue; - } - if (!result.isPresent()) { - result = current; - } - result = Optional.of(result.get().or(current.get())); + public CompletableFuture> visitVectorSearch( + VectorSearch vectorSearch) { + List>> futures = + new ArrayList<>(readers.size()); + for (GlobalIndexReader reader : readers) { + futures.add(reader.visitVectorSearch(vectorSearch)); } - return result; - } - - private Optional union( - Function> visitor) { - Optional result = Optional.empty(); - List> results = executeAllReaders(visitor); - for (Optional current : results) { - if (!current.isPresent()) { - continue; - } - if (!result.isPresent()) { - result = current; - } - result = Optional.of(result.get().or(current.get())); + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply( + v -> { + Optional result = Optional.empty(); + for (CompletableFuture> f : futures) { + Optional current = f.join(); + if (!current.isPresent()) { + continue; + } + if (!result.isPresent()) { + result = current; + } else { + result = Optional.of(result.get().or(current.get())); + } + } + return result; + }); + } + + private CompletableFuture> unionAsync( + Function>> visitor) { + List>> futures = + new ArrayList<>(readers.size()); + for (GlobalIndexReader reader : readers) { + futures.add(visitor.apply(reader)); } - return result; - } - - private List executeAllReaders(Function function) { - return readers.stream().map(function).collect(Collectors.toList()); + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply( + v -> { + Optional result = Optional.empty(); + for (CompletableFuture> f : futures) { + Optional current = f.join(); + if (!current.isPresent()) { + continue; + } + if (!result.isPresent()) { + result = current; + } else { + result = Optional.of(result.get().or(current.get())); + } + } + return result; + }); } @Override diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndex.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndex.java deleted file mode 100644 index 2f87f89d6eb5..000000000000 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndex.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.globalindex.bitmap; - -import org.apache.paimon.fileindex.FileIndexReader; -import org.apache.paimon.fileindex.FileIndexResult; -import org.apache.paimon.fileindex.FileIndexWriter; -import org.apache.paimon.fileindex.bitmap.BitmapFileIndex; -import org.apache.paimon.fileindex.bitmap.BitmapIndexResult; -import org.apache.paimon.fs.SeekableInputStream; -import org.apache.paimon.globalindex.GlobalIndexIOMeta; -import org.apache.paimon.globalindex.GlobalIndexReader; -import org.apache.paimon.globalindex.GlobalIndexResult; -import org.apache.paimon.globalindex.GlobalIndexSingletonWriter; -import org.apache.paimon.globalindex.GlobalIndexer; -import org.apache.paimon.globalindex.io.GlobalIndexFileReader; -import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; -import org.apache.paimon.globalindex.wrap.FileIndexReaderWrapper; -import org.apache.paimon.globalindex.wrap.FileIndexWriterWrapper; - -import java.io.IOException; -import java.util.List; -import java.util.Optional; - -import static org.apache.paimon.utils.Preconditions.checkArgument; - -/** Bitmap global index. */ -public class BitmapGlobalIndex implements GlobalIndexer { - - private final BitmapFileIndex index; - - public BitmapGlobalIndex(BitmapFileIndex index) { - this.index = index; - } - - @Override - public GlobalIndexSingletonWriter createWriter(GlobalIndexFileWriter fileWriter) - throws IOException { - FileIndexWriter writer = index.createWriter(); - return new FileIndexWriterWrapper( - fileWriter, writer, BitmapGlobalIndexerFactory.IDENTIFIER); - } - - public GlobalIndexReader createReader( - GlobalIndexFileReader fileReader, List files) throws IOException { - checkArgument(files.size() == 1); - GlobalIndexIOMeta indexMeta = files.get(0); - SeekableInputStream input = fileReader.getInputStream(indexMeta); - FileIndexReader reader = index.createReader(input, 0, (int) indexMeta.fileSize()); - return new FileIndexReaderWrapper(reader, this::toGlobalResult, input); - } - - private Optional toGlobalResult(FileIndexResult result) { - if (FileIndexResult.REMAIN == result) { - return Optional.empty(); - } else if (FileIndexResult.SKIP == result) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - BitmapIndexResult bitmapResult = (BitmapIndexResult) result; - return Optional.of(GlobalIndexResult.create(() -> bitmapResult.get().toNavigable64())); - } -} diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndexerFactory.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndexerFactory.java deleted file mode 100644 index 2150b3072ab0..000000000000 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/bitmap/BitmapGlobalIndexerFactory.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.globalindex.bitmap; - -import org.apache.paimon.fileindex.bitmap.BitmapFileIndex; -import org.apache.paimon.globalindex.GlobalIndexer; -import org.apache.paimon.globalindex.GlobalIndexerFactory; -import org.apache.paimon.options.Options; -import org.apache.paimon.types.DataField; - -/** Factory for creating bitmap global indexers. */ -public class BitmapGlobalIndexerFactory implements GlobalIndexerFactory { - - public static final String IDENTIFIER = "bitmap"; - - @Override - public String identifier() { - return IDENTIFIER; - } - - @Override - public GlobalIndexer create(DataField dataField, Options options) { - BitmapFileIndex bitmapFileIndex = new BitmapFileIndex(dataField.type(), options); - return new BitmapGlobalIndex(bitmapFileIndex); - } -} diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexer.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexer.java index 34110bc8a599..99d72317abaa 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexer.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexer.java @@ -32,6 +32,7 @@ import java.io.IOException; import java.util.List; +import java.util.concurrent.ExecutorService; /** * The {@link GlobalIndexer} for btree index. We do not build a B-tree directly in memory, instead, @@ -65,7 +66,6 @@ public class BTreeGlobalIndexer implements GlobalIndexer { public BTreeGlobalIndexer(DataField dataField, Options options) { this.keySerializer = KeySerializer.create(dataField.type()); this.options = options; - // todo: cacheManager can be null to disallow data cache. this.cacheManager = new LazyField<>( () -> @@ -92,7 +92,10 @@ public BTreeIndexWriter createWriter(GlobalIndexFileWriter fileWriter) throws IO @Override public GlobalIndexReader createReader( - GlobalIndexFileReader fileReader, List files) throws IOException { - return new LazyFilteredBTreeReader(files, keySerializer, fileReader, cacheManager.get()); + GlobalIndexFileReader fileReader, + List files, + ExecutorService executor) { + return new LazyFilteredBTreeReader( + files, keySerializer, fileReader, cacheManager.get(), executor); } } diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeIndexReader.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeIndexReader.java index 2d42b9d71154..69706c4b6628 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeIndexReader.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/BTreeIndexReader.java @@ -21,14 +21,12 @@ import org.apache.paimon.fs.Path; import org.apache.paimon.fs.SeekableInputStream; import org.apache.paimon.globalindex.GlobalIndexIOMeta; -import org.apache.paimon.globalindex.GlobalIndexReader; import org.apache.paimon.globalindex.GlobalIndexResult; import org.apache.paimon.globalindex.io.GlobalIndexFileReader; import org.apache.paimon.io.cache.CacheManager; import org.apache.paimon.memory.MemorySegment; import org.apache.paimon.memory.MemorySlice; import org.apache.paimon.memory.MemorySliceInput; -import org.apache.paimon.predicate.FieldRef; import org.apache.paimon.sst.BlockCache; import org.apache.paimon.sst.BlockHandle; import org.apache.paimon.sst.BlockIterator; @@ -40,6 +38,7 @@ import javax.annotation.Nullable; +import java.io.Closeable; import java.io.IOException; import java.util.Comparator; import java.util.List; @@ -49,8 +48,11 @@ import java.util.function.LongConsumer; import java.util.zip.CRC32; -/** The {@link GlobalIndexReader} implementation for btree index. */ -public class BTreeIndexReader implements GlobalIndexReader { +/** + * Synchronous index reader for a single BTree index file. Parallelism across multiple files is + * handled by {@link LazyFilteredBTreeReader}. + */ +public class BTreeIndexReader implements Closeable { private final SeekableInputStream input; private final SstFileReader reader; @@ -236,202 +238,98 @@ public void scanNullRowIds(LongConsumer consumer) { } } - @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - // nulls are stored separately in null bitmap. - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return allNonNullRows(); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitIsNotNull() { + return createResult(this::allNonNullRows); } - @Override - public Optional visitIsNull(FieldRef fieldRef) { - // nulls are stored separately in null bitmap. - return Optional.of(GlobalIndexResult.create(nullBitmap::get)); + public Optional visitIsNull() { + return createResult(nullBitmap::get); } - @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - // todo: `startsWith` can also be covered by btree index. - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return allNonNullRows(); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitStartsWith(Object literal) { + return createResult(this::allNonNullRows); } - @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return allNonNullRows(); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitEndsWith(Object literal) { + return createResult(this::allNonNullRows); } - @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return allNonNullRows(); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitContains(Object literal) { + return createResult(this::allNonNullRows); } - @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return allNonNullRows(); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitLike(Object literal) { + return createResult(this::allNonNullRows); } - @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return rangeQuery(minKey, literal, true, false); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitLessThan(Object literal) { + return createResult(() -> rangeQuery(minKey, literal, true, false)); } - @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return rangeQuery(literal, maxKey, true, true); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitGreaterOrEqual(Object literal) { + return createResult(() -> rangeQuery(literal, maxKey, true, true)); } - @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - RoaringNavigableMap64 result = allNonNullRows(); - result.andNot(rangeQuery(literal, literal, true, true)); - return result; - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitNotEqual(Object literal) { + return createResult( + () -> { + RoaringNavigableMap64 result = allNonNullRows(); + result.andNot(rangeQuery(literal, literal, true, true)); + return result; + }); } - @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return rangeQuery(minKey, literal, true, true); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitLessOrEqual(Object literal) { + return createResult(() -> rangeQuery(minKey, literal, true, true)); } - @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return rangeQuery(literal, literal, true, true); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitEqual(Object literal) { + return createResult(() -> rangeQuery(literal, literal, true, true)); } - @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return rangeQuery(literal, maxKey, false, true); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitGreaterThan(Object literal) { + return createResult(() -> rangeQuery(literal, maxKey, false, true)); } - @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - RoaringNavigableMap64 result = new RoaringNavigableMap64(); - for (Object literal : literals) { - result.or(rangeQuery(literal, literal, true, true)); - } - return result; - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitIn(List literals) { + return createResult( + () -> { + RoaringNavigableMap64 result = new RoaringNavigableMap64(); + for (Object literal : literals) { + result.or(rangeQuery(literal, literal, true, true)); + } + return result; + }); } - @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - RoaringNavigableMap64 result = allNonNullRows(); - result.andNot(this.visitIn(fieldRef, literals).get().results()); - return result; - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitNotIn(List literals) { + return createResult( + () -> { + RoaringNavigableMap64 result = allNonNullRows(); + RoaringNavigableMap64 inResult = new RoaringNavigableMap64(); + for (Object literal : literals) { + inResult.or(rangeQuery(literal, literal, true, true)); + } + result.andNot(inResult); + return result; + }); } - @Override - public Optional visitBetween(FieldRef fieldRef, Object from, Object to) { - return Optional.of( - GlobalIndexResult.create( - () -> { - try { - return rangeQuery(from, to, true, true); - } catch (IOException ioe) { - throw new RuntimeException("fail to read btree index file.", ioe); - } - })); + public Optional visitBetween(Object from, Object to) { + return createResult(() -> rangeQuery(from, to, true, true)); + } + + private Optional createResult(IOSupplier supplier) { + try { + return Optional.of(GlobalIndexResult.create(supplier.get())); + } catch (IOException e) { + throw new RuntimeException("fail to read btree index file.", e); + } + } + + @FunctionalInterface + private interface IOSupplier { + T get() throws IOException; } private RoaringNavigableMap64 allNonNullRows() throws IOException { diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/KeySerializer.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/KeySerializer.java index 9971fd27517d..a8c848e2b613 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/KeySerializer.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/KeySerializer.java @@ -41,9 +41,12 @@ import org.apache.paimon.types.TinyIntType; import org.apache.paimon.types.VarCharType; +import javax.annotation.concurrent.ThreadSafe; + import java.util.Comparator; /** This interface provides core methods to ser/de and compare btree index keys. */ +@ThreadSafe public interface KeySerializer { byte[] serialize(Object key); @@ -136,11 +139,10 @@ public KeySerializer visit(LocalZonedTimestampType localZonedTimestampType) { /** Serializer for int type. */ class IntSerializer implements KeySerializer { - private final MemorySliceOutput keyOut = new MemorySliceOutput(4); @Override public byte[] serialize(Object key) { - keyOut.reset(); + MemorySliceOutput keyOut = new MemorySliceOutput(4); keyOut.writeInt((Integer) key); return keyOut.toSlice().copyBytes(); } @@ -158,11 +160,10 @@ public Comparator createComparator() { /** Serializer for long type. */ class BigIntSerializer implements KeySerializer { - private final MemorySliceOutput keyOut = new MemorySliceOutput(8); @Override public byte[] serialize(Object key) { - keyOut.reset(); + MemorySliceOutput keyOut = new MemorySliceOutput(8); keyOut.writeLong((Long) key); return keyOut.toSlice().copyBytes(); } @@ -199,11 +200,10 @@ public Comparator createComparator() { /** Serializer for small int type. */ class SmallIntSerializer implements KeySerializer { - private final MemorySliceOutput keyOut = new MemorySliceOutput(2); @Override public byte[] serialize(Object key) { - keyOut.reset(); + MemorySliceOutput keyOut = new MemorySliceOutput(2); keyOut.writeShort((Short) key); return keyOut.toSlice().copyBytes(); } @@ -240,11 +240,10 @@ public Comparator createComparator() { /** Serializer for float type. */ class FloatSerializer implements KeySerializer { - private final MemorySliceOutput keyOut = new MemorySliceOutput(4); @Override public byte[] serialize(Object key) { - keyOut.reset(); + MemorySliceOutput keyOut = new MemorySliceOutput(4); keyOut.writeInt(Float.floatToIntBits((Float) key)); return keyOut.toSlice().copyBytes(); } @@ -262,11 +261,10 @@ public Comparator createComparator() { /** Serializer for double type. */ class DoubleSerializer implements KeySerializer { - private final MemorySliceOutput keyOut = new MemorySliceOutput(8); @Override public byte[] serialize(Object key) { - keyOut.reset(); + MemorySliceOutput keyOut = new MemorySliceOutput(8); keyOut.writeLong(Double.doubleToLongBits((Double) key)); return keyOut.toSlice().copyBytes(); } @@ -284,7 +282,6 @@ public Comparator createComparator() { /** Serializer for decimal type. */ class DecimalSerializer implements KeySerializer { - private final MemorySliceOutput keyOut = new MemorySliceOutput(8); private final int precision; private final int scale; @@ -296,7 +293,7 @@ public DecimalSerializer(int precision, int scale) { @Override public byte[] serialize(Object key) { if (Decimal.isCompact(precision)) { - keyOut.reset(); + MemorySliceOutput keyOut = new MemorySliceOutput(8); keyOut.writeLong(((Decimal) key).toUnscaledLong()); return keyOut.toSlice().copyBytes(); } @@ -338,7 +335,6 @@ public Comparator createComparator() { /** Serializer for timestamp. */ class TimestampSerializer implements KeySerializer { - private final MemorySliceOutput keyOut = new MemorySliceOutput(8); private final int precision; public TimestampSerializer(int precision) { @@ -347,7 +343,7 @@ public TimestampSerializer(int precision) { @Override public byte[] serialize(Object key) { - keyOut.reset(); + MemorySliceOutput keyOut = new MemorySliceOutput(12); if (Timestamp.isCompact(precision)) { keyOut.writeLong(((Timestamp) key).getMillisecond()); } else { diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeReader.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeReader.java index f5badf1a70a6..85eb93500e9b 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeReader.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeReader.java @@ -22,274 +22,220 @@ import org.apache.paimon.globalindex.GlobalIndexIOMeta; import org.apache.paimon.globalindex.GlobalIndexReader; import org.apache.paimon.globalindex.GlobalIndexResult; -import org.apache.paimon.globalindex.UnionGlobalIndexReader; import org.apache.paimon.globalindex.io.GlobalIndexFileReader; import org.apache.paimon.io.cache.CacheManager; import org.apache.paimon.predicate.FieldRef; import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.function.Function; +import java.util.function.Supplier; /** - * An Index Reader for BTree which dynamically filters file list by input predicate, then merge the - * result by an {@link org.apache.paimon.globalindex.UnionGlobalIndexReader}. In the ideal situation - * such as visiting an Equal predicate, only a very few files would be actually read. + * An Index Reader for BTree which dynamically filters file list by input predicate, then visits + * each selected file in parallel via an executor. Each index file is synchronized independently to + * allow maximum concurrency. */ public class LazyFilteredBTreeReader implements GlobalIndexReader { private final BTreeFileMetaSelector fileSelector; - private final Map readerCache; + private final Map readerCache; private final KeySerializer keySerializer; private final CacheManager cacheManager; private final GlobalIndexFileReader fileReader; + private final ExecutorService executor; public LazyFilteredBTreeReader( List files, KeySerializer keySerializer, GlobalIndexFileReader fileReader, - CacheManager cacheManager) { + CacheManager cacheManager, + ExecutorService executor) { this.fileSelector = new BTreeFileMetaSelector(files, keySerializer); - this.readerCache = new HashMap<>(); + this.readerCache = new ConcurrentHashMap<>(); this.cacheManager = cacheManager; this.fileReader = fileReader; this.keySerializer = keySerializer; + this.executor = executor; } @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - Optional> selectedOpt = fileSelector.visitIsNotNull(fieldRef); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitIsNotNull(fieldRef); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return visitParallel( + () -> fileSelector.visitIsNotNull(fieldRef), BTreeIndexReader::visitIsNotNull); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - Optional> selectedOpt = fileSelector.visitIsNull(fieldRef); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitIsNull(fieldRef); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return visitParallel( + () -> fileSelector.visitIsNull(fieldRef), BTreeIndexReader::visitIsNull); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitStartsWith(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitStartsWith(fieldRef, literal); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitStartsWith(fieldRef, literal), + reader -> reader.visitStartsWith(literal)); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitEndsWith(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitEndsWith(fieldRef, literal); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitEndsWith(fieldRef, literal), + reader -> reader.visitEndsWith(literal)); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitContains(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitContains(fieldRef, literal); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitContains(fieldRef, literal), + reader -> reader.visitContains(literal)); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = fileSelector.visitLike(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitLike(fieldRef, literal); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitLike(fieldRef, literal), + reader -> reader.visitLike(literal)); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitLessThan(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitLessThan(fieldRef, literal); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitLessThan(fieldRef, literal), + reader -> reader.visitLessThan(literal)); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitGreaterOrEqual(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitGreaterOrEqual(fieldRef, literal); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitGreaterOrEqual(fieldRef, literal), + reader -> reader.visitGreaterOrEqual(literal)); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitNotEqual(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitNotEqual(fieldRef, literal); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitNotEqual(fieldRef, literal), + reader -> reader.visitNotEqual(literal)); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitLessOrEqual(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitLessOrEqual(fieldRef, literal); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitLessOrEqual(fieldRef, literal), + reader -> reader.visitLessOrEqual(literal)); } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = fileSelector.visitEqual(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitEqual(fieldRef, literal); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitEqual(fieldRef, literal), + reader -> reader.visitEqual(literal)); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - Optional> selectedOpt = - fileSelector.visitGreaterThan(fieldRef, literal); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitGreaterThan(fieldRef, literal); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return visitParallel( + () -> fileSelector.visitGreaterThan(fieldRef, literal), + reader -> reader.visitGreaterThan(literal)); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - Optional> selectedOpt = fileSelector.visitIn(fieldRef, literals); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitIn(fieldRef, literals); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return visitParallel( + () -> fileSelector.visitIn(fieldRef, literals), reader -> reader.visitIn(literals)); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - Optional> selectedOpt = fileSelector.visitNotIn(fieldRef, literals); - if (!selectedOpt.isPresent()) { - return Optional.empty(); - } - List selected = selectedOpt.get(); - if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); - } - return createUnionReader(selected).visitNotIn(fieldRef, literals); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return visitParallel( + () -> fileSelector.visitNotIn(fieldRef, literals), + reader -> reader.visitNotIn(literals)); } @Override - public Optional visitBetween(FieldRef fieldRef, Object from, Object to) { - Optional> selectedOpt = - fileSelector.visitBetween(fieldRef, from, to); + public CompletableFuture> visitBetween( + FieldRef fieldRef, Object from, Object to) { + return visitParallel( + () -> fileSelector.visitBetween(fieldRef, from, to), + reader -> reader.visitBetween(from, to)); + } + + private CompletableFuture> visitParallel( + Supplier>> selector, + Function> visitor) { + Optional> selectedOpt = selector.get(); if (!selectedOpt.isPresent()) { - return Optional.empty(); + return CompletableFuture.completedFuture(Optional.empty()); } List selected = selectedOpt.get(); if (selected.isEmpty()) { - return Optional.of(GlobalIndexResult.createEmpty()); + return CompletableFuture.completedFuture(Optional.of(GlobalIndexResult.createEmpty())); + } + + List>> futures = + new ArrayList<>(selected.size()); + for (GlobalIndexIOMeta meta : selected) { + futures.add( + CompletableFuture.supplyAsync( + () -> visitor.apply(getOrCreateReader(meta)), executor)); + } + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply(v -> unionResults(futures)); + } + + private Optional unionResults( + List>> futures) { + Optional result = Optional.empty(); + for (CompletableFuture> future : futures) { + Optional current = future.join(); + if (!current.isPresent()) { + continue; + } + if (!result.isPresent()) { + result = current; + } else { + result = Optional.of(result.get().or(current.get())); + } } - return createUnionReader(selected).visitBetween(fieldRef, from, to); + return result; + } + + private BTreeIndexReader getOrCreateReader(GlobalIndexIOMeta meta) { + return readerCache.computeIfAbsent(meta.filePath(), k -> createBTreeReader(meta)); } - /** - * Create a Union Reader for given files. The union reader is composed by readers from reader - * cache, so please do not close it. - */ - private UnionGlobalIndexReader createUnionReader(List files) { - List readers = new ArrayList<>(); - for (GlobalIndexIOMeta meta : files) { - readers.add( - readerCache.computeIfAbsent( - meta.filePath(), - name -> { - try { - return new BTreeIndexReader( - keySerializer, fileReader, meta, cacheManager); - } catch (IOException e) { - throw new RuntimeException( - "Can't create BTree index reader for " + name, e); - } - })); + private BTreeIndexReader createBTreeReader(GlobalIndexIOMeta meta) { + try { + return new BTreeIndexReader(keySerializer, fileReader, meta, cacheManager); + } catch (IOException e) { + throw new RuntimeException("Can't create BTree index reader for " + meta.filePath(), e); } - return new UnionGlobalIndexReader(readers); } @Override public void close() throws IOException { IOException exception = null; - for (Map.Entry entry : this.readerCache.entrySet()) { + for (Map.Entry entry : this.readerCache.entrySet()) { try { entry.getValue().close(); } catch (IOException ioe) { diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexReaderWrapper.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexReaderWrapper.java deleted file mode 100644 index 2fb8fe30655c..000000000000 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexReaderWrapper.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.globalindex.wrap; - -import org.apache.paimon.fileindex.FileIndexReader; -import org.apache.paimon.fileindex.FileIndexResult; -import org.apache.paimon.globalindex.GlobalIndexReader; -import org.apache.paimon.globalindex.GlobalIndexResult; -import org.apache.paimon.predicate.FieldRef; - -import java.io.Closeable; -import java.io.IOException; -import java.util.List; -import java.util.Optional; -import java.util.function.Function; - -/** A {@link GlobalIndexReader} wrapper for {@link FileIndexReader}. */ -public class FileIndexReaderWrapper implements GlobalIndexReader { - - private final FileIndexReader reader; - private final Function> transform; - private final Closeable closeable; - - public FileIndexReaderWrapper( - FileIndexReader reader, - Function> transform, - Closeable closeable) { - this.reader = reader; - this.transform = transform; - this.closeable = closeable; - } - - @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return transform.apply(reader.visitIsNotNull(fieldRef)); - } - - @Override - public Optional visitIsNull(FieldRef fieldRef) { - return transform.apply(reader.visitIsNull(fieldRef)); - } - - @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitStartsWith(fieldRef, literal)); - } - - @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitEndsWith(fieldRef, literal)); - } - - @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitContains(fieldRef, literal)); - } - - @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitLike(fieldRef, literal)); - } - - @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitLessThan(fieldRef, literal)); - } - - @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitGreaterOrEqual(fieldRef, literal)); - } - - @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitNotEqual(fieldRef, literal)); - } - - @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitLessOrEqual(fieldRef, literal)); - } - - @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitEqual(fieldRef, literal)); - } - - @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return transform.apply(reader.visitGreaterThan(fieldRef, literal)); - } - - @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return transform.apply(reader.visitIn(fieldRef, literals)); - } - - @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return transform.apply(reader.visitNotIn(fieldRef, literals)); - } - - @Override - public void close() throws IOException { - closeable.close(); - } -} diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexWriterWrapper.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexWriterWrapper.java deleted file mode 100644 index 536da1ac8c5a..000000000000 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/wrap/FileIndexWriterWrapper.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.globalindex.wrap; - -import org.apache.paimon.fileindex.FileIndexWriter; -import org.apache.paimon.globalindex.GlobalIndexReader; -import org.apache.paimon.globalindex.GlobalIndexSingletonWriter; -import org.apache.paimon.globalindex.ResultEntry; -import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; - -import java.io.OutputStream; -import java.util.Collections; -import java.util.List; - -/** A {@link GlobalIndexReader} wrapper for {@link FileIndexWriter}. */ -public class FileIndexWriterWrapper implements GlobalIndexSingletonWriter { - - private final GlobalIndexFileWriter fileWriter; - private final FileIndexWriter writer; - private final String indexType; - private long count = 0; - - public FileIndexWriterWrapper( - GlobalIndexFileWriter fileWriter, FileIndexWriter writer, String indexType) { - this.fileWriter = fileWriter; - this.writer = writer; - this.indexType = indexType; - } - - @Override - public void write(Object key) { - count++; - writer.write(key); - } - - @Override - public List finish() { - if (count > 0) { - String fileName = fileWriter.newFileName(indexType); - try (OutputStream outputStream = fileWriter.newOutputStream(fileName)) { - outputStream.write(writer.serializedBytes()); - } catch (Exception e) { - throw new RuntimeException("Failed to write global index file: " + fileName, e); - } - return Collections.singletonList(new ResultEntry(fileName, count, null)); - } else { - return Collections.emptyList(); - } - } -} diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/FullTextSearch.java b/paimon-common/src/main/java/org/apache/paimon/predicate/FullTextSearch.java index 5b91d0bbc6d2..d25225ec69f7 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/FullTextSearch.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/FullTextSearch.java @@ -18,11 +18,7 @@ package org.apache.paimon.predicate; -import org.apache.paimon.globalindex.GlobalIndexReader; -import org.apache.paimon.globalindex.ScoredGlobalIndexResult; - import java.io.Serializable; -import java.util.Optional; /** FullTextSearch to perform full-text search on a text column. */ public class FullTextSearch implements Serializable { @@ -32,8 +28,13 @@ public class FullTextSearch implements Serializable { private final String queryText; private final String fieldName; private final int limit; + private final String queryOperator; public FullTextSearch(String queryText, int limit, String fieldName) { + this(queryText, limit, fieldName, "or"); + } + + public FullTextSearch(String queryText, int limit, String fieldName, String queryOperator) { if (queryText == null || queryText.isEmpty()) { throw new IllegalArgumentException("Query text cannot be null or empty"); } @@ -46,6 +47,13 @@ public FullTextSearch(String queryText, int limit, String fieldName) { this.queryText = queryText; this.limit = limit; this.fieldName = fieldName; + String normalizedOperator = + queryOperator == null ? "or" : queryOperator.trim().toLowerCase(); + if (!"or".equals(normalizedOperator) && !"and".equals(normalizedOperator)) { + throw new IllegalArgumentException( + "Query operator must be 'or' or 'and', got: " + queryOperator); + } + this.queryOperator = normalizedOperator; } public String queryText() { @@ -60,13 +68,14 @@ public String fieldName() { return fieldName; } - public Optional visit(GlobalIndexReader visitor) { - return visitor.visitFullTextSearch(this); + public String queryOperator() { + return queryOperator; } @Override public String toString() { return String.format( - "FullTextSearch{field=%s, query='%s', limit=%d}", fieldName, queryText, limit); + "FullTextSearch{field=%s, query='%s', limit=%d, operator=%s}", + fieldName, queryText, limit, queryOperator); } } diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/VectorSearch.java b/paimon-common/src/main/java/org/apache/paimon/predicate/VectorSearch.java index d4bb6ed3e5dd..5e660ed17fe0 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/VectorSearch.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/VectorSearch.java @@ -18,15 +18,12 @@ package org.apache.paimon.predicate; -import org.apache.paimon.globalindex.GlobalIndexReader; -import org.apache.paimon.globalindex.ScoredGlobalIndexResult; import org.apache.paimon.utils.Range; import org.apache.paimon.utils.RoaringNavigableMap64; import javax.annotation.Nullable; import java.io.Serializable; -import java.util.Optional; /** VectorSearch to perform vector similarity search. * */ public class VectorSearch implements Serializable { @@ -91,10 +88,6 @@ public VectorSearch offsetRange(long from, long to) { return this; } - public Optional visit(GlobalIndexReader visitor) { - return visitor.visitVectorSearch(this); - } - @Override public String toString() { return String.format("FieldName(%s), Limit(%s)", fieldName, limit); diff --git a/paimon-common/src/main/java/org/apache/paimon/reader/DataEvolutionArray.java b/paimon-common/src/main/java/org/apache/paimon/reader/DataEvolutionArray.java index 8b89c1ce7f8a..0995e84e96ed 100644 --- a/paimon-common/src/main/java/org/apache/paimon/reader/DataEvolutionArray.java +++ b/paimon-common/src/main/java/org/apache/paimon/reader/DataEvolutionArray.java @@ -31,14 +31,31 @@ /** The array which is made up by several rows. */ public class DataEvolutionArray implements InternalArray { + /** Sentinel for "no fallback"; positions with rowOffsets[pos] < 0 stay null. */ + public static final long NO_MISSING_FIELD_FALLBACK = Long.MIN_VALUE; + private final InternalArray[] rows; private final int[] rowOffsets; private final int[] fieldOffsets; + /** + * Value to return from getLong(pos) when {@code rowOffsets[pos] < 0}. Used by data-evolution + * null-count arrays to encode "field not physically present in any file in the group" as "all + * rowCount rows are null" instead of "unknown stats". {@link #NO_MISSING_FIELD_FALLBACK} + * disables the fallback. + */ + private final long missingFieldLong; + public DataEvolutionArray(int rowNumber, int[] rowOffsets, int[] fieldOffsets) { + this(rowNumber, rowOffsets, fieldOffsets, NO_MISSING_FIELD_FALLBACK); + } + + public DataEvolutionArray( + int rowNumber, int[] rowOffsets, int[] fieldOffsets, long missingFieldLong) { this.rows = new InternalArray[rowNumber]; this.rowOffsets = rowOffsets; this.fieldOffsets = fieldOffsets; + this.missingFieldLong = missingFieldLong; } public void setRow(int pos, InternalArray row) { @@ -73,6 +90,16 @@ private int offsetInRow(int pos) { @Override public boolean isNullAt(int pos) { + // rowOffsets[pos] == -1: field is absent from every file in the group, so every + // logical row is null for it; with missingFieldLong set this is encoded as a + // known count rather than "unknown stats" (isNullAt=false), so non-IS-NULL + // predicates can prune the file. + // rowOffsets[pos] == -2: field exists in a file but its stats were not captured + // (e.g. valueStatsCols did not include it). Treat as unknown stats so callers + // stay conservative. + if (rowOffsets[pos] == -1 && missingFieldLong != NO_MISSING_FIELD_FALLBACK) { + return false; + } if (rowOffsets[pos] < 0) { return true; } @@ -101,6 +128,9 @@ public int getInt(int pos) { @Override public long getLong(int pos) { + if (rowOffsets[pos] == -1 && missingFieldLong != NO_MISSING_FIELD_FALLBACK) { + return missingFieldLong; + } return chooseArray(pos).getLong(offsetInRow(pos)); } diff --git a/paimon-common/src/main/java/org/apache/paimon/sst/BlockCache.java b/paimon-common/src/main/java/org/apache/paimon/sst/BlockCache.java index d391c8de69c1..9aae395acd60 100644 --- a/paimon-common/src/main/java/org/apache/paimon/sst/BlockCache.java +++ b/paimon-common/src/main/java/org/apache/paimon/sst/BlockCache.java @@ -20,6 +20,7 @@ import org.apache.paimon.fs.Path; import org.apache.paimon.fs.SeekableInputStream; +import org.apache.paimon.fs.VectoredReadable; import org.apache.paimon.io.cache.CacheKey; import org.apache.paimon.io.cache.CacheManager; import org.apache.paimon.io.cache.CacheManager.SegmentContainer; @@ -28,10 +29,10 @@ import java.io.Closeable; import java.io.IOException; -import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; /** Cache for block reading. */ @@ -46,13 +47,19 @@ public BlockCache(Path filePath, SeekableInputStream input, CacheManager cacheMa this.filePath = filePath; this.input = input; this.cacheManager = cacheManager; - this.blocks = new HashMap<>(); + this.blocks = new ConcurrentHashMap<>(); } private byte[] readFrom(long offset, int length) throws IOException { byte[] buffer = new byte[length]; - input.seek(offset); - IOUtils.readFully(input, buffer); + if (input instanceof VectoredReadable) { + ((VectoredReadable) input).preadFully(offset, buffer, 0, length); + } else { + synchronized (input) { + input.seek(offset); + IOUtils.readFully(input, buffer); + } + } return buffer; } diff --git a/paimon-common/src/main/java/org/apache/paimon/utils/FileBasedBloomFilter.java b/paimon-common/src/main/java/org/apache/paimon/utils/FileBasedBloomFilter.java index d6feca60ac2c..4e8bc2bbc923 100644 --- a/paimon-common/src/main/java/org/apache/paimon/utils/FileBasedBloomFilter.java +++ b/paimon-common/src/main/java/org/apache/paimon/utils/FileBasedBloomFilter.java @@ -21,6 +21,7 @@ import org.apache.paimon.annotation.VisibleForTesting; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.SeekableInputStream; +import org.apache.paimon.fs.VectoredReadable; import org.apache.paimon.io.cache.CacheCallback; import org.apache.paimon.io.cache.CacheKey; import org.apache.paimon.io.cache.CacheKey.PositionCacheKey; @@ -95,9 +96,15 @@ public boolean testHash(int hash) { private byte[] readBytes(CacheKey k) throws IOException { PositionCacheKey key = (PositionCacheKey) k; - input.seek(key.position()); byte[] bytes = new byte[key.length()]; - IOUtils.readFully(input, bytes); + if (input instanceof VectoredReadable) { + ((VectoredReadable) input).preadFully(key.position(), bytes, 0, key.length()); + } else { + synchronized (input) { + input.seek(key.position()); + IOUtils.readFully(input, bytes); + } + } return bytes; } diff --git a/paimon-common/src/main/java/org/apache/paimon/utils/FileType.java b/paimon-common/src/main/java/org/apache/paimon/utils/FileType.java index df71b0248a13..b09281c685e0 100644 --- a/paimon-common/src/main/java/org/apache/paimon/utils/FileType.java +++ b/paimon-common/src/main/java/org/apache/paimon/utils/FileType.java @@ -34,7 +34,7 @@ * _SUCCESS, consumer, service files *
  • {@link #DATA}: data files and any unrecognized files (default) *
  • {@link #BUCKET_INDEX}: bucket level index files (Hash, DV) - *
  • {@link #GLOBAL_INDEX}: table level global index files (btree, bitmap, lumina, tantivy) + *
  • {@link #GLOBAL_INDEX}: table level global index files (btree, lumina, tantivy) *
  • {@link #FILE_INDEX}: data-file index files (bloom filter, bitmap, etc.) * */ diff --git a/paimon-common/src/main/java/org/apache/paimon/utils/LazyField.java b/paimon-common/src/main/java/org/apache/paimon/utils/LazyField.java index 8e9eb9e2f11d..db776c415605 100644 --- a/paimon-common/src/main/java/org/apache/paimon/utils/LazyField.java +++ b/paimon-common/src/main/java/org/apache/paimon/utils/LazyField.java @@ -20,12 +20,11 @@ import java.util.function.Supplier; -/** A class to lazy initialized field. */ +/** A class to lazy initialized field. Thread-safe via double-checked locking. */ public class LazyField { + private volatile boolean initialized; private Supplier supplier; - - private boolean initialized; private T value; public LazyField(Supplier supplier) { @@ -33,14 +32,17 @@ public LazyField(Supplier supplier) { } public T get() { - if (!initialized) { - T t = supplier.get(); - value = t; - initialized = true; - supplier = null; // release the closure chain for GC - return t; + if (initialized) { + return value; + } + synchronized (this) { + if (!initialized) { + value = supplier.get(); + initialized = true; + supplier = null; + } + return value; } - return value; } public boolean initialized() { diff --git a/paimon-common/src/main/resources/META-INF/services/org.apache.paimon.globalindex.GlobalIndexerFactory b/paimon-common/src/main/resources/META-INF/services/org.apache.paimon.globalindex.GlobalIndexerFactory index 4c3fe70db9c5..1477dcba359e 100644 --- a/paimon-common/src/main/resources/META-INF/services/org.apache.paimon.globalindex.GlobalIndexerFactory +++ b/paimon-common/src/main/resources/META-INF/services/org.apache.paimon.globalindex.GlobalIndexerFactory @@ -13,5 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -org.apache.paimon.globalindex.bitmap.BitmapGlobalIndexerFactory -org.apache.paimon.globalindex.btree.BTreeGlobalIndexerFactory \ No newline at end of file +org.apache.paimon.globalindex.btree.BTreeGlobalIndexerFactory diff --git a/paimon-common/src/test/java/org/apache/paimon/data/BlobViewStructTest.java b/paimon-common/src/test/java/org/apache/paimon/data/BlobViewStructTest.java index 058876219fe8..0d40c1fe30df 100644 --- a/paimon-common/src/test/java/org/apache/paimon/data/BlobViewStructTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/data/BlobViewStructTest.java @@ -40,6 +40,17 @@ public void testSerializeAndDeserialize() { assertThat(deserialized.rowId()).isEqualTo(5L); } + @Test + public void testRejectUnknownDatabase() { + BlobViewStruct viewStruct = + new BlobViewStruct(Identifier.create(Identifier.UNKNOWN_DATABASE, "source"), 7, 5L); + + assertThatThrownBy(viewStruct::serialize) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Blob view upstream table identifier must include database name"); + } + @Test public void testRejectUnexpectedVersion() { BlobViewStruct viewStruct = diff --git a/paimon-common/src/test/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReaderTest.java b/paimon-common/src/test/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReaderTest.java new file mode 100644 index 000000000000..2fbdc2709c7d --- /dev/null +++ b/paimon-common/src/test/java/org/apache/paimon/fileindex/bitmap/ApplyBitmapIndexRecordReaderTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.fileindex.bitmap; + +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.fs.Path; +import org.apache.paimon.reader.FileRecordIterator; +import org.apache.paimon.reader.FileRecordReader; +import org.apache.paimon.utils.CloseableIterator; +import org.apache.paimon.utils.RoaringBitmap32; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link ApplyBitmapIndexRecordReader}. */ +public class ApplyBitmapIndexRecordReaderTest { + + private static final Path DUMMY_PATH = new Path("/dummy"); + + @Test + public void testLimitStopsAfterFirstBatchWithPreSelectedUnderlying() throws Exception { + // Simulates DataFileRecordReader#readBatchInternal: iterator.selection(selection) + int limit = 10; + int batchSize = 20; + int totalRows = 100; + RoaringBitmap32 limitBitmap = RoaringBitmap32.bitmapOfRange(0, totalRows).limit(limit); + BitmapIndexResult selection = new BitmapIndexResult(() -> limitBitmap); + PreSelectedCountingFileRecordReader underlying = + new PreSelectedCountingFileRecordReader(totalRows, batchSize, limitBitmap); + ApplyBitmapIndexRecordReader reader = + new ApplyBitmapIndexRecordReader(underlying, selection); + + List rows = readViaCloseableIterator(reader); + + assertThat(rows).containsExactlyElementsOf(range(limit)); + assertThat(underlying.readBatchCount()).isEqualTo(1); + } + + @Test + public void testSparseBitmapWithPreSelectedUnderlying() throws Exception { + int batchSize = 5; + int totalRows = 20; + RoaringBitmap32 sparseBitmap = RoaringBitmap32.bitmapOf(0, 2, 4, 6, 8); + BitmapIndexResult selection = new BitmapIndexResult(() -> sparseBitmap); + PreSelectedCountingFileRecordReader underlying = + new PreSelectedCountingFileRecordReader(totalRows, batchSize, sparseBitmap); + ApplyBitmapIndexRecordReader reader = + new ApplyBitmapIndexRecordReader(underlying, selection); + + List rows = readViaCloseableIterator(reader); + + assertThat(rows).containsExactly(0, 2, 4, 6, 8); + assertThat(underlying.readBatchCount()).isEqualTo(2); + } + + @Test + public void testLimitStopsAfterFirstBatch() throws Exception { + int limit = 10; + int batchSize = 20; + int totalRows = 100; + CountingFileRecordReader underlying = new CountingFileRecordReader(totalRows, batchSize); + BitmapIndexResult selection = + new BitmapIndexResult( + () -> RoaringBitmap32.bitmapOfRange(0, totalRows).limit(limit)); + ApplyBitmapIndexRecordReader reader = + new ApplyBitmapIndexRecordReader(underlying, selection); + + List rows = readViaCloseableIterator(reader); + + assertThat(rows).containsExactlyElementsOf(range(limit)); + assertThat(underlying.readBatchCount()).isEqualTo(1); + } + + @Test + public void testForEachRemainingStopsAfterLimit() throws Exception { + int limit = 10; + int batchSize = 20; + int totalRows = 100; + CountingFileRecordReader underlying = new CountingFileRecordReader(totalRows, batchSize); + BitmapIndexResult selection = + new BitmapIndexResult( + () -> RoaringBitmap32.bitmapOfRange(0, totalRows).limit(limit)); + ApplyBitmapIndexRecordReader reader = + new ApplyBitmapIndexRecordReader(underlying, selection); + + List rows = new ArrayList<>(); + reader.forEachRemaining(row -> rows.add(row.getInt(0))); + + assertThat(rows).containsExactlyElementsOf(range(limit)); + assertThat(underlying.readBatchCount()).isEqualTo(1); + } + + @Test + public void testSparseBitmapStillStopsAtLastPosition() throws Exception { + int batchSize = 5; + int totalRows = 20; + CountingFileRecordReader underlying = new CountingFileRecordReader(totalRows, batchSize); + BitmapIndexResult selection = + new BitmapIndexResult(() -> RoaringBitmap32.bitmapOf(0, 2, 4, 6, 8)); + ApplyBitmapIndexRecordReader reader = + new ApplyBitmapIndexRecordReader(underlying, selection); + + List rows = readViaCloseableIterator(reader); + + assertThat(rows).containsExactly(0, 2, 4, 6, 8); + assertThat(underlying.readBatchCount()).isEqualTo(2); + } + + private static List readViaCloseableIterator(ApplyBitmapIndexRecordReader reader) + throws Exception { + List rows = new ArrayList<>(); + try (CloseableIterator iterator = reader.toCloseableIterator()) { + while (iterator.hasNext()) { + rows.add(iterator.next().getInt(0)); + } + } + return rows; + } + + private static List range(int endExclusive) { + List result = new ArrayList<>(endExclusive); + for (int i = 0; i < endExclusive; i++) { + result.add(i); + } + return result; + } + + /** + * Simulates {@code DataFileRecordReader#readBatchInternal}, which applies {@code + * iterator.selection(selection)} before wrapping with {@link ApplyBitmapIndexRecordReader}. + */ + private static class PreSelectedCountingFileRecordReader + implements FileRecordReader { + + private final int totalRows; + private final int batchSize; + private final RoaringBitmap32 selectionBitmap; + private final AtomicInteger readBatchCount = new AtomicInteger(0); + private int nextBatchStart; + + private PreSelectedCountingFileRecordReader( + int totalRows, int batchSize, RoaringBitmap32 selectionBitmap) { + this.totalRows = totalRows; + this.batchSize = batchSize; + this.selectionBitmap = selectionBitmap; + } + + int readBatchCount() { + return readBatchCount.get(); + } + + @Override + public FileRecordIterator readBatch() { + readBatchCount.incrementAndGet(); + if (nextBatchStart >= totalRows) { + return null; + } + int batchStart = nextBatchStart; + int batchEnd = Math.min(nextBatchStart + batchSize, totalRows); + nextBatchStart = batchEnd; + return new PositionFileRecordIterator(batchStart, batchEnd).selection(selectionBitmap); + } + + @Override + public void close() {} + } + + private static class CountingFileRecordReader implements FileRecordReader { + + private final int totalRows; + private final int batchSize; + private final AtomicInteger readBatchCount = new AtomicInteger(0); + private int nextBatchStart; + + private CountingFileRecordReader(int totalRows, int batchSize) { + this.totalRows = totalRows; + this.batchSize = batchSize; + } + + int readBatchCount() { + return readBatchCount.get(); + } + + @Override + public FileRecordIterator readBatch() { + readBatchCount.incrementAndGet(); + if (nextBatchStart >= totalRows) { + return null; + } + int batchStart = nextBatchStart; + int batchEnd = Math.min(nextBatchStart + batchSize, totalRows); + nextBatchStart = batchEnd; + return new PositionFileRecordIterator(batchStart, batchEnd); + } + + @Override + public void close() {} + } + + private static class PositionFileRecordIterator implements FileRecordIterator { + + private final int end; + private int nextPosition; + private int returnedPosition = -1; + + private PositionFileRecordIterator(int start, int end) { + this.nextPosition = start; + this.end = end; + } + + @Override + public InternalRow next() { + if (nextPosition >= end) { + return null; + } + returnedPosition = nextPosition; + return GenericRow.of(nextPosition++); + } + + @Override + public long returnedPosition() { + return returnedPosition; + } + + @Override + public Path filePath() { + return DUMMY_PATH; + } + + @Override + public void releaseBatch() {} + } +} diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexEvaluatorTest.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexEvaluatorTest.java index 0eea0bb02b6e..45d542d921ac 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexEvaluatorTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexEvaluatorTest.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -63,14 +64,11 @@ private static RowType rowType() { } private static GlobalIndexResult resultOf(long... rowIds) { - return GlobalIndexResult.create( - () -> { - RoaringNavigableMap64 bm = new RoaringNavigableMap64(); - for (long id : rowIds) { - bm.add(id); - } - return bm; - }); + RoaringNavigableMap64 bm = new RoaringNavigableMap64(); + for (long id : rowIds) { + bm.add(id); + } + return GlobalIndexResult.create(bm); } private static GlobalIndexReader readerReturning(GlobalIndexResult result) { @@ -83,9 +81,7 @@ void testSingleFieldSequential() { GlobalIndexResult expected = resultOf(1, 2, 3); GlobalIndexEvaluator evaluator = new GlobalIndexEvaluator( - rowType, - fieldId -> Collections.singletonList(readerReturning(expected)), - null); + rowType, fieldId -> Collections.singletonList(readerReturning(expected))); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = builder.equal(0, 42); @@ -114,8 +110,7 @@ void testAndParallelMultipleFields() { rowType, fieldId -> Collections.singletonList( - readerReturning(fieldResults.get(fieldId))), - executor); + readerReturning(fieldResults.get(fieldId)))); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = PredicateBuilder.and(builder.equal(0, 42), builder.equal(1, 99)); @@ -144,8 +139,7 @@ void testOrParallelMultipleFields() { rowType, fieldId -> Collections.singletonList( - readerReturning(fieldResults.get(fieldId))), - executor); + readerReturning(fieldResults.get(fieldId)))); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = PredicateBuilder.or(builder.equal(0, 42), builder.equal(1, 99)); @@ -172,8 +166,7 @@ void testOrReturnsEmptyWhenChildUnsupported() { return Collections.singletonList(readerReturning(resultA)); } return Collections.emptyList(); - }, - executor); + }); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = PredicateBuilder.or(builder.equal(0, 42), builder.equal(1, 99)); @@ -201,8 +194,7 @@ void testAndWithEmptyResultShortCircuits() { rowType, fieldId -> Collections.singletonList( - readerReturning(fieldResults.get(fieldId))), - executor); + readerReturning(fieldResults.get(fieldId)))); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = PredicateBuilder.and(builder.equal(0, 42), builder.equal(1, 99)); @@ -228,14 +220,19 @@ void testParallelUsesMultipleThreads() { Collections.singletonList( new StubGlobalIndexReader(resultOf(fieldId, fieldId + 10)) { @Override - public Optional visitEqual( - FieldRef fieldRef, Object literal) { - threadNames.put( - Thread.currentThread().getName(), true); - return super.visitEqual(fieldRef, literal); + public CompletableFuture> + visitEqual(FieldRef fieldRef, Object literal) { + return CompletableFuture.supplyAsync( + () -> { + threadNames.put( + Thread.currentThread() + .getName(), + true); + return Optional.ofNullable(result); + }, + executor); } - }), - executor); + })); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = @@ -260,8 +257,7 @@ void testNullExecutorFallsBackToSequential() { callCount.incrementAndGet(); return Collections.singletonList( readerReturning(resultOf(fieldId, fieldId + 10))); - }, - null); + }); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = PredicateBuilder.and(builder.equal(0, 1), builder.equal(1, 2)); @@ -292,8 +288,7 @@ void testNestedAndPredicateDoesNotDeadlockWithSmallPool() { rowType, fieldId -> Collections.singletonList( - readerReturning(fieldResults.get(fieldId))), - executor); + readerReturning(fieldResults.get(fieldId)))); // and(a, b, c) builds as and(and(a, b), c) — nested binary tree PredicateBuilder builder = new PredicateBuilder(rowType); @@ -327,8 +322,7 @@ void testNestedOrPredicateDoesNotDeadlockWithSmallPool() { rowType, fieldId -> Collections.singletonList( - readerReturning(fieldResults.get(fieldId))), - executor); + readerReturning(fieldResults.get(fieldId)))); // or(a, b, c) builds as or(or(a, b), c) — nested binary tree PredicateBuilder builder = new PredicateBuilder(rowType); @@ -363,8 +357,7 @@ void testMixedNestedPredicateDoesNotDeadlockWithSmallPool() { rowType, fieldId -> Collections.singletonList( - readerReturning(fieldResults.get(fieldId))), - executor); + readerReturning(fieldResults.get(fieldId)))); // AND(OR(a, b), OR(a, c)) — mixed nesting, different compound types PredicateBuilder builder = new PredicateBuilder(rowType); @@ -402,8 +395,7 @@ void testDeepMixedNestedPredicateDoesNotDeadlockWithSmallPool() { rowType, fieldId -> Collections.singletonList( - readerReturning(fieldResults.get(fieldId))), - executor); + readerReturning(fieldResults.get(fieldId)))); // AND(OR(AND(a, b), c), OR(AND(a, c), b)) — deep mixed nesting PredicateBuilder builder = new PredicateBuilder(rowType); @@ -427,7 +419,7 @@ void testDeepMixedNestedPredicateDoesNotDeadlockWithSmallPool() { } @Test - void testSameFieldPredicatesNotAccessedConcurrently() { + void testSameFieldPredicatesAccessedConcurrently() { executor = Executors.newFixedThreadPool(4); RowType rowType = rowType(); @@ -437,143 +429,36 @@ void testSameFieldPredicatesNotAccessedConcurrently() { GlobalIndexReader concurrencyDetectingReader = new StubGlobalIndexReader(resultOf(1, 2, 3, 4, 5)) { @Override - public Optional visitEqual( + public CompletableFuture> visitEqual( FieldRef fieldRef, Object literal) { - int c = concurrency.incrementAndGet(); - maxConcurrency.updateAndGet(cur -> Math.max(cur, c)); - try { - Thread.sleep(50); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - concurrency.decrementAndGet(); - return super.visitEqual(fieldRef, literal); + return CompletableFuture.supplyAsync( + () -> { + int c = concurrency.incrementAndGet(); + maxConcurrency.updateAndGet(cur -> Math.max(cur, c)); + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + concurrency.decrementAndGet(); + return Optional.ofNullable(result); + }, + executor); } }; GlobalIndexEvaluator evaluator = new GlobalIndexEvaluator( - rowType, - fieldId -> Collections.singletonList(concurrencyDetectingReader), - executor); + rowType, fieldId -> Collections.singletonList(concurrencyDetectingReader)); - // AND(a=1, a=2, a=3) — all same field, must not run concurrently + // AND(a=1, a=2, a=3) — readers dispatch internally, concurrency comes from reader PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = PredicateBuilder.and(builder.equal(0, 1), builder.equal(0, 2), builder.equal(0, 3)); evaluator.evaluate(predicate); - assertThat(maxConcurrency.get()).isEqualTo(1); - evaluator.close(); - } - - @Test - void testMixedNestedSameFieldNotAccessedConcurrently() { - executor = Executors.newFixedThreadPool(4); - RowType rowType = rowType(); - - AtomicInteger concurrencyA = new AtomicInteger(0); - AtomicInteger maxConcurrencyA = new AtomicInteger(0); - - GlobalIndexReader concurrencyDetectingReaderA = - new StubGlobalIndexReader(resultOf(1, 2, 3, 4, 5)) { - @Override - public Optional visitEqual( - FieldRef fieldRef, Object literal) { - int c = concurrencyA.incrementAndGet(); - maxConcurrencyA.updateAndGet(cur -> Math.max(cur, c)); - try { - Thread.sleep(50); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - concurrencyA.decrementAndGet(); - return super.visitEqual(fieldRef, literal); - } - }; - - GlobalIndexEvaluator evaluator = - new GlobalIndexEvaluator( - rowType, - fieldId -> { - if (fieldId == 0) { - return Collections.singletonList(concurrencyDetectingReaderA); - } - return Collections.singletonList( - readerReturning(resultOf(1, 2, 3, 4, 5))); - }, - executor); - - // AND(OR(a=1, b=2), OR(a=3, c=4)) — field a appears in both OR subtrees - PredicateBuilder builder = new PredicateBuilder(rowType); - Predicate predicate = - PredicateBuilder.and( - PredicateBuilder.or(builder.equal(0, 1), builder.equal(1, 2)), - PredicateBuilder.or(builder.equal(0, 3), builder.equal(2, 4))); - - evaluator.evaluate(predicate); - - assertThat(maxConcurrencyA.get()).isEqualTo(1); - evaluator.close(); - } - - @Test - void testLazyResultNotMaterializedConcurrently() { - executor = Executors.newFixedThreadPool(4); - RowType rowType = rowType(); - - AtomicInteger concurrency = new AtomicInteger(0); - AtomicInteger maxConcurrency = new AtomicInteger(0); - - GlobalIndexReader lazyReader = - new StubGlobalIndexReader(null) { - @Override - public Optional visitEqual( - FieldRef fieldRef, Object literal) { - return Optional.of( - GlobalIndexResult.create( - () -> { - int c = concurrency.incrementAndGet(); - maxConcurrency.updateAndGet(cur -> Math.max(cur, c)); - try { - Thread.sleep(50); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - concurrency.decrementAndGet(); - RoaringNavigableMap64 bm = new RoaringNavigableMap64(); - bm.add(1); - bm.add(2); - bm.add(3); - return bm; - })); - } - }; - - GlobalIndexEvaluator evaluator = - new GlobalIndexEvaluator( - rowType, - fieldId -> { - if (fieldId == 0) { - return Collections.singletonList(lazyReader); - } - return Collections.singletonList( - readerReturning(resultOf(1, 2, 3, 4, 5))); - }, - executor); - - // AND(OR(a=1, b=2), OR(a=3, c=4)) — field a in both OR subtrees - // lazy results for field a must not be materialized concurrently - PredicateBuilder builder = new PredicateBuilder(rowType); - Predicate predicate = - PredicateBuilder.and( - PredicateBuilder.or(builder.equal(0, 1), builder.equal(1, 2)), - PredicateBuilder.or(builder.equal(0, 3), builder.equal(2, 4))); - - evaluator.evaluate(predicate); - - assertThat(maxConcurrency.get()).isEqualTo(1); + assertThat(maxConcurrency.get()).isGreaterThan(1); evaluator.close(); } @@ -591,8 +476,7 @@ void testMultipleReadersPerFieldCombinedWithAnd() { fieldId -> Arrays.asList( readerReturning(readerResult1), - readerReturning(readerResult2)), - executor); + readerReturning(readerResult2))); PredicateBuilder builder = new PredicateBuilder(rowType); Predicate predicate = builder.equal(0, 42); @@ -614,9 +498,7 @@ void testNonFieldLeafPredicateDoesNotThrow() { GlobalIndexEvaluator evaluator = new GlobalIndexEvaluator( - rowType, - fieldId -> Collections.singletonList(readerReturning(resultA)), - executor); + rowType, fieldId -> Collections.singletonList(readerReturning(resultA))); // Manually build AND(alwaysTrue, a=1) to bypass PredicateBuilder simplification PredicateBuilder builder = new PredicateBuilder(rowType); @@ -636,7 +518,7 @@ void testNonFieldLeafPredicateDoesNotThrow() { void testNullPredicate() { RowType rowType = rowType(); GlobalIndexEvaluator evaluator = - new GlobalIndexEvaluator(rowType, fieldId -> Collections.emptyList(), null); + new GlobalIndexEvaluator(rowType, fieldId -> Collections.emptyList()); Optional result = evaluator.evaluate(null); @@ -654,80 +536,92 @@ private static void assertBitmapContainsExactly( private static class StubGlobalIndexReader implements GlobalIndexReader { - private final GlobalIndexResult result; + protected final GlobalIndexResult result; StubGlobalIndexReader(GlobalIndexResult result) { this.result = result; } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return Optional.ofNullable(result); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.ofNullable(result)); } @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return Optional.ofNullable(result); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.ofNullable(result)); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexSerDeUtilsTest.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexSerDeUtilsTest.java index 2fec1e2410e7..7729620d29c5 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexSerDeUtilsTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexSerDeUtilsTest.java @@ -36,7 +36,7 @@ public class GlobalIndexSerDeUtilsTest { @Test public void testSerializeAndDeserializeGlobalIndexResult() throws IOException { RoaringNavigableMap64 bitmap = RoaringNavigableMap64.bitmapOf(1, 5, 10, 100, 1000); - GlobalIndexResult original = GlobalIndexResult.create(() -> bitmap); + GlobalIndexResult original = GlobalIndexResult.create(bitmap); byte[] serialized = serialize(original); GlobalIndexResult deserialized = deserialize(serialized); @@ -65,8 +65,7 @@ public void testSerializeAndDeserializeTopkGlobalIndexResult() throws IOExceptio scoreMap.put(10L, 0.7f); scoreMap.put(100L, 0.6f); - ScoredGlobalIndexResult original = - ScoredGlobalIndexResult.create(() -> bitmap, scoreMap::get); + ScoredGlobalIndexResult original = ScoredGlobalIndexResult.create(bitmap, scoreMap::get); byte[] serialized = serialize(original); GlobalIndexResult deserialized = deserialize(serialized); @@ -92,8 +91,7 @@ public void testSerializeAndDeserializeTopkWithLargeRowIds() throws IOException scoreMap.put(Integer.MAX_VALUE + 100L, 0.3f); scoreMap.put(Long.MAX_VALUE - 1, 0.1f); - ScoredGlobalIndexResult original = - ScoredGlobalIndexResult.create(() -> bitmap, scoreMap::get); + ScoredGlobalIndexResult original = ScoredGlobalIndexResult.create(bitmap, scoreMap::get); byte[] serialized = serialize(original); GlobalIndexResult deserialized = deserialize(serialized); diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/bitmapindex/BitmapGlobalIndexTest.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/bitmapindex/BitmapGlobalIndexTest.java deleted file mode 100644 index cd0277fcdcef..000000000000 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/bitmapindex/BitmapGlobalIndexTest.java +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.globalindex.bitmapindex; - -import org.apache.paimon.data.BinaryString; -import org.apache.paimon.fileindex.bitmap.BitmapFileIndex; -import org.apache.paimon.fs.FileIO; -import org.apache.paimon.fs.Path; -import org.apache.paimon.fs.PositionOutputStream; -import org.apache.paimon.fs.local.LocalFileIO; -import org.apache.paimon.globalindex.GlobalIndexIOMeta; -import org.apache.paimon.globalindex.GlobalIndexReader; -import org.apache.paimon.globalindex.GlobalIndexResult; -import org.apache.paimon.globalindex.GlobalIndexSingletonWriter; -import org.apache.paimon.globalindex.bitmap.BitmapGlobalIndex; -import org.apache.paimon.globalindex.io.GlobalIndexFileReader; -import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; -import org.apache.paimon.options.Options; -import org.apache.paimon.predicate.FieldRef; -import org.apache.paimon.types.DataType; -import org.apache.paimon.types.DataTypes; -import org.apache.paimon.utils.RoaringBitmap32; -import org.apache.paimon.utils.RoaringNavigableMap64; - -import org.junit.Rule; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; -import org.junit.rules.TemporaryFolder; - -import java.io.File; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.UUID; -import java.util.function.Consumer; - -/** Tests for {@link BitmapGlobalIndex}. */ -public class BitmapGlobalIndexTest { - - @TempDir private File tempDir; - - @Rule public TemporaryFolder folder = new TemporaryFolder(); - - @Test - public void testV1() throws Exception { - testIntType(BitmapFileIndex.VERSION_1); - testStringType(BitmapFileIndex.VERSION_1); - testBooleanType(BitmapFileIndex.VERSION_1); - testHighCardinality(BitmapFileIndex.VERSION_1, 1000000, 100000, null); - testStringTypeWithReusing(BitmapFileIndex.VERSION_1); - testAllNull(BitmapFileIndex.VERSION_1); - } - - @Test - public void testV2() throws Exception { - testIntType(BitmapFileIndex.VERSION_2); - testStringType(BitmapFileIndex.VERSION_2); - testBooleanType(BitmapFileIndex.VERSION_2); - testHighCardinality(BitmapFileIndex.VERSION_2, 1000000, 100000, null); - testStringTypeWithReusing(BitmapFileIndex.VERSION_2); - testAllNull(BitmapFileIndex.VERSION_2); - } - - private void testStringType(int version) throws Exception { - FieldRef fieldRef = new FieldRef(0, "", DataTypes.STRING()); - BinaryString a = BinaryString.fromString("a"); - BinaryString b = BinaryString.fromString("b"); - Object[] dataColumn = {a, null, b, null, a}; - GlobalIndexReader reader = - createTestReaderOnWriter( - version, - null, - DataTypes.STRING(), - writer -> { - for (Object o : dataColumn) { - writer.write(o); - } - }); - assert reader.visitEqual(fieldRef, a) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0, 4)); - assert reader.visitEqual(fieldRef, b) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(2)); - assert reader.visitIsNull(fieldRef) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(1, 3)); - assert reader.visitIn(fieldRef, Arrays.asList(a, b)) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0, 2, 4)); - assert reader.visitEqual(fieldRef, BinaryString.fromString("c")).get().results().isEmpty(); - } - - private void testIntType(int version) throws Exception { - FieldRef fieldRef = new FieldRef(0, "", DataTypes.INT()); - Object[] dataColumn = {0, 1, null}; - GlobalIndexReader reader = - createTestReaderOnWriter( - version, - null, - DataTypes.INT(), - writer -> { - for (Object o : dataColumn) { - writer.write(o); - } - }); - assert reader.visitEqual(fieldRef, 0) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0)); - assert reader.visitEqual(fieldRef, 1) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(1)); - assert reader.visitIsNull(fieldRef) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(2)); - assert reader.visitIn(fieldRef, Arrays.asList(0, 1, 2)) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0, 1)); - - assert reader.visitEqual(fieldRef, 2).get().results().isEmpty(); - } - - private void testBooleanType(int version) throws Exception { - FieldRef fieldRef = new FieldRef(0, "", DataTypes.BOOLEAN()); - Object[] dataColumn = {Boolean.TRUE, Boolean.FALSE, Boolean.TRUE, Boolean.FALSE, null}; - GlobalIndexReader reader = - createTestReaderOnWriter( - version, - null, - DataTypes.BOOLEAN(), - writer -> { - for (Object o : dataColumn) { - writer.write(o); - } - }); - assert reader.visitEqual(fieldRef, Boolean.TRUE) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0, 2)); - assert reader.visitIsNull(fieldRef) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(4)); - } - - private void testHighCardinality( - int version, int rowCount, int approxCardinality, Integer secondaryBlockSize) - throws Exception { - FieldRef fieldRef = new FieldRef(0, "", DataTypes.STRING()); - RoaringBitmap32 middleBm = new RoaringBitmap32(); - RoaringBitmap32 nullBm = new RoaringBitmap32(); - long time1 = System.currentTimeMillis(); - String prefix = "ssssssssss"; - GlobalIndexReader reader = - createTestReaderOnWriter( - version, - secondaryBlockSize, - DataTypes.STRING(), - writer -> { - for (int i = 0; i < rowCount; i++) { - - int sid = (int) (Math.random() * approxCardinality); - if (sid == approxCardinality / 2) { - middleBm.add(i); - } else if (Math.random() < 0.01) { - nullBm.add(i); - writer.write(null); - continue; - } - writer.write(BinaryString.fromString(prefix + sid)); - } - }); - System.out.println("write time: " + (System.currentTimeMillis() - time1)); - long time2 = System.currentTimeMillis(); - GlobalIndexResult result = - reader.visitEqual( - fieldRef, BinaryString.fromString(prefix + (approxCardinality / 2))) - .get(); - System.out.println("read time: " + (System.currentTimeMillis() - time2)); - assert result.results().equals(middleBm.toNavigable64()); - long time3 = System.currentTimeMillis(); - GlobalIndexResult resultNull = reader.visitIsNull(fieldRef).get(); - System.out.println("read null bitmap time: " + (System.currentTimeMillis() - time3)); - assert resultNull.results().equals(nullBm.toNavigable64()); - } - - private GlobalIndexReader createTestReaderOnWriter( - int writerVersion, - Integer indexBlockSize, - DataType dataType, - Consumer consumer) - throws Exception { - Options options = new Options(); - options.setInteger(BitmapFileIndex.VERSION, writerVersion); - if (indexBlockSize != null) { - options.setInteger(BitmapFileIndex.INDEX_BLOCK_SIZE, indexBlockSize); - } - BitmapFileIndex bitmapFileIndex = new BitmapFileIndex(dataType, options); - BitmapGlobalIndex bitmapGlobalIndex = new BitmapGlobalIndex(bitmapFileIndex); - final FileIO fileIO = new LocalFileIO(); - GlobalIndexFileWriter fileWriter = - new GlobalIndexFileWriter() { - @Override - public String newFileName(String prefix) { - return prefix + UUID.randomUUID(); - } - - @Override - public PositionOutputStream newOutputStream(String fileName) - throws IOException { - return fileIO.newOutputStream(new Path(tempDir.toString(), fileName), true); - } - }; - GlobalIndexSingletonWriter globalIndexWriter = bitmapGlobalIndex.createWriter(fileWriter); - consumer.accept(globalIndexWriter); - String fileName = globalIndexWriter.finish().get(0).fileName(); - Path path = new Path(tempDir.toString(), fileName); - long fileSize = fileIO.getFileSize(path); - - GlobalIndexFileReader fileReader = - meta -> fileIO.newInputStream(new Path(tempDir.toString(), meta.filePath())); - - GlobalIndexIOMeta globalIndexMeta = new GlobalIndexIOMeta(path, fileSize, null); - - return bitmapGlobalIndex.createReader( - fileReader, Collections.singletonList(globalIndexMeta)); - } - - private void testStringTypeWithReusing(int version) throws Exception { - FieldRef fieldRef = new FieldRef(0, "", DataTypes.STRING()); - BinaryString a = BinaryString.fromString("a"); - BinaryString b = BinaryString.fromString("b"); - BinaryString c = BinaryString.fromString("a"); - GlobalIndexReader reader = - createTestReaderOnWriter( - version, - null, - DataTypes.STRING(), - writer -> { - writer.write(a); - writer.write(null); - a.pointTo(b.getSegments(), b.getOffset(), b.getSizeInBytes()); - writer.write(null); - writer.write(a); - writer.write(null); - a.pointTo(c.getSegments(), c.getOffset(), c.getSizeInBytes()); - writer.write(null); - }); - assert reader.visitEqual(fieldRef, a) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0)); - assert reader.visitEqual(fieldRef, b) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(3)); - assert reader.visitIsNull(fieldRef) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(1, 2, 4, 5)); - assert reader.visitIn(fieldRef, Arrays.asList(a, b)) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0, 3)); - assert reader.visitEqual(fieldRef, BinaryString.fromString("c")).get().results().isEmpty(); - } - - private void testAllNull(int version) throws Exception { - FieldRef fieldRef = new FieldRef(0, "", DataTypes.INT()); - Object[] dataColumn = {null, null, null}; - GlobalIndexReader reader = - createTestReaderOnWriter( - version, - null, - DataTypes.INT(), - writer -> { - for (Object o : dataColumn) { - writer.write(o); - } - }); - assert reader.visitIsNull(fieldRef) - .get() - .results() - .equals(RoaringNavigableMap64.bitmapOf(0, 1, 2)); - assert reader.visitIsNotNull(fieldRef).get().results().isEmpty(); - } -} diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/AbstractIndexReaderTest.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/AbstractIndexReaderTest.java index 23b8d77bf21e..1e81b7c74499 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/AbstractIndexReaderTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/AbstractIndexReaderTest.java @@ -167,33 +167,33 @@ public void testRangePredicate() throws Exception { Object literal = data.get(literalIdx).getKey(); // 1. test <= literal - result = reader.visitLessOrEqual(ref, literal).get(); + result = reader.visitLessOrEqual(ref, literal).join().get(); assertResult(result, filter(obj -> comparator.compare(obj, literal) <= 0)); // 2. test < literal - result = reader.visitLessThan(ref, literal).get(); + result = reader.visitLessThan(ref, literal).join().get(); assertResult(result, filter(obj -> comparator.compare(obj, literal) < 0)); // 3. test >= literal - result = reader.visitGreaterOrEqual(ref, literal).get(); + result = reader.visitGreaterOrEqual(ref, literal).join().get(); assertResult(result, filter(obj -> comparator.compare(obj, literal) >= 0)); // 4. test > literal - result = reader.visitGreaterThan(ref, literal).get(); + result = reader.visitGreaterThan(ref, literal).join().get(); assertResult(result, filter(obj -> comparator.compare(obj, literal) > 0)); // 5. test equal - result = reader.visitEqual(ref, literal).get(); + result = reader.visitEqual(ref, literal).join().get(); assertResult(result, filter(obj -> comparator.compare(obj, literal) == 0)); // 6. test not equal - result = reader.visitNotEqual(ref, literal).get(); + result = reader.visitNotEqual(ref, literal).join().get(); assertResult(result, filter(obj -> comparator.compare(obj, literal) != 0)); // 7. test between int betweenOffset = random.nextInt(dataNum - literalIdx); Object toLiteral = data.get(literalIdx + betweenOffset).getKey(); - result = reader.visitBetween(ref, literal, toLiteral).get(); + result = reader.visitBetween(ref, literal, toLiteral).join().get(); assertResult( result, filter( @@ -204,12 +204,12 @@ public void testRangePredicate() throws Exception { // 8. test < min Object literal7 = data.get(0).getKey(); - result = reader.visitLessThan(ref, literal7).get(); + result = reader.visitLessThan(ref, literal7).join().get(); Assertions.assertTrue(result.results().isEmpty()); // 9. test > max Object literal8 = data.get(dataNum - 1).getKey(); - result = reader.visitGreaterThan(ref, literal8).get(); + result = reader.visitGreaterThan(ref, literal8).join().get(); Assertions.assertTrue(result.results().isEmpty()); // 10. test between @@ -217,10 +217,10 @@ public void testRangePredicate() throws Exception { Integer min = (Integer) (data.get(0).getKey()); Integer max = (Integer) (data.get(dataNum - 1).getKey()); - result = reader.visitBetween(ref, min - 100, min - 1).get(); + result = reader.visitBetween(ref, min - 100, min - 1).join().get(); Assertions.assertTrue(result.results().isEmpty()); - result = reader.visitBetween(ref, max + 1, max + 100).get(); + result = reader.visitBetween(ref, max + 1, max + 100).join().get(); Assertions.assertTrue(result.results().isEmpty()); } } @@ -239,10 +239,10 @@ public void testIsNull() throws Exception { try (GlobalIndexReader reader = prepareDataAndCreateReader()) { GlobalIndexResult result; - result = reader.visitIsNull(ref).get(); + result = reader.visitIsNull(ref).join().get(); assertResult(result, filter(Objects::isNull)); - result = reader.visitIsNotNull(ref).get(); + result = reader.visitIsNotNull(ref).join().get(); assertResult(result, filter(Objects::nonNull)); } } @@ -264,11 +264,11 @@ public void testInPredicate() throws Exception { set.addAll(literals); // 1. test in - result = reader.visitIn(ref, literals).get(); + result = reader.visitIn(ref, literals).join().get(); assertResult(result, filter(set::contains)); // 2. test not in - result = reader.visitNotIn(ref, literals).get(); + result = reader.visitNotIn(ref, literals).join().get(); assertResult(result, filter(obj -> !set.contains(obj))); } } diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeIndexReaderTest.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeIndexReaderTest.java index e2c028286f72..3c5e9b07c0d6 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeIndexReaderTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeIndexReaderTest.java @@ -24,8 +24,11 @@ import org.junit.jupiter.api.extension.ExtendWith; +import java.util.Collections; import java.util.List; +import static org.apache.paimon.shade.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; + /** Test for {@link BTreeIndexReader} to read a single file. */ @ExtendWith(ParameterizedTestExtension.class) public class BTreeIndexReaderTest extends AbstractIndexReaderTest { @@ -37,7 +40,7 @@ public BTreeIndexReaderTest(List args) { @Override protected GlobalIndexReader prepareDataAndCreateReader() throws Exception { GlobalIndexIOMeta written = writeData(data); - - return new BTreeIndexReader(keySerializer, fileReader, written, CACHE_MANAGER); + return globalIndexer.createReader( + fileReader, Collections.singletonList(written), newDirectExecutorService()); } } diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeThreadSafetyTest.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeThreadSafetyTest.java new file mode 100644 index 000000000000..af923d1f66d3 --- /dev/null +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/BTreeThreadSafetyTest.java @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.globalindex.btree; + +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.PositionOutputStream; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexParallelWriter; +import org.apache.paimon.globalindex.GlobalIndexReader; +import org.apache.paimon.globalindex.GlobalIndexResult; +import org.apache.paimon.globalindex.ResultEntry; +import org.apache.paimon.globalindex.io.GlobalIndexFileReader; +import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.io.cache.CacheManager; +import org.apache.paimon.options.MemorySize; +import org.apache.paimon.options.Options; +import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.IntType; +import org.apache.paimon.utils.Pair; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Thread-safety tests for BTree global index readers. */ +public class BTreeThreadSafetyTest { + + private static final CacheManager CACHE_MANAGER = new CacheManager(MemorySize.VALUE_8_MB); + + @TempDir java.nio.file.Path tempPath; + + private FileIO fileIO; + private GlobalIndexFileReader fileReader; + private GlobalIndexFileWriter fileWriter; + private BTreeGlobalIndexer globalIndexer; + private KeySerializer keySerializer; + private Comparator comparator; + private ExecutorService executor; + + private List> data; + private static final int DATA_NUM = 10000; + private static final int FILE_NUM = 10; + + @BeforeEach + void setUp() { + fileIO = LocalFileIO.create(); + fileWriter = + new GlobalIndexFileWriter() { + @Override + public String newFileName(String prefix) { + return "test-btree-" + UUID.randomUUID() + prefix; + } + + @Override + public PositionOutputStream newOutputStream(String fileName) + throws IOException { + return fileIO.newOutputStream( + new Path(new Path(tempPath.toUri()), fileName), true); + } + }; + fileReader = + meta -> + fileIO.newInputStream( + new Path(new Path(tempPath.toUri()), meta.filePath())); + + Options options = new Options(); + options.set(BTreeIndexOptions.BTREE_INDEX_CACHE_SIZE, MemorySize.ofMebiBytes(8)); + DataField dataField = new DataField(1, "id", new IntType()); + globalIndexer = new BTreeGlobalIndexer(dataField, options); + keySerializer = KeySerializer.create(new IntType()); + comparator = keySerializer.createComparator(); + + data = new ArrayList<>(DATA_NUM); + for (int i = 0; i < DATA_NUM; i++) { + data.add(Pair.of(i, (long) i)); + } + data.sort((p1, p2) -> comparator.compare(p1.getKey(), p2.getKey())); + } + + @AfterEach + void tearDown() { + if (executor != null) { + executor.shutdownNow(); + } + } + + @Test + void testNoDeadlockMoreFilesThanThreads() throws Exception { + executor = Executors.newFixedThreadPool(2); + List metas = writeMultipleFiles(); + + try (GlobalIndexReader reader = globalIndexer.createReader(fileReader, metas, executor)) { + FieldRef ref = new FieldRef(1, "id", new IntType()); + Optional result = + reader.visitIsNotNull(ref).get(10, TimeUnit.SECONDS); + assertThat(result).isPresent(); + assertThat(result.get().results().iterator().hasNext()).isTrue(); + } + } + + @Test + void testConcurrentVisitEqualOnSameReader() throws Exception { + executor = Executors.newFixedThreadPool(8); + List metas = writeMultipleFiles(); + + try (GlobalIndexReader reader = globalIndexer.createReader(fileReader, metas, executor)) { + FieldRef ref = new FieldRef(1, "id", new IntType()); + int numThreads = 32; + CountDownLatch latch = new CountDownLatch(numThreads); + ExecutorService queryPool = Executors.newFixedThreadPool(numThreads); + + List>> futures = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + int literal = (i * 313) % DATA_NUM; + futures.add( + queryPool.submit( + () -> { + latch.countDown(); + latch.await(); + Optional result = + reader.visitEqual(ref, literal) + .get(10, TimeUnit.SECONDS); + assertThat(result).isPresent(); + return toList(result.get()); + })); + } + + for (int i = 0; i < numThreads; i++) { + int literal = (i * 313) % DATA_NUM; + List expected = + data.stream() + .filter(p -> comparator.compare(p.getKey(), literal) == 0) + .map(Pair::getValue) + .collect(Collectors.toList()); + List actual = futures.get(i).get(15, TimeUnit.SECONDS); + assertThat(actual).containsExactlyInAnyOrderElementsOf(expected); + } + + queryPool.shutdown(); + queryPool.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + void testConcurrentMixedPredicates() throws Exception { + executor = Executors.newFixedThreadPool(4); + List metas = writeMultipleFiles(); + + try (GlobalIndexReader reader = globalIndexer.createReader(fileReader, metas, executor)) { + FieldRef ref = new FieldRef(1, "id", new IntType()); + int numThreads = 24; + CountDownLatch latch = new CountDownLatch(numThreads); + ExecutorService queryPool = Executors.newFixedThreadPool(numThreads); + + List>> futures = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + int idx = i; + futures.add( + queryPool.submit( + () -> { + latch.countDown(); + latch.await(); + Optional result; + switch (idx % 4) { + case 0: + result = + reader.visitEqual(ref, idx * 100) + .get(10, TimeUnit.SECONDS); + break; + case 1: + result = + reader.visitLessThan(ref, idx * 100) + .get(10, TimeUnit.SECONDS); + break; + case 2: + result = + reader.visitGreaterThan(ref, idx * 100) + .get(10, TimeUnit.SECONDS); + break; + default: + result = + reader.visitBetween(ref, idx * 50, idx * 100) + .get(10, TimeUnit.SECONDS); + break; + } + assertThat(result).isPresent(); + return toList(result.get()); + })); + } + + for (int i = 0; i < numThreads; i++) { + int idx = i; + List expected; + switch (idx % 4) { + case 0: + expected = filter(obj -> comparator.compare(obj, idx * 100) == 0); + break; + case 1: + expected = filter(obj -> comparator.compare(obj, idx * 100) < 0); + break; + case 2: + expected = filter(obj -> comparator.compare(obj, idx * 100) > 0); + break; + default: + expected = + filter( + obj -> + comparator.compare(obj, idx * 50) >= 0 + && comparator.compare(obj, idx * 100) <= 0); + break; + } + List actual = futures.get(i).get(15, TimeUnit.SECONDS); + assertThat(actual).containsExactlyInAnyOrderElementsOf(expected); + } + + queryPool.shutdown(); + queryPool.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + void testLazyCreationOnlyOncePerFile() throws Exception { + AtomicInteger creationCount = new AtomicInteger(0); + + executor = Executors.newFixedThreadPool(8); + List metas = writeMultipleFiles(); + + GlobalIndexFileReader countingFileReader = + meta -> { + creationCount.incrementAndGet(); + return fileIO.newInputStream( + new Path(new Path(tempPath.toUri()), meta.filePath())); + }; + + try (GlobalIndexReader reader = + globalIndexer.createReader(countingFileReader, metas, executor)) { + FieldRef ref = new FieldRef(1, "id", new IntType()); + int numThreads = 32; + CountDownLatch latch = new CountDownLatch(numThreads); + ExecutorService queryPool = Executors.newFixedThreadPool(numThreads); + + List> futures = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + futures.add( + queryPool.submit( + () -> { + latch.countDown(); + try { + latch.await(); + reader.visitIsNotNull(ref).get(10, TimeUnit.SECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + })); + } + + for (Future f : futures) { + f.get(15, TimeUnit.SECONDS); + } + + queryPool.shutdown(); + queryPool.awaitTermination(5, TimeUnit.SECONDS); + } + + // Each file should be opened at most once per reader creation + // (ConcurrentHashMap.computeIfAbsent). + // BTreeIndexReader opens the file once during construction. + assertThat(creationCount.get()).isLessThanOrEqualTo(FILE_NUM); + } + + @Test + void testSelectorPrunesCorrectly() throws Exception { + executor = Executors.newFixedThreadPool(4); + List metas = writeMultipleFiles(); + + Set openedFiles = new HashSet<>(); + GlobalIndexFileReader trackingFileReader = + meta -> { + synchronized (openedFiles) { + openedFiles.add(meta.filePath().getName()); + } + return fileIO.newInputStream( + new Path(new Path(tempPath.toUri()), meta.filePath())); + }; + + try (GlobalIndexReader reader = + globalIndexer.createReader(trackingFileReader, metas, executor)) { + FieldRef ref = new FieldRef(1, "id", new IntType()); + + // Query for value 5 — should only need to open the first file (keys 0-999) + Optional result = + reader.visitEqual(ref, 5).get(10, TimeUnit.SECONDS); + assertThat(result).isPresent(); + List hits = toList(result.get()); + assertThat(hits).contains(5L); + + // Should NOT have opened all files + assertThat(openedFiles.size()).isLessThan(FILE_NUM); + } + } + + private List writeMultipleFiles() throws IOException { + int recordPerFile = DATA_NUM / FILE_NUM; + List metas = new ArrayList<>(FILE_NUM); + int currentStart = 0; + while (currentStart < DATA_NUM) { + int nextStart = Math.min(currentStart + recordPerFile, DATA_NUM); + metas.add(writeData(data.subList(currentStart, nextStart))); + currentStart = nextStart; + } + return metas; + } + + private GlobalIndexIOMeta writeData(List> subData) throws IOException { + GlobalIndexParallelWriter indexWriter = globalIndexer.createWriter(fileWriter); + for (Pair pair : subData) { + indexWriter.write(pair.getKey(), pair.getValue()); + } + List results = indexWriter.finish(); + Assertions.assertEquals(1, results.size()); + ResultEntry resultEntry = results.get(0); + String fileName = resultEntry.fileName(); + return new GlobalIndexIOMeta( + new Path(new Path(tempPath.toUri()), fileName), + fileIO.getFileSize(new Path(new Path(tempPath.toUri()), fileName)), + resultEntry.meta()); + } + + private List filter(java.util.function.Predicate predicate) { + return data.stream() + .filter(pair -> predicate.test(pair.getKey())) + .map(Pair::getValue) + .collect(Collectors.toList()); + } + + private static List toList(GlobalIndexResult result) { + Iterator iter = result.results().iterator(); + List list = new ArrayList<>(); + while (iter.hasNext()) { + list.add(iter.next()); + } + return list; + } +} diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeIndexReaderTest.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeIndexReaderTest.java index fa5437211907..86e9e8227ed9 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeIndexReaderTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/btree/LazyFilteredBTreeIndexReaderTest.java @@ -18,22 +18,41 @@ package org.apache.paimon.globalindex.btree; +import org.apache.paimon.fs.Path; import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexParallelWriter; import org.apache.paimon.globalindex.GlobalIndexReader; +import org.apache.paimon.globalindex.GlobalIndexResult; +import org.apache.paimon.globalindex.ResultEntry; +import org.apache.paimon.options.MemorySize; +import org.apache.paimon.options.Options; +import org.apache.paimon.predicate.FieldRef; import org.apache.paimon.testutils.junit.parameterized.ParameterizedTestExtension; +import org.apache.paimon.types.DataField; import org.apache.paimon.utils.Pair; +import org.apache.paimon.utils.SemaphoredDelegatingExecutor; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; +import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Random; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import static org.apache.paimon.shade.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static org.assertj.core.api.Assertions.assertThat; /** Test for {@link LazyFilteredBTreeReader} to read multiple files. */ @@ -47,8 +66,7 @@ public class LazyFilteredBTreeIndexReaderTest extends AbstractIndexReaderTest { @Override protected GlobalIndexReader prepareDataAndCreateReader() throws Exception { List written = writeData(); - - return globalIndexer.createReader(fileReader, written); + return globalIndexer.createReader(fileReader, written, newDirectExecutorService()); } private List writeData() throws Exception { @@ -132,4 +150,273 @@ public void testUnorderedIterator() throws Exception { .containsExactlyElementsOf(expectedMap.get(key)); } } + + /** + * Stress-test thread safety of {@link LazyFilteredBTreeReader} under concurrent access with + * aggressive cache eviction. Uses small block size (4KB) to create many cache blocks per file, + * tiny cache (64KB) to force frequent evictions, 20 index files sharing one CacheManager, and + * 16 concurrent threads each running 15 random queries of all types. This exercises: + * + *
      + *
    • ConcurrentHashMap reader cache in LazyFilteredBTreeReader + *
    • BlockCache.getBlock() check-then-act race under eviction pressure + *
    • CacheManager eviction callbacks racing with concurrent reads + *
    • BTreeIndexReader readLock contention across query types + *
    • LazyField initialization race for null bitmaps + *
    • SegmentContainer.accessCount non-atomic increment under contention + *
    + */ + @TestTemplate + public void testConcurrentAccess() throws Exception { + // Small block size → many blocks per file → more cache entries + // Tiny cache → aggressive eviction under concurrent load + Options stressOptions = new Options(); + stressOptions.set(BTreeIndexOptions.BTREE_INDEX_BLOCK_SIZE, MemorySize.ofKibiBytes(4)); + stressOptions.set(BTreeIndexOptions.BTREE_INDEX_CACHE_SIZE, MemorySize.ofKibiBytes(64)); + stressOptions.set(BTreeIndexOptions.BTREE_INDEX_HIGH_PRIORITY_POOL_RATIO, 0.1); + BTreeGlobalIndexer stressIndexer = + new BTreeGlobalIndexer(new DataField(1, "testField", dataType), stressOptions); + + // Inject null values at the tail to test isNull/isNotNull under concurrency + for (int i = dataNum - 1; i >= dataNum * 0.9; i--) { + data.get(i).setLeft(null); + } + + // Split into 20 files → 20 BTreeIndexReader instances sharing one CacheManager + int fileNum = 20; + List written = new ArrayList<>(fileNum); + int currentStart = 0; + int recordPerFile = dataNum / fileNum; + while (currentStart < dataNum) { + int nextStart = Math.min(currentStart + recordPerFile, dataNum); + written.add(writeDataWithIndexer(stressIndexer, data.subList(currentStart, nextStart))); + currentStart = nextStart; + } + + // Real multi-threaded executor for the reader's internal file-level parallelism + ExecutorService readerExecutor = Executors.newFixedThreadPool(8); + try (GlobalIndexReader reader = + stressIndexer.createReader(fileReader, written, readerExecutor)) { + FieldRef ref = new FieldRef(1, "testField", dataType); + + int concurrency = 16; + int queriesPerThread = 15; + CyclicBarrier barrier = new CyclicBarrier(concurrency); + ExecutorService testExecutor = Executors.newFixedThreadPool(concurrency); + List> futures = new ArrayList<>(); + + for (int t = 0; t < concurrency; t++) { + int threadId = t; + futures.add( + testExecutor.submit( + () -> { + try { + barrier.await(); + } catch (Exception e) { + throw new RuntimeException(e); + } + Random random = new Random(threadId); + for (int i = 0; i < queriesPerThread; i++) { + runRandomQuery(ref, random, reader); + } + })); + } + + testExecutor.shutdown(); + for (Future f : futures) { + f.get(120, TimeUnit.SECONDS); + } + } finally { + readerExecutor.shutdown(); + } + } + + private GlobalIndexIOMeta writeDataWithIndexer( + BTreeGlobalIndexer indexer, List> subData) throws IOException { + GlobalIndexParallelWriter indexWriter = indexer.createWriter(fileWriter); + for (Pair pair : subData) { + indexWriter.write(pair.getKey(), pair.getValue()); + } + List results = indexWriter.finish(); + Assertions.assertEquals(1, results.size()); + ResultEntry resultEntry = results.get(0); + String fileName = resultEntry.fileName(); + return new GlobalIndexIOMeta( + new Path(new Path(tempPath.toUri()), fileName), + fileIO.getFileSize(new Path(new Path(tempPath.toUri()), fileName)), + resultEntry.meta()); + } + + /** + * Regression test for deadlock when using {@link SemaphoredDelegatingExecutor}. Before the fix, + * {@code LazyFilteredBTreeReader.visitParallel} submitted tasks to the executor (acquiring + * permits), and the visitor inside {@code BTreeIndexReader.supplyResult} submitted nested tasks + * to the same executor (needing more permits), causing deadlock when permits were exhausted. + */ + @TestTemplate + public void testNoDeadlockWithSemaphoredExecutor() throws Exception { + List written = writeData(); + + // Permit count (2) is smaller than file count (10), which would have triggered + // the deadlock before the fix: Level 1 tasks hold permits while Level 2 tasks + // wait for permits from the same pool. + ExecutorService baseExecutor = Executors.newFixedThreadPool(4); + ExecutorService semaphoredExecutor = + new SemaphoredDelegatingExecutor(baseExecutor, 2, false); + + try (GlobalIndexReader reader = + globalIndexer.createReader(fileReader, written, semaphoredExecutor)) { + FieldRef ref = new FieldRef(1, "testField", dataType); + + Random random = new Random(42); + for (int i = 0; i < 10; i++) { + int idx = random.nextInt(dataNum); + Object literal = data.get(idx).getKey(); + + GlobalIndexResult result = reader.visitEqual(ref, literal).join().get(); + assertResult(result, filter(obj -> comparator.compare(obj, literal) == 0)); + + result = reader.visitLessThan(ref, literal).join().get(); + assertResult(result, filter(obj -> comparator.compare(obj, literal) < 0)); + + List inLiterals = new ArrayList<>(); + for (int j = 0; j < 5; j++) { + inLiterals.add(data.get(random.nextInt(dataNum)).getKey()); + } + TreeSet inSet = new TreeSet<>(comparator); + inSet.addAll(inLiterals); + result = reader.visitIn(ref, inLiterals).join().get(); + assertResult(result, filter(inSet::contains)); + } + } finally { + baseExecutor.shutdown(); + } + } + + @SuppressWarnings("OptionalGetWithoutIsPresent") + private void runRandomQuery(FieldRef ref, Random random, GlobalIndexReader reader) { + int queryType = random.nextInt(11); + int idx = random.nextInt(dataNum); + Object literal = data.get(idx).getKey(); + GlobalIndexResult result; + + switch (queryType) { + case 0: // equal + if (literal == null) { + return; + } + result = reader.visitEqual(ref, literal).join().get(); + assertResult( + result, + filter(obj -> obj != null && comparator.compare(obj, literal) == 0)); + break; + case 1: // lessThan + if (literal == null) { + return; + } + result = reader.visitLessThan(ref, literal).join().get(); + assertResult( + result, filter(obj -> obj != null && comparator.compare(obj, literal) < 0)); + break; + case 2: // greaterOrEqual + if (literal == null) { + return; + } + result = reader.visitGreaterOrEqual(ref, literal).join().get(); + assertResult( + result, + filter(obj -> obj != null && comparator.compare(obj, literal) >= 0)); + break; + case 3: // lessOrEqual + if (literal == null) { + return; + } + result = reader.visitLessOrEqual(ref, literal).join().get(); + assertResult( + result, + filter(obj -> obj != null && comparator.compare(obj, literal) <= 0)); + break; + case 4: // greaterThan + if (literal == null) { + return; + } + result = reader.visitGreaterThan(ref, literal).join().get(); + assertResult( + result, filter(obj -> obj != null && comparator.compare(obj, literal) > 0)); + break; + case 5: // notEqual + if (literal == null) { + return; + } + result = reader.visitNotEqual(ref, literal).join().get(); + assertResult( + result, + filter(obj -> obj != null && comparator.compare(obj, literal) != 0)); + break; + case 6: // between + if (literal == null) { + return; + } + int toIdx = Math.min(idx + random.nextInt(200) + 1, dataNum - 1); + Object toLiteral = data.get(toIdx).getKey(); + if (toLiteral == null) { + return; + } + result = reader.visitBetween(ref, literal, toLiteral).join().get(); + assertResult( + result, + filter( + obj -> + obj != null + && comparator.compare(obj, toLiteral) <= 0 + && comparator.compare(obj, literal) >= 0)); + break; + case 7: // in — many sub-queries hitting cache concurrently + { + List inLiterals = new ArrayList<>(); + for (int j = 0; j < 20; j++) { + Object l = data.get(random.nextInt(dataNum)).getKey(); + if (l != null) { + inLiterals.add(l); + } + } + if (inLiterals.isEmpty()) { + return; + } + TreeSet inSet = new TreeSet<>(comparator); + inSet.addAll(inLiterals); + result = reader.visitIn(ref, inLiterals).join().get(); + assertResult(result, filter(obj -> obj != null && inSet.contains(obj))); + break; + } + case 8: // notIn + { + List notInLiterals = new ArrayList<>(); + for (int j = 0; j < 20; j++) { + Object l = data.get(random.nextInt(dataNum)).getKey(); + if (l != null) { + notInLiterals.add(l); + } + } + if (notInLiterals.isEmpty()) { + return; + } + TreeSet notInSet = new TreeSet<>(comparator); + notInSet.addAll(notInLiterals); + result = reader.visitNotIn(ref, notInLiterals).join().get(); + assertResult(result, filter(obj -> obj != null && !notInSet.contains(obj))); + break; + } + case 9: // isNull — tests LazyField concurrent init + result = reader.visitIsNull(ref).join().get(); + assertResult(result, filter(Objects::isNull)); + break; + case 10: // isNotNull + result = reader.visitIsNotNull(ref).join().get(); + assertResult(result, filter(Objects::nonNull)); + break; + default: + break; + } + } } diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexReader.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexReader.java index affc79e916dc..ec5c73cd7cce 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexReader.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexReader.java @@ -39,6 +39,7 @@ import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; +import java.util.concurrent.CompletableFuture; /** * Test full-text index reader that performs brute-force text matching. Loads all documents into @@ -62,7 +63,8 @@ public TestFullTextGlobalIndexReader( } @Override - public Optional visitFullTextSearch(FullTextSearch fullTextSearch) { + public CompletableFuture> visitFullTextSearch( + FullTextSearch fullTextSearch) { try { ensureLoaded(); } catch (IOException e) { @@ -73,7 +75,7 @@ public Optional visitFullTextSearch(FullTextSearch full int limit = fullTextSearch.limit(); int effectiveK = Math.min(limit, count); if (effectiveK <= 0) { - return Optional.empty(); + return CompletableFuture.completedFuture(Optional.empty()); } String[] queryTerms = queryText.toLowerCase(Locale.ROOT).split("\\s+"); @@ -96,7 +98,7 @@ public Optional visitFullTextSearch(FullTextSearch full } if (topK.isEmpty()) { - return Optional.empty(); + return CompletableFuture.completedFuture(Optional.empty()); } RoaringNavigableMap64 resultBitmap = new RoaringNavigableMap64(); @@ -106,7 +108,8 @@ public Optional visitFullTextSearch(FullTextSearch full scoreMap.put(row.rowId, row.score); } - return Optional.of(ScoredGlobalIndexResult.create(() -> resultBitmap, scoreMap::get)); + return CompletableFuture.completedFuture( + Optional.of(ScoredGlobalIndexResult.create(resultBitmap, scoreMap::get))); } private static float computeScore(String document, String[] queryTerms) { @@ -172,73 +175,85 @@ public void close() throws IOException { // =================== unsupported predicate operations ===================== @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } /** A row ID paired with its similarity score, used in the top-k min-heap. */ diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexer.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexer.java index d9d3d14ff702..a7fdc0d49da9 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexer.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/testfulltext/TestFullTextGlobalIndexer.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.util.List; +import java.util.concurrent.ExecutorService; import static org.apache.paimon.utils.Preconditions.checkArgument; @@ -55,7 +56,9 @@ public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) throws I @Override public GlobalIndexReader createReader( - GlobalIndexFileReader fileReader, List files) throws IOException { + GlobalIndexFileReader fileReader, + List files, + ExecutorService executor) { checkArgument(files.size() == 1, "Expected exactly one index file per shard"); return new TestFullTextGlobalIndexReader(fileReader, files.get(0)); } diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexReader.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexReader.java index dc7fb4b6a6c7..da7f533f8ce9 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexReader.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexReader.java @@ -37,6 +37,7 @@ import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; +import java.util.concurrent.CompletableFuture; /** * Test vector index reader that performs brute-force linear scan for similarity search. Loads all @@ -68,7 +69,8 @@ public TestVectorGlobalIndexReader( } @Override - public Optional visitVectorSearch(VectorSearch vectorSearch) { + public CompletableFuture> visitVectorSearch( + VectorSearch vectorSearch) { try { ensureLoaded(); } catch (IOException e) { @@ -86,7 +88,7 @@ public Optional visitVectorSearch(VectorSearch vectorSe int limit = vectorSearch.limit(); int effectiveK = Math.min(limit, count); if (effectiveK <= 0) { - return Optional.empty(); + return CompletableFuture.completedFuture(Optional.empty()); } RoaringNavigableMap64 includeRowIds = vectorSearch.includeRowIds(); @@ -115,7 +117,8 @@ public Optional visitVectorSearch(VectorSearch vectorSe scoreMap.put(row.rowId, row.score); } - return Optional.of(ScoredGlobalIndexResult.create(() -> resultBitmap, scoreMap::get)); + return CompletableFuture.completedFuture( + Optional.of(ScoredGlobalIndexResult.create(resultBitmap, scoreMap::get))); } private float computeScore(float[] query, float[] stored) { @@ -213,73 +216,85 @@ public void close() throws IOException { // =================== unsupported predicate operations ===================== @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } /** A row ID paired with its similarity score, used in the top-k min-heap. */ diff --git a/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexer.java b/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexer.java index d6ee516f1225..d12da8ffcab2 100644 --- a/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexer.java +++ b/paimon-common/src/test/java/org/apache/paimon/globalindex/testvector/TestVectorGlobalIndexer.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.List; +import java.util.concurrent.ExecutorService; import static org.apache.paimon.utils.Preconditions.checkArgument; @@ -75,7 +76,9 @@ public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) throws I @Override public GlobalIndexReader createReader( - GlobalIndexFileReader fileReader, List files) throws IOException { + GlobalIndexFileReader fileReader, + List files, + ExecutorService executor) { checkArgument(files.size() == 1, "Expected exactly one index file per shard"); return new TestVectorGlobalIndexReader(fileReader, files.get(0), metric); } diff --git a/paimon-core/src/main/java/org/apache/paimon/AbstractFileStore.java b/paimon-core/src/main/java/org/apache/paimon/AbstractFileStore.java index 9aad0259cdf8..3aa399186e80 100644 --- a/paimon-core/src/main/java/org/apache/paimon/AbstractFileStore.java +++ b/paimon-core/src/main/java/org/apache/paimon/AbstractFileStore.java @@ -19,6 +19,7 @@ package org.apache.paimon; import org.apache.paimon.CoreOptions.ExternalPathStrategy; +import org.apache.paimon.catalog.Identifier; import org.apache.paimon.catalog.RenamingSnapshotCommit; import org.apache.paimon.catalog.SnapshotCommit; import org.apache.paimon.catalog.TableRollback; @@ -38,16 +39,19 @@ import org.apache.paimon.metastore.ChainTableOverwriteCommitCallback; import org.apache.paimon.metastore.TagPreviewCommitCallback; import org.apache.paimon.metastore.VisibilityWaitCallback; +import org.apache.paimon.operation.ChainTablePartitionExpire; import org.apache.paimon.operation.ChangelogDeletion; import org.apache.paimon.operation.FileStoreCommitImpl; import org.apache.paimon.operation.Lock; import org.apache.paimon.operation.ManifestsReader; +import org.apache.paimon.operation.NormalPartitionExpire; import org.apache.paimon.operation.PartitionExpire; import org.apache.paimon.operation.SnapshotDeletion; import org.apache.paimon.operation.TagDeletion; import org.apache.paimon.operation.commit.CommitRollback; import org.apache.paimon.operation.commit.ConflictDetection; import org.apache.paimon.partition.PartitionExpireStrategy; +import org.apache.paimon.partition.PartitionValuesTimeExpireStrategy; import org.apache.paimon.schema.SchemaManager; import org.apache.paimon.schema.TableSchema; import org.apache.paimon.service.ServiceManager; @@ -65,6 +69,7 @@ import org.apache.paimon.tag.TagAutoManager; import org.apache.paimon.tag.TagPreview; import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.ChainTableUtils; import org.apache.paimon.utils.ChangelogManager; import org.apache.paimon.utils.FileStorePathFactory; import org.apache.paimon.utils.IndexFilePathFactories; @@ -266,9 +271,17 @@ public CoreOptions options() { } @Override - public boolean mergeSchema(RowType rowType, boolean allowExplicitCast) { + public boolean mergeSchema( + RowType rowType, + boolean typeWidening, + boolean allowExplicitCast, + boolean caseSensitive) { return schemaManager.mergeSchema( - rowType, allowExplicitCast, catalogEnvironment.schemaModification()); + rowType, + typeWidening, + allowExplicitCast, + caseSensitive, + catalogEnvironment.schemaModification()); } @Override @@ -440,6 +453,10 @@ public PartitionExpire newPartitionExpire(String commitUser, FileStoreTable tabl return null; } + if (options.isChainTable()) { + return newChainTablePartitionExpire(table); + } + return newPartitionExpire( commitUser, table, @@ -459,12 +476,19 @@ public PartitionExpire newPartitionExpire( Duration expirationTime, Duration checkInterval, PartitionExpireStrategy expireStrategy) { + if (options.isChainTable()) { + checkArgument( + expireStrategy instanceof PartitionValuesTimeExpireStrategy, + "Chain table only supports 'values-time' partition expiration strategy."); + return newChainTablePartitionExpire(table, expirationTime, checkInterval); + } + PartitionModification partitionModification = null; if (options.partitionedTableInMetastore()) { partitionModification = catalogEnvironment.partitionModification(); } - return new PartitionExpire( + return new NormalPartitionExpire( expirationTime, checkInterval, expireStrategy, @@ -476,6 +500,60 @@ public PartitionExpire newPartitionExpire( options.partitionExpireBatchSize()); } + @Nullable + private ChainTablePartitionExpire newChainTablePartitionExpire(FileStoreTable table) { + Duration partitionExpireTime = options.partitionExpireTime(); + if (partitionExpireTime == null) { + return null; + } + return newChainTablePartitionExpire( + table, partitionExpireTime, options.partitionExpireCheckInterval()); + } + + @Nullable + private ChainTablePartitionExpire newChainTablePartitionExpire( + FileStoreTable table, Duration expirationTime, Duration checkInterval) { + if (partitionType().getFieldCount() == 0) { + return null; + } + FileStoreTable primaryTable = ChainTableUtils.resolveChainPrimaryTable(table); + FileStoreTable snapshotTable = + primaryTable.switchToBranch(options.scanFallbackSnapshotBranch()); + FileStoreTable deltaTable = primaryTable.switchToBranch(options.scanFallbackDeltaBranch()); + return new ChainTablePartitionExpire( + expirationTime, + checkInterval, + snapshotTable, + deltaTable, + options, + partitionType(), + options.endInputCheckPartitionExpire(), + options.partitionExpireMaxNum(), + options.partitionExpireBatchSize(), + newPartitionModificationForBranch(options.scanFallbackSnapshotBranch()), + newPartitionModificationForBranch(options.scanFallbackDeltaBranch())); + } + + @Nullable + private PartitionModification newPartitionModificationForBranch(String branchName) { + if (!options.partitionedTableInMetastore()) { + return null; + } + + Identifier identifier = catalogEnvironment.identifier(); + if (identifier == null) { + return catalogEnvironment.partitionModification(); + } + + Identifier branchIdentifier = + new Identifier( + identifier.getDatabaseName(), + identifier.getTableName(), + branchName, + identifier.getSystemTableName()); + return catalogEnvironment.copy(branchIdentifier).partitionModification(); + } + @Override public TagAutoManager newTagAutoManager(FileStoreTable table) { return TagAutoManager.create( diff --git a/paimon-core/src/main/java/org/apache/paimon/FileStore.java b/paimon-core/src/main/java/org/apache/paimon/FileStore.java index 98905b47e56c..1714eec2a983 100644 --- a/paimon-core/src/main/java/org/apache/paimon/FileStore.java +++ b/paimon-core/src/main/java/org/apache/paimon/FileStore.java @@ -117,7 +117,11 @@ PartitionExpire newPartitionExpire( ServiceManager newServiceManager(); - boolean mergeSchema(RowType rowType, boolean allowExplicitCast); + boolean mergeSchema( + RowType rowType, + boolean typeWidening, + boolean allowExplicitCast, + boolean caseSensitive); List createTagCallbacks(FileStoreTable table); diff --git a/paimon-core/src/main/java/org/apache/paimon/KeyValue.java b/paimon-core/src/main/java/org/apache/paimon/KeyValue.java index c314f21107a7..4c4ff61c24b2 100644 --- a/paimon-core/src/main/java/org/apache/paimon/KeyValue.java +++ b/paimon-core/src/main/java/org/apache/paimon/KeyValue.java @@ -89,6 +89,11 @@ public long sequenceNumber() { return sequenceNumber; } + public KeyValue setSequenceNumber(long sequenceNumber) { + this.sequenceNumber = sequenceNumber; + return this; + } + public RowKind valueKind() { return valueKind; } diff --git a/paimon-core/src/main/java/org/apache/paimon/append/DedicatedFormatRollingFileWriter.java b/paimon-core/src/main/java/org/apache/paimon/append/DedicatedFormatRollingFileWriter.java index ed89ad8cc414..3be374ff9dca 100644 --- a/paimon-core/src/main/java/org/apache/paimon/append/DedicatedFormatRollingFileWriter.java +++ b/paimon-core/src/main/java/org/apache/paimon/append/DedicatedFormatRollingFileWriter.java @@ -103,7 +103,7 @@ public class DedicatedFormatRollingFileWriter private static final long CHECK_ROLLING_RECORD_CNT = 1000L; // Core components - private final Supplier< + private final @Nullable Supplier< ProjectedFileWriter, DataFileMeta>> writerFactory; private final @Nullable Supplier blobWriterFactory; @@ -168,21 +168,25 @@ public DedicatedFormatRollingFileWriter( } } - this.writerFactory = - createNormalWriterFactory( - fileIO, - schemaId, - fileFormat, - fieldsInNormalFile, - writeSchema, - pathFactory, - seqNumCounterSupplier, - fileCompression, - statsCollectorFactories, - fileIndexOptions, - fileSource, - asyncFileWrite, - statsDenseStore); + if (fieldsInNormalFile.isEmpty()) { + this.writerFactory = null; + } else { + this.writerFactory = + createNormalWriterFactory( + fileIO, + schemaId, + fileFormat, + fieldsInNormalFile, + writeSchema, + pathFactory, + seqNumCounterSupplier, + fileCompression, + statsCollectorFactories, + fileIndexOptions, + fileSource, + asyncFileWrite, + statsDenseStore); + } if (context != null) { this.blobWriterFactory = @@ -353,7 +357,7 @@ public void write(InternalRow row) throws IOException { ? externalStorageBlobWriter.transformRow(row) : row; - if (currentWriter == null) { + if (writerFactory != null && currentWriter == null) { currentWriter = writerFactory.get(); } if ((blobWriter == null) && (blobWriterFactory != null)) { @@ -368,10 +372,12 @@ public void write(InternalRow row) throws IOException { if (vectorStoreWriter != null) { vectorStoreWriter.write(transformedRow); } - currentWriter.write(transformedRow); + if (currentWriter != null) { + currentWriter.write(transformedRow); + } recordCount++; - if (rollingFile()) { + if (currentWriter != null && rollingFile()) { closeCurrentWriter(); } } catch (Throwable e) { @@ -382,7 +388,7 @@ public void write(InternalRow row) throws IOException { /** Handles write exceptions by logging and cleaning up resources. */ private void handleWriteException(Throwable e) { - String filePath = (currentWriter == null) ? null : currentWriter.writer().path().toString(); + String filePath = currentWriter == null ? null : currentWriter.writer().path().toString(); LOG.warn("Exception occurs when writing file {}. Cleaning up.", filePath, e); abort(); } @@ -451,12 +457,12 @@ private boolean rollingFile() throws IOException { * @throws IOException if closing fails */ private void closeCurrentWriter() throws IOException { - if (currentWriter == null) { + if (currentWriter == null && blobWriter == null && vectorStoreWriter == null) { return; } // Close main writer and get metadata - DataFileMeta mainDataFileMeta = closeMainWriter(); + DataFileMeta mainDataFileMeta = currentWriter == null ? null : closeMainWriter(); // Close blob writer and process blob metadata List blobMetas = closeBlobWriter(); @@ -464,11 +470,13 @@ private void closeCurrentWriter() throws IOException { // Close vector-store writer and process vector-store metadata List vectorStoreMetas = closeVectorStoreWriter(); - // Validate consistency between main and blob files - validateFileConsistency(mainDataFileMeta, blobMetas, vectorStoreMetas); + if (mainDataFileMeta != null) { + // Validate consistency between main and blob files + validateFileConsistency(mainDataFileMeta, blobMetas, vectorStoreMetas); + results.add(mainDataFileMeta); + } // Add results to the results list - results.add(mainDataFileMeta); results.addAll(blobMetas); results.addAll(vectorStoreMetas); diff --git a/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinator.java b/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinator.java index fc65fe95b653..c41493564c5c 100644 --- a/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinator.java +++ b/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinator.java @@ -434,11 +434,36 @@ private List> blobFileGroupsToCompact(List blob private List> fileGroupsToCompact(List files) { List> result = new ArrayList<>(); List sortedFiles = new ArrayList<>(files); - sortedFiles.sort(comparingLong(DataFileMeta::nonNullFirstRowId)); + sortedFiles.sort( + comparingLong(DataFileMeta::nonNullFirstRowId) + .thenComparingLong(DataFileMeta::maxSequenceNumber)); + + RangeHelper rangeHelper = + new RangeHelper<>(DataFileMeta::nonNullRowIdRange); + List smallFileCandidates = new ArrayList<>(); + for (List rowRangeGroup : + rangeHelper.mergeOverlappingRanges(sortedFiles)) { + if (rowRangeGroup.size() >= BLOB_COMPACT_MIN_FILE_NUM) { + rowRangeGroup.sort( + comparingLong(DataFileMeta::nonNullFirstRowId) + .thenComparingLong(DataFileMeta::maxSequenceNumber)); + result.add(rowRangeGroup); + } else { + smallFileCandidates.add(rowRangeGroup.get(0)); + } + } + + result.addAll(smallFileGroupsToCompact(smallFileCandidates)); + result.sort(comparingLong(group -> group.get(0).nonNullFirstRowId())); + return result; + } + + private List> smallFileGroupsToCompact(List files) { + List> result = new ArrayList<>(); List continuousFiles = new ArrayList<>(); long expectedFirstRowId = -1; - for (DataFileMeta file : sortedFiles) { + for (DataFileMeta file : files) { if (file.fileSize() >= blobTargetFileSize) { addFileGroupsToCompact(result, continuousFiles); continuousFiles.clear(); diff --git a/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactTask.java b/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactTask.java index e56fe32f82a7..f8bdc959c3ce 100644 --- a/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactTask.java +++ b/paimon-core/src/main/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactTask.java @@ -45,6 +45,7 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.utils.FileStorePathFactory; import org.apache.paimon.utils.LongCounter; +import org.apache.paimon.utils.Range; import org.apache.paimon.utils.RecordWriter; import org.apache.paimon.utils.SetUtils; @@ -191,7 +192,7 @@ private CommitMessage doCompactBlobFiles(FileStoreTable table, String commitUser CoreOptions options = table.coreOptions(); List sortedCompactBefore = sortedByFirstRowId(compactBefore); DataField blobField = blobField(table, options, sortedCompactBefore); - checkRowIdsContinuous(sortedCompactBefore); + Range compactBeforeRange = checkContiguousRowRange(sortedCompactBefore); checkArgument( sortedCompactBefore.size() > 1, "Blob compaction task %s should contain at least two files to compact.", @@ -232,12 +233,11 @@ private CommitMessage doCompactBlobFiles(FileStoreTable table, String commitUser throw e; } - long firstRowId = sortedCompactBefore.get(0).nonNullFirstRowId(); long minSequenceId = minSequenceId(sortedCompactBefore); long maxSequenceId = maxSequenceId(sortedCompactBefore); DataFileMeta compactedFile = writer.result() - .assignFirstRowId(firstRowId) + .assignFirstRowId(compactBeforeRange.from) .assignSequenceNumber(minSequenceId, maxSequenceId); compactAfter.add(compactedFile); checkArgument(compactAfter.size() == 1, "Blob file compaction should produce one file."); @@ -326,46 +326,30 @@ private DataField blobField( return field; } - private void checkRowIdsContinuous(List files) { - checkArgument(!files.isEmpty(), "%s should not be empty.", "Blob compact before files"); - long expectedFirstRowId = files.get(0).nonNullFirstRowId(); - for (DataFileMeta file : files) { - long firstRowId = file.nonNullFirstRowId(); - checkArgument( - firstRowId == expectedFirstRowId, - "%s should be continuous and sorted by row id, expected %s but got %s in file %s.", - "Blob compact before files", - expectedFirstRowId, - firstRowId, - file); - expectedFirstRowId += file.rowCount(); - } + private Range checkContiguousRowRange(List files) { + checkArgument(!files.isEmpty(), "%s should not be empty.", "Blob compact files"); + List ranges = + files.stream().map(DataFileMeta::nonNullRowIdRange).collect(Collectors.toList()); + List merged = Range.sortAndMergeOverlap(ranges, true); + checkArgument( + merged.size() == 1, + "%s should have a contiguous row range, but got %s.", + "Blob compact files", + merged); + return merged.get(0); } private void checkSameRowRange( List compactBefore, List compactAfter) { + Range beforeRange = checkContiguousRowRange(compactBefore); + Range afterRange = checkContiguousRowRange(compactAfter); checkArgument( - !compactBefore.isEmpty(), - "%s compact before files should not be empty.", - "Blob compact files"); - checkArgument( - !compactAfter.isEmpty(), - "%s compact after files should not be empty.", - "Blob compact files"); - long beforeFirstRowId = compactBefore.get(0).nonNullFirstRowId(); - long afterFirstRowId = compactAfter.get(0).nonNullFirstRowId(); - long beforeRowCount = compactBefore.stream().mapToLong(DataFileMeta::rowCount).sum(); - long afterRowCount = compactAfter.stream().mapToLong(DataFileMeta::rowCount).sum(); - checkArgument( - beforeFirstRowId == afterFirstRowId && beforeRowCount == afterRowCount, + beforeRange.equals(afterRange), "%s compact after files should have the same row range as compact before files, " - + "before first row id is %s with row count %s, " - + "but after first row id is %s with row count %s.", + + "before range is %s, but after range is %s.", "Blob compact files", - beforeFirstRowId, - beforeRowCount, - afterFirstRowId, - afterRowCount); + beforeRange, + afterRange); } private long minSequenceId(List files) { diff --git a/paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexReadThreadPool.java b/paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexReadThreadPool.java new file mode 100644 index 000000000000..d202003e6678 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexReadThreadPool.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.globalindex; + +import org.apache.paimon.utils.SemaphoredDelegatingExecutor; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadPoolExecutor; + +import static org.apache.paimon.utils.ThreadPoolUtils.createCachedThreadPool; + +/** Shared thread pool for global index read operations. */ +public class GlobalIndexReadThreadPool { + + private static final String THREAD_NAME = "GLOBAL-INDEX-READ-POOL"; + + private static ThreadPoolExecutor executorService = + createCachedThreadPool(Runtime.getRuntime().availableProcessors(), THREAD_NAME); + + public static synchronized ExecutorService getExecutorService(int threadNum) { + if (threadNum == executorService.getMaximumPoolSize()) { + return executorService; + } + if (threadNum < executorService.getMaximumPoolSize()) { + return new SemaphoredDelegatingExecutor(executorService, threadNum, false); + } else { + executorService = createCachedThreadPool(threadNum, THREAD_NAME); + return executorService; + } + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexScanner.java b/paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexScanner.java index 3e591380191a..975b28183331 100644 --- a/paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexScanner.java +++ b/paimon-core/src/main/java/org/apache/paimon/globalindex/GlobalIndexScanner.java @@ -45,6 +45,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.function.IntFunction; import java.util.stream.Collectors; @@ -53,7 +54,6 @@ import static org.apache.paimon.predicate.PredicateVisitor.collectFieldNames; import static org.apache.paimon.table.source.snapshot.TimeTravelUtil.tryTravelOrLatest; import static org.apache.paimon.utils.Preconditions.checkNotNull; -import static org.apache.paimon.utils.ThreadPoolUtils.createCachedThreadPool; /** Scanner for shard-based global indexes. */ public class GlobalIndexScanner implements Closeable { @@ -70,10 +70,8 @@ public GlobalIndexScanner( IndexPathFactory indexPathFactory, Collection indexFiles) { this.options = options; - Integer threadNum = options.get(GLOBAL_INDEX_THREAD_NUM); - int parallelism = - threadNum != null ? threadNum : Runtime.getRuntime().availableProcessors(); - this.executor = createCachedThreadPool(parallelism, "GLOBAL-INDEX-POOL"); + this.executor = + GlobalIndexReadThreadPool.getExecutorService(options.get(GLOBAL_INDEX_THREAD_NUM)); this.indexPathFactory = indexPathFactory; GlobalIndexFileReader indexFileReader = meta -> fileIO.newInputStream(meta.filePath()); Map>>> indexMetas = new HashMap<>(); @@ -96,7 +94,7 @@ public GlobalIndexScanner( indexFileReader, indexMetas.get(fieldId), rowType.getField(fieldId)); - this.globalIndexEvaluator = new GlobalIndexEvaluator(rowType, readersFunction, executor); + this.globalIndexEvaluator = new GlobalIndexEvaluator(rowType, readersFunction); } public static Optional create( @@ -153,35 +151,37 @@ private Collection createReaders( } Set readers = new HashSet<>(); - try { - for (Map.Entry>> entry : indexMetas.entrySet()) { - String indexType = entry.getKey(); - Map> metas = entry.getValue(); - GlobalIndexerFactory globalIndexerFactory = - GlobalIndexerFactoryUtils.load(indexType); - GlobalIndexer globalIndexer = globalIndexerFactory.create(dataField, options); - - List unionReader = new ArrayList<>(); - for (Map.Entry> rangeMetas : metas.entrySet()) { - Range range = rangeMetas.getKey(); - List indexFileMetas = rangeMetas.getValue(); - - List globalMetas = - indexFileMetas.stream() - .map(this::toGlobalMeta) - .collect(Collectors.toList()); - GlobalIndexReader innerReader = - new OffsetGlobalIndexReader( - globalIndexer.createReader(indexFileReadWrite, globalMetas), - range.from, - range.to); - unionReader.add(innerReader); - } - - readers.add(new UnionGlobalIndexReader(unionReader)); + for (Map.Entry>> entry : indexMetas.entrySet()) { + String indexType = entry.getKey(); + Map> metas = entry.getValue(); + GlobalIndexerFactory globalIndexerFactory = GlobalIndexerFactoryUtils.load(indexType); + GlobalIndexer globalIndexer = globalIndexerFactory.create(dataField, options); + + List> futures = new ArrayList<>(metas.size()); + for (Map.Entry> rangeMetas : metas.entrySet()) { + Range range = rangeMetas.getKey(); + List indexFileMetas = rangeMetas.getValue(); + List globalMetas = + indexFileMetas.stream() + .map(this::toGlobalMeta) + .collect(Collectors.toList()); + futures.add( + CompletableFuture.supplyAsync( + () -> + new OffsetGlobalIndexReader( + globalIndexer.createReader( + indexFileReadWrite, globalMetas, executor), + range.from, + range.to), + executor)); } - } catch (IOException e) { - throw new RuntimeException("Failed to create global index reader", e); + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + List unionReader = new ArrayList<>(futures.size()); + for (CompletableFuture future : futures) { + unionReader.add(future.join()); + } + readers.add(new UnionGlobalIndexReader(unionReader)); } return readers; @@ -197,6 +197,5 @@ private GlobalIndexIOMeta toGlobalMeta(IndexFileMeta meta) { @Override public void close() throws IOException { globalIndexEvaluator.close(); - executor.shutdownNow(); } } diff --git a/paimon-core/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilder.java b/paimon-core/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilder.java index 9f143d6cf666..ad68d83eb340 100644 --- a/paimon-core/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilder.java +++ b/paimon-core/src/main/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilder.java @@ -74,8 +74,10 @@ import java.util.stream.IntStream; import static java.util.Collections.singletonList; +import static org.apache.paimon.format.blob.BlobFileFormat.isBlobFile; import static org.apache.paimon.globalindex.GlobalIndexBuilderUtils.createIndexWriter; import static org.apache.paimon.globalindex.GlobalIndexBuilderUtils.toIndexFileMetas; +import static org.apache.paimon.types.VectorType.isVectorStoreFile; import static org.apache.paimon.utils.Preconditions.checkArgument; /** Builder to build btree global index. */ @@ -143,7 +145,7 @@ public Optional>> scan() { if (snapshot == null) { return Optional.empty(); } - snapshotReader = snapshotReader.withSnapshot(snapshot); + snapshotReader = withManifestEntryFilter(snapshotReader.withSnapshot(snapshot)); Range dataRange = new Range(0, snapshot.nextRowId() - 1); return Optional.of( @@ -164,7 +166,7 @@ public Optional>> incrementalScan() { if (snapshot == null) { return Optional.empty(); } - snapshotReader = snapshotReader.withSnapshot(snapshot); + snapshotReader = withManifestEntryFilter(snapshotReader.withSnapshot(snapshot)); Preconditions.checkArgument(indexField != null, "indexField must be set before scan."); Range dataRange = new Range(0, snapshot.nextRowId() - 1); @@ -180,6 +182,13 @@ public Optional>> incrementalScan() { snapshotReader.read().dataSplits())); } + private SnapshotReader withManifestEntryFilter(SnapshotReader snapshotReader) { + return snapshotReader.withManifestEntryFilter( + entry -> + !isBlobFile(entry.file().fileName()) + && !isVectorStoreFile(entry.file().fileName())); + } + private List indexedRowRanges(Snapshot snapshot) { List ranges = new ArrayList<>(); for (IndexManifestEntry entry : diff --git a/paimon-core/src/main/java/org/apache/paimon/index/IndexFileHandler.java b/paimon-core/src/main/java/org/apache/paimon/index/IndexFileHandler.java index b211dd593c0d..cb9525cc5c87 100644 --- a/paimon-core/src/main/java/org/apache/paimon/index/IndexFileHandler.java +++ b/paimon-core/src/main/java/org/apache/paimon/index/IndexFileHandler.java @@ -20,11 +20,13 @@ import org.apache.paimon.Snapshot; import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.InternalRow; import org.apache.paimon.deletionvectors.DeletionVector; import org.apache.paimon.deletionvectors.DeletionVectorsIndexFile; import org.apache.paimon.fs.FileIO; import org.apache.paimon.fs.Path; import org.apache.paimon.manifest.IndexManifestEntry; +import org.apache.paimon.manifest.IndexManifestEntrySerializer; import org.apache.paimon.manifest.IndexManifestFile; import org.apache.paimon.options.MemorySize; import org.apache.paimon.utils.Filter; @@ -37,10 +39,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import static org.apache.paimon.index.HashIndexFile.HASH_INDEX; @@ -149,6 +153,20 @@ public Map, List> scan( return result; } + public Map, List> scanBuckets( + Snapshot snapshot, String indexType, Set> buckets) { + if (buckets.isEmpty()) { + return Collections.emptyMap(); + } + + Map, List> result = new HashMap<>(); + for (IndexManifestEntry file : scanBucketEntries(snapshot, indexType, buckets)) { + result.computeIfAbsent(Pair.of(file.partition(), file.bucket()), k -> new ArrayList<>()) + .add(file.indexFile()); + } + return result; + } + public List scanEntries() { Snapshot snapshot = snapshotManager.latestSnapshot(); if (snapshot == null || snapshot.indexManifest() == null) { @@ -184,6 +202,37 @@ public List scanEntries( return result; } + public List scanBucketEntries( + Snapshot snapshot, String indexType, Set> buckets) { + if (snapshot == null || buckets.isEmpty()) { + return Collections.emptyList(); + } + String indexManifest = snapshot.indexManifest(); + if (indexManifest == null) { + return Collections.emptyList(); + } + + Function partitionGetter = + IndexManifestEntrySerializer.partitionGetter(); + Function bucketGetter = IndexManifestEntrySerializer.bucketGetter(); + Function indexTypeGetter = + IndexManifestEntrySerializer.indexTypeGetter(); + Map> bucketsByPartition = new HashMap<>(); + for (Pair bucket : buckets) { + bucketsByPartition + .computeIfAbsent(bucket.getLeft(), k -> new HashSet<>()) + .add(bucket.getRight()); + } + Filter rowFilter = + row -> + indexType.equals(indexTypeGetter.apply(row)) + && bucketsByPartition + .getOrDefault( + partitionGetter.apply(row), Collections.emptySet()) + .contains(bucketGetter.apply(row)); + return indexManifestFile.read(indexManifest, null, rowFilter, Filter.alwaysTrue()); + } + public Path indexManifestFilePath(String indexManifest) { return indexManifestFile.indexManifestFilePath(indexManifest); } diff --git a/paimon-core/src/main/java/org/apache/paimon/io/DataFileMeta.java b/paimon-core/src/main/java/org/apache/paimon/io/DataFileMeta.java index 23b34947bcc9..a8cdc031e134 100644 --- a/paimon-core/src/main/java/org/apache/paimon/io/DataFileMeta.java +++ b/paimon-core/src/main/java/org/apache/paimon/io/DataFileMeta.java @@ -277,8 +277,15 @@ static DataFileMeta create( SimpleStats valueStats(); + /** + * Minimum sequence number of records in this file. When {@code sequence.snapshot-ordering} is + * enabled for a primary-key table, this field is repurposed to carry the commit snapshot id + * instead of the per-record sequence number range (the snapshot id is stamped into it at commit + * time by {@code FileStoreCommitImpl}). + */ long minSequenceNumber(); + /** @see #minSequenceNumber() */ long maxSequenceNumber(); long schemaId(); diff --git a/paimon-core/src/main/java/org/apache/paimon/io/KeyValueDataFileRecordReader.java b/paimon-core/src/main/java/org/apache/paimon/io/KeyValueDataFileRecordReader.java index 6cf08769703f..2538c08a2127 100644 --- a/paimon-core/src/main/java/org/apache/paimon/io/KeyValueDataFileRecordReader.java +++ b/paimon-core/src/main/java/org/apache/paimon/io/KeyValueDataFileRecordReader.java @@ -36,12 +36,26 @@ public class KeyValueDataFileRecordReader implements FileRecordReader private final FileRecordReader reader; private final KeyValueSerializer serializer; private final int level; + private final boolean overrideSequenceWithSnapshotId; + private final long snapshotId; public KeyValueDataFileRecordReader( FileRecordReader reader, RowType keyType, RowType valueType, int level) { + this(reader, keyType, valueType, level, false, KeyValue.UNKNOWN_SEQUENCE); + } + + public KeyValueDataFileRecordReader( + FileRecordReader reader, + RowType keyType, + RowType valueType, + int level, + boolean overrideSequenceWithSnapshotId, + long snapshotId) { this.reader = reader; this.serializer = new KeyValueSerializer(keyType, valueType); this.level = level; + this.overrideSequenceWithSnapshotId = overrideSequenceWithSnapshotId; + this.snapshotId = snapshotId; } @Nullable @@ -53,10 +67,20 @@ public FileRecordIterator readBatch() throws IOException { } return iterator.transform( - internalRow -> - internalRow == null - ? null - : serializer.fromRow(internalRow).setLevel(level)); + internalRow -> { + if (internalRow == null) { + return null; + } + KeyValue kv = serializer.fromRow(internalRow).setLevel(level); + // In snapshot-ordering mode, an APPEND file's on-disk per-record sequence + // numbers are stale; we override them with the commit snapshot id so later + // snapshots win during merge. Any read path bypassing this override would + // order APPEND records incorrectly. + if (overrideSequenceWithSnapshotId) { + kv.setSequenceNumber(snapshotId); + } + return kv; + }); } @Override diff --git a/paimon-core/src/main/java/org/apache/paimon/io/KeyValueFileReaderFactory.java b/paimon-core/src/main/java/org/apache/paimon/io/KeyValueFileReaderFactory.java index fc505e19c270..5f7e3741927e 100644 --- a/paimon-core/src/main/java/org/apache/paimon/io/KeyValueFileReaderFactory.java +++ b/paimon-core/src/main/java/org/apache/paimon/io/KeyValueFileReaderFactory.java @@ -30,6 +30,7 @@ import org.apache.paimon.format.OrcFormatReaderContext; import org.apache.paimon.fs.FileIO; import org.apache.paimon.fs.Path; +import org.apache.paimon.manifest.FileSource; import org.apache.paimon.partition.PartitionUtils; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.reader.FileRecordReader; @@ -42,6 +43,7 @@ import org.apache.paimon.utils.AsyncRecordReader; import org.apache.paimon.utils.FileStorePathFactory; import org.apache.paimon.utils.FormatReaderMapping; +import org.apache.paimon.utils.Preconditions; import javax.annotation.Nullable; @@ -68,6 +70,7 @@ public class KeyValueFileReaderFactory implements FileReaderFactory { private final long asyncThreshold; private final boolean ignoreCorruptFiles; private final boolean ignoreLostFiles; + private final boolean snapshotSequenceOrdering; private final Map formatReaderMappings; private final BinaryRow partition; private final DeletionVector.Factory dvFactory; @@ -93,6 +96,7 @@ protected KeyValueFileReaderFactory( this.asyncThreshold = coreOptions.fileReaderAsyncThreshold().getBytes(); this.ignoreCorruptFiles = coreOptions.scanIgnoreCorruptFile(); this.ignoreLostFiles = coreOptions.scanIgnoreLostFile(); + this.snapshotSequenceOrdering = coreOptions.snapshotSequenceOrdering(); this.partition = partition; this.formatReaderMappings = new HashMap<>(); this.dvFactory = dvFactory; @@ -168,7 +172,26 @@ private FileRecordReader createRecordReader( new ApplyDeletionVectorReader(fileRecordReader, deletionVector.get()); } - return new KeyValueDataFileRecordReader(fileRecordReader, keyType, valueType, file.level()); + // In snapshot-ordering mode, APPEND files carry the commit snapshot id in + // minSequenceNumber (stamped at commit time); override per-record sequence with it so + // later snapshots win during merge. COMPACT files already carry the snapshot id in their + // per-record _SEQUENCE_NUMBER and are left untouched. + boolean overrideSequenceWithSnapshotId = false; + if (snapshotSequenceOrdering) { + Preconditions.checkState( + file.fileSource().isPresent(), + "sequence.snapshot-ordering requires data files with fileSource metadata. " + + "This option is only safe for newly-created tables or empty tables. " + + "Legacy files without fileSource cannot be ordered by commit snapshot id."); + overrideSequenceWithSnapshotId = file.fileSource().get() == FileSource.APPEND; + } + return new KeyValueDataFileRecordReader( + fileRecordReader, + keyType, + valueType, + file.level(), + overrideSequenceWithSnapshotId, + file.minSequenceNumber()); } public static Builder builder( diff --git a/paimon-core/src/main/java/org/apache/paimon/manifest/IndexManifestEntrySerializer.java b/paimon-core/src/main/java/org/apache/paimon/manifest/IndexManifestEntrySerializer.java index 98ab4df6f136..60113adff1c9 100644 --- a/paimon-core/src/main/java/org/apache/paimon/manifest/IndexManifestEntrySerializer.java +++ b/paimon-core/src/main/java/org/apache/paimon/manifest/IndexManifestEntrySerializer.java @@ -18,6 +18,7 @@ package org.apache.paimon.manifest; +import org.apache.paimon.data.BinaryRow; import org.apache.paimon.data.GenericArray; import org.apache.paimon.data.GenericRow; import org.apache.paimon.data.InternalRow; @@ -25,6 +26,8 @@ import org.apache.paimon.index.IndexFileMeta; import org.apache.paimon.utils.VersionedObjectSerializer; +import java.util.function.Function; + import static org.apache.paimon.data.BinaryString.fromString; import static org.apache.paimon.index.IndexFileMetaSerializer.dvMetasToRowArrayData; import static org.apache.paimon.index.IndexFileMetaSerializer.rowArrayDataToDvMetas; @@ -104,4 +107,16 @@ public IndexManifestEntry convertFrom(int version, InternalRow row) { row.isNullAt(8) ? null : row.getString(8).toString(), globalIndexMeta)); } + + public static Function partitionGetter() { + return row -> deserializeBinaryRow(row.getBinary(2)); + } + + public static Function bucketGetter() { + return row -> row.getInt(3); + } + + public static Function indexTypeGetter() { + return row -> row.getString(4).toString(); + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java index 5d04440c9950..6cc551baec76 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java @@ -65,6 +65,9 @@ public class PartialUpdateMergeFunction implements MergeFunction { public static final String SEQUENCE_GROUP = "sequence-group"; + private static final String SEQUENCE_GROUP_PK_ERROR = + "The sequence-group '%s' contains primary key field '%s', " + + "which is not allowed. Primary key columns cannot be put in sequence-group."; private final InternalRow.FieldGetter[] getters; private final boolean ignoreDelete; @@ -428,6 +431,14 @@ private Factory(Options options, RowType rowType, List primaryKeys) { .map(fieldName -> requireField(fieldName, fieldNames)) .forEach( field -> { + String protectedFieldName = fieldNames.get(field); + if (primaryKeys.contains(protectedFieldName)) { + throw new IllegalArgumentException( + String.format( + SEQUENCE_GROUP_PK_ERROR, + k, + protectedFieldName)); + } if (fieldSeqComparators.containsKey(field)) { throw new IllegalArgumentException( String.format( @@ -435,13 +446,17 @@ private Factory(Options options, RowType rowType, List primaryKeys) { fieldNames.get(field), k)); } fieldSeqComparators.put(field, userDefinedSeqComparator); - fieldsProtectedBySequenceGroup.add(fieldNames.get(field)); + fieldsProtectedBySequenceGroup.add(protectedFieldName); }); // add self for (int index : sequenceFields) { - allSequenceFields.add(fieldNames.get(index)); String fieldName = fieldNames.get(index); + if (primaryKeys.contains(fieldName)) { + throw new IllegalArgumentException( + String.format(SEQUENCE_GROUP_PK_ERROR, k, fieldName)); + } + allSequenceFields.add(fieldName); fieldSeqComparators.put(index, userDefinedSeqComparator); sequenceGroupMap.put(fieldName, index); } diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldNestedUpdateAgg.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldNestedUpdateAgg.java index b2848b35b410..13b19e6e9ba1 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldNestedUpdateAgg.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldNestedUpdateAgg.java @@ -19,24 +19,29 @@ package org.apache.paimon.mergetree.compact.aggregate; import org.apache.paimon.codegen.Projection; +import org.apache.paimon.codegen.RecordComparator; import org.apache.paimon.codegen.RecordEqualiser; import org.apache.paimon.data.BinaryRow; import org.apache.paimon.data.GenericArray; import org.apache.paimon.data.InternalArray; import org.apache.paimon.data.InternalRow; import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataType; import org.apache.paimon.types.RowType; import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.apache.paimon.codegen.CodeGenUtils.newProjection; +import static org.apache.paimon.codegen.CodeGenUtils.newRecordComparator; import static org.apache.paimon.codegen.CodeGenUtils.newRecordEqualiser; import static org.apache.paimon.options.ConfigOptions.key; +import static org.apache.paimon.utils.Preconditions.checkArgument; import static org.apache.paimon.utils.Preconditions.checkNotNull; /** @@ -52,10 +57,23 @@ public class FieldNestedUpdateAgg extends FieldAggregator { @Nullable private final Projection keyProjection; @Nullable private final RecordEqualiser elementEqualiser; + @Nullable private final Projection sequenceProjection; + @Nullable private final RecordComparator sequenceComparator; + private final boolean hasSequenceField; + private final int countLimit; public FieldNestedUpdateAgg( String name, ArrayType dataType, List nestedKey, int countLimit) { + this(name, dataType, nestedKey, Collections.emptyList(), countLimit); + } + + public FieldNestedUpdateAgg( + String name, + ArrayType dataType, + List nestedKey, + List nestedSequenceField, + int countLimit) { super(name, dataType); RowType nestedType = (RowType) dataType.getElementType(); this.nestedFields = nestedType.getFieldCount(); @@ -67,6 +85,31 @@ public FieldNestedUpdateAgg( this.elementEqualiser = null; } + // If nestedSequenceField is set, we need to compare sequence fields to determine + // whether to update. Only update when the new sequence is greater than the old one. + if (!nestedSequenceField.isEmpty()) { + checkArgument( + this.keyProjection != null, + "nested-sequence-field requires nested-key to be set."); + this.sequenceProjection = newProjection(nestedType, nestedSequenceField); + this.hasSequenceField = true; + + // Extract the data types of the sequence fields to generate a native record comparator + int sequenceFields = nestedSequenceField.size(); + List seqTypes = new ArrayList<>(sequenceFields); + int[] sortFields = new int[sequenceFields]; + for (int i = 0; i < sequenceFields; i++) { + String fieldName = nestedSequenceField.get(i); + seqTypes.add(nestedType.getTypeAt(nestedType.getFieldIndex(fieldName))); + sortFields[i] = i; + } + this.sequenceComparator = newRecordComparator(seqTypes, sortFields); + } else { + this.sequenceProjection = null; + this.sequenceComparator = null; + this.hasSequenceField = false; + } + // If deduplicate key is set, we don't guarantee that the result is exactly right this.countLimit = countLimit; } @@ -94,7 +137,15 @@ public Object agg(Object accumulator, Object inputField) { Map map = new HashMap<>(); for (InternalRow row : rows) { BinaryRow key = keyProjection.apply(row).copy(); - map.put(key, row); + if (hasSequenceField) { + // When sequence field is configured, only update if the new sequence is greater + InternalRow existing = map.get(key); + if (existing == null || compareSequence(row, existing) >= 0) { + map.put(key, row); + } + } else { + map.put(key, row); + } } rows = new ArrayList<>(map.values()); @@ -146,6 +197,22 @@ public Object retract(Object accumulator, Object retractField) { } } + private int compareSequence(InternalRow newRow, InternalRow oldRow) { + checkNotNull( + sequenceComparator, + "sequenceComparator should not be null when hasSequenceField is true."); + checkNotNull( + sequenceProjection, + "sequenceProjection should not be null when hasSequenceField is true."); + + // Project the rows into sub-rows containing only sequence fields + BinaryRow newSeqRow = sequenceProjection.apply(newRow).copy(); + BinaryRow oldSeqRow = sequenceProjection.apply(oldRow).copy(); + + // Triggers native CodeGen comparison (Nulls First by default) + return sequenceComparator.compare(newSeqRow, oldSeqRow); + } + private void addNonNullRows(InternalArray array, List rows) { for (int i = 0; i < array.size(); i++) { if (array.isNullAt(i)) { diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/factory/FieldNestedUpdateAggFactory.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/factory/FieldNestedUpdateAggFactory.java index 070931e01135..ca0d9614986a 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/factory/FieldNestedUpdateAggFactory.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/factory/FieldNestedUpdateAggFactory.java @@ -39,6 +39,7 @@ public FieldNestedUpdateAgg create(DataType fieldType, CoreOptions options, Stri return createFieldNestedUpdateAgg( fieldType, options.fieldNestedUpdateAggNestedKey(field), + options.fieldNestedUpdateAggNestedSequenceField(field), options.fieldNestedUpdateAggCountLimit(field)); } @@ -48,17 +49,25 @@ public String identifier() { } private FieldNestedUpdateAgg createFieldNestedUpdateAgg( - DataType fieldType, List nestedKey, int countLimit) { + DataType fieldType, + List nestedKey, + List nestedSequenceField, + int countLimit) { if (nestedKey == null) { nestedKey = Collections.emptyList(); } + if (nestedSequenceField == null) { + nestedSequenceField = Collections.emptyList(); + } + String typeErrorMsg = "Data type for nested table column must be 'Array' but was '%s'."; checkArgument(fieldType instanceof ArrayType, typeErrorMsg, fieldType); ArrayType arrayType = (ArrayType) fieldType; checkArgument(arrayType.getElementType() instanceof RowType, typeErrorMsg, fieldType); - return new FieldNestedUpdateAgg(identifier(), arrayType, nestedKey, countLimit); + return new FieldNestedUpdateAgg( + identifier(), arrayType, nestedKey, nestedSequenceField, countLimit); } } diff --git a/paimon-core/src/main/java/org/apache/paimon/metastore/ChainTableCommitPreCallback.java b/paimon-core/src/main/java/org/apache/paimon/metastore/ChainTableCommitPreCallback.java index 0b26cb637540..03399a4fde77 100644 --- a/paimon-core/src/main/java/org/apache/paimon/metastore/ChainTableCommitPreCallback.java +++ b/paimon-core/src/main/java/org/apache/paimon/metastore/ChainTableCommitPreCallback.java @@ -36,6 +36,7 @@ import org.apache.paimon.table.sink.CommitPreCallback; import org.apache.paimon.table.source.snapshot.SnapshotReader; import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.ChainPartitionProjector; import org.apache.paimon.utils.ChainTableUtils; import org.apache.paimon.utils.InternalRowPartitionComputer; import org.apache.paimon.utils.RowDataToObjectArrayConverter; @@ -43,6 +44,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -111,33 +113,59 @@ public void call( partitionType, table.schema().partitionKeys().toArray(new String[0]), coreOptions.legacyPartitionName()); - RecordComparator partitionComparator = - CodeGenUtils.newRecordComparator(partitionType.getFieldTypes()); + + List chainKeys = + ChainTableUtils.chainPartitionKeys(coreOptions, table.schema().partitionKeys()); + int chainFieldCount = chainKeys.size(); + ChainPartitionProjector projector = + new ChainPartitionProjector(partitionType, chainFieldCount); + int groupFieldCount = projector.groupFieldCount(); + RecordComparator chainComparator = + CodeGenUtils.newRecordComparator(projector.chainPartitionType().getFieldTypes()); + List snapshotPartitions = table.newSnapshotReader().partitionEntries().stream() .map(PartitionEntry::partition) - .sorted(partitionComparator) .collect(Collectors.toList()); SnapshotReader deltaSnapshotReader = deltaTable.newSnapshotReader(); PredicateBuilder builder = new PredicateBuilder(partitionType); for (BinaryRow partition : changedPartitions) { + BinaryRow partitionGroup = projector.extractGroupPartition(partition); + BinaryRow partitionChain = projector.extractChainPartition(partition); + + List sameGroupSnapshots = + filterSameGroup(snapshotPartitions, partitionGroup, projector); + sameGroupSnapshots.sort( + (a, b) -> + chainComparator.compare( + projector.extractChainPartition(a), + projector.extractChainPartition(b))); + Optional preSnapshotPartition = - findPreSnapshotPartition(snapshotPartitions, partition, partitionComparator); + findPreSnapshotInGroup( + sameGroupSnapshots, partitionChain, chainComparator, projector); Optional nextSnapshotPartition = - findNextSnapshotPartition(snapshotPartitions, partition, partitionComparator); + findNextSnapshotInGroup( + sameGroupSnapshots, partitionChain, chainComparator, projector); + Predicate deltaFollowingPredicate = - ChainTableUtils.createTriangularPredicate( - partition, partitionConverter, builder::equal, builder::greaterThan); + ChainTableUtils.createGroupChainPredicate( + partition, + partitionConverter, + groupFieldCount, + builder::equal, + builder::greaterThan); List deltaFollowingPartitions = deltaSnapshotReader.withPartitionFilter(deltaFollowingPredicate) .partitionEntries().stream() .map(PartitionEntry::partition) .filter( deltaPartition -> - isBeforeNextSnapshotPartition( + isBeforeNextSnapshotInGroup( deltaPartition, nextSnapshotPartition, - partitionComparator)) + chainComparator, + projector)) .collect(Collectors.toList()); boolean canDrop = deltaFollowingPartitions.isEmpty() || preSnapshotPartition.isPresent(); @@ -159,13 +187,26 @@ private boolean isPureDeleteCommit( && indexFiles.stream().allMatch(f -> f.kind() == FileKind.DELETE); } - private Optional findPreSnapshotPartition( - List snapshotPartitions, - BinaryRow partition, - RecordComparator partitionComparator) { + private List filterSameGroup( + List partitions, BinaryRow groupKey, ChainPartitionProjector projector) { + List result = new ArrayList<>(); + for (BinaryRow partition : partitions) { + if (projector.extractGroupPartition(partition).equals(groupKey)) { + result.add(partition); + } + } + return result; + } + + private Optional findPreSnapshotInGroup( + List sortedSameGroupPartitions, + BinaryRow targetChain, + RecordComparator chainComparator, + ChainPartitionProjector projector) { BinaryRow pre = null; - for (BinaryRow snapshotPartition : snapshotPartitions) { - if (partitionComparator.compare(snapshotPartition, partition) < 0) { + for (BinaryRow snapshotPartition : sortedSameGroupPartitions) { + BinaryRow chain = projector.extractChainPartition(snapshotPartition); + if (chainComparator.compare(chain, targetChain) < 0) { pre = snapshotPartition; } else { break; @@ -174,24 +215,31 @@ private Optional findPreSnapshotPartition( return Optional.ofNullable(pre); } - private Optional findNextSnapshotPartition( - List snapshotPartitions, - BinaryRow partition, - RecordComparator partitionComparator) { - for (BinaryRow snapshotPartition : snapshotPartitions) { - if (partitionComparator.compare(snapshotPartition, partition) > 0) { + private Optional findNextSnapshotInGroup( + List sortedSameGroupPartitions, + BinaryRow targetChain, + RecordComparator chainComparator, + ChainPartitionProjector projector) { + for (BinaryRow snapshotPartition : sortedSameGroupPartitions) { + BinaryRow chain = projector.extractChainPartition(snapshotPartition); + if (chainComparator.compare(chain, targetChain) > 0) { return Optional.of(snapshotPartition); } } return Optional.empty(); } - private boolean isBeforeNextSnapshotPartition( + private boolean isBeforeNextSnapshotInGroup( BinaryRow partition, Optional nextSnapshotPartition, - RecordComparator partitionComparator) { - return !nextSnapshotPartition.isPresent() - || partitionComparator.compare(partition, nextSnapshotPartition.get()) < 0; + RecordComparator chainComparator, + ChainPartitionProjector projector) { + if (!nextSnapshotPartition.isPresent()) { + return true; + } + BinaryRow partitionChain = projector.extractChainPartition(partition); + BinaryRow nextChain = projector.extractChainPartition(nextSnapshotPartition.get()); + return chainComparator.compare(partitionChain, nextChain) < 0; } private String generatePartitionValues( diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/ChainTablePartitionExpire.java b/paimon-core/src/main/java/org/apache/paimon/operation/ChainTablePartitionExpire.java new file mode 100644 index 000000000000..0e70854363ad --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/operation/ChainTablePartitionExpire.java @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.operation; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.annotation.VisibleForTesting; +import org.apache.paimon.catalog.Catalog; +import org.apache.paimon.codegen.CodeGenUtils; +import org.apache.paimon.codegen.RecordComparator; +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.manifest.PartitionEntry; +import org.apache.paimon.partition.PartitionTimeExtractor; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.PartitionModification; +import org.apache.paimon.table.sink.BatchTableCommit; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.ChainPartitionProjector; +import org.apache.paimon.utils.ChainTableUtils; +import org.apache.paimon.utils.InternalRowPartitionComputer; + +import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.Duration; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +/** + * Partition expiration for chain tables. + * + *

    Chain tables store data across snapshot and delta branches. A delta partition depends on its + * nearest earlier snapshot partition as an anchor for merge-on-read. This class expires partitions + * in "segments" defined by consecutive snapshot partitions to maintain chain integrity. + * + *

    A segment consists of one snapshot partition and all delta partitions whose time falls between + * that snapshot and the next snapshot in sorted order. The segment is the atomic unit of + * expiration: either the entire segment (snapshot + deltas) is expired, or nothing in it is. + * + *

    Algorithm per group: + * + *

      + *
    1. List all snapshot branch partitions sorted by chain partition time. + *
    2. Filter to those before the cutoff ({@code now - expirationTime}). + *
    3. If fewer than 2 snapshots are before the cutoff, nothing can be expired (the last one must + * be kept as anchor). + *
    4. The most recent snapshot before the cutoff is the anchor (kept). All earlier snapshots form + * expirable segments together with their associated delta partitions. + *
    5. The number of segments expired is limited by {@code maxExpireNum}. + *
    6. Delta partitions are dropped first, then snapshot partitions, so that {@code + * ChainTableCommitPreCallback} validation passes. + *
    + */ +public class ChainTablePartitionExpire implements PartitionExpire { + + private static final Logger LOG = LoggerFactory.getLogger(ChainTablePartitionExpire.class); + + private final Duration expirationTime; + private final Duration checkInterval; + private final FileStoreTable snapshotTable; + private final FileStoreTable deltaTable; + private final PartitionTimeExtractor timeExtractor; + private final ChainPartitionProjector projector; + private final RecordComparator chainPartitionComparator; + private final InternalRowPartitionComputer partitionComputer; + private final List partitionKeys; + private final List chainPartitionKeys; + private final boolean endInputCheckPartitionExpire; + private final int maxExpireNum; + private final int expireBatchSize; + @Nullable private final PartitionModification snapshotPartitionModification; + @Nullable private final PartitionModification deltaPartitionModification; + private LocalDateTime lastCheck; + + public ChainTablePartitionExpire( + Duration expirationTime, + Duration checkInterval, + FileStoreTable snapshotTable, + FileStoreTable deltaTable, + CoreOptions options, + RowType partitionType, + boolean endInputCheckPartitionExpire, + int maxExpireNum, + int expireBatchSize, + @Nullable PartitionModification snapshotPartitionModification, + @Nullable PartitionModification deltaPartitionModification) { + this.expirationTime = expirationTime; + this.checkInterval = checkInterval; + this.snapshotTable = snapshotTable; + this.deltaTable = deltaTable; + this.partitionKeys = partitionType.getFieldNames(); + this.maxExpireNum = maxExpireNum; + this.expireBatchSize = expireBatchSize; + this.snapshotPartitionModification = snapshotPartitionModification; + this.deltaPartitionModification = deltaPartitionModification; + + List allPartitionKeys = partitionType.getFieldNames(); + this.chainPartitionKeys = ChainTableUtils.chainPartitionKeys(options, allPartitionKeys); + int chainFieldCount = chainPartitionKeys.size(); + this.projector = new ChainPartitionProjector(partitionType, chainFieldCount); + this.chainPartitionComparator = + CodeGenUtils.newRecordComparator(projector.chainPartitionType().getFieldTypes()); + this.timeExtractor = + new PartitionTimeExtractor( + options.partitionTimestampPattern(), options.partitionTimestampFormatter()); + this.partitionComputer = + new InternalRowPartitionComputer( + options.partitionDefaultName(), + partitionType, + allPartitionKeys.toArray(new String[0]), + options.legacyPartitionName()); + this.endInputCheckPartitionExpire = endInputCheckPartitionExpire; + + long rndSeconds = 0; + long checkIntervalSeconds = checkInterval.toMillis() / 1000; + if (checkIntervalSeconds > 0) { + rndSeconds = ThreadLocalRandom.current().nextLong(checkIntervalSeconds); + } + this.lastCheck = LocalDateTime.now().minusSeconds(rndSeconds); + } + + @Override + public List> expire(long commitIdentifier) { + return expire(LocalDateTime.now(), commitIdentifier); + } + + @Override + public boolean isValueExpiration() { + return true; + } + + @Override + public boolean isValueAllExpired(Collection partitions) { + return isValueAllExpired(partitions, LocalDateTime.now()); + } + + @VisibleForTesting + boolean isValueAllExpired(Collection partitions, LocalDateTime now) { + LocalDateTime expireDateTime = now.minus(expirationTime); + for (BinaryRow partition : partitions) { + LocalDateTime partTime = extractPartitionTime(partition); + if (partTime == null || !expireDateTime.isAfter(partTime)) { + return false; + } + } + + // All partitions are time-wise before cutoff, but chain table retains anchors + // (the most recent snapshot before cutoff per group) and their segment's deltas. + // Compute per-group retain boundary: partitions at or after the boundary are retained. + Map retainBoundary = computeGroupRetainBoundary(expireDateTime); + for (BinaryRow partition : partitions) { + BinaryRow groupKey = projector.extractGroupPartition(partition); + LocalDateTime boundary = retainBoundary.get(groupKey); + if (boundary == null) { + return false; + } + LocalDateTime partTime = extractPartitionTime(partition); + if (partTime != null && !partTime.isBefore(boundary)) { + return false; + } + } + return true; + } + + /** + * For each group that has snapshot partitions, compute the time boundary at or above which + * partitions are retained (not expired). Returns {@link LocalDateTime#MIN} for groups where + * fewer than 2 snapshots fall before the cutoff (nothing can be expired). Groups with no + * snapshot partitions at all (delta-only) are not included in the result; without a snapshot + * anchor, their earlier deltas may still be required by later delta-only chain reads. + */ + private Map computeGroupRetainBoundary(LocalDateTime cutoffTime) { + List snapshotEntries = snapshotTable.newSnapshotReader().partitionEntries(); + Map> groupedSnapshots = groupByGroupKey(snapshotEntries); + + Map boundaries = new HashMap<>(); + for (Map.Entry> entry : groupedSnapshots.entrySet()) { + BinaryRow groupKey = entry.getKey(); + int countBeforeCutoff = 0; + LocalDateTime latestBeforeCutoff = null; + for (BinaryRow snapshot : entry.getValue()) { + LocalDateTime time = extractPartitionTime(snapshot); + if (time != null && cutoffTime.isAfter(time)) { + countBeforeCutoff++; + if (latestBeforeCutoff == null || time.isAfter(latestBeforeCutoff)) { + latestBeforeCutoff = time; + } + } + } + if (countBeforeCutoff < 2) { + boundaries.put(groupKey, LocalDateTime.MIN); + } else { + boundaries.put(groupKey, latestBeforeCutoff); + } + } + return boundaries; + } + + @VisibleForTesting + void setLastCheck(LocalDateTime time) { + lastCheck = time; + } + + @VisibleForTesting + List> expire(LocalDateTime now, long commitIdentifier) { + if (checkInterval.isZero() + || now.isAfter(lastCheck.plus(checkInterval)) + || (endInputCheckPartitionExpire && Long.MAX_VALUE == commitIdentifier)) { + List> expired = doExpire(now.minus(expirationTime)); + lastCheck = now; + return expired; + } + return null; + } + + private List> doExpire(LocalDateTime cutoffTime) { + List snapshotPartitions = + snapshotTable.newSnapshotReader().partitionEntries(); + List deltaPartitions = deltaTable.newSnapshotReader().partitionEntries(); + + Map> groupedSnapshots = groupByGroupKey(snapshotPartitions); + Map> groupedDeltas = groupByGroupKey(deltaPartitions); + + List snapshotPartitionsToExpire = new ArrayList<>(); + List deltaPartitionsToExpire = new ArrayList<>(); + + for (Map.Entry> entry : groupedSnapshots.entrySet()) { + BinaryRow groupKey = entry.getKey(); + List groupSnapshots = entry.getValue(); + + groupSnapshots.sort( + (a, b) -> + chainPartitionComparator.compare( + projector.extractChainPartition(a), + projector.extractChainPartition(b))); + + List snapshotsBeforeCutoff = new ArrayList<>(); + for (BinaryRow partition : groupSnapshots) { + LocalDateTime partTime = extractPartitionTime(partition); + if (partTime != null && cutoffTime.isAfter(partTime)) { + snapshotsBeforeCutoff.add(partition); + } + } + + if (snapshotsBeforeCutoff.size() < 2) { + continue; + } + + // Anchor = most recent snapshot before cutoff, kept as merge base + int anchorIndex = snapshotsBeforeCutoff.size() - 1; + + // Expirable snapshots: all before anchor, oldest first. + // Each forms a segment with its associated deltas. + int segmentsToExpire = Math.min(anchorIndex, maxExpireNum); + + List groupDeltas = groupedDeltas.get(groupKey); + + for (int i = 0; i < segmentsToExpire; i++) { + BinaryRow segmentSnapshot = snapshotsBeforeCutoff.get(i); + snapshotPartitionsToExpire.add(segmentSnapshot); + + if (groupDeltas != null) { + // Segment boundary: from this snapshot's time up to the next snapshot's time + LocalDateTime segmentStart = extractPartitionTime(segmentSnapshot); + BinaryRow nextSnapshot = snapshotsBeforeCutoff.get(i + 1); + LocalDateTime segmentEnd = extractPartitionTime(nextSnapshot); + + if (segmentStart != null && segmentEnd != null) { + for (BinaryRow deltaPartition : groupDeltas) { + LocalDateTime deltaTime = extractPartitionTime(deltaPartition); + if (deltaTime != null + && !deltaTime.isBefore(segmentStart) + && deltaTime.isBefore(segmentEnd)) { + deltaPartitionsToExpire.add(deltaPartition); + } + } + } + } + } + + // Also collect orphan deltas before the earliest expired snapshot + if (segmentsToExpire > 0 && groupDeltas != null) { + LocalDateTime firstSnapshotTime = + extractPartitionTime(snapshotsBeforeCutoff.get(0)); + if (firstSnapshotTime != null) { + for (BinaryRow deltaPartition : groupDeltas) { + LocalDateTime deltaTime = extractPartitionTime(deltaPartition); + if (deltaTime != null && deltaTime.isBefore(firstSnapshotTime)) { + deltaPartitionsToExpire.add(deltaPartition); + } + } + } + } + } + + if (snapshotPartitionsToExpire.isEmpty() && deltaPartitionsToExpire.isEmpty()) { + return new ArrayList<>(); + } + + List> deltaSpecs = toPartitionSpecs(deltaPartitionsToExpire); + List> snapshotSpecs = toPartitionSpecs(snapshotPartitionsToExpire); + List> allExpired = new ArrayList<>(); + + if (!deltaSpecs.isEmpty()) { + LOG.info("Chain table expire delta partitions: {}", deltaSpecs); + batchDropPartitions(deltaTable, deltaSpecs, deltaPartitionModification); + allExpired.addAll(deltaSpecs); + } + + if (!snapshotSpecs.isEmpty()) { + LOG.info("Chain table expire snapshot partitions: {}", snapshotSpecs); + batchDropPartitions(snapshotTable, snapshotSpecs, snapshotPartitionModification); + allExpired.addAll(snapshotSpecs); + } + + return allExpired; + } + + private void batchDropPartitions( + FileStoreTable table, + List> partitionSpecs, + @Nullable PartitionModification partitionModification) { + if (partitionModification != null) { + try { + if (expireBatchSize > 0 && expireBatchSize < partitionSpecs.size()) { + for (List> batch : + Lists.partition(partitionSpecs, expireBatchSize)) { + partitionModification.dropPartitions(batch); + partitionModification.dropPartitions(toDonePartitions(batch)); + } + } else { + partitionModification.dropPartitions(partitionSpecs); + partitionModification.dropPartitions(toDonePartitions(partitionSpecs)); + } + } catch (Catalog.TableNotExistException e) { + throw new RuntimeException(e); + } + } else { + if (expireBatchSize > 0 && expireBatchSize < partitionSpecs.size()) { + for (List> batch : + Lists.partition(partitionSpecs, expireBatchSize)) { + dropPartitions(table, batch); + } + } else { + dropPartitions(table, partitionSpecs); + } + } + } + + private List> toDonePartitions( + List> expiredPartitions) { + List> donePartitions = new ArrayList<>(expiredPartitions.size()); + for (Map partition : expiredPartitions) { + LinkedHashMap donePartition = new LinkedHashMap<>(partition); + Map.Entry lastEntry = null; + for (Map.Entry entry : donePartition.entrySet()) { + lastEntry = entry; + } + if (lastEntry != null) { + donePartition.put(lastEntry.getKey(), lastEntry.getValue() + ".done"); + donePartitions.add(donePartition); + } + } + return donePartitions; + } + + private Map> groupByGroupKey(List partitionEntries) { + Map> grouped = new LinkedHashMap<>(); + for (PartitionEntry entry : partitionEntries) { + BinaryRow fullPartition = entry.partition(); + BinaryRow groupKey = projector.extractGroupPartition(fullPartition); + grouped.computeIfAbsent(groupKey, k -> new ArrayList<>()).add(fullPartition); + } + return grouped; + } + + private LocalDateTime extractPartitionTime(BinaryRow partition) { + try { + LinkedHashMap partValues = + partitionComputer.generatePartValues(partition); + List chainValues = new ArrayList<>(); + for (String key : chainPartitionKeys) { + chainValues.add(partValues.get(key)); + } + return timeExtractor.extract(chainPartitionKeys, chainValues); + } catch (Exception e) { + LOG.warn("Failed to extract partition time from {}", partition, e); + return null; + } + } + + private List> toPartitionSpecs(List partitions) { + return partitions.stream() + .map( + p -> { + LinkedHashMap values = + partitionComputer.generatePartValues(p); + Map spec = new LinkedHashMap<>(); + for (String key : partitionKeys) { + String value = values.get(key); + if (value != null) { + spec.put(key, value); + } + } + return spec; + }) + .collect(Collectors.toList()); + } + + private void dropPartitions(FileStoreTable table, List> partitionSpecs) { + try (BatchTableCommit commit = table.newBatchWriteBuilder().newCommit()) { + commit.truncatePartitions(partitionSpecs); + } catch (Exception e) { + throw new RuntimeException( + String.format( + "Failed to drop partitions from %s: %s.", table.name(), partitionSpecs), + e); + } + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/DataEvolutionFileStoreScan.java b/paimon-core/src/main/java/org/apache/paimon/operation/DataEvolutionFileStoreScan.java index 514afff296bb..38ab9b0db02f 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/DataEvolutionFileStoreScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/DataEvolutionFileStoreScan.java @@ -46,11 +46,13 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; +import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Queue; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; @@ -60,16 +62,18 @@ import static org.apache.paimon.format.blob.BlobFileFormat.isBlobFile; import static org.apache.paimon.manifest.ManifestFileMeta.allContainsRowId; import static org.apache.paimon.types.VectorType.isVectorStoreFile; -import static org.apache.paimon.utils.Preconditions.checkNotNull; /** {@link FileStoreScan} for data-evolution enabled table. */ public class DataEvolutionFileStoreScan extends AppendOnlyFileStoreScan { - private final ConcurrentMap>, List> fileFields; - private boolean dropStats = false; @Nullable private RowType readType; + // Cache file's physical field id set per (schemaId, writeCols) to avoid recomputing during + // per-file column pruning in postFilterManifestEntries. + private final ConcurrentMap>, Set> fileFieldIdsCache = + new ConcurrentHashMap<>(); + public DataEvolutionFileStoreScan( ManifestsReader manifestsReader, BucketSelectConverter bucketSelectConverter, @@ -90,8 +94,6 @@ public DataEvolutionFileStoreScan( false, deletionVectorsEnabled, true); - - this.fileFields = new ConcurrentHashMap<>(); } @Override @@ -175,21 +177,22 @@ public Iterator readManifestEntries( @Override protected boolean postFilterManifestEntriesEnabled() { - return inputFilter != null; + // Always enable post-filtering. The list filterByStats handles predicate-based pruning + // and pruneByReadType strips per-file columns that are not requested — both + // need row-id-range grouping that single filterByStats(ManifestEntry) cannot see. + return inputFilter != null || readType != null; } @Override protected List postFilterManifestEntries(List entries) { - checkNotNull(inputFilter); - // group by row id range RangeHelper rangeHelper = new RangeHelper<>(e -> e.file().nonNullRowIdRange()); List> splitByRowId = rangeHelper.mergeOverlappingRanges(entries); return splitByRowId.stream() - .filter(this::filterByStats) - .flatMap(Collection::stream) + .filter(group -> inputFilter == null || filterByStats(group)) + .flatMap(group -> pruneByReadType(group).stream()) .map(entry -> dropStats ? dropStats(entry) : entry) .collect(Collectors.toList()); } @@ -200,6 +203,62 @@ private boolean filterByStats(List entries) { stats.rowCount(), stats.minValues(), stats.maxValues(), stats.nullCounts()); } + /** + * Per-file column pruning within a row-id-range group: drop files whose physical columns have + * no overlap with the query's {@code readType}. Necessary for columnar-split DE scenarios where + * a logical row is reconstructed from multiple files in the same row id range — a query that + * does not reference a file's columns has no reason to read it. + * + *

    When every file in the group lacks a requested column (e.g. an ADD COLUMN projection over + * a row-disjoint pre-ALTER group), one file is kept as a row-count representative so the reader + * can emit the right number of NULL-filled rows. + */ + private List pruneByReadType(List group) { + if (readType == null || group.size() <= 1) { + return group; + } + Set readFieldIds = new HashSet<>(); + for (DataField f : readType.getFields()) { + readFieldIds.add(f.id()); + } + List kept = new ArrayList<>(group.size()); + for (ManifestEntry entry : group) { + Set fileIds = fileFieldIdsForEntry(entry); + for (int id : readFieldIds) { + if (fileIds.contains(id)) { + kept.add(entry); + break; + } + } + } + // Group must contribute at least one file so the reader sees rowCount and can NULL-fill + // missing columns for the projection's rows. + return kept.isEmpty() ? Collections.singletonList(group.get(0)) : kept; + } + + private Set fileFieldIdsForEntry(ManifestEntry entry) { + return fileFieldIdsCache.computeIfAbsent( + Pair.of(entry.file().schemaId(), entry.file().writeCols()), + pair -> computeFileFieldIds(this::scanTableSchema, entry.file())); + } + + /** + * Field ids of the columns physically present in {@code file}, resolved through the file's own + * schema (i.e. the schema the file was written under). Field id, not field name, is the stable + * identity across schemas — necessary so a renamed column matches an old file written under the + * pre-rename name. + */ + @VisibleForTesting + static Set computeFileFieldIds( + Function scanTableSchema, DataFileMeta file) { + Set ids = new HashSet<>(); + for (DataField f : + scanTableSchema.apply(file.schemaId()).project(file.writeCols()).fields()) { + ids.add(f.id()); + } + return ids; + } + /** TODO: Optimize implementation of this method. */ @VisibleForTesting static EvolutionStats evolutionStats( @@ -279,16 +338,20 @@ static EvolutionStats evolutionStats( } } + long groupRowCount = metas.get(0).file().rowCount(); DataEvolutionRow finalMin = new DataEvolutionRow(metas.size(), rowOffsets, fieldOffsets); DataEvolutionRow finalMax = new DataEvolutionRow(metas.size(), rowOffsets, fieldOffsets); + // For null-count specifically, a field absent from every file in the group means every + // logical row is null for that field — encode as groupRowCount so stats predicates can + // prune non-null comparisons (e.g. `extra2 = 'x'`) instead of falling back to + // "unknown stats -> keep" in LeafPredicate.test. DataEvolutionArray finalNullCounts = - new DataEvolutionArray(metas.size(), rowOffsets, fieldOffsets); + new DataEvolutionArray(metas.size(), rowOffsets, fieldOffsets, groupRowCount); finalMin.setRows(min); finalMax.setRows(max); finalNullCounts.setRows(nullCounts); - return new EvolutionStats( - metas.get(0).file().rowCount(), finalMin, finalMax, finalNullCounts); + return new EvolutionStats(groupRowCount, finalMin, finalMax, finalNullCounts); } /** Note: Keep this thread-safe. */ @@ -296,27 +359,13 @@ static EvolutionStats evolutionStats( protected boolean filterByStats(ManifestEntry entry) { DataFileMeta file = entry.file(); - if (readType != null) { - boolean containsReadCol = false; - List fileFieldNmes = - fileFields.computeIfAbsent( - Pair.of(file.schemaId(), file.writeCols()), - pair -> - scanTableSchema(file.schemaId()) - .project(file.writeCols()) - .logicalRowType() - .getFieldNames()); - - for (String field : readType.getFieldNames()) { - if (fileFieldNmes.contains(field)) { - containsReadCol = true; - break; - } - } - if (!containsReadCol) { - return false; - } - } + // Do not drop a file based on read-column intersection. For data-evolution + // tables a field absent from a file is an implicit NULL across rowCount() + // rows, and predicates such as `new_col IS NULL` should still match those + // rows. Predicate-based stats pruning runs in + // filterByStats(List), which evolves stats per file via + // DataEvolutionRow / DataEvolutionArray and correctly reports missing + // fields as null. // If rowRanges is null, all entries should be kept if (this.rowRangeIndex == null) { diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreCommitImpl.java b/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreCommitImpl.java index 3f9fdb9f1c0c..7cca259cbf9f 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreCommitImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/FileStoreCommitImpl.java @@ -55,6 +55,8 @@ import org.apache.paimon.operation.commit.SuccessCommitResult; import org.apache.paimon.operation.metrics.CommitMetrics; import org.apache.paimon.operation.metrics.CommitStats; +import org.apache.paimon.options.MemorySize; +import org.apache.paimon.options.Options; import org.apache.paimon.partition.PartitionPredicate; import org.apache.paimon.partition.PartitionStatistics; import org.apache.paimon.predicate.Predicate; @@ -157,6 +159,7 @@ public class FileStoreCommitImpl implements FileStoreCommit { private boolean ignoreEmptyCommit; private CommitMetrics commitMetrics; private boolean appendCommitCheckConflict = false; + private long lastCommittedSnapshotId = -1L; public FileStoreCommitImpl( SnapshotCommit snapshotCommit, @@ -374,7 +377,8 @@ public int commit(ManifestCommittable committable, boolean checkAppendFiles) { changes.compactChangelog, commitDuration, generatedSnapshot, - attempts); + attempts, + lastCommittedSnapshotId); } } return generatedSnapshot; @@ -387,7 +391,8 @@ private void reportCommit( List compactChangelogFiles, long commitDuration, int generatedSnapshots, - int attempts) { + int attempts, + long lastCommittedSnapshotId) { CommitStats commitStats = new CommitStats( appendTableFiles, @@ -396,7 +401,8 @@ private void reportCommit( compactChangelogFiles, commitDuration, generatedSnapshots, - attempts); + attempts, + lastCommittedSnapshotId); commitMetrics.reportCommit(commitStats); } @@ -538,7 +544,8 @@ public int overwritePartition( emptyList(), commitDuration, generatedSnapshot, - attempts); + attempts, + lastCommittedSnapshotId); } } return generatedSnapshot; @@ -829,6 +836,7 @@ CommitResult tryCommitOnce( if (snapshot.commitUser().equals(commitUser) && snapshot.commitIdentifier() == identifier && snapshot.commitKind() == commitKind) { + lastCommittedSnapshotId = snapshot.id(); return new SuccessCommitResult(); } } @@ -964,13 +972,7 @@ CommitResult tryCommitOnce( // try to merge old manifest files to create base manifest list mergeAfterManifests = ManifestFileMerger.merge( - mergeBeforeManifests, - manifestFile, - options.manifestTargetSize().getBytes(), - options.manifestMergeMinCount(), - options.manifestFullCompactionThresholdSize().getBytes(), - partitionType, - options.scanManifestParallelism()); + mergeBeforeManifests, manifestFile, partitionType, options); baseManifestList = manifestList.write(mergeAfterManifests); if (options.rowTrackingEnabled()) { @@ -989,6 +991,10 @@ CommitResult tryCommitOnce( deltaFiles = assigned.assignedEntries; } + if (options.snapshotSequenceOrdering()) { + deltaFiles = stampSequenceWithSnapshotId(newSnapshotId, commitKind, deltaFiles); + } + // the added records subtract the deleted records from long deltaRecordCount = recordCountAdd(deltaFiles) - recordCountDelete(deltaFiles); long totalRecordCount = previousTotalRecordCount + deltaRecordCount; @@ -1102,6 +1108,7 @@ CommitResult tryCommitOnce( if (strictModeChecker != null) { strictModeChecker.update(newSnapshotId); } + lastCommittedSnapshotId = newSnapshotId; CommitCallback.Context context = new CommitCallback.Context( finalBaseFiles, finalDeltaFiles, indexFiles, newSnapshot, identifier); @@ -1190,16 +1197,16 @@ private boolean compactManifestOnce() { manifestList.readDataManifests(latestSnapshot); List mergeAfterManifests; - // the fist trial + // the fist trial: use a copied options with forced full compaction settings + Options compactOptions = Options.fromMap(options.toMap()); + compactOptions.set(CoreOptions.MANIFEST_MERGE_MIN_COUNT, 1); + compactOptions.set(CoreOptions.MANIFEST_FULL_COMPACTION_FILE_SIZE, MemorySize.ofBytes(1)); mergeAfterManifests = ManifestFileMerger.merge( mergeBeforeManifests, manifestFile, - options.manifestTargetSize().getBytes(), - 1, - 1, partitionType, - options.scanManifestParallelism()); + new CoreOptions(compactOptions)); if (new HashSet<>(mergeBeforeManifests).equals(new HashSet<>(mergeAfterManifests))) { // no need to commit this snapshot, because no compact were happened @@ -1265,4 +1272,31 @@ public void close() { IOUtils.closeAllQuietly(commitCallbacks); IOUtils.closeQuietly(snapshotCommit); } + + /** + * Stamps the commit snapshot id into {@link DataFileMeta#minSequenceNumber()} / {@link + * DataFileMeta#maxSequenceNumber()} of APPEND files, reusing these fields instead of adding a + * new one (same pattern as {@link RowTrackingCommitUtils#assignRowTracking}). COMPACT files are + * returned unchanged: their input was read through the override path, so their per-record + * {@code _SEQUENCE_NUMBER} already carries the snapshot id. + * + *

    All records of a snapshot share one id, so intra-snapshot order is not preserved. This is + * accepted: the default spillable writer collapses a commit's writes through the merge function + * to one record per key before flush, and the feature targets cross-snapshot ordering only. + */ + private static List stampSequenceWithSnapshotId( + long snapshotId, CommitKind commitKind, List files) { + if (commitKind == CommitKind.COMPACT) { + return files; + } + List result = new ArrayList<>(files.size()); + for (ManifestEntry entry : files) { + if (entry.kind() == FileKind.ADD) { + result.add(entry.assignSequenceNumber(snapshotId, snapshotId)); + } else { + result.add(entry); + } + } + return result; + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/ManifestAdjacentSortedRun.java b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestAdjacentSortedRun.java new file mode 100644 index 000000000000..ca0797c2139c --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestAdjacentSortedRun.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.operation; + +import org.apache.paimon.manifest.ManifestFileMeta; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * A {@code ManifestAdjacentSortedRun} is a list of {@link ManifestFileMeta}s sorted by a single + * partition field (the configured manifest sort field). The intervals {@code + * [partitionStats.minValues[k], partitionStats.maxValues[k]]} of these manifests do not overlap on + * field {@code k}, where {@code k} is the configured sort field index. + * + *

    Boundary Equality: Files with boundary-touching intervals (min == previous.max) are + * considered non-overlapping and can be placed in the same SortedRun. This reduces the number of + * runs and improves compaction efficiency. However, such files may be separated into different + * Sections during splitIntoSections to avoid merge-sort overhead. + */ +public class ManifestAdjacentSortedRun { + + private int level; + private final List files; + private final long totalSize; + + private ManifestAdjacentSortedRun(List files) { + this.level = -1; + this.files = Collections.unmodifiableList(files); + long size = 0L; + for (ManifestFileMeta file : files) { + size += file.fileSize(); + } + this.totalSize = size; + } + + /** + * Build a {@code ManifestAdjacentSortedRun} from an already-sorted list. The caller MUST + * guarantee that {@code sortedFiles} is sorted ascending on the configured sort field's min + * value, and that intervals do not overlap on that field. + */ + public static ManifestAdjacentSortedRun fromSorted(List sortedFiles) { + return new ManifestAdjacentSortedRun(sortedFiles); + } + + public List files() { + return files; + } + + public long totalSize() { + return totalSize; + } + + public int level() { + return level; + } + + public void setLevel(int level) { + this.level = level; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ManifestAdjacentSortedRun)) { + return false; + } + ManifestAdjacentSortedRun that = (ManifestAdjacentSortedRun) o; + return level == that.level && files.equals(that.files); + } + + @Override + public int hashCode() { + return Objects.hash(level, files); + } + + @Override + public String toString() { + return "ManifestAdjacentSortedRun{level=" + + level + + ", files=[" + + files.stream().map(ManifestFileMeta::fileName).collect(Collectors.joining(", ")) + + "]}"; + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileMerger.java b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileMerger.java index cdcad1ed3e84..f899aa71786f 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileMerger.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileMerger.java @@ -18,6 +18,7 @@ package org.apache.paimon.operation; +import org.apache.paimon.CoreOptions; import org.apache.paimon.data.BinaryRow; import org.apache.paimon.io.RollingFileWriter; import org.apache.paimon.manifest.FileEntry; @@ -48,7 +49,7 @@ import static org.apache.paimon.utils.ManifestReadThreadPool.sequentialBatchedExecute; import static org.apache.paimon.utils.Preconditions.checkArgument; -/** Util for merging manifest files. */ +/** Manifest file merger with standard merge logic and optional sort rewrite. */ public class ManifestFileMerger { private static final Logger LOG = LoggerFactory.getLogger(ManifestFileMerger.class); @@ -62,33 +63,44 @@ public class ManifestFileMerger { public static List merge( List input, ManifestFile manifestFile, - long suggestedMetaSize, - int suggestedMinMetaCount, - long manifestFullCompactionSize, RowType partitionType, - @Nullable Integer manifestReadParallelism) { + CoreOptions options) { + // Extract configuration from options + long suggestedMetaSize = options.manifestTargetSize().getBytes(); + int suggestedMinMetaCount = options.manifestMergeMinCount(); + long manifestFullCompactionSize = options.manifestFullCompactionThresholdSize().getBytes(); + Integer manifestReadParallelism = options.scanManifestParallelism(); + // these are the newly created manifest files, clean them up if exception occurs List newFilesForAbort = new ArrayList<>(); try { - Optional> fullCompacted = - tryFullCompaction( - input, - newFilesForAbort, - manifestFile, - suggestedMetaSize, - manifestFullCompactionSize, - partitionType, - manifestReadParallelism); - return fullCompacted.orElseGet( - () -> - tryMinorCompaction( - input, - newFilesForAbort, - manifestFile, - suggestedMetaSize, - suggestedMinMetaCount, - manifestReadParallelism)); + // If manifest-sort.enabled is enabled and there are partition fields, use + // trySortRewrite + if (options.manifestSortEnabled() && partitionType.getFieldCount() > 0) { + return ManifestFileSorter.trySortCompaction( + input, newFilesForAbort, manifestFile, partitionType, options); + } else { + // Otherwise try full compaction first, then minor compaction if needed + Optional> fullCompacted = + tryFullCompaction( + input, + newFilesForAbort, + manifestFile, + suggestedMetaSize, + manifestFullCompactionSize, + partitionType, + manifestReadParallelism); + return fullCompacted.orElseGet( + () -> + tryMinorCompaction( + input, + newFilesForAbort, + manifestFile, + suggestedMetaSize, + suggestedMinMetaCount, + manifestReadParallelism)); + } } catch (Throwable e) { // exception occurs, clean up and rethrow for (ManifestFileMeta manifest : newFilesForAbort) { @@ -234,7 +246,6 @@ public static Optional> tryFullCompaction( } // 2.2. merge - if (toBeMerged.size() <= 1) { return Optional.empty(); } @@ -295,7 +306,7 @@ private static FullCompactionReadResult readForFullCompaction( return new FullCompactionReadResult(file, requireChange, entries); } - private static Set computeDeletePartitions(Set deleteEntries) { + static Set computeDeletePartitions(Set deleteEntries) { Set partitions = new HashSet<>(); for (FileEntry.Identifier identifier : deleteEntries) { partitions.add(identifier.partition); @@ -303,13 +314,13 @@ private static Set computeDeletePartitions(Set return partitions; } - private static class FullCompactionReadResult { + static class FullCompactionReadResult { - private final ManifestFileMeta file; - private final boolean requireChange; - private final List entries; + final ManifestFileMeta file; + final boolean requireChange; + final List entries; - private FullCompactionReadResult( + FullCompactionReadResult( ManifestFileMeta file, boolean requireChange, List entries) { this.file = file; this.requireChange = requireChange; diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileSorter.java b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileSorter.java new file mode 100644 index 000000000000..39ef0bab5299 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestFileSorter.java @@ -0,0 +1,1149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.operation; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.codegen.CodeGenUtils; +import org.apache.paimon.codegen.RecordComparator; +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.io.RollingFileWriter; +import org.apache.paimon.manifest.FileEntry; +import org.apache.paimon.manifest.FileKind; +import org.apache.paimon.manifest.ManifestEntry; +import org.apache.paimon.manifest.ManifestFile; +import org.apache.paimon.manifest.ManifestFileMeta; +import org.apache.paimon.partition.PartitionPredicate; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.Filter; +import org.apache.paimon.utils.Pair; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.function.Function; + +import static java.util.Collections.singletonList; +import static org.apache.paimon.utils.ManifestReadThreadPool.sequentialBatchedExecute; + +/** Manifest file sorter that sorts and rewrites manifest files by a configured partition field. */ +public class ManifestFileSorter { + + private static final Logger LOG = LoggerFactory.getLogger(ManifestFileSorter.class); + + /** Context object that carries shared state across compaction methods. */ + static class CompactionContext { + final boolean fullCompaction; + final RecordComparator fieldComparator; + final Set deleteEntries; + final Map defaultCompactionMap; + final List levelRuns; + final List pickedRuns; + + CompactionContext( + boolean fullCompaction, + RecordComparator fieldComparator, + Set deleteEntries, + Map defaultCompactionMap, + List levelRuns, + List pickedRuns) { + this.fullCompaction = fullCompaction; + this.fieldComparator = fieldComparator; + this.deleteEntries = deleteEntries; + this.defaultCompactionMap = defaultCompactionMap; + this.levelRuns = levelRuns; + this.pickedRuns = pickedRuns; + } + } + + /** Result of classifying manifest files. */ + private static class ClassifyResult { + final List lsmFiles; + final Set deleteEntries; + final Map defaultCompactionMap; + + ClassifyResult( + List lsmFiles, + Set deleteEntries, + Map defaultCompactionMap) { + this.lsmFiles = lsmFiles; + this.deleteEntries = deleteEntries; + this.defaultCompactionMap = defaultCompactionMap; + } + } + + /** + * Try to sort-rewrite the merged manifest list by a configured partition field. If the sort + * field cannot be resolved, the input is returned as-is. + * + *

    Dispatches to {@link #tryFullCompaction} when totalDeltaFileSize >= sizeTrigger, or {@link + * #tryMinorCompaction} otherwise. + */ + static List trySortCompaction( + List input, + List newFilesForAbort, + ManifestFile manifestFile, + RowType partitionType, + CoreOptions options) + throws Exception { + String sortPartitionField = options.manifestSortPartitionField(); + long suggestedMetaSize = options.manifestTargetSize().getBytes(); + int suggestedMinMetaCount = options.manifestMergeMinCount(); + long fullCompactionThreshold = options.manifestFullCompactionThresholdSize().getBytes(); + long maxRewriteSize = options.manifestSortMaxRewriteSize(); + int maxSizeAmplificationPercent = options.maxSizeAmplificationPercent(); + int sortedRunSizeRatio = options.sortedRunSizeRatio(); + Integer manifestReadParallelism = options.scanManifestParallelism(); + + Optional> fullCompacted = + tryFullCompaction( + input, + newFilesForAbort, + manifestFile, + partitionType, + sortPartitionField, + suggestedMetaSize, + suggestedMinMetaCount, + fullCompactionThreshold, + maxRewriteSize, + maxSizeAmplificationPercent, + sortedRunSizeRatio, + manifestReadParallelism); + if (fullCompacted.isPresent()) { + return fullCompacted.get(); + } + return tryMinorCompaction( + input, + newFilesForAbort, + manifestFile, + partitionType, + sortPartitionField, + suggestedMetaSize, + suggestedMinMetaCount, + maxRewriteSize, + maxSizeAmplificationPercent, + sortedRunSizeRatio, + manifestReadParallelism); + } + + /** + * Full compaction path: totalDeltaFileSize >= sizeTrigger. + * + *

    Does not build index mapping. sortAndRewriteSection writes all entries (ADD+DELETE merged) + * together without separating them. + */ + private static Optional> tryFullCompaction( + List input, + List newFilesForAbort, + ManifestFile manifestFile, + RowType partitionType, + String sortPartitionField, + long suggestedMetaSize, + int suggestedMinMetaCount, + long fullCompactionThreshold, + long maxRewriteSize, + int maxSizeAmplificationPercent, + int sortedRunSizeRatio, + @Nullable Integer manifestReadParallelism) + throws Exception { + // Step 1: Check if full compaction threshold is met + long totalDeltaFileSize = 0; + for (ManifestFileMeta file : input) { + if (file.numDeletedFiles() > 0 || file.fileSize() < suggestedMetaSize) { + totalDeltaFileSize += file.fileSize(); + } + } + if (totalDeltaFileSize < fullCompactionThreshold) { + return Optional.empty(); + } + // Step 2: Prepare compaction context + CompactionContext ctx = + prepareCompaction( + input, + true, + manifestFile, + partitionType, + sortPartitionField, + suggestedMetaSize, + maxSizeAmplificationPercent, + sortedRunSizeRatio, + manifestReadParallelism); + List levelRuns = ctx.levelRuns; + List pickedRuns = ctx.pickedRuns; + + if (pickedRuns.isEmpty() && ctx.defaultCompactionMap.isEmpty()) { + LOG.debug( + "Manifest sort full compact skipped: no runs picked and no defaultCompaction files."); + return Optional.empty(); + } + + LOG.info( + "Manifest sort full compact: input={} files, lsm={} runs, picked={} runs, " + + "defaultCompaction={} files.", + input.size(), + levelRuns.size(), + pickedRuns.size(), + ctx.defaultCompactionMap.size()); + + // Step 3: Collect reused files (not picked) and picked files + Set pickedSet = new HashSet<>(pickedRuns); + List result = new ArrayList<>(); + for (ManifestAdjacentSortedRun run : levelRuns) { + if (!pickedSet.contains(run)) { + result.addAll(run.files()); + } + } + List pickedFiles = new ArrayList<>(); + for (ManifestAdjacentSortedRun run : pickedRuns) { + pickedFiles.addAll(run.files()); + } + pickedFiles.addAll(ctx.defaultCompactionMap.keySet()); + + // Step 4: Split into sections and merge small adjacent sections + List

    sections = + splitIntoSections(pickedFiles, ctx.fieldComparator, ctx.defaultCompactionMap); + sections = mergeSmallAdjacentSections(sections, suggestedMetaSize); + + LOG.info( + "Manifest sort full compact: pickedFiles={}, sections={}.", + pickedFiles.size(), + sections.size()); + + // Step 5: Rewrite sections + FullCompactOutput output = new FullCompactOutput(result); + rewriteSections( + sections, + output, + newFilesForAbort, + ctx, + manifestFile, + suggestedMetaSize, + suggestedMinMetaCount, + maxRewriteSize, + manifestReadParallelism); + + LOG.info( + "Manifest sort full compact completed: input={}, resultFiles={}.", + input.size(), + result.size()); + return Optional.of(result); + } + + /** + * Minor compaction path: totalDeltaFileSize < sizeTrigger. + * + *

    Builds index mapping to preserve original positions. sortAndRewriteSection separates ADD + * and DELETE entries, placing ADD at result[minIdx] and DELETE at result[maxIdx]. + */ + private static List tryMinorCompaction( + List input, + List newFilesForAbort, + ManifestFile manifestFile, + RowType partitionType, + String sortPartitionField, + long suggestedMetaSize, + int suggestedMinMetaCount, + long maxRewriteSize, + int maxSizeAmplificationPercent, + int sortedRunSizeRatio, + @Nullable Integer manifestReadParallelism) + throws Exception { + // Step 1: Prepare compaction context (early-return if nothing to compact) + CompactionContext ctx = + prepareCompaction( + input, + false, + manifestFile, + partitionType, + sortPartitionField, + suggestedMetaSize, + maxSizeAmplificationPercent, + sortedRunSizeRatio, + manifestReadParallelism); + List levelRuns = ctx.levelRuns; + List pickedRuns = ctx.pickedRuns; + + if (pickedRuns.isEmpty() && ctx.defaultCompactionMap.isEmpty()) { + LOG.debug( + "Manifest sort minor compact skipped: no runs picked and no defaultCompaction files."); + return input; + } + + LOG.info( + "Manifest sort minor compact: input={} files, lsm={} runs, picked={} runs, " + + "defaultCompaction={} files.", + input.size(), + levelRuns.size(), + pickedRuns.size(), + ctx.defaultCompactionMap.size()); + + // Step 2: Build fileName -> index mapping and initialize 2D result + Map fileNameToIndex = new HashMap<>(); + List> result = new ArrayList<>(input.size()); + for (int i = 0; i < input.size(); i++) { + fileNameToIndex.put(input.get(i).fileName(), i); + result.add(new ArrayList<>()); + } + + // Step 3: Collect reused files and picked files + Set pickedSet = new HashSet<>(pickedRuns); + for (ManifestAdjacentSortedRun run : levelRuns) { + if (!pickedSet.contains(run)) { + for (ManifestFileMeta file : run.files()) { + Integer idx = fileNameToIndex.get(file.fileName()); + if (idx != null) { + result.get(idx).add(file); + } + } + } + } + + List pickedFiles = new ArrayList<>(); + for (ManifestAdjacentSortedRun run : pickedRuns) { + pickedFiles.addAll(run.files()); + } + pickedFiles.addAll(ctx.defaultCompactionMap.keySet()); + + // Step 4: Compute index range + int minIdx = Integer.MAX_VALUE; + int maxIdx = Integer.MIN_VALUE; + for (ManifestFileMeta meta : pickedFiles) { + Integer idx = fileNameToIndex.get(meta.fileName()); + if (idx != null) { + minIdx = Math.min(minIdx, idx); + maxIdx = Math.max(maxIdx, idx); + } + } + Pair indexRange = Pair.of(minIdx, maxIdx); + + // Step 5: Split into sections and merge small adjacent sections + List

    sections = + splitIntoSections(pickedFiles, ctx.fieldComparator, ctx.defaultCompactionMap); + sections = mergeSmallAdjacentSections(sections, suggestedMetaSize); + + LOG.info( + "Manifest sort minor compact: pickedFiles={}, sections={}.", + pickedFiles.size(), + sections.size()); + + // Step 6: Rewrite sections + MinorCompactOutput output = new MinorCompactOutput(result, indexRange, fileNameToIndex); + rewriteSections( + sections, + output, + newFilesForAbort, + ctx, + manifestFile, + suggestedMetaSize, + suggestedMinMetaCount, + maxRewriteSize, + manifestReadParallelism); + + // Step 7: Flatten 2D result into a single list + List flatResult = new ArrayList<>(); + for (List subList : result) { + flatResult.addAll(subList); + } + + LOG.info( + "Manifest sort minor compact completed: input={}, resultFiles={}.", + input.size(), + flatResult.size()); + return flatResult; + } + + /** + * Prepare compaction context: resolve sort field, classify manifests, build level runs, and + * pick runs for compaction. + * + * @return CompactionContext containing all shared state + */ + private static CompactionContext prepareCompaction( + List input, + boolean fullCompaction, + ManifestFile manifestFile, + RowType partitionType, + String sortPartitionField, + long suggestedMetaSize, + int maxSizeAmplificationPercent, + int sortedRunSizeRatio, + @Nullable Integer manifestReadParallelism) { + + // Step 1: Resolve sort field and build comparator for partition ordering. + String sortField = resolveSortField(sortPartitionField, partitionType); + if (sortField == null) { + throw new IllegalArgumentException( + "Cannot resolve sort field for manifest sort rewrite."); + } + int sortFieldIndex = partitionType.getFieldNames().indexOf(sortField); + RecordComparator fieldComparator = + CodeGenUtils.newRecordComparator( + partitionType.getFieldTypes(), new int[] {sortFieldIndex}); + + // Step 2: Classify manifests into LSM files and collect delete entries. + ClassifyResult classifyResult = + classifyManifests( + input, + fullCompaction, + manifestFile, + partitionType, + suggestedMetaSize, + manifestReadParallelism); + List lsmFiles = classifyResult.lsmFiles; + + // Step 3: Build level-sorted runs from LSM files based on partition order. + List levelRuns = + lsmFiles.isEmpty() + ? new ArrayList<>() + : buildLevelSortedRuns(lsmFiles, fieldComparator); + + // Step 4: Pick runs for compaction using size amplification and ratio strategy. + ManifestPickStrategy pickStrategy = + new ManifestPickStrategy(maxSizeAmplificationPercent, sortedRunSizeRatio); + List pickedRuns = pickStrategy.pick(levelRuns); + + return new CompactionContext( + fullCompaction, + fieldComparator, + classifyResult.deleteEntries, + classifyResult.defaultCompactionMap, + levelRuns, + pickedRuns); + } + + /** + * Classify manifest files into default-compaction group and LSM group. + * + *

    Full compaction: small files and files overlapping delete partitions go into + * defaultCompactionMap; the rest are returned as lsmFiles. + * + *

    Non-full compaction: small files go to defaultCompactionMap for minor-style merge; the + * rest are returned as lsmFiles. + * + * @return ClassifyResult containing lsmFiles, deleteEntries, and defaultCompactionMap + */ + private static ClassifyResult classifyManifests( + List input, + boolean fullCompaction, + ManifestFile manifestFile, + RowType partitionType, + long suggestedMetaSize, + @Nullable Integer manifestReadParallelism) { + // Initialize classification containers and read delete entries + Map classifiedDefaultMap = new LinkedHashMap<>(); + List lsmFiles = new LinkedList<>(input); + Set classifiedDeleteEntries = Collections.emptySet(); + PartitionPredicate predicate = null; + if (fullCompaction) { + classifiedDeleteEntries = + FileEntry.readDeletedEntries(manifestFile, input, manifestReadParallelism); + + // Build partition predicate from delete entries for overlap detection + if (classifiedDeleteEntries.isEmpty()) { + predicate = PartitionPredicate.ALWAYS_FALSE; + } else { + if (partitionType.getFieldCount() > 0) { + Set deletePartitions = + ManifestFileMerger.computeDeletePartitions(classifiedDeleteEntries); + predicate = PartitionPredicate.fromMultiple(partitionType, deletePartitions); + } else { + predicate = PartitionPredicate.ALWAYS_TRUE; + } + } + } + + // Classify each file based on size and delete-partition overlap + Iterator iterator = lsmFiles.iterator(); + while (iterator.hasNext()) { + ManifestFileMeta file = iterator.next(); + boolean small = file.fileSize() < suggestedMetaSize; + boolean inDeleteRange = + predicate != null + && predicate.test( + file.numAddedFiles() + file.numDeletedFiles(), + file.partitionStats().minValues(), + file.partitionStats().maxValues(), + file.partitionStats().nullCounts()); + if (small || inDeleteRange) { + iterator.remove(); + classifiedDefaultMap.put(file, inDeleteRange); + } + } + + return new ClassifyResult(lsmFiles, classifiedDeleteEntries, classifiedDefaultMap); + } + + /** + * Build level-sorted runs from a list of manifest files. Sorts files by min partition value, + * greedy-scans to build non-overlapping SortedRuns, then assigns levels by totalSize (Top-4 + * largest to level 1~4, rest to level 0). + */ + static List buildLevelSortedRuns( + List input, RecordComparator fieldComparator) { + // Step 1: Sort by min value (if equal, then by max value) + input.sort( + (a, b) -> { + int cmp = + fieldComparator.compare( + a.partitionStats().minValues(), b.partitionStats().minValues()); + if (cmp != 0) { + return cmp; + } + return fieldComparator.compare( + a.partitionStats().maxValues(), b.partitionStats().maxValues()); + }); + + // Step 2: Interval graph coloring algorithm - assign files to runs + // Use priority queue to track runs by their max values + PriorityQueue> runs = + new PriorityQueue<>( + (r1, r2) -> { + ManifestFileMeta last1 = r1.get(r1.size() - 1); + ManifestFileMeta last2 = r2.get(r2.size() - 1); + return fieldComparator.compare( + last1.partitionStats().maxValues(), + last2.partitionStats().maxValues()); + }); + + for (ManifestFileMeta file : input) { + List earliestRun = runs.poll(); + if (earliestRun == null) { + // No existing runs, create a new one + List newRun = new ArrayList<>(); + newRun.add(file); + runs.offer(newRun); + } else if (fieldComparator.compare( + file.partitionStats().minValues(), + earliestRun.get(earliestRun.size() - 1).partitionStats().maxValues()) + >= 0) { + // Current file's min >= run's max, append to this run + // Note: When min == max (boundary equality), files are considered + // non-overlapping and can be placed in the same SortedRun. This allows + // building fewer SortedRuns, improving compaction efficiency while + // maintaining correct sort order. However, these files may later be separated + // into different Sections during splitIntoSections to avoid merge-sort overhead. + // + // See ManifestAdjacentSortedRun class comment for the full boundary equality + // semantics. + earliestRun.add(file); + runs.offer(earliestRun); + } else { + // Overlap detected, put the run back and create a new one + runs.offer(earliestRun); + List newRun = new ArrayList<>(); + newRun.add(file); + runs.offer(newRun); + } + } + + // Step 3: Convert to ManifestAdjacentSortedRun list + List result = new ArrayList<>(); + while (!runs.isEmpty()) { + result.add(ManifestAdjacentSortedRun.fromSorted(runs.poll())); + } + + // Step 4: Sort by totalSize and assign levels + result.sort(Comparator.comparingLong(ManifestAdjacentSortedRun::totalSize)); + int n = result.size(); + int maxLevel = ManifestPickStrategy.MAX_LEVEL; + for (int i = 0; i < n; i++) { + if (i >= n - maxLevel) { + result.get(i).setLevel(i - (n - maxLevel) + 1); + } else { + result.get(i).setLevel(0); + } + } + return result; + } + + /** + * Split picked files into sections. Files with overlapping sort-key intervals go into the same + * section. Each section is built with pre-computed totalSize and hasDefaultCompactMeta. + */ + static List

    splitIntoSections( + List pickedFiles, + RecordComparator fieldComparator, + Map defaultCompactionMap) { + pickedFiles.sort( + (a, b) -> { + int cmp = + fieldComparator.compare( + a.partitionStats().minValues(), b.partitionStats().minValues()); + if (cmp != 0) { + return cmp; + } + return fieldComparator.compare( + a.partitionStats().maxValues(), b.partitionStats().maxValues()); + }); + + List
    sections = new ArrayList<>(); + List currentFiles = new ArrayList<>(); + long currentTotalSize = 0; + boolean currentHasDefault = false; + ManifestFileMeta first = pickedFiles.get(0); + currentFiles.add(first); + currentTotalSize += first.fileSize(); + currentHasDefault = defaultCompactionMap.containsKey(first); + BinaryRow sectionMaxBound = first.partitionStats().maxValues(); + + for (int i = 1; i < pickedFiles.size(); i++) { + ManifestFileMeta file = pickedFiles.get(i); + // Note: Boundary equality (file.min == sectionMaxBound) results in separate + // sections. This design choice balances three factors: + // 1. Avoid merge-sort overhead: Files with non-overlapping boundaries can be processed + // independently without merge-sort, improving performance. + // 2. Maintain partition filtering capability: Each section has a distinct key range, + // enabling efficient partition pruning during queries. + // 3. Preserve ordering invariant: Separating boundary-touching files into different + // sections + // does not break the global sort order, as they are still processed in ascending + // order. + // + // IMPORTANT: While boundary-touching files are separated into different Sections here, + // they may be placed in the same SortedRun during buildLevelSortedRuns (which uses >= 0 + // comparison). This dual behavior is intentional and documented in class comments. + if (fieldComparator.compare(file.partitionStats().minValues(), sectionMaxBound) >= 0) { + sections.add(new Section(currentFiles, currentTotalSize, currentHasDefault)); + currentFiles = new ArrayList<>(); + currentTotalSize = 0; + currentFiles.add(file); + currentTotalSize += file.fileSize(); + currentHasDefault = defaultCompactionMap.containsKey(file); + sectionMaxBound = file.partitionStats().maxValues(); + } else { + currentFiles.add(file); + currentTotalSize += file.fileSize(); + if (!currentHasDefault && defaultCompactionMap.containsKey(file)) { + currentHasDefault = true; + } + if (fieldComparator.compare(file.partitionStats().maxValues(), sectionMaxBound) + > 0) { + sectionMaxBound = file.partitionStats().maxValues(); + } + } + } + sections.add(new Section(currentFiles, currentTotalSize, currentHasDefault)); + return sections; + } + + /** + * Merge small adjacent sections to avoid producing too many small rewrite batches. If either + * the pending section or the current section total size is smaller than {@code + * suggestedMetaSize}, they are combined into a single section. + */ + private static List
    mergeSmallAdjacentSections( + List
    sections, long suggestedMetaSize) { + List
    merged = new ArrayList<>(); + Section pending = null; + + for (Section section : sections) { + if (pending == null) { + pending = section; + } else { + if (pending.totalSize < suggestedMetaSize + || section.totalSize < suggestedMetaSize) { + pending = Section.merge(pending, section); + } else { + merged.add(pending); + pending = section; + } + } + } + if (pending != null) { + merged.add(pending); + } + return merged; + } + + /** + * Rewrite sections with budget control. + * + *

    Semantics of manifest-sort.max-rewrite-size: This budget applies only to the sorted + * rewrite portion. When the cumulative size reaches the limit: + * + *

      + *
    • First overflow: The current section is split. The rewritable part is sorted and + * rewritten. The remaining part is appended back to the sections queue for later + * processing. + *
    • Subsequent overflows: If the section has files in defaultCompactionMap (needs default + * compaction), rewriteSubSegments is called to process it in smaller chunks. Otherwise, + * the section is skipped. + *
    + * + *

    This design ensures that the budget only limits the aggressive sort rewrite, while still + * allowing necessary cleanup operations (delete entry elimination, small file merge) through + * the rewriteSubSegments fallback path. + */ + private static void rewriteSections( + List

    sections, + RewriteOutput output, + List sortNewFiles, + CompactionContext ctx, + ManifestFile manifestFile, + long suggestedMetaSize, + int suggestedMinMetaCount, + long maxRewriteSize, + @Nullable Integer manifestReadParallelism) + throws Exception { + long processedSize = 0; + boolean reachedLimit = false; + + for (int i = 0; i < sections.size(); i++) { + Section section = sections.get(i); + if (section.files.size() == 1) { + sortAndRewriteSection( + section.files, + output, + sortNewFiles, + ctx, + manifestFile, + manifestReadParallelism); + continue; + } + + if (processedSize + section.totalSize <= maxRewriteSize) { + processedSize += section.totalSize; + sortAndRewriteSection( + section.files, + output, + sortNewFiles, + ctx, + manifestFile, + manifestReadParallelism); + } else if (!reachedLimit) { + long rewriteTotalSize = maxRewriteSize - processedSize; + processedSize += section.totalSize; + List rewriteFiles = new ArrayList<>(); + List remainingFiles = new ArrayList<>(); + long rewriteSize = 0; + long remainingSize = 0; + boolean remainingHasDefault = false; + + for (ManifestFileMeta file : section.files) { + if (rewriteSize + file.fileSize() <= rewriteTotalSize) { + rewriteFiles.add(file); + rewriteSize += file.fileSize(); + } else { + remainingFiles.add(file); + remainingSize += file.fileSize(); + if (ctx.defaultCompactionMap.containsKey(file)) { + remainingHasDefault = true; + } + } + } + + sortAndRewriteSection( + rewriteFiles, + output, + sortNewFiles, + ctx, + manifestFile, + manifestReadParallelism); + + if (!remainingFiles.isEmpty()) { + Section remainingSection = + new Section(remainingFiles, remainingSize, remainingHasDefault); + // global manifest file metas order by sort key is not a required invariant + sections.add(remainingSection); + } + reachedLimit = true; + } else if (section.hasDefaultCompactMeta) { + rewriteSubSegments( + section.files, + output, + sortNewFiles, + ctx, + manifestFile, + suggestedMetaSize, + suggestedMinMetaCount, + manifestReadParallelism); + } else { + output.addAllUnchanged(section.files); + } + } + } + + /** + * Rewrite a section in smaller sub-segments when it exceeds the sort rewrite budget. + * + *

    Semantics difference from old minor merge: In the old ManifestFileMerger path, the + * trailing candidates are kept unchanged when their count is below manifest.merge-min-count. In + * this sort path, rewriteSubSegments is triggered when defaultCompactionMap is non-empty, + * regardless of the manifest count. This is because files in defaultCompactionMap either: + * + *

      + *
    • Are small files needing consolidation + *
    • Contain delete entries that must be eliminated + *
    + * + *

    The manifest.merge-min-count threshold is still applied to the final sub-segment's tail, + * acting as a conservative gate to avoid unnecessary rewrite when there are no delete entries + * and the tail is too small. + */ + private static void rewriteSubSegments( + List section, + RewriteOutput output, + List sortNewFiles, + CompactionContext ctx, + ManifestFile manifestFile, + long suggestedMetaSize, + int suggestedMinMetaCount, + @Nullable Integer manifestReadParallelism) + throws Exception { + List subSegment = new ArrayList<>(); + long subSegmentSize = 0; + for (ManifestFileMeta m : section) { + subSegmentSize += m.fileSize(); + subSegment.add(m); + + if (subSegmentSize >= suggestedMetaSize) { + sortAndRewriteSection( + subSegment, + output, + sortNewFiles, + ctx, + manifestFile, + manifestReadParallelism); + subSegment.clear(); + subSegmentSize = 0; + } + } + // Flush tail only if delete entries exist or file count >= minCount. + if (!subSegment.isEmpty()) { + if (!ctx.deleteEntries.isEmpty() || subSegment.size() >= suggestedMinMetaCount) { + sortAndRewriteSection( + subSegment, + output, + sortNewFiles, + ctx, + manifestFile, + manifestReadParallelism); + } else { + output.addAllUnchanged(subSegment); + } + } + } + + /** + * Sort and rewrite a section. Dispatches to full or minor compact path. + * + *

    sortNewFiles is the same reference as newFilesForAbort, ensuring newly written files are + * cleaned up on exception by the caller's catch block. + */ + private static void sortAndRewriteSection( + List section, + RewriteOutput output, + List sortNewFiles, + CompactionContext ctx, + ManifestFile manifestFile, + @Nullable Integer manifestReadParallelism) + throws Exception { + // Skip rewrite for single file not in delete-range. + if (section.size() == 1 && !ctx.defaultCompactionMap.getOrDefault(section.get(0), false)) { + output.addUnchanged(section.get(0)); + return; + } + + if (ctx.fullCompaction) { + sortAndRewriteFull( + section, output, sortNewFiles, ctx, manifestFile, manifestReadParallelism); + } else { + sortAndRewriteMinor( + section, output, sortNewFiles, ctx, manifestFile, manifestReadParallelism); + } + } + + /** + * Full compaction path: read all surviving entries (ADD merged with DELETE), sort them + * together, and write to output as a single sorted stream. + */ + private static void sortAndRewriteFull( + List section, + RewriteOutput output, + List sortNewFiles, + CompactionContext ctx, + ManifestFile manifestFile, + @Nullable Integer manifestReadParallelism) + throws Exception { + // Read surviving ADD entries: filter out entries cancelled by deleteEntries. + Function> reader = + meta -> { + List batch = new ArrayList<>(); + for (ManifestEntry entry : + manifestFile.read( + meta.fileName(), + meta.fileSize(), + FileEntry.addFilter(), + Filter.alwaysTrue())) { + if (!ctx.deleteEntries.contains(entry.identifier())) { + batch.add(entry); + } + } + return batch; + }; + + List entries = new ArrayList<>(); + for (ManifestEntry entry : + sequentialBatchedExecute(reader, section, manifestReadParallelism)) { + entries.add(entry); + } + + if (!entries.isEmpty()) { + List sorted = + sortAndWriteEntries(entries, ctx.fieldComparator, manifestFile); + output.addSortedFiles(sorted); + sortNewFiles.addAll(sorted); + } + } + + /** + * Minor compaction path: read entries with ADD/DELETE classified in a single pass per file, + * then sort each group independently and write them to output. + * + *

    Each file is read in parallel (via sequentialBatchedExecute). The reader classifies + * entries into ADD and DELETE within each file, returning a Pair. Results are merged in the + * main thread. + */ + private static void sortAndRewriteMinor( + List section, + RewriteOutput output, + List sortNewFiles, + CompactionContext ctx, + ManifestFile manifestFile, + @Nullable Integer manifestReadParallelism) + throws Exception { + // Read and classify ADD/DELETE in one pass per file. + Function, List>>> reader = + meta -> { + List addBatch = new ArrayList<>(); + List deleteBatch = new ArrayList<>(); + for (ManifestEntry entry : + manifestFile.read(meta.fileName(), meta.fileSize())) { + if (entry.kind() == FileKind.ADD) { + addBatch.add(entry); + } else { + deleteBatch.add(entry); + } + } + return singletonList(Pair.of(addBatch, deleteBatch)); + }; + + Map addMap = new HashMap<>(); + List minorDeleteEntries = new ArrayList<>(); + for (Pair, List> pair : + sequentialBatchedExecute(reader, section, manifestReadParallelism)) { + for (ManifestEntry entry : pair.getLeft()) { + addMap.put(entry.identifier(), entry); + } + minorDeleteEntries.addAll(pair.getRight()); + } + + // Cancel out ADD+DELETE pairs with the same identifier within the section. + minorDeleteEntries.removeIf( + manifestEntry -> addMap.remove(manifestEntry.identifier()) != null); + List addEntries = new ArrayList<>(addMap.values()); + + if (!addEntries.isEmpty()) { + List sorted = + sortAndWriteEntries(addEntries, ctx.fieldComparator, manifestFile); + output.addSortedFiles(sorted); + sortNewFiles.addAll(sorted); + } + + if (!minorDeleteEntries.isEmpty()) { + List sorted = + sortAndWriteEntries(minorDeleteEntries, ctx.fieldComparator, manifestFile); + output.addDeleteFiles(sorted); + sortNewFiles.addAll(sorted); + } + } + + /** Sort entries and write them to a new manifest file with proper error handling. */ + private static List sortAndWriteEntries( + List entries, + RecordComparator fieldComparator, + ManifestFile manifestFile) + throws Exception { + entries.sort((a, b) -> compareSortKey(a, b, fieldComparator)); + RollingFileWriter writer = + manifestFile.createRollingWriter(); + Exception exception = null; + try { + writer.write(entries); + } catch (Exception e) { + exception = e; + } finally { + if (exception != null) { + writer.abort(); + throw exception; + } + writer.close(); + } + return writer.result(); + } + + /** + * Compare two {@link ManifestEntry}s by the composite key {@code (sort-field, kind, fileName)}. + * {@code fileName} is used as the tie-breaker so that all entries sharing the same sort-field + * value AND the same data file are emitted contiguously. + */ + static int compareSortKey(ManifestEntry a, ManifestEntry b, RecordComparator fieldComparator) { + int c = fieldComparator.compare(a.partition(), b.partition()); + if (c != 0) { + return c; + } + // ADD before DELETE + int kindCmp = a.kind().compareTo(b.kind()); + if (kindCmp != 0) { + return kindCmp; + } + return a.file().fileName().compareTo(b.file().fileName()); + } + + /** + * Resolve the partition field to sort manifests by. + * + *

    Resolution rules: + * + *

      + *
    1. If {@code manifest-sort.partition-field} is configured, return that value. + *
    2. Otherwise, default to the first partition field. + *
    + */ + static String resolveSortField(String sortPartitionField, RowType partitionType) { + if (sortPartitionField != null && !sortPartitionField.isEmpty()) { + return sortPartitionField; + } + return partitionType.getFieldNames().get(0); + } + + /** Strategy interface for writing compaction results. */ + interface RewriteOutput { + void addUnchanged(ManifestFileMeta file); + + void addAllUnchanged(List files); + + void addSortedFiles(List files); + + void addDeleteFiles(List files); + } + + private static class FullCompactOutput implements RewriteOutput { + private final List result; + + FullCompactOutput(List result) { + this.result = result; + } + + @Override + public void addUnchanged(ManifestFileMeta file) { + result.add(file); + } + + @Override + public void addAllUnchanged(List files) { + result.addAll(files); + } + + @Override + public void addSortedFiles(List files) { + result.addAll(files); + } + + @Override + public void addDeleteFiles(List files) { + result.addAll(files); + } + } + + private static class MinorCompactOutput implements RewriteOutput { + private final List> result; + private final Pair indexRange; + private final Map fileNameToIndex; + + MinorCompactOutput( + List> result, + Pair indexRange, + Map fileNameToIndex) { + this.result = result; + this.indexRange = indexRange; + this.fileNameToIndex = fileNameToIndex; + } + + @Override + public void addUnchanged(ManifestFileMeta file) { + Integer idx = fileNameToIndex.get(file.fileName()); + result.get(idx).add(file); + } + + @Override + public void addAllUnchanged(List files) { + for (ManifestFileMeta file : files) { + addUnchanged(file); + } + } + + @Override + public void addSortedFiles(List files) { + result.get(indexRange.getLeft()).addAll(files); + } + + @Override + public void addDeleteFiles(List files) { + result.get(indexRange.getRight()).addAll(files); + } + } + + /** A section of manifest files with pre-computed metadata. */ + static class Section { + final List files; + final long totalSize; + final boolean hasDefaultCompactMeta; + + Section(List files, long totalSize, boolean hasDefaultCompactMeta) { + this.files = files; + this.totalSize = totalSize; + this.hasDefaultCompactMeta = hasDefaultCompactMeta; + } + + /** Create a merged section from two sections. */ + static Section merge(Section a, Section b) { + List merged = new ArrayList<>(a.files); + merged.addAll(b.files); + return new Section( + merged, + a.totalSize + b.totalSize, + a.hasDefaultCompactMeta || b.hasDefaultCompactMeta); + } + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/ManifestPickStrategy.java b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestPickStrategy.java new file mode 100644 index 000000000000..519c49676ce3 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/operation/ManifestPickStrategy.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.operation; + +import org.apache.paimon.utils.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +/** + * Pick strategy for manifest LSM Tree compaction. + * + *

    Strategy priority: + * + *

      + *
    1. SizeAmp: if all lower-level runs' total size exceeds the highest-level run's size + * times {@code sizeAmpThreshold}, trigger full compaction (pick all runs). + *
    2. SizeRatio: from low to high, pick adjacent runs whose amplification factor is less + * than {@code sizeRatioThreshold}. + *
    3. Forced pick: level0 and level1 runs are always picked. + *
    + */ +public class ManifestPickStrategy { + + public static final int MAX_LEVEL = 4; + + private final int sizeAmpThreshold; + private final int sizeRatioThreshold; + + public ManifestPickStrategy(int sizeAmpThreshold, int sizeRatioThreshold) { + Preconditions.checkArgument(sizeAmpThreshold > 0, "sizeAmpThreshold must be positive"); + Preconditions.checkArgument(sizeRatioThreshold > 0, "sizeRatioThreshold must be positive"); + this.sizeAmpThreshold = sizeAmpThreshold; + this.sizeRatioThreshold = sizeRatioThreshold; + } + + /** + * Pick runs that need compaction from the given level runs. + * + * @param levelRuns runs with assigned levels (level 0~4) + * @return list of picked runs to compact + */ + public List pick(List levelRuns) { + if (levelRuns.isEmpty() || levelRuns.size() <= MAX_LEVEL) { + return new ArrayList<>(); + } + + // Try SizeAmp first + List sizeAmpResult = pickForSizeAmp(levelRuns); + if (sizeAmpResult != null) { + return sizeAmpResult; + } + + // SizeRatio + forced pick + return pickForSizeRatioAndForce(levelRuns); + } + + /** + * SizeAmp check: if all lower-level (0~3) runs' total size exceeds the highest-level run's size + * by more than {@code sizeAmpThreshold} percent, pick all runs for full compaction. + * + *

    Formula (consistent with {@code UniversalCompaction#pickForSizeAmp}): {@code + * lowerLevelTotalSize * 100 > sizeAmpThreshold * highestRunSize} + */ + private List pickForSizeAmp( + List levelRuns) { + if (levelRuns.isEmpty()) { + return null; + } + + // The last run has the highest level (set by buildLevelSortedRuns) + ManifestAdjacentSortedRun highestRun = levelRuns.get(levelRuns.size() - 1); + int maxLevel = highestRun.level(); + + if (maxLevel <= 0) { + return null; + } + + long lowerLevelTotalSize = 0; + for (ManifestAdjacentSortedRun run : levelRuns) { + if (run.level() < maxLevel) { + lowerLevelTotalSize += run.totalSize(); + } + } + + // size amplification = percentage of additional size + if (lowerLevelTotalSize * 100 > (long) sizeAmpThreshold * highestRun.totalSize()) { + return new ArrayList<>(levelRuns); + } + return null; + } + + /** + * SizeRatio + forced pick. + * + *

      + *
    • Level0 and level1 are always picked. + *
    • From low to high, if the cumulative picked size with ratio amplification covers the + * next run's size, continue picking. + *
    + * + *

    Formula (consistent with {@code UniversalCompaction#pickForSizeRatio}): {@code pickedSize + * * (100.0 + sizeRatioThreshold) / 100.0 >= nextRunSize} + */ + private List pickForSizeRatioAndForce( + List levelRuns) { + // levelRuns is already sorted by level ascending (set by buildLevelSortedRuns) + List picked = new ArrayList<>(); + + // Always pick the first run to guarantee a non-empty result. + picked.add(levelRuns.get(0)); + long pickedSize = levelRuns.get(0).totalSize(); + + // From the second run onward: forced pick level0/level1, then SizeRatio for the rest. + for (int i = 1; i < levelRuns.size(); i++) { + ManifestAdjacentSortedRun run = levelRuns.get(i); + if (run.level() <= 1) { + picked.add(run); + pickedSize += run.totalSize(); + } else { + long nextRunSize = run.totalSize(); + if (pickedSize * (100 + sizeRatioThreshold) >= nextRunSize * 100L) { + picked.add(run); + pickedSize += nextRunSize; + } + } + } + if (picked.size() == 1) { + return new ArrayList<>(); + } + return picked; + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/NormalPartitionExpire.java b/paimon-core/src/main/java/org/apache/paimon/operation/NormalPartitionExpire.java new file mode 100644 index 000000000000..5a23f538a9f3 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/operation/NormalPartitionExpire.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.operation; + +import org.apache.paimon.annotation.VisibleForTesting; +import org.apache.paimon.catalog.Catalog; +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.manifest.PartitionEntry; +import org.apache.paimon.partition.PartitionExpireStrategy; +import org.apache.paimon.partition.PartitionValuesTimeExpireStrategy; +import org.apache.paimon.table.PartitionModification; + +import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.Duration; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +/** Expire partitions. */ +public class NormalPartitionExpire implements PartitionExpire { + + private static final Logger LOG = LoggerFactory.getLogger(NormalPartitionExpire.class); + + private static final String DELIMITER = ","; + + private final Duration expirationTime; + private final Duration checkInterval; + private final FileStoreScan scan; + private final FileStoreCommit commit; + @Nullable private final PartitionModification partitionModification; + private LocalDateTime lastCheck; + private final PartitionExpireStrategy strategy; + private final boolean endInputCheckPartitionExpire; + private final int maxExpireNum; + private final int expireBatchSize; + + public NormalPartitionExpire( + Duration expirationTime, + Duration checkInterval, + PartitionExpireStrategy strategy, + FileStoreScan scan, + FileStoreCommit commit, + @Nullable PartitionModification partitionModification, + boolean endInputCheckPartitionExpire, + int maxExpireNum, + int expireBatchSize) { + this.expirationTime = expirationTime; + this.checkInterval = checkInterval; + this.strategy = strategy; + this.scan = scan; + this.commit = commit; + this.partitionModification = partitionModification; + // Avoid the execution time of stream jobs from being too short and preventing partition + // expiration + long rndSeconds = 0; + long checkIntervalSeconds = checkInterval.toMillis() / 1000; + if (checkIntervalSeconds > 0) { + rndSeconds = ThreadLocalRandom.current().nextLong(checkIntervalSeconds); + } + this.lastCheck = LocalDateTime.now().minusSeconds(rndSeconds); + this.endInputCheckPartitionExpire = endInputCheckPartitionExpire; + this.maxExpireNum = maxExpireNum; + this.expireBatchSize = expireBatchSize; + } + + public NormalPartitionExpire( + Duration expirationTime, + Duration checkInterval, + PartitionExpireStrategy strategy, + FileStoreScan scan, + FileStoreCommit commit, + @Nullable PartitionModification partitionModification, + int maxExpireNum, + int expireBatchSize) { + this( + expirationTime, + checkInterval, + strategy, + scan, + commit, + partitionModification, + false, + maxExpireNum, + expireBatchSize); + } + + @Override + public List> expire(long commitIdentifier) { + return expire(LocalDateTime.now(), commitIdentifier); + } + + @Override + public boolean isValueExpiration() { + return strategy instanceof PartitionValuesTimeExpireStrategy; + } + + @Override + public boolean isValueAllExpired(Collection partitions) { + PartitionValuesTimeExpireStrategy valuesStrategy = + (PartitionValuesTimeExpireStrategy) strategy; + LocalDateTime expireDateTime = LocalDateTime.now().minus(expirationTime); + for (BinaryRow partition : partitions) { + if (!valuesStrategy.isExpired(expireDateTime, partition)) { + return false; + } + } + return true; + } + + @VisibleForTesting + void setLastCheck(LocalDateTime time) { + lastCheck = time; + } + + @VisibleForTesting + List> expire(LocalDateTime now, long commitIdentifier) { + if (checkInterval.isZero() + || now.isAfter(lastCheck.plus(checkInterval)) + || (endInputCheckPartitionExpire && Long.MAX_VALUE == commitIdentifier)) { + List> expired = + doExpire(now.minus(expirationTime), commitIdentifier); + lastCheck = now; + return expired; + } + return null; + } + + private List> doExpire( + LocalDateTime expireDateTime, long commitIdentifier) { + List partitionEntries = + strategy.selectExpiredPartitions(scan, expireDateTime); + List> expiredPartValues = new ArrayList<>(partitionEntries.size()); + for (PartitionEntry partition : partitionEntries) { + Object[] array = strategy.convertPartition(partition.partition()); + expiredPartValues.add(strategy.toPartitionValue(array)); + } + + List> expired = new ArrayList<>(); + if (!expiredPartValues.isEmpty()) { + // convert partition value to partition string, and limit the partition num + expired = convertToPartitionString(expiredPartValues); + LOG.info("Expire Partitions: {}", expired); + if (expireBatchSize > 0 && expireBatchSize < expired.size()) { + Lists.partition(expired, expireBatchSize) + .forEach( + expiredBatchPartitions -> + doBatchExpire(expiredBatchPartitions, commitIdentifier)); + } else { + doBatchExpire(expired, commitIdentifier); + } + } + return expired; + } + + private void doBatchExpire( + List> expiredBatchPartitions, long commitIdentifier) { + if (partitionModification != null) { + try { + partitionModification.dropPartitions(expiredBatchPartitions); + // also drop corresponding .done partitions + partitionModification.dropPartitions(toDonePartitions(expiredBatchPartitions)); + } catch (Catalog.TableNotExistException e) { + throw new RuntimeException(e); + } + } else { + // .done partitions only exist when partitionModification != null + // (metastore.partitioned-table = true), so no need to handle them here + commit.dropPartitions(expiredBatchPartitions, commitIdentifier); + } + } + + private List> toDonePartitions( + List> expiredPartitions) { + List> donePartitions = new ArrayList<>(expiredPartitions.size()); + for (Map partition : expiredPartitions) { + LinkedHashMap donePartition = new LinkedHashMap<>(partition); + // append .done suffix to the last partition field value + Map.Entry lastEntry = null; + for (Map.Entry entry : donePartition.entrySet()) { + lastEntry = entry; + } + if (lastEntry != null) { + donePartition.put(lastEntry.getKey(), lastEntry.getValue() + ".done"); + donePartitions.add(donePartition); + } + } + return donePartitions; + } + + private List> convertToPartitionString( + List> expiredPartValues) { + return expiredPartValues.stream() + .map(values -> String.join(DELIMITER, values)) + .sorted() + // Use split(DELIMITER, -1) to preserve trailing empty strings + .map(s -> s.split(DELIMITER, -1)) + .map(strategy::toPartitionString) + .limit(Math.min(expiredPartValues.size(), maxExpireNum)) + .collect(Collectors.toList()); + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/PartitionExpire.java b/paimon-core/src/main/java/org/apache/paimon/operation/PartitionExpire.java index 74b7850ed73d..bb1487735b27 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/PartitionExpire.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/PartitionExpire.java @@ -18,208 +18,38 @@ package org.apache.paimon.operation; -import org.apache.paimon.annotation.VisibleForTesting; -import org.apache.paimon.catalog.Catalog; import org.apache.paimon.data.BinaryRow; -import org.apache.paimon.manifest.PartitionEntry; -import org.apache.paimon.partition.PartitionExpireStrategy; -import org.apache.paimon.partition.PartitionValuesTimeExpireStrategy; -import org.apache.paimon.table.PartitionModification; - -import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import javax.annotation.Nullable; -import java.time.Duration; -import java.time.LocalDateTime; -import java.util.ArrayList; import java.util.Collection; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ThreadLocalRandom; -import java.util.stream.Collectors; - -/** Expire partitions. */ -public class PartitionExpire { - - private static final Logger LOG = LoggerFactory.getLogger(PartitionExpire.class); - - private static final String DELIMITER = ","; - - private final Duration expirationTime; - private final Duration checkInterval; - private final FileStoreScan scan; - private final FileStoreCommit commit; - @Nullable private final PartitionModification partitionModification; - private LocalDateTime lastCheck; - private final PartitionExpireStrategy strategy; - private final boolean endInputCheckPartitionExpire; - private final int maxExpireNum; - private final int expireBatchSize; - - public PartitionExpire( - Duration expirationTime, - Duration checkInterval, - PartitionExpireStrategy strategy, - FileStoreScan scan, - FileStoreCommit commit, - @Nullable PartitionModification partitionModification, - boolean endInputCheckPartitionExpire, - int maxExpireNum, - int expireBatchSize) { - this.expirationTime = expirationTime; - this.checkInterval = checkInterval; - this.strategy = strategy; - this.scan = scan; - this.commit = commit; - this.partitionModification = partitionModification; - // Avoid the execution time of stream jobs from being too short and preventing partition - // expiration - long rndSeconds = 0; - long checkIntervalSeconds = checkInterval.toMillis() / 1000; - if (checkIntervalSeconds > 0) { - rndSeconds = ThreadLocalRandom.current().nextLong(checkIntervalSeconds); - } - this.lastCheck = LocalDateTime.now().minusSeconds(rndSeconds); - this.endInputCheckPartitionExpire = endInputCheckPartitionExpire; - this.maxExpireNum = maxExpireNum; - this.expireBatchSize = expireBatchSize; - } - - public PartitionExpire( - Duration expirationTime, - Duration checkInterval, - PartitionExpireStrategy strategy, - FileStoreScan scan, - FileStoreCommit commit, - @Nullable PartitionModification partitionModification, - int maxExpireNum, - int expireBatchSize) { - this( - expirationTime, - checkInterval, - strategy, - scan, - commit, - partitionModification, - false, - maxExpireNum, - expireBatchSize); - } - - public List> expire(long commitIdentifier) { - return expire(LocalDateTime.now(), commitIdentifier); - } - - public boolean isValueExpiration() { - return strategy instanceof PartitionValuesTimeExpireStrategy; - } - public boolean isValueAllExpired(Collection partitions) { - PartitionValuesTimeExpireStrategy valuesStrategy = - (PartitionValuesTimeExpireStrategy) strategy; - LocalDateTime expireDateTime = LocalDateTime.now().minus(expirationTime); - for (BinaryRow partition : partitions) { - if (!valuesStrategy.isExpired(expireDateTime, partition)) { - return false; - } - } - return true; - } - - @VisibleForTesting - void setLastCheck(LocalDateTime time) { - lastCheck = time; - } - - @VisibleForTesting - List> expire(LocalDateTime now, long commitIdentifier) { - if (checkInterval.isZero() - || now.isAfter(lastCheck.plus(checkInterval)) - || (endInputCheckPartitionExpire && Long.MAX_VALUE == commitIdentifier)) { - List> expired = - doExpire(now.minus(expirationTime), commitIdentifier); - lastCheck = now; - return expired; - } - return null; - } - - private List> doExpire( - LocalDateTime expireDateTime, long commitIdentifier) { - List partitionEntries = - strategy.selectExpiredPartitions(scan, expireDateTime); - List> expiredPartValues = new ArrayList<>(partitionEntries.size()); - for (PartitionEntry partition : partitionEntries) { - Object[] array = strategy.convertPartition(partition.partition()); - expiredPartValues.add(strategy.toPartitionValue(array)); - } - - List> expired = new ArrayList<>(); - if (!expiredPartValues.isEmpty()) { - // convert partition value to partition string, and limit the partition num - expired = convertToPartitionString(expiredPartValues); - LOG.info("Expire Partitions: {}", expired); - if (expireBatchSize > 0 && expireBatchSize < expired.size()) { - Lists.partition(expired, expireBatchSize) - .forEach( - expiredBatchPartitions -> - doBatchExpire(expiredBatchPartitions, commitIdentifier)); - } else { - doBatchExpire(expired, commitIdentifier); - } - } - return expired; - } - - private void doBatchExpire( - List> expiredBatchPartitions, long commitIdentifier) { - if (partitionModification != null) { - try { - partitionModification.dropPartitions(expiredBatchPartitions); - // also drop corresponding .done partitions - partitionModification.dropPartitions(toDonePartitions(expiredBatchPartitions)); - } catch (Catalog.TableNotExistException e) { - throw new RuntimeException(e); - } - } else { - // .done partitions only exist when partitionModification != null - // (metastore.partitioned-table = true), so no need to handle them here - commit.dropPartitions(expiredBatchPartitions, commitIdentifier); - } - } - - private List> toDonePartitions( - List> expiredPartitions) { - List> donePartitions = new ArrayList<>(expiredPartitions.size()); - for (Map partition : expiredPartitions) { - LinkedHashMap donePartition = new LinkedHashMap<>(partition); - // append .done suffix to the last partition field value - Map.Entry lastEntry = null; - for (Map.Entry entry : donePartition.entrySet()) { - lastEntry = entry; - } - if (lastEntry != null) { - donePartition.put(lastEntry.getKey(), lastEntry.getValue() + ".done"); - donePartitions.add(donePartition); - } - } - return donePartitions; - } - - private List> convertToPartitionString( - List> expiredPartValues) { - return expiredPartValues.stream() - .map(values -> String.join(DELIMITER, values)) - .sorted() - // Use split(DELIMITER, -1) to preserve trailing empty strings - .map(s -> s.split(DELIMITER, -1)) - .map(strategy::toPartitionString) - .limit(Math.min(expiredPartValues.size(), maxExpireNum)) - .collect(Collectors.toList()); - } +/** + * Common interface for partition expiration strategies. + * + *

    Implementations include {@link NormalPartitionExpire} for normal tables and {@link + * ChainTablePartitionExpire} for chain tables that require segment-based expiration across snapshot + * and delta branches. + */ +public interface PartitionExpire { + + /** + * Expire partitions that are older than the configured expiration time. + * + * @return the list of expired partition specs, or null if the check interval has not elapsed + */ + @Nullable + List> expire(long commitIdentifier); + + /** Whether this expiration uses values-time strategy. */ + boolean isValueExpiration(); + + /** + * Check whether all given partitions are expired according to the values-time strategy. + * + *

    Only valid when {@link #isValueExpiration()} returns true. + */ + boolean isValueAllExpired(Collection partitions); } diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitMetrics.java b/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitMetrics.java index c89bae8c9aa9..874635a9fde8 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitMetrics.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitMetrics.java @@ -74,6 +74,7 @@ public MetricGroup getMetricGroup() { public static final String LAST_COMPACTION_INPUT_FILE_SIZE = "lastCompactionInputFileSize"; public static final String LAST_COMPACTION_OUTPUT_FILE_SIZE = "lastCompactionOutputFileSize"; + public static final String LAST_COMMITTED_SNAPSHOT_ID = "lastCommittedSnapshotId"; private void registerGenericCommitMetrics() { metricGroup.gauge( @@ -126,6 +127,9 @@ private void registerGenericCommitMetrics() { metricGroup.gauge( LAST_COMPACTION_OUTPUT_FILE_SIZE, () -> latestCommit == null ? 0L : latestCommit.getCompactionOutputFileSize()); + metricGroup.gauge( + LAST_COMMITTED_SNAPSHOT_ID, + () -> latestCommit == null ? -1L : latestCommit.getLastCommittedSnapshotId()); } public void reportCommit(CommitStats commitStats) { diff --git a/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitStats.java b/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitStats.java index 11d6854270b6..ed53b611c2c5 100644 --- a/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitStats.java +++ b/paimon-core/src/main/java/org/apache/paimon/operation/metrics/CommitStats.java @@ -53,6 +53,7 @@ public class CommitStats { private final long generatedSnapshots; private final long numPartitionsWritten; private final long numBucketsWritten; + private final long lastCommittedSnapshotId; public CommitStats( List appendTableFiles, @@ -61,7 +62,8 @@ public CommitStats( List compactChangelogFiles, long commitDuration, int generatedSnapshots, - int attempts) { + int attempts, + long lastCommittedSnapshotId) { List addedTableFiles = appendTableFiles.stream() .filter(f -> FileKind.ADD.equals(f.kind())) @@ -110,6 +112,7 @@ public CommitStats( this.duration = commitDuration; this.generatedSnapshots = generatedSnapshots; this.attempts = attempts; + this.lastCommittedSnapshotId = lastCommittedSnapshotId; } @VisibleForTesting @@ -236,4 +239,9 @@ public long getCompactionInputFileSize() { public long getCompactionOutputFileSize() { return compactionOutputFileSize; } + + @VisibleForTesting + protected long getLastCommittedSnapshotId() { + return lastCommittedSnapshotId; + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/privilege/PrivilegedFileStore.java b/paimon-core/src/main/java/org/apache/paimon/privilege/PrivilegedFileStore.java index 6ced52bbd1a0..a1f6d5fbb295 100644 --- a/paimon-core/src/main/java/org/apache/paimon/privilege/PrivilegedFileStore.java +++ b/paimon-core/src/main/java/org/apache/paimon/privilege/PrivilegedFileStore.java @@ -219,9 +219,13 @@ public ServiceManager newServiceManager() { } @Override - public boolean mergeSchema(RowType rowType, boolean allowExplicitCast) { + public boolean mergeSchema( + RowType rowType, + boolean typeWidening, + boolean allowExplicitCast, + boolean caseSensitive) { privilegeChecker.assertCanInsert(identifier); - return wrapped.mergeSchema(rowType, allowExplicitCast); + return wrapped.mergeSchema(rowType, typeWidening, allowExplicitCast, caseSensitive); } @Override diff --git a/paimon-core/src/main/java/org/apache/paimon/rest/RESTCatalog.java b/paimon-core/src/main/java/org/apache/paimon/rest/RESTCatalog.java index 7e2f6cfd2743..9d76cfdf4fef 100644 --- a/paimon-core/src/main/java/org/apache/paimon/rest/RESTCatalog.java +++ b/paimon-core/src/main/java/org/apache/paimon/rest/RESTCatalog.java @@ -761,15 +761,7 @@ public void dropBranch(Identifier identifier, String branch) throws BranchNotExi @Override public void renameBranch(Identifier identifier, String fromBranch, String toBranch) throws BranchNotExistException, BranchAlreadyExistException { - try { - api.renameBranch(identifier, fromBranch, toBranch); - } catch (NoSuchResourceException e) { - throw new BranchNotExistException(identifier, fromBranch, e); - } catch (AlreadyExistsException e) { - throw new BranchAlreadyExistException(identifier, toBranch, e); - } catch (ForbiddenException e) { - throw new TableNoPermissionException(identifier, e); - } + throw new UnsupportedOperationException(); } @Override diff --git a/paimon-core/src/main/java/org/apache/paimon/schema/ColumnDirectiveUtils.java b/paimon-core/src/main/java/org/apache/paimon/schema/ColumnDirectiveUtils.java new file mode 100644 index 000000000000..d523bbb37047 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/schema/ColumnDirectiveUtils.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.schema; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.options.ConfigOption; +import org.apache.paimon.options.FallbackKey; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.BlobType; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.DataTypeRoot; +import org.apache.paimon.types.VectorType; +import org.apache.paimon.utils.Preconditions; +import org.apache.paimon.utils.StringUtils; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Utilities for column comment directives (BLOB / VECTOR type conversion via ADD COLUMN). */ +public final class ColumnDirectiveUtils { + + public static final String BLOB_FIELD_DIRECTIVE = "__BLOB_FIELD"; + public static final String BLOB_DESCRIPTOR_FIELD_DIRECTIVE = "__BLOB_DESCRIPTOR_FIELD"; + public static final String BLOB_VIEW_FIELD_DIRECTIVE = "__BLOB_VIEW_FIELD"; + public static final String BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE = + "__BLOB_EXTERNAL_STORAGE_FIELD"; + public static final String VECTOR_FIELD_DIRECTIVE = "__VECTOR_FIELD"; + + private ColumnDirectiveUtils() {} + + /** + * Parses the comment of an {@code ALTER TABLE ADD COLUMN} statement. Returns {@code null} when + * the comment is a regular user comment; returns a {@link ParsedDirective} when the comment + * begins with a supported directive. Throws {@link IllegalArgumentException} when the comment + * begins with {@code __BLOB} or {@code __VECTOR} but is not one of the supported directives. + */ + @Nullable + public static ParsedDirective parseAddColumnComment(@Nullable String comment) { + if (comment == null) { + return null; + } + comment = StringUtils.trim(comment); + if (comment.startsWith("__VECTOR")) { + return parseVectorDirective(comment); + } + if (!comment.startsWith("__BLOB")) { + return null; + } + // match longer prefixes first + String optionKey = matchDirective(comment, BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE); + String marker = BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE; + if (optionKey == null) { + optionKey = matchDirective(comment, BLOB_VIEW_FIELD_DIRECTIVE); + marker = BLOB_VIEW_FIELD_DIRECTIVE; + } + if (optionKey == null) { + optionKey = matchDirective(comment, BLOB_DESCRIPTOR_FIELD_DIRECTIVE); + marker = BLOB_DESCRIPTOR_FIELD_DIRECTIVE; + } + if (optionKey == null) { + optionKey = matchDirective(comment, BLOB_FIELD_DIRECTIVE); + marker = BLOB_FIELD_DIRECTIVE; + } + Preconditions.checkArgument( + optionKey != null, + "Unsupported BLOB directive in column comment: '%s'. Supported directives are " + + "'%s', '%s', '%s' and '%s'.", + comment, + BLOB_FIELD_DIRECTIVE, + BLOB_DESCRIPTOR_FIELD_DIRECTIVE, + BLOB_VIEW_FIELD_DIRECTIVE, + BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE); + String realComment = + comment.length() == marker.length() + ? null + : comment.substring(marker.length() + 1).trim(); + if (realComment != null && realComment.isEmpty()) { + realComment = null; + } + return new ParsedDirective(optionKey, realComment, false, 0); + } + + private static ParsedDirective parseVectorDirective(String comment) { + String optionKey = matchDirective(comment, VECTOR_FIELD_DIRECTIVE); + Preconditions.checkArgument( + optionKey != null, + "Unsupported VECTOR directive in column comment: '%s'. Supported directive is '%s'.", + comment, + VECTOR_FIELD_DIRECTIVE); + if (comment.length() == VECTOR_FIELD_DIRECTIVE.length()) { + throw new IllegalArgumentException( + String.format( + "VECTOR directive '%s' requires a dimension, e.g. '%s;128' or '%s;128; my comment'.", + comment, VECTOR_FIELD_DIRECTIVE, VECTOR_FIELD_DIRECTIVE)); + } + String rest = comment.substring(VECTOR_FIELD_DIRECTIVE.length() + 1); + int dimEnd = rest.indexOf(';'); + String dimStr; + String realComment; + if (dimEnd < 0) { + dimStr = rest.trim(); + realComment = null; + } else { + dimStr = rest.substring(0, dimEnd).trim(); + realComment = rest.substring(dimEnd + 1).trim(); + if (realComment.isEmpty()) { + realComment = null; + } + } + int dim; + try { + dim = Integer.parseInt(dimStr); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format( + "Expected an integer dimension after '%s;', but got: '%s'.", + VECTOR_FIELD_DIRECTIVE, dimStr)); + } + Preconditions.checkArgument(dim >= 1, "Vector dimension must be >= 1, but got: %s.", dim); + return new ParsedDirective(optionKey, realComment, true, dim); + } + + @Nullable + private static String matchDirective(String comment, String marker) { + if (!comment.startsWith(marker)) { + return null; + } + if (comment.length() == marker.length()) { + return optionKeyFor(marker); + } + return comment.charAt(marker.length()) == ';' ? optionKeyFor(marker) : null; + } + + private static String optionKeyFor(String marker) { + if (BLOB_FIELD_DIRECTIVE.equals(marker)) { + return CoreOptions.BLOB_FIELD.key(); + } else if (BLOB_DESCRIPTOR_FIELD_DIRECTIVE.equals(marker)) { + return CoreOptions.BLOB_DESCRIPTOR_FIELD.key(); + } else if (BLOB_VIEW_FIELD_DIRECTIVE.equals(marker)) { + return CoreOptions.BLOB_VIEW_FIELD.key(); + } else if (BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE.equals(marker)) { + return CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key(); + } else if (VECTOR_FIELD_DIRECTIVE.equals(marker)) { + return CoreOptions.VECTOR_FIELD.key(); + } else { + throw new IllegalArgumentException("Unsupported directive: " + marker); + } + } + + /** + * One-stop method for ADD COLUMN: parses the comment, converts the type, and appends the field + * to the corresponding option. Returns {@code null} when the comment is not a directive. + */ + @Nullable + public static ConvertedColumn applyAddColumnDirective( + @Nullable String comment, + String fieldName, + DataType sourceType, + Map options) { + ParsedDirective directive = parseAddColumnComment(comment); + if (directive == null) { + return null; + } + DataType newType = convertType(directive, fieldName, sourceType); + modifyFieldOptions(directive.optionKey(), fieldName, options); + if (CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key().equals(directive.optionKey())) { + modifyFieldOptions(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), fieldName, options); + } + return new ConvertedColumn(newType, directive.realComment()); + } + + /** + * Process comment directives on every field of a {@link Schema} used for CREATE TABLE. Fields + * whose comment matches a directive get their type converted and the corresponding option + * appended; the directive prefix is stripped from the stored comment. + */ + public static Schema applyDirectives(Schema schema) { + List fields = schema.fields(); + Map options = new HashMap<>(schema.options()); + List newFields = new ArrayList<>(fields.size()); + boolean changed = false; + + for (DataField field : fields) { + ConvertedColumn converted = + applyAddColumnDirective( + field.description(), field.name(), field.type(), options); + if (converted == null) { + newFields.add(field); + } else { + changed = true; + newFields.add( + new DataField( + field.id(), field.name(), converted.type(), converted.comment())); + } + } + + if (!changed) { + return schema; + } + return new Schema( + newFields, schema.partitionKeys(), schema.primaryKeys(), options, schema.comment()); + } + + private static DataType convertType( + ParsedDirective directive, String fieldName, DataType sourceType) { + if (directive.isVector()) { + Preconditions.checkArgument( + sourceType.getTypeRoot() == DataTypeRoot.ARRAY, + "Column %s declared with a VECTOR directive must be of ARRAY type, but was %s.", + fieldName, + sourceType); + DataType elementType = ((ArrayType) sourceType).getElementType(); + return new VectorType(sourceType.isNullable(), directive.vectorDim(), elementType); + } else { + DataTypeRoot root = sourceType.getTypeRoot(); + Preconditions.checkArgument( + root == DataTypeRoot.VARBINARY + || root == DataTypeRoot.BINARY + || root == DataTypeRoot.BLOB, + "Column %s declared with a BLOB directive must be of BYTES, " + + "BINARY or BLOB type, but was %s.", + fieldName, + sourceType); + return new BlobType(sourceType.isNullable()); + } + } + + /** + * Append {@code fieldName} to the comma-separated option identified by {@code optionKey}. If + * the canonical key is empty but a fallback key holds the value (e.g. legacy {@code + * blob.stored-descriptor-fields}), the fallback value is migrated to the canonical key before + * appending so old entries are not shadowed. + */ + public static void modifyFieldOptions( + String optionKey, String fieldName, Map options) { + ConfigOption option; + if (CoreOptions.BLOB_FIELD.key().equals(optionKey)) { + option = CoreOptions.BLOB_FIELD; + } else if (CoreOptions.BLOB_DESCRIPTOR_FIELD.key().equals(optionKey)) { + option = CoreOptions.BLOB_DESCRIPTOR_FIELD; + } else if (CoreOptions.BLOB_VIEW_FIELD.key().equals(optionKey)) { + option = CoreOptions.BLOB_VIEW_FIELD; + } else if (CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key().equals(optionKey)) { + option = CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD; + } else if (CoreOptions.VECTOR_FIELD.key().equals(optionKey)) { + option = CoreOptions.VECTOR_FIELD; + } else { + throw new IllegalArgumentException("Unsupported directive: " + optionKey); + } + + String existing = options.get(optionKey); + if (existing == null || existing.isEmpty()) { + for (FallbackKey fk : option.fallbackKeys()) { + String fallbackValue = options.remove(fk.getKey()); + if (fallbackValue != null && !fallbackValue.isEmpty()) { + existing = fallbackValue; + break; + } + } + } + String newValue = existing == null ? fieldName : existing + "," + fieldName; + options.put(optionKey, newValue); + } + + private static final ConfigOption[] BLOB_OPTIONS = + new ConfigOption[] { + CoreOptions.BLOB_FIELD, + CoreOptions.BLOB_DESCRIPTOR_FIELD, + CoreOptions.BLOB_VIEW_FIELD, + CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD + }; + + private static final ConfigOption[] VECTOR_OPTIONS = + new ConfigOption[] {CoreOptions.VECTOR_FIELD}; + + /** + * Remove directive-managed options when a BLOB or VECTOR column is dropped. Only acts on BLOB + * or VECTOR type columns; other types are ignored. + */ + public static void removeDroppedDirectiveOptions( + String fieldName, DataTypeRoot typeRoot, Map options) { + if (typeRoot == DataTypeRoot.BLOB) { + for (ConfigOption option : BLOB_OPTIONS) { + removeFromCsvOption(option.key(), fieldName, options); + for (FallbackKey fk : option.fallbackKeys()) { + removeFromCsvOption(fk.getKey(), fieldName, options); + } + } + } else if (typeRoot == DataTypeRoot.VECTOR) { + for (ConfigOption option : VECTOR_OPTIONS) { + removeFromCsvOption(option.key(), fieldName, options); + for (FallbackKey fk : option.fallbackKeys()) { + removeFromCsvOption(fk.getKey(), fieldName, options); + } + } + options.remove(String.format("field.%s.vector-dim", fieldName)); + } + } + + private static void removeFromCsvOption( + String key, String fieldName, Map options) { + String existing = options.get(key); + if (existing == null || existing.isEmpty()) { + return; + } + StringBuilder sb = new StringBuilder(); + for (String v : existing.split(",")) { + String trimmed = v.trim(); + if (trimmed.isEmpty() || trimmed.equals(fieldName)) { + continue; + } + if (sb.length() > 0) { + sb.append(','); + } + sb.append(trimmed); + } + if (sb.length() == 0) { + options.remove(key); + } else { + options.put(key, sb.toString()); + } + } + + /** Result of {@link #applyAddColumnDirective}: the converted type and effective comment. */ + public static final class ConvertedColumn { + private final DataType type; + @Nullable private final String comment; + + private ConvertedColumn(DataType type, @Nullable String comment) { + this.type = type; + this.comment = comment; + } + + public DataType type() { + return type; + } + + @Nullable + public String comment() { + return comment; + } + } + + /** Parsed directive: the option key to update, user-facing comment, and vector metadata. */ + public static final class ParsedDirective { + private final String optionKey; + @Nullable private final String realComment; + private final boolean isVector; + private final int vectorDim; + + private ParsedDirective( + String optionKey, @Nullable String realComment, boolean isVector, int vectorDim) { + this.optionKey = optionKey; + this.realComment = realComment; + this.isVector = isVector; + this.vectorDim = vectorDim; + } + + public String optionKey() { + return optionKey; + } + + @Nullable + public String realComment() { + return realComment; + } + + public boolean isVector() { + return isVector; + } + + public int vectorDim() { + return vectorDim; + } + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaManager.java b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaManager.java index 670960889d5c..95c6f04514db 100644 --- a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaManager.java +++ b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaManager.java @@ -25,6 +25,7 @@ import org.apache.paimon.catalog.Identifier; import org.apache.paimon.fs.FileIO; import org.apache.paimon.fs.Path; +import org.apache.paimon.schema.ColumnDirectiveUtils.ConvertedColumn; import org.apache.paimon.schema.SchemaChange.AddColumn; import org.apache.paimon.schema.SchemaChange.DropColumn; import org.apache.paimon.schema.SchemaChange.RemoveOption; @@ -101,6 +102,8 @@ import static org.apache.paimon.catalog.Identifier.DEFAULT_MAIN_BRANCH; import static org.apache.paimon.catalog.Identifier.UNKNOWN_DATABASE; import static org.apache.paimon.mergetree.compact.PartialUpdateMergeFunction.SEQUENCE_GROUP; +import static org.apache.paimon.schema.ColumnDirectiveUtils.applyAddColumnDirective; +import static org.apache.paimon.schema.ColumnDirectiveUtils.applyDirectives; import static org.apache.paimon.utils.DefaultValueUtils.validateDefaultValue; import static org.apache.paimon.utils.FileUtils.listVersionedFiles; import static org.apache.paimon.utils.Preconditions.checkArgument; @@ -196,6 +199,7 @@ public TableSchema createTable(Schema schema, boolean externalTable) throws Exce } } + schema = applyDirectives(schema); TableSchema newSchema = TableSchema.create(0, schema); // validate table from creating table @@ -342,8 +346,27 @@ public static TableSchema generateTableSchema( "Column %s cannot specify NOT NULL in the %s table.", String.join(".", addColumn.fieldNames()), lazyIdentifier.get().getFullName()); + + ConvertedColumn converted = + applyAddColumnDirective( + addColumn.description(), + addColumn.fieldNames()[0], + addColumn.dataType(), + newOptions); + DataType requestedDataType = addColumn.dataType(); + String effectiveComment = addColumn.description(); + if (converted != null) { + Preconditions.checkArgument( + addColumn.fieldNames().length == 1, + "Comment directive cannot be used on a nested column %s.", + String.join(".", addColumn.fieldNames())); + requestedDataType = converted.type(); + effectiveComment = converted.comment(); + } + int id = highestFieldId.incrementAndGet(); - DataType dataType = ReassignFieldId.reassign(addColumn.dataType(), highestFieldId); + DataType dataType = ReassignFieldId.reassign(requestedDataType, highestFieldId); + String storedComment = effectiveComment; new NestedColumnModifier(addColumn.fieldNames(), lazyIdentifier) { @Override protected void updateLastColumn( @@ -352,8 +375,7 @@ protected void updateLastColumn( Catalog.ColumnNotExistException { assertColumnNotExists(newFields, fieldName, lazyIdentifier); - DataField dataField = - new DataField(id, fieldName, dataType, addColumn.description()); + DataField dataField = new DataField(id, fieldName, dataType, storedComment); // key: name ; value : index Map map = new HashMap<>(); @@ -435,6 +457,16 @@ protected void updateLastColumn( } else if (change instanceof DropColumn) { DropColumn drop = (DropColumn) change; dropColumnValidation(oldTableSchema, drop); + if (drop.fieldNames().length == 1) { + String dropName = drop.fieldNames()[0]; + newFields.stream() + .filter(f -> f.name().equals(dropName)) + .findFirst() + .ifPresent( + f -> + ColumnDirectiveUtils.removeDroppedDirectiveOptions( + dropName, f.type().getTypeRoot(), newOptions)); + } new NestedColumnModifier(drop.fieldNames(), lazyIdentifier) { @Override protected void updateLastColumn( @@ -451,6 +483,8 @@ protected void updateLastColumn( UpdateColumnType update = (UpdateColumnType) change; assertNotUpdatingPartitionKeys(oldTableSchema, update.fieldNames(), "update"); assertNotUpdatingPrimaryKeys(oldTableSchema, update.fieldNames(), "update"); + assertNotChangingBlobColumnType( + newFields, update.fieldNames(), update.newDataType()); updateNestedColumn( newFields, update.fieldNames(), @@ -716,22 +750,33 @@ private static void checkMoveIndexEqual(SchemaChange.Move move, int fieldIndex, } } + /** + * Merge {@code rowType} into the current schema (via {@link SchemaMergingUtils#mergeSchemas}) + * and persist the result. Returns {@code true} if the schema changed and was committed, {@code + * false} if the merge was a no-op. See {@code SchemaMergingUtils} for how {@code typeWidening} + * / {@code allowExplicitCast} drive existing-column type evolution. + */ public boolean mergeSchema( RowType rowType, + boolean typeWidening, boolean allowExplicitCast, + boolean caseSensitive, @Nullable SchemaModification schemaModification) { TableSchema current = latest().orElseThrow( () -> new RuntimeException( "It requires that the current schema to exist when calling 'mergeSchema'")); - TableSchema update = SchemaMergingUtils.mergeSchemas(current, rowType, allowExplicitCast); + TableSchema update = + SchemaMergingUtils.mergeSchemas( + current, rowType, typeWidening, allowExplicitCast, caseSensitive); if (current.equals(update)) { return false; } try { if (schemaModification != null) { - List changes = SchemaMergingUtils.diffSchemaChanges(current, update); + List changes = + SchemaMergingUtils.diffSchemaChanges(current, update, caseSensitive); schemaModification.alterSchema(changes); return true; } else { @@ -923,6 +968,28 @@ private static void assertNotRenamingBlobColumn(List fields, String[] } } + private static void assertNotChangingBlobColumnType( + List fields, String[] fieldNames, DataType newType) { + if (fieldNames.length > 1) { + return; + } + String fieldName = fieldNames[0]; + for (DataField field : fields) { + if (!field.name().equals(fieldName)) { + continue; + } + boolean wasBlob = field.type().is(DataTypeRoot.BLOB); + boolean willBeBlob = newType.is(DataTypeRoot.BLOB); + if (wasBlob || willBeBlob) { + throw new UnsupportedOperationException( + String.format( + "Cannot change column type involving BLOB: [%s] %s -> %s", + fieldName, field.type(), newType)); + } + return; + } + } + private abstract static class NestedColumnModifier { private final String[] updateFieldNames; diff --git a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java index 990087d792f9..d444d89ca4ad 100644 --- a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java +++ b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java @@ -30,11 +30,11 @@ import org.apache.paimon.types.RowType; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.TreeMap; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; -import java.util.stream.Collectors; /** The util class for merging the schemas. */ public class SchemaMergingUtils { @@ -43,7 +43,11 @@ public class SchemaMergingUtils { public static final String MAP_VALUE_FIELD_NAME = "value"; public static TableSchema mergeSchemas( - TableSchema currentTableSchema, RowType targetType, boolean allowExplicitCast) { + TableSchema currentTableSchema, + RowType targetType, + boolean typeWidening, + boolean allowExplicitCast, + boolean caseSensitive) { RowType currentType = currentTableSchema.logicalRowType(); if (currentType.equals(targetType)) { return currentTableSchema; @@ -51,7 +55,13 @@ public static TableSchema mergeSchemas( AtomicInteger highestFieldId = new AtomicInteger(currentTableSchema.highestFieldId()); RowType newRowType = - mergeSchemas(currentType, targetType, highestFieldId, allowExplicitCast); + mergeSchemas( + currentType, + targetType, + highestFieldId, + typeWidening, + allowExplicitCast, + caseSensitive); if (newRowType.equals(currentType)) { // It happens if the `targetType` only changes `nullability` but we always respect the // current's. @@ -72,29 +82,40 @@ public static RowType mergeSchemas( RowType tableSchema, RowType dataSchema, AtomicInteger highestFieldId, - boolean allowExplicitCast) { - return (RowType) merge(tableSchema, dataSchema, highestFieldId, allowExplicitCast); + boolean typeWidening, + boolean allowExplicitCast, + boolean caseSensitive) { + return (RowType) + merge( + tableSchema, + dataSchema, + highestFieldId, + typeWidening, + allowExplicitCast, + caseSensitive); } /** - * Merge the base data type and the update data type if possible. + * Merge the base (target) data type with the update (incoming) data type. * - *

    For RowType, find the fields which exists in both the base schema and the update schema, - * and try to merge them by calling the method iteratively; remain those fields that are only in - * the base schema and append those fields that are only in the update schema. - * - *

    For other complex type, try to merge the element types. - * - *

    For primitive data type, we treat that's compatible if the original type can be safely - * cast to the new type. + *

      + *
    • RowType: merge existing fields recursively, keep base-only fields, append update-only + * fields as new columns. + *
    • Complex types (Array/Map/Multiset): recursively merge element/value types. + *
    • Leaf types when {@code typeWidening=false} (default): keep the base type unchanged — + * incoming data is cast to it by the alignment layer. + *
    • Leaf types when {@code typeWidening=true}: widen the base type to the update type if + * the cast is safe (or explicit when {@code allowExplicitCast=true}). + *
    */ public static DataType merge( DataType base0, DataType update0, AtomicInteger highestFieldId, - boolean allowExplicitCast) { - // Here we try to merge the base0 and update0 without regard to the nullability, - // and set the base0's nullability to the return's. + boolean typeWidening, + boolean allowExplicitCast, + boolean caseSensitive) { + // Compare ignoring nullability; the base's nullability flows to the result. DataType base = base0.copy(true); DataType update = update0.copy(true); @@ -103,45 +124,38 @@ public static DataType merge( } else if (base instanceof RowType && update instanceof RowType) { List baseFields = ((RowType) base).getFields(); List updateFields = ((RowType) update).getFields(); - Map updateFieldMap = - updateFields.stream() - .collect(Collectors.toMap(DataField::name, Function.identity())); - List updatedFields = - baseFields.stream() - .map( - baseField -> { - if (updateFieldMap.containsKey(baseField.name())) { - DataField updateField = - updateFieldMap.get(baseField.name()); - DataType updatedDataType = - merge( - baseField.type(), - updateField.type(), - highestFieldId, - allowExplicitCast); - return new DataField( - baseField.id(), - baseField.name(), - updatedDataType, - baseField.description(), - baseField.defaultValue()); - } else { - return baseField; - } - }) - .collect(Collectors.toList()); + Map updateFieldMap = buildFieldMap(updateFields, caseSensitive); + List updatedFields = new ArrayList<>(); + for (DataField baseField : baseFields) { + if (updateFieldMap.containsKey(baseField.name())) { + DataField updateField = updateFieldMap.get(baseField.name()); + DataType updatedDataType = + merge( + baseField.type(), + updateField.type(), + highestFieldId, + typeWidening, + allowExplicitCast, + caseSensitive); + updatedFields.add( + new DataField( + baseField.id(), + baseField.name(), + updatedDataType, + baseField.description(), + baseField.defaultValue())); + } else { + updatedFields.add(baseField); + } + } - Map baseFieldMap = - baseFields.stream() - .collect(Collectors.toMap(DataField::name, Function.identity())); - List newFields = - updateFields.stream() - .filter(field -> !baseFieldMap.containsKey(field.name())) - .map(field -> assignIdForNewField(field, highestFieldId)) - .map(field -> field.copy(true)) - .collect(Collectors.toList()); + Map baseFieldMap = buildFieldMap(baseFields, caseSensitive); + for (DataField field : updateFields) { + if (!baseFieldMap.containsKey(field.name())) { + updatedFields.add(assignIdForNewField(field, highestFieldId).copy(true)); + } + } - updatedFields.addAll(newFields); return new RowType(base0.isNullable(), updatedFields); } else if (base instanceof MapType && update instanceof MapType) { return new MapType( @@ -150,12 +164,16 @@ public static DataType merge( ((MapType) base).getKeyType(), ((MapType) update).getKeyType(), highestFieldId, - allowExplicitCast), + typeWidening, + allowExplicitCast, + caseSensitive), merge( ((MapType) base).getValueType(), ((MapType) update).getValueType(), highestFieldId, - allowExplicitCast)); + typeWidening, + allowExplicitCast, + caseSensitive)); } else if (base instanceof ArrayType && update instanceof ArrayType) { return new ArrayType( base0.isNullable(), @@ -163,7 +181,9 @@ public static DataType merge( ((ArrayType) base).getElementType(), ((ArrayType) update).getElementType(), highestFieldId, - allowExplicitCast)); + typeWidening, + allowExplicitCast, + caseSensitive)); } else if (base instanceof MultisetType && update instanceof MultisetType) { return new MultisetType( base0.isNullable(), @@ -171,7 +191,13 @@ public static DataType merge( ((MultisetType) base).getElementType(), ((MultisetType) update).getElementType(), highestFieldId, - allowExplicitCast)); + typeWidening, + allowExplicitCast, + caseSensitive)); + } else if (!typeWidening) { + // Default: keep the existing leaf type — only column additions evolve the schema. + // Incoming values are cast to this type by the alignment layer. + return base0; } else if (base instanceof DecimalType && update instanceof DecimalType) { if (((DecimalType) base).getScale() == ((DecimalType) update).getScale()) { return new DecimalType( @@ -243,13 +269,14 @@ private static DataField assignIdForNewField(DataField field, AtomicInteger high * This supports detecting added columns and type changes (including nested structs). */ public static List diffSchemaChanges( - TableSchema oldSchema, TableSchema newSchema) { + TableSchema oldSchema, TableSchema newSchema, boolean caseSensitive) { List changes = new ArrayList<>(); diffFields( oldSchema.logicalRowType().getFields(), newSchema.logicalRowType().getFields(), new String[0], - changes); + changes, + caseSensitive); return changes; } @@ -257,9 +284,9 @@ private static void diffFields( List oldFields, List newFields, String[] parentNames, - List changes) { - Map oldFieldMap = - oldFields.stream().collect(Collectors.toMap(DataField::name, Function.identity())); + List changes, + boolean caseSensitive) { + Map oldFieldMap = buildFieldMap(oldFields, caseSensitive); for (DataField newField : newFields) { String[] fieldNames = appendFieldName(parentNames, newField.name()); @@ -271,7 +298,7 @@ private static void diffFields( fieldNames, newField.type(), newField.description(), null)); } else if (!oldField.type().equals(newField.type()) && !diffNestedTypeChanges( - oldField.type(), newField.type(), fieldNames, changes)) { + oldField.type(), newField.type(), fieldNames, changes, caseSensitive)) { changes.add(SchemaChange.updateColumnType(fieldNames, newField.type(), true)); } } @@ -282,9 +309,15 @@ private static void diffFields( * changes. Returns false to let the caller fall back to {@link SchemaChange.UpdateColumnType}. */ private static boolean diffNestedTypeChanges( - DataType oldType, DataType newType, String[] fieldNames, List changes) { + DataType oldType, + DataType newType, + String[] fieldNames, + List changes, + boolean caseSensitive) { List stagedChanges = new ArrayList<>(); - boolean handled = diffNestedTypeChangesInner(oldType, newType, fieldNames, stagedChanges); + boolean handled = + diffNestedTypeChangesInner( + oldType, newType, fieldNames, stagedChanges, caseSensitive); if (handled) { changes.addAll(stagedChanges); } @@ -292,21 +325,26 @@ private static boolean diffNestedTypeChanges( } private static boolean diffNestedTypeChangesInner( - DataType oldType, DataType newType, String[] fieldNames, List changes) { + DataType oldType, + DataType newType, + String[] fieldNames, + List changes, + boolean caseSensitive) { if (oldType instanceof RowType && newType instanceof RowType) { List oldFields = ((RowType) oldType).getFields(); List newFields = ((RowType) newType).getFields(); - if (hasRemovedFields(oldFields, newFields)) { + if (hasRemovedFields(oldFields, newFields, caseSensitive)) { return false; } - diffFields(oldFields, newFields, fieldNames, changes); + diffFields(oldFields, newFields, fieldNames, changes, caseSensitive); return true; } else if (oldType instanceof ArrayType && newType instanceof ArrayType) { return diffNestedTypeChanges( ((ArrayType) oldType).getElementType(), ((ArrayType) newType).getElementType(), appendFieldName(fieldNames, ARRAY_ELEMENT_FIELD_NAME), - changes); + changes, + caseSensitive); } else if (oldType instanceof MapType && newType instanceof MapType) { MapType oldMapType = (MapType) oldType; MapType newMapType = (MapType) newType; @@ -317,14 +355,15 @@ private static boolean diffNestedTypeChangesInner( oldMapType.getValueType(), newMapType.getValueType(), appendFieldName(fieldNames, MAP_VALUE_FIELD_NAME), - changes); + changes, + caseSensitive); } return false; } - private static boolean hasRemovedFields(List oldFields, List newFields) { - Map newFieldMap = - newFields.stream().collect(Collectors.toMap(DataField::name, Function.identity())); + private static boolean hasRemovedFields( + List oldFields, List newFields, boolean caseSensitive) { + Map newFieldMap = buildFieldMap(newFields, caseSensitive); for (DataField oldField : oldFields) { if (!newFieldMap.containsKey(oldField.name())) { return true; @@ -333,6 +372,16 @@ private static boolean hasRemovedFields(List oldFields, List buildFieldMap( + List fields, boolean caseSensitive) { + Map map = + caseSensitive ? new HashMap<>() : new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + for (DataField field : fields) { + map.put(field.name(), field); + } + return map; + } + private static String[] appendFieldName(String[] parentNames, String fieldName) { String[] result = new String[parentNames.length + 1]; System.arraycopy(parentNames, 0, result, 0, parentNames.length); diff --git a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java index 3d02ede61743..c600927af0f8 100644 --- a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java +++ b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java @@ -78,6 +78,7 @@ import static org.apache.paimon.CoreOptions.SNAPSHOT_NUM_RETAINED_MIN; import static org.apache.paimon.CoreOptions.STREAMING_READ_OVERWRITE; import static org.apache.paimon.format.FileFormat.vectorFileFormat; +import static org.apache.paimon.schema.TableSchema.PAIMON_07_VERSION; import static org.apache.paimon.table.PrimaryKeyTableUtils.createMergeFunctionFactory; import static org.apache.paimon.table.SpecialFields.KEY_FIELD_PREFIX; import static org.apache.paimon.table.SpecialFields.SYSTEM_FIELD_NAMES; @@ -104,6 +105,18 @@ public class SchemaValidation { * @param schema the schema to be validated */ public static void validateTableSchema(TableSchema schema) { + validateTableSchema(schema, Collections.emptySet()); + } + + /** + * Validate the {@link TableSchema} and {@link CoreOptions}. + * + * @param schema the schema to be validated + * @param dynamicOptionKeys option keys that are overridden dynamically at runtime (e.g. by + * dedicated compaction jobs) and should therefore be excluded from certain static + * validations such as the {@code write-only} requirement for snapshot ordering + */ + public static void validateTableSchema(TableSchema schema, Set dynamicOptionKeys) { CoreOptions options = new CoreOptions(schema.options()); validateOnlyContainPrimitiveType(schema.fields(), schema.primaryKeys(), "primary key"); @@ -287,6 +300,10 @@ public static void validateTableSchema(TableSchema schema) { "deletion-vectors.merge-on-read requires deletion-vectors.enabled to be true."); } + if (options.snapshotSequenceOrdering()) { + validateSnapshotSequenceOrdering(schema, options, dynamicOptionKeys); + } + // vector field names must point to vector type Set fieldNamesSpecifiedAsVector = options.vectorField(); schema.fields() @@ -318,6 +335,8 @@ public static void validateTableSchema(TableSchema schema) { validateChangelogReadSequenceNumber(schema, options); validatePkClusteringOverride(options); + + validateManifestSort(schema, options); } public static void validateFallbackBranch(SchemaManager schemaManager, TableSchema schema) { @@ -611,6 +630,32 @@ private static void validateFileIndex(TableSchema schema) { } } + private static void validateSnapshotSequenceOrdering( + TableSchema schema, CoreOptions options, Set dynamicOptionKeys) { + checkArgument( + !schema.primaryKeys().isEmpty(), + "%s = true requires a primary-key table; append-only tables cannot use " + + "snapshot-based sequence ordering.", + CoreOptions.SEQUENCE_SNAPSHOT_ORDERING.key()); + checkArgument( + options.sequenceField().isEmpty(), + "%s = true is mutually exclusive with %s; the snapshot id is the sole tiebreaker.", + CoreOptions.SEQUENCE_SNAPSHOT_ORDERING.key(), + CoreOptions.SEQUENCE_FIELD.key()); + // Skip writeOnly check when write-only is dynamically overridden (e.g. by dedicated + // compact jobs that override write-only=false at runtime). + if (!dynamicOptionKeys.contains(CoreOptions.WRITE_ONLY.key())) { + checkArgument( + options.writeOnly(), + "%s = true requires %s = true. Snapshot ordering relies on snapshot id to " + + "determine record order, but inline compaction happens before " + + "snapshot creation — files have not been stamped with the correct " + + "snapshot id yet. Use dedicated compaction job instead.", + CoreOptions.SEQUENCE_SNAPSHOT_ORDERING.key(), + CoreOptions.WRITE_ONLY.key()); + } + } + private static void validateForDeletionVectors(CoreOptions options) { checkArgument( options.changelogProducer() == ChangelogProducer.NONE @@ -681,7 +726,9 @@ private static void validateBucket(TableSchema schema, CoreOptions options) { } else if (bucket < 1 && !isPostponeBucketTable(schema, bucket)) { throw new RuntimeException("The number of buckets needs to be greater than 0."); } else { - if (schema.primaryKeys().isEmpty() && schema.bucketKeys().isEmpty()) { + if (schema.primaryKeys().isEmpty() + && schema.bucketKeys().isEmpty() + && (bucket != 1 || schema.version() != PAIMON_07_VERSION)) { throw new RuntimeException( "You should define a 'bucket-key' for bucketed append mode."); } @@ -947,6 +994,14 @@ public static void validateChainTable(TableSchema schema, CoreOptions options) { options.partitionTimestampFormatter() != null, "Partition timestamp formatter is required for chain table."); + if (options.partitionExpireTime() != null) { + Preconditions.checkArgument( + "values-time".equals(options.partitionExpireStrategy()), + "Chain table only supports 'values-time' partition expiration strategy, " + + "but found '%s'.", + options.partitionExpireStrategy()); + } + // validate chain-table.chain-partition-keys List chainPartKeys = options.chainTableChainPartitionKeys(); if (chainPartKeys != null) { @@ -1021,4 +1076,22 @@ public static void validatePkClusteringOverride(CoreOptions options) { } } } + + private static void validateManifestSort(TableSchema schema, CoreOptions options) { + if (options.manifestSortEnabled()) { + checkArgument( + !schema.partitionKeys().isEmpty(), + "Cannot enable '%s' for non-partition table.", + CoreOptions.MANIFEST_SORT_ENABLED.key()); + String sortPartitionField = options.manifestSortPartitionField(); + if (sortPartitionField != null && !sortPartitionField.isEmpty()) { + checkArgument( + schema.partitionKeys().contains(sortPartitionField), + "'%s' = '%s' is not a partition field. Available partition fields: %s.", + CoreOptions.MANIFEST_SORT_PARTITION_FIELD.key(), + sortPartitionField, + schema.partitionKeys()); + } + } + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java b/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java index 31963420e71e..5521295ea31b 100644 --- a/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java +++ b/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java @@ -178,7 +178,11 @@ public boolean write(InternalRow record) throws IOException { } if (inMemorySortBuffer.isEmpty()) { // did not fit in a fresh buffer, must be large... - throw new IOException("The record exceeds the maximum size of a sort buffer."); + throw new IOException( + "The record exceeds the maximum size of a Paimon write sort buffer. " + + "A single serialized record cannot fit into an empty write buffer. " + + "Please check whether the input contains an oversized row, " + + "or increase the table option 'write-buffer-size'."); } else { spill(); diff --git a/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java b/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java index df4224915927..0249a8e6b397 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java @@ -374,7 +374,7 @@ protected FileStoreTable copyInternal( } // validate schema with new options - SchemaValidation.validateTableSchema(newTableSchema); + SchemaValidation.validateTableSchema(newTableSchema, dynamicOptions.keySet()); return copy(newTableSchema); } @@ -383,9 +383,10 @@ protected FileStoreTable copyInternal( public FileStoreTable copyWithLatestSchema() { Optional optionalLatestSchema = schemaManager().latest(); if (optionalLatestSchema.isPresent()) { - Map options = tableSchema.options(); - TableSchema newTableSchema = optionalLatestSchema.get(); - newTableSchema = newTableSchema.copy(options); + TableSchema latestSchema = optionalLatestSchema.get(); + Map mergedOptions = new HashMap<>(latestSchema.options()); + mergedOptions.putAll(tableSchema.options()); + TableSchema newTableSchema = latestSchema.copy(mergedOptions); SchemaValidation.validateTableSchema(newTableSchema); return copy(newTableSchema); } else { diff --git a/paimon-core/src/main/java/org/apache/paimon/table/KnownSplitsTable.java b/paimon-core/src/main/java/org/apache/paimon/table/KnownSplitsTable.java index d76a6d700c5f..64b1952d1ec9 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/KnownSplitsTable.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/KnownSplitsTable.java @@ -21,6 +21,7 @@ import org.apache.paimon.fs.FileIO; import org.apache.paimon.table.source.InnerTableRead; import org.apache.paimon.table.source.InnerTableScan; +import org.apache.paimon.table.source.ReadBuilder; import org.apache.paimon.table.source.Split; import org.apache.paimon.types.RowType; @@ -48,6 +49,14 @@ public Split[] splits() { return splits; } + @Override + public ReadBuilder newReadBuilder() { + // ReadonlyTable's default implementation creates ReadBuilderImpl(this), which makes the + // ReadBuilder capture this KnownSplitsTable. ReadBuilder should not carry all splits, so + // delegate to origin. + return origin.newReadBuilder(); + } + @Override public String name() { return origin.name(); diff --git a/paimon-core/src/main/java/org/apache/paimon/table/format/FormatTableScan.java b/paimon-core/src/main/java/org/apache/paimon/table/format/FormatTableScan.java index 9a561a6bd1da..9bbd64ccdf9c 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/format/FormatTableScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/format/FormatTableScan.java @@ -19,7 +19,10 @@ package org.apache.paimon.table.format; import org.apache.paimon.CoreOptions; +import org.apache.paimon.casting.CastExecutor; +import org.apache.paimon.casting.CastExecutors; import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.BinaryString; import org.apache.paimon.data.GenericRow; import org.apache.paimon.data.serializer.InternalRowSerializer; import org.apache.paimon.format.csv.CsvOptions; @@ -43,7 +46,9 @@ import org.apache.paimon.table.source.InnerTableScan; import org.apache.paimon.table.source.Split; import org.apache.paimon.table.source.TableScan; +import org.apache.paimon.types.DataType; import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VarCharType; import org.apache.paimon.utils.InternalRowPartitionComputer; import org.apache.paimon.utils.Pair; import org.apache.paimon.utils.PartitionPathUtils; @@ -245,7 +250,8 @@ protected static Pair computeScanPathAndLevel( Map equalityPrefix = extractLeadingEqualityPartitionSpecWhenOnlyAnd( partitionKeys, - ((DefaultPartitionPredicate) partitionFilter).predicate()); + ((DefaultPartitionPredicate) partitionFilter).predicate(), + partitionType); if (!equalityPrefix.isEmpty()) { // Use optimized scan for specific partition path String partitionPath = @@ -313,7 +319,7 @@ private boolean preferToSplitFile(FileStatus file) { } public static Map extractLeadingEqualityPartitionSpecWhenOnlyAnd( - List partitionKeys, Predicate predicate) { + List partitionKeys, Predicate predicate, RowType partitionType) { List predicates = PredicateBuilder.splitAnd(predicate); Map equals = new HashMap<>(); for (Predicate sub : predicates) { @@ -324,7 +330,10 @@ public static Map extractLeadingEqualityPartitionSpecWhenOnlyAnd LeafFunction function = ((LeafPredicate) sub).function(); String field = fieldRef.name(); if (function instanceof Equal && partitionKeys.contains(field)) { - equals.put(field, ((LeafPredicate) sub).literals().get(0).toString()); + equals.put( + field, + partitionLiteralToString( + fieldRef.type(), ((LeafPredicate) sub).literals().get(0))); } } } @@ -339,4 +348,16 @@ public static Map extractLeadingEqualityPartitionSpecWhenOnlyAnd } return result; } + + private static String partitionLiteralToString(DataType type, Object literal) { + if (literal == null) { + return null; + } + + CastExecutor executor = + (CastExecutor) + CastExecutors.resolve(type, VarCharType.STRING_TYPE); + BinaryString value = executor.cast(literal); + return value == null ? null : value.toString(); + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/BlobViewResolvingRow.java b/paimon-core/src/main/java/org/apache/paimon/table/source/BlobViewResolvingRow.java index 842b1f27297e..f18a659b0907 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/BlobViewResolvingRow.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/BlobViewResolvingRow.java @@ -66,6 +66,15 @@ public void setRowKind(RowKind kind) { @Override public boolean isNullAt(int pos) { + if (wrapped.isNullAt(pos)) { + return true; + } + if (blobViewFields.contains(pos)) { + Blob blob = wrapped.getBlob(pos); + if (blob instanceof BlobView) { + return resolver.resolvesToNull((BlobView) blob); + } + } return wrapped.isNullAt(pos); } @@ -134,6 +143,9 @@ public Blob getBlob(int pos) { Blob blob = wrapped.getBlob(pos); if (blobViewFields.contains(pos) && blob instanceof BlobView) { BlobView blobView = (BlobView) blob; + if (resolver.resolvesToNull(blobView)) { + return null; + } if (!blobView.isResolved()) { resolver.resolve(blobView); } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DvAwareStats.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DvAwareStats.java new file mode 100644 index 000000000000..2d11500ab7a8 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DvAwareStats.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.io.DataFileMeta; + +import javax.annotation.Nullable; + +/** Utilities for DV-aware file statistics. */ +public final class DvAwareStats { + + private DvAwareStats() {} + + /** Returns whether file stats remain tight after applying delete metadata. */ + public static boolean isTightBounds(DataFileMeta file, @Nullable DeletionFile dv) { + if (file.deleteRowCount().orElse(0L) > 0L) { + return false; + } + if (dv == null) { + return true; + } + Long card = dv.cardinality(); + return card != null && card == 0L; + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextReadImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextReadImpl.java index f58f8f26ead4..66e509de8999 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextReadImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextReadImpl.java @@ -20,6 +20,7 @@ import org.apache.paimon.fs.FileIO; import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexReadThreadPool; import org.apache.paimon.globalindex.GlobalIndexReader; import org.apache.paimon.globalindex.GlobalIndexResult; import org.apache.paimon.globalindex.GlobalIndexer; @@ -33,15 +34,15 @@ import org.apache.paimon.predicate.FullTextSearch; import org.apache.paimon.table.FileStoreTable; import org.apache.paimon.types.DataField; +import org.apache.paimon.utils.IOUtils; -import java.io.IOException; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; -import static java.util.Collections.singletonList; -import static org.apache.paimon.utils.ManifestReadThreadPool.randomlyExecuteSequentialReturn; +import static org.apache.paimon.CoreOptions.GLOBAL_INDEX_THREAD_NUM; import static org.apache.paimon.utils.Preconditions.checkNotNull; /** Implementation for {@link FullTextRead}. */ @@ -51,13 +52,24 @@ public class FullTextReadImpl implements FullTextRead { private final int limit; private final DataField textColumn; private final String queryText; + private final String queryOperator; public FullTextReadImpl( FileStoreTable table, int limit, DataField textColumn, String queryText) { + this(table, limit, textColumn, queryText, "or"); + } + + public FullTextReadImpl( + FileStoreTable table, + int limit, + DataField textColumn, + String queryText, + String queryOperator) { this.table = table; this.limit = limit; this.textColumn = textColumn; this.queryText = queryText; + this.queryOperator = queryOperator; } @Override @@ -66,29 +78,33 @@ public GlobalIndexResult read(List splits) { return GlobalIndexResult.createEmpty(); } - Integer threadNum = table.coreOptions().globalIndexThreadNum(); - String indexType = splits.get(0).fullTextIndexFiles().get(0).indexType(); GlobalIndexer globalIndexer = GlobalIndexerFactoryUtils.load(indexType) .create(textColumn, table.coreOptions().toConfiguration()); IndexPathFactory indexPathFactory = table.store().pathFactory().globalIndexFileFactory(); - Iterator> resultIterators = - randomlyExecuteSequentialReturn( - split -> - singletonList( - eval( - globalIndexer, - indexPathFactory, - split.rowRangeStart(), - split.rowRangeEnd(), - split.fullTextIndexFiles())), - splits, - threadNum); + + int parallelism = table.coreOptions().toConfiguration().get(GLOBAL_INDEX_THREAD_NUM); + ExecutorService executor = GlobalIndexReadThreadPool.getExecutorService(parallelism); + + List>> futures = + new ArrayList<>(splits.size()); + for (FullTextSearchSplit split : splits) { + futures.add( + eval( + globalIndexer, + indexPathFactory, + split.rowRangeStart(), + split.rowRangeEnd(), + split.fullTextIndexFiles(), + executor)); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); ScoredGlobalIndexResult result = ScoredGlobalIndexResult.createEmpty(); - while (resultIterators.hasNext()) { - Optional next = resultIterators.next(); + for (CompletableFuture> f : futures) { + Optional next = f.join(); if (next.isPresent()) { result = result.or(next.get()); } @@ -97,12 +113,13 @@ public GlobalIndexResult read(List splits) { return result.topK(limit); } - private Optional eval( + private CompletableFuture> eval( GlobalIndexer globalIndexer, IndexPathFactory indexPathFactory, long rowRangeStart, long rowRangeEnd, - List fullTextIndexFiles) { + List fullTextIndexFiles, + ExecutorService executor) { List indexIOMetaList = new ArrayList<>(); for (IndexFileMeta indexFile : fullTextIndexFiles) { GlobalIndexMeta meta = checkNotNull(indexFile.globalIndexMeta()); @@ -115,13 +132,12 @@ private Optional eval( @SuppressWarnings("resource") FileIO fileIO = table.fileIO(); GlobalIndexFileReader indexFileReader = m -> fileIO.newInputStream(m.filePath()); - try (GlobalIndexReader reader = - globalIndexer.createReader(indexFileReader, indexIOMetaList)) { - FullTextSearch fullTextSearch = new FullTextSearch(queryText, limit, textColumn.name()); - return new OffsetGlobalIndexReader(reader, rowRangeStart, rowRangeEnd) - .visitFullTextSearch(fullTextSearch); - } catch (IOException e) { - throw new RuntimeException(e); - } + GlobalIndexReader reader = + globalIndexer.createReader(indexFileReader, indexIOMetaList, executor); + FullTextSearch fullTextSearch = + new FullTextSearch(queryText, limit, textColumn.name(), queryOperator); + return new OffsetGlobalIndexReader(reader, rowRangeStart, rowRangeEnd) + .visitFullTextSearch(fullTextSearch) + .whenComplete((r, t) -> IOUtils.closeQuietly(reader)); } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilder.java b/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilder.java index 731962d0cfcd..7ce4f11e1c5c 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilder.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilder.java @@ -34,6 +34,9 @@ public interface FullTextSearchBuilder extends Serializable { /** The query text to search. */ FullTextSearchBuilder withQueryText(String queryText); + /** The default query operator. Supported values are 'or' and 'and'. */ + FullTextSearchBuilder withQueryOperator(String queryOperator); + /** Create full-text scan to scan index files. */ FullTextScan newFullTextScan(); diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilderImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilderImpl.java index 1165fcf2052a..33651e8045aa 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilderImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/FullTextSearchBuilderImpl.java @@ -35,6 +35,7 @@ public class FullTextSearchBuilderImpl implements FullTextSearchBuilder { private int limit; private DataField textColumn; private String queryText; + private String queryOperator = "or"; public FullTextSearchBuilderImpl(InnerTable table) { this.table = (FileStoreTable) table; @@ -58,6 +59,12 @@ public FullTextSearchBuilder withQueryText(String queryText) { return this; } + @Override + public FullTextSearchBuilder withQueryOperator(String queryOperator) { + this.queryOperator = queryOperator; + return this; + } + @Override public FullTextScan newFullTextScan() { checkNotNull(textColumn, "Text column must be set via withTextColumn()"); @@ -69,6 +76,6 @@ public FullTextRead newFullTextRead() { checkArgument(limit > 0, "Limit must be positive, set via withLimit()"); checkNotNull(textColumn, "Text column must be set via withTextColumn()"); checkNotNull(queryText, "Query text must be set via withQueryText()"); - return new FullTextReadImpl(table, limit, textColumn, queryText); + return new FullTextReadImpl(table, limit, textColumn, queryText, queryOperator); } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java b/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java index 65793014f10d..cc9fc936cb60 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/PushDownUtils.java @@ -30,6 +30,7 @@ import org.apache.paimon.types.TinyIntType; import java.util.HashSet; +import java.util.List; import java.util.Set; import static org.apache.paimon.utils.ListUtils.isNullOrEmpty; @@ -78,4 +79,21 @@ public static boolean minmaxAvailable(Split split, Set columns) { valueStatsCols == null || new HashSet<>(valueStatsCols).containsAll(columns)); } + + /** Returns whether every data file in the split has tight stats. */ + public static boolean tightBoundsAvailable(Split split) { + if (!(split instanceof DataSplit)) { + return false; + } + DataSplit dataSplit = (DataSplit) split; + List files = dataSplit.dataFiles(); + List dvs = dataSplit.deletionFiles().orElse(null); + for (int i = 0; i < files.size(); i++) { + DeletionFile dv = dvs == null ? null : dvs.get(i); + if (!DvAwareStats.isTightBounds(files.get(i), dv)) { + return false; + } + } + return true; + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java index 6971bb908409..a3402c3f1d66 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java @@ -20,6 +20,7 @@ import org.apache.paimon.fs.FileIO; import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexReadThreadPool; import org.apache.paimon.globalindex.GlobalIndexReader; import org.apache.paimon.globalindex.GlobalIndexResult; import org.apache.paimon.globalindex.GlobalIndexScanner; @@ -35,6 +36,7 @@ import org.apache.paimon.predicate.VectorSearch; import org.apache.paimon.table.FileStoreTable; import org.apache.paimon.types.DataField; +import org.apache.paimon.utils.IOUtils; import org.apache.paimon.utils.RoaringNavigableMap64; import javax.annotation.Nullable; @@ -42,14 +44,14 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; -import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.TreeSet; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; -import static java.util.Collections.singletonList; -import static org.apache.paimon.utils.ManifestReadThreadPool.randomlyExecuteSequentialReturn; +import static org.apache.paimon.CoreOptions.GLOBAL_INDEX_THREAD_NUM; import static org.apache.paimon.utils.Preconditions.checkNotNull; /** Implementation for {@link VectorRead}. */ @@ -81,30 +83,35 @@ public GlobalIndexResult read(List splits) { } RoaringNavigableMap64 preFilter = preFilter(splits).orElse(null); - Integer threadNum = table.coreOptions().globalIndexThreadNum(); String indexType = splits.get(0).vectorIndexFiles().get(0).indexType(); GlobalIndexer globalIndexer = GlobalIndexerFactoryUtils.load(indexType) .create(vectorColumn, table.coreOptions().toConfiguration()); IndexPathFactory indexPathFactory = table.store().pathFactory().globalIndexFileFactory(); - Iterator> resultIterators = - randomlyExecuteSequentialReturn( - split -> - singletonList( - eval( - globalIndexer, - indexPathFactory, - split.rowRangeStart(), - split.rowRangeEnd(), - split.vectorIndexFiles(), - preFilter)), - splits, - threadNum); + + int parallelism = table.coreOptions().toConfiguration().get(GLOBAL_INDEX_THREAD_NUM); + ExecutorService executor = GlobalIndexReadThreadPool.getExecutorService(parallelism); + + List>> futures = + new ArrayList<>(splits.size()); + for (VectorSearchSplit split : splits) { + futures.add( + eval( + globalIndexer, + indexPathFactory, + split.rowRangeStart(), + split.rowRangeEnd(), + split.vectorIndexFiles(), + preFilter, + executor)); + } + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); ScoredGlobalIndexResult result = ScoredGlobalIndexResult.createEmpty(); - while (resultIterators.hasNext()) { - Optional next = resultIterators.next(); + for (CompletableFuture> f : futures) { + Optional next = f.join(); if (next.isPresent()) { result = result.or(next.get()); } @@ -132,13 +139,14 @@ private Optional preFilter(List splits } } - private Optional eval( + private CompletableFuture> eval( GlobalIndexer globalIndexer, IndexPathFactory indexPathFactory, long rowRangeStart, long rowRangeEnd, List vectorIndexFiles, - @Nullable RoaringNavigableMap64 includeRowIds) { + @Nullable RoaringNavigableMap64 includeRowIds, + ExecutorService executor) { List indexIOMetaList = new ArrayList<>(); for (IndexFileMeta indexFile : vectorIndexFiles) { GlobalIndexMeta meta = checkNotNull(indexFile.globalIndexMeta()); @@ -151,15 +159,13 @@ private Optional eval( @SuppressWarnings("resource") FileIO fileIO = table.fileIO(); GlobalIndexFileReader indexFileReader = m -> fileIO.newInputStream(m.filePath()); - try (GlobalIndexReader reader = - globalIndexer.createReader(indexFileReader, indexIOMetaList)) { - VectorSearch vectorSearch = - new VectorSearch(vector, limit, vectorColumn.name()) - .withIncludeRowIds(includeRowIds); - return new OffsetGlobalIndexReader(reader, rowRangeStart, rowRangeEnd) - .visitVectorSearch(vectorSearch); - } catch (IOException e) { - throw new RuntimeException(e); - } + GlobalIndexReader reader = + globalIndexer.createReader(indexFileReader, indexIOMetaList, executor); + VectorSearch vectorSearch = + new VectorSearch(vector, limit, vectorColumn.name()) + .withIncludeRowIds(includeRowIds); + return new OffsetGlobalIndexReader(reader, rowRangeStart, rowRangeEnd) + .visitVectorSearch(vectorSearch) + .whenComplete((r, t) -> IOUtils.closeQuietly(reader)); } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java index d033e427a6dd..0a752f256467 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/snapshot/SnapshotReaderImpl.java @@ -630,15 +630,16 @@ private Set> toPartBuckets( private Map, Map> scanDvIndex( @Nullable Snapshot snapshot, Set> buckets) { - if (snapshot == null || snapshot.indexManifest() == null) { + if (snapshot == null || snapshot.indexManifest() == null || buckets.isEmpty()) { return Collections.emptyMap(); } Map, Map> result = new HashMap<>(); Path indexManifestPath = indexFileHandler.indexManifestFilePath(snapshot.indexManifest()); + Set> remainingBuckets = new HashSet<>(buckets); // 1. read from cache if (dvMetaCache != null) { - Iterator> iterator = buckets.iterator(); + Iterator> iterator = remainingBuckets.iterator(); while (iterator.hasNext()) { Pair next = iterator.next(); BinaryRow partition = next.getLeft(); @@ -658,31 +659,58 @@ private Map, Map> scanDvIndex( } } } + if (remainingBuckets.isEmpty()) { + return result; + } // 2. read from file system Map, List> partitionFileMetas = - indexFileHandler.scan( - snapshot, - DELETION_VECTORS_INDEX, - buckets.stream().map(Pair::getLeft).collect(Collectors.toSet())); + dvMetaCache == null + ? indexFileHandler.scanBuckets( + snapshot, DELETION_VECTORS_INDEX, remainingBuckets) + : indexFileHandler.scan( + snapshot, + DELETION_VECTORS_INDEX, + remainingBuckets.stream() + .map(Pair::getLeft) + .collect(Collectors.toSet())); partitionFileMetas.forEach( (entry, indexFileMetas) -> { - Map deletionFiles = - toDeletionFiles(entry, indexFileMetas); - if (dvMetaCache != null) { - dvMetaCache.put( + Pair partitionBucket = entry; + if (remainingBuckets.contains(entry)) { + Map deletionFiles = + toDeletionFiles(partitionBucket, indexFileMetas); + result.put(partitionBucket, deletionFiles); + if (dvMetaCache != null) { + dvMetaCache.put( + indexManifestPath, + partitionBucket.getLeft(), + partitionBucket.getRight(), + deletionFiles); + } + } else if (dvMetaCache != null) { + dvMetaCache.putLazy( indexManifestPath, - entry.getLeft(), - entry.getRight(), - deletionFiles); - } - if (buckets.contains(entry)) { - result.put(entry, deletionFiles); + partitionBucket.getLeft(), + partitionBucket.getRight(), + deletionFileNumber(indexFileMetas), + () -> toDeletionFiles(partitionBucket, indexFileMetas)); } }); return result; } + private int deletionFileNumber(List fileMetas) { + int count = 0; + for (IndexFileMeta indexFile : fileMetas) { + LinkedHashMap dvRanges = indexFile.dvRanges(); + if (dvRanges != null) { + count += dvRanges.size(); + } + } + return count; + } + private Map toDeletionFiles( Pair partitionBucket, List fileMetas) { Map deletionFiles = new HashMap<>(); diff --git a/paimon-core/src/main/java/org/apache/paimon/table/system/CompactBucketsTable.java b/paimon-core/src/main/java/org/apache/paimon/table/system/CompactBucketsTable.java index dad19ff26327..f18d184b8437 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/system/CompactBucketsTable.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/system/CompactBucketsTable.java @@ -263,8 +263,6 @@ public RecordReader createReader(Split split) throws IOException { DataSplit dataSplit = (DataSplit) split; // in case of schema evolution for (DataFileMeta file : dataSplit.dataFiles()) { - System.out.println( - "File schema id " + file.schemaId() + ", base schema id " + baseSchemaId); if (file.schemaId() > baseSchemaId) { throw new RuntimeException( String.format( diff --git a/paimon-core/src/main/java/org/apache/paimon/utils/BlobViewLookup.java b/paimon-core/src/main/java/org/apache/paimon/utils/BlobViewLookup.java index f570a2bb85b6..76e4a084ac88 100644 --- a/paimon-core/src/main/java/org/apache/paimon/utils/BlobViewLookup.java +++ b/paimon-core/src/main/java/org/apache/paimon/utils/BlobViewLookup.java @@ -26,6 +26,7 @@ import org.apache.paimon.catalog.Identifier; import org.apache.paimon.data.Blob; import org.apache.paimon.data.BlobDescriptor; +import org.apache.paimon.data.BlobView; import org.apache.paimon.data.BlobViewResolver; import org.apache.paimon.data.BlobViewStruct; import org.apache.paimon.data.InternalRow; @@ -36,13 +37,16 @@ import org.apache.paimon.types.DataField; import org.apache.paimon.types.RowType; +import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; @@ -71,40 +75,67 @@ static BlobViewResolver createResolver( CatalogContext catalogContext, List viewStructs, CatalogLoader catalogLoader) { - Map cached = - preloadDescriptors(catalogContext, viewStructs, catalogLoader); - Map cache = new HashMap<>(); - return blobView -> { - BlobViewStruct viewStruct = blobView.viewStruct(); - BlobDescriptor descriptor = cached.get(viewStruct); - if (descriptor == null) { - throw new IllegalStateException( - "BlobViewStruct not found in preloaded cache: " - + viewStruct - + ". Cache keys: " - + cached.keySet()); + PreloadedBlobViews cached = preloadDescriptors(catalogContext, viewStructs, catalogLoader); + return new BlobViewResolver() { + + private final Map cache = new HashMap<>(); + + @Override + public void resolve(BlobView blobView) { + BlobViewStruct viewStruct = blobView.viewStruct(); + BlobDescriptor descriptor = cached.descriptor(viewStruct); + if (descriptor == null) { + if (cached.resolvesToNull(viewStruct)) { + throw new IllegalStateException( + "BlobViewStruct resolves to a null blob value: " + viewStruct); + } + throw missingBlobViewStruct(viewStruct, cached); + } + UriReader uriReader = + cache.computeIfAbsent( + viewStruct.identifier(), + identifier -> { + try (Catalog catalog = catalogLoader.create(catalogContext)) { + return UriReader.fromFile( + catalog.getTable(identifier).fileIO()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + blobView.resolve(uriReader, descriptor); + } + + @Override + public boolean resolvesToNull(BlobView blobView) { + BlobViewStruct viewStruct = blobView.viewStruct(); + if (cached.resolvesToNull(viewStruct)) { + return true; + } + if (cached.descriptor(viewStruct) == null) { + throw missingBlobViewStruct(viewStruct, cached); + } + return false; } - UriReader uriReader = - cache.computeIfAbsent( - viewStruct.identifier(), - identifier -> { - try (Catalog catalog = catalogLoader.create(catalogContext)) { - return UriReader.fromFile( - catalog.getTable(identifier).fileIO()); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - blobView.resolve(uriReader, descriptor); }; } - private static Map preloadDescriptors( + private static IllegalStateException missingBlobViewStruct( + BlobViewStruct viewStruct, PreloadedBlobViews cached) { + return new IllegalStateException( + "BlobViewStruct not found in preloaded cache: " + + viewStruct + + ". Descriptor cache keys: " + + cached.descriptorKeys() + + ". Null-value cache keys: " + + cached.nullValueKeys()); + } + + private static PreloadedBlobViews preloadDescriptors( CatalogContext catalogContext, List viewStructs, CatalogLoader catalogLoader) { if (viewStructs.isEmpty()) { - return Collections.emptyMap(); + return PreloadedBlobViews.empty(); } Map grouped = groupReferencesByTable(viewStructs); @@ -148,7 +179,7 @@ private static Map groupReferencesByTable( return grouped; } - private static Map loadReferencedDescriptors( + private static PreloadedBlobViews loadReferencedDescriptors( CatalogContext catalogContext, Collection grouped, ExecutorService executor, @@ -160,9 +191,9 @@ private static Map loadReferencedDescriptors( } long targetRowsPerTask = targetRowsPerTask(plans); - CompletionService> completionService = + CompletionService completionService = new ExecutorCompletionService<>(executor); - List>> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); for (TableReadPlan plan : plans) { for (List rangeChunk : splitRowRanges(plan.rowRanges, targetRowsPerTask)) { @@ -189,13 +220,13 @@ private static Map loadReferencedDescriptors( } } - Map resolved = new HashMap<>(); + PreloadedBlobViews resolved = new PreloadedBlobViews(); try { for (int i = 0; i < futures.size(); i++) { resolved.putAll(completionService.take().get()); } } catch (Exception e) { - for (Future> future : futures) { + for (Future future : futures) { future.cancel(true); } throw e; @@ -243,7 +274,7 @@ private static TableReadPlan createTableReadPlan( } } - private static Map loadTableDescriptorChunk( + private static PreloadedBlobViews loadTableDescriptorChunk( CatalogContext catalogContext, Identifier identifier, List fields, @@ -252,7 +283,7 @@ private static Map loadTableDescriptorChunk( CatalogLoader catalogLoader) throws Exception { try (Catalog catalog = catalogLoader.create(catalogContext)) { - Map resolved = new HashMap<>(); + PreloadedBlobViews resolved = new PreloadedBlobViews(); Table table = catalog.getTable(identifier) .copy( @@ -271,12 +302,18 @@ private static Map loadTableDescriptorChunk( while ((row = batch.next()) != null) { long rowId = row.getLong(fields.size()); for (int i = 0; i < fields.size(); i++) { + BlobViewStruct viewStruct = + new BlobViewStruct( + identifier, fields.get(i).fieldId, rowId); + if (row.isNullAt(i)) { + resolved.putNull(viewStruct); + continue; + } Blob blob = row.getBlob(i); - if (blob != null) { - resolved.put( - new BlobViewStruct( - identifier, fields.get(i).fieldId, rowId), - blob.toDescriptor()); + if (blob == null) { + resolved.putNull(viewStruct); + } else { + resolved.putDescriptor(viewStruct, blob.toDescriptor()); } } } @@ -400,5 +437,56 @@ private TableReadPlan( } } + private static class PreloadedBlobViews implements Serializable { + + private static final long serialVersionUID = 1L; + + private final Map descriptors; + private final Set nullValues; + + private PreloadedBlobViews() { + this(new HashMap<>(), new HashSet<>()); + } + + private PreloadedBlobViews( + Map descriptors, Set nullValues) { + this.descriptors = descriptors; + this.nullValues = nullValues; + } + + private static PreloadedBlobViews empty() { + return new PreloadedBlobViews(Collections.emptyMap(), Collections.emptySet()); + } + + private BlobDescriptor descriptor(BlobViewStruct viewStruct) { + return descriptors.get(viewStruct); + } + + private boolean resolvesToNull(BlobViewStruct viewStruct) { + return nullValues.contains(viewStruct); + } + + private void putDescriptor(BlobViewStruct viewStruct, BlobDescriptor descriptor) { + descriptors.put(viewStruct, descriptor); + } + + private void putNull(BlobViewStruct viewStruct) { + nullValues.add(viewStruct); + } + + private void putAll(PreloadedBlobViews other) { + descriptors.putAll(other.descriptors); + nullValues.addAll(other.nullValues); + } + + private Set descriptorKeys() { + return descriptors.keySet(); + } + + private Set nullValueKeys() { + return nullValues; + } + } + private BlobViewLookup() {} } diff --git a/paimon-core/src/main/java/org/apache/paimon/utils/DVMetaCache.java b/paimon-core/src/main/java/org/apache/paimon/utils/DVMetaCache.java index 34b92f9bc3f0..b81241257b0f 100644 --- a/paimon-core/src/main/java/org/apache/paimon/utils/DVMetaCache.java +++ b/paimon-core/src/main/java/org/apache/paimon/utils/DVMetaCache.java @@ -29,11 +29,12 @@ import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; /** Cache for deletion vector meta. */ public class DVMetaCache { - private final Cache> cache; + private final Cache cache; public DVMetaCache(long maxValueNumber) { this.cache = @@ -45,20 +46,95 @@ public DVMetaCache(long maxValueNumber) { .build(); } - private static int weigh(DVMetaCacheKey cacheKey, Map cacheValue) { - return cacheValue.size() + 1; + private static int weigh(DVMetaCacheKey cacheKey, DVMetaCacheValue cacheValue) { + return cacheValue.weight(); } @Nullable public Map read(Path manifestPath, BinaryRow partition, int bucket) { DVMetaCacheKey cacheKey = new DVMetaCacheKey(manifestPath, partition, bucket); - return this.cache.getIfPresent(cacheKey); + DVMetaCacheValue cacheValue = this.cache.getIfPresent(cacheKey); + return cacheValue == null ? null : cacheValue.get(); } public void put( Path path, BinaryRow partition, int bucket, Map dvFilesMap) { DVMetaCacheKey key = new DVMetaCacheKey(path, partition, bucket); - this.cache.put(key, dvFilesMap); + this.cache.put(key, DVMetaCacheValue.eager(dvFilesMap)); + } + + public void putLazy( + Path path, + BinaryRow partition, + int bucket, + int valueNumber, + Supplier> dvFilesSupplier) { + DVMetaCacheKey key = new DVMetaCacheKey(path, partition, bucket); + this.cache.put(key, DVMetaCacheValue.lazy(valueNumber, dvFilesSupplier)); + } + + /** Cache value for deletion vector meta at bucket level. */ + private static final class DVMetaCacheValue { + + private final int weight; + private final DeletionFilesField deletionFilesField; + + private DVMetaCacheValue(int weight, DeletionFilesField deletionFilesField) { + this.weight = weight; + this.deletionFilesField = deletionFilesField; + } + + private static DVMetaCacheValue eager(Map deletionFiles) { + return new DVMetaCacheValue( + deletionFiles.size() + 1, new ExistingDeletionFilesField(deletionFiles)); + } + + private static DVMetaCacheValue lazy( + int valueNumber, Supplier> deletionFilesSupplier) { + return new DVMetaCacheValue( + valueNumber + 1, new LazyDeletionFilesField(deletionFilesSupplier)); + } + + private int weight() { + return weight; + } + + private Map get() { + return deletionFilesField.get(); + } + } + + private interface DeletionFilesField { + + Map get(); + } + + private static final class ExistingDeletionFilesField implements DeletionFilesField { + + private final Map deletionFiles; + + private ExistingDeletionFilesField(Map deletionFiles) { + this.deletionFiles = deletionFiles; + } + + @Override + public Map get() { + return deletionFiles; + } + } + + private static final class LazyDeletionFilesField implements DeletionFilesField { + + private final LazyField> deletionFiles; + + private LazyDeletionFilesField(Supplier> deletionFilesSupplier) { + this.deletionFiles = new LazyField<>(deletionFilesSupplier); + } + + @Override + public synchronized Map get() { + return deletionFiles.get(); + } } /** Cache key for deletion vector meta at bucket level. */ diff --git a/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java b/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java index 86cea365c7a1..379331a1c221 100644 --- a/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java +++ b/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java @@ -997,6 +997,68 @@ protected GenericRow createRow3ColsWithKind(RowKind rowKind, Object... values) { return GenericRow.ofKind(rowKind, values[0], values[1], values[2]); } + /** Java writes a ROW-format append-only table for Python to read (Java→Python E2E). */ + @Test + @EnabledIfSystemProperty(named = "run.e2e.tests", matches = "true") + public void testJavaWriteRowAppendTable() throws Exception { + Identifier identifier = identifier("mixed_test_append_tablej_row"); + catalog.dropTable(identifier, true); + Schema schema = + Schema.newBuilder() + .column("id", DataTypes.INT()) + .column("name", DataTypes.STRING()) + .column("value", DataTypes.DOUBLE()) + .option("file.format", "row") + .option("bucket", "-1") + .build(); + + catalog.createTable(identifier, schema, false); + FileStoreTable table = (FileStoreTable) catalog.getTable(identifier); + + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = writeBuilder.newWrite(); + BatchTableCommit commit = writeBuilder.newCommit()) { + write.write(GenericRow.of(1, BinaryString.fromString("Apple"), 1.5)); + write.write(GenericRow.of(2, BinaryString.fromString("Banana"), 0.8)); + write.write(GenericRow.of(3, BinaryString.fromString("Carrot"), 0.6)); + write.write(GenericRow.of(4, BinaryString.fromString("Broccoli"), 1.2)); + write.write(GenericRow.of(5, BinaryString.fromString("Chicken"), 5.0)); + write.write(GenericRow.of(6, BinaryString.fromString("Beef"), 8.0)); + commit.commit(write.prepareCommit()); + } + + List splits = new ArrayList<>(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newRead(); + List res = + getResult( + read, + splits, + row -> DataFormatTestUtil.toStringNoRowKind(row, table.rowType())); + assertThat(res).hasSize(6); + LOG.info("testJavaWriteRowAppendTable: wrote and read back {} ROW-format rows", res.size()); + } + + /** Java reads a ROW-format append-only table written by Python (Python→Java E2E). */ + @Test + @EnabledIfSystemProperty(named = "run.e2e.tests", matches = "true") + public void testReadRowAppendTable() throws Exception { + Identifier identifier = identifier("mixed_test_append_tablep_row"); + Table table = catalog.getTable(identifier); + FileStoreTable fileStoreTable = (FileStoreTable) table; + List splits = + new ArrayList<>(fileStoreTable.newSnapshotReader().read().dataSplits()); + TableRead read = fileStoreTable.newRead(); + List res = + getResult( + read, + splits, + row -> DataFormatTestUtil.toStringNoRowKind(row, table.rowType())); + assertThat(res).hasSize(6); + LOG.info( + "testReadRowAppendTable: Java read {} ROW-format rows written by Python", + res.size()); + } + /** Java writes a VARIANT-column table for Python to read (Java→Python E2E). */ @Test @EnabledIfSystemProperty(named = "run.e2e.tests", matches = "true") diff --git a/paimon-core/src/test/java/org/apache/paimon/append/BlobTableTest.java b/paimon-core/src/test/java/org/apache/paimon/append/BlobTableTest.java index 145cf0b04d33..3f12ce039fcb 100644 --- a/paimon-core/src/test/java/org/apache/paimon/append/BlobTableTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/append/BlobTableTest.java @@ -26,6 +26,7 @@ import org.apache.paimon.data.Blob; import org.apache.paimon.data.BlobData; import org.apache.paimon.data.BlobDescriptor; +import org.apache.paimon.data.BlobPlaceholder; import org.apache.paimon.data.BlobView; import org.apache.paimon.data.BlobViewStruct; import org.apache.paimon.data.GenericRow; @@ -47,6 +48,7 @@ import org.apache.paimon.table.sink.BatchTableWrite; import org.apache.paimon.table.sink.BatchWriteBuilder; import org.apache.paimon.table.sink.CommitMessage; +import org.apache.paimon.table.sink.CommitMessageImpl; import org.apache.paimon.table.sink.StreamTableWrite; import org.apache.paimon.table.sink.StreamWriteBuilder; import org.apache.paimon.table.source.EndOfScanException; @@ -152,6 +154,91 @@ public void testBasic() throws Exception { assertThat(integer.get()).isEqualTo(1000); } + @Test + public void testUpdateBlobColumn() throws Exception { + createTableDefault(); + + byte[] blob0 = "blob-0".getBytes(); + byte[] blob1 = "blob-1".getBytes(); + byte[] blob2 = "blob-2".getBytes(); + writeDataDefault( + Arrays.asList( + GenericRow.of(0, BinaryString.fromString("row-0"), new BlobData(blob0)), + GenericRow.of(1, BinaryString.fromString("row-1"), new BlobData(blob1)), + GenericRow.of(2, BinaryString.fromString("row-2"), new BlobData(blob2)))); + + byte[] updatedBlob1 = "updated-blob-1".getBytes(); + FileStoreTable table = getTableDefault(); + RowType blobWriteType = table.schema().logicalRowType().project("f2"); + BatchWriteBuilder builder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite().withWriteType(blobWriteType); + BatchTableCommit commit = builder.newCommit()) { + write.write(GenericRow.of(BlobPlaceholder.INSTANCE)); + write.write(GenericRow.of(new BlobData(updatedBlob1))); + write.write(GenericRow.of(BlobPlaceholder.INSTANCE)); + + List commitMessages = write.prepareCommit(); + assignFirstRowId(commitMessages, 0L); + commit.commit(commitMessages); + } + + Map actual = new HashMap<>(); + readDefault(row -> actual.put(row.getInt(0), row.getBlob(2).toData())); + + assertThat(actual.size()).isEqualTo(3); + assertThat(actual.get(0)).isEqualTo(blob0); + assertThat(actual.get(1)).isEqualTo(updatedBlob1); + assertThat(actual.get(2)).isEqualTo(blob2); + } + + @Test + public void testCompactUpdatedBlobColumn() throws Exception { + createTableDefault(); + + byte[] blob0 = "blob-0".getBytes(); + byte[] blob1 = "blob-1".getBytes(); + byte[] blob2 = "blob-2".getBytes(); + writeDataDefault( + Arrays.asList( + GenericRow.of(0, BinaryString.fromString("row-0"), new BlobData(blob0)), + GenericRow.of(1, BinaryString.fromString("row-1"), new BlobData(blob1)), + GenericRow.of(2, BinaryString.fromString("row-2"), new BlobData(blob2)))); + + byte[] updatedBlob1 = "updated-blob-1".getBytes(); + FileStoreTable table = getTableDefault(); + RowType blobWriteType = table.schema().logicalRowType().project("f2"); + BatchWriteBuilder builder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite().withWriteType(blobWriteType); + BatchTableCommit commit = builder.newCommit()) { + write.write(GenericRow.of(BlobPlaceholder.INSTANCE)); + write.write(GenericRow.of(new BlobData(updatedBlob1))); + write.write(GenericRow.of(BlobPlaceholder.INSTANCE)); + + List commitMessages = write.prepareCommit(); + assignFirstRowId(commitMessages, 0L); + commit.commit(commitMessages); + } + + DataEvolutionCompactCoordinator coordinator = + new DataEvolutionCompactCoordinator(table, true, false); + List tasks = coordinator.plan(); + assertThat(tasks.stream().anyMatch(DataEvolutionCompactTask::isBlobTask)).isTrue(); + + List compactMessages = new ArrayList<>(); + for (DataEvolutionCompactTask task : tasks) { + compactMessages.add(task.doCompact(table, commitUser)); + } + commitDefault(compactMessages); + + Map actual = new HashMap<>(); + readDefault(row -> actual.put(row.getInt(0), row.getBlob(2).toData())); + + assertThat(actual.size()).isEqualTo(3); + assertThat(actual.get(0)).isEqualTo(blob0); + assertThat(actual.get(1)).isEqualTo(updatedBlob1); + assertThat(actual.get(2)).isEqualTo(blob2); + } + @Test public void testWriteByInputStream() throws Exception { createTableDefault(); @@ -813,7 +900,8 @@ public void testBlobViewE2E() throws Exception { GenericRow.of( 1, BinaryString.fromString("row1"), new BlobData(imageBytes1)), GenericRow.of( - 2, BinaryString.fromString("row2"), new BlobData(imageBytes2)))); + 2, BinaryString.fromString("row2"), new BlobData(imageBytes2)), + GenericRow.of(3, BinaryString.fromString("row3"), null))); int imageFieldId = upstreamTable.rowType().getFields().stream() @@ -829,6 +917,7 @@ public void testBlobViewE2E() throws Exception { Map idToBlob = new HashMap<>(); idToBlob.put(1, imageBytes1); idToBlob.put(2, imageBytes2); + idToBlob.put(3, null); rowIdReader .newRead() .createReader(rowIdReader.newScan().plan()) @@ -837,7 +926,7 @@ public void testBlobViewE2E() throws Exception { int id = row.getInt(0); idToRowId.put(id, row.getLong(1)); }); - assertThat(idToRowId.size()).isEqualTo(2); + assertThat(idToRowId.size()).isEqualTo(3); String downstreamTableName = "DownstreamView"; Schema.Builder downstreamSchema = Schema.newBuilder(); @@ -871,7 +960,15 @@ public void testBlobViewE2E() throws Exception { new BlobViewStruct( Identifier.fromString(upstreamFullName), imageFieldId, - idToRowId.get(2)))))); + idToRowId.get(2)))), + GenericRow.of( + 3, + BinaryString.fromString("label3"), + Blob.fromView( + new BlobViewStruct( + Identifier.fromString(upstreamFullName), + imageFieldId, + idToRowId.get(3)))))); ReadBuilder downstreamReadBuilder = downstreamTable.newReadBuilder(); downstreamReadBuilder @@ -880,6 +977,11 @@ public void testBlobViewE2E() throws Exception { .forEachRemaining( row -> { int id = row.getInt(0); + if (idToBlob.get(id) == null) { + assertThat(row.isNullAt(2)).isTrue(); + assertThat(row.getBlob(2)).isNull(); + return; + } Blob blob = row.getBlob(2); assertThat(blob).isInstanceOf(BlobView.class); assertThat(((BlobView) blob).isResolved()).isTrue(); @@ -1715,6 +1817,22 @@ private void writeRows(Table table, Iterable rows) throws Exception commit.close(); } + private static void assignFirstRowId(List commitMessages, long firstRowId) { + commitMessages.forEach( + commitMessage -> { + CommitMessageImpl impl = (CommitMessageImpl) commitMessage; + List newFiles = + new ArrayList<>(impl.newFilesIncrement().newFiles()); + impl.newFilesIncrement().newFiles().clear(); + impl.newFilesIncrement() + .newFiles() + .addAll( + newFiles.stream() + .map(file -> file.assignFirstRowId(firstRowId)) + .collect(Collectors.toList())); + }); + } + private void createThreeTypeBlobTable() throws Exception { Schema.Builder schemaBuilder = Schema.newBuilder(); schemaBuilder.column("f0", DataTypes.INT()); diff --git a/paimon-core/src/test/java/org/apache/paimon/append/MultipleBlobTableTest.java b/paimon-core/src/test/java/org/apache/paimon/append/MultipleBlobTableTest.java index 1cbaf903f0e4..f230216a58a6 100644 --- a/paimon-core/src/test/java/org/apache/paimon/append/MultipleBlobTableTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/append/MultipleBlobTableTest.java @@ -422,7 +422,8 @@ public void testAddBlobColumnThenProjectBothBlobs() throws Exception { // Add new blob column f3 catalog.alterTable( identifier(), - Collections.singletonList(SchemaChange.addColumn("f3", DataTypes.BLOB())), + Collections.singletonList( + SchemaChange.addColumn("f3", DataTypes.BLOB(), "__BLOB_FIELD", null)), false); // Write more data with both f2 and f3 diff --git a/paimon-core/src/test/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinatorTest.java b/paimon-core/src/test/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinatorTest.java index f2f3a1675670..fe1b2dfc3ead 100644 --- a/paimon-core/src/test/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinatorTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/append/dataevolution/DataEvolutionCompactCoordinatorTest.java @@ -191,6 +191,24 @@ public void testCompactPlannerWithBlobFiles() { entries.get(5).file()); } + @Test + public void testCompactPlannerWithUpdatedBlobFiles() { + List entries = new ArrayList<>(); + entries.add(makeEntry("file1.parquet", 0L, 3L, 100)); + entries.add(makeBlobEntry("old.blob", 0L, 3L, 100, 0, "pic")); + entries.add(makeBlobEntry("updated.blob", 0L, 3L, 100, 1, "pic")); + + DataEvolutionCompactCoordinator.CompactPlanner planner = + blobPlanner(1024, 1, 2, rowType(new DataField(1, "pic", DataTypes.BLOB()))); + + List tasks = planner.compactPlan(entries); + + assertThat(tasks).hasSize(1); + assertThat(tasks.get(0).isBlobTask()).isTrue(); + assertThat(tasks.get(0).compactBefore()) + .containsExactly(entries.get(1).file(), entries.get(2).file()); + } + @Test public void testCompactPlannerDoesNotCompactBlobFilesAcrossDataFiles() { List entries = new ArrayList<>(); diff --git a/paimon-core/src/test/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilderTest.java b/paimon-core/src/test/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilderTest.java index 1010b8fc79d7..0c55221a50f7 100644 --- a/paimon-core/src/test/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilderTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/globalindex/btree/BTreeGlobalIndexBuilderTest.java @@ -22,11 +22,14 @@ import org.apache.paimon.Snapshot; import org.apache.paimon.data.BinaryRow; import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.BlobData; import org.apache.paimon.data.GenericRow; import org.apache.paimon.index.GlobalIndexMeta; import org.apache.paimon.index.IndexFileHandler; import org.apache.paimon.index.IndexFileMeta; +import org.apache.paimon.io.DataFileMeta; import org.apache.paimon.manifest.IndexManifestEntry; +import org.apache.paimon.manifest.ManifestEntry; import org.apache.paimon.memory.MemorySlice; import org.apache.paimon.partition.PartitionPredicate; import org.apache.paimon.predicate.Predicate; @@ -296,6 +299,53 @@ public void testIncrementalScanWithPartitionPredicate() throws Exception { "incrementalScan should only return the new rows in partition p0"); } + @Test + public void testScanFiltersBlobFilesByManifestEntryFilter() throws Exception { + Schema.Builder schemaBuilder = Schema.newBuilder(); + schemaBuilder.column("dt", DataTypes.STRING()); + schemaBuilder.column("f0", DataTypes.INT()); + schemaBuilder.column("f1", DataTypes.BLOB()); + schemaBuilder.option(CoreOptions.ROW_TRACKING_ENABLED.key(), "true"); + schemaBuilder.option(CoreOptions.DATA_EVOLUTION_ENABLED.key(), "true"); + schemaBuilder.option(CoreOptions.BLOB_TARGET_FILE_SIZE.key(), "1 b"); + schemaBuilder.partitionKeys(Collections.singletonList("dt")); + + catalog.createTable(identifier("BlobTable"), schemaBuilder.build(), false); + FileStoreTable table = getTable(identifier("BlobTable")); + + byte[] blobBytes = new byte[] {1}; + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = writeBuilder.newWrite()) { + for (int i = 0; i < 10; i++) { + write.write( + GenericRow.of(BinaryString.fromString("p0"), i, new BlobData(blobBytes))); + } + try (BatchTableCommit commit = writeBuilder.newCommit()) { + commit.commit(write.prepareCommit()); + } + } + + Assertions.assertTrue( + containsBlobFile(table.store().newScan().plan().files()), + "Test table should contain blob manifest entries."); + + BTreeGlobalIndexBuilder builder = new BTreeGlobalIndexBuilder(table).withIndexField("f0"); + assertNoBlobFiles( + builder.scan() + .map(Pair::getRight) + .orElseThrow( + () -> + new IllegalStateException( + "Expected scan result for blob table."))); + assertNoBlobFiles( + builder.incrementalScan() + .map(Pair::getRight) + .orElseThrow( + () -> + new IllegalStateException( + "Expected incremental scan result for blob table."))); + } + private Map>> gatherIndexMetas(FileStoreTable table) { IndexFileHandler handler = table.store().newIndexFileHandler(); @@ -319,6 +369,26 @@ private Map>> gatherIndexMetas(FileStore return metasByParts; } + private boolean containsBlobFile(List entries) { + for (ManifestEntry entry : entries) { + if ("blob".equals(entry.file().fileFormat())) { + return true; + } + } + return false; + } + + private void assertNoBlobFiles(List splits) { + for (DataSplit split : splits) { + for (DataFileMeta file : split.dataFiles()) { + Assertions.assertNotEquals( + "blob", + file.fileFormat(), + "BTree global index scan should not include blob files."); + } + } + } + private void assertFilesNonOverlapping( BinaryRow partition, List> metas) { if (metas.isEmpty()) { diff --git a/paimon-core/src/test/java/org/apache/paimon/index/IndexFileHandlerTest.java b/paimon-core/src/test/java/org/apache/paimon/index/IndexFileHandlerTest.java index 24972642c493..70d5881e3405 100644 --- a/paimon-core/src/test/java/org/apache/paimon/index/IndexFileHandlerTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/index/IndexFileHandlerTest.java @@ -19,25 +19,39 @@ package org.apache.paimon.index; import org.apache.paimon.CoreOptions; +import org.apache.paimon.Snapshot; +import org.apache.paimon.TestAppendFileStore; import org.apache.paimon.data.BinaryRow; import org.apache.paimon.fs.FileIO; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.io.CompactIncrement; +import org.apache.paimon.io.DataIncrement; import org.apache.paimon.manifest.FileKind; import org.apache.paimon.manifest.IndexManifestEntry; import org.apache.paimon.options.MemorySize; +import org.apache.paimon.table.sink.CommitMessageImpl; import org.apache.paimon.types.RowType; import org.apache.paimon.utils.FileStorePathFactory; import org.apache.paimon.utils.IndexFilePathFactories; +import org.apache.paimon.utils.Pair; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.UUID; +import static org.apache.paimon.deletionvectors.DeletionVectorsIndexFile.DELETION_VECTORS_INDEX; +import static org.apache.paimon.index.HashIndexFile.HASH_INDEX; import static org.assertj.core.api.Assertions.assertThat; /** @@ -110,4 +124,52 @@ void testExistsAndDeleteIndexFile(String indexType) throws IOException { assertThat(handler.existsIndexFile(entry)).isFalse(); } + + @Test + void testScanBucketsOnlyReturnsRequestedBuckets() throws Exception { + TestAppendFileStore store = + TestAppendFileStore.createAppendStore(tempPath, new HashMap<>()); + Map> bucket0Dvs = new HashMap<>(); + bucket0Dvs.put("f0", Arrays.asList(1, 2)); + Map> bucket1Dvs = new HashMap<>(); + bucket1Dvs.put("f1", Collections.singletonList(3)); + IndexFileMeta hashIndex = + store.newIndexFileHandler() + .hashIndex(BinaryRow.EMPTY_ROW, 1) + .write(new int[] {1, 2, 3}); + store.commit( + store.writeDVIndexFiles(BinaryRow.EMPTY_ROW, 0, bucket0Dvs), + store.writeDVIndexFiles(BinaryRow.EMPTY_ROW, 1, bucket1Dvs), + new CommitMessageImpl( + BinaryRow.EMPTY_ROW, + 1, + 1, + DataIncrement.indexIncrement(Collections.singletonList(hashIndex)), + CompactIncrement.emptyIncrement())); + + Snapshot snapshot = store.snapshotManager().latestSnapshot(); + IndexFileHandler indexFileHandler = store.newIndexFileHandler(); + assertThat( + indexFileHandler.scanBuckets( + snapshot, DELETION_VECTORS_INDEX, Collections.emptySet())) + .isEmpty(); + + Map, List> scanned = + indexFileHandler.scanBuckets( + snapshot, + DELETION_VECTORS_INDEX, + Collections.singleton(Pair.of(BinaryRow.EMPTY_ROW, 1))); + + assertThat(scanned).containsOnlyKeys(Pair.of(BinaryRow.EMPTY_ROW, 1)); + assertThat(scanned.get(Pair.of(BinaryRow.EMPTY_ROW, 1))) + .extracting(IndexFileMeta::dvRanges) + .allSatisfy(dvRanges -> assertThat(dvRanges).containsOnlyKeys("f1")); + + assertThat( + indexFileHandler.scanBuckets( + snapshot, + HASH_INDEX, + Collections.singleton(Pair.of(BinaryRow.EMPTY_ROW, 1)))) + .containsOnlyKeys(Pair.of(BinaryRow.EMPTY_ROW, 1)); + } } diff --git a/paimon-core/src/test/java/org/apache/paimon/manifest/ManifestFileMetaTest.java b/paimon-core/src/test/java/org/apache/paimon/manifest/ManifestFileMetaTest.java index 36b0d15f114f..75a1ab0a84df 100644 --- a/paimon-core/src/test/java/org/apache/paimon/manifest/ManifestFileMetaTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/manifest/ManifestFileMetaTest.java @@ -18,16 +18,26 @@ package org.apache.paimon.manifest; +import org.apache.paimon.CoreOptions; import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.BinaryRowWriter; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.fs.FileIOFinder; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.SeekableInputStream; import org.apache.paimon.fs.SeekableInputStreamWrapper; import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.io.DataFileMeta; import org.apache.paimon.operation.ManifestFileMerger; +import org.apache.paimon.options.Options; import org.apache.paimon.partition.PartitionPredicate; +import org.apache.paimon.schema.SchemaManager; +import org.apache.paimon.stats.StatsTestUtils; import org.apache.paimon.types.IntType; import org.apache.paimon.types.RowType; import org.apache.paimon.utils.FailingFileIO; +import org.apache.paimon.utils.FileStorePathFactory; import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; @@ -42,6 +52,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -84,9 +95,16 @@ public void testMergeWithoutFullCompaction(int numLastBits) { createData(numLastBits, input, expected); // no trigger Full Compaction + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", "500B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "9223372036854775807B"); List actual = ManifestFileMerger.merge( - input, manifestFile, 500, 3, Long.MAX_VALUE, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); assertThat(actual).hasSameSizeAs(expected); // these two manifest files are merged from the input @@ -118,14 +136,16 @@ private void testCleanUp(List input, long fullCompactionThresh ManifestFile failingManifestFile = createManifestFile(FailingFileIO.getFailingPath(failingName, tempDir.toString())); try { + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", "500B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set( + "manifest.full-compaction-threshold-size", fullCompactionThreshold + "B"); ManifestFileMerger.merge( input, failingManifestFile, - 500, - 3, - fullCompactionThreshold, getPartitionType(), - null); + CoreOptions.fromMap(testOptions.toMap())); } catch (Throwable e) { assertThat(e).hasRootCauseExactlyInstanceOf(FailingFileIO.ArtificialException.class); // old files should be kept untouched, while new files should be cleaned up @@ -156,9 +176,16 @@ public void testMerge() { // delta with delete apply partition 1,2 addDeltaManifests(input, true); // trigger full compaction + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", "500B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "200B"); List merged = ManifestFileMerger.merge( - input, manifestFile, 500, 3, 200, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); // 1st Manifest don't need to Merge assertSameContent(input.get(0), merged.get(0), manifestFile); @@ -173,9 +200,16 @@ public void testMergeWithoutDelta() { // base List input = createBaseManifestFileMetas(true); + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", "500B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "200B"); List merged = ManifestFileMerger.merge( - input, manifestFile, 500, 3, 200, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); assertEquivalentEntries(input, merged); assertThat(merged).hasSameElementsAs(input); @@ -186,9 +220,16 @@ public void testMergeWithoutDelta() { ManifestFileMeta delta = makeManifest(makeEntry(true, "A", 1), makeEntry(false, "A", 1)); input1.add(delta); + Options testOptions1 = new Options(); + testOptions1.set("manifest.target-file-size", "500B"); + testOptions1.set("manifest.merge-min-count", "3"); + testOptions1.set("manifest.full-compaction-threshold-size", "200B"); List merged1 = ManifestFileMerger.merge( - input1, manifestFile, 500, 3, 200, getPartitionType(), null); + input1, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions1.toMap())); assertThat(base).hasSameElementsAs(merged1); assertEquivalentEntries(input1, merged1); @@ -198,9 +239,16 @@ public void testMergeWithoutDelta() { public void testMergeWithoutBase() { List input = new ArrayList<>(); addDeltaManifests(input, true); + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", "500B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "200B"); List merged = ManifestFileMerger.merge( - input, manifestFile, 500, 3, 200, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); assertEquivalentEntries(input, merged); } @@ -225,9 +273,16 @@ public void testMergeWithoutDeleteFile() { input.add(makeManifest(makeEntry(true, "F"))); input.add(makeManifest(makeEntry(true, "G"))); + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", "500B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "200B"); List merged = ManifestFileMerger.merge( - input, manifestFile, 500, 3, 200, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); assertEquivalentEntries(input, merged); } @@ -489,9 +544,16 @@ public void testMergeFullCompactionWithoutDeleteFile() { input.add(makeManifest(makeEntry(true, "F"))); input.add(makeManifest(makeEntry(true, "G"))); + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", threshold + "B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "200B"); List merged = ManifestFileMerger.merge( - input, manifestFile, threshold, 3, 200, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); assertEquivalentEntries( input.stream() .filter(f -> !baseFiles.contains(f.fileName())) @@ -819,4 +881,621 @@ private void beforeFirstRead() throws IOException { } } } + + // ==================== Manifest Sort Tests ==================== + + /** + * Test manifest sort with overlapping partition ranges. Each manifest contains entries spanning + * multiple partitions, creating overlapping intervals that require sort rewrite to resolve. + * After sort rewrite, all surviving ADD entries should be sorted by partition field. + */ + @Test + public void testManifestSortWithOverlappingPartitions() { + List input = new ArrayList<>(); + + // manifest-A: partitions [5, 13] + List entriesA = new ArrayList<>(); + for (int p = 5; p <= 13; p++) { + entriesA.add(makeEntry(true, String.format("A-p%d", p), p)); + } + input.add(makeManifest(entriesA.toArray(new ManifestEntry[0]))); + + // manifest-B: partitions [0, 9] + List entriesB = new ArrayList<>(); + for (int p = 0; p <= 9; p++) { + entriesB.add(makeEntry(true, String.format("B-p%d", p), p)); + } + input.add(makeManifest(entriesB.toArray(new ManifestEntry[0]))); + + // manifest-C: partitions [3, 7] -- overlaps with A and B + List entriesC = new ArrayList<>(); + for (int p = 3; p <= 7; p++) { + entriesC.add(makeEntry(true, String.format("C-p%d", p), p)); + } + input.add(makeManifest(entriesC.toArray(new ManifestEntry[0]))); + + // manifest-D: partitions [8, 12] -- overlaps with A + List entriesD = new ArrayList<>(); + for (int p = 8; p <= 12; p++) { + entriesD.add(makeEntry(true, String.format("D-p%d", p), p)); + } + input.add(makeManifest(entriesD.toArray(new ManifestEntry[0]))); + + // manifest-E: partitions [1, 6] -- overlaps with B and C + List entriesE = new ArrayList<>(); + for (int p = 1; p <= 6; p++) { + entriesE.add(makeEntry(true, String.format("E-p%d", p), p)); + } + input.add(makeManifest(entriesE.toArray(new ManifestEntry[0]))); + + // manifest-F: partitions [4, 14] -- overlaps with D + List entriesF = new ArrayList<>(); + for (int p = 4; p <= 14; p++) { + entriesF.add(makeEntry(true, String.format("F-p%d", p), p)); + } + input.add(makeManifest(entriesF.toArray(new ManifestEntry[0]))); + + Options testOptions = new Options(); + testOptions.set("manifest-sort.enabled", "true"); + List merged = + ManifestFileMerger.merge( + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); + + // Verify entries are equivalent (no data loss) + assertEquivalentEntries(input, merged); + + // Verify all entries within each output manifest are sorted by partition + for (ManifestFileMeta meta : merged) { + List entries = manifestFile.read(meta.fileName(), meta.fileSize()); + for (int i = 1; i < entries.size(); i++) { + int prevPartition = entries.get(i - 1).partition().getInt(0); + int currPartition = entries.get(i).partition().getInt(0); + assertThat(currPartition) + .as("Entries within a manifest should be sorted by partition") + .isGreaterThanOrEqualTo(prevPartition); + } + } + } + + /** + * Test that sort rewrite correctly eliminates DELETE entries and their corresponding ADD + * entries. The key condition is that totalDeltaFileSize must reach manifestFullCompactionSize + * to trigger the full compaction path inside trySortRewrite, which reads deleteEntries and + * passes them to sortAndRewriteSection for elimination. + * + *

    Design: + * + *

    +     *   - Base manifests with overlapping partitions (all ADD, large enough to be "mustChange"
    +     *     since fileSize < suggestedMetaSize):
    +     *     manifest-A: partitions [0, 4] with entries A-p0..A-p4
    +     *     manifest-B: partitions [2, 6] with entries B-p2..B-p6 (overlaps A)
    +     *     manifest-C: partitions [5, 9] with entries C-p5..C-p9 (overlaps B)
    +     *   - Delta manifests with DELETE entries (cancel some ADD entries):
    +     *     manifest-D: DELETE A-p2, DELETE B-p4, ADD new-p2, ADD new-p4
    +     *     manifest-E: DELETE C-p7, ADD new-p7
    +     *   - After sort rewrite: A-p2, B-p4, C-p7 should be eliminated,
    +     *     replaced by new-p2, new-p4, new-p7. Output should only contain ADD entries,
    +     *     sorted by partition.
    +     * 
    + */ + @Test + public void testManifestSortEliminatesDeleteEntries() { + List input = new ArrayList<>(); + + // manifest-A: partitions [0, 4] + List entriesA = new ArrayList<>(); + for (int p = 0; p <= 4; p++) { + entriesA.add(makeEntry(true, String.format("A-p%d", p), p)); + } + input.add(makeManifest(entriesA.toArray(new ManifestEntry[0]))); + + // manifest-B: partitions [2, 6] -- overlaps A + List entriesB = new ArrayList<>(); + for (int p = 2; p <= 6; p++) { + entriesB.add(makeEntry(true, String.format("B-p%d", p), p)); + } + input.add(makeManifest(entriesB.toArray(new ManifestEntry[0]))); + + // manifest-C: partitions [5, 9] -- overlaps B + List entriesC = new ArrayList<>(); + for (int p = 5; p <= 9; p++) { + entriesC.add(makeEntry(true, String.format("C-p%d", p), p)); + } + input.add(makeManifest(entriesC.toArray(new ManifestEntry[0]))); + + // manifest-D: DELETE A-p2, DELETE B-p4, ADD new-p2, ADD new-p4 + input.add( + makeManifest( + makeEntry(false, "A-p2", 2), + makeEntry(false, "B-p4", 4), + makeEntry(true, "new-p2", 2), + makeEntry(true, "new-p4", 4))); + + // manifest-E: DELETE C-p7, ADD new-p7 + input.add(makeManifest(makeEntry(false, "C-p7", 7), makeEntry(true, "new-p7", 7))); + + Options testOptions = new Options(); + testOptions.set("manifest-sort.enabled", "true"); + testOptions.set("manifest.full-compaction-threshold-size", "10B"); + + List merged = + ManifestFileMerger.merge( + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); + + // Collect all output entries + List allOutputEntries = new ArrayList<>(); + for (ManifestFileMeta meta : merged) { + allOutputEntries.addAll(manifestFile.read(meta.fileName(), meta.fileSize())); + } + + // Verify: no DELETE entries in output (all DELETE pairs eliminated) + long deleteCount = + allOutputEntries.stream().filter(e -> e.kind() == FileKind.DELETE).count(); + assertThat(deleteCount).as("Sort rewrite should eliminate all DELETE entries").isEqualTo(0); + + // Verify: the deleted ADD entries (A-p2, B-p4, C-p7) are NOT in output + Set outputFileNames = + allOutputEntries.stream().map(e -> e.file().fileName()).collect(Collectors.toSet()); + assertThat(outputFileNames).doesNotContain("A-p2", "B-p4", "C-p7"); + + // Verify: the replacement entries (new-p2, new-p4, new-p7) ARE in output + assertThat(outputFileNames).contains("new-p2", "new-p4", "new-p7"); + + // Verify: all surviving entries match what FileEntry.mergeEntries would produce + assertEquivalentEntries(input, merged); + + // Verify entries within each output manifest are sorted by partition + for (ManifestFileMeta meta : merged) { + List entries = manifestFile.read(meta.fileName(), meta.fileSize()); + for (int i = 1; i < entries.size(); i++) { + int prevPartition = entries.get(i - 1).partition().getInt(0); + int currPartition = entries.get(i).partition().getInt(0); + assertThat(currPartition) + .as("Entries within manifest should be sorted by partition") + .isGreaterThanOrEqualTo(prevPartition); + } + } + } + + /** + * Test manifest sort with a multi-field partition type. + * + *

    Setup: partition=(region INT, dt INT, hour INT), sort by dt (field index=1). 9 manifest + * files form 6 overlapping sorted runs by dt range: + * + *

    +     *   Run1: 3 files, dt=[0,15],[3,5],[6,8]
    +     *   Run2: 2 files, dt=[1,8],[5,7]
    +     *   Run3: 1 file,  dt=[0,9]
    +     *   Run4: 1 file,  dt=[5,14]
    +     *   Run5: 1 file,  dt=[8,15]
    +     *   Run6: 1 file,  dt=[4,12]
    +     * 
    + * + *

    Verifies: 1) no data loss after sort-rewrite, 2) entries within each output manifest are + * sorted by dt. + */ + @Test + public void testManifestSortWithMultiplePartitions() { + // Use a 3-field partition type: (region INT, dt INT, hour INT) + RowType multiPartitionType = RowType.of(new IntType(), new IntType(), new IntType()); + + // Create a dedicated ManifestFile for the 3-field partition type + Path path = new Path(tempDir.toString()); + FileIO fileIO = FileIOFinder.find(path); + ManifestFile multiPartManifestFile = + new ManifestFile.Factory( + fileIO, + new SchemaManager(fileIO, path), + multiPartitionType, + avro, + "zstd", + new FileStorePathFactory( + path, + multiPartitionType, + "default", + CoreOptions.FILE_FORMAT.defaultValue(), + CoreOptions.DATA_FILE_PREFIX.defaultValue(), + CoreOptions.CHANGELOG_FILE_PREFIX.defaultValue(), + CoreOptions.PARTITION_GENERATE_LEGACY_NAME.defaultValue(), + CoreOptions.FILE_SUFFIX_INCLUDE_COMPRESSION.defaultValue(), + CoreOptions.FILE_COMPRESSION.defaultValue(), + null, + null, + CoreOptions.ExternalPathStrategy.NONE, + null, + false, + null), + Long.MAX_VALUE, + null) + .create(); + + List input = new ArrayList<>(); + + // Run1 + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r1a-p0", 10, 0, 1), + makeMultiPartEntry(true, "r1a-p1", 20, 1, 2), + makeMultiPartEntry(true, "r1a-p2", 30, 15, 3))) + .get(0)); + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r1b-p3", 10, 3, 4), + makeMultiPartEntry(true, "r1b-p4", 20, 4, 5), + makeMultiPartEntry(true, "r1b-p5", 30, 5, 6))) + .get(0)); + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r1c-p6", 10, 6, 7), + makeMultiPartEntry(true, "r1c-p7", 20, 7, 8), + makeMultiPartEntry(true, "r1c-p8", 30, 8, 9))) + .get(0)); + + // Run2 + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r2a-p1", 5, 1, 10), + makeMultiPartEntry(true, "r2a-p2", 15, 2, 11), + makeMultiPartEntry(true, "r2a-p3", 25, 3, 12), + makeMultiPartEntry(true, "r2a-p4", 35, 8, 13))) + .get(0)); + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r2b-p5", 5, 5, 14), + makeMultiPartEntry(true, "r2b-p6", 15, 6, 15), + makeMultiPartEntry(true, "r2b-p7", 25, 7, 16))) + .get(0)); + + // Run3 + List run3Entries = new ArrayList<>(); + for (int p = 0; p <= 9; p++) { + run3Entries.add(makeMultiPartEntry(true, String.format("r3-p%d", p), 99, p, p + 20)); + } + input.add(multiPartManifestFile.write(run3Entries).get(0)); + + // Run4 + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r4a-p10", 10, 5, 30), + makeMultiPartEntry(true, "r4a-p11", 20, 11, 31), + makeMultiPartEntry(true, "r4a-p12", 30, 12, 32), + makeMultiPartEntry(true, "r4a-p13", 40, 13, 33), + makeMultiPartEntry(true, "r4a-p14", 50, 14, 34))) + .get(0)); + + // Run5 + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r5a-p11", 11, 8, 40), + makeMultiPartEntry(true, "r5a-p12", 21, 12, 41), + makeMultiPartEntry(true, "r5a-p13", 31, 13, 42), + makeMultiPartEntry(true, "r5a-p14", 41, 14, 43), + makeMultiPartEntry(true, "r5a-p15", 51, 15, 44))) + .get(0)); + + // Run6 + input.add( + multiPartManifestFile + .write( + Arrays.asList( + makeMultiPartEntry(true, "r6a-p7", 7, 4, 50), + makeMultiPartEntry(true, "r6a-p8", 17, 8, 51), + makeMultiPartEntry(true, "r6a-p9", 27, 9, 52), + makeMultiPartEntry(true, "r6a-p10", 37, 10, 53), + makeMultiPartEntry(true, "r6a-p11", 47, 11, 54), + makeMultiPartEntry(true, "r6a-p12", 57, 12, 55))) + .get(0)); + + Options testOptions = new Options(); + testOptions.set("manifest-sort.enabled", "true"); + // Sort by the second partition field "f1" (dt) + testOptions.set("manifest-sort.partition-field", "f1"); + List merged = + ManifestFileMerger.merge( + input, + multiPartManifestFile, + multiPartitionType, + CoreOptions.fromMap(testOptions.toMap())); + + // Verify no data loss + List inputEntries = + input.stream() + .flatMap( + f -> + multiPartManifestFile.read(f.fileName(), f.fileSize()) + .stream()) + .collect(Collectors.toList()); + List entryBeforeMerge = + FileEntry.mergeEntries(inputEntries).stream() + .filter(entry -> entry.kind() == FileKind.ADD) + .map(entry -> entry.kind() + "-" + entry.file().fileName()) + .collect(Collectors.toList()); + List entryAfterMerge = new ArrayList<>(); + for (ManifestFileMeta meta : merged) { + for (ManifestEntry entry : + multiPartManifestFile.read(meta.fileName(), meta.fileSize())) { + entryAfterMerge.add(entry.kind() + "-" + entry.file().fileName()); + } + } + assertThat(entryBeforeMerge).hasSameElementsAs(entryAfterMerge); + + // Verify entries within each output manifest are sorted by the second field (dt) + for (ManifestFileMeta meta : merged) { + List entries = + multiPartManifestFile.read(meta.fileName(), meta.fileSize()); + for (int i = 1; i < entries.size(); i++) { + int prevDt = entries.get(i - 1).partition().getInt(1); + int currDt = entries.get(i).partition().getInt(1); + assertThat(currDt) + .as("Entries within manifest should be sorted by partition") + .isGreaterThanOrEqualTo(prevDt); + } + } + } + + /** + * Test that when manifest-sort.max-rewrite-size budget is exceeded in the middle of a section, + * the remaining files are appended to the tail and the final manifest order is preserved. + * + *

    Design: + * + *

    +     *   - Create a large section with overlapping partition ranges that exceeds the budget
    +     *   - Set a small manifest-sort.max-rewrite-size to force budget split
    +     *   - Verify that after merge, all manifests are globally sorted by partition field
    +     *   - Verify that entries are equivalent (no data loss)
    +     * 
    + */ + @Test + public void testManifestSortBudgetSplitPreservesOrder() { + // Create manifests with overlapping ranges, large enough to exceed budget + List input = new ArrayList<>(); + + // Manifest A: partitions [0, 10] - large size + List entriesA = new ArrayList<>(); + for (int p = 0; p <= 10; p++) { + entriesA.add(makeEntry(true, String.format("A-p%d", p), p)); + } + ManifestFileMeta manifestA = makeManifest(entriesA.toArray(new ManifestEntry[0])); + // Manually increase file size to simulate large manifest + input.add( + new ManifestFileMeta( + manifestA.fileName(), + 100, + manifestA.numAddedFiles(), + manifestA.numDeletedFiles(), + manifestA.partitionStats(), + manifestA.schemaId(), + manifestA.minBucket(), + manifestA.maxBucket(), + manifestA.minLevel(), + manifestA.maxLevel(), + manifestA.minRowId(), + manifestA.maxRowId())); + + // Manifest B: partitions [5, 15] - overlaps with A + List entriesB = new ArrayList<>(); + for (int p = 5; p <= 15; p++) { + entriesB.add(makeEntry(true, String.format("B-p%d", p), p)); + } + ManifestFileMeta manifestB = makeManifest(entriesB.toArray(new ManifestEntry[0])); + input.add( + new ManifestFileMeta( + manifestB.fileName(), + 100, + manifestB.numAddedFiles(), + manifestB.numDeletedFiles(), + manifestB.partitionStats(), + manifestB.schemaId(), + manifestB.minBucket(), + manifestB.maxBucket(), + manifestB.minLevel(), + manifestB.maxLevel(), + manifestB.minRowId(), + manifestB.maxRowId())); + + // Manifest C: partitions [10, 20] - overlaps with B + List entriesC = new ArrayList<>(); + for (int p = 10; p <= 20; p++) { + entriesC.add(makeEntry(true, String.format("C-p%d", p), p)); + } + ManifestFileMeta manifestC = makeManifest(entriesC.toArray(new ManifestEntry[0])); + input.add( + new ManifestFileMeta( + manifestC.fileName(), + 100, + manifestC.numAddedFiles(), + manifestC.numDeletedFiles(), + manifestC.partitionStats(), + manifestC.schemaId(), + manifestC.minBucket(), + manifestC.maxBucket(), + manifestC.minLevel(), + manifestC.maxLevel(), + manifestC.minRowId(), + manifestC.maxRowId())); + + // Set small budget to force split + Options testOptions = new Options(); + testOptions.set("manifest-sort.enabled", "true"); + testOptions.set("manifest-sort.max-rewrite-size", "150B"); // Total input size is 300B + + List merged = + ManifestFileMerger.merge( + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); + + // Verify entries are equivalent + assertEquivalentEntries(input, merged); + + // Verify global ordering: all manifests sorted by partition min value + for (int i = 1; i < merged.size(); i++) { + BinaryRow prevMin = merged.get(i - 1).partitionStats().minValues(); + BinaryRow currMin = merged.get(i).partitionStats().minValues(); + assertThat(currMin.getInt(0)) + .as("Manifests should be globally sorted by partition field") + .isGreaterThanOrEqualTo(prevMin.getInt(0)); + } + + // Verify entries within each manifest are sorted + for (ManifestFileMeta meta : merged) { + List entries = manifestFile.read(meta.fileName(), meta.fileSize()); + for (int i = 1; i < entries.size(); i++) { + int prevPartition = entries.get(i - 1).partition().getInt(0); + int currPartition = entries.get(i).partition().getInt(0); + assertThat(currPartition) + .as("Entries within manifest should be sorted by partition") + .isGreaterThanOrEqualTo(prevPartition); + } + } + } + + /** + * Test boundary equality (min == previous.max) handling in both SortedRun construction and + * Section splitting. Boundary-touching files should be allowed in the same SortedRun but may be + * separated into different Sections. + * + *

    Design: + * + *

    +     *   - Create manifests with boundary-touching partition ranges
    +     *   - Manifest A: [0, 5]
    +     *   - Manifest B: [5, 10] (min == A.max, boundary touching)
    +     *   - Manifest C: [10, 15] (min == B.max, boundary touching)
    +     *   - Verify they can be in the same SortedRun (>= comparison)
    +     *   - Verify they may be split into different Sections (>= comparison with comment)
    +     * 
    + */ + @Test + public void testBoundaryEqualityHandling() { + List input = new ArrayList<>(); + + // Manifest A: partitions [0, 5] + List entriesA = new ArrayList<>(); + for (int p = 0; p <= 5; p++) { + entriesA.add(makeEntry(true, String.format("A-p%d", p), p)); + } + input.add(makeManifest(entriesA.toArray(new ManifestEntry[0]))); + + // Manifest B: partitions [5, 10] - boundary touches A (min == A.max) + List entriesB = new ArrayList<>(); + for (int p = 5; p <= 10; p++) { + entriesB.add(makeEntry(true, String.format("B-p%d", p), p)); + } + input.add(makeManifest(entriesB.toArray(new ManifestEntry[0]))); + + // Manifest C: partitions [10, 15] - boundary touches B (min == B.max) + List entriesC = new ArrayList<>(); + for (int p = 10; p <= 15; p++) { + entriesC.add(makeEntry(true, String.format("C-p%d", p), p)); + } + input.add(makeManifest(entriesC.toArray(new ManifestEntry[0]))); + + Options testOptions = new Options(); + testOptions.set("manifest-sort.enabled", "true"); + + List merged = + ManifestFileMerger.merge( + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); + + // Verify entries are equivalent + assertEquivalentEntries(input, merged); + + // Verify all manifests maintain global sort order + for (int i = 1; i < merged.size(); i++) { + BinaryRow prevMin = merged.get(i - 1).partitionStats().minValues(); + BinaryRow prevMax = merged.get(i - 1).partitionStats().maxValues(); + BinaryRow currMin = merged.get(i).partitionStats().minValues(); + + // Boundary-touching is allowed: currMin >= prevMin + assertThat(currMin.getInt(0)) + .as("Global order should be maintained with boundary-touching allowed") + .isGreaterThanOrEqualTo(prevMin.getInt(0)); + + // Log boundary equality cases for documentation + if (currMin.getInt(0) == prevMax.getInt(0)) { + System.out.println( + String.format( + "Boundary equality detected: manifest[%d].min=%d == manifest[%d].max=%d", + i, currMin.getInt(0), i - 1, prevMax.getInt(0))); + } + } + + // Verify entries within each manifest are sorted + for (ManifestFileMeta meta : merged) { + List entries = manifestFile.read(meta.fileName(), meta.fileSize()); + for (int i = 1; i < entries.size(); i++) { + int prevPartition = entries.get(i - 1).partition().getInt(0); + int currPartition = entries.get(i).partition().getInt(0); + assertThat(currPartition) + .as("Entries within manifest should be sorted by partition") + .isGreaterThanOrEqualTo(prevPartition); + } + } + } + + /** Create a ManifestEntry with a 3-field partition row (region, dt, hour). */ + private ManifestEntry makeMultiPartEntry( + boolean isAdd, String fileName, int region, int dt, int hour) { + BinaryRow binaryRow = new BinaryRow(3); + BinaryRowWriter writer = new BinaryRowWriter(binaryRow); + writer.writeInt(0, region); + writer.writeInt(1, dt); + writer.writeInt(2, hour); + writer.complete(); + + return ManifestEntry.create( + isAdd ? FileKind.ADD : FileKind.DELETE, + binaryRow, + 0, + 0, + DataFileMeta.create( + fileName, + 0, + 0, + binaryRow, + binaryRow, + StatsTestUtils.newEmptySimpleStats(), + StatsTestUtils.newEmptySimpleStats(), + 0, + 0, + 0, + 0, + Collections.emptyList(), + Timestamp.fromEpochMillis(200000), + 0L, + null, + FileSource.APPEND, + null, + null, + null, + null)); + } } diff --git a/paimon-core/src/test/java/org/apache/paimon/manifest/NoPartitionManifestFileMetaTest.java b/paimon-core/src/test/java/org/apache/paimon/manifest/NoPartitionManifestFileMetaTest.java index 591b3206518d..66465f1e7531 100644 --- a/paimon-core/src/test/java/org/apache/paimon/manifest/NoPartitionManifestFileMetaTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/manifest/NoPartitionManifestFileMetaTest.java @@ -18,7 +18,9 @@ package org.apache.paimon.manifest; +import org.apache.paimon.CoreOptions; import org.apache.paimon.operation.ManifestFileMerger; +import org.apache.paimon.options.Options; import org.apache.paimon.types.RowType; import org.junit.jupiter.api.BeforeEach; @@ -49,9 +51,16 @@ public void testMerge() { List input = createBaseManifestFileMetas(false); addDeltaManifests(input, false); + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", "500B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "200B"); List merged = ManifestFileMerger.merge( - input, manifestFile, 500, 3, 200, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); assertEquivalentEntries(input, merged); // the first one is not deleted, it should not be merged @@ -89,9 +98,16 @@ public void testMergeFullCompactionWithoutDeleteFile() { input.add(makeManifest(makeEntry(true, "F", null))); input.add(makeManifest(makeEntry(true, "G", null))); + Options testOptions = new Options(); + testOptions.set("manifest.target-file-size", threshold + "B"); + testOptions.set("manifest.merge-min-count", "3"); + testOptions.set("manifest.full-compaction-threshold-size", "200B"); List merged = ManifestFileMerger.merge( - input, manifestFile, threshold, 3, 200, getPartitionType(), null); + input, + manifestFile, + getPartitionType(), + CoreOptions.fromMap(testOptions.toMap())); assertEquivalentEntries( input.stream() .filter(f -> !baseFiles.contains(f.fileName())) diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java index f4f2d28f7578..050f4b855b39 100644 --- a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java @@ -887,6 +887,69 @@ public void testAggregationWithoutSequenceGroup() { "Must use sequence group for aggregation functions but not found for field f1.")); } + @Test + public void testSequenceGroupCannotContainPrimaryKey() { + // Issue #7052: Putting a primary key column in sequence-group should be forbidden + // as it causes Parquet decoding failures during compaction + Options options = new Options(); + options.set("fields.f0.sequence-group", "f1,f2"); + RowType rowType = + RowType.of(DataTypes.INT(), DataTypes.INT(), DataTypes.INT(), DataTypes.INT()); + assertThatThrownBy( + () -> + PartialUpdateMergeFunction.factory( + options, rowType, ImmutableList.of("f0"))) + .hasMessageContaining( + "The sequence-group 'fields.f0.sequence-group' contains primary key field 'f0', " + + "which is not allowed. Primary key columns cannot be put in sequence-group."); + } + + @Test + public void testMultiSequenceFieldsCannotContainPrimaryKey() { + // Issue #7052: Multi-field sequence-group also cannot contain primary key columns + // The sequence fields (f2,f3) are the "self" part, they must not contain PKs + Options options = new Options(); + options.set("fields.f2,f3.sequence-group", "f0,f4"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + assertThatThrownBy( + () -> + PartialUpdateMergeFunction.factory( + options, rowType, ImmutableList.of("f2"))) + .hasMessageContaining( + "The sequence-group 'fields.f2,f3.sequence-group' contains primary key field 'f2', " + + "which is not allowed. Primary key columns cannot be put in sequence-group."); + } + + @Test + public void testPrimaryKeyCannotBeInSequenceGroupValue() { + // Issue #7052: A primary key column appearing in the value part of sequence-group + // is forbidden — f2 is the PK and appears in the sequence-group's value list + Options options = new Options(); + options.set("fields.f4.sequence-group", "f1,f2"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + assertThatThrownBy( + () -> + PartialUpdateMergeFunction.factory( + options, rowType, ImmutableList.of("f2"))) + .hasMessageContaining( + "The sequence-group 'fields.f4.sequence-group' contains primary key field 'f2', " + + "which is not allowed. Primary key columns cannot be put in sequence-group."); + } + @Test public void testDeleteReproduceCorrectSequenceNumber() { Options options = new Options(); diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeSnapshotOrderingTest.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeSnapshotOrderingTest.java new file mode 100644 index 000000000000..0e29efd17102 --- /dev/null +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeSnapshotOrderingTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.mergetree.compact; + +import org.apache.paimon.CoreOptions.SortEngine; +import org.apache.paimon.KeyValue; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowKind; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for snapshot-ordering in sort-merge readers. With {@code sequence.snapshot-ordering}, the + * commit snapshot id is carried in each record's {@code sequenceNumber} (stamped at read time for + * APPEND files), so the sort-merge readers need no snapshot-specific branch: comparing by {@code + * sequenceNumber} already makes records from later snapshots win. + */ +public class SortMergeSnapshotOrderingTest { + + private static final Comparator KEY_COMPARATOR = + (a, b) -> Integer.compare(a.getInt(0), b.getInt(0)); + + private static final RowType VALUE_TYPE = RowType.of(DataTypes.INT()); + + @ParameterizedTest + @EnumSource(SortEngine.class) + public void testLaterSnapshotWins(SortEngine sortEngine) throws IOException { + // seq carries the snapshot id: snapshot 6 wins over snapshot 5. + KeyValue winner = merge(sortEngine, kv(1, 5, 999), kv(1, 6, 1)); + assertThat(winner.value().getInt(0)).isEqualTo(1); + assertThat(winner.sequenceNumber()).isEqualTo(6); + } + + @ParameterizedTest + @EnumSource(SortEngine.class) + public void testHigherSequenceWins(SortEngine sortEngine) throws IOException { + KeyValue winner = merge(sortEngine, kv(1, 100, 100), kv(1, 50, 50)); + assertThat(winner.value().getInt(0)).isEqualTo(100); + } + + private static KeyValue kv(int key, long seq, int value) { + return new KeyValue() + .replace(GenericRow.of(key), seq, RowKind.INSERT, GenericRow.of(value)); + } + + private static KeyValue merge(SortEngine sortEngine, KeyValue... kvs) throws IOException { + List> readers = new ArrayList<>(); + for (KeyValue kv : kvs) { + readers.add(new SingleKvReader(kv)); + } + + MergeFunctionWrapper wrapper = + new ReducerMergeFunctionWrapper(DeduplicateMergeFunction.factory().create()); + + RecordReader reader = + SortMergeReader.createSortMergeReader( + readers, KEY_COMPARATOR, null, wrapper, sortEngine); + + RecordReader.RecordIterator batch = reader.readBatch(); + assertThat(batch).isNotNull(); + KeyValue result = batch.next(); + assertThat(result).isNotNull(); + assertThat(batch.next()).isNull(); + batch.releaseBatch(); + reader.close(); + return result; + } + + private static class SingleKvReader implements RecordReader { + private KeyValue kv; + + SingleKvReader(KeyValue kv) { + this.kv = kv; + } + + @Nullable + @Override + public RecordIterator readBatch() { + if (kv == null) { + return null; + } + KeyValue toReturn = kv; + kv = null; + return new RecordIterator() { + private boolean returned = false; + + @Nullable + @Override + public KeyValue next() { + if (returned) { + return null; + } + returned = true; + return toReturn; + } + + @Override + public void releaseBatch() {} + }; + } + + @Override + public void close() {} + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java index 1bdddcb846e1..a6a2794c36cf 100644 --- a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java @@ -749,6 +749,252 @@ public void testFieldNestedAppendAggWithCountLimit() { .containsExactlyInAnyOrderElementsOf(Arrays.asList(row(0, 1, "B"), row(0, 1, "b"))); } + @Test + public void testFieldNestedUpdateAggWithSequenceField() { + DataType elementRowType = + DataTypes.ROW( + DataTypes.FIELD(0, "k0", DataTypes.INT()), + DataTypes.FIELD(1, "k1", DataTypes.INT()), + DataTypes.FIELD(2, "v", DataTypes.STRING()), + DataTypes.FIELD(3, "seq", DataTypes.INT())); + FieldNestedUpdateAgg agg = + new FieldNestedUpdateAgg( + FieldNestedUpdateAggFactory.NAME, + DataTypes.ARRAY(elementRowType), + Arrays.asList("k0", "k1"), + Collections.singletonList("seq"), + Integer.MAX_VALUE); + + InternalArray accumulator; + InternalArray.ElementGetter elementGetter = + InternalArray.createElementGetter(elementRowType); + + InternalRow current = row(0, 0, "A", 1); + accumulator = (InternalArray) agg.agg(null, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf(Collections.singletonList(current)); + + current = row(0, 1, "B", 2); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList(row(0, 0, "A", 1), row(0, 1, "B", 2))); + + current = row(0, 1, "b", 3); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList(row(0, 0, "A", 1), row(0, 1, "b", 3))); + + current = row(0, 1, "B_late", 2); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList(row(0, 0, "A", 1), row(0, 1, "b", 3))); + + current = row(0, 1, "b", 3); + accumulator = (InternalArray) agg.retract(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf(Collections.singletonList(row(0, 0, "A", 1))); + } + + @Test + public void testFieldNestedUpdateAggWithMultipleSequenceFields() { + DataType elementRowType = + DataTypes.ROW( + DataTypes.FIELD(0, "k0", DataTypes.INT()), + DataTypes.FIELD(1, "k1", DataTypes.INT()), + DataTypes.FIELD(2, "v", DataTypes.STRING()), + DataTypes.FIELD(3, "seq", DataTypes.INT()), + DataTypes.FIELD(4, "ts", DataTypes.TIMESTAMP(3))); + + FieldNestedUpdateAgg agg = + new FieldNestedUpdateAgg( + FieldNestedUpdateAggFactory.NAME, + DataTypes.ARRAY(elementRowType), + Arrays.asList("k0", "k1"), + Arrays.asList("seq", "ts"), + Integer.MAX_VALUE); + + InternalArray accumulator = null; + InternalArray.ElementGetter elementGetter = + InternalArray.createElementGetter(elementRowType); + + org.apache.paimon.data.Timestamp ts1 = + org.apache.paimon.data.Timestamp.fromEpochMillis(1000L); + org.apache.paimon.data.Timestamp ts2 = + org.apache.paimon.data.Timestamp.fromEpochMillis(2000L); + org.apache.paimon.data.Timestamp ts3 = + org.apache.paimon.data.Timestamp.fromEpochMillis(3000L); + + InternalRow current = row(1, 0, "A", 1, ts2); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + current = row(0, 1, "B", 2, ts1); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList(row(1, 0, "A", 1, ts2), row(0, 1, "B", 2, ts1))); + + current = row(1, 1, "C", 1, ts2); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList( + row(1, 0, "A", 1, ts2), + row(0, 1, "B", 2, ts1), + row(1, 1, "C", 1, ts2))); + + current = row(1, 0, "A_late_updated_by_ts", 1, ts1); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList( + row(1, 0, "A", 1, ts2), + row(0, 1, "B", 2, ts1), + row(1, 1, "C", 1, ts2))); + + current = row(1, 0, "A_updated_by_ts", 1, ts3); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList( + row(1, 0, "A_updated_by_ts", 1, ts3), + row(0, 1, "B", 2, ts1), + row(1, 1, "C", 1, ts2))); + + // Try to update with a smaller 1st seq, even if the 2nd seq (ts) is larger + // Result: Should be IGNORED because the 1st seq field (1 < 2) takes higher priority. + current = row(0, 1, "b_ignored", 1, ts3); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList( + row(1, 0, "A_updated_by_ts", 1, ts3), + row(0, 1, "B", 2, ts1), + row(1, 1, "C", 1, ts2))); + + // Update with the SAME 1st seq, but a larger 2nd seq (ts) + // Result: Should be SUCCESSFULLY UPDATED because seq (2 == 2) and ts (ts2 > ts1). + current = row(0, 1, "B_updated_by_ts", 2, ts2); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList( + row(1, 0, "A_updated_by_ts", 1, ts3), + row(0, 1, "B_updated_by_ts", 2, ts2), + row(1, 1, "C", 1, ts2))); + + // Update with a larger 1st seq, even if the 2nd seq (ts) is smaller + // Result: Should be SUCCESSFULLY UPDATED because the 1st seq field (3 > 2) wins. + current = row(0, 1, "B_updated_by_seq", 3, ts1); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList( + row(1, 0, "A_updated_by_ts", 1, ts3), + row(0, 1, "B_updated_by_seq", 3, ts1), + row(1, 1, "C", 1, ts2))); + + // Retract the latest row matching the current state + current = row(0, 1, "B_updated_by_seq", 3, ts1); + accumulator = (InternalArray) agg.retract(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList( + row(1, 0, "A_updated_by_ts", 1, ts3), row(1, 1, "C", 1, ts2))); + } + + @Test + public void testFieldNestedUpdateAggWithSequenceFieldWithoutNestedKey() { + DataType elementRowType = + DataTypes.ROW( + DataTypes.FIELD(0, "k0", DataTypes.INT()), + DataTypes.FIELD(1, "k1", DataTypes.INT()), + DataTypes.FIELD(2, "v", DataTypes.STRING()), + DataTypes.FIELD(3, "seq", DataTypes.INT())); + + org.assertj.core.api.Assertions.assertThatThrownBy( + () -> + new FieldNestedUpdateAgg( + FieldNestedUpdateAggFactory.NAME, + DataTypes.ARRAY(elementRowType), + Collections.emptyList(), + Collections.singletonList("seq"), + Integer.MAX_VALUE)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("nested-sequence-field requires nested-key to be set."); + } + + @Test + public void testFieldNestedUpdateAggWithCountLimitWithSequenceFieldWithoutNestedKey() { + DataType elementRowType = + DataTypes.ROW( + DataTypes.FIELD(0, "k0", DataTypes.INT()), + DataTypes.FIELD(1, "k1", DataTypes.INT()), + DataTypes.FIELD(2, "v", DataTypes.STRING()), + DataTypes.FIELD(3, "seq", DataTypes.INT())); + + // Verify that the same precondition check applies even when a count limit is specified + org.assertj.core.api.Assertions.assertThatThrownBy( + () -> + new FieldNestedUpdateAgg( + FieldNestedUpdateAggFactory.NAME, + DataTypes.ARRAY(elementRowType), + Collections.emptyList(), + Collections.singletonList("seq"), + 2)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("nested-sequence-field requires nested-key to be set."); + } + + @Test + public void testFieldNestedUpdateAggWithCountLimitWithSequenceField() { + DataType elementRowType = + DataTypes.ROW( + DataTypes.FIELD(0, "k0", DataTypes.INT()), + DataTypes.FIELD(1, "k1", DataTypes.INT()), + DataTypes.FIELD(2, "v", DataTypes.STRING()), + DataTypes.FIELD(3, "seq", DataTypes.INT())); + + FieldNestedUpdateAgg agg = + new FieldNestedUpdateAgg( + FieldNestedUpdateAggFactory.NAME, + DataTypes.ARRAY(elementRowType), + Arrays.asList("k0", "k1"), + Collections.singletonList("seq"), + 2); // Enforce count limit = 2 + + InternalArray accumulator = null; + InternalArray.ElementGetter elementGetter = + InternalArray.createElementGetter(elementRowType); + + InternalRow current = row(0, 1, "B", 1); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf(Collections.singletonList(row(0, 1, "B", 1))); + + current = row(0, 1, "B_updated", 2); + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + // The existing row should be updated, and the total size remains 1 + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Collections.singletonList(row(0, 1, "B_updated", 2))); + + current = row(1, 2, "C", 3); // Different nested key (1, 2) + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList(row(0, 1, "B_updated", 2), row(1, 2, "C", 3))); + + current = row(0, 3, "D", 4); // Another different nested key (0, 3) + accumulator = (InternalArray) agg.agg(accumulator, singletonArray(current)); + + // count limit is 2, so the third element will be dropped + assertThat(unnest(accumulator, elementGetter)) + .containsExactlyInAnyOrderElementsOf( + Arrays.asList(row(0, 1, "B_updated", 2), row(1, 2, "C", 3))); + } + private List unnest(InternalArray array, InternalArray.ElementGetter elementGetter) { return IntStream.range(0, array.size()) .mapToObj(i -> elementGetter.getElementOrNull(array, i)) @@ -763,6 +1009,15 @@ private InternalRow row(Integer k0, Integer k1, String v) { return GenericRow.of(k0, k1, BinaryString.fromString(v)); } + private InternalRow row(Integer k0, Integer k1, String v, Integer seq) { + return GenericRow.of(k0, k1, BinaryString.fromString(v), seq); + } + + private InternalRow row( + Object k0, Object k1, String v, Object seq, org.apache.paimon.data.Timestamp ts) { + return GenericRow.of(k0, k1, BinaryString.fromString(v), seq, ts); + } + @Test public void testFieldCollectAggWithDistinct() { FieldCollectAgg agg = diff --git a/paimon-core/src/test/java/org/apache/paimon/operation/ChainTablePartitionExpireTest.java b/paimon-core/src/test/java/org/apache/paimon/operation/ChainTablePartitionExpireTest.java new file mode 100644 index 000000000000..5ff5edbf71b4 --- /dev/null +++ b/paimon-core/src/test/java/org/apache/paimon/operation/ChainTablePartitionExpireTest.java @@ -0,0 +1,786 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.operation; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.catalog.Catalog; +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.manifest.PartitionEntry; +import org.apache.paimon.options.Options; +import org.apache.paimon.partition.PartitionStatistics; +import org.apache.paimon.partition.PartitionUpdateTimeExpireStrategy; +import org.apache.paimon.schema.Schema; +import org.apache.paimon.schema.SchemaChange; +import org.apache.paimon.schema.SchemaManager; +import org.apache.paimon.schema.TableSchema; +import org.apache.paimon.table.CatalogEnvironment; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.FileStoreTableFactory; +import org.apache.paimon.table.PartitionModification; +import org.apache.paimon.table.sink.CommitMessage; +import org.apache.paimon.table.sink.StreamTableWrite; +import org.apache.paimon.table.sink.TableCommitImpl; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.time.Duration; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link ChainTablePartitionExpire}. */ +public class ChainTablePartitionExpireTest { + + @TempDir java.nio.file.Path tempDir; + + private String commitUser; + + @BeforeEach + public void before() { + commitUser = UUID.randomUUID().toString(); + } + + @Test + public void testExplicitPartitionExpireRejectsUpdateTimeStrategy() throws Exception { + Path tablePath = tablePath("explicit_update_time_strategy"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + + assertThatThrownBy( + () -> + mainTable + .store() + .newPartitionExpire( + commitUser, + mainTable, + Duration.ofDays(1), + Duration.ZERO, + new PartitionUpdateTimeExpireStrategy( + CoreOptions.fromMap(Collections.emptyMap()), + mainTable.schema().logicalPartitionType()))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Chain table only supports 'values-time' partition expiration strategy"); + } + + @Test + public void testExpireWithSinglePartitionKey() throws Exception { + Path tablePath = tablePath("simple_expire"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(snapshotTable, "20250101", "v1"); + write(snapshotTable, "20250201", "v2"); + write(snapshotTable, "20250301", "v3"); + + write(deltaTable, "20250110", "v4"); + write(deltaTable, "20250115", "v5"); + write(deltaTable, "20250210", "v6"); + write(deltaTable, "20250215", "v7"); + write(deltaTable, "20250315", "v8"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + assertThat(listPartitions(snapshotTable)) + .containsExactlyInAnyOrder("20250101", "20250201", "20250301"); + assertThat(listPartitions(deltaTable)) + .containsExactlyInAnyOrder( + "20250110", "20250115", "20250210", "20250215", "20250315"); + + // cutoff = 2025-03-31 - 20d = 2025-03-11 + // Snapshots before cutoff: 20250101, 20250201, 20250301 (3 snapshots) + // Anchor = 20250301 (kept), expire 2 segments: + // Segment0: S(20250101), d(20250110), d(20250115) + // Segment1: S(20250201), d(20250210), d(20250215) + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(20), false); + expire.setLastCheck(LocalDateTime.of(2025, 1, 1, 0, 0)); + List> expired = + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + assertThat(expired).isNotNull(); + assertThat(expired).isNotEmpty(); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + assertThat(listPartitions(snapshotTable)).containsExactlyInAnyOrder("20250301"); + assertThat(listPartitions(deltaTable)).containsExactlyInAnyOrder("20250315"); + } + + @Test + public void testNoExpireWhenOnlyOneSnapshotBeforeCutoff() throws Exception { + Path tablePath = tablePath("no_expire_one_snapshot"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(snapshotTable, "20250201", "v1"); + write(snapshotTable, "20250315", "v2"); + write(deltaTable, "20250205", "v3"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), false); + expire.setLastCheck(LocalDateTime.of(2025, 1, 1, 0, 0)); + List> expired = + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + assertThat(expired).isEmpty(); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + assertThat(listPartitions(snapshotTable)).containsExactlyInAnyOrder("20250201", "20250315"); + assertThat(listPartitions(deltaTable)).containsExactlyInAnyOrder("20250205"); + } + + @Test + public void testExpireMultipleSegments() throws Exception { + Path tablePath = tablePath("multi_segments"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(snapshotTable, "20250101", "v1"); + write(snapshotTable, "20250115", "v2"); + write(snapshotTable, "20250201", "v3"); + write(snapshotTable, "20250315", "v4"); + + write(deltaTable, "20250105", "v5"); + write(deltaTable, "20250120", "v6"); + write(deltaTable, "20250210", "v7"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // cutoff = 2025-03-31 - 30d = 2025-03-01 + // Snapshots before cutoff: 20250101, 20250115, 20250201 (3 snapshots) + // Anchor = 20250201 (kept), expire S(20250101), S(20250115) + // Delta before anchor: d(20250105), d(20250120) + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), false); + expire.setLastCheck(LocalDateTime.of(2025, 1, 1, 0, 0)); + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + assertThat(listPartitions(snapshotTable)).containsExactlyInAnyOrder("20250201", "20250315"); + assertThat(listPartitions(deltaTable)).containsExactlyInAnyOrder("20250210"); + } + + @Test + public void testNoExpireWhenNoSnapshotsBeforeCutoff() throws Exception { + Path tablePath = tablePath("no_expire_no_snapshot"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(snapshotTable, "20250315", "v1"); + write(snapshotTable, "20250320", "v2"); + write(deltaTable, "20250316", "v3"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // cutoff = 2025-03-31 - 30d = 2025-03-01 + // No snapshots before cutoff + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), false); + expire.setLastCheck(LocalDateTime.of(2025, 1, 1, 0, 0)); + List> expired = + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + assertThat(expired).isEmpty(); + } + + @Test + public void testCheckIntervalPreventsExpire() throws Exception { + Path tablePath = tablePath("check_interval"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(snapshotTable, "20250101", "v1"); + write(snapshotTable, "20250201", "v2"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + ChainTablePartitionExpire expire = + new ChainTablePartitionExpire( + Duration.ofDays(30), + Duration.ofDays(1), + snapshotTable, + deltaTable, + CoreOptions.fromMap(buildOptions(Duration.ofDays(30), false)), + snapshotTable.schema().logicalPartitionType(), + false, + Integer.MAX_VALUE, + 0, + null, + null); + expire.setLastCheck(LocalDateTime.of(2025, 3, 31, 0, 0)); + + List> expired = + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + assertThat(expired).isNull(); + } + + @Test + public void testMaxExpireNumLimitsSegments() throws Exception { + Path tablePath = tablePath("max_expire_num"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + // Snapshots: S(0101), S(0115), S(0201), S(0315) + // cutoff = 03-01, anchor = S(0201) + // Segments to expire: Segment1={S(0101), d(0105)}, Segment2={S(0115), d(0120)} + write(snapshotTable, "20250101", "v1"); + write(snapshotTable, "20250115", "v2"); + write(snapshotTable, "20250201", "v3"); + write(snapshotTable, "20250315", "v4"); + + write(deltaTable, "20250105", "v5"); + write(deltaTable, "20250120", "v6"); + write(deltaTable, "20250210", "v7"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // maxExpireNum=1 means only 1 segment: Segment1={S(0101), d(0105)} + ChainTablePartitionExpire expire = + new ChainTablePartitionExpire( + Duration.ofDays(30), + Duration.ZERO, + snapshotTable, + deltaTable, + CoreOptions.fromMap(buildOptions(Duration.ofDays(30), false)), + snapshotTable.schema().logicalPartitionType(), + false, + 1, + 0, + null, + null); + expire.setLastCheck(LocalDateTime.of(2025, 1, 1, 0, 0)); + List> expired = + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + assertThat(expired).isNotNull(); + // 1 segment = S(0101) + d(0105) = 2 partitions + assertThat(expired).hasSize(2); + + // Verify: S(0101) expired, S(0115) still exists (not expired, was in segment 2) + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + assertThat(listPartitions(snapshotTable)) + .containsExactlyInAnyOrder("20250115", "20250201", "20250315"); + // d(0105) expired, d(0120) and d(0210) kept + assertThat(listPartitions(deltaTable)).containsExactlyInAnyOrder("20250120", "20250210"); + } + + @Test + public void testExpireWithGroupPartition() throws Exception { + Path tablePath = tablePath("group_partition"); + createChainTable(tablePath, true); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + // Group "US": snapshots 0101, 0201, 0301 + writeGrouped(snapshotTable, "US", "20250101", "v1"); + writeGrouped(snapshotTable, "US", "20250201", "v2"); + writeGrouped(snapshotTable, "US", "20250301", "v3"); + // Group "US": deltas 0110, 0210 + writeGrouped(deltaTable, "US", "20250110", "d1"); + writeGrouped(deltaTable, "US", "20250210", "d2"); + + // Group "EU": only one snapshot before cutoff, so nothing should expire + writeGrouped(snapshotTable, "EU", "20250215", "v4"); + writeGrouped(snapshotTable, "EU", "20250320", "v5"); + writeGrouped(deltaTable, "EU", "20250220", "d3"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // cutoff = 2025-03-31 - 30d = 2025-03-01 + // Group "US": snapshots before cutoff = [0101, 0201]. Anchor = 0201 (kept). + // Expire: S(0101), delta(0110) (before anchor 0201) + // Keep: S(0201), S(0301), delta(0210) + // Group "EU": snapshots before cutoff = [0215] (only 1). Nothing expired. + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), true); + expire.setLastCheck(LocalDateTime.of(2025, 1, 1, 0, 0)); + List> expired = + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + assertThat(expired).isNotNull(); + assertThat(expired).hasSize(2); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + List snapshotParts = listGroupedPartitions(snapshotTable); + List deltaParts = listGroupedPartitions(deltaTable); + + assertThat(snapshotParts).contains("US|20250201", "US|20250301"); + assertThat(snapshotParts).doesNotContain("US|20250101"); + assertThat(snapshotParts).contains("EU|20250215", "EU|20250320"); + + assertThat(deltaParts).contains("US|20250210"); + assertThat(deltaParts).doesNotContain("US|20250110"); + assertThat(deltaParts).contains("EU|20250220"); + } + + @Test + public void testUsesBranchSpecificPartitionModifications() throws Exception { + Path tablePath = tablePath("branch_specific_partition_modification"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(snapshotTable, "20250101", "v1"); + write(snapshotTable, "20250201", "v2"); + write(snapshotTable, "20250301", "v3"); + write(deltaTable, "20250110", "d1"); + write(deltaTable, "20250210", "d2"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + RecordingPartitionModification snapshotModification = new RecordingPartitionModification(); + RecordingPartitionModification deltaModification = new RecordingPartitionModification(); + + ChainTablePartitionExpire expire = + new ChainTablePartitionExpire( + Duration.ofDays(30), + Duration.ZERO, + snapshotTable, + deltaTable, + CoreOptions.fromMap(buildOptions(Duration.ofDays(30), false)), + snapshotTable.schema().logicalPartitionType(), + false, + Integer.MAX_VALUE, + 0, + snapshotModification, + deltaModification); + expire.setLastCheck(LocalDateTime.of(2025, 1, 1, 0, 0)); + expire.expire(LocalDateTime.of(2025, 3, 31, 0, 0), Long.MAX_VALUE); + + assertThat(deltaModification.droppedValues("dt")) + .contains("20250110", "20250110.done") + .doesNotContain("20250101", "20250101.done"); + assertThat(snapshotModification.droppedValues("dt")) + .contains("20250101", "20250101.done") + .doesNotContain("20250110", "20250110.done"); + } + + @Test + public void testIsValueAllExpiredReturnsFalseForAnchor() throws Exception { + Path tablePath = tablePath("value_expired_anchor"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + // S(0101), S(0201), S(0301) + write(snapshotTable, "20250101", "v1"); + write(snapshotTable, "20250201", "v2"); + write(snapshotTable, "20250301", "v3"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // expirationTime = 30d, "now" = 2025-03-31 → cutoff = 2025-03-01 + // Snapshots before cutoff: S(0101), S(0201). Anchor = S(0201) (kept). + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), false); + LocalDateTime now = LocalDateTime.of(2025, 3, 31, 0, 0); + + BinaryRow anchor0201 = findPartition(snapshotTable, "20250201"); + BinaryRow expired0101 = findPartition(snapshotTable, "20250101"); + + // Anchor partition alone → not all expired (anchor is retained) + assertThat(expire.isValueAllExpired(Collections.singletonList(anchor0201), now)).isFalse(); + + // Truly expired partition alone → all expired + assertThat(expire.isValueAllExpired(Collections.singletonList(expired0101), now)).isTrue(); + + // Mix of anchor + expired → not all expired + assertThat(expire.isValueAllExpired(Arrays.asList(expired0101, anchor0201), now)).isFalse(); + } + + @Test + public void testIsValueAllExpiredReturnsFalseWhenTooFewSnapshots() throws Exception { + Path tablePath = tablePath("value_expired_few_snapshots"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + // Only 1 snapshot before cutoff + write(snapshotTable, "20250201", "v1"); + write(snapshotTable, "20250315", "v2"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // cutoff = 2025-03-31 - 30d = 2025-03-01 + // Only S(0201) before cutoff → < 2, nothing can expire + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), false); + LocalDateTime now = LocalDateTime.of(2025, 3, 31, 0, 0); + + BinaryRow partition0201 = findPartition(snapshotTable, "20250201"); + assertThat(expire.isValueAllExpired(Collections.singletonList(partition0201), now)) + .isFalse(); + } + + @Test + public void testIsValueAllExpiredReturnsFalseForDeltaOnlyGroup() throws Exception { + Path tablePath = tablePath("value_expired_delta_only"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(deltaTable, "20250101", "d1"); + write(deltaTable, "20250201", "d2"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // cutoff = 2025-03-31 - 30d = 2025-03-01 + // No snapshot boundary exists, so delta-only partitions are retained. + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), false); + LocalDateTime now = LocalDateTime.of(2025, 3, 31, 0, 0); + + BinaryRow deltaPartition = findPartition(deltaTable, "20250101"); + assertThat(expire.isValueAllExpired(Collections.singletonList(deltaPartition), now)) + .isFalse(); + } + + @Test + public void testIsValueAllExpiredWithGroupPartitions() throws Exception { + Path tablePath = tablePath("value_expired_groups"); + createChainTable(tablePath, true); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + // Group "US": 3 snapshots, anchor = S(US,0201) + writeGrouped(snapshotTable, "US", "20250101", "v1"); + writeGrouped(snapshotTable, "US", "20250201", "v2"); + writeGrouped(snapshotTable, "US", "20250301", "v3"); + + // Group "EU": only 1 snapshot before cutoff → nothing expires + writeGrouped(snapshotTable, "EU", "20250215", "v4"); + writeGrouped(snapshotTable, "EU", "20250320", "v5"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // cutoff = 2025-03-31 - 30d = 2025-03-01 + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), true); + LocalDateTime now = LocalDateTime.of(2025, 3, 31, 0, 0); + + BinaryRow usExpired = findGroupedPartition(snapshotTable, "US", "20250101"); + BinaryRow usAnchor = findGroupedPartition(snapshotTable, "US", "20250201"); + BinaryRow euRetained = findGroupedPartition(snapshotTable, "EU", "20250215"); + + // US expired partition → truly expired + assertThat(expire.isValueAllExpired(Collections.singletonList(usExpired), now)).isTrue(); + + // US anchor → retained + assertThat(expire.isValueAllExpired(Collections.singletonList(usAnchor), now)).isFalse(); + + // EU partition (< 2 snapshots before cutoff) → retained + assertThat(expire.isValueAllExpired(Collections.singletonList(euRetained), now)).isFalse(); + + // Mix across groups: US expired + EU retained + assertThat(expire.isValueAllExpired(Arrays.asList(usExpired, euRetained), now)).isFalse(); + } + + @Test + public void testIsValueAllExpiredReturnsFalseForPartitionsAfterCutoff() throws Exception { + Path tablePath = tablePath("value_expired_after_cutoff"); + createChainTable(tablePath, false); + FileStoreTable mainTable = loadTable(tablePath); + FileStoreTable snapshotTable = mainTable.switchToBranch("snapshot"); + FileStoreTable deltaTable = mainTable.switchToBranch("delta"); + + write(snapshotTable, "20250101", "v1"); + write(snapshotTable, "20250315", "v2"); + + snapshotTable = loadTable(tablePath).switchToBranch("snapshot"); + deltaTable = loadTable(tablePath).switchToBranch("delta"); + + // cutoff = 2025-03-31 - 30d = 2025-03-01 + // S(0315) is after cutoff → not expired at all + ChainTablePartitionExpire expire = + newChainExpire(snapshotTable, deltaTable, Duration.ofDays(30), false); + LocalDateTime now = LocalDateTime.of(2025, 3, 31, 0, 0); + + BinaryRow afterCutoff = findPartition(snapshotTable, "20250315"); + assertThat(expire.isValueAllExpired(Collections.singletonList(afterCutoff), now)).isFalse(); + } + + // ========== Helper methods ========== + + private BinaryRow findPartition(FileStoreTable table, String dtValue) { + return table.newSnapshotReader().partitionEntries().stream() + .map(PartitionEntry::partition) + .filter(p -> p.getString(0).toString().equals(dtValue)) + .findFirst() + .orElseThrow(() -> new RuntimeException("Partition " + dtValue + " not found")); + } + + private BinaryRow findGroupedPartition(FileStoreTable table, String region, String dt) { + return table.newSnapshotReader().partitionEntries().stream() + .map(PartitionEntry::partition) + .filter( + p -> + p.getString(0).toString().equals(region) + && p.getString(1).toString().equals(dt)) + .findFirst() + .orElseThrow( + () -> + new RuntimeException( + "Partition " + region + "|" + dt + " not found")); + } + + private Path tablePath(String tableName) { + return new Path(tempDir.toUri().toString(), tableName); + } + + private void createChainTable(Path tablePath, boolean withGroupPartition) throws Exception { + LocalFileIO fileIO = LocalFileIO.create(); + SchemaManager schemaManager = new SchemaManager(fileIO, tablePath); + + Map options = new HashMap<>(); + options.put(CoreOptions.BUCKET.key(), "1"); + options.put("merge-engine", "deduplicate"); + options.put("sequence.field", "v"); + + Schema schema; + if (withGroupPartition) { + schema = + new Schema( + RowType.of( + new org.apache.paimon.types.DataType[] { + DataTypes.STRING(), + DataTypes.STRING(), + DataTypes.STRING(), + DataTypes.STRING() + }, + new String[] {"region", "dt", "pk", "v"}) + .getFields(), + Arrays.asList("region", "dt"), + Arrays.asList("pk", "region", "dt"), + options, + ""); + } else { + schema = + new Schema( + RowType.of( + new org.apache.paimon.types.DataType[] { + DataTypes.STRING(), + DataTypes.STRING(), + DataTypes.STRING() + }, + new String[] {"dt", "pk", "v"}) + .getFields(), + Collections.singletonList("dt"), + Arrays.asList("pk", "dt"), + options, + ""); + } + schemaManager.createTable(schema); + + FileStoreTable mainTable = loadTable(tablePath); + mainTable.createBranch("snapshot"); + mainTable.createBranch("delta"); + + List chainTableOptions = + Arrays.asList( + SchemaChange.setOption("chain-table.enabled", "true"), + SchemaChange.setOption("scan.fallback-snapshot-branch", "snapshot"), + SchemaChange.setOption("scan.fallback-delta-branch", "delta"), + SchemaChange.setOption("partition.timestamp-pattern", "$dt"), + SchemaChange.setOption("partition.timestamp-formatter", "yyyyMMdd")); + if (withGroupPartition) { + chainTableOptions = new java.util.ArrayList<>(chainTableOptions); + chainTableOptions.add(SchemaChange.setOption("chain-table.chain-partition-keys", "dt")); + } + schemaManager.commitChanges(chainTableOptions); + new SchemaManager(fileIO, tablePath, "snapshot").commitChanges(chainTableOptions); + new SchemaManager(fileIO, tablePath, "delta").commitChanges(chainTableOptions); + } + + private FileStoreTable loadTable(Path tablePath) { + LocalFileIO fileIO = LocalFileIO.create(); + Options options = new Options(); + options.set(CoreOptions.PATH, tablePath.toString()); + String branchName = CoreOptions.branch(options.toMap()); + TableSchema tableSchema = new SchemaManager(fileIO, tablePath, branchName).latest().get(); + return FileStoreTableFactory.create( + fileIO, tablePath, tableSchema, CatalogEnvironment.empty()); + } + + private void write(FileStoreTable table, String dt, String v) throws Exception { + StreamTableWrite write = + table.copy(Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), "true")) + .newWrite(commitUser); + write.write( + GenericRow.of( + BinaryString.fromString(dt), + BinaryString.fromString(v), + BinaryString.fromString(v))); + TableCommitImpl commit = table.newCommit(commitUser); + List commitMessages = write.prepareCommit(true, 0); + commit.commit(0, commitMessages); + write.close(); + commit.close(); + } + + private void writeGrouped(FileStoreTable table, String region, String dt, String v) + throws Exception { + StreamTableWrite write = + table.copy(Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), "true")) + .newWrite(commitUser); + write.write( + GenericRow.of( + BinaryString.fromString(region), + BinaryString.fromString(dt), + BinaryString.fromString(v), + BinaryString.fromString(v))); + TableCommitImpl commit = table.newCommit(commitUser); + List commitMessages = write.prepareCommit(true, 0); + commit.commit(0, commitMessages); + write.close(); + commit.close(); + } + + private List listPartitions(FileStoreTable table) { + return table.newSnapshotReader().partitionEntries().stream() + .map(PartitionEntry::partition) + .map(p -> p.getString(0).toString()) + .sorted() + .collect(Collectors.toList()); + } + + private List listGroupedPartitions(FileStoreTable table) { + return table.newSnapshotReader().partitionEntries().stream() + .map(PartitionEntry::partition) + .map(p -> p.getString(0).toString() + "|" + p.getString(1).toString()) + .sorted() + .collect(Collectors.toList()); + } + + private Map buildOptions(Duration expirationTime, boolean withGroupPartition) { + Map opts = new HashMap<>(); + opts.put("partition.timestamp-pattern", "$dt"); + opts.put("partition.timestamp-formatter", "yyyyMMdd"); + opts.put("scan.fallback-snapshot-branch", "snapshot"); + opts.put("scan.fallback-delta-branch", "delta"); + opts.put(CoreOptions.PARTITION_EXPIRATION_TIME.key(), expirationTime.toDays() + " d"); + if (withGroupPartition) { + opts.put("chain-table.chain-partition-keys", "dt"); + } + return opts; + } + + private ChainTablePartitionExpire newChainExpire( + FileStoreTable snapshotTable, + FileStoreTable deltaTable, + Duration expirationTime, + boolean withGroupPartition) { + return new ChainTablePartitionExpire( + expirationTime, + Duration.ZERO, + snapshotTable, + deltaTable, + CoreOptions.fromMap(buildOptions(expirationTime, withGroupPartition)), + snapshotTable.schema().logicalPartitionType(), + false, + Integer.MAX_VALUE, + 0, + null, + null); + } + + private static class RecordingPartitionModification implements PartitionModification { + + private final List> droppedPartitions = new ArrayList<>(); + + @Override + public void createPartitions(List> partitions) + throws Catalog.TableNotExistException {} + + @Override + public void dropPartitions(List> partitions) + throws Catalog.TableNotExistException { + for (Map partition : partitions) { + droppedPartitions.add(new HashMap<>(partition)); + } + } + + @Override + public void alterPartitions(List partitions) + throws Catalog.TableNotExistException {} + + @Override + public void close() throws Exception {} + + private List droppedValues(String key) { + return droppedPartitions.stream() + .map(partition -> partition.get(key)) + .collect(Collectors.toList()); + } + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/operation/FileStoreCommitTest.java b/paimon-core/src/test/java/org/apache/paimon/operation/FileStoreCommitTest.java index 71eb081de89f..e13b11d474b6 100644 --- a/paimon-core/src/test/java/org/apache/paimon/operation/FileStoreCommitTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/operation/FileStoreCommitTest.java @@ -807,7 +807,7 @@ public void testWriteStats() throws Exception { ArrayList newFields = new ArrayList<>(TestKeyValueGenerator.DEFAULT_ROW_TYPE.getFields()); newFields.add(new DataField(-1, "newField", DataTypes.INT())); - store.mergeSchema(new RowType(false, newFields), true); + store.mergeSchema(new RowType(false, newFields), true, true, true); store.commitData(generateDataList(10), gen::getPartition, kv -> 0); readStats = statsFileHandler.readStats(); assertThat(readStats).isEmpty(); diff --git a/paimon-core/src/test/java/org/apache/paimon/operation/PartitionExpireTest.java b/paimon-core/src/test/java/org/apache/paimon/operation/PartitionExpireTest.java index 6af293bea1b4..e6abe773ebfa 100644 --- a/paimon-core/src/test/java/org/apache/paimon/operation/PartitionExpireTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/operation/PartitionExpireTest.java @@ -86,7 +86,7 @@ import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** Test for {@link PartitionExpire}. */ +/** Test for {@link NormalPartitionExpire}. */ public class PartitionExpireTest { @TempDir java.nio.file.Path tempDir; @@ -189,7 +189,7 @@ public void testIllegalPartition() throws Exception { write("20230103", "31"); write("20230103", "32"); write("20230105", "51"); - PartitionExpire expire = newExpire(); + NormalPartitionExpire expire = newExpire(); expire.setLastCheck(date(1)); Assertions.assertDoesNotThrow(() -> expire.expire(date(8), Long.MAX_VALUE)); assertThat(read()).containsExactlyInAnyOrder("abcd:12"); @@ -215,7 +215,7 @@ public void testBatchExpire() throws Exception { write("20230103", "31"); write("20230103", "32"); write("20230105", "51"); - PartitionExpire expire = newExpire(); + NormalPartitionExpire expire = newExpire(); expire.setLastCheck(date(1)); Assertions.assertDoesNotThrow(() -> expire.expire(date(8), Long.MAX_VALUE)); @@ -246,7 +246,7 @@ public void testExpireWithNullOrEmptyPartition() throws Exception { write("20230103", "32"); write("20230105", "51"); - PartitionExpire expire = newExpire(); + NormalPartitionExpire expire = newExpire(); expire.setLastCheck(date(1)); Assertions.assertDoesNotThrow(() -> expire.expire(date(6), Long.MAX_VALUE)); @@ -272,7 +272,7 @@ public void test() throws Exception { write("20230103", "32"); write("20230105", "51"); - PartitionExpire expire = newExpire(); + NormalPartitionExpire expire = newExpire(); expire.setLastCheck(date(1)); expire.expire(date(3), Long.MAX_VALUE); @@ -319,7 +319,7 @@ public void testDonePartitionExpire() throws Exception { doneAction.markDone("f0=20230103"); doneAction.markDone("f0=20230108"); - PartitionExpire expire = newExpire(); + NormalPartitionExpire expire = newExpire(); expire.setLastCheck(date(1)); expire.expire(date(8), Long.MAX_VALUE); @@ -423,7 +423,7 @@ public void testDeleteExpiredPartition() throws Exception { List commitMessages = write("20230101", "11"); write("20230105", "51"); - PartitionExpire expire = newExpire(); + NormalPartitionExpire expire = newExpire(); expire.setLastCheck(date(1)); expire.expire(date(5), Long.MAX_VALUE); assertThat(read()).containsExactlyInAnyOrder("20230105:51"); @@ -471,9 +471,9 @@ private List write(String f0, String f1) throws Exception { return commitMessages; } - private PartitionExpire newExpire() { + private NormalPartitionExpire newExpire() { FileStoreTable table = newExpireTable(); - return table.store().newPartitionExpire("", table); + return (NormalPartitionExpire) table.store().newPartitionExpire("", table); } private FileStoreTable newExpireTable() { diff --git a/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitMetricsTest.java b/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitMetricsTest.java index 6a79a0ae5807..9b475f34c8cc 100644 --- a/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitMetricsTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitMetricsTest.java @@ -100,6 +100,10 @@ public void testMetricsAreUpdated() { registeredGenericMetrics.get( CommitMetrics.LAST_CHANGELOG_RECORDS_COMMIT_COMPACTED); + Gauge lastCommittedSnapshotId = + (Gauge) + registeredGenericMetrics.get(CommitMetrics.LAST_COMMITTED_SNAPSHOT_ID); + assertThat(lastCommitDuration.getValue()).isEqualTo(0); assertThat(commitDuration.getCount()).isEqualTo(0); assertThat(commitDuration.getStatistics().size()).isEqualTo(0); @@ -117,6 +121,7 @@ public void testMetricsAreUpdated() { assertThat(lastChangelogRecordsAppended.getValue()).isEqualTo(0); assertThat(lastDeltaRecordsCompacted.getValue()).isEqualTo(0); assertThat(lastChangelogRecordsCompacted.getValue()).isEqualTo(0); + assertThat(lastCommittedSnapshotId.getValue()).isEqualTo(-1); // report once reportOnce(commitMetrics); @@ -145,6 +150,7 @@ public void testMetricsAreUpdated() { assertThat(lastChangelogRecordsAppended.getValue()).isEqualTo(503); assertThat(lastDeltaRecordsCompacted.getValue()).isEqualTo(613); assertThat(lastChangelogRecordsCompacted.getValue()).isEqualTo(512); + assertThat(lastCommittedSnapshotId.getValue()).isEqualTo(42); // report again reportAgain(commitMetrics); @@ -173,6 +179,7 @@ public void testMetricsAreUpdated() { assertThat(lastChangelogRecordsAppended.getValue()).isEqualTo(213); assertThat(lastDeltaRecordsCompacted.getValue()).isEqualTo(506); assertThat(lastChangelogRecordsCompacted.getValue()).isEqualTo(601); + assertThat(lastCommittedSnapshotId.getValue()).isEqualTo(99); } private void reportOnce(CommitMetrics commitMetrics) { @@ -199,7 +206,8 @@ private void reportOnce(CommitMetrics commitMetrics) { compactChangelogFiles, 200, 2, - 1); + 1, + 42L); commitMetrics.reportCommit(commitStats); } @@ -228,7 +236,8 @@ private void reportAgain(CommitMetrics commitMetrics) { compactChangelogFiles, 500, 1, - 2); + 2, + 99L); commitMetrics.reportCommit(commitStats); } diff --git a/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitStatsTest.java b/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitStatsTest.java index e4a2a7fd22ed..e705dfd41012 100644 --- a/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitStatsTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/operation/metrics/CommitStatsTest.java @@ -78,7 +78,8 @@ public void testFailedAppendSnapshot() { Collections.emptyList(), 0, 0, - 1); + 1, + -1L); assertThat(commitStats.getTableFilesAdded()).isEqualTo(0); assertThat(commitStats.getTableFilesDeleted()).isEqualTo(0); assertThat(commitStats.getTableFilesAppended()).isEqualTo(0); @@ -94,6 +95,7 @@ public void testFailedAppendSnapshot() { assertThat(commitStats.getNumBucketsWritten()).isEqualTo(0); assertThat(commitStats.getDuration()).isEqualTo(0); assertThat(commitStats.getAttempts()).isEqualTo(1); + assertThat(commitStats.getLastCommittedSnapshotId()).isEqualTo(-1); } @Test @@ -106,7 +108,8 @@ public void testFailedCompactSnapshot() { Collections.emptyList(), 3000, 1, - 2); + 2, + 5L); assertThat(commitStats.getTableFilesAdded()).isEqualTo(2); assertThat(commitStats.getTableFilesDeleted()).isEqualTo(0); assertThat(commitStats.getTableFilesAppended()).isEqualTo(2); @@ -122,6 +125,7 @@ public void testFailedCompactSnapshot() { assertThat(commitStats.getNumBucketsWritten()).isEqualTo(2); assertThat(commitStats.getDuration()).isEqualTo(3000); assertThat(commitStats.getAttempts()).isEqualTo(2); + assertThat(commitStats.getLastCommittedSnapshotId()).isEqualTo(5); } @Test @@ -134,7 +138,8 @@ public void testSucceedAllSnapshot() { compactChangelogFiles, 3000, 2, - 2); + 2, + 10L); assertThat(commitStats.getTableFilesAdded()).isEqualTo(4); assertThat(commitStats.getTableFilesDeleted()).isEqualTo(1); assertThat(commitStats.getTableFilesAppended()).isEqualTo(2); @@ -150,5 +155,6 @@ public void testSucceedAllSnapshot() { assertThat(commitStats.getNumBucketsWritten()).isEqualTo(3); assertThat(commitStats.getDuration()).isEqualTo(3000); assertThat(commitStats.getAttempts()).isEqualTo(2); + assertThat(commitStats.getLastCommittedSnapshotId()).isEqualTo(10); } } diff --git a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java index 74e4f4e465a9..af5d94e3f632 100644 --- a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java +++ b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogServer.java @@ -61,7 +61,6 @@ import org.apache.paimon.rest.requests.CreateViewRequest; import org.apache.paimon.rest.requests.ListPartitionsByNamesRequest; import org.apache.paimon.rest.requests.MarkDonePartitionsRequest; -import org.apache.paimon.rest.requests.RenameBranchRequest; import org.apache.paimon.rest.requests.RenameTableRequest; import org.apache.paimon.rest.requests.ReplaceTableRequest; import org.apache.paimon.rest.requests.ResetConsumerRequest; @@ -1888,31 +1887,7 @@ private MockResponse branchApiHandle( case "POST": if (resources.length == 6) { branch = RESTUtil.decodeString(resources[4]); - if ("rename".equals(resources[5])) { - // Rename branch: /branches/{branch}/rename - RenameBranchRequest requestBody = - RESTApi.fromJson(data, RenameBranchRequest.class); - String toBranch = requestBody.toBranch(); - table.renameBranch(branch, toBranch); - // Update store for renamed branch - Identifier fromBranchIdentifier = - new Identifier( - identifier.getDatabaseName(), - identifier.getTableName(), - branch); - Identifier toBranchIdentifier = - new Identifier( - identifier.getDatabaseName(), - identifier.getTableName(), - toBranch); - tableLatestSnapshotStore.put( - toBranchIdentifier.getFullName(), - tableLatestSnapshotStore.get( - fromBranchIdentifier.getFullName())); - tableMetadataStore.put( - toBranchIdentifier.getFullName(), - tableMetadataStore.get(fromBranchIdentifier.getFullName())); - } else if ("forward".equals(resources[5])) { + if ("forward".equals(resources[5])) { // Fast forward branch branchManager.fastForward(branch); branchIdentifier = diff --git a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java index 8cefa2ec7b42..6ff873aed17d 100644 --- a/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/rest/RESTCatalogTest.java @@ -2161,25 +2161,7 @@ void testBranches() throws Exception { () -> restCatalog.createBranch(identifier, "my_branch", null)); assertThat(restCatalog.listBranches(identifier)).containsOnly("my_branch"); - // Test rename branch - restCatalog.renameBranch(identifier, "my_branch", "renamed_branch"); - assertThat(restCatalog.listBranches(identifier)).containsOnly("renamed_branch"); - assertThat(restCatalog.getTable(new Identifier(databaseName, "table", "renamed_branch"))) - .isNotNull(); - - // Test rename to existing branch should fail - restCatalog.createBranch(identifier, "another_branch", null); - assertThrows( - Catalog.BranchAlreadyExistException.class, - () -> restCatalog.renameBranch(identifier, "renamed_branch", "another_branch")); - - // Test rename non-existent branch should fail - assertThrows( - Catalog.BranchNotExistException.class, - () -> restCatalog.renameBranch(identifier, "non_existent_branch", "new_branch")); - - restCatalog.dropBranch(identifier, "renamed_branch"); - restCatalog.dropBranch(identifier, "another_branch"); + restCatalog.dropBranch(identifier, "my_branch"); assertThrows( Catalog.BranchNotExistException.class, diff --git a/paimon-core/src/test/java/org/apache/paimon/schema/ColumnDirectiveUtilsTest.java b/paimon-core/src/test/java/org/apache/paimon/schema/ColumnDirectiveUtilsTest.java new file mode 100644 index 000000000000..2d8494af3b11 --- /dev/null +++ b/paimon-core/src/test/java/org/apache/paimon/schema/ColumnDirectiveUtilsTest.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.schema; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.DataTypeRoot; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VectorType; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link ColumnDirectiveUtils}. */ +public class ColumnDirectiveUtilsTest { + + // -- applyAddColumnDirective (single-column, the main API for ADD COLUMN) -- + + @Test + public void testNonDirectiveCommentReturnsNull() { + Map opts = new HashMap<>(); + assertThat( + ColumnDirectiveUtils.applyAddColumnDirective( + null, "col", DataTypes.BYTES(), opts)) + .isNull(); + assertThat(ColumnDirectiveUtils.applyAddColumnDirective("", "col", DataTypes.BYTES(), opts)) + .isNull(); + assertThat( + ColumnDirectiveUtils.applyAddColumnDirective( + "normal comment", "col", DataTypes.BYTES(), opts)) + .isNull(); + assertThat(opts).isEmpty(); + } + + @Test + public void testBlobFieldDirective() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_FIELD; profile picture", "pic", DataTypes.BYTES(), opts); + + assertThat(result).isNotNull(); + assertThat(result.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(result.comment()).isEqualTo("profile picture"); + assertThat(opts).containsEntry(CoreOptions.BLOB_FIELD.key(), "pic"); + } + + @Test + public void testBlobDescriptorFieldDirective() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_DESCRIPTOR_FIELD; desc text", "desc_col", DataTypes.BYTES(), opts); + + assertThat(result).isNotNull(); + assertThat(result.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(result.comment()).isEqualTo("desc text"); + assertThat(opts).containsEntry(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), "desc_col"); + } + + @Test + public void testBlobViewFieldDirective() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_VIEW_FIELD; view comment", "view_col", DataTypes.BYTES(), opts); + + assertThat(result).isNotNull(); + assertThat(result.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(result.comment()).isEqualTo("view comment"); + assertThat(opts).containsEntry(CoreOptions.BLOB_VIEW_FIELD.key(), "view_col"); + } + + @Test + public void testBlobExternalStorageFieldDirective() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_EXTERNAL_STORAGE_FIELD; external video", + "video", + DataTypes.BYTES(), + opts); + + assertThat(result).isNotNull(); + assertThat(result.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(result.comment()).isEqualTo("external video"); + assertThat(opts).containsEntry(CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key(), "video"); + assertThat(opts).containsEntry(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), "video"); + } + + @Test + public void testVectorFieldDirective() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__VECTOR_FIELD;128; embedding vector", + "emb", + DataTypes.ARRAY(DataTypes.FLOAT()), + opts); + + assertThat(result).isNotNull(); + assertThat(result.type().getTypeRoot()).isEqualTo(DataTypeRoot.VECTOR); + VectorType vectorType = (VectorType) result.type(); + assertThat(vectorType.getLength()).isEqualTo(128); + assertThat(vectorType.getElementType()).isEqualTo(DataTypes.FLOAT()); + assertThat(result.comment()).isEqualTo("embedding vector"); + assertThat(opts).containsEntry(CoreOptions.VECTOR_FIELD.key(), "emb"); + } + + @Test + public void testVectorFieldDirectiveWithoutComment() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__VECTOR_FIELD;64", "emb", DataTypes.ARRAY(DataTypes.DOUBLE()), opts); + + assertThat(result).isNotNull(); + VectorType vectorType = (VectorType) result.type(); + assertThat(vectorType.getLength()).isEqualTo(64); + assertThat(vectorType.getElementType()).isEqualTo(DataTypes.DOUBLE()); + assertThat(result.comment()).isNull(); + } + + @Test + public void testBlobDirectiveAppendsToExistingOption() { + Map opts = new HashMap<>(); + opts.put(CoreOptions.BLOB_FIELD.key(), "existing"); + + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_FIELD", "new_col", DataTypes.BYTES(), opts); + + assertThat(opts).containsEntry(CoreOptions.BLOB_FIELD.key(), "existing,new_col"); + } + + @Test + public void testBareDirectiveWithoutComment() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_FIELD", "col", DataTypes.BYTES(), opts); + + assertThat(result).isNotNull(); + assertThat(result.comment()).isNull(); + } + + @Test + public void testBlobDirectiveWithBlobSourceType() { + Map opts = new HashMap<>(); + ColumnDirectiveUtils.ConvertedColumn result = + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_FIELD", "col", DataTypes.BLOB(), opts); + + assertThat(result).isNotNull(); + assertThat(result.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + } + + // -- applyAddColumnDirective error cases -- + + @Test + public void testBlobDirectiveRejectsNonBinaryType() { + assertThatThrownBy( + () -> + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_FIELD", "col", DataTypes.INT(), new HashMap<>())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be of BYTES, BINARY or BLOB type"); + } + + @Test + public void testVectorDirectiveRejectsNonArrayType() { + assertThatThrownBy( + () -> + ColumnDirectiveUtils.applyAddColumnDirective( + "__VECTOR_FIELD;128", + "col", + DataTypes.INT(), + new HashMap<>())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be of ARRAY type"); + } + + @Test + public void testVectorDirectiveRequiresDimension() { + assertThatThrownBy( + () -> + ColumnDirectiveUtils.applyAddColumnDirective( + "__VECTOR_FIELD", + "col", + DataTypes.ARRAY(DataTypes.FLOAT()), + new HashMap<>())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("requires a dimension"); + } + + @Test + public void testVectorDirectiveRejectsNonIntegerDimension() { + assertThatThrownBy( + () -> + ColumnDirectiveUtils.applyAddColumnDirective( + "__VECTOR_FIELD;abc", + "col", + DataTypes.ARRAY(DataTypes.FLOAT()), + new HashMap<>())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Expected an integer dimension"); + } + + @Test + public void testUnknownBlobDirectiveRejected() { + assertThatThrownBy( + () -> + ColumnDirectiveUtils.applyAddColumnDirective( + "__BLOB_UNKNOWN", + "col", + DataTypes.BYTES(), + new HashMap<>())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported BLOB directive"); + } + + @Test + public void testUnknownVectorDirectiveRejected() { + assertThatThrownBy( + () -> + ColumnDirectiveUtils.applyAddColumnDirective( + "__VECTOR_UNKNOWN", + "col", + DataTypes.ARRAY(DataTypes.FLOAT()), + new HashMap<>())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported VECTOR directive"); + } + + // -- applyDirectives (Schema-level, used by CREATE TABLE) -- + + @Test + public void testApplyDirectivesNoDirectives() { + Schema original = + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + new DataField(1, "v", DataTypes.STRING()) + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + new HashMap<>(), + ""); + + Schema result = ColumnDirectiveUtils.applyDirectives(original); + assertThat(result).isSameAs(original); + } + + @Test + public void testApplyDirectivesMixedFields() { + Map options = new HashMap<>(); + Schema schema = + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + new DataField( + 1, + "pic", + DataTypes.BYTES(), + "__BLOB_FIELD; picture"), + new DataField( + 2, + "emb", + DataTypes.ARRAY(DataTypes.FLOAT()), + "__VECTOR_FIELD;64; my embedding"), + new DataField( + 3, "normal", DataTypes.STRING(), "keep me") + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + ""); + + Schema result = ColumnDirectiveUtils.applyDirectives(schema); + + assertThat(result).isNotSameAs(schema); + assertThat(result.fields()).hasSize(4); + + DataField pic = result.fields().get(1); + assertThat(pic.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(pic.description()).isEqualTo("picture"); + + DataField emb = result.fields().get(2); + assertThat(emb.type().getTypeRoot()).isEqualTo(DataTypeRoot.VECTOR); + assertThat(emb.description()).isEqualTo("my embedding"); + + DataField normal = result.fields().get(3); + assertThat(normal.type()).isEqualTo(DataTypes.STRING()); + assertThat(normal.description()).isEqualTo("keep me"); + + assertThat(result.options()).containsEntry(CoreOptions.BLOB_FIELD.key(), "pic"); + assertThat(result.options()).containsEntry(CoreOptions.VECTOR_FIELD.key(), "emb"); + } + + // -- removeDroppedDirectiveOptions -- + + @Test + public void testRemoveDroppedBlobOptions() { + Map opts = new HashMap<>(); + opts.put(CoreOptions.BLOB_FIELD.key(), "a,b"); + opts.put(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), "b,c"); + opts.put(CoreOptions.BLOB_VIEW_FIELD.key(), "b"); + opts.put(CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key(), "b"); + opts.put("blob.stored-descriptor-fields", "b,legacy"); + opts.put(CoreOptions.VECTOR_FIELD.key(), "v"); + + ColumnDirectiveUtils.removeDroppedDirectiveOptions("b", DataTypeRoot.BLOB, opts); + + assertThat(opts).containsEntry(CoreOptions.BLOB_FIELD.key(), "a"); + assertThat(opts).containsEntry(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), "c"); + assertThat(opts).doesNotContainKey(CoreOptions.BLOB_VIEW_FIELD.key()); + assertThat(opts).doesNotContainKey(CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key()); + assertThat(opts).containsEntry("blob.stored-descriptor-fields", "legacy"); + assertThat(opts).containsEntry(CoreOptions.VECTOR_FIELD.key(), "v"); + } + + @Test + public void testRemoveDroppedVectorOptions() { + Map opts = new HashMap<>(); + opts.put(CoreOptions.BLOB_FIELD.key(), "a"); + opts.put(CoreOptions.VECTOR_FIELD.key(), "emb,emb2"); + opts.put("field.emb.vector-dim", "128"); + + ColumnDirectiveUtils.removeDroppedDirectiveOptions("emb", DataTypeRoot.VECTOR, opts); + + assertThat(opts).containsEntry(CoreOptions.BLOB_FIELD.key(), "a"); + assertThat(opts).containsEntry(CoreOptions.VECTOR_FIELD.key(), "emb2"); + assertThat(opts).doesNotContainKey("field.emb.vector-dim"); + } + + @Test + public void testRemoveDroppedNonDirectiveTypeIsNoop() { + Map opts = new HashMap<>(); + opts.put(CoreOptions.BLOB_FIELD.key(), "a"); + opts.put(CoreOptions.VECTOR_FIELD.key(), "v"); + + ColumnDirectiveUtils.removeDroppedDirectiveOptions("x", DataTypeRoot.INTEGER, opts); + + assertThat(opts).containsEntry(CoreOptions.BLOB_FIELD.key(), "a"); + assertThat(opts).containsEntry(CoreOptions.VECTOR_FIELD.key(), "v"); + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java b/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java index ff04a167a7eb..47fc100032f8 100644 --- a/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java @@ -83,7 +83,7 @@ public void testMergeTableSchemas() { DataField f = new DataField(-1, "f", fDataType); RowType t = new RowType(Lists.newArrayList(a, b, d, f)); - TableSchema merged = SchemaMergingUtils.mergeSchemas(current, t, false); + TableSchema merged = SchemaMergingUtils.mergeSchemas(current, t, true, false, true); assertThat(merged.id()).isEqualTo(1); assertThat(merged.highestFieldId()).isEqualTo(6); assertThat(merged.primaryKeys()).containsExactlyInAnyOrder("a", "d"); @@ -110,10 +110,40 @@ public void testMergeTableSchemaNotChanges() { // fake the RowType of data with different field sequences RowType t = new RowType(Lists.newArrayList(b, a)); - TableSchema merged = SchemaMergingUtils.mergeSchemas(current, t, false); + TableSchema merged = SchemaMergingUtils.mergeSchemas(current, t, true, false, true); assertThat(merged.id()).isEqualTo(0); } + @Test + public void testMergeSchemasWithoutTypeWidening() { + // typeWidening=false (default): existing column types are kept; only new columns are added. + DataField a = new DataField(0, "a", new IntType()); + DataField b = new DataField(1, "b", new VarCharType(VarCharType.MAX_LENGTH)); + TableSchema current = + new TableSchema( + 0, + Lists.newArrayList(a, b), + 1, + new ArrayList<>(), + Lists.newArrayList("a"), + new HashMap<>(), + ""); + + // Incoming data widens `a` (INT -> BIGINT) and adds a new column `c`. + DataField aWidened = new DataField(-1, "a", new BigIntType()); + DataField c = new DataField(-1, "c", new IntType()); + RowType t = new RowType(Lists.newArrayList(aWidened, b, c)); + + TableSchema merged = SchemaMergingUtils.mergeSchemas(current, t, false, false, true); + List fields = merged.fields(); + assertThat(fields.size()).isEqualTo(3); + // `a` keeps its existing INT type instead of widening to BIGINT. + assertThat(fields.get(0).type()).isEqualTo(new IntType()); + // `c` is appended as a new column. + assertThat(fields.get(2).name()).isEqualTo("c"); + assertThat(fields.get(2).type()).isEqualTo(new IntType()); + } + @Test public void testMergeSchemas() { // This will test both `mergeSchemas` and `merge` methods. @@ -132,7 +162,8 @@ public void testMergeSchemas() { // Case 1: an additional field. DataField e = new DataField(-1, "e", new DateType()); RowType t1 = new RowType(Lists.newArrayList(a, b, c, d, e)); - RowType r1 = (RowType) SchemaMergingUtils.merge(source, t1, highestFieldId, false); + RowType r1 = + (RowType) SchemaMergingUtils.merge(source, t1, highestFieldId, true, false, true); assertThat(highestFieldId.get()).isEqualTo(4); assertThat(r1.isNullable()).isTrue(); assertThat(r1.getFieldCount()).isEqualTo(5); @@ -143,7 +174,7 @@ public void testMergeSchemas() { // Case 2: two missing fields. RowType t2 = new RowType(Lists.newArrayList(a, c, e)); - RowType r2 = SchemaMergingUtils.mergeSchemas(r1, t2, highestFieldId, false); + RowType r2 = SchemaMergingUtils.mergeSchemas(r1, t2, highestFieldId, true, false, true); assertThat(highestFieldId.get()).isEqualTo(4); assertThat(r2.getFieldCount()).isEqualTo(5); assertThat(r2.getTypeAt(3)).isEqualTo(d.type()); @@ -155,7 +186,7 @@ public void testMergeSchemas() { RowType fDataType = new RowType(Lists.newArrayList(f1, f2)); DataField f = new DataField(-1, "f", fDataType); RowType t3 = new RowType(Lists.newArrayList(a, b, c, d, f)); - RowType r3 = (RowType) SchemaMergingUtils.merge(r2, t3, highestFieldId, false); + RowType r3 = (RowType) SchemaMergingUtils.merge(r2, t3, highestFieldId, true, false, true); assertThat(highestFieldId.get()).isEqualTo(7); assertThat(r3.getFieldCount()).isEqualTo(6); RowType expectedFDataType = new RowType(Lists.newArrayList(f1.newId(5), f2.newId(6))); @@ -168,7 +199,7 @@ public void testMergeSchemas() { RowType newFDataType = new RowType(Lists.newArrayList(f1, f2, f3)); DataField newF = new DataField(-1, "f", newFDataType); RowType t4 = new RowType(Lists.newArrayList(a, b, c, d, e, newF)); - RowType r4 = SchemaMergingUtils.mergeSchemas(r3, t4, highestFieldId, false); + RowType r4 = SchemaMergingUtils.mergeSchemas(r3, t4, highestFieldId, true, false, true); assertThat(highestFieldId.get()).isEqualTo(8); assertThat(r4.getFieldCount()).isEqualTo(6); RowType newExpectedFDataType = @@ -178,14 +209,15 @@ public void testMergeSchemas() { // Case 5: a field that isn't compatible with the existing one. DataField newA = new DataField(-1, "a", new SmallIntType()); RowType t5 = new RowType(Lists.newArrayList(newA, b, c, d, e, newF)); - assertThatThrownBy(() -> SchemaMergingUtils.merge(r4, t5, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(r4, t5, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // Case 6: all new-coming fields DataField g = new DataField(-1, "g", new TimeType()); DataField h = new DataField(-1, "h", new TimeType()); RowType t6 = new RowType(Lists.newArrayList(g, h)); - RowType r6 = SchemaMergingUtils.mergeSchemas(r4, t6, highestFieldId, false); + RowType r6 = SchemaMergingUtils.mergeSchemas(r4, t6, highestFieldId, true, false, true); assertThat(highestFieldId.get()).isEqualTo(10); assertThat(r6.getFieldCount()).isEqualTo(8); } @@ -261,7 +293,8 @@ public void testDiffNestedSchemaChangesInArrayAndMap() { new HashMap<>(), ""); - List changes = SchemaMergingUtils.diffSchemaChanges(oldSchema, newSchema); + List changes = + SchemaMergingUtils.diffSchemaChanges(oldSchema, newSchema, true); assertThat(changes).hasSize(2); SchemaChange.AddColumn addArrayNestedField = (SchemaChange.AddColumn) changes.get(0); @@ -317,7 +350,8 @@ public void testDiffNestedSchemaChangesDoesNotTreatMapKeyAsValueChange() { new HashMap<>(), ""); - List changes = SchemaMergingUtils.diffSchemaChanges(oldSchema, newSchema); + List changes = + SchemaMergingUtils.diffSchemaChanges(oldSchema, newSchema, true); assertThat(changes).hasSize(1); SchemaChange.UpdateColumnType updateMapType = @@ -360,7 +394,8 @@ public void testDiffNestedSchemaChangesFallsBackToTypeUpdateWhenNestedFieldRemov new HashMap<>(), ""); - List changes = SchemaMergingUtils.diffSchemaChanges(oldSchema, newSchema); + List changes = + SchemaMergingUtils.diffSchemaChanges(oldSchema, newSchema, true); assertThat(changes).hasSize(1); SchemaChange.UpdateColumnType updateItemsType = @@ -377,23 +412,29 @@ public void testMergeArrayTypes() { // the element types are same. DataType t1 = new ArrayType(true, new IntType()); - ArrayType r1 = (ArrayType) SchemaMergingUtils.merge(source, t1, highestFieldId, false); + ArrayType r1 = + (ArrayType) SchemaMergingUtils.merge(source, t1, highestFieldId, true, false, true); assertThat(r1.isNullable()).isFalse(); assertThat(r1.getElementType() instanceof IntType).isTrue(); // the element types aren't same, but can be evolved safety. DataType t2 = new ArrayType(true, new BigIntType()); - ArrayType r2 = (ArrayType) SchemaMergingUtils.merge(source, t2, highestFieldId, false); + ArrayType r2 = + (ArrayType) SchemaMergingUtils.merge(source, t2, highestFieldId, true, false, true); assertThat(r2.isNullable()).isFalse(); assertThat(r2.getElementType() instanceof BigIntType).isTrue(); // the element types aren't same, and can't be evolved safety. DataType t3 = new ArrayType(true, new SmallIntType()); - assertThatThrownBy(() -> SchemaMergingUtils.merge(source, t3, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + source, t3, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // the value type of target's isn't same to the source's, but the source type can be cast to // the target type explicitly. - ArrayType r3 = (ArrayType) SchemaMergingUtils.merge(source, t3, highestFieldId, true); + ArrayType r3 = + (ArrayType) SchemaMergingUtils.merge(source, t3, highestFieldId, true, true, true); assertThat(r3.isNullable()).isFalse(); assertThat(r3.getElementType() instanceof SmallIntType).isTrue(); } @@ -406,25 +447,31 @@ public void testMergeMapTypes() { // both the key and value types are same to the source's. DataType t1 = new MapType(new VarCharType(VarCharType.MAX_LENGTH), new IntType()); - MapType r1 = (MapType) SchemaMergingUtils.merge(source, t1, highestFieldId, false); + MapType r1 = + (MapType) SchemaMergingUtils.merge(source, t1, highestFieldId, true, false, true); assertThat(r1.isNullable()).isTrue(); assertThat(r1.getKeyType() instanceof VarCharType).isTrue(); assertThat(r1.getValueType() instanceof IntType).isTrue(); // the value type of target's isn't same to the source's, but can be evolved safety. DataType t2 = new MapType(new VarCharType(VarCharType.MAX_LENGTH), new DoubleType()); - MapType r2 = (MapType) SchemaMergingUtils.merge(source, t2, highestFieldId, false); + MapType r2 = + (MapType) SchemaMergingUtils.merge(source, t2, highestFieldId, true, false, true); assertThat(r2.isNullable()).isTrue(); assertThat(r2.getKeyType() instanceof VarCharType).isTrue(); assertThat(r2.getValueType() instanceof DoubleType).isTrue(); // the value type of target's isn't same to the source's, and can't be evolved safety. DataType t3 = new MapType(new VarCharType(VarCharType.MAX_LENGTH), new SmallIntType()); - assertThatThrownBy(() -> SchemaMergingUtils.merge(source, t3, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + source, t3, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // the value type of target's isn't same to the source's, but the source type can be cast to // the target type explicitly. - MapType r3 = (MapType) SchemaMergingUtils.merge(source, t3, highestFieldId, true); + MapType r3 = + (MapType) SchemaMergingUtils.merge(source, t3, highestFieldId, true, true, true); assertThat(r3.isNullable()).isTrue(); assertThat(r3.getKeyType() instanceof VarCharType).isTrue(); assertThat(r3.getValueType() instanceof SmallIntType).isTrue(); @@ -439,24 +486,31 @@ public void testMergeMultisetTypes() { // the element types are same. DataType t1 = new MultisetType(true, new IntType()); MultisetType r1 = - (MultisetType) SchemaMergingUtils.merge(source, t1, highestFieldId, false); + (MultisetType) + SchemaMergingUtils.merge(source, t1, highestFieldId, true, false, true); assertThat(r1.isNullable()).isFalse(); assertThat(r1.getElementType() instanceof IntType).isTrue(); // the element types aren't same, but can be evolved safety. DataType t2 = new MultisetType(true, new BigIntType()); MultisetType r2 = - (MultisetType) SchemaMergingUtils.merge(source, t2, highestFieldId, false); + (MultisetType) + SchemaMergingUtils.merge(source, t2, highestFieldId, true, false, true); assertThat(r2.isNullable()).isFalse(); assertThat(r2.getElementType() instanceof BigIntType).isTrue(); // the element types aren't same, and can't be evolved safety. DataType t3 = new MultisetType(true, new SmallIntType()); - assertThatThrownBy(() -> SchemaMergingUtils.merge(source, t3, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + source, t3, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // the value type of target's isn't same to the source's, but the source type can be cast to // the target type explicitly. - MultisetType r3 = (MultisetType) SchemaMergingUtils.merge(source, t3, highestFieldId, true); + MultisetType r3 = + (MultisetType) + SchemaMergingUtils.merge(source, t3, highestFieldId, true, true, true); assertThat(r3.isNullable()).isFalse(); assertThat(r3.getElementType() instanceof SmallIntType).isTrue(); } @@ -467,26 +521,30 @@ public void testMergeDecimalTypes() { DataType s1 = new DecimalType(); DataType t1 = new DecimalType(10, 0); - DecimalType r1 = (DecimalType) SchemaMergingUtils.merge(s1, t1, highestFieldId, false); + DecimalType r1 = + (DecimalType) SchemaMergingUtils.merge(s1, t1, highestFieldId, true, false, true); assertThat(r1.isNullable()).isTrue(); assertThat(r1.getPrecision()).isEqualTo(DecimalType.DEFAULT_PRECISION); assertThat(r1.getScale()).isEqualTo(DecimalType.DEFAULT_SCALE); DataType s2 = new DecimalType(5, 2); DataType t2 = new DecimalType(7, 3); - assertThatThrownBy(() -> SchemaMergingUtils.merge(s2, t2, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s2, t2, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); DataType s3 = new DecimalType(false, 5, 2); DataType t3 = new DecimalType(7, 2); - DecimalType r3 = (DecimalType) SchemaMergingUtils.merge(s3, t3, highestFieldId, false); + DecimalType r3 = + (DecimalType) SchemaMergingUtils.merge(s3, t3, highestFieldId, true, false, true); assertThat(r3.isNullable()).isFalse(); assertThat(r3.getPrecision()).isEqualTo(7); assertThat(r3.getScale()).isEqualTo(2); DataType s4 = new DecimalType(7, 2); DataType t4 = new DecimalType(5, 2); - DecimalType r4 = (DecimalType) SchemaMergingUtils.merge(s4, t4, highestFieldId, false); + DecimalType r4 = + (DecimalType) SchemaMergingUtils.merge(s4, t4, highestFieldId, true, false, true); assertThat(r4.isNullable()).isTrue(); assertThat(r4.getPrecision()).isEqualTo(7); assertThat(r4.getScale()).isEqualTo(2); @@ -495,10 +553,13 @@ public void testMergeDecimalTypes() { DataType dcmSource = new DecimalType(); DataType iTarget = new IntType(); assertThatThrownBy( - () -> SchemaMergingUtils.merge(dcmSource, iTarget, highestFieldId, false)) + () -> + SchemaMergingUtils.merge( + dcmSource, iTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // DecimalType -> Other Numeric Type with allowExplicitCast = true - DataType res = SchemaMergingUtils.merge(dcmSource, iTarget, highestFieldId, true); + DataType res = + SchemaMergingUtils.merge(dcmSource, iTarget, highestFieldId, true, true, true); assertThat(res instanceof IntType).isTrue(); } @@ -509,118 +570,141 @@ public void testMergeTypesWithLength() { // BinaryType DataType s1 = new BinaryType(10); DataType t1 = new BinaryType(10); - BinaryType r1 = (BinaryType) SchemaMergingUtils.merge(s1, t1, highestFieldId, false); + BinaryType r1 = + (BinaryType) SchemaMergingUtils.merge(s1, t1, highestFieldId, true, false, true); assertThat(r1.getLength()).isEqualTo(10); DataType s2 = new BinaryType(2); DataType t2 = new BinaryType(); // smaller length - assertThatThrownBy(() -> SchemaMergingUtils.merge(s2, t2, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s2, t2, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // smaller length with allowExplicitCast = true - BinaryType r2 = (BinaryType) SchemaMergingUtils.merge(s2, t2, highestFieldId, true); + BinaryType r2 = + (BinaryType) SchemaMergingUtils.merge(s2, t2, highestFieldId, true, true, true); assertThat(r2.getLength()).isEqualTo(BinaryType.DEFAULT_LENGTH); // bigger length DataType t3 = new BinaryType(5); - BinaryType r3 = (BinaryType) SchemaMergingUtils.merge(s2, t3, highestFieldId, false); + BinaryType r3 = + (BinaryType) SchemaMergingUtils.merge(s2, t3, highestFieldId, true, false, true); assertThat(r3.getLength()).isEqualTo(5); // VarCharType DataType s4 = new VarCharType(); DataType t4 = new VarCharType(1); - VarCharType r4 = (VarCharType) SchemaMergingUtils.merge(s4, t4, highestFieldId, false); + VarCharType r4 = + (VarCharType) SchemaMergingUtils.merge(s4, t4, highestFieldId, true, false, true); assertThat(r4.getLength()).isEqualTo(VarCharType.DEFAULT_LENGTH); DataType s5 = new VarCharType(2); DataType t5 = new VarCharType(); // smaller length - assertThatThrownBy(() -> SchemaMergingUtils.merge(s5, t5, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s5, t5, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // smaller length with allowExplicitCast = true - VarCharType r5 = (VarCharType) SchemaMergingUtils.merge(s5, t5, highestFieldId, true); + VarCharType r5 = + (VarCharType) SchemaMergingUtils.merge(s5, t5, highestFieldId, true, true, true); assertThat(r5.getLength()).isEqualTo(VarCharType.DEFAULT_LENGTH); // bigger length DataType t6 = new VarCharType(5); - VarCharType r6 = (VarCharType) SchemaMergingUtils.merge(s5, t6, highestFieldId, false); + VarCharType r6 = + (VarCharType) SchemaMergingUtils.merge(s5, t6, highestFieldId, true, false, true); assertThat(r6.getLength()).isEqualTo(5); // CharType DataType s7 = new CharType(); DataType t7 = new CharType(1); - CharType r7 = (CharType) SchemaMergingUtils.merge(s7, t7, highestFieldId, false); + CharType r7 = + (CharType) SchemaMergingUtils.merge(s7, t7, highestFieldId, true, false, true); assertThat(r7.getLength()).isEqualTo(CharType.DEFAULT_LENGTH); DataType s8 = new CharType(2); DataType t8 = new CharType(); // smaller length - assertThatThrownBy(() -> SchemaMergingUtils.merge(s8, t8, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s8, t8, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // smaller length with allowExplicitCast = true - CharType r8 = (CharType) SchemaMergingUtils.merge(s8, t8, highestFieldId, true); + CharType r8 = (CharType) SchemaMergingUtils.merge(s8, t8, highestFieldId, true, true, true); assertThat(r8.getLength()).isEqualTo(CharType.DEFAULT_LENGTH); // bigger length DataType t9 = new CharType(5); - CharType r9 = (CharType) SchemaMergingUtils.merge(s8, t9, highestFieldId, false); + CharType r9 = + (CharType) SchemaMergingUtils.merge(s8, t9, highestFieldId, true, false, true); assertThat(r9.getLength()).isEqualTo(5); // VarBinaryType DataType s10 = new VarBinaryType(); DataType t10 = new VarBinaryType(1); VarBinaryType r10 = - (VarBinaryType) SchemaMergingUtils.merge(s10, t10, highestFieldId, false); + (VarBinaryType) + SchemaMergingUtils.merge(s10, t10, highestFieldId, true, false, true); assertThat(r10.getLength()).isEqualTo(VarBinaryType.DEFAULT_LENGTH); DataType s11 = new VarBinaryType(2); DataType t11 = new VarBinaryType(); // smaller length - assertThatThrownBy(() -> SchemaMergingUtils.merge(s11, t11, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s11, t11, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // smaller length with allowExplicitCast = true VarBinaryType r11 = - (VarBinaryType) SchemaMergingUtils.merge(s11, t11, highestFieldId, true); + (VarBinaryType) + SchemaMergingUtils.merge(s11, t11, highestFieldId, true, true, true); assertThat(r11.getLength()).isEqualTo(VarBinaryType.DEFAULT_LENGTH); // bigger length DataType t12 = new VarBinaryType(5); VarBinaryType r12 = - (VarBinaryType) SchemaMergingUtils.merge(s11, t12, highestFieldId, false); + (VarBinaryType) + SchemaMergingUtils.merge(s11, t12, highestFieldId, true, false, true); assertThat(r12.getLength()).isEqualTo(5); // CharType -> VarCharType DataType s13 = new CharType(); DataType t13 = new VarCharType(10); - VarCharType r13 = (VarCharType) SchemaMergingUtils.merge(s13, t13, highestFieldId, false); + VarCharType r13 = + (VarCharType) SchemaMergingUtils.merge(s13, t13, highestFieldId, true, false, true); assertThat(r13.getLength()).isEqualTo(10); // VarCharType ->CharType DataType s14 = new VarCharType(10); DataType t14 = new CharType(); - assertThatThrownBy(() -> SchemaMergingUtils.merge(s14, t14, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s14, t14, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); - CharType r14 = (CharType) SchemaMergingUtils.merge(s14, t14, highestFieldId, true); + CharType r14 = + (CharType) SchemaMergingUtils.merge(s14, t14, highestFieldId, true, true, true); assertThat(r14.getLength()).isEqualTo(CharType.DEFAULT_LENGTH); // BinaryType -> VarBinaryType DataType s15 = new BinaryType(); DataType t15 = new VarBinaryType(10); VarBinaryType r15 = - (VarBinaryType) SchemaMergingUtils.merge(s15, t15, highestFieldId, false); + (VarBinaryType) + SchemaMergingUtils.merge(s15, t15, highestFieldId, true, false, true); assertThat(r15.getLength()).isEqualTo(10); // VarBinaryType -> BinaryType DataType s16 = new VarBinaryType(10); DataType t16 = new BinaryType(); - assertThatThrownBy(() -> SchemaMergingUtils.merge(s16, t16, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s16, t16, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); - BinaryType r16 = (BinaryType) SchemaMergingUtils.merge(s16, t16, highestFieldId, true); + BinaryType r16 = + (BinaryType) SchemaMergingUtils.merge(s16, t16, highestFieldId, true, true, true); assertThat(r16.getLength()).isEqualTo(BinaryType.DEFAULT_LENGTH); // VarCharType -> VarBinaryType DataType s17 = new VarCharType(10); DataType t17 = new VarBinaryType(); - assertThatThrownBy(() -> SchemaMergingUtils.merge(s17, t17, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s17, t17, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); VarBinaryType r17 = - (VarBinaryType) SchemaMergingUtils.merge(s17, t17, highestFieldId, true); + (VarBinaryType) + SchemaMergingUtils.merge(s17, t17, highestFieldId, true, true, true); assertThat(r17.getLength()).isEqualTo(VarBinaryType.DEFAULT_LENGTH); } @@ -632,7 +716,8 @@ public void testMergeTypesWithPrecision() { DataType s1 = new LocalZonedTimestampType(); DataType t1 = new LocalZonedTimestampType(); LocalZonedTimestampType r1 = - (LocalZonedTimestampType) SchemaMergingUtils.merge(s1, t1, highestFieldId, false); + (LocalZonedTimestampType) + SchemaMergingUtils.merge(s1, t1, highestFieldId, true, false, true); assertThat(r1.isNullable()).isTrue(); assertThat(r1.getPrecision()).isEqualTo(LocalZonedTimestampType.DEFAULT_PRECISION); @@ -640,31 +725,40 @@ public void testMergeTypesWithPrecision() { assertThatThrownBy( () -> SchemaMergingUtils.merge( - s1, new LocalZonedTimestampType(3), highestFieldId, false)) + s1, + new LocalZonedTimestampType(3), + highestFieldId, + true, + false, + true)) .isInstanceOf(UnsupportedOperationException.class); // higher precision DataType t2 = new LocalZonedTimestampType(6); LocalZonedTimestampType r2 = - (LocalZonedTimestampType) SchemaMergingUtils.merge(s1, t2, highestFieldId, false); + (LocalZonedTimestampType) + SchemaMergingUtils.merge(s1, t2, highestFieldId, true, false, true); assertThat(r2.getPrecision()).isEqualTo(6); // LocalZonedTimestampType -> TimeType DataType s3 = new LocalZonedTimestampType(); DataType t3 = new TimeType(6); - assertThatThrownBy(() -> SchemaMergingUtils.merge(s3, t3, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s3, t3, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // LocalZonedTimestampType -> TimestampType DataType s4 = new LocalZonedTimestampType(); DataType t4 = new TimestampType(); - TimestampType r4 = (TimestampType) SchemaMergingUtils.merge(s4, t4, highestFieldId, false); + TimestampType r4 = + (TimestampType) SchemaMergingUtils.merge(s4, t4, highestFieldId, true, false, true); assertThat(r4.getPrecision()).isEqualTo(TimestampType.DEFAULT_PRECISION); // TimestampType. DataType s5 = new TimestampType(); DataType t5 = new TimestampType(); - TimestampType r5 = (TimestampType) SchemaMergingUtils.merge(s5, t5, highestFieldId, false); + TimestampType r5 = + (TimestampType) SchemaMergingUtils.merge(s5, t5, highestFieldId, true, false, true); assertThat(r5.isNullable()).isTrue(); assertThat(r5.getPrecision()).isEqualTo(TimestampType.DEFAULT_PRECISION); @@ -672,65 +766,81 @@ s1, new LocalZonedTimestampType(3), highestFieldId, false)) assertThatThrownBy( () -> SchemaMergingUtils.merge( - s5, new TimestampType(3), highestFieldId, false)) + s5, + new TimestampType(3), + highestFieldId, + true, + false, + true)) .isInstanceOf(UnsupportedOperationException.class); // higher precision DataType t6 = new TimestampType(9); - TimestampType r6 = (TimestampType) SchemaMergingUtils.merge(s5, t6, highestFieldId, false); + TimestampType r6 = + (TimestampType) SchemaMergingUtils.merge(s5, t6, highestFieldId, true, false, true); assertThat(r6.getPrecision()).isEqualTo(9); // TimestampType -> LocalZonedTimestampType DataType s7 = new TimestampType(); DataType t7 = new LocalZonedTimestampType(); LocalZonedTimestampType r7 = - (LocalZonedTimestampType) SchemaMergingUtils.merge(s7, t7, highestFieldId, false); + (LocalZonedTimestampType) + SchemaMergingUtils.merge(s7, t7, highestFieldId, true, false, true); assertThat(r7.getPrecision()).isEqualTo(TimestampType.DEFAULT_PRECISION); // TimestampType -> TimestampType DataType s8 = new TimestampType(); DataType t8 = new TimeType(6); - TimeType r8 = (TimeType) SchemaMergingUtils.merge(s8, t8, highestFieldId, false); + TimeType r8 = + (TimeType) SchemaMergingUtils.merge(s8, t8, highestFieldId, true, false, true); assertThat(r8.getPrecision()).isEqualTo(TimestampType.DEFAULT_PRECISION); // TimeType. DataType s9 = new TimeType(); DataType t9 = new TimeType(); - TimeType r9 = (TimeType) SchemaMergingUtils.merge(s9, t9, highestFieldId, false); + TimeType r9 = + (TimeType) SchemaMergingUtils.merge(s9, t9, highestFieldId, true, false, true); assertThat(r9.isNullable()).isTrue(); assertThat(r9.getPrecision()).isEqualTo(TimeType.DEFAULT_PRECISION); // lower precision DataType s10 = new TimeType(6); assertThatThrownBy( - () -> SchemaMergingUtils.merge(s10, new TimeType(3), highestFieldId, false)) + () -> + SchemaMergingUtils.merge( + s10, new TimeType(3), highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // higher precision DataType t10 = new TimeType(9); - TimeType r10 = (TimeType) SchemaMergingUtils.merge(s9, t10, highestFieldId, false); + TimeType r10 = + (TimeType) SchemaMergingUtils.merge(s9, t10, highestFieldId, true, false, true); assertThat(r10.getPrecision()).isEqualTo(9); // TimeType -> LocalZonedTimestampType DataType s11 = new TimeType(); DataType t11 = new LocalZonedTimestampType(); - assertThatThrownBy(() -> SchemaMergingUtils.merge(s11, t11, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s11, t11, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // TimeType -> LocalZonedTimestampType with allowExplicitCast = true LocalZonedTimestampType r11 = - (LocalZonedTimestampType) SchemaMergingUtils.merge(s11, t11, highestFieldId, true); + (LocalZonedTimestampType) + SchemaMergingUtils.merge(s11, t11, highestFieldId, true, true, true); assertThat(r11.getPrecision()).isEqualTo(LocalZonedTimestampType.DEFAULT_PRECISION); // TimeType -> TimestampType DataType s12 = new TimeType(); DataType t12 = new TimestampType(); - assertThatThrownBy(() -> SchemaMergingUtils.merge(s12, t12, highestFieldId, false)) + assertThatThrownBy( + () -> SchemaMergingUtils.merge(s12, t12, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // TimeType -> TimestampType with allowExplicitCast = true TimestampType r12 = - (TimestampType) SchemaMergingUtils.merge(s12, t12, highestFieldId, true); + (TimestampType) + SchemaMergingUtils.merge(s12, t12, highestFieldId, true, true, true); assertThat(r12.getPrecision()).isEqualTo(TimestampType.DEFAULT_PRECISION); } @@ -756,186 +866,274 @@ public void testMergePrimitiveTypes() { DataType dcmTarget = new DecimalType(); // BooleanType - DataType btRes1 = SchemaMergingUtils.merge(bSource, bTarget, highestFieldId, false); + DataType btRes1 = + SchemaMergingUtils.merge(bSource, bTarget, highestFieldId, true, false, true); assertThat(btRes1 instanceof BooleanType).isTrue(); // BooleanType -> Numeric Type - assertThatThrownBy(() -> SchemaMergingUtils.merge(bSource, tiTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + bSource, tiTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // BooleanType -> Numeric Type with allowExplicitCast = true - DataType btRes2 = SchemaMergingUtils.merge(bSource, tiTarget, highestFieldId, true); + DataType btRes2 = + SchemaMergingUtils.merge(bSource, tiTarget, highestFieldId, true, true, true); assertThat(btRes2 instanceof TinyIntType).isTrue(); // TinyIntType - DataType tiRes1 = SchemaMergingUtils.merge(tiSource, tiTarget, highestFieldId, false); + DataType tiRes1 = + SchemaMergingUtils.merge(tiSource, tiTarget, highestFieldId, true, false, true); assertThat(tiRes1 instanceof TinyIntType).isTrue(); // TinyIntType -> SmallIntType - DataType tiRes2 = SchemaMergingUtils.merge(tiSource, siTarget, highestFieldId, false); + DataType tiRes2 = + SchemaMergingUtils.merge(tiSource, siTarget, highestFieldId, true, false, true); assertThat(tiRes2 instanceof SmallIntType).isTrue(); // TinyIntType -> IntType - DataType tiRes3 = SchemaMergingUtils.merge(tiSource, iTarget, highestFieldId, false); + DataType tiRes3 = + SchemaMergingUtils.merge(tiSource, iTarget, highestFieldId, true, false, true); assertThat(tiRes3 instanceof IntType).isTrue(); // TinyIntType -> BigIntType - DataType tiRes4 = SchemaMergingUtils.merge(tiSource, biTarget, highestFieldId, false); + DataType tiRes4 = + SchemaMergingUtils.merge(tiSource, biTarget, highestFieldId, true, false, true); assertThat(tiRes4 instanceof BigIntType).isTrue(); // TinyIntType -> FloatType - DataType tiRes5 = SchemaMergingUtils.merge(tiSource, fTarget, highestFieldId, false); + DataType tiRes5 = + SchemaMergingUtils.merge(tiSource, fTarget, highestFieldId, true, false, true); assertThat(tiRes5 instanceof FloatType).isTrue(); // TinyIntType -> DoubleType - DataType tiRes6 = SchemaMergingUtils.merge(tiSource, dTarget, highestFieldId, false); + DataType tiRes6 = + SchemaMergingUtils.merge(tiSource, dTarget, highestFieldId, true, false, true); assertThat(tiRes6 instanceof DoubleType).isTrue(); // TinyIntType -> DecimalType - DataType tiRes7 = SchemaMergingUtils.merge(tiSource, dcmTarget, highestFieldId, false); + DataType tiRes7 = + SchemaMergingUtils.merge(tiSource, dcmTarget, highestFieldId, true, false, true); assertThat(tiRes7 instanceof DecimalType).isTrue(); // TinyIntType -> BooleanType - assertThatThrownBy(() -> SchemaMergingUtils.merge(tiSource, bTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + tiSource, bTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // TinyIntType -> BooleanType with allowExplicitCast = true - DataType tiRes8 = SchemaMergingUtils.merge(tiSource, bTarget, highestFieldId, true); + DataType tiRes8 = + SchemaMergingUtils.merge(tiSource, bTarget, highestFieldId, true, true, true); assertThat(tiRes8 instanceof BooleanType).isTrue(); // SmallIntType - DataType siRes1 = SchemaMergingUtils.merge(siSource, siTarget, highestFieldId, false); + DataType siRes1 = + SchemaMergingUtils.merge(siSource, siTarget, highestFieldId, true, false, true); assertThat(siRes1 instanceof SmallIntType).isTrue(); // SmallIntType -> TinyIntType assertThatThrownBy( - () -> SchemaMergingUtils.merge(siSource, tiTarget, highestFieldId, false)) + () -> + SchemaMergingUtils.merge( + siSource, tiTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // SmallIntType -> TinyIntType with allowExplicitCast = true - DataType siRes2 = SchemaMergingUtils.merge(siSource, tiTarget, highestFieldId, true); + DataType siRes2 = + SchemaMergingUtils.merge(siSource, tiTarget, highestFieldId, true, true, true); assertThat(siRes2 instanceof TinyIntType).isTrue(); // SmallIntType -> IntType - DataType siRes3 = SchemaMergingUtils.merge(siSource, iTarget, highestFieldId, false); + DataType siRes3 = + SchemaMergingUtils.merge(siSource, iTarget, highestFieldId, true, false, true); assertThat(siRes3 instanceof IntType).isTrue(); // SmallIntType -> BigIntType - DataType siRes4 = SchemaMergingUtils.merge(siSource, biTarget, highestFieldId, false); + DataType siRes4 = + SchemaMergingUtils.merge(siSource, biTarget, highestFieldId, true, false, true); assertThat(siRes4 instanceof BigIntType).isTrue(); // SmallIntType -> FloatType - DataType siRes5 = SchemaMergingUtils.merge(siSource, fTarget, highestFieldId, false); + DataType siRes5 = + SchemaMergingUtils.merge(siSource, fTarget, highestFieldId, true, false, true); assertThat(siRes5 instanceof FloatType).isTrue(); // SmallIntType -> DoubleType - DataType siRes6 = SchemaMergingUtils.merge(siSource, dTarget, highestFieldId, false); + DataType siRes6 = + SchemaMergingUtils.merge(siSource, dTarget, highestFieldId, true, false, true); assertThat(siRes6 instanceof DoubleType).isTrue(); // SmallIntType -> DecimalType - DataType siRes7 = SchemaMergingUtils.merge(siSource, dcmTarget, highestFieldId, false); + DataType siRes7 = + SchemaMergingUtils.merge(siSource, dcmTarget, highestFieldId, true, false, true); assertThat(siRes7 instanceof DecimalType).isTrue(); // IntType - DataType iRes1 = SchemaMergingUtils.merge(iSource, iTarget, highestFieldId, false); + DataType iRes1 = + SchemaMergingUtils.merge(iSource, iTarget, highestFieldId, true, false, true); assertThat(iRes1 instanceof IntType).isTrue(); // IntType -> TinyIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(iSource, tiTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + iSource, tiTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // IntType -> TinyIntType with allowExplicitCast = true - DataType iRes2 = SchemaMergingUtils.merge(iSource, tiTarget, highestFieldId, true); + DataType iRes2 = + SchemaMergingUtils.merge(iSource, tiTarget, highestFieldId, true, true, true); assertThat(iRes2 instanceof TinyIntType).isTrue(); // IntType -> SmallIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(iSource, siTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + iSource, siTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // IntType -> SmallIntType with allowExplicitCast = true - DataType iRes3 = SchemaMergingUtils.merge(iSource, siTarget, highestFieldId, true); + DataType iRes3 = + SchemaMergingUtils.merge(iSource, siTarget, highestFieldId, true, true, true); assertThat(iRes3 instanceof SmallIntType).isTrue(); // IntType -> BigIntType - DataType iRes4 = SchemaMergingUtils.merge(iSource, biTarget, highestFieldId, false); + DataType iRes4 = + SchemaMergingUtils.merge(iSource, biTarget, highestFieldId, true, false, true); assertThat(iRes4 instanceof BigIntType).isTrue(); // IntType -> FloatType - DataType iRes5 = SchemaMergingUtils.merge(iSource, fTarget, highestFieldId, false); + DataType iRes5 = + SchemaMergingUtils.merge(iSource, fTarget, highestFieldId, true, false, true); assertThat(iRes5 instanceof FloatType).isTrue(); // IntType -> DoubleType - DataType iRes6 = SchemaMergingUtils.merge(iSource, dTarget, highestFieldId, false); + DataType iRes6 = + SchemaMergingUtils.merge(iSource, dTarget, highestFieldId, true, false, true); assertThat(iRes6 instanceof DoubleType).isTrue(); // IntType -> DecimalType - DataType iRes7 = SchemaMergingUtils.merge(iSource, dcmTarget, highestFieldId, false); + DataType iRes7 = + SchemaMergingUtils.merge(iSource, dcmTarget, highestFieldId, true, false, true); assertThat(iRes7 instanceof DecimalType).isTrue(); // BigIntType - DataType biRes1 = SchemaMergingUtils.merge(biSource, biTarget, highestFieldId, false); + DataType biRes1 = + SchemaMergingUtils.merge(biSource, biTarget, highestFieldId, true, false, true); assertThat(biRes1 instanceof BigIntType).isTrue(); // BigIntType -> TinyIntType assertThatThrownBy( - () -> SchemaMergingUtils.merge(biSource, tiTarget, highestFieldId, false)) + () -> + SchemaMergingUtils.merge( + biSource, tiTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // BigIntType -> TinyIntType with allowExplicitCast = true - DataType biRes2 = SchemaMergingUtils.merge(biSource, tiTarget, highestFieldId, true); + DataType biRes2 = + SchemaMergingUtils.merge(biSource, tiTarget, highestFieldId, true, true, true); assertThat(biRes2 instanceof TinyIntType).isTrue(); // BigIntType -> SmallIntType assertThatThrownBy( - () -> SchemaMergingUtils.merge(biSource, siTarget, highestFieldId, false)) + () -> + SchemaMergingUtils.merge( + biSource, siTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // BigIntType -> SmallIntType with allowExplicitCast = true - DataType biRes3 = SchemaMergingUtils.merge(biSource, siTarget, highestFieldId, true); + DataType biRes3 = + SchemaMergingUtils.merge(biSource, siTarget, highestFieldId, true, true, true); assertThat(biRes3 instanceof SmallIntType).isTrue(); // BigIntType -> IntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(biSource, iTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + biSource, iTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // BigIntType -> IntType with allowExplicitCast = true - DataType biRes4 = SchemaMergingUtils.merge(biSource, iTarget, highestFieldId, true); + DataType biRes4 = + SchemaMergingUtils.merge(biSource, iTarget, highestFieldId, true, true, true); assertThat(biRes4 instanceof IntType).isTrue(); // BigIntType -> FloatType - DataType biRes5 = SchemaMergingUtils.merge(biSource, fTarget, highestFieldId, false); + DataType biRes5 = + SchemaMergingUtils.merge(biSource, fTarget, highestFieldId, true, false, true); assertThat(biRes5 instanceof FloatType).isTrue(); // BigIntType -> DoubleType - DataType biRes6 = SchemaMergingUtils.merge(biSource, dTarget, highestFieldId, false); + DataType biRes6 = + SchemaMergingUtils.merge(biSource, dTarget, highestFieldId, true, false, true); assertThat(biRes6 instanceof DoubleType).isTrue(); // BigIntType -> DecimalType - DataType biRes7 = SchemaMergingUtils.merge(biSource, dcmTarget, highestFieldId, false); + DataType biRes7 = + SchemaMergingUtils.merge(biSource, dcmTarget, highestFieldId, true, false, true); assertThat(biRes7 instanceof DecimalType).isTrue(); // FloatType - DataType fRes1 = SchemaMergingUtils.merge(fSource, fTarget, highestFieldId, false); + DataType fRes1 = + SchemaMergingUtils.merge(fSource, fTarget, highestFieldId, true, false, true); assertThat(fRes1 instanceof FloatType).isTrue(); // FloatType -> TinyIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(fSource, tiTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + fSource, tiTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // FloatType -> TinyIntType with allowExplicitCast = true - DataType fRes2 = SchemaMergingUtils.merge(fSource, tiTarget, highestFieldId, true); + DataType fRes2 = + SchemaMergingUtils.merge(fSource, tiTarget, highestFieldId, true, true, true); assertThat(fRes2 instanceof TinyIntType).isTrue(); // FloatType -> SmallIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(fSource, siTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + fSource, siTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // FloatType -> IntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(fSource, iTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + fSource, iTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // FloatType -> IntType with allowExplicitCast = true - DataType fRes4 = SchemaMergingUtils.merge(fSource, iTarget, highestFieldId, true); + DataType fRes4 = + SchemaMergingUtils.merge(fSource, iTarget, highestFieldId, true, true, true); assertThat(fRes4 instanceof IntType).isTrue(); // FloatType -> BigIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(fSource, biTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + fSource, biTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // FloatType -> DoubleType - DataType fRes6 = SchemaMergingUtils.merge(fSource, dTarget, highestFieldId, false); + DataType fRes6 = + SchemaMergingUtils.merge(fSource, dTarget, highestFieldId, true, false, true); assertThat(fRes6 instanceof DoubleType).isTrue(); // FloatType -> DecimalType - DataType fRes7 = SchemaMergingUtils.merge(fSource, dcmTarget, highestFieldId, false); + DataType fRes7 = + SchemaMergingUtils.merge(fSource, dcmTarget, highestFieldId, true, false, true); assertThat(fRes7 instanceof DecimalType).isTrue(); // DoubleType - DataType dRes1 = SchemaMergingUtils.merge(dSource, dTarget, highestFieldId, false); + DataType dRes1 = + SchemaMergingUtils.merge(dSource, dTarget, highestFieldId, true, false, true); assertThat(dRes1 instanceof DoubleType).isTrue(); // DoubleType -> TinyIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(dSource, tiTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + dSource, tiTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // DoubleType -> SmallIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(dSource, siTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + dSource, siTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // DoubleType -> SmallIntType with allowExplicitCast = true - DataType dRes3 = SchemaMergingUtils.merge(dSource, siTarget, highestFieldId, true); + DataType dRes3 = + SchemaMergingUtils.merge(dSource, siTarget, highestFieldId, true, true, true); assertThat(dRes3 instanceof SmallIntType).isTrue(); // DoubleType -> IntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(dSource, iTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + dSource, iTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // DoubleType -> BigIntType - assertThatThrownBy(() -> SchemaMergingUtils.merge(dSource, biTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + dSource, biTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // DoubleType -> BigIntType with allowExplicitCast = true - DataType dRes5 = SchemaMergingUtils.merge(dSource, biTarget, highestFieldId, true); + DataType dRes5 = + SchemaMergingUtils.merge(dSource, biTarget, highestFieldId, true, true, true); assertThat(dRes5 instanceof BigIntType).isTrue(); // DoubleType -> FloatType - assertThatThrownBy(() -> SchemaMergingUtils.merge(dSource, fTarget, highestFieldId, false)) + assertThatThrownBy( + () -> + SchemaMergingUtils.merge( + dSource, fTarget, highestFieldId, true, false, true)) .isInstanceOf(UnsupportedOperationException.class); // DoubleType -> DecimalType - DataType dRes7 = SchemaMergingUtils.merge(dSource, dcmTarget, highestFieldId, false); + DataType dRes7 = + SchemaMergingUtils.merge(dSource, dcmTarget, highestFieldId, true, false, true); assertThat(dRes7 instanceof DecimalType).isTrue(); } } diff --git a/paimon-core/src/test/java/org/apache/paimon/schema/SchemaValidationTest.java b/paimon-core/src/test/java/org/apache/paimon/schema/SchemaValidationTest.java index c3a79d91fdf1..6bda43769232 100644 --- a/paimon-core/src/test/java/org/apache/paimon/schema/SchemaValidationTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/schema/SchemaValidationTest.java @@ -40,6 +40,8 @@ import static org.apache.paimon.CoreOptions.VECTOR_FIELD; import static org.apache.paimon.CoreOptions.VECTOR_FILE_FORMAT; import static org.apache.paimon.schema.SchemaValidation.validateTableSchema; +import static org.apache.paimon.schema.TableSchema.CURRENT_VERSION; +import static org.apache.paimon.schema.TableSchema.PAIMON_07_VERSION; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -468,6 +470,55 @@ private void validateTableSchemaWithMapField(Map options) { new TableSchema(1, fields, 10, emptyList(), singletonList("f1"), options, "")); } + @Test + public void testSnapshotSequenceOrderingHappyPath() { + Map options = new HashMap<>(); + options.put(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING.key(), "true"); + options.put(CoreOptions.WRITE_ONLY.key(), "true"); + assertThatNoException().isThrownBy(() -> validateTableSchemaExec(options)); + } + + @Test + public void testSnapshotSequenceOrderingRejectsNonWriteOnly() { + Map options = new HashMap<>(); + options.put(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING.key(), "true"); + assertThatThrownBy(() -> validateTableSchemaExec(options)) + .hasMessageContaining(CoreOptions.WRITE_ONLY.key()); + } + + @Test + public void testSnapshotSequenceOrderingRejectsSequenceField() { + Map options = new HashMap<>(); + options.put(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING.key(), "true"); + options.put(CoreOptions.WRITE_ONLY.key(), "true"); + options.put(CoreOptions.SEQUENCE_FIELD.key(), "f2"); + assertThatThrownBy(() -> validateTableSchemaExec(options)) + .hasMessageContaining("sequence.field"); + } + + @Test + public void testSnapshotSequenceOrderingRejectsNonPkTable() { + List fields = + Arrays.asList( + new DataField(0, "f0", DataTypes.INT()), + new DataField(1, "f1", DataTypes.INT())); + Map options = new HashMap<>(); + options.put(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING.key(), "true"); + options.put(BUCKET.key(), String.valueOf(-1)); + assertThatThrownBy( + () -> + validateTableSchema( + new TableSchema( + 1, + fields, + 10, + emptyList(), + emptyList(), + options, + ""))) + .hasMessageContaining("primary-key"); + } + @Test public void testFileFormatPerLevelRejectsIncompatibleSchema() { List fields = @@ -509,6 +560,68 @@ public void testFileFormatPerLevelAcceptsCompatibleSchema() { new TableSchema(1, fields, 10, emptyList(), singletonList("k"), options, "")); } + @Test + void testManifestSortValidation() { + List fields = + Arrays.asList( + new DataField(0, "f0", DataTypes.INT()), + new DataField(1, "f1", DataTypes.INT())); + + // Test 1: manifest-sort.enabled on non-partition table should fail + Map options1 = new HashMap<>(); + options1.put(CoreOptions.MANIFEST_SORT_ENABLED.key(), "true"); + options1.put(BUCKET.key(), String.valueOf(-1)); + assertThatThrownBy( + () -> + validateTableSchema( + new TableSchema( + 1, + fields, + 10, + emptyList(), + emptyList(), + options1, + ""))) + .hasMessageContaining( + "Cannot enable 'manifest-sort.enabled' for non-partition table."); + + // Test 2: manifest-sort-partition-field not in partition keys should fail + Map options2 = new HashMap<>(); + options2.put(CoreOptions.MANIFEST_SORT_ENABLED.key(), "true"); + options2.put(CoreOptions.MANIFEST_SORT_PARTITION_FIELD.key(), "f1"); + options2.put(BUCKET.key(), String.valueOf(-1)); + assertThatThrownBy( + () -> + validateTableSchema( + new TableSchema( + 1, + fields, + 10, + singletonList("f0"), + emptyList(), + options2, + ""))) + .hasMessageContaining("is not a partition field"); + + // Test 3: valid manifest-sort config should pass + Map options3 = new HashMap<>(); + options3.put(CoreOptions.MANIFEST_SORT_ENABLED.key(), "true"); + options3.put(CoreOptions.MANIFEST_SORT_PARTITION_FIELD.key(), "f0"); + options3.put(BUCKET.key(), String.valueOf(-1)); + assertThatNoException() + .isThrownBy( + () -> + validateTableSchema( + new TableSchema( + 1, + fields, + 10, + singletonList("f0"), + emptyList(), + options3, + ""))); + } + @Test public void testMergeOnReadCoexistsWithVisibilityCallback() { Map options = new HashMap<>(); @@ -545,6 +658,69 @@ public void testMergeOnReadCoexistsWithVisibilityCallbackAndPostponeBucket() { .doesNotThrowAnyException(); } + @Test + public void testBucketAppendBackwardCompatibility() { + List fields = + Arrays.asList( + new DataField(0, "f0", DataTypes.INT()), + new DataField(1, "f1", DataTypes.STRING())); + + Map legacyOptions = new HashMap<>(); + legacyOptions.put(BUCKET.key(), "1"); + + TableSchema legacySchema = + new TableSchema( + PAIMON_07_VERSION, + 0L, + fields, + 1, + emptyList(), + emptyList(), + legacyOptions, + "", + 0L); + + assertThatCode(() -> validateTableSchema(legacySchema)).doesNotThrowAnyException(); + + Map currentOptions = new HashMap<>(); + currentOptions.put(BUCKET.key(), "1"); + + TableSchema currentSchema = + new TableSchema( + CURRENT_VERSION, + 0L, + fields, + 1, + emptyList(), + emptyList(), + currentOptions, + "", + 0L); + + assertThatThrownBy(() -> validateTableSchema(currentSchema)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("bucket-key"); + + Map legacyMultiBucketOptions = new HashMap<>(); + legacyMultiBucketOptions.put(BUCKET.key(), "2"); + + TableSchema legacyMultiBucketSchema = + new TableSchema( + PAIMON_07_VERSION, + 0L, + fields, + 1, + emptyList(), + emptyList(), + legacyMultiBucketOptions, + "", + 0L); + + assertThatThrownBy(() -> validateTableSchema(legacyMultiBucketSchema)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("bucket-key"); + } + @Test public void testMergeOnReadRequiresDvEnabled() { Map options = new HashMap<>(); diff --git a/paimon-core/src/test/java/org/apache/paimon/table/AppendOnlySimpleTableTest.java b/paimon-core/src/test/java/org/apache/paimon/table/AppendOnlySimpleTableTest.java index e1bd365b2380..fca17b5b9800 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/AppendOnlySimpleTableTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/AppendOnlySimpleTableTest.java @@ -77,6 +77,7 @@ import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; import org.apache.paimon.utils.BranchMergeHandler; +import org.apache.paimon.utils.CloseableIterator; import org.apache.paimon.utils.RoaringBitmap32; import org.apache.commons.math3.random.RandomDataGenerator; @@ -85,6 +86,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -101,7 +103,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -319,7 +320,14 @@ public void testDiscardDuplicateFiles() throws Exception { public void testDiscardDuplicateFilesMultiThread() throws Exception { FileStoreTable table = createFileStoreTable( - options -> options.set(CoreOptions.COMMIT_DISCARD_DUPLICATE_FILES, true)); + options -> { + options.set(CoreOptions.COMMIT_DISCARD_DUPLICATE_FILES, true); + options.set(CoreOptions.COMMIT_MAX_RETRIES, 50); + options.set(CoreOptions.COMMIT_MAX_RETRY_WAIT, Duration.ofMillis(100)); + // Keep all snapshots so concurrent expiry does not race readers. + options.set(CoreOptions.SNAPSHOT_NUM_RETAINED_MIN, 1000); + options.set(CoreOptions.SNAPSHOT_NUM_RETAINED_MAX, 1000); + }); BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); List> messages = new ArrayList<>(); for (int i = 0; i < 10; i++) { @@ -328,33 +336,43 @@ public void testDiscardDuplicateFilesMultiThread() throws Exception { messages.add(write.prepareCommit()); } } - Runnable doCommit = - () -> { - ThreadLocalRandom rnd = ThreadLocalRandom.current(); - for (int i = 0; i < 10; i++) { - try (BatchTableCommit commit = writeBuilder.newCommit()) { - commit.commit(messages.get(rnd.nextInt(messages.size()))); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }; - + int commitThreadNum = 10; + int commitsPerThread = 10; Runnable asserter = () -> { List splits = table.newReadBuilder().newScan().plan().splits(); assertThat(splits.size()).isEqualTo(1); - assertTrue(splits.get(0).convertToRawFiles().get().size() <= 10); + assertThat(splits.get(0).convertToRawFiles().get().size()) + .isLessThanOrEqualTo(messages.size()); }; - // test multiple threads - ExecutorService pool = Executors.newCachedThreadPool(); - List> futures = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - futures.add(pool.submit(doCommit)); - } - for (Future future : futures) { - future.get(); + ExecutorService pool = Executors.newFixedThreadPool(commitThreadNum); + try { + List> futures = new ArrayList<>(); + for (int thread = 0; thread < commitThreadNum; thread++) { + int threadId = thread; + futures.add( + pool.submit( + () -> { + for (int round = 0; round < commitsPerThread; round++) { + int messageIndex = (threadId + round) % messages.size(); + try (BatchTableCommit commit = writeBuilder.newCommit()) { + commit.commit(messages.get(messageIndex)); + } catch (Exception e) { + throw new RuntimeException( + String.format( + "Failed to commit message %s in thread %s round %s.", + messageIndex, threadId, round), + e); + } + } + })); + } + for (Future future : futures) { + future.get(); + } + } finally { + pool.shutdownNow(); } asserter.run(); } @@ -1258,6 +1276,43 @@ public void testLimitPushDown() throws Exception { Thread.sleep(1_000); } + @Test + public void testLimitWithCloseableIterator() throws Exception { + RowType rowType = RowType.builder().field("id", DataTypes.INT()).build(); + Consumer configure = + options -> { + options.set(FILE_FORMAT, FILE_FORMAT_PARQUET); + options.set(WRITE_ONLY, true); + options.set(SOURCE_SPLIT_TARGET_SIZE, MemorySize.ofMebiBytes(256)); + }; + FileStoreTable table = createUnawareBucketFileStoreTable(rowType, configure); + + int rowCount = 5000; + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + for (int i = 0; i < rowCount; i++) { + write.write(GenericRow.of(i)); + } + commit.commit(0, write.prepareCommit(true, 0)); + write.close(); + commit.close(); + + int limit = 10; + TableScan.Plan plan = table.newScan().withLimit(limit).plan(); + RecordReader reader = + table.newRead().withLimit(limit).createReader(plan.splits()); + AtomicInteger count = new AtomicInteger(0); + try (CloseableIterator iterator = reader.toCloseableIterator()) { + while (iterator.hasNext()) { + iterator.next(); + count.incrementAndGet(); + } + } + assertThat(count.get()).isEqualTo(limit); + + Thread.sleep(1_000); + } + @Test public void testWithShardAppendTable() throws Exception { FileStoreTable table = createFileStoreTable(conf -> conf.set(BUCKET, -1)); diff --git a/paimon-core/src/test/java/org/apache/paimon/table/BitmapGlobalIndexTableTest.java b/paimon-core/src/test/java/org/apache/paimon/table/BitmapGlobalIndexTableTest.java deleted file mode 100644 index 6443c6a5ec53..000000000000 --- a/paimon-core/src/test/java/org/apache/paimon/table/BitmapGlobalIndexTableTest.java +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.table; - -import org.apache.paimon.data.BinaryRow; -import org.apache.paimon.data.BinaryString; -import org.apache.paimon.data.InternalRow; -import org.apache.paimon.fs.FileIO; -import org.apache.paimon.globalindex.DataEvolutionBatchScan; -import org.apache.paimon.globalindex.GlobalIndexFileReadWrite; -import org.apache.paimon.globalindex.GlobalIndexResult; -import org.apache.paimon.globalindex.GlobalIndexScanner; -import org.apache.paimon.globalindex.GlobalIndexSingletonWriter; -import org.apache.paimon.globalindex.GlobalIndexer; -import org.apache.paimon.globalindex.GlobalIndexerFactory; -import org.apache.paimon.globalindex.GlobalIndexerFactoryUtils; -import org.apache.paimon.globalindex.ResultEntry; -import org.apache.paimon.globalindex.bitmap.BitmapGlobalIndexerFactory; -import org.apache.paimon.index.GlobalIndexMeta; -import org.apache.paimon.index.IndexFileMeta; -import org.apache.paimon.io.CompactIncrement; -import org.apache.paimon.io.DataIncrement; -import org.apache.paimon.options.Options; -import org.apache.paimon.partition.PartitionPredicate; -import org.apache.paimon.predicate.Predicate; -import org.apache.paimon.predicate.PredicateBuilder; -import org.apache.paimon.reader.RecordReader; -import org.apache.paimon.table.sink.CommitMessage; -import org.apache.paimon.table.sink.CommitMessageImpl; -import org.apache.paimon.table.source.ReadBuilder; -import org.apache.paimon.types.DataField; -import org.apache.paimon.types.RowType; -import org.apache.paimon.utils.Range; -import org.apache.paimon.utils.RoaringNavigableMap64; - -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertNotNull; - -/** Test for BTree indexed batch scan. */ -public class BitmapGlobalIndexTableTest extends DataEvolutionTestBase { - - @Test - public void testBitmapGlobalIndex() throws Exception { - write(100000L); - createIndex("f1"); - - FileStoreTable table = (FileStoreTable) catalog.getTable(identifier()); - - Predicate predicate = - new PredicateBuilder(table.rowType()).equal(1, BinaryString.fromString("a100")); - - RoaringNavigableMap64 rowIds = globalIndexScan(table, predicate); - assertNotNull(rowIds); - Assertions.assertThat(rowIds.getLongCardinality()).isEqualTo(1); - Assertions.assertThat(rowIds.toRangeList()).containsExactly(new Range(100L, 100L)); - - Predicate predicate2 = - new PredicateBuilder(table.rowType()) - .in( - 1, - Arrays.asList( - BinaryString.fromString("a200"), - BinaryString.fromString("a300"), - BinaryString.fromString("a400"))); - - rowIds = globalIndexScan(table, predicate2); - assertNotNull(rowIds); - Assertions.assertThat(rowIds.getLongCardinality()).isEqualTo(3); - Assertions.assertThat(rowIds.toRangeList()) - .containsExactlyInAnyOrder( - new Range(200L, 200L), new Range(300L, 300L), new Range(400L, 400L)); - - RoaringNavigableMap64 finalRowIds = rowIds; - DataEvolutionBatchScan scan = - (DataEvolutionBatchScan) - table.newScan() - .withGlobalIndexResult(GlobalIndexResult.create(() -> finalRowIds)); - - List readF1 = new ArrayList<>(); - table.newRead() - .createReader(scan.plan()) - .forEachRemaining( - row -> { - readF1.add(row.getString(1).toString()); - }); - - Assertions.assertThat(readF1).containsExactly("a200", "a300", "a400"); - } - - @Test - public void testBitmapGlobalIndexWithCoreScan() throws Exception { - write(100000L); - createIndex("f1"); - - FileStoreTable table = (FileStoreTable) catalog.getTable(identifier()); - - Predicate predicate = - new PredicateBuilder(table.rowType()) - .in( - 1, - Arrays.asList( - BinaryString.fromString("a200"), - BinaryString.fromString("a300"), - BinaryString.fromString("a400"), - BinaryString.fromString("a56789"))); - - ReadBuilder readBuilder = table.newReadBuilder().withFilter(predicate); - - List readF1 = new ArrayList<>(); - readBuilder - .newRead() - .createReader(readBuilder.newScan().plan()) - .forEachRemaining( - row -> { - readF1.add(row.getString(1).toString()); - }); - - Assertions.assertThat(readF1).containsExactly("a200", "a300", "a400", "a56789"); - } - - @Test - public void testMultipleBitmapIndices() throws Exception { - write(100000L); - createIndex("f1"); - createIndex("f2"); - - FileStoreTable table = (FileStoreTable) catalog.getTable(identifier()); - Predicate predicate1 = - new PredicateBuilder(table.rowType()) - .in( - 1, - Arrays.asList( - BinaryString.fromString("a200"), - BinaryString.fromString("a300"), - BinaryString.fromString("a56789"))); - - Predicate predicate2 = - new PredicateBuilder(table.rowType()) - .in( - 2, - Arrays.asList( - BinaryString.fromString("b200"), - BinaryString.fromString("b400"), - BinaryString.fromString("b56789"))); - - Predicate predicate = PredicateBuilder.and(predicate1, predicate2); - ReadBuilder readBuilder = table.newReadBuilder().withFilter(predicate); - - List result = new ArrayList<>(); - readBuilder - .newRead() - .createReader(readBuilder.newScan().plan()) - .forEachRemaining( - row -> { - result.add(row.getString(1).toString()); - }); - - Assertions.assertThat(result).containsExactly("a200", "a56789"); - } - - private void createIndex(String fieldName) throws Exception { - FileStoreTable table = (FileStoreTable) catalog.getTable(identifier()); - FileIO fileIO = table.fileIO(); - RowType rowType = SpecialFields.rowTypeWithRowTracking(table.rowType().project(fieldName)); - ReadBuilder readBuilder = table.newReadBuilder().withReadType(rowType); - RecordReader reader = - readBuilder.newRead().createReader(readBuilder.newScan().plan()); - - GlobalIndexFileReadWrite indexFileReadWrite = - new GlobalIndexFileReadWrite( - fileIO, - table.store().pathFactory().indexFileFactory(BinaryRow.EMPTY_ROW, 0)); - - DataField indexField = table.rowType().getField(fieldName); - GlobalIndexerFactory globalIndexerFactory = - GlobalIndexerFactoryUtils.load(BitmapGlobalIndexerFactory.IDENTIFIER); - - List indexFileMetas = - createBitmapIndex(globalIndexerFactory, indexField, reader, indexFileReadWrite); - - DataIncrement dataIncrement = DataIncrement.indexIncrement(indexFileMetas); - - CommitMessage commitMessage = - new CommitMessageImpl( - BinaryRow.EMPTY_ROW, - 0, - null, - dataIncrement, - CompactIncrement.emptyIncrement()); - - table.newBatchWriteBuilder().newCommit().commit(Collections.singletonList(commitMessage)); - } - - private List createBitmapIndex( - GlobalIndexerFactory indexerFactory, - DataField indexField, - RecordReader reader, - GlobalIndexFileReadWrite indexFileReadWrite) - throws Exception { - GlobalIndexer globalIndexer = indexerFactory.create(indexField, new Options()); - GlobalIndexSingletonWriter writer = - (GlobalIndexSingletonWriter) globalIndexer.createWriter(indexFileReadWrite); - - reader.forEachRemaining(r -> writer.write(r.getString(0))); - - List results = writer.finish(); - // bitmap index only generate one file for each writer - Assertions.assertThat(results).hasSize(1); - ResultEntry result = results.get(0); - - String fileName = result.fileName(); - long fileSize = indexFileReadWrite.fileSize(fileName); - GlobalIndexMeta globalIndexMeta = - new GlobalIndexMeta(0, result.rowCount() - 1, indexField.id(), null, result.meta()); - return Collections.singletonList( - new IndexFileMeta( - BitmapGlobalIndexerFactory.IDENTIFIER, - fileName, - fileSize, - result.rowCount(), - globalIndexMeta, - null)); - } - - private RoaringNavigableMap64 globalIndexScan(FileStoreTable table, Predicate predicate) - throws Exception { - try (GlobalIndexScanner scanner = - GlobalIndexScanner.create(table, PartitionPredicate.ALWAYS_TRUE, predicate).get()) { - return scanner.scan(predicate).get().results(); - } - } -} diff --git a/paimon-core/src/test/java/org/apache/paimon/table/BtreeGlobalIndexTableTest.java b/paimon-core/src/test/java/org/apache/paimon/table/BtreeGlobalIndexTableTest.java index 4be621942bc5..80c82fd62641 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/BtreeGlobalIndexTableTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/BtreeGlobalIndexTableTest.java @@ -19,6 +19,7 @@ package org.apache.paimon.table; import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; import org.apache.paimon.globalindex.DataEvolutionBatchScan; import org.apache.paimon.globalindex.GlobalIndexResult; import org.apache.paimon.globalindex.GlobalIndexScanner; @@ -27,11 +28,15 @@ import org.apache.paimon.partition.PartitionPredicate; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.schema.SchemaChange; import org.apache.paimon.table.sink.BatchTableCommit; +import org.apache.paimon.table.sink.BatchTableWrite; +import org.apache.paimon.table.sink.BatchWriteBuilder; import org.apache.paimon.table.sink.CommitMessage; import org.apache.paimon.table.source.DataSplit; import org.apache.paimon.table.source.ReadBuilder; import org.apache.paimon.table.source.Split; +import org.apache.paimon.types.DataTypes; import org.apache.paimon.utils.Range; import org.apache.paimon.utils.RoaringNavigableMap64; @@ -81,7 +86,7 @@ public void testBTreeGlobalIndex() throws Exception { DataEvolutionBatchScan scan = (DataEvolutionBatchScan) table.newScan(); RoaringNavigableMap64 finalRowIds = rowIds; - scan.withGlobalIndexResult(GlobalIndexResult.create(() -> finalRowIds)); + scan.withGlobalIndexResult(GlobalIndexResult.create(finalRowIds)); List readF1 = new ArrayList<>(); table.newRead() @@ -165,6 +170,36 @@ public void testMultipleBTreeIndices() throws Exception { assertThat(result).containsExactly("a200", "a56789"); } + @Test + public void testBTreeGlobalIndexOnAddedColumnContainsOldRowsAsNull() throws Exception { + long oldRowCount = 10L; + write(oldRowCount); + + catalog.alterTable(identifier(), SchemaChange.addColumn("f3", DataTypes.STRING()), false); + FileStoreTable table = (FileStoreTable) catalog.getTable(identifier()); + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = writeBuilder.newWrite()) { + write.write( + GenericRow.of( + 100, + BinaryString.fromString("a-new"), + BinaryString.fromString("b-new"), + BinaryString.fromString("not-null"))); + try (BatchTableCommit commit = writeBuilder.newCommit()) { + commit.commit(write.prepareCommit()); + } + } + + createIndex("f3"); + + table = (FileStoreTable) catalog.getTable(identifier()); + Predicate predicate = new PredicateBuilder(table.rowType()).isNull(3); + RoaringNavigableMap64 rowIds = globalIndexScan(table, predicate); + assertNotNull(rowIds); + assertThat(rowIds.getLongCardinality()).isEqualTo(oldRowCount); + assertThat(rowIds.toRangeList()).containsExactly(new Range(0L, oldRowCount - 1)); + } + private void createIndex(String fieldName) throws Exception { createIndex(fieldName, null); } diff --git a/paimon-core/src/test/java/org/apache/paimon/table/DataEvolutionTableTest.java b/paimon-core/src/test/java/org/apache/paimon/table/DataEvolutionTableTest.java index cc1b07fe6b22..0b1799ac813f 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/DataEvolutionTableTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/DataEvolutionTableTest.java @@ -1541,6 +1541,382 @@ public void testReadAfterMultipleAppendsToDifferentColumnSets() throws Exception assertThat(rows.get(2).getString(2).toString()).isEqualTo("b"); } + /** + * Central repro for the ADD COLUMN bug fixed in this change. Pre-ALTER files do not carry the + * new column physically; {@code WHERE new_col IS NULL} must match every pre-ALTER row. Before + * the fix, the single-entry filterByStats dropped pre-ALTER files at the manifest layer and the + * predicate returned zero rows. + */ + @Test + public void testAddColumnIsNullKeepsPreAlterRows() throws Exception { + createTableDefault(); + Schema schema = schemaDefault(); + + // Pre-ALTER write: only (f0, f1). + BatchWriteBuilder builder = getTableDefault().newBatchWriteBuilder(); + RowType writeF0F1 = schema.rowType().project(Arrays.asList("f0", "f1")); + try (BatchTableWrite write = builder.newWrite().withWriteType(writeF0F1)) { + for (int i = 0; i < 5; i++) { + write.write(GenericRow.of(i, BinaryString.fromString("a" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + + // ADD COLUMN f3 (post-ALTER) and write a full-schema row at a fresh row id. + catalog.alterTable(identifier(), SchemaChange.addColumn("f3", DataTypes.STRING()), false); + FileStoreTable table = getTableDefault(); + builder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite()) { + for (int i = 5; i < 10; i++) { + write.write( + GenericRow.of( + i, + BinaryString.fromString("a" + i), + BinaryString.fromString("c" + i), + BinaryString.fromString("e" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + + // WHERE f3 IS NULL -> pre-ALTER rows (5 of them). + PredicateBuilder pb = new PredicateBuilder(table.rowType()); + int f3Idx = table.rowType().getFieldIndex("f3"); + ReadBuilder rb = table.newReadBuilder().withFilter(pb.isNull(f3Idx)); + assertThat(countMatchingRows(rb)).isEqualTo(5); + } + + /** + * Predicate-aware stats pruning for ADD COLUMN: WHERE new_col = 'something' cannot match + * pre-ALTER rows (their new_col is implicit NULL), so the pre-ALTER manifest must be pruned at + * planning time. The all-NULL encoding in EvolutionStats / DataEvolutionArray makes + * LeafPredicate.test drop the file via the leaf's normal decision instead of falling back to + * "unknown stats -> keep". + */ + @Test + public void testAddColumnEqualityPredicatePrunesPreAlterFiles() throws Exception { + createTableDefault(); + Schema schema = schemaDefault(); + + // Pre-ALTER write: only (f0, f1). + BatchWriteBuilder builder = getTableDefault().newBatchWriteBuilder(); + RowType writeF0F1 = schema.rowType().project(Arrays.asList("f0", "f1")); + try (BatchTableWrite write = builder.newWrite().withWriteType(writeF0F1)) { + for (int i = 0; i < 5; i++) { + write.write(GenericRow.of(i, BinaryString.fromString("a" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + + catalog.alterTable(identifier(), SchemaChange.addColumn("f3", DataTypes.STRING()), false); + FileStoreTable table = getTableDefault(); + builder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite()) { + for (int i = 5; i < 10; i++) { + write.write( + GenericRow.of( + i, + BinaryString.fromString("a" + i), + BinaryString.fromString("c" + i), + BinaryString.fromString("e" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + + // Total files on the table. + assertThat(plannedFileCount(table, null, null)).isEqualTo(2); + + // WHERE f3 = 'e7' -> only the post-ALTER file can match. The pre-ALTER file is + // pruned at planning because EvolutionStats encodes its missing f3 as all-NULL, + // letting LeafPredicate.test evaluate Equal against (min=null, max=null, + // nullCount=rowCount) and return false instead of falling through to + // "unknown stats -> keep". + PredicateBuilder pb = new PredicateBuilder(table.rowType()); + int f3Idx = table.rowType().getFieldIndex("f3"); + Predicate filter = pb.equal(f3Idx, BinaryString.fromString("e7")); + assertThat(plannedFileCount(table, null, filter)).isEqualTo(1); + } + + /** + * Central repro for the RENAME COLUMN bug fixed in this change. The renamed field's id is + * preserved across schemas, so a predicate on the latest name must still match rows in the + * pre-rename file (whose physical writeCols carry the old name). Before the fix, the + * single-entry filterByStats compared by name and dropped pre-rename files at the manifest + * layer. + */ + @Test + public void testRenameColumnPredicateKeepsPreRenameRows() throws Exception { + createTableDefault(); + Schema schema = schemaDefault(); + + // Pre-rename write: f2 carries the values that will later be queried as f3. + BatchWriteBuilder builder = getTableDefault().newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite().withWriteType(schema.rowType())) { + for (int i = 0; i < 5; i++) { + write.write( + GenericRow.of( + i, + BinaryString.fromString("a" + i), + BinaryString.fromString("preR_" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + + catalog.alterTable(identifier(), SchemaChange.renameColumn("f2", "f3"), false); + FileStoreTable table = getTableDefault(); + builder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite()) { + for (int i = 5; i < 10; i++) { + write.write( + GenericRow.of( + i, + BinaryString.fromString("a" + i), + BinaryString.fromString("postR_" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + + // WHERE f3 LIKE 'preR_%' -> rows from the pre-rename file (5 rows). + PredicateBuilder pb = new PredicateBuilder(table.rowType()); + int f3Idx = table.rowType().getFieldIndex("f3"); + ReadBuilder rb = + table.newReadBuilder() + .withFilter(pb.startsWith(f3Idx, BinaryString.fromString("preR_"))); + assertThat(countMatchingRows(rb)).isEqualTo(5); + } + + /** + * Columnar-split: two files cover the same row id range, each carrying a different subset of + * columns. A query that projects only columns owned by one file should not read the other. + */ + @Test + public void testNoFilterProjectionPrunesColumnarSplitFiles() throws Exception { + write(5); + FileStoreTable table = getTableDefault(); + Schema schema = schemaDefault(); + assertThat(plannedFileCount(table, null, null)).isEqualTo(2); + + RowType readF0 = schema.rowType().project(Collections.singletonList("f0")); + assertThat(plannedFileCount(table, readF0, null)).isEqualTo(1); + + RowType readF1 = schema.rowType().project(Collections.singletonList("f1")); + assertThat(plannedFileCount(table, readF1, null)).isEqualTo(1); + + RowType readF2 = schema.rowType().project(Collections.singletonList("f2")); + assertThat(plannedFileCount(table, readF2, null)).isEqualTo(1); + + RowType readF0F2 = schema.rowType().project(Arrays.asList("f0", "f2")); + assertThat(plannedFileCount(table, readF0F2, null)).isEqualTo(2); + + assertThat(plannedFileCount(table, schema.rowType(), null)).isEqualTo(2); + } + + /** + * Row-disjoint pre-ALTER files must not be dropped by the column-pruning logic — the reader + * needs them to emit rowCount NULL-filled rows for the projection. + */ + @Test + public void testNoFilterProjectionKeepsRowDisjointFiles() throws Exception { + createTableDefault(); + Schema schema = schemaDefault(); + BatchWriteBuilder builder = getTableDefault().newBatchWriteBuilder(); + RowType writeType = schema.rowType().project(Arrays.asList("f0", "f1")); + try (BatchTableWrite write = builder.newWrite().withWriteType(writeType)) { + for (int i = 0; i < 5; i++) { + write.write(GenericRow.of(i, BinaryString.fromString("a" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + builder = getTableDefault().newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite().withWriteType(schema.rowType())) { + for (int i = 5; i < 10; i++) { + write.write( + GenericRow.of( + i, + BinaryString.fromString("a" + i), + BinaryString.fromString("b" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + FileStoreTable table = getTableDefault(); + + assertThat(plannedFileCount(table, null, null)).isEqualTo(2); + + // Projecting f2 must still keep the pre-ALTER file as a row-count witness so + // the reader emits 5 NULL-filled rows for the pre-ALTER range. + RowType readF2 = schema.rowType().project(Collections.singletonList("f2")); + assertThat(plannedFileCount(table, readF2, null)).isEqualTo(2); + } + + /** + * Columnar split + predicate on the file-A column: stats prune through file A's column, column + * pruning then drops file B from the kept group. + */ + @Test + public void testColumnarSplitWithPredicateOnFileAColumn() throws Exception { + write(10); + FileStoreTable table = getTableDefault(); + Schema schema = schemaDefault(); + PredicateBuilder pb = new PredicateBuilder(table.rowType()); + int f0Idx = table.rowType().getFieldIndex("f0"); + RowType readF0 = schema.rowType().project(Collections.singletonList("f0")); + assertThat(plannedFileCount(table, readF0, pb.greaterThan(f0Idx, 5))).isEqualTo(1); + assertThat(plannedFileCount(table, readF0, pb.greaterThan(f0Idx, 1000))).isEqualTo(0); + } + + /** + * Columnar split + predicate on the file-B column: stats prune through file B's column, column + * pruning then drops file A from the kept group. + */ + @Test + public void testColumnarSplitWithPredicateOnFileBColumn() throws Exception { + write(10); + FileStoreTable table = getTableDefault(); + Schema schema = schemaDefault(); + PredicateBuilder pb = new PredicateBuilder(table.rowType()); + int f2Idx = table.rowType().getFieldIndex("f2"); + RowType readF2 = schema.rowType().project(Collections.singletonList("f2")); + assertThat(plannedFileCount(table, readF2, pb.equal(f2Idx, BinaryString.fromString("b5")))) + .isEqualTo(1); + } + + /** + * Three-way columnar split: fileA{f0}, fileB{f1}, fileC{f2} share a row id range. A query that + * touches one column should retain exactly that one file. + */ + @Test + public void testThreeWayColumnarSplitPruning() throws Exception { + createTableDefault(); + Schema schema = schemaDefault(); + BatchWriteBuilder builder = getTableDefault().newBatchWriteBuilder(); + + RowType writeF0 = schema.rowType().project(Collections.singletonList("f0")); + try (BatchTableWrite write = builder.newWrite().withWriteType(writeF0)) { + for (int i = 0; i < 5; i++) { + write.write(GenericRow.of(i)); + } + builder.newCommit().commit(write.prepareCommit()); + } + + builder = getTableDefault().newBatchWriteBuilder(); + RowType writeF1 = schema.rowType().project(Collections.singletonList("f1")); + try (BatchTableWrite write = builder.newWrite().withWriteType(writeF1)) { + for (int i = 0; i < 5; i++) { + write.write(GenericRow.of(BinaryString.fromString("f1_" + i))); + } + List msgs = write.prepareCommit(); + setFirstRowId(msgs, 0L); + builder.newCommit().commit(msgs); + } + + builder = getTableDefault().newBatchWriteBuilder(); + RowType writeF2 = schema.rowType().project(Collections.singletonList("f2")); + try (BatchTableWrite write = builder.newWrite().withWriteType(writeF2)) { + for (int i = 0; i < 5; i++) { + write.write(GenericRow.of(BinaryString.fromString("f2_" + i))); + } + List msgs = write.prepareCommit(); + setFirstRowId(msgs, 0L); + builder.newCommit().commit(msgs); + } + + FileStoreTable table = getTableDefault(); + assertThat(plannedFileCount(table, null, null)).isEqualTo(3); + assertThat( + plannedFileCount( + table, + schema.rowType().project(Collections.singletonList("f0")), + null)) + .isEqualTo(1); + assertThat( + plannedFileCount( + table, + schema.rowType().project(Collections.singletonList("f1")), + null)) + .isEqualTo(1); + assertThat( + plannedFileCount( + table, + schema.rowType().project(Collections.singletonList("f2")), + null)) + .isEqualTo(1); + assertThat( + plannedFileCount( + table, schema.rowType().project(Arrays.asList("f0", "f2")), null)) + .isEqualTo(2); + assertThat( + plannedFileCount( + table, schema.rowType().project(Arrays.asList("f1", "f2")), null)) + .isEqualTo(2); + } + + /** + * A columnar-split group covering rows 0..4 (file A {f0,f1} + file B {f2}), plus a row-disjoint + * group at rows 5..9 (file C with the full schema). Per-group column pruning composes correctly + * across the two topologies. + */ + @Test + public void testMixedColumnarSplitAndRowDisjoint() throws Exception { + write(5); + Schema schema = schemaDefault(); + BatchWriteBuilder builder = getTableDefault().newBatchWriteBuilder(); + try (BatchTableWrite write = builder.newWrite().withWriteType(schema.rowType())) { + for (int i = 5; i < 10; i++) { + write.write( + GenericRow.of( + i, + BinaryString.fromString("a" + i), + BinaryString.fromString("c" + i))); + } + builder.newCommit().commit(write.prepareCommit()); + } + FileStoreTable table = getTableDefault(); + + assertThat(plannedFileCount(table, null, null)).isEqualTo(3); + RowType readF0 = schema.rowType().project(Collections.singletonList("f0")); + assertThat(plannedFileCount(table, readF0, null)).isEqualTo(2); + RowType readF2 = schema.rowType().project(Collections.singletonList("f2")); + assertThat(plannedFileCount(table, readF2, null)).isEqualTo(2); + } + + /** + * System-field-only projection is filtered out of readType in + * DataEvolutionFileStoreScan.withReadType — readType stays null and + * postFilterManifestEntriesEnabled returns false. The column-pruning path is not entered, so + * every file in every group flows through unchanged. + */ + @Test + public void testSystemFieldOnlyProjectionIsNotPruned() throws Exception { + write(5); + FileStoreTable table = getTableDefault(); + assertThat(plannedFileCount(table, null, null)).isEqualTo(2); + assertThat(plannedFileCount(table, RowType.of(SpecialFields.ROW_ID), null)).isEqualTo(2); + } + + private static int plannedFileCount(FileStoreTable table, RowType readType, Predicate filter) { + ReadBuilder rb = table.newReadBuilder(); + if (readType != null) { + rb = rb.withReadType(readType); + } + if (filter != null) { + rb = rb.withFilter(filter); + } + return rb.newScan().plan().splits().stream() + .mapToInt( + s -> + s instanceof DataSplit + ? ((DataSplit) s).dataFiles().size() + : ((IndexedSplit) s).dataSplit().dataFiles().size()) + .sum(); + } + + private static long countMatchingRows(ReadBuilder rb) throws Exception { + RecordReader reader = rb.newRead().createReader(rb.newScan().plan()); + AtomicInteger cnt = new AtomicInteger(0); + reader.forEachRemaining(r -> cnt.incrementAndGet()); + reader.close(); + return cnt.get(); + } + private Range assertContinuousRowIdRange(List files) { files.sort(Comparator.comparingLong(DataFileMeta::nonNullFirstRowId)); long start = files.get(0).nonNullFirstRowId(); diff --git a/paimon-core/src/test/java/org/apache/paimon/table/PrimaryKeySimpleTableTest.java b/paimon-core/src/test/java/org/apache/paimon/table/PrimaryKeySimpleTableTest.java index 82b097d8949a..4628710b9e78 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/PrimaryKeySimpleTableTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/PrimaryKeySimpleTableTest.java @@ -1786,7 +1786,7 @@ public void testPartialUpdateWithAgg() throws Exception { createFileStoreTable( options -> { options.set("merge-engine", "partial-update"); - options.set("fields.a.sequence-group", "c"); + options.set("fields.b.sequence-group", "c"); options.set("fields.c.aggregate-function", "sum"); }, rowType); @@ -1795,32 +1795,32 @@ public void testPartialUpdateWithAgg() throws Exception { TableRead read = table.newRead(); StreamTableWrite write = table.newWrite(""); StreamTableCommit commit = table.newCommit(""); - // 1. inserts + // 1. inserts (b=3 is the sequence field, all rows have same b=3 so all accepted) write.write(GenericRow.of(1, 1, 3, 3)); - write.write(GenericRow.of(1, 1, 1, 1)); - write.write(GenericRow.of(1, 1, 2, 2)); + write.write(GenericRow.of(1, 1, 3, 1)); + write.write(GenericRow.of(1, 1, 3, 2)); commit.commit(0, write.prepareCommit(true, 0)); List result = getResult(read, toSplits(snapshotReader.read().dataSplits()), rowToString); - assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 2, 6]"); + assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 3, 6]"); - // 2. Update Before - write.write(GenericRow.ofKind(RowKind.UPDATE_BEFORE, 1, 1, 2, 2)); + // 2. Update Before (b=3, same sequence) + write.write(GenericRow.ofKind(RowKind.UPDATE_BEFORE, 1, 1, 3, 2)); commit.commit(1, write.prepareCommit(true, 1)); result = getResult(read, toSplits(snapshotReader.read().dataSplits()), rowToString); - assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 2, 4]"); + assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 3, 4]"); - // 3. Update After - write.write(GenericRow.ofKind(RowKind.UPDATE_AFTER, 1, 1, 2, 3)); + // 3. Update After (b=3, same sequence) + write.write(GenericRow.ofKind(RowKind.UPDATE_AFTER, 1, 1, 3, 3)); commit.commit(2, write.prepareCommit(true, 2)); result = getResult(read, toSplits(snapshotReader.read().dataSplits()), rowToString); - assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 2, 7]"); + assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 3, 7]"); - // 4. Retracts - write.write(GenericRow.ofKind(RowKind.DELETE, 1, 1, 2, 3)); + // 4. Retracts (b=3, same sequence) + write.write(GenericRow.ofKind(RowKind.DELETE, 1, 1, 3, 3)); commit.commit(3, write.prepareCommit(true, 3)); result = getResult(read, toSplits(snapshotReader.read().dataSplits()), rowToString); - assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 2, 4]"); + assertThat(result).containsExactlyInAnyOrder("+I[1, 1, 3, 4]"); write.close(); commit.close(); } @@ -2688,4 +2688,573 @@ public void testMergeBranchPrimaryKeyTable() throws Exception { assertThatThrownBy(() -> table.mergeBranch(BRANCH_NAME, "main")) .satisfies(anyCauseMatches(IllegalArgumentException.class, "append-only tables")); } + + @Test + public void testSnapshotSequenceOrdering() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + // Snapshot 1: write pk=(1,10) many times so that the per-record sequence number is high. + for (int i = 0; i < 100; i++) { + write.write(rowData(1, 10, 999L)); + } + commit.commit(0, write.prepareCommit(false, 0)); + + // Snapshot 2: write pk=(1,10) once with a lower value. Because the snapshot id (2) + // is larger than snapshot 1, this record should win even though its per-record sequence + // number is much lower. + write.write(rowData(1, 10, 1L)); + commit.commit(1, write.prepareCommit(false, 1)); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + assertThat(result).containsExactly("1|10|1"); + + write.close(); + commit.close(); + } + + @Test + public void testSnapshotSequenceOrderingWithMinHeap() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + conf.set(CoreOptions.SORT_ENGINE, CoreOptions.SortEngine.MIN_HEAP); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + for (int i = 0; i < 100; i++) { + write.write(rowData(1, 10, 999L)); + } + commit.commit(0, write.prepareCommit(false, 0)); + + write.write(rowData(1, 10, 1L)); + commit.commit(1, write.prepareCommit(false, 1)); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + assertThat(result).containsExactly("1|10|1"); + + write.close(); + commit.close(); + } + + @Test + public void testSnapshotSequenceOrderingFallsBackToSequenceWithinSnapshot() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + // Within a single snapshot, sequence number is the tiebreaker. The later write (999) + // gets a higher sequence number and should win. + write.write(rowData(1, 10, 1L)); + write.write(rowData(1, 10, 999L)); + commit.commit(0, write.prepareCommit(false, 0)); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + assertThat(result).containsExactly("1|10|999"); + + write.close(); + commit.close(); + } + + @Test + public void testSnapshotSequenceOrderingCompactionPreservesInputSnapshotId() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + // Snapshot 1: write pk=(1,10) with val=100 + write.write(rowData(1, 10, 100L)); + commit.commit(0, write.prepareCommit(false, 0)); + + // Snapshot 2: write pk=(1,10) with val=200 (this should win after compaction) + write.write(rowData(1, 10, 200L)); + commit.commit(1, write.prepareCommit(false, 1)); + + // Snapshot 3: write a DIFFERENT key pk=(1,20) + write.write(rowData(1, 20, 300L)); + commit.commit(2, write.prepareCommit(false, 2)); + + // Snapshot 4: compact using dedicated compact writer (simulates compact job) + write.close(); + commit.close(); + FileStoreTable compactTable = + table.copy(Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), "false")); + StreamTableWrite compactWrite = compactTable.newWrite(commitUser); + StreamTableCommit compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(3, compactWrite.prepareCommit(true, 3)); + compactWrite.close(); + compactCommit.close(); + + List splits = table.newSnapshotReader().read().dataSplits(); + for (DataSplit split : splits) { + for (DataFileMeta file : split.dataFiles()) { + // The compacted file's minSequenceNumber should reflect the min snapshot id + // of records inside (from per-record _SEQUENCE_NUMBER values written during + // compaction), NOT the compaction commit's snapshot id (4). + assertThat(file.minSequenceNumber()) + .as( + "Compacted file %s should have minSequenceNumber from per-record " + + "snapshot ids, not the compaction commit's snapshot id", + file.fileName()) + .isLessThanOrEqualTo(3); + } + } + + // Also verify the read result is correct + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, toSplits(splits), toString); + assertThat(result).containsExactlyInAnyOrder("1|10|200", "1|20|300"); + } + + @Test + public void testSnapshotSequenceOrderingCompactionNoOrderingReversal() throws Exception { + // Reproduces the scenario from the PR review: compaction of files from + // snapshot 1 and 3 must NOT cause records from snapshot 1 to win over + // an uncompacted file from snapshot 2. + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + conf.set(CoreOptions.BUCKET, 1); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + // Snapshot 1: write pk=(1,10) with val=100 + write.write(rowData(1, 10, 100L)); + commit.commit(0, write.prepareCommit(false, 0)); + + // Snapshot 2: write SAME key pk=(1,10) with val=200 — this should win + write.write(rowData(1, 10, 200L)); + commit.commit(1, write.prepareCommit(false, 1)); + + // Snapshot 3: write DIFFERENT key pk=(1,20) with val=300 + write.write(rowData(1, 20, 300L)); + commit.commit(2, write.prepareCommit(false, 2)); + + // Compact all files using dedicated compact writer + write.close(); + commit.close(); + FileStoreTable compactTable = + table.copy(Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), "false")); + StreamTableWrite compactWrite = compactTable.newWrite(commitUser); + StreamTableCommit compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(3, compactWrite.prepareCommit(true, 3)); + compactWrite.close(); + compactCommit.close(); + + // Write pk=(1,10) again with val=999 — snapshot 5 should definitely win + write = table.newWrite(commitUser); + commit = table.newCommit(commitUser); + write.write(rowData(1, 10, 999L)); + commit.commit(4, write.prepareCommit(false, 4)); + + write.close(); + commit.close(); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + // pk=(1,10): snapshot 5 (val=999) wins over snapshot 2 (val=200) and snapshot 1 (val=100) + // pk=(1,20): snapshot 3 (val=300) is the only version + assertThat(result).containsExactlyInAnyOrder("1|10|999", "1|20|300"); + } + + @Test + public void testSnapshotSequenceOrderingMultiRoundCompaction() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + conf.set(CoreOptions.BUCKET, 1); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + // Snapshot 1: pk=(1,10) val=100 + write.write(rowData(1, 10, 100L)); + commit.commit(0, write.prepareCommit(false, 0)); + + // Snapshot 2: pk=(1,10) val=200 — should win over snapshot 1 + write.write(rowData(1, 10, 200L)); + commit.commit(1, write.prepareCommit(false, 1)); + + // Snapshot 3: pk=(1,20) val=300 + write.write(rowData(1, 20, 300L)); + commit.commit(2, write.prepareCommit(false, 2)); + + // First compaction (snapshot 4) using dedicated compact writer + write.close(); + commit.close(); + FileStoreTable compactTable = + table.copy(Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), "false")); + StreamTableWrite compactWrite = compactTable.newWrite(commitUser); + StreamTableCommit compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(3, compactWrite.prepareCommit(true, 3)); + compactWrite.close(); + compactCommit.close(); + + // Snapshot 5: pk=(1,10) val=500 — should win over everything + write = table.newWrite(commitUser); + commit = table.newCommit(commitUser); + write.write(rowData(1, 10, 500L)); + commit.commit(4, write.prepareCommit(false, 4)); + + // Snapshot 6: pk=(1,30) val=600 + write.write(rowData(1, 30, 600L)); + commit.commit(5, write.prepareCommit(false, 5)); + + // Second compaction (snapshot 7) using dedicated compact writer + write.close(); + commit.close(); + compactWrite = compactTable.newWrite(commitUser); + compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(6, compactWrite.prepareCommit(true, 6)); + compactWrite.close(); + compactCommit.close(); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + assertThat(result).containsExactlyInAnyOrder("1|10|500", "1|20|300", "1|30|600"); + } + + @Test + public void testSnapshotSequenceOrderingWithChangelogInput() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + conf.set(CHANGELOG_PRODUCER, ChangelogProducer.INPUT); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + write.write(rowData(1, 10, 100L)); + commit.commit(0, write.prepareCommit(false, 0)); + + write.write(rowData(1, 10, 1L)); + commit.commit(1, write.prepareCommit(false, 1)); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + assertThat(result).containsExactly("1|10|1"); + + write.close(); + commit.close(); + } + + @Test + public void testSnapshotSequenceOrderingWithChangelogLookup() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + conf.set(CHANGELOG_PRODUCER, LOOKUP); + }); + StreamTableWrite write = + table.newWrite(commitUser).withIOManager(new IOManagerImpl(tempDir.toString())); + StreamTableCommit commit = table.newCommit(commitUser); + + write.write(rowData(1, 10, 100L)); + commit.commit(0, write.prepareCommit(false, 0)); + + write.write(rowData(1, 10, 1L)); + commit.commit(1, write.prepareCommit(false, 1)); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + assertThat(result).containsExactly("1|10|1"); + + write.close(); + commit.close(); + } + + @Test + public void testSnapshotSequenceOrderingDeleteFromLaterSnapshot() throws Exception { + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + }); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + write.write(rowData(1, 10, 100L)); + commit.commit(0, write.prepareCommit(false, 0)); + + write.write(rowDataWithKind(RowKind.DELETE, 1, 10, 100L)); + commit.commit(1, write.prepareCommit(false, 1)); + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getLong(2); + List result = getResult(read, splits, toString); + assertThat(result).isEmpty(); + + write.close(); + commit.close(); + } + + /** + * Regression: with snapshot-ordering on, a partial-update merge function must keep its result's + * {@code sequenceNumber} equal to the snapshot id carried by its inputs. The compacted file's + * per-record {@code _SEQUENCE_NUMBER} (and therefore its file-level minSequenceNumber) must + * stay a real snapshot id (>= 0); a regression to -1 would break ordering against later + * snapshots. + */ + @Test + public void testSnapshotSequenceOrderingPartialUpdateCompactionPreservesSnapshotId() + throws Exception { + RowType rowType = + RowType.of( + new DataType[] { + DataTypes.INT(), DataTypes.INT(), DataTypes.INT(), DataTypes.INT() + }, + new String[] {"pt", "a", "b", "c"}); + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + conf.set(MERGE_ENGINE, PARTIAL_UPDATE); + conf.set(BUCKET, 1); + }, + rowType); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + // Snapshot 1: partial write of column b + write.write(GenericRow.of(1, 1, 100, null)); + commit.commit(0, write.prepareCommit(false, 0)); + + // Snapshot 2: partial write of column c — partial-update merges with snapshot 1's row + write.write(GenericRow.of(1, 1, null, 200)); + commit.commit(1, write.prepareCommit(false, 1)); + + // Snapshot 3: compact files from snapshots 1+2 using a dedicated compact writer. The + // compaction reader merges the two partial rows through PartialUpdateMergeFunction; the + // merged record's sequenceNumber must stay a real snapshot id. + write.close(); + commit.close(); + FileStoreTable compactTable = + table.copy(Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), "false")); + StreamTableWrite compactWrite = compactTable.newWrite(commitUser); + StreamTableCommit compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(2, compactWrite.prepareCommit(true, 2)); + compactWrite.close(); + compactCommit.close(); + + List splitsAfterCompact = table.newSnapshotReader().read().dataSplits(); + for (DataSplit split : splitsAfterCompact) { + for (DataFileMeta file : split.dataFiles()) { + assertThat(file.minSequenceNumber()) + .as( + "Compacted file %s must carry a real snapshot id in" + + " minSequenceNumber (>= 0). A value of -1 means the" + + " partial-update merge result lost its snapshot id" + + " during compaction.", + file.fileName()) + .isGreaterThanOrEqualTo(0L); + } + } + + // Snapshot 4: write a fresh value of b — this snapshot must win. + write = table.newWrite(commitUser); + commit = table.newCommit(commitUser); + write.write(GenericRow.of(1, 1, 999, null)); + commit.commit(3, write.prepareCommit(false, 3)); + + // Snapshot 5: another compaction using dedicated compact writer + write.close(); + commit.close(); + compactWrite = compactTable.newWrite(commitUser); + compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(4, compactWrite.prepareCommit(true, 4)); + compactWrite.close(); + compactCommit.close(); + for (DataSplit split : table.newSnapshotReader().read().dataSplits()) { + for (DataFileMeta file : split.dataFiles()) { + assertThat(file.minSequenceNumber()) + .as("Final compacted file %s minSequenceNumber", file.fileName()) + .isGreaterThanOrEqualTo(0L); + assertThat(file.maxSequenceNumber()) + .as("Final compacted file %s maxSequenceNumber", file.fileName()) + .isGreaterThanOrEqualTo(0L); + } + } + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> + r.getInt(0) + + "|" + + r.getInt(1) + + "|" + + (r.isNullAt(2) ? "null" : r.getInt(2)) + + "|" + + (r.isNullAt(3) ? "null" : r.getInt(3)); + List result = getResult(read, splits, toString); + // b=999 (snapshot 4 wins over snapshot 1's 100), c=200 (only snapshot 2 wrote it) + assertThat(result).containsExactly("1|1|999|200"); + + write.close(); + commit.close(); + } + + /** + * Regression: with snapshot-ordering on, an aggregate merge function must keep its result's + * {@code sequenceNumber} equal to the snapshot id carried by its inputs. Mirrors the + * partial-update regression — if the merged record loses the snapshot id, the compacted file's + * minSequenceNumber regresses to -1. + */ + @Test + public void testSnapshotSequenceOrderingAggregateCompactionPreservesSnapshotId() + throws Exception { + RowType rowType = + RowType.of( + new DataType[] { + DataTypes.INT(), DataTypes.INT(), DataTypes.INT(), DataTypes.INT() + }, + new String[] {"pt", "a", "b", "c"}); + FileStoreTable table = + createFileStoreTable( + conf -> { + conf.set(CoreOptions.SEQUENCE_SNAPSHOT_ORDERING, true); + conf.set(CoreOptions.WRITE_ONLY, true); + conf.set(MERGE_ENGINE, AGGREGATE); + conf.set(BUCKET, 1); + conf.set("fields.b.aggregate-function", "sum"); + conf.set("fields.c.aggregate-function", "max"); + }, + rowType); + StreamTableWrite write = table.newWrite(commitUser); + StreamTableCommit commit = table.newCommit(commitUser); + + // Snapshot 1 + write.write(GenericRow.of(1, 1, 10, 100)); + commit.commit(0, write.prepareCommit(false, 0)); + + // Snapshot 2: aggregates with snapshot 1's row. + write.write(GenericRow.of(1, 1, 20, 50)); + commit.commit(1, write.prepareCommit(false, 1)); + + // Snapshot 3: compact using dedicated compact writer + write.close(); + commit.close(); + FileStoreTable compactTable = + table.copy(Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), "false")); + StreamTableWrite compactWrite = compactTable.newWrite(commitUser); + StreamTableCommit compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(2, compactWrite.prepareCommit(true, 2)); + compactWrite.close(); + compactCommit.close(); + + for (DataSplit split : table.newSnapshotReader().read().dataSplits()) { + for (DataFileMeta file : split.dataFiles()) { + assertThat(file.minSequenceNumber()) + .as( + "Aggregate-compacted file %s must carry a real snapshot id in" + + " minSequenceNumber (>= 0). A value of -1 means the" + + " aggregate merge result lost its snapshot id during" + + " compaction.", + file.fileName()) + .isGreaterThanOrEqualTo(0L); + } + } + + // Snapshot 4: another insert that must aggregate on top of the compacted result. + write = table.newWrite(commitUser); + commit = table.newCommit(commitUser); + write.write(GenericRow.of(1, 1, 5, 999)); + commit.commit(3, write.prepareCommit(false, 3)); + + // Snapshot 5: final compaction using dedicated compact writer + write.close(); + commit.close(); + compactWrite = compactTable.newWrite(commitUser); + compactCommit = compactTable.newCommit(commitUser); + compactWrite.compact(binaryRow(1), 0, true); + compactCommit.commit(4, compactWrite.prepareCommit(true, 4)); + compactWrite.close(); + compactCommit.close(); + for (DataSplit split : table.newSnapshotReader().read().dataSplits()) { + for (DataFileMeta file : split.dataFiles()) { + assertThat(file.minSequenceNumber()) + .as("Final compacted file %s minSequenceNumber", file.fileName()) + .isGreaterThanOrEqualTo(0L); + assertThat(file.maxSequenceNumber()) + .as("Final compacted file %s maxSequenceNumber", file.fileName()) + .isGreaterThanOrEqualTo(0L); + } + } + + List splits = toSplits(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newReadBuilder().newRead(); + Function toString = + r -> r.getInt(0) + "|" + r.getInt(1) + "|" + r.getInt(2) + "|" + r.getInt(3); + List result = getResult(read, splits, toString); + // b = sum(10, 20, 5) = 35, c = max(100, 50, 999) = 999 + assertThat(result).containsExactly("1|1|35|999"); + + write.close(); + commit.close(); + } } diff --git a/paimon-core/src/test/java/org/apache/paimon/table/SchemaEvolutionTest.java b/paimon-core/src/test/java/org/apache/paimon/table/SchemaEvolutionTest.java index 58d82b0bc97d..01570b422ff0 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/SchemaEvolutionTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/SchemaEvolutionTest.java @@ -41,8 +41,10 @@ import org.apache.paimon.table.source.snapshot.SnapshotReader; import org.apache.paimon.types.DataField; import org.apache.paimon.types.DataType; +import org.apache.paimon.types.DataTypeRoot; import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.LazyField; import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableList; import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; @@ -188,6 +190,352 @@ public void testAddDuplicateField() throws Exception { columnName, identifier.getFullName()); } + @Test + public void testAddBlobColumnViaCommentDirective() throws Exception { + // create table with one pre-existing BLOB column registered in blob-field, so we can + // also verify that ADD COLUMN appends to (rather than overwrites) the existing value. + Map options = blobEnabledOptions(); + options.put(CoreOptions.BLOB_FIELD.key(), "existing_col"); + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + new DataField( + 1, "existing_col", DataTypes.BLOB().copy(true)) + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + "")); + + // bare directive — no user comment, appended to existing blob-field value. + // directive + user comment, SDK caller passes a BlobType directly (allowed when + // accompanied by a directive so the storage mode is explicit). + schemaManager.commitChanges( + ImmutableList.of( + SchemaChange.addColumn("picture", DataTypes.BYTES(), "__BLOB_FIELD", null), + SchemaChange.addColumn( + "desc_col", + DataTypes.BLOB(), + "__BLOB_DESCRIPTOR_FIELD; descriptor comment", + null))); + + TableSchema latest = schemaManager.latest().get(); + + DataField picture = + latest.fields().stream().filter(f -> f.name().equals("picture")).findFirst().get(); + assertThat(picture.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(picture.description()).isNull(); + + DataField desc = + latest.fields().stream().filter(f -> f.name().equals("desc_col")).findFirst().get(); + assertThat(desc.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(desc.description()).isEqualTo("descriptor comment"); + + assertThat(latest.options().get(CoreOptions.BLOB_FIELD.key())) + .isEqualTo("existing_col,picture"); + assertThat(latest.options().get(CoreOptions.BLOB_DESCRIPTOR_FIELD.key())) + .isEqualTo("desc_col"); + } + + @Test + public void testAddBlobColumnErrors() throws Exception { + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + new DataField( + 1, + "nested", + DataTypes.ROW( + new DataField(2, "a", DataTypes.INT()))) + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + new HashMap<>(), + "")); + + // non-BYTES/BINARY type rejected. + assertThatThrownBy( + () -> + schemaManager.commitChanges( + Collections.singletonList( + SchemaChange.addColumn( + "bad", + DataTypes.INT(), + "__BLOB_FIELD", + null)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be of BYTES, BINARY or BLOB type"); + + // nested column rejected. + assertThatThrownBy( + () -> + schemaManager.commitChanges( + Collections.singletonList( + SchemaChange.addColumn( + new String[] {"nested", "blob"}, + DataTypes.BYTES(), + "__BLOB_FIELD", + null)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("nested column"); + + // SET OPTION on blob-field is rejected (the option is @Immutable). + TableSchema oldSchema = schemaManager.latest().get(); + LazyField hasSnapshots = new LazyField<>(() -> true); + LazyField lazyId = new LazyField<>(() -> identifier); + assertThatThrownBy( + () -> + SchemaManager.generateTableSchema( + oldSchema, + Collections.singletonList( + SchemaChange.setOption( + CoreOptions.BLOB_FIELD.key(), "k")), + hasSnapshots, + lazyId)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining(CoreOptions.BLOB_FIELD.key()); + } + + @Test + public void testDropColumnCleansOptions() throws Exception { + Map options = blobEnabledOptions(); + options.put(CoreOptions.VECTOR_FILE_FORMAT.key(), "json"); + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + new DataField( + 1, "pic", DataTypes.BYTES(), "__BLOB_FIELD"), + new DataField( + 2, + "emb", + DataTypes.ARRAY(DataTypes.FLOAT()), + "__VECTOR_FIELD;64") + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + "")); + + TableSchema before = schemaManager.latest().get(); + assertThat(before.options().get(CoreOptions.BLOB_FIELD.key())).isEqualTo("pic"); + assertThat(before.options().get(CoreOptions.VECTOR_FIELD.key())).isEqualTo("emb"); + + schemaManager.commitChanges(Collections.singletonList(SchemaChange.dropColumn("pic"))); + TableSchema afterPic = schemaManager.latest().get(); + assertThat(afterPic.options()).doesNotContainKey(CoreOptions.BLOB_FIELD.key()); + assertThat(afterPic.options().get(CoreOptions.VECTOR_FIELD.key())).isEqualTo("emb"); + + schemaManager.commitChanges(Collections.singletonList(SchemaChange.dropColumn("emb"))); + TableSchema afterEmb = schemaManager.latest().get(); + assertThat(afterEmb.options()).doesNotContainKey(CoreOptions.VECTOR_FIELD.key()); + } + + @Test + public void testUpdateColumnTypeOnBlobIsRejected() throws Exception { + Map options = blobEnabledOptions(); + options.put(CoreOptions.BLOB_FIELD.key(), "pic"); + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + new DataField(1, "pic", DataTypes.BLOB().copy(true)), + new DataField(2, "raw", DataTypes.BYTES()) + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + "")); + + // BLOB -> BYTES rejected. + assertThatThrownBy( + () -> + schemaManager.commitChanges( + Collections.singletonList( + SchemaChange.updateColumnType( + "pic", DataTypes.BYTES())))) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("BLOB"); + + // BYTES -> BLOB rejected (must be added via ADD COLUMN directive instead). + assertThatThrownBy( + () -> + schemaManager.commitChanges( + Collections.singletonList( + SchemaChange.updateColumnType( + "raw", DataTypes.BLOB())))) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("BLOB"); + } + + @Test + public void testAddBlobViewColumnViaCommentDirective() throws Exception { + Map options = blobEnabledOptions(); + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + "")); + + schemaManager.commitChanges( + Collections.singletonList( + SchemaChange.addColumn( + "view_col", + DataTypes.BYTES(), + "__BLOB_VIEW_FIELD; view comment", + null))); + + TableSchema latest = schemaManager.latest().get(); + DataField viewCol = + latest.fields().stream().filter(f -> f.name().equals("view_col")).findFirst().get(); + assertThat(viewCol.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(viewCol.description()).isEqualTo("view comment"); + assertThat(latest.options().get(CoreOptions.BLOB_VIEW_FIELD.key())).isEqualTo("view_col"); + } + + @Test + public void testAddVectorColumnViaCommentDirective() throws Exception { + Map options = blobEnabledOptions(); + options.put(CoreOptions.VECTOR_FILE_FORMAT.key(), "json"); + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + "")); + + schemaManager.commitChanges( + Collections.singletonList( + SchemaChange.addColumn( + "embedding", + DataTypes.ARRAY(DataTypes.FLOAT()), + "__VECTOR_FIELD;128; embedding vector", + null))); + + TableSchema latest = schemaManager.latest().get(); + DataField embedding = + latest.fields().stream() + .filter(f -> f.name().equals("embedding")) + .findFirst() + .get(); + assertThat(embedding.type().getTypeRoot()).isEqualTo(DataTypeRoot.VECTOR); + assertThat(embedding.description()).isEqualTo("embedding vector"); + assertThat(latest.options().get(CoreOptions.VECTOR_FIELD.key())).isEqualTo("embedding"); + } + + @Test + public void testAddVectorColumnErrors() throws Exception { + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + new HashMap<>(), + "")); + + // non-ARRAY type rejected for vector directive + assertThatThrownBy( + () -> + schemaManager.commitChanges( + Collections.singletonList( + SchemaChange.addColumn( + "bad", + DataTypes.INT(), + "__VECTOR_FIELD;128", + null)))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be of ARRAY type"); + } + + @Test + public void testCreateTableWithCommentDirectives() throws Exception { + Map options = blobEnabledOptions(); + options.put(CoreOptions.VECTOR_FILE_FORMAT.key(), "json"); + schemaManager.createTable( + new Schema( + RowType.of( + new DataField[] { + new DataField(0, "k", DataTypes.INT()), + new DataField( + 1, + "pic", + DataTypes.BYTES(), + "__BLOB_FIELD; picture"), + new DataField( + 2, + "view_col", + DataTypes.BYTES(), + "__BLOB_VIEW_FIELD; view field"), + new DataField( + 3, + "embedding", + DataTypes.ARRAY(DataTypes.FLOAT()), + "__VECTOR_FIELD;64; my embedding") + }) + .getFields(), + Collections.emptyList(), + Collections.emptyList(), + options, + "")); + + TableSchema latest = schemaManager.latest().get(); + + DataField pic = + latest.fields().stream().filter(f -> f.name().equals("pic")).findFirst().get(); + assertThat(pic.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(pic.description()).isEqualTo("picture"); + + DataField viewCol = + latest.fields().stream().filter(f -> f.name().equals("view_col")).findFirst().get(); + assertThat(viewCol.type().getTypeRoot()).isEqualTo(DataTypeRoot.BLOB); + assertThat(viewCol.description()).isEqualTo("view field"); + + DataField embedding = + latest.fields().stream() + .filter(f -> f.name().equals("embedding")) + .findFirst() + .get(); + assertThat(embedding.type().getTypeRoot()).isEqualTo(DataTypeRoot.VECTOR); + assertThat(embedding.description()).isEqualTo("my embedding"); + + assertThat(latest.options().get(CoreOptions.BLOB_FIELD.key())).isEqualTo("pic"); + assertThat(latest.options().get(CoreOptions.BLOB_VIEW_FIELD.key())).isEqualTo("view_col"); + assertThat(latest.options().get(CoreOptions.VECTOR_FIELD.key())).isEqualTo("embedding"); + } + + private static Map blobEnabledOptions() { + Map options = new HashMap<>(); + options.put(CoreOptions.DATA_EVOLUTION_ENABLED.key(), "true"); + options.put(CoreOptions.ROW_TRACKING_ENABLED.key(), "true"); + options.put(CoreOptions.BUCKET.key(), "-1"); + return options; + } + @Test public void testUpdateFieldType() throws Exception { Schema schema = diff --git a/paimon-core/src/test/java/org/apache/paimon/table/SimpleTableTestBase.java b/paimon-core/src/test/java/org/apache/paimon/table/SimpleTableTestBase.java index cfdccc197ec3..e02c2bae1bea 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/SimpleTableTestBase.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/SimpleTableTestBase.java @@ -615,6 +615,17 @@ public void testCopyWithLatestSchema() throws Exception { "1|40|400|binary|varbinary|mapKey:mapVal|multiset|4000")); } + @Test + public void testCopyWithLatestSchemaPicksUpAlteredOptions() throws Exception { + FileStoreTable table = createFileStoreTable(); + SchemaManager schemaManager = new SchemaManager(table.fileIO(), table.location()); + + schemaManager.commitChanges(SchemaChange.setOption("my-custom-key", "my-custom-value")); + + FileStoreTable updated = table.copyWithLatestSchema(); + assertThat(updated.schema().options()).containsEntry("my-custom-key", "my-custom-value"); + } + @Test public void testConsumerIdNotBlank() throws Exception { FileStoreTable table = diff --git a/paimon-core/src/test/java/org/apache/paimon/table/format/FormatTableScanTest.java b/paimon-core/src/test/java/org/apache/paimon/table/format/FormatTableScanTest.java index e6602fb716fd..cda08b61e725 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/format/FormatTableScanTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/format/FormatTableScanTest.java @@ -237,6 +237,31 @@ void testGetScanPathAndLevelWithEqualityFilter() throws IOException { assertThat(searched.size()).isEqualTo(1); } + @TestTemplate + void testComputeScanPathWithDateEqualityFilter() { + Path tableLocation = new Path(tmpPath.toUri()); + RowType datePartitionType = RowType.builder().field("dt", DataTypes.DATE()).build(); + List datePartitionKeys = datePartitionType.getFieldNames(); + + PredicateBuilder builder = new PredicateBuilder(datePartitionType); + Predicate equalityPredicate = + builder.equal(0, (int) java.time.LocalDate.parse("2026-05-01").toEpochDay()); + PartitionPredicate partitionFilter = + PartitionPredicate.fromPredicate(datePartitionType, equalityPredicate); + + Pair result = + FormatTableScan.computeScanPathAndLevel( + tableLocation, + datePartitionKeys, + partitionFilter, + datePartitionType, + enablePartitionValueOnly); + String partitionPath = enablePartitionValueOnly ? "2026-05-01" : "dt=2026-05-01"; + + assertThat(result.getLeft().toString()).isEqualTo(tableLocation + partitionPath); + assertThat(result.getRight()).isEqualTo(0); + } + @TestTemplate void testComputeScanPathWithFirstLevel() throws IOException { Path tableLocation = new Path(tmpPath.toUri()); @@ -554,7 +579,7 @@ public void testExtractEqualityPartitionSpecWithAllEqualityWhenAllIsAnd() { Map result = FormatTableScan.extractLeadingEqualityPartitionSpecWhenOnlyAnd( - partitionKeys, equalityPredicate); + partitionKeys, equalityPredicate, type); assertThat(result).hasSize(3); assertThat(result.get("year")).isEqualTo("2023"); @@ -581,7 +606,7 @@ public void testExtractEqualityPartitionSpecWithLeadingConsecutiveEqualityWhenAl Map result = FormatTableScan.extractLeadingEqualityPartitionSpecWhenOnlyAnd( - partitionKeys, mixedPredicate); + partitionKeys, mixedPredicate, type); assertThat(result).isNotNull(); assertThat(result).hasSize(2); @@ -609,7 +634,7 @@ public void testExtractEqualityPartitionSpecWithFirstPartitionKeyEqualityWhenAll Map result = FormatTableScan.extractLeadingEqualityPartitionSpecWhenOnlyAnd( - partitionKeys, mixedPredicate); + partitionKeys, mixedPredicate, type); assertThat(result).hasSize(1); assertThat(result.get("year")).isEqualTo("2023"); assertThat(result.containsKey("month")).isFalse(); @@ -635,7 +660,7 @@ public void testExtractEqualityPartitionSpecWithNoLeadingEqualityWhenAllIsAnd() Map result = FormatTableScan.extractLeadingEqualityPartitionSpecWhenOnlyAnd( - partitionKeys, mixedPredicate); + partitionKeys, mixedPredicate, type); assertThat(result).isEmpty(); } @@ -656,7 +681,7 @@ public void testExtractEqualityPartitionSpecWithNonEqualityPredicateWhenAllIsAnd Map result = FormatTableScan.extractLeadingEqualityPartitionSpecWhenOnlyAnd( - partitionKeys, nonEqualityPredicate); + partitionKeys, nonEqualityPredicate, type); assertThat(result).isEmpty(); } @@ -676,7 +701,7 @@ public void testExtractLeadingEqualityPartitionSpecWhenOnlyAndWithOrPredicate() Map result = FormatTableScan.extractLeadingEqualityPartitionSpecWhenOnlyAnd( - partitionKeys, orPredicate); + partitionKeys, orPredicate, type); assertThat(result).isEmpty(); } diff --git a/paimon-core/src/test/java/org/apache/paimon/table/source/DvAwareStatsTest.java b/paimon-core/src/test/java/org/apache/paimon/table/source/DvAwareStatsTest.java new file mode 100644 index 000000000000..8cb07245f948 --- /dev/null +++ b/paimon-core/src/test/java/org/apache/paimon/table/source/DvAwareStatsTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.io.DataFileMeta; +import org.apache.paimon.io.DataFileTestUtils; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link DvAwareStats}. */ +public class DvAwareStatsTest { + + @Test + public void testNoDvAndEmptyDeleteRowCountIsTight() { + DataFileMeta file = DataFileTestUtils.newFile(); + assertThat(DvAwareStats.isTightBounds(file, null)).isTrue(); + } + + @Test + public void testNoDvAndPositiveDeleteRowCountIsWide() { + DataFileMeta file = DataFileTestUtils.newFile("f1", 0, 0, 9, 10L, 5L); + assertThat(DvAwareStats.isTightBounds(file, null)).isFalse(); + } + + @Test + public void testEmptyDvIsTight() { + DataFileMeta file = DataFileTestUtils.newFile(); + DeletionFile dv = new DeletionFile("dv-1", 0L, 0L, 0L); + assertThat(DvAwareStats.isTightBounds(file, dv)).isTrue(); + } + + @Test + public void testPopulatedDvIsWide() { + DataFileMeta file = DataFileTestUtils.newFile(); + DeletionFile dv = new DeletionFile("dv-1", 0L, 16L, 5L); + assertThat(DvAwareStats.isTightBounds(file, dv)).isFalse(); + } + + @Test + public void testUnknownCardinalityDvIsConservativelyWide() { + DataFileMeta file = DataFileTestUtils.newFile(); + DeletionFile dv = new DeletionFile("dv-1", 0L, 16L, null); + assertThat(DvAwareStats.isTightBounds(file, dv)).isFalse(); + } + + @Test + public void testNoDvAndZeroDeleteRowCountIsTight() { + DataFileMeta file = DataFileTestUtils.newFile("f1", 0, 0, 9, 10L, 0L); + assertThat(DvAwareStats.isTightBounds(file, null)).isTrue(); + } + + @Test + public void testBothDeleteRowCountAndPopulatedDvIsWide() { + DataFileMeta file = DataFileTestUtils.newFile("f1", 0, 0, 9, 10L, 5L); + DeletionFile dv = new DeletionFile("dv-1", 0L, 16L, 3L); + assertThat(DvAwareStats.isTightBounds(file, dv)).isFalse(); + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/table/source/PushDownUtilsTest.java b/paimon-core/src/test/java/org/apache/paimon/table/source/PushDownUtilsTest.java new file mode 100644 index 000000000000..b1c3e5003ad6 --- /dev/null +++ b/paimon-core/src/test/java/org/apache/paimon/table/source/PushDownUtilsTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.io.DataFileMeta; +import org.apache.paimon.io.DataFileTestUtils; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.OptionalLong; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link PushDownUtils}. */ +public class PushDownUtilsTest { + + @Test + public void testNonDataSplitReturnsFalse() { + Split nonDataSplit = + new Split() { + @Override + public long rowCount() { + return 0; + } + + @Override + public OptionalLong mergedRowCount() { + return OptionalLong.empty(); + } + }; + assertThat(PushDownUtils.tightBoundsAvailable(nonDataSplit)).isFalse(); + } + + @Test + public void testEmptySplitReturnsTrue() { + DataSplit split = newSplit(Collections.emptyList(), null); + assertThat(PushDownUtils.tightBoundsAvailable(split)).isTrue(); + } + + @Test + public void testAllFilesTightReturnsTrue() { + List files = + Arrays.asList( + DataFileTestUtils.newFile(), + DataFileTestUtils.newFile(), + DataFileTestUtils.newFile()); + DataSplit split = newSplit(files, null); + assertThat(PushDownUtils.tightBoundsAvailable(split)).isTrue(); + } + + @Test + public void testAnyFileWithPopulatedDvReturnsFalse() { + List files = + Arrays.asList( + DataFileTestUtils.newFile(), + DataFileTestUtils.newFile(), + DataFileTestUtils.newFile()); + List dvs = + Arrays.asList( + new DeletionFile("dv-0", 0L, 0L, 0L), + new DeletionFile("dv-1", 0L, 16L, 5L), + new DeletionFile("dv-2", 0L, 0L, 0L)); + DataSplit split = newSplit(files, dvs); + assertThat(PushDownUtils.tightBoundsAvailable(split)).isFalse(); + } + + @Test + public void testFileWithDeleteRowCountReturnsFalse() { + List files = + Arrays.asList( + DataFileTestUtils.newFile("f0", 0, 0, 9, 10L, 5L), + DataFileTestUtils.newFile(), + DataFileTestUtils.newFile()); + List dvs = + Arrays.asList( + new DeletionFile("dv-0", 0L, 0L, 0L), + new DeletionFile("dv-1", 0L, 0L, 0L), + new DeletionFile("dv-2", 0L, 0L, 0L)); + DataSplit split = newSplit(files, dvs); + assertThat(PushDownUtils.tightBoundsAvailable(split)).isFalse(); + } + + private static DataSplit newSplit(List files, List deletionFiles) { + DataSplit.Builder builder = + DataSplit.builder() + .withSnapshot(1L) + .withPartition(BinaryRow.EMPTY_ROW) + .withBucket(0) + .withBucketPath("dummy") + .withDataFiles(files); + if (deletionFiles != null) { + builder.withDataDeletionFiles(deletionFiles); + } + return builder.build(); + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/utils/DVMetaCacheTest.java b/paimon-core/src/test/java/org/apache/paimon/utils/DVMetaCacheTest.java index c0c28ce3abcd..ab1ae19cfbd9 100644 --- a/paimon-core/src/test/java/org/apache/paimon/utils/DVMetaCacheTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/utils/DVMetaCacheTest.java @@ -31,6 +31,12 @@ import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; @@ -101,10 +107,92 @@ public void testEmptyMap() { } @Test - public void testCacheEviction() { - DVMetaCache cache = new DVMetaCache(5); + public void testLazyValue() { + DVMetaCache cache = new DVMetaCache(100); + Path path = new Path("manifest/index-manifest-00003"); + BinaryRow partition = partition("year=2023/month=10"); + AtomicInteger invoked = new AtomicInteger(); + + cache.putLazy( + path, + partition, + 1, + 1, + () -> { + invoked.incrementAndGet(); + Map dvFiles = new HashMap<>(); + dvFiles.put( + "data-d4e5f6g7-h8i9-0123-defg-456789012345-1.parquet", + new DeletionFile( + "index-d4e5f6g7-h8i9-0123-defg-456789012345-1", 0L, 100L, 1L)); + return dvFiles; + }); + + assertThat(invoked).hasValue(0); + + Map result1 = cache.read(path, partition, 1); + assertThat(result1).isNotNull().hasSize(1); + assertThat(invoked).hasValue(1); + + Map result2 = cache.read(path, partition, 1); + assertThat(result2).isSameAs(result1); + assertThat(invoked).hasValue(1); + } + + @Test + public void testLazyValueInitializedOnceConcurrently() throws Exception { + DVMetaCache cache = new DVMetaCache(100); Path path = new Path("manifest/index-manifest-00004"); BinaryRow partition = partition("year=2023/month=09"); + AtomicInteger invoked = new AtomicInteger(); + CountDownLatch supplierEntered = new CountDownLatch(1); + CountDownLatch releaseSupplier = new CountDownLatch(1); + Map dvFiles = new HashMap<>(); + dvFiles.put( + "data-d4e5f6g7-h8i9-0123-defg-456789012345-1.parquet", + new DeletionFile("index-d4e5f6g7-h8i9-0123-defg-456789012345-1", 0L, 100L, 1L)); + + cache.putLazy( + path, + partition, + 1, + 1, + () -> { + invoked.incrementAndGet(); + supplierEntered.countDown(); + try { + assertThat(releaseSupplier.await(5, TimeUnit.SECONDS)).isTrue(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + return dvFiles; + }); + + ExecutorService executor = Executors.newFixedThreadPool(2); + try { + Future> first = + executor.submit(() -> cache.read(path, partition, 1)); + assertThat(supplierEntered.await(5, TimeUnit.SECONDS)).isTrue(); + + Future> second = + executor.submit(() -> cache.read(path, partition, 1)); + releaseSupplier.countDown(); + + assertThat(first.get(5, TimeUnit.SECONDS)).isSameAs(dvFiles); + assertThat(second.get(5, TimeUnit.SECONDS)).isSameAs(dvFiles); + assertThat(invoked).hasValue(1); + } finally { + releaseSupplier.countDown(); + executor.shutdownNow(); + } + } + + @Test + public void testCacheEviction() { + DVMetaCache cache = new DVMetaCache(5); + Path path = new Path("manifest/index-manifest-00005"); + BinaryRow partition = partition("year=2023/month=08"); // Fill cache to capacity Map dvFiles1 = new HashMap<>(); diff --git a/paimon-filesystems/paimon-hadoop-uber/pom.xml b/paimon-filesystems/paimon-hadoop-uber/pom.xml index fe14145f29d9..07991ccd1843 100644 --- a/paimon-filesystems/paimon-hadoop-uber/pom.xml +++ b/paimon-filesystems/paimon-hadoop-uber/pom.xml @@ -650,6 +650,7 @@ + org.apache.paimon:paimon-hadoop-shaded com.google.guava:* com.google.protobuf:* com.google.code.findbugs:* diff --git a/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplit.java b/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplit.java index ed84e141330c..aa059225204b 100644 --- a/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplit.java +++ b/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplit.java @@ -69,6 +69,12 @@ public long getSchemaId() { return schemaId; } + @Override + public TableAwareFileStoreSourceSplit updateWithRecordsToSkip(long recordsToSkip) { + return new TableAwareFileStoreSourceSplit( + splitId(), split(), recordsToSkip, identifier, lastSchemaId, schemaId); + } + @Override public boolean equals(Object o) { if (!(o instanceof TableAwareFileStoreSourceSplit)) { diff --git a/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReader.java b/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReader.java index 48c0ef82f25a..9aeb48567d0b 100644 --- a/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReader.java +++ b/paimon-flink/paimon-flink-cdc/src/main/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReader.java @@ -76,8 +76,9 @@ public class CDCSourceSplitReader private TableReaderInfo currentTableReaderInfo; @Nullable private LazyRecordReader currentReader; @Nullable private String currentSplitId; - private long currentNumRead; + private long currentDataRowsRead; private RecordIterator currentFirstBatch; + private final Queue currentSchemaChangeEvents = new LinkedList<>(); private boolean paused; private final AtomicBoolean wakeup; @@ -183,6 +184,7 @@ public void wakeUp() { @Override public void close() throws Exception { + currentSchemaChangeEvents.clear(); if (currentReader != null) { if (currentReader.lazyRecordReader != null) { currentReader.lazyRecordReader.close(); @@ -214,13 +216,17 @@ private void checkSplitOrStartNext() throws IOException { List schemaChangeEvents = tableManager.generateSchemaChangeEventList( identifier, nextSplit.getLastSchemaId(), nextSplit.getSchemaId()); - currentTableReaderInfo = new TableReaderInfo(identifier, tableSchema, schemaChangeEvents); + currentTableReaderInfo = new TableReaderInfo(identifier, tableSchema); currentSplitId = nextSplit.splitId(); currentReader = createLazyRecordReader(nextSplit.split()); - currentNumRead = nextSplit.recordsToSkip(); + currentDataRowsRead = nextSplit.recordsToSkip(); + currentSchemaChangeEvents.clear(); + if (currentDataRowsRead == 0) { + currentSchemaChangeEvents.addAll(schemaChangeEvents); + } - if (currentNumRead > 0) { - seek(currentNumRead); + if (currentDataRowsRead > 0) { + seek(currentDataRowsRead); } } @@ -257,6 +263,7 @@ private CDCRecordsWithSplitIds finishSplit() throws IOException { currentReader = null; } + currentSchemaChangeEvents.clear(); final CDCRecordsWithSplitIds finishRecords = CDCRecordsWithSplitIds.finishedSplit(currentSplitId); currentSplitId = null; @@ -271,33 +278,24 @@ private class FileStoreRecordIterator implements BulkFormat.RecordIterator(); private TableReaderInfo tableReaderInfo; - private final Queue schemaChangeEventList = new LinkedList<>(); public FileStoreRecordIterator replace( RecordIterator iterator, TableReaderInfo tableReaderInfo) { this.iterator = iterator; - this.recordAndPosition.set(null, RecordAndPosition.NO_OFFSET, currentNumRead); + this.recordAndPosition.set(null, RecordAndPosition.NO_OFFSET, currentDataRowsRead); this.tableReaderInfo = tableReaderInfo; - this.schemaChangeEventList.addAll(tableReaderInfo.schemaChangeEvents); return this; } @Nullable @Override public RecordAndPosition next() { - Event event = nextEvent(); - if (event == null) { - return null; - } - - recordAndPosition.setNext(event); - currentNumRead++; - return recordAndPosition; - } - - private Event nextEvent() { - if (!schemaChangeEventList.isEmpty()) { - return schemaChangeEventList.poll(); + if (!currentSchemaChangeEvents.isEmpty()) { + recordAndPosition.set( + currentSchemaChangeEvents.poll(), + RecordAndPosition.NO_OFFSET, + currentDataRowsRead); + return recordAndPosition; } InternalRow row; @@ -310,11 +308,14 @@ private Event nextEvent() { return null; } - return convertRowToDataChangeEvent( - tableReaderInfo.tableId, - row, - tableReaderInfo.fieldGetters, - tableReaderInfo.generator); + recordAndPosition.setNext( + convertRowToDataChangeEvent( + tableReaderInfo.tableId, + row, + tableReaderInfo.fieldGetters, + tableReaderInfo.generator)); + currentDataRowsRead++; + return recordAndPosition; } @Override @@ -358,19 +359,14 @@ private static class TableReaderInfo { private final Identifier identifier; private final TableId tableId; private final TableSchema currentSchema; - private final List schemaChangeEvents; private final BinaryRecordDataGenerator generator; private final List fieldGetters; - private TableReaderInfo( - Identifier identifier, - TableSchema currentSchema, - List schemaChangeEvents) { + private TableReaderInfo(Identifier identifier, TableSchema currentSchema) { this.identifier = identifier; this.tableId = TableId.tableId(identifier.getDatabaseName(), identifier.getTableName()); this.currentSchema = currentSchema; - this.schemaChangeEvents = schemaChangeEvents; org.apache.flink.cdc.common.schema.Schema currentCDCSchema = convertPaimonSchemaToFlinkCDCSchema(currentSchema); diff --git a/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/action/cdc/mysql/MySqlSyncDatabaseActionITCase.java b/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/action/cdc/mysql/MySqlSyncDatabaseActionITCase.java index a0026a528ba1..2af4b678a7cc 100644 --- a/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/action/cdc/mysql/MySqlSyncDatabaseActionITCase.java +++ b/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/action/cdc/mysql/MySqlSyncDatabaseActionITCase.java @@ -969,7 +969,7 @@ public void testSyncManyTableWithLimitedMemory() throws Exception { } @Test - @Timeout(60) + @Timeout(120) public void testSyncMultipleShards() throws Exception { Map mySqlConfig = getBasicMySqlConfig(); diff --git a/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplitSerializerTest.java b/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplitSerializerTest.java index 901c7d0b1ad4..79ccabf518be 100644 --- a/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplitSerializerTest.java +++ b/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/TableAwareFileStoreSourceSplitSerializerTest.java @@ -19,8 +19,11 @@ package org.apache.paimon.flink.pipeline.cdc.source; import org.apache.paimon.catalog.Identifier; +import org.apache.paimon.flink.source.FileStoreSourceSplit; +import org.apache.paimon.flink.source.FileStoreSourceSplitState; import org.apache.paimon.table.source.DataSplit; +import org.apache.flink.connector.file.src.util.RecordAndPosition; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -37,18 +40,9 @@ public class TableAwareFileStoreSourceSplitSerializerTest { @Test public void test() throws Exception { Identifier identifier = Identifier.create("test_database", "test_table"); - DataSplit dataSplit = - DataSplit.builder() - .withSnapshot(1) - .withPartition(row(1)) - .withBucket(2) - .withDataFiles(Arrays.asList(newFile(0), newFile(1))) - .isStreaming(false) - .rawConvertible(false) - .withBucketPath("/temp/2") // not used - .build(); TableAwareFileStoreSourceSplit split = - new TableAwareFileStoreSourceSplit("split-1", dataSplit, 0L, identifier, null, 1L); + new TableAwareFileStoreSourceSplit( + "split-1", newDataSplit(), 0L, identifier, null, 1L); TableAwareFileStoreSourceSplit.Serializer serializer = new TableAwareFileStoreSourceSplit.Serializer(); @@ -57,4 +51,38 @@ public void test() throws Exception { serializer.deserialize(serializer.getVersion(), serialized); assertThat(deserialized).isEqualTo(split); } + + @Test + public void testUpdateWithRecordsToSkipKeepsTableAwareSplit() { + Identifier identifier = Identifier.create("test_database", "test_table"); + DataSplit dataSplit = newDataSplit(); + TableAwareFileStoreSourceSplit split = + new TableAwareFileStoreSourceSplit("split-1", dataSplit, 0L, identifier, 1L, 2L); + FileStoreSourceSplitState state = new FileStoreSourceSplitState(split); + + state.setPosition(new RecordAndPosition<>(null, RecordAndPosition.NO_OFFSET, 10L)); + + FileStoreSourceSplit restored = state.toSourceSplit(); + assertThat(restored).isInstanceOf(TableAwareFileStoreSourceSplit.class); + TableAwareFileStoreSourceSplit tableAwareRestored = + (TableAwareFileStoreSourceSplit) restored; + assertThat(tableAwareRestored.splitId()).isEqualTo(split.splitId()); + assertThat(tableAwareRestored.split()).isEqualTo(split.split()); + assertThat(tableAwareRestored.recordsToSkip()).isEqualTo(10L); + assertThat(tableAwareRestored.getIdentifier()).isEqualTo(identifier); + assertThat(tableAwareRestored.getLastSchemaId()).isEqualTo(1L); + assertThat(tableAwareRestored.getSchemaId()).isEqualTo(2L); + } + + private static DataSplit newDataSplit() { + return DataSplit.builder() + .withSnapshot(1) + .withPartition(row(1)) + .withBucket(2) + .withDataFiles(Arrays.asList(newFile(0), newFile(1))) + .isStreaming(false) + .rawConvertible(false) + .withBucketPath("/temp/2") // not used + .build(); + } } diff --git a/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReaderTest.java b/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReaderTest.java index e12ff8f19e9a..82199811b5e3 100644 --- a/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReaderTest.java +++ b/paimon-flink/paimon-flink-cdc/src/test/java/org/apache/paimon/flink/pipeline/cdc/source/reader/CDCSourceSplitReaderTest.java @@ -42,9 +42,13 @@ import org.apache.paimon.utils.RecordWriter; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.cdc.common.event.AddColumnEvent; import org.apache.flink.cdc.common.event.DataChangeEvent; import org.apache.flink.cdc.common.event.Event; import org.apache.flink.cdc.common.event.OperationType; +import org.apache.flink.cdc.common.event.SchemaChangeEvent; +import org.apache.flink.cdc.common.event.TableId; +import org.apache.flink.cdc.common.schema.Column; import org.apache.flink.connector.base.source.reader.RecordsWithSplitIds; import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition; import org.apache.flink.connector.base.source.reader.splitreader.SplitsChange; @@ -62,6 +66,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -144,8 +149,15 @@ public void testSplitReaderWakeupAble() throws Exception { } private CDCSourceSplitReader createReader(TableRead tableRead) { + return createReader(tableRead, Collections.emptyList()); + } + + private CDCSourceSplitReader createReader( + TableRead tableRead, List schemaChangeEvents) { return new TestCDCSourceSplitReader( - new FileStoreSourceReaderMetrics(new DummyMetricGroup()), tableRead); + new FileStoreSourceReaderMetrics(new DummyMetricGroup()), + tableRead, + schemaChangeEvents); } private void innerTestOnce(int skip) throws Exception { @@ -251,6 +263,92 @@ public void testMultipleBatchInSplit() throws Exception { reader.close(); } + @Test + public void testSchemaChangeEventOnlyEmittedOnceInMultipleBatchSplit() throws Exception { + TestChangelogDataReadWrite rw = new TestChangelogDataReadWrite(tablePath); + CDCSourceSplitReader reader = createReader(rw.createReadWithKey(), schemaChangeEvents()); + + List> input1 = kvs(); + List files = rw.writeFiles(row(1), 0, input1); + + List> input2 = kvs(6); + List files2 = rw.writeFiles(row(1), 0, input2); + files.addAll(files2); + + assignSplit(reader, newSourceSplit("id1", row(1), 0, files)); + + RecordsWithSplitIds> records = reader.fetch(); + assertThat(readEventTypes(records, "id1")) + .containsExactly( + SchemaChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class); + + records = reader.fetch(); + assertThat(readEventTypes(records, "id1")) + .containsExactly( + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class, + DataChangeEvent.class); + + records = reader.fetch(); + assertRecords(records, "id1", "id1", 0, null); + + reader.close(); + } + + @Test + public void testSchemaChangeEventDoesNotAdvanceRecordsToSkip() throws Exception { + TestChangelogDataReadWrite rw = new TestChangelogDataReadWrite(tablePath); + CDCSourceSplitReader reader = createReader(rw.createReadWithKey(), schemaChangeEvents()); + + List> input = kvs(); + List files = rw.writeFiles(row(1), 0, input); + + assignSplit(reader, newSourceSplit("id1", row(1), 0, files)); + + RecordsWithSplitIds> records = reader.fetch(); + assertThat(readRecordSkipCounts(records, "id1")) + .containsExactly(0L, 1L, 2L, 3L, 4L, 5L, 6L); + + reader.close(); + } + + @Test + public void testRestoreWithSchemaChangeEventsDoesNotReemitSchemaEvent() throws Exception { + TestChangelogDataReadWrite rw = new TestChangelogDataReadWrite(tablePath); + CDCSourceSplitReader reader = createReader(rw.createReadWithKey(), schemaChangeEvents()); + + List> input1 = kvs(); + List files = rw.writeFiles(row(1), 0, input1); + + List> input2 = kvs(6); + List files2 = rw.writeFiles(row(1), 0, input2); + files.addAll(files2); + + assignSplit(reader, newSourceSplit("id1", row(1), 0, files, input1.size())); + + RecordsWithSplitIds> records = reader.fetch(); + assertRecords(records, null, "id1", input1.size(), Collections.emptyList()); + + records = reader.fetch(); + assertRecords( + records, + null, + "id1", + input1.size(), + input2.stream().map(t -> t.f1).collect(Collectors.toList())); + + reader.close(); + } + @Test public void testRestore() throws Exception { TestChangelogDataReadWrite rw = new TestChangelogDataReadWrite(tablePath); @@ -456,6 +554,53 @@ private List> readRecords( return result; } + private List> readEventTypes( + RecordsWithSplitIds> records, String nextSplit) { + assertThat(records.finishedSplits()).isEmpty(); + assertThat(records.nextSplit()).isEqualTo(nextSplit); + List> result = new ArrayList<>(); + RecordIterator iterator; + while ((iterator = records.nextRecordFromSplit()) != null) { + RecordAndPosition record; + while ((record = iterator.next()) != null) { + result.add( + record.getRecord() instanceof SchemaChangeEvent + ? SchemaChangeEvent.class + : DataChangeEvent.class); + } + } + records.recycle(); + return result; + } + + private List readRecordSkipCounts( + RecordsWithSplitIds> records, String nextSplit) { + assertThat(records.finishedSplits()).isEmpty(); + assertThat(records.nextSplit()).isEqualTo(nextSplit); + List result = new ArrayList<>(); + RecordIterator iterator; + while ((iterator = records.nextRecordFromSplit()) != null) { + RecordAndPosition record; + while ((record = iterator.next()) != null) { + result.add(record.getRecordSkipCount()); + } + } + records.recycle(); + return result; + } + + private List schemaChangeEvents() { + return Collections.singletonList( + new AddColumnEvent( + TableId.tableId(DATABASE, TABLE), + Arrays.asList( + AddColumnEvent.last( + Column.physicalColumn( + "extra", + org.apache.flink.cdc.common.types.DataTypes + .BIGINT()))))); + } + private List> kvs() { return kvs(0); } @@ -524,8 +669,11 @@ public static TableAwareFileStoreSourceSplit newSourceSplit( private static class TestCDCSourceSplitReader extends CDCSourceSplitReader { private final TableRead tableRead; - public TestCDCSourceSplitReader(FileStoreSourceReaderMetrics metrics, TableRead tableRead) { - super(metrics, new TestTableManager(tableRead)); + public TestCDCSourceSplitReader( + FileStoreSourceReaderMetrics metrics, + TableRead tableRead, + List schemaChangeEvents) { + super(metrics, new TestTableManager(tableRead, schemaChangeEvents)); this.tableRead = tableRead; } @@ -555,10 +703,16 @@ public RecordReader recordReader() throws IOException { private static class TestTableManager extends CDCSource.TableManager { private final TableRead tableRead; + private final List schemaChangeEvents; public TestTableManager(TableRead tableRead) { + this(tableRead, Collections.emptyList()); + } + + public TestTableManager(TableRead tableRead, List schemaChangeEvents) { super(null, null, null); this.tableRead = tableRead; + this.schemaChangeEvents = schemaChangeEvents; } @Override @@ -571,5 +725,11 @@ public TestTableManager(TableRead tableRead) { public TableRead getTableRead(Identifier identifier, TableSchema schema) { return tableRead; } + + @Override + public List generateSchemaChangeEventList( + Identifier identifier, @Nullable Long lastSchemaId, long schemaId) { + return schemaChangeEvents; + } } } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/action/ExpirePartitionsAction.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/action/ExpirePartitionsAction.java index ff4cb93426f7..7efb65d2d61d 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/action/ExpirePartitionsAction.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/action/ExpirePartitionsAction.java @@ -23,6 +23,7 @@ import org.apache.paimon.catalog.Identifier; import org.apache.paimon.operation.PartitionExpire; import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.utils.Preconditions; import org.apache.paimon.utils.TimeUtils; import java.time.Duration; @@ -64,6 +65,7 @@ public ExpirePartitionsAction( public void executeLocally() throws Exception { FileStoreTable fileStoreTable = (FileStoreTable) table; FileStore fileStore = fileStoreTable.store(); + PartitionExpire partitionExpire = fileStore.newPartitionExpire( "", @@ -76,7 +78,9 @@ public void executeLocally() throws Exception { catalogLoader(), new Identifier( identifier.getDatabaseName(), identifier.getTableName()))); - + Preconditions.checkNotNull( + partitionExpire, + "Both the partition expiration time and partition field can not be null."); partitionExpire.expire(Long.MAX_VALUE); } } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java index 3d41de6f9d55..a7699d578b8c 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/lineage/LineageUtils.java @@ -19,13 +19,19 @@ package org.apache.paimon.flink.lineage; import org.apache.paimon.CoreOptions; +import org.apache.paimon.catalog.CatalogContext; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.FormatTable; import org.apache.paimon.table.Table; +import org.apache.paimon.utils.StringUtils; import org.apache.flink.api.connector.source.Boundedness; import org.apache.flink.streaming.api.lineage.LineageDataset; import org.apache.flink.streaming.api.lineage.LineageVertex; import org.apache.flink.streaming.api.lineage.SourceLineageVertex; +import javax.annotation.Nullable; + import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -39,10 +45,23 @@ public class LineageUtils { private static final String PAIMON_DATASET_PREFIX = "paimon://"; + private static final String CATALOG_KEY = "catalog-key"; private static final Set PAIMON_OPTION_KEYS = CoreOptions.getOptions().stream().map(opt -> opt.key()).collect(Collectors.toSet()); + /** Extracts the {@link CatalogContext} from a table, or null if not available. */ + @Nullable + private static CatalogContext catalogContext(Table table) { + if (table instanceof FileStoreTable) { + return ((FileStoreTable) table).catalogEnvironment().catalogContext(); + } + if (table instanceof FormatTable) { + return ((FormatTable) table).catalogContext(); + } + return null; + } + /** * Builds the config map for a dataset facet from a {@link Table}. Includes filtered Paimon * {@link CoreOptions}, partition keys, primary keys, and the table comment (if present). @@ -85,6 +104,29 @@ public static SourceLineageVertex sourceLineageVertex( return new PaimonSourceLineageVertex(boundedness, Collections.singletonList(dataset)); } + private static String getFullName(Table table) { + String name = table.fullName(); + CatalogContext ctx = catalogContext(table); + if (ctx != null) { + String catalogKey = ctx.options().toMap().get(CATALOG_KEY); + if (!StringUtils.isNullOrWhitespaceOnly(catalogKey)) { + name = catalogKey + "." + name; + } + } + return name; + } + + /** + * Creates a {@link SourceLineageVertex} for a Paimon DataStream source table. The table name is + * derived from the table's full name, prefixed with the {@code catalog-key} if available. + * + * @param isBounded whether the source is bounded (batch) or unbounded (streaming) + * @param table the Paimon table + */ + public static SourceLineageVertex sourceLineageVertex(boolean isBounded, Table table) { + return sourceLineageVertex(getFullName(table), isBounded, table); + } + /** * Creates a {@link LineageVertex} for a Paimon sink table. * @@ -97,4 +139,14 @@ public static LineageVertex sinkLineageVertex(String name, Table table) { name, getNamespace(table), buildConfigMap(table), table.rowType()); return new PaimonSinkLineageVertex(Collections.singletonList(dataset)); } + + /** + * Creates a {@link LineageVertex} for a Paimon DataStream sink table. The table name is derived + * from the table's full name, prefixed with the {@code catalog-key} if available. + * + * @param table the Paimon table + */ + public static LineageVertex sinkLineageVertex(Table table) { + return sinkLineageVertex(getFullName(table), table); + } } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/procedure/ExpirePartitionsProcedure.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/procedure/ExpirePartitionsProcedure.java index 8a6528bd0229..ecddcf11e16e 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/procedure/ExpirePartitionsProcedure.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/procedure/ExpirePartitionsProcedure.java @@ -90,11 +90,9 @@ public String identifier() { FileStore fileStore = fileStoreTable.store(); PartitionExpire partitionExpire = fileStore.newPartitionExpire("", fileStoreTable); - Preconditions.checkNotNull( partitionExpire, "Both the partition expiration time and partition field can not be null."); - List> expired = partitionExpire.expire(Long.MAX_VALUE); return expired == null || expired.isEmpty() ? new Row[] {Row.of("No expired partitions.")} diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java index 7305076cba57..a223ad51295e 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkFormatTableDataStreamSink.java @@ -73,7 +73,7 @@ public FormatTableSink( @Override public LineageVertex getLineageVertex() { - return LineageUtils.sinkLineageVertex(table.fullName(), table); + return LineageUtils.sinkLineageVertex(table); } /** diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java index c5949dfa07ac..84c504166966 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/PaimonDiscardingSink.java @@ -41,6 +41,6 @@ public PaimonDiscardingSink(FileStoreTable table) { @Override public LineageVertex getLineageVertex() { - return LineageUtils.sinkLineageVertex(table.fullName(), table); + return LineageUtils.sinkLineageVertex(table); } } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java index 95999ab39e3f..bfef8bc6309d 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/PaimonDataStreamSource.java @@ -84,7 +84,6 @@ public SimpleVersionedSerializer getEnumeratorCheckpointSerializer( @Override public LineageVertex getLineageVertex() { - return LineageUtils.sourceLineageVertex( - table.fullName(), getBoundedness() == Boundedness.BOUNDED, table); + return LineageUtils.sourceLineageVertex(getBoundedness() == Boundedness.BOUNDED, table); } } diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/metrics/FileStoreSourceReaderMetrics.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/metrics/FileStoreSourceReaderMetrics.java index a270e0eceecd..cdbf770c1fc9 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/metrics/FileStoreSourceReaderMetrics.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/metrics/FileStoreSourceReaderMetrics.java @@ -27,7 +27,6 @@ public class FileStoreSourceReaderMetrics { private long latestFileCreationTime = UNDEFINED; private long lastSplitUpdateTime = UNDEFINED; - public static final long UNDEFINED = -1L; public static final long ACTIVE = Long.MAX_VALUE; diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/ReadOperator.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/ReadOperator.java index a9b9767041e6..30b2ab0d62e4 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/ReadOperator.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/operator/ReadOperator.java @@ -22,6 +22,7 @@ import org.apache.paimon.disk.IOManager; import org.apache.paimon.flink.FlinkRowData; import org.apache.paimon.flink.NestedProjectedRowData; +import org.apache.paimon.flink.source.RecordLimiter; import org.apache.paimon.flink.source.metrics.FileStoreSourceReaderMetrics; import org.apache.paimon.table.source.DataSplit; import org.apache.paimon.table.source.Split; @@ -68,6 +69,7 @@ public class ReadOperator extends AbstractStreamOperator private transient long emitEventTimeLag = FileStoreSourceReaderMetrics.UNDEFINED; private transient long idleStartTime = FileStoreSourceReaderMetrics.ACTIVE; private transient Counter numRecordsIn; + @Nullable private transient RecordLimiter recordLimiter; @Nullable private final Long limit; public ReadOperator( @@ -98,6 +100,7 @@ public void open() throws Exception { .getIOManager() .getSpillingDirectoriesPaths()); this.read = readSupplier.get().withIOManager(ioManager); + this.recordLimiter = RecordLimiter.create(limit); this.reuseRow = new FlinkRowData(null); this.reuseRecord = new StreamRecord<>(null); this.idlingStarted(); @@ -122,7 +125,7 @@ public void processElement(StreamRecord record) throws Exception { boolean firstRecord = true; try (CloseableIterator iterator = read.createReader(split).toCloseableIterator()) { - while (iterator.hasNext()) { + while (!reachLimit() && iterator.hasNext()) { emitEventTimeLag = System.currentTimeMillis() - eventTime; // each Split is already counted as one input record, @@ -133,10 +136,6 @@ public void processElement(StreamRecord record) throws Exception { numRecordsIn.inc(); } - if (reachLimit()) { - return; - } - reuseRow.replace(iterator.next()); if (nestedProjectedRowData == null) { reuseRecord.replace(reuseRow); @@ -145,6 +144,10 @@ public void processElement(StreamRecord record) throws Exception { reuseRecord.replace(nestedProjectedRowData); } output.collect(reuseRecord); + + if (recordLimiter != null) { + recordLimiter.increment(); + } } } // start idle when data sending is completed @@ -160,7 +163,7 @@ public void close() throws Exception { } private boolean reachLimit() { - if (limit != null && numRecordsIn.getCount() > limit) { + if (recordLimiter != null && recordLimiter.reachLimit()) { LOG.info("Reader {} reach the limit record {}.", this, limit); return true; } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/BatchFileStoreITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/BatchFileStoreITCase.java index 93ef7fbe4d2e..a3cbc264b96c 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/BatchFileStoreITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/BatchFileStoreITCase.java @@ -1415,6 +1415,38 @@ public void testBatchReadSourceWithSnapshot() { .containsExactlyInAnyOrder(Row.of(1, 11, 111), Row.of(2, 22, 222)); } + @Test + public void testDedicatedPathLimitTenOnManyRows() { + sql("CREATE TABLE limit_many_rows (a INT, b INT, c INT)"); + StringBuilder insertValues = new StringBuilder(); + for (int i = 1; i <= 100; i++) { + if (i > 1) { + insertValues.append(", "); + } + insertValues.append(String.format("(%d, %d, %d)", i, i * 10, i * 100)); + } + batchSql("INSERT INTO limit_many_rows VALUES " + insertValues); + + List result = + batchSql( + "SELECT * FROM limit_many_rows " + + "/*+ OPTIONS('scan.dedicated-split-generation'='true') */ " + + "LIMIT 10"); + assertThat(result).hasSize(10); + assertThat(result) + .containsExactlyInAnyOrder( + Row.of(1, 10, 100), + Row.of(2, 20, 200), + Row.of(3, 30, 300), + Row.of(4, 40, 400), + Row.of(5, 50, 500), + Row.of(6, 60, 600), + Row.of(7, 70, 700), + Row.of(8, 80, 800), + Row.of(9, 90, 900), + Row.of(10, 100, 1000)); + } + @Test public void testBatchReadSourceWithoutSnapshot() { assertThat( diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java index 01e0b6c11b02..c1fcc2ca0914 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java @@ -337,7 +337,7 @@ public void testProjectPushDownWithLookupChangelogProducer() { "CREATE TABLE IF NOT EXISTS T_P (" + "j INT, k INT, a INT, b INT, c STRING, PRIMARY KEY (j,k) NOT ENFORCED)" + " WITH ('merge-engine'='partial-update', 'changelog-producer' = 'lookup', " - + "'fields.a.sequence-group'='j', 'fields.b.sequence-group'='c');"); + + "'fields.a.sequence-group'='b,c');"); batchSql("INSERT INTO T_P VALUES (1, 1, 1, 1, '1')"); assertThat(sql("SELECT k, c FROM T_P")).containsExactlyInAnyOrder(Row.of(1, "1")); } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/SchemaChangeITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/SchemaChangeITCase.java index d51c1d006ac2..8e6ae0b123ea 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/SchemaChangeITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/SchemaChangeITCase.java @@ -1839,4 +1839,25 @@ public void testDropPrimaryKeyOnNonEmptyTable() { UnsupportedOperationException.class, "Cannot drop primary keys on a non-empty table.")); } + + private static final String BLOB_TABLE_OPTIONS = + "'row-tracking.enabled'='true', 'data-evolution.enabled'='true', 'bucket'='-1'"; + + @Test + public void testAddBlobColumnViaCommentDirective() { + sql("CREATE TABLE T (id INT, data STRING) WITH (" + BLOB_TABLE_OPTIONS + ")"); + + // bare directive — no user comment + sql("ALTER TABLE T ADD desc_col BYTES COMMENT '__BLOB_DESCRIPTOR_FIELD'"); + // directive + user comment + sql("ALTER TABLE T ADD picture BYTES COMMENT '__BLOB_FIELD; profile picture'"); + + String createSql = sql("SHOW CREATE TABLE T").get(0).toString(); + assertThat(createSql).doesNotContain("__BLOB"); + assertThat(createSql).contains("`desc_col`"); + assertThat(createSql).contains("`picture`"); + assertThat(createSql).contains("COMMENT 'profile picture'"); + assertThat(createSql).contains("'blob-field' = 'picture'"); + assertThat(createSql).contains("'blob-descriptor-field' = 'desc_col'"); + } } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/ConsumerActionITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/ConsumerActionITCase.java index 86358055f474..017e496685bc 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/ConsumerActionITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/ConsumerActionITCase.java @@ -26,11 +26,9 @@ import org.apache.paimon.types.DataType; import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; -import org.apache.paimon.utils.BlockingIterator; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.table.api.TableException; -import org.apache.flink.types.Row; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -40,9 +38,7 @@ import java.util.List; import java.util.Optional; -import static org.apache.flink.table.planner.factories.TestValuesTableFactory.changelogRow; import static org.apache.paimon.flink.util.ReadWriteTableTestUtil.init; -import static org.apache.paimon.flink.util.ReadWriteTableTestUtil.testStreamingRead; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -76,22 +72,8 @@ public void testResetConsumer(String invoker) throws Exception { writeData(rowData(2L, BinaryString.fromString("Hello"))); writeData(rowData(3L, BinaryString.fromString("Paimon"))); - // use consumer streaming read table - BlockingIterator iterator = - testStreamingRead( - "SELECT * FROM `" - + tableName - + "` /*+ OPTIONS('consumer-id'='myid','consumer.expiration-time'='3h') */", - Arrays.asList( - changelogRow("+I", 1L, "Hi"), - changelogRow("+I", 2L, "Hello"), - changelogRow("+I", 3L, "Paimon"))); - ConsumerManager consumerManager = new ConsumerManager(table.fileIO(), table.location()); - while (!consumerManager.consumer("myid").isPresent()) { - Thread.sleep(1000); - } - iterator.close(); + consumerManager.resetConsumer("myid", new Consumer(4)); Optional consumer1 = consumerManager.consumer("myid"); assertThat(consumer1).isPresent(); @@ -242,23 +224,9 @@ public void testResetBranchConsumer(String invoker) throws Exception { table.createBranch("b1", "tag"); String branchTableName = tableName + "$branch_b1"; - // use consumer streaming read table - BlockingIterator iterator = - testStreamingRead( - "SELECT * FROM `" - + branchTableName - + "` /*+ OPTIONS('consumer-id'='myid','consumer.expiration-time'='3h') */", - Arrays.asList( - changelogRow("+I", 1L, "Hi"), - changelogRow("+I", 2L, "Hello"), - changelogRow("+I", 3L, "Paimon"))); - ConsumerManager consumerManager = new ConsumerManager(table.fileIO(), table.location(), branchName); - while (!consumerManager.consumer("myid").isPresent()) { - Thread.sleep(1000); - } - iterator.close(); + consumerManager.resetConsumer("myid", new Consumer(4)); Optional consumer1 = consumerManager.consumer("myid"); assertThat(consumer1).isPresent(); @@ -356,54 +324,10 @@ public void testClearConsumers(String invoker) throws Exception { writeData(rowData(2L, BinaryString.fromString("Hello"))); writeData(rowData(3L, BinaryString.fromString("Paimon"))); - // use consumer streaming read table - BlockingIterator iterator1 = - testStreamingRead( - "SELECT * FROM `" - + tableName - + "` /*+ OPTIONS('consumer-id'='myid1_1','consumer.expiration-time'='3h') */", - Arrays.asList( - changelogRow("+I", 1L, "Hi"), - changelogRow("+I", 2L, "Hello"), - changelogRow("+I", 3L, "Paimon"))); - ConsumerManager consumerManager = new ConsumerManager(table.fileIO(), table.location()); - while (!consumerManager.consumer("myid1_1").isPresent()) { - Thread.sleep(1000); - } - iterator1.close(); - - // use consumer streaming read table - BlockingIterator iterator2 = - testStreamingRead( - "SELECT * FROM `" - + tableName - + "` /*+ OPTIONS('consumer-id'='myid1_2','consumer.expiration-time'='3h') */", - Arrays.asList( - changelogRow("+I", 1L, "Hi"), - changelogRow("+I", 2L, "Hello"), - changelogRow("+I", 3L, "Paimon"))); - - while (!consumerManager.consumer("myid1_2").isPresent()) { - Thread.sleep(1000); - } - iterator2.close(); - - // use consumer streaming read table - BlockingIterator iterator3 = - testStreamingRead( - "SELECT * FROM `" - + tableName - + "` /*+ OPTIONS('consumer-id'='myid2','consumer.expiration-time'='3h') */", - Arrays.asList( - changelogRow("+I", 1L, "Hi"), - changelogRow("+I", 2L, "Hello"), - changelogRow("+I", 3L, "Paimon"))); - - while (!consumerManager.consumer("myid2").isPresent()) { - Thread.sleep(1000); - } - iterator3.close(); + consumerManager.resetConsumer("myid1_1", new Consumer(4)); + consumerManager.resetConsumer("myid1_2", new Consumer(4)); + consumerManager.resetConsumer("myid2", new Consumer(4)); Optional consumer1 = consumerManager.consumer("myid1_1"); Optional consumer2 = consumerManager.consumer("myid1_2"); diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java index cea640ab8f10..ed78d5cec175 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/lineage/LineageUtilsTest.java @@ -19,6 +19,7 @@ package org.apache.paimon.flink.lineage; import org.apache.paimon.CoreOptions; +import org.apache.paimon.catalog.CatalogContext; import org.apache.paimon.flink.PaimonDataStreamScanProvider; import org.apache.paimon.flink.sink.PaimonDiscardingSink; import org.apache.paimon.flink.source.ContinuousFileStoreSource; @@ -26,8 +27,10 @@ import org.apache.paimon.flink.source.operator.MonitorSource; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.options.Options; import org.apache.paimon.schema.Schema; import org.apache.paimon.schema.SchemaManager; +import org.apache.paimon.table.CatalogEnvironment; import org.apache.paimon.table.FileStoreTable; import org.apache.paimon.table.FileStoreTableFactory; import org.apache.paimon.types.IntType; @@ -82,6 +85,24 @@ private FileStoreTable createTable( return FileStoreTableFactory.create(LocalFileIO.create(), tablePath); } + private FileStoreTable createTableWithCatalogOptions(Map catalogOptions) + throws Exception { + FileStoreTable table = + createTable(new HashMap<>(), Collections.emptyList(), Arrays.asList("f0")); + CatalogEnvironment catalogEnvironment = + new CatalogEnvironment( + null, + null, + null, + null, + null, + CatalogContext.create(Options.fromMap(catalogOptions)), + false, + false); + return FileStoreTableFactory.create( + LocalFileIO.create(), tablePath, table.schema(), catalogEnvironment); + } + @Test void testGetNamespace() throws Exception { FileStoreTable table = @@ -120,6 +141,28 @@ void testSourceLineageVertexUnbounded() throws Exception { assertThat(vertex.boundedness()).isEqualTo(Boundedness.CONTINUOUS_UNBOUNDED); } + @Test + void testSourceLineageVertexKeepsProvidedNameWhenCatalogKeyExists() throws Exception { + Map catalogOptions = new HashMap<>(); + catalogOptions.put("catalog-key", "jdbc-warehouse"); + FileStoreTable table = createTableWithCatalogOptions(catalogOptions); + + SourceLineageVertex vertex = LineageUtils.sourceLineageVertex("paimon.db.src", true, table); + + assertThat(vertex.datasets().get(0).name()).isEqualTo("paimon.db.src"); + } + + @Test + void testDataStreamSourceLineageVertexUsesCatalogKey() throws Exception { + Map catalogOptions = new HashMap<>(); + catalogOptions.put("catalog-key", "jdbc-warehouse"); + FileStoreTable table = createTableWithCatalogOptions(catalogOptions); + + SourceLineageVertex vertex = LineageUtils.sourceLineageVertex(true, table); + + assertThat(vertex.datasets().get(0).name()).isEqualTo("jdbc-warehouse." + table.fullName()); + } + @Test void testSinkLineageVertex() throws Exception { FileStoreTable table = @@ -135,6 +178,17 @@ void testSinkLineageVertex() throws Exception { assertThat(dataset.namespace()).startsWith("paimon://"); } + @Test + void testDataStreamSinkLineageVertexUsesCatalogKey() throws Exception { + Map catalogOptions = new HashMap<>(); + catalogOptions.put("catalog-key", "jdbc-warehouse"); + FileStoreTable table = createTableWithCatalogOptions(catalogOptions); + + LineageVertex vertex = LineageUtils.sinkLineageVertex(table); + + assertThat(vertex.datasets().get(0).name()).isEqualTo("jdbc-warehouse." + table.fullName()); + } + @Test void testConfigFacetContainsPartitionAndPrimaryKeys() throws Exception { FileStoreTable table = diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/procedure/DropGlobalIndexProcedureITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/procedure/DropGlobalIndexProcedureITCase.java index acda61357c13..5659467d8aa9 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/procedure/DropGlobalIndexProcedureITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/procedure/DropGlobalIndexProcedureITCase.java @@ -40,7 +40,7 @@ public class DropGlobalIndexProcedureITCase extends CatalogITCaseBase { @Test - public void testDropBitmapGlobalIndex() throws Exception { + public void testDropBtreeGlobalIndex() throws Exception { sql( "CREATE TABLE T (" + " id INT," @@ -65,58 +65,58 @@ public void testDropBitmapGlobalIndex() throws Exception { commit.close(); } - // Create bitmap index + // Create btree index tEnv.getConfig() .set(org.apache.flink.table.api.config.TableConfigOptions.TABLE_DML_SYNC, true); List createResult = sql( "CALL sys.create_global_index(`table` => 'default.T', " + "`index_column` => 'name', " - + "`index_type` => 'bitmap')"); + + "`index_type` => 'btree')"); assertThat(createResult).hasSize(1); assertThat(createResult.get(0).getField(0)) - .isEqualTo("bitmap global index created successfully for table: T"); + .isEqualTo("BTree global index created successfully for table: T"); // Verify index was created table = paimonTable("T"); - List bitmapEntries = + List btreeEntries = table.store().newIndexFileHandler().scanEntries().stream() - .filter(entry -> entry.indexFile().indexType().equals("bitmap")) + .filter(entry -> entry.indexFile().indexType().equals("btree")) .collect(Collectors.toList()); - assertThat(bitmapEntries).isNotEmpty(); + assertThat(btreeEntries).isNotEmpty(); long totalRowCount = - bitmapEntries.stream() + btreeEntries.stream() .map(entry -> entry.indexFile().rowCount()) .mapToLong(Long::longValue) .sum(); assertThat(totalRowCount).isEqualTo(100000L); - // Drop bitmap index + // Drop btree index List dropResult = sql( "CALL sys.drop_global_index(`table` => 'default.T', " + "`index_column` => 'name', " - + "`index_type` => 'bitmap')"); + + "`index_type` => 'btree')"); assertThat(dropResult).hasSize(1); assertThat(dropResult.get(0).getField(0)) .isInstanceOf(String.class) .asString() .contains("Dropped") - .contains("bitmap") + .contains("btree") .contains("global index files") .contains("name"); // Verify index was dropped table = paimonTable("T"); - bitmapEntries = + btreeEntries = table.store().newIndexFileHandler().scanEntries().stream() - .filter(entry -> entry.indexFile().indexType().equals("bitmap")) + .filter(entry -> entry.indexFile().indexType().equals("btree")) .collect(Collectors.toList()); - assertThat(bitmapEntries).isEmpty(); + assertThat(btreeEntries).isEmpty(); } @Test - public void testDropBitmapGlobalIndexWithPartition() throws Exception { + public void testDropBtreeGlobalIndexWithPartition() throws Exception { sql( "CREATE TABLE T (" + " id INT," @@ -194,27 +194,27 @@ public void testDropBitmapGlobalIndexWithPartition() throws Exception { commit.close(); } - // Create bitmap index + // Create btree index tEnv.getConfig() .set(org.apache.flink.table.api.config.TableConfigOptions.TABLE_DML_SYNC, true); List createResult = sql( "CALL sys.create_global_index(`table` => 'default.T', " + "`index_column` => 'name', " - + "`index_type` => 'bitmap')"); + + "`index_type` => 'btree')"); assertThat(createResult).hasSize(1); // Verify index was created table = paimonTable("T"); - List bitmapEntries = + List btreeEntries = table.store().newIndexFileHandler().scanEntries().stream() - .filter(entry -> entry.indexFile().indexType().equals("bitmap")) + .filter(entry -> entry.indexFile().indexType().equals("btree")) .collect(Collectors.toList()); - assertThat(bitmapEntries).isNotEmpty(); + assertThat(btreeEntries).isNotEmpty(); // Verify total row count long totalRowCount = - bitmapEntries.stream() + btreeEntries.stream() .map( entry -> entry.indexFile().globalIndexMeta().rowRangeEnd() @@ -226,30 +226,30 @@ public void testDropBitmapGlobalIndexWithPartition() throws Exception { .sum(); assertThat(totalRowCount).isEqualTo(189088L); - // Drop bitmap index for partition p1 only + // Drop btree index for partition p1 only List dropResult = sql( "CALL sys.drop_global_index(`table` => 'default.T', " + "`index_column` => 'name', " - + "`index_type` => 'bitmap', " + + "`index_type` => 'btree', " + "`partitions` => 'pt=p1')"); assertThat(dropResult).hasSize(1); assertThat(dropResult.get(0).getField(0)) .isInstanceOf(String.class) .asString() .contains("Dropped") - .contains("bitmap"); + .contains("btree"); // Verify only p1 index was dropped - bitmapEntries = + btreeEntries = table.store().newIndexFileHandler().scanEntries().stream() - .filter(entry -> entry.indexFile().indexType().equals("bitmap")) + .filter(entry -> entry.indexFile().indexType().equals("btree")) .collect(Collectors.toList()); - assertThat(bitmapEntries).isNotEmpty(); + assertThat(btreeEntries).isNotEmpty(); // Verify remaining row count (p0: 87222 + p2: 33433 = 120655) long remainingRowCount = - bitmapEntries.stream() + btreeEntries.stream() .map( entry -> entry.indexFile().globalIndexMeta().rowRangeEnd() @@ -294,11 +294,11 @@ public void testDropNonExistentIndex() throws Exception { sql( "CALL sys.drop_global_index(`table` => 'default.T', " + "`index_column` => 'name', " - + "`index_type` => 'bitmap')"); + + "`index_type` => 'btree')"); assertThat(dropResult).hasSize(1); assertThat(dropResult.get(0).getField(0)) .isInstanceOf(String.class) .asString() - .contains("No bitmap global index found for column 'name'"); + .contains("No btree global index found for column 'name'"); } } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/DedicatedSplitReadLimitTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/DedicatedSplitReadLimitTest.java new file mode 100644 index 000000000000..9ef7757dfaaa --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/DedicatedSplitReadLimitTest.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.source.operator; + +import org.apache.paimon.catalog.Catalog; +import org.apache.paimon.catalog.CatalogContext; +import org.apache.paimon.catalog.CatalogFactory; +import org.apache.paimon.catalog.Identifier; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.disk.IOManager; +import org.apache.paimon.metrics.MetricRegistry; +import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.schema.Schema; +import org.apache.paimon.table.Table; +import org.apache.paimon.table.sink.BatchTableCommit; +import org.apache.paimon.table.sink.BatchTableWrite; +import org.apache.paimon.table.sink.BatchWriteBuilder; +import org.apache.paimon.table.source.Split; +import org.apache.paimon.table.source.TableRead; +import org.apache.paimon.types.DataTypes; + +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests {@link ReadOperator} limit on the dedicated split read path. */ +public class DedicatedSplitReadLimitTest { + + private static final int LIMIT = 10; + + @TempDir Path tempDir; + + private Table table; + + @BeforeEach + public void before() + throws Catalog.TableAlreadyExistException, Catalog.DatabaseNotExistException, + Catalog.TableNotExistException, Catalog.DatabaseAlreadyExistException { + Catalog catalog = + CatalogFactory.createCatalog( + CatalogContext.create(new org.apache.paimon.fs.Path(tempDir.toUri()))); + Schema schema = + Schema.newBuilder() + .column("a", DataTypes.INT()) + .column("b", DataTypes.INT()) + .column("c", DataTypes.INT()) + .primaryKey("a") + .option("bucket", "1") + .build(); + Identifier identifier = Identifier.create("default", "t"); + catalog.createDatabase("default", false); + catalog.createTable(identifier, schema, false); + this.table = catalog.getTable(identifier); + } + + @Test + public void testReadOperatorStopsAfterLimit() throws Exception { + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + BatchTableWrite write = writeBuilder.newWrite(); + for (int i = 0; i < 100; i++) { + write.write(GenericRow.of(i, i, i)); + } + BatchTableCommit commit = writeBuilder.newCommit(); + commit.commit(write.prepareCommit()); + write.close(); + commit.close(); + + ReadBatchCountingRead countingRead = + new ReadBatchCountingRead(table.newReadBuilder().newRead()); + ReadOperator readOperator = new ReadOperator(() -> countingRead, null, (long) LIMIT); + + OneInputStreamOperatorTestHarness harness = + new OneInputStreamOperatorTestHarness<>(readOperator); + harness.setup( + InternalSerializers.create( + RowType.of(new IntType(), new IntType(), new IntType()))); + harness.open(); + for (Split split : table.newReadBuilder().newScan().plan().splits()) { + harness.processElement(new StreamRecord<>(split)); + } + + assertThat(harness.getOutput()).hasSize(LIMIT); + assertThat(countingRead.readBatchInvocations()).isEqualTo(1); + } + + private static class ReadBatchCountingRead implements TableRead { + + private final TableRead delegate; + private final AtomicInteger readBatchInvocations = new AtomicInteger(); + + private ReadBatchCountingRead(TableRead delegate) { + this.delegate = delegate; + } + + int readBatchInvocations() { + return readBatchInvocations.get(); + } + + @Override + public TableRead withMetricRegistry(MetricRegistry registry) { + delegate.withMetricRegistry(registry); + return this; + } + + @Override + public TableRead executeFilter() { + delegate.executeFilter(); + return this; + } + + @Override + public TableRead withIOManager(IOManager ioManager) { + delegate.withIOManager(ioManager); + return this; + } + + @Override + public RecordReader createReader(Split split) throws IOException { + RecordReader reader = delegate.createReader(split); + return new RecordReader() { + @Override + public RecordIterator readBatch() throws IOException { + readBatchInvocations.incrementAndGet(); + return reader.readBatch(); + } + + @Override + public void close() throws IOException { + reader.close(); + } + }; + } + } +} diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/OperatorSourceTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/OperatorSourceTest.java index e4ab4ec15799..9c57f27b866d 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/OperatorSourceTest.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/operator/OperatorSourceTest.java @@ -238,15 +238,11 @@ public void testReadOperatorWithLimit() throws Exception { } ArrayList values = new ArrayList<>(harness.getOutput()); - // In ReadOperator each Split is already counted as one input record. But in this case it - // will not happen. - // So in this case the result values's size if 3 even if the limit is 2. - // The IT case see BatchFileStoreITCase#testBatchReadSourceWithSnapshot. + // ReadOperator limit is enforced on emitted records. assertThat(values) .containsExactlyInAnyOrder( new StreamRecord<>(GenericRowData.of(1, 1, 1)), - new StreamRecord<>(GenericRowData.of(2, 2, 2)), - new StreamRecord<>(GenericRowData.of(3, 3, 3))); + new StreamRecord<>(GenericRowData.of(2, 2, 2))); } @Test diff --git a/paimon-format/src/main/java/org/apache/paimon/format/row/RowFormatWriter.java b/paimon-format/src/main/java/org/apache/paimon/format/row/RowFormatWriter.java index ff9ee1112bd6..7fcd75074491 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/row/RowFormatWriter.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/row/RowFormatWriter.java @@ -89,7 +89,6 @@ public void close() throws IOException { footer.writeTo(out); out.flush(); - out.close(); } private void flushBlock() throws IOException { diff --git a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveAlterTableUtils.java b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveAlterTableUtils.java index 06a0df9cff7c..7759a3fd9371 100644 --- a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveAlterTableUtils.java +++ b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveAlterTableUtils.java @@ -29,20 +29,22 @@ /** Utils for hive alter table. */ public class HiveAlterTableUtils { - public static void alterTable(IMetaStoreClient client, Identifier identifier, Table table) + public static void alterTable( + IMetaStoreClient client, Identifier identifier, Table table, boolean skipUpdateStats) throws TException { try { - alterTableWithEnv(client, identifier, table); + alterTableWithEnv(client, identifier, table, skipUpdateStats); } catch (NoClassDefFoundError | NoSuchMethodError e) { alterTableWithoutEnv(client, identifier, table); } } private static void alterTableWithEnv( - IMetaStoreClient client, Identifier identifier, Table table) throws TException { + IMetaStoreClient client, Identifier identifier, Table table, boolean skipUpdateStats) + throws TException { boolean skipHiveUpdateStats = - Boolean.parseBoolean( - table.getParameters().get(StatsSetupConst.DO_NOT_UPDATE_STATS)); + Boolean.parseBoolean(table.getParameters().get(StatsSetupConst.DO_NOT_UPDATE_STATS)) + || skipUpdateStats; EnvironmentContext environmentContext = new EnvironmentContext(); environmentContext.putToProperties(StatsSetupConst.CASCADE, "true"); environmentContext.putToProperties( diff --git a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalog.java b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalog.java index 63ba25e1b544..c988d2b41adb 100644 --- a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalog.java +++ b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalog.java @@ -116,6 +116,7 @@ import static org.apache.paimon.format.csv.CsvOptions.FIELD_DELIMITER; import static org.apache.paimon.hive.HiveCatalogOptions.HADOOP_CONF_DIR; import static org.apache.paimon.hive.HiveCatalogOptions.HIVE_CONF_DIR; +import static org.apache.paimon.hive.HiveCatalogOptions.HIVE_SKIP_UPDATE_STATS; import static org.apache.paimon.hive.HiveCatalogOptions.IDENTIFIER; import static org.apache.paimon.hive.HiveCatalogOptions.LOCATION_IN_PROPERTIES; import static org.apache.paimon.hive.HiveTableUtils.tryToFormatSchema; @@ -1292,7 +1293,12 @@ private void alterTableToHms( Path location = getTableLocation(identifier, table); // file format is null, because only data table support alter table. updateHmsTable(table, identifier, newSchema, null, location); - clients().execute(client -> HiveAlterTableUtils.alterTable(client, identifier, table)); + boolean skipUpdateStats = options.get(HIVE_SKIP_UPDATE_STATS); + clients() + .execute( + client -> + HiveAlterTableUtils.alterTable( + client, identifier, table, skipUpdateStats)); } @Override diff --git a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalogOptions.java b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalogOptions.java index ceab49836820..7700a763d093 100644 --- a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalogOptions.java +++ b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/HiveCatalogOptions.java @@ -72,6 +72,14 @@ public final class HiveCatalogOptions { .defaultValue(TimeUnit.MINUTES.toMillis(5)) .withDescription("Setting the client's pool cache eviction interval(ms).\n"); + public static final ConfigOption HIVE_SKIP_UPDATE_STATS = + ConfigOptions.key("hive.skip-update-stats") + .booleanType() + .defaultValue(false) + .withDescription( + "If true, sets DO_NOT_UPDATE_STATS in the Hive EnvironmentContext " + + "when altering tables, preventing Hive from updating table statistics."); + public static final ConfigOption CLIENT_POOL_CACHE_KEYS = ConfigOptions.key("client-pool-cache.keys") .stringType() diff --git a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/migrate/HiveMigrator.java b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/migrate/HiveMigrator.java index e25a61e16fcb..830a01f9781c 100644 --- a/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/migrate/HiveMigrator.java +++ b/paimon-hive/paimon-hive-catalog/src/main/java/org/apache/paimon/hive/migrate/HiveMigrator.java @@ -165,20 +165,22 @@ public void executeMigrate() throws Exception { FileStoreTable paimonTable = (FileStoreTable) hiveCatalog.getTable(identifier); checkPaimonTable(paimonTable); - List partitions = - client.listPartitions(sourceDatabase, sourceTable, Short.MAX_VALUE); checkCompatible(sourceHiveTable, paimonTable); List tasks = new ArrayList<>(); Map rollBack = new ConcurrentHashMap<>(); - if (partitions.isEmpty()) { + if (sourceHiveTable.getPartitionKeys().isEmpty()) { tasks.add( importUnPartitionedTableTask( fileIO, sourceHiveTable, paimonTable, rollBack)); } else { - tasks.addAll( - importPartitionedTableTask( - fileIO, partitions, sourceHiveTable, paimonTable, rollBack)); + List partitions = + client.listPartitions(sourceDatabase, sourceTable, Short.MAX_VALUE); + if (!partitions.isEmpty()) { + tasks.addAll( + importPartitionedTableTask( + fileIO, partitions, sourceHiveTable, paimonTable, rollBack)); + } } List> futures = diff --git a/paimon-hive/paimon-hive-catalog/src/test/java/org/apache/paimon/hive/HiveTableStatsTest.java b/paimon-hive/paimon-hive-catalog/src/test/java/org/apache/paimon/hive/HiveTableStatsTest.java index 35065eb43b03..1d21bb56be84 100644 --- a/paimon-hive/paimon-hive-catalog/src/test/java/org/apache/paimon/hive/HiveTableStatsTest.java +++ b/paimon-hive/paimon-hive-catalog/src/test/java/org/apache/paimon/hive/HiveTableStatsTest.java @@ -33,7 +33,6 @@ import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; import org.apache.paimon.shade.guava30.com.google.common.collect.Maps; -import org.apache.hadoop.hive.common.StatsSetupConst; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.api.Table; import org.junit.jupiter.api.BeforeEach; @@ -44,6 +43,7 @@ import java.util.UUID; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTORECONNECTURLKEY; +import static org.apache.paimon.hive.HiveCatalogOptions.HIVE_SKIP_UPDATE_STATS; import static org.assertj.core.api.Assertions.assertThat; /** Verify that table stats has been updated. */ @@ -59,7 +59,7 @@ public void setUp() throws Exception { hiveConf.setVar(METASTORECONNECTURLKEY, jdoConnectionURL + ";create=true"); String metastoreClientClass = "org.apache.hadoop.hive.metastore.HiveMetaStoreClient"; Options catalogOptions = new Options(); - catalogOptions.set(StatsSetupConst.DO_NOT_UPDATE_STATS, "true"); + catalogOptions.set(HIVE_SKIP_UPDATE_STATS, true); catalogOptions.set(CatalogOptions.WAREHOUSE, warehouse); CatalogContext catalogContext = CatalogContext.create(catalogOptions); FileIO fileIO = FileIO.get(new Path(warehouse), catalogContext); @@ -81,14 +81,15 @@ public void testAlterTable() throws Exception { Maps.newHashMap(), ""), false); + HiveCatalog hiveCatalog = (HiveCatalog) catalog; + Table table1 = hiveCatalog.getHmsTable(identifier); catalog.alterTable( identifier, Lists.newArrayList( SchemaChange.addColumn("col2", DataTypes.DATE()), SchemaChange.addColumn("col3", DataTypes.STRING(), "col3 field")), false); - HiveCatalog hiveCatalog = (HiveCatalog) catalog; - Table table = hiveCatalog.getHmsTable(identifier); - assertThat(table.getParameters().get("COLUMN_STATS_ACCURATE")).isEqualTo(null); + Table table2 = hiveCatalog.getHmsTable(identifier); + assertThat(table1.getParameters()).isEqualTo(table2.getParameters()); } } diff --git a/paimon-hive/paimon-hive-common/src/main/java/org/apache/paimon/hive/HiveTypeUtils.java b/paimon-hive/paimon-hive-common/src/main/java/org/apache/paimon/hive/HiveTypeUtils.java index e4799341d1dc..23c01d1144ac 100644 --- a/paimon-hive/paimon-hive-common/src/main/java/org/apache/paimon/hive/HiveTypeUtils.java +++ b/paimon-hive/paimon-hive-common/src/main/java/org/apache/paimon/hive/HiveTypeUtils.java @@ -45,6 +45,7 @@ import org.apache.paimon.types.VarBinaryType; import org.apache.paimon.types.VarCharType; import org.apache.paimon.types.VariantType; +import org.apache.paimon.types.VectorType; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveVarchar; @@ -235,6 +236,11 @@ public TypeInfo visit(BlobType blobType) { return TypeInfoFactory.binaryTypeInfo; } + @Override + public TypeInfo visit(VectorType vectorType) { + return TypeInfoFactory.getListTypeInfo(vectorType.getElementType().accept(this)); + } + @Override protected TypeInfo defaultMethod(org.apache.paimon.types.DataType dataType) { throw new UnsupportedOperationException("Unsupported type: " + dataType); diff --git a/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PaimonOutputFormat.java b/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PaimonOutputFormat.java index ef1ee687ecb0..9254ba91f61e 100644 --- a/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PaimonOutputFormat.java +++ b/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PaimonOutputFormat.java @@ -19,10 +19,15 @@ package org.apache.paimon.hive.mapred; import org.apache.paimon.CoreOptions; +import org.apache.paimon.data.GenericRow; import org.apache.paimon.hive.RowDataContainer; +import org.apache.paimon.schema.TableSchema; import org.apache.paimon.table.FileStoreTable; import org.apache.paimon.table.sink.BatchTableWrite; import org.apache.paimon.table.sink.BatchWriteBuilder; +import org.apache.paimon.types.DataField; +import org.apache.paimon.utils.PartitionPathUtils; +import org.apache.paimon.utils.TypeUtils; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -37,7 +42,10 @@ import org.apache.hadoop.util.Progressable; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Properties; @@ -69,20 +77,122 @@ public FileSinkOperator.RecordWriter getHiveRecordWriter( Properties properties, Progressable progressable) throws IOException { - return writer(jobConf); + FileStoreTable table = forWriteOnly(createFileStoreTable(jobConf)); + PaimonRecordWriter inner = writer(jobConf, table); + + GenericRow staticPartitionRow = buildStaticPartitionRow(path, table.schema()); + return staticPartitionRow == null + ? inner + : new PartitionedRecordWriter( + inner, staticPartitionRow, table.schema().fields().size()); } private static PaimonRecordWriter writer(JobConf jobConf) { + return writer(jobConf, forWriteOnly(createFileStoreTable(jobConf))); + } + + private static PaimonRecordWriter writer(JobConf jobConf, FileStoreTable table) { TaskAttemptID taskAttemptID = TezUtil.taskAttemptWrapper(jobConf); + BatchWriteBuilder batchWriteBuilder = table.newBatchWriteBuilder(); + BatchTableWrite batchTableWrite = batchWriteBuilder.newWrite(); + return new PaimonRecordWriter(batchTableWrite, taskAttemptID, table.name()); + } - FileStoreTable table = createFileStoreTable(jobConf); - // force write-only = true + private static FileStoreTable forWriteOnly(FileStoreTable table) { Map newOptions = Collections.singletonMap(CoreOptions.WRITE_ONLY.key(), Boolean.TRUE.toString()); - FileStoreTable copy = table.copy(newOptions); - BatchWriteBuilder batchWriteBuilder = copy.newBatchWriteBuilder(); - BatchTableWrite batchTableWrite = batchWriteBuilder.newWrite(); + return table.copy(newOptions); + } + + static GenericRow buildStaticPartitionRow(Path path, TableSchema schema) { + List partitionKeys = schema.partitionKeys(); + if (partitionKeys.isEmpty()) { + return null; + } + + LinkedHashMap spec = + PartitionPathUtils.extractPartitionSpecFromPath( + new org.apache.paimon.fs.Path(path.toString())); + if (spec.isEmpty()) { + return null; + } + + assertPartitionKeysAreSchemaTail(schema.fields(), partitionKeys); + + GenericRow row = new GenericRow(partitionKeys.size()); + List fields = schema.fields(); + for (int i = 0; i < partitionKeys.size(); i++) { + String key = partitionKeys.get(i); + String raw = lookupCaseInsensitive(spec, key); + if (raw == null) { + throw new IllegalArgumentException( + "Mixed static and dynamic partition writes are not supported for managed " + + "Paimon Hive tables. Partition key '" + + key + + "' has no value in the static partition path " + + path + + ". Provide values for every partition key in the INSERT."); + } + DataField field = findField(fields, key); + row.setField(i, TypeUtils.castFromString(raw, field.type())); + } + return row; + } + + private static void assertPartitionKeysAreSchemaTail( + List fields, List partitionKeys) { + int n = partitionKeys.size(); + int start = fields.size() - n; + if (start < 0) { + throw new IllegalArgumentException( + "Table schema has " + + fields.size() + + " columns but the schema declares " + + n + + " partition keys. Static partition write requires partition keys " + + "to be a trailing slice of the schema."); + } + for (int i = 0; i < n; i++) { + String expected = partitionKeys.get(i); + String actual = fields.get(start + i).name(); + if (!actual.equalsIgnoreCase(expected)) { + List schemaNames = new ArrayList<>(fields.size()); + for (DataField f : fields) { + schemaNames.add(f.name()); + } + throw new IllegalArgumentException( + "Static partition write requires partition keys to be the trailing " + + "columns of the schema in declared order. Expected column at " + + "position " + + (start + i) + + " to be '" + + expected + + "', found '" + + actual + + "'. Schema: " + + schemaNames + + ", partition keys: " + + partitionKeys + + "."); + } + } + } + + private static String lookupCaseInsensitive(Map spec, String key) { + for (Map.Entry e : spec.entrySet()) { + if (e.getKey().equalsIgnoreCase(key)) { + return e.getValue(); + } + } + return null; + } - return new PaimonRecordWriter(batchTableWrite, taskAttemptID, copy.name()); + private static DataField findField(List fields, String name) { + for (DataField f : fields) { + if (f.name().equalsIgnoreCase(name)) { + return f; + } + } + throw new IllegalStateException("Partition column not found in schema: " + name); } } diff --git a/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PartitionedRecordWriter.java b/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PartitionedRecordWriter.java new file mode 100644 index 000000000000..10c0bd2a4119 --- /dev/null +++ b/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/mapred/PartitionedRecordWriter.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.hive.mapred; + +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.JoinedRow; +import org.apache.paimon.hive.RowDataContainer; + +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.Reporter; + +import java.io.IOException; + +class PartitionedRecordWriter + implements FileSinkOperator.RecordWriter, + org.apache.hadoop.mapred.RecordWriter { + + private final PaimonRecordWriter inner; + private final GenericRow partitionRow; + private final int dataOnlyWidth; + private final int fullSchemaWidth; + private final JoinedRow joinedRow = new JoinedRow(); + + PartitionedRecordWriter( + PaimonRecordWriter inner, GenericRow partitionRow, int fullSchemaWidth) { + this.inner = inner; + this.partitionRow = partitionRow; + this.fullSchemaWidth = fullSchemaWidth; + this.dataOnlyWidth = fullSchemaWidth - partitionRow.getFieldCount(); + } + + @Override + public void write(Writable row) throws IOException { + InternalRow source = ((RowDataContainer) row).get(); + InternalRow toWrite; + int width = source.getFieldCount(); + if (width == dataOnlyWidth) { + joinedRow.replace(source, partitionRow); + toWrite = joinedRow; + } else if (width == fullSchemaWidth) { + toWrite = source; + } else { + throw new IOException( + "Unexpected row width " + + width + + "; expected " + + dataOnlyWidth + + " (static partition path) or " + + fullSchemaWidth + + " (full schema)"); + } + try { + inner.batchTableWrite().write(toWrite); + } catch (Exception e) { + throw new IOException(e); + } + } + + @Override + public void write(NullWritable key, RowDataContainer value) throws IOException { + write(value); + } + + @Override + public void close(boolean abort) throws IOException { + inner.close(abort); + } + + @Override + public void close(Reporter reporter) throws IOException { + inner.close(reporter); + } +} diff --git a/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/objectinspector/PaimonObjectInspectorFactory.java b/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/objectinspector/PaimonObjectInspectorFactory.java index 9e9eb0f31f7a..1b81722d000f 100644 --- a/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/objectinspector/PaimonObjectInspectorFactory.java +++ b/paimon-hive/paimon-hive-connector-common/src/main/java/org/apache/paimon/hive/objectinspector/PaimonObjectInspectorFactory.java @@ -28,6 +28,7 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.types.TimeType; import org.apache.paimon.types.VarCharType; +import org.apache.paimon.types.VectorType; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; @@ -51,6 +52,7 @@ public static ObjectInspector create(DataType logicalType) { case FLOAT: case DOUBLE: case BINARY: + case BLOB: case VARBINARY: return PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( (PrimitiveTypeInfo) HiveTypeUtils.toTypeInfo(logicalType)); @@ -81,6 +83,9 @@ public static ObjectInspector create(DataType logicalType) { case ARRAY: ArrayType arrayType = (ArrayType) logicalType; return new PaimonListObjectInspector(arrayType.getElementType()); + case VECTOR: + VectorType vectorType = (VectorType) logicalType; + return new PaimonListObjectInspector(vectorType.getElementType()); case MAP: MapType mapType = (MapType) logicalType; return new PaimonMapObjectInspector(mapType.getKeyType(), mapType.getValueType()); diff --git a/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/HiveWriteITCase.java b/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/HiveWriteITCase.java index b124cef1b1ae..73a83e9d9f83 100644 --- a/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/HiveWriteITCase.java +++ b/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/HiveWriteITCase.java @@ -196,6 +196,75 @@ public void testInsert() throws Exception { assertThat(select).containsExactly("1\t2\t3\tHello", "4\t5\t6\tFine"); } + @Test + public void testInsertStaticPartition() throws Exception { + List sourceData = + Arrays.asList( + GenericRow.of( + 1, + BinaryString.fromString("Alice"), + 5000.0, + BinaryString.fromString("IT")), + GenericRow.of( + 2, + BinaryString.fromString("Bob"), + 6000.0, + BinaryString.fromString("HR")), + GenericRow.of( + 3, + BinaryString.fromString("Charlie"), + 5500.0, + BinaryString.fromString("IT"))); + + String sourceTableName = + createAppendOnlyExternalTable( + RowType.of( + new DataType[] { + DataTypes.INT(), + DataTypes.STRING(), + DataTypes.DOUBLE(), + DataTypes.STRING() + }, + new String[] {"id", "name", "salary", "department"}), + Collections.emptyList(), + sourceData); + + String tableName = + "static_partition_insert_" + UUID.randomUUID().toString().replace('-', '_'); + hiveShell.execute("SET hive.metastore.warehouse.dir=" + folder.newFolder().toURI()); + hiveShell.execute( + String.join( + "\n", + Arrays.asList( + "CREATE TABLE " + tableName + " (", + " id INT,", + " name STRING,", + " salary DOUBLE,", + " department STRING", + ")", + "PARTITIONED BY (dt STRING)", + "STORED BY '" + PaimonStorageHandler.class.getName() + "'", + "TBLPROPERTIES (", + " 'primary-key' = 'id,dt',", + " 'bucket' = '2',", + " 'file.format' = 'parquet'", + ")"))); + + hiveShell.execute( + "INSERT INTO TABLE " + + tableName + + " PARTITION (dt='2026') " + + "SELECT id, name, salary, department FROM " + + sourceTableName); + + List rows = hiveShell.executeQuery("SELECT * FROM " + tableName + " ORDER BY id"); + assertThat(rows) + .containsExactly( + "1\tAlice\t5000.0\tIT\t2026", + "2\tBob\t6000.0\tHR\t2026", + "3\tCharlie\t5500.0\tIT\t2026"); + } + @Test public void testWriteOnlyWithChangeLogTableOption() throws Exception { diff --git a/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/mapred/PaimonOutputFormatTest.java b/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/mapred/PaimonOutputFormatTest.java new file mode 100644 index 000000000000..ee31675ab56f --- /dev/null +++ b/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/mapred/PaimonOutputFormatTest.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.hive.mapred; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.schema.TableSchema; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.DataTypes; + +import org.apache.hadoop.fs.Path; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link PaimonOutputFormat}. */ +public class PaimonOutputFormatTest { + + @Test + public void buildsRowMatchingUserBugReport() { + TableSchema schema = + new TableSchema( + 0, + Arrays.asList( + new DataField(0, "id", DataTypes.INT().notNull()), + new DataField(1, "name", DataTypes.STRING()), + new DataField(2, "salary", DataTypes.DOUBLE()), + new DataField(3, "department", DataTypes.STRING()), + new DataField(4, "dt", DataTypes.STRING().notNull())), + 4, + Collections.singletonList("dt"), + Arrays.asList("id", "dt"), + Collections.emptyMap(), + ""); + + Path path = new Path("/wh/test_paimon2/dt=2026/file"); + + GenericRow row = PaimonOutputFormat.buildStaticPartitionRow(path, schema); + assertThat(row).isNotNull(); + assertThat(row.getFieldCount()).isEqualTo(1); + assertThat(row.getString(0)).isEqualTo(BinaryString.fromString("2026")); + } + + @Test + public void buildRowConvertsTypedPartitionValues() { + TableSchema schema = + new TableSchema( + 0, + Arrays.asList( + new DataField(0, "v", DataTypes.INT()), + new DataField(1, "region", DataTypes.STRING().notNull()), + new DataField(2, "year", DataTypes.INT().notNull())), + 2, + Arrays.asList("region", "year"), + Collections.emptyList(), + Collections.emptyMap(), + ""); + + Path path = new Path("/wh/t/region=us/year=2026/file"); + + GenericRow row = PaimonOutputFormat.buildStaticPartitionRow(path, schema); + assertThat(row).isNotNull(); + assertThat(row.getString(0)).isEqualTo(BinaryString.fromString("us")); + assertThat(row.getInt(1)).isEqualTo(2026); + } + + @Test + public void buildRowReturnsNullForUnpartitionedTable() { + GenericRow row = + PaimonOutputFormat.buildStaticPartitionRow( + new Path("/wh/t/file"), singleFieldSchema()); + assertThat(row).isNull(); + } + + @Test + public void buildRowReturnsNullWhenPathHasNoPartitionSegments() { + TableSchema schema = + new TableSchema( + 0, + Arrays.asList( + new DataField(0, "v", DataTypes.INT()), + new DataField(1, "dt", DataTypes.STRING().notNull())), + 1, + Collections.singletonList("dt"), + Collections.emptyList(), + Collections.emptyMap(), + ""); + + GenericRow row = + PaimonOutputFormat.buildStaticPartitionRow(new Path("/wh/t/_tmp/file"), schema); + assertThat(row).isNull(); + } + + @Test + public void buildRowFailsWhenPartitionKeysNotAtSchemaTail() { + TableSchema schema = + new TableSchema( + 0, + Arrays.asList( + new DataField(0, "id", DataTypes.INT().notNull()), + new DataField(1, "dt", DataTypes.STRING().notNull()), + new DataField(2, "name", DataTypes.STRING())), + 2, + Collections.singletonList("dt"), + Collections.emptyList(), + Collections.emptyMap(), + ""); + + Path path = new Path("/wh/t/dt=2026/file"); + + assertThatThrownBy(() -> PaimonOutputFormat.buildStaticPartitionRow(path, schema)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("trailing columns") + .hasMessageContaining("'dt'"); + } + + @Test + public void buildRowFailsOnMixedStaticAndDynamicPartition() { + TableSchema schema = + new TableSchema( + 0, + Arrays.asList( + new DataField(0, "v", DataTypes.INT()), + new DataField(1, "region", DataTypes.STRING().notNull()), + new DataField(2, "year", DataTypes.INT().notNull())), + 2, + Arrays.asList("region", "year"), + Collections.emptyList(), + Collections.emptyMap(), + ""); + + Path path = new Path("/wh/t/region=us/file"); + + assertThatThrownBy(() -> PaimonOutputFormat.buildStaticPartitionRow(path, schema)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Mixed static and dynamic partition"); + } + + private static TableSchema singleFieldSchema() { + return new TableSchema( + 0, + Collections.singletonList(new DataField(0, "v", DataTypes.INT())), + 0, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyMap(), + ""); + } +} diff --git a/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/objectinspector/PaimonBlobObjectInspectorTest.java b/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/objectinspector/PaimonBlobObjectInspectorTest.java new file mode 100644 index 000000000000..bd35017f3406 --- /dev/null +++ b/paimon-hive/paimon-hive-connector-common/src/test/java/org/apache/paimon/hive/objectinspector/PaimonBlobObjectInspectorTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.hive.objectinspector; + +import org.apache.paimon.types.DataTypes; + +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaBinaryObjectInspector; +import org.apache.hadoop.io.BytesWritable; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test create object inspector for blob type. */ +public class PaimonBlobObjectInspectorTest { + + @Test + public void testCategoryAndClass() { + PrimitiveObjectInspector oi = + (PrimitiveObjectInspector) PaimonObjectInspectorFactory.create(DataTypes.BLOB()); + assertThat(oi.getCategory()).isEqualTo(ObjectInspector.Category.PRIMITIVE); + assertThat(oi.getPrimitiveCategory()) + .isEqualTo(PrimitiveObjectInspector.PrimitiveCategory.BINARY); + assertThat(oi.getJavaPrimitiveClass()).isEqualTo(byte[].class); + assertThat(oi.getPrimitiveWritableClass()).isEqualTo(BytesWritable.class); + } + + @Test + public void testGetPrimitiveJavaObject() { + PrimitiveObjectInspector oi = + (PrimitiveObjectInspector) PaimonObjectInspectorFactory.create(DataTypes.BLOB()); + byte[] input = new byte[] {1, 2, 3, 4}; + assertThat((byte[]) oi.getPrimitiveJavaObject(input)).isEqualTo(input); + assertThat(oi.getPrimitiveJavaObject(null)).isNull(); + } + + @Test + public void testGetPrimitiveWritableObject() { + PrimitiveObjectInspector oi = + (PrimitiveObjectInspector) PaimonObjectInspectorFactory.create(DataTypes.BLOB()); + byte[] input = new byte[] {1, 2, 3, 4}; + BytesWritable expected = new BytesWritable(input); + assertThat(oi.getPrimitiveWritableObject(input)).isEqualTo(expected); + assertThat(oi.getPrimitiveWritableObject(null)).isNull(); + } + + @Test + public void testCopyObject() { + PrimitiveObjectInspector oi = + (PrimitiveObjectInspector) PaimonObjectInspectorFactory.create(DataTypes.BLOB()); + byte[] input = new byte[] {1, 2, 3, 4}; + Object copy = oi.copyObject(input); + assertThat(copy).isEqualTo(input); + assertThat(oi.copyObject(null)).isNull(); + } + + @Test + public void testCreateObjectInspector() { + PrimitiveObjectInspector oi = + (PrimitiveObjectInspector) PaimonObjectInspectorFactory.create(DataTypes.BLOB()); + assertThat(oi).isInstanceOf(JavaBinaryObjectInspector.class); + } +} diff --git a/paimon-lumina/pom.xml b/paimon-lumina/pom.xml index bf7a2a62f1ec..905c304965b1 100644 --- a/paimon-lumina/pom.xml +++ b/paimon-lumina/pom.xml @@ -34,7 +34,7 @@ under the License. lumina - https://lumina-binary.oss-cn-shanghai.aliyuncs.com/mvn-repo/ + https://dlf-mvn-repo.oss-cn-shanghai.aliyuncs.com/mvn-repo/release jindodata diff --git a/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexReader.java b/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexReader.java index df20e14f851c..68dd9b43fb80 100644 --- a/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexReader.java +++ b/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexReader.java @@ -43,6 +43,8 @@ import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; import static org.apache.paimon.utils.Preconditions.checkArgument; @@ -65,6 +67,7 @@ public class LuminaVectorGlobalIndexReader implements GlobalIndexReader { private final GlobalIndexFileReader fileReader; private final DataType fieldType; private final LuminaVectorIndexOptions options; + private final ExecutorService executor; private volatile LuminaIndexMeta indexMeta; private volatile LuminaIndex index; @@ -75,8 +78,10 @@ public LuminaVectorGlobalIndexReader( GlobalIndexFileReader fileReader, List ioMetas, DataType fieldType, - LuminaVectorIndexOptions options) { + LuminaVectorIndexOptions options, + ExecutorService executor) { checkArgument(ioMetas.size() == 1, "Expected exactly one index file per shard"); + this.executor = executor; this.fileReader = fileReader; this.ioMeta = ioMetas.get(0); this.fieldType = fieldType; @@ -84,17 +89,22 @@ public LuminaVectorGlobalIndexReader( } @Override - public Optional visitVectorSearch(VectorSearch vectorSearch) { - try { - ensureLoaded(); - return Optional.ofNullable(search(vectorSearch)); - } catch (IOException e) { - throw new RuntimeException( - String.format( - "Failed to search Lumina vector index with fieldName=%s, limit=%d", - vectorSearch.fieldName(), vectorSearch.limit()), - e); - } + public CompletableFuture> visitVectorSearch( + VectorSearch vectorSearch) { + return CompletableFuture.supplyAsync( + () -> { + try { + ensureLoaded(); + return Optional.ofNullable(search(vectorSearch)); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to search Lumina vector index with fieldName=%s, limit=%d", + vectorSearch.fieldName(), vectorSearch.limit()), + e); + } + }, + executor); } private ScoredGlobalIndexResult search(VectorSearch vectorSearch) throws IOException { @@ -345,73 +355,85 @@ public void close() throws IOException { // =================== unsupported ===================== @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } /** diff --git a/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexer.java b/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexer.java index 00ea925e86c7..276cc5aa8543 100644 --- a/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexer.java +++ b/paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexer.java @@ -28,6 +28,7 @@ import org.apache.paimon.types.DataType; import java.util.List; +import java.util.concurrent.ExecutorService; /** Lumina vector global indexer. */ public class LuminaVectorGlobalIndexer implements GlobalIndexer { @@ -47,7 +48,9 @@ public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) { @Override public GlobalIndexReader createReader( - GlobalIndexFileReader fileReader, List files) { - return new LuminaVectorGlobalIndexReader(fileReader, files, fieldType, options); + GlobalIndexFileReader fileReader, + List files, + ExecutorService executor) { + return new LuminaVectorGlobalIndexReader(fileReader, files, fieldType, options, executor); } } diff --git a/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorBenchmark.java b/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorBenchmark.java index 01dc570d19ae..1458352bba31 100644 --- a/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorBenchmark.java +++ b/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorBenchmark.java @@ -47,6 +47,8 @@ import java.util.Map; import java.util.SplittableRandom; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Benchmark for Lumina vector index using {@link LuminaVectorGlobalIndexWriter} and {@link @@ -541,10 +543,11 @@ public void benchmarkQuery() throws Exception { ioMetaArg -> benchFileIO.newInputStream( new Path(benchIndexDir, ioMetaArg.filePath().getName())); + ExecutorService executor = Executors.newCachedThreadPool(); try (LuminaVectorGlobalIndexReader reader = new LuminaVectorGlobalIndexReader( - gFileReader, ioMetas, vectorType, indexOptions)) { - reader.visitVectorSearch(vs); + gFileReader, ioMetas, vectorType, indexOptions, executor)) { + reader.visitVectorSearch(vs).join(); openBytesArr[i] = reader.getOpenBytesRead(); openSeekArr[i] = reader.getOpenSeekCount(); openReadArr[i] = reader.getOpenReadCount(); diff --git a/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexTest.java b/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexTest.java index 210baab03b6e..ab3724729548 100644 --- a/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexTest.java +++ b/paimon-lumina/src/test/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexTest.java @@ -51,6 +51,8 @@ import java.util.List; import java.util.Random; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -64,6 +66,7 @@ public class LuminaVectorGlobalIndexTest { private Path indexPath; private DataType vectorType; private final String fieldName = "vec"; + private ExecutorService executor; @BeforeEach public void setup() { @@ -86,10 +89,14 @@ public void setup() { fileIO = new LocalFileIO(); indexPath = new Path(tempDir.toString()); vectorType = new ArrayType(new FloatType()); + executor = Executors.newCachedThreadPool(); } @AfterEach public void cleanup() throws IOException { + if (executor != null) { + executor.shutdownNow(); + } if (fileIO != null) { fileIO.delete(indexPath, true); } @@ -151,11 +158,11 @@ public void testDifferentMetrics() throws IOException { GlobalIndexFileReader fileReader = createFileReader(metricIndexPath); try (LuminaVectorGlobalIndexReader reader = new LuminaVectorGlobalIndexReader( - fileReader, metas, vectorType, indexOptions)) { + fileReader, metas, vectorType, indexOptions, executor)) { VectorSearch vectorSearch = new VectorSearch(testVectors.get(0), 3, fieldName); LuminaScoredGlobalIndexResult searchResult = (LuminaScoredGlobalIndexResult) - reader.visitVectorSearch(vectorSearch).get(); + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(searchResult.results().getLongCardinality()).isEqualTo(3); assertThat(searchResult.results().contains(0L)).isTrue(); float score = searchResult.scoreGetter().score(0L); @@ -186,11 +193,11 @@ public void testDifferentDimensions() throws IOException { GlobalIndexFileReader fileReader = createFileReader(dimIndexPath); try (LuminaVectorGlobalIndexReader reader = new LuminaVectorGlobalIndexReader( - fileReader, metas, vectorType, indexOptions)) { + fileReader, metas, vectorType, indexOptions, executor)) { VectorSearch vectorSearch = new VectorSearch(testVectors.get(0), 5, fieldName); LuminaScoredGlobalIndexResult searchResult = (LuminaScoredGlobalIndexResult) - reader.visitVectorSearch(vectorSearch).get(); + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(searchResult.results().getLongCardinality()).isEqualTo(5); assertThat(searchResult.results().contains(0L)).isTrue(); float score = searchResult.scoreGetter().score(0L); @@ -236,12 +243,14 @@ public void testFloatVectorIndexEndToEnd() throws IOException { GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions, executor)) { // Query vector[0] = (1.0, 0.0); nearest neighbors by L2 should be // row 0 (1.0, 0.0), row 3 (0.98, 0.05), row 1 (0.95, 0.1). VectorSearch vectorSearch = new VectorSearch(vectors[0], 3, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + (LuminaScoredGlobalIndexResult) + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(result.results().getLongCardinality()).isEqualTo(3); assertThat(result.results().contains(0L)).isTrue(); assertThat(result.results().contains(3L)).isTrue(); @@ -255,14 +264,18 @@ public void testFloatVectorIndexEndToEnd() throws IOException { filterResults.add(expectedRowId); vectorSearch = new VectorSearch(vectors[0], 3, fieldName).withIncludeRowIds(filterResults); - result = (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + result = + (LuminaScoredGlobalIndexResult) + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(result.results().getLongCardinality()).isEqualTo(1); assertThat(result.results().contains(expectedRowId)).isTrue(); // Test with multiple results float[] queryVector = new float[] {0.85f, 0.15f}; vectorSearch = new VectorSearch(queryVector, 2, fieldName); - result = (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + result = + (LuminaScoredGlobalIndexResult) + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(result.results().getLongCardinality()).isEqualTo(2); } } @@ -292,12 +305,13 @@ public void testSearchWithFilter() throws IOException { GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions, executor)) { // Unfiltered: query (1,0) top-3 should come from the first cluster (rows 0,1,2). VectorSearch search = new VectorSearch(vectors[0], 3, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).get(); + (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).join().get(); assertThat(result.results().contains(0L)).isTrue(); assertThat(result.results().contains(1L)).isTrue(); assertThat(result.results().contains(2L)).isTrue(); @@ -306,7 +320,7 @@ public void testSearchWithFilter() throws IOException { RoaringNavigableMap64 filter = new RoaringNavigableMap64(); filter.add(3L); search = new VectorSearch(vectors[0], 3, fieldName).withIncludeRowIds(filter); - result = (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).get(); + result = (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).join().get(); assertThat(result.results().contains(3L)).isTrue(); assertThat(result.results().getLongCardinality()).isEqualTo(1); @@ -315,7 +329,7 @@ public void testSearchWithFilter() throws IOException { crossFilter.add(1L); crossFilter.add(4L); search = new VectorSearch(vectors[0], 6, fieldName).withIncludeRowIds(crossFilter); - result = (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).get(); + result = (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).join().get(); assertThat(result.results().contains(1L)).isTrue(); assertThat(result.results().contains(4L)).isTrue(); assertThat(result.results().getLongCardinality()).isEqualTo(2); @@ -363,13 +377,14 @@ public void testLargeVectorSet() throws IOException { GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions, executor)) { for (int queryIdx : new int[] {50, 150, 320}) { VectorSearch vectorSearch = new VectorSearch(testVectors.get(queryIdx), 3, fieldName); LuminaScoredGlobalIndexResult searchResult = (LuminaScoredGlobalIndexResult) - reader.visitVectorSearch(vectorSearch).get(); + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(searchResult.results().getLongCardinality()).isEqualTo(3); assertThat(searchResult.results().contains((long) queryIdx)).isTrue(); assertThat(searchResult.scoreGetter().score((long) queryIdx)).isNotNaN(); @@ -377,7 +392,8 @@ public void testLargeVectorSet() throws IOException { VectorSearch vectorSearch = new VectorSearch(testVectors.get(200), 5, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + (LuminaScoredGlobalIndexResult) + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(result.results().getLongCardinality()).isEqualTo(5); assertThat(result.results().contains(200L)).isTrue(); } @@ -418,10 +434,11 @@ public void testReaderMetaOptionsOverrideDefaultOptions() throws IOException { GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = new LuminaVectorGlobalIndexReader( - fileReader, metas, vectorType, readIndexOptions)) { + fileReader, metas, vectorType, readIndexOptions, executor)) { VectorSearch vectorSearch = new VectorSearch(vectors[0], 3, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + (LuminaScoredGlobalIndexResult) + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(result.results().getLongCardinality()).isEqualTo(3); assertThat(result.results().contains(0L)).isTrue(); } @@ -459,10 +476,12 @@ public void testVectorTypeEndToEnd() throws IOException { GlobalIndexFileReader fileReader = createFileReader(vecIndexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vecFieldType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vecFieldType, indexOptions, executor)) { VectorSearch vectorSearch = new VectorSearch(vectors[0], 3, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + (LuminaScoredGlobalIndexResult) + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(result.results().getLongCardinality()).isEqualTo(3); assertThat(result.results().contains(0L)).isTrue(); assertThat(result.results().contains(3L)).isTrue(); @@ -543,11 +562,13 @@ public void testNullVectorSkipWithCorrectIds() throws IOException { List metas = toIOMetas(results, indexPath); GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions, executor)) { // Search for vec0=(1,0), should find ID=0 VectorSearch vectorSearch = new VectorSearch(vectors[0], 3, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + (LuminaScoredGlobalIndexResult) + reader.visitVectorSearch(vectorSearch).join().get(); assertThat(result.results().getLongCardinality()).isEqualTo(3); // IDs should be {0, 2, 5} - shard-relative with null gaps assertThat(result.results().contains(0L)).isTrue(); @@ -609,7 +630,8 @@ public void testNullGapWithPreFilter() throws IOException { GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions, executor)) { // Pre-filter includes null gap IDs {1, 4} and valid ID {2} RoaringNavigableMap64 filter = new RoaringNavigableMap64(); filter.add(1L); // null position - should not match @@ -618,7 +640,7 @@ public void testNullGapWithPreFilter() throws IOException { VectorSearch search = new VectorSearch(vectors[0], 3, fieldName).withIncludeRowIds(filter); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).get(); + (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).join().get(); // Only row 2 should be in results (rows 1 and 4 are null gaps) assertThat(result.results().getLongCardinality()).isEqualTo(1); assertThat(result.results().contains(2L)).isTrue(); @@ -647,10 +669,11 @@ public void testNullAtStart() throws IOException { List metas = toIOMetas(results, indexPath); GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions, executor)) { VectorSearch search = new VectorSearch(new float[] {1.0f, 0.0f}, 1, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).get(); + (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).join().get(); assertThat(result.results().contains(1L)).isTrue(); assertThat(result.results().contains(0L)).isFalse(); } @@ -676,10 +699,11 @@ public void testNullAtEnd() throws IOException { List metas = toIOMetas(results, indexPath); GlobalIndexFileReader fileReader = createFileReader(indexPath); try (LuminaVectorGlobalIndexReader reader = - new LuminaVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + new LuminaVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions, executor)) { VectorSearch search = new VectorSearch(new float[] {1.0f, 0.0f}, 1, fieldName); LuminaScoredGlobalIndexResult result = - (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).get(); + (LuminaScoredGlobalIndexResult) reader.visitVectorSearch(search).join().get(); assertThat(result.results().contains(0L)).isTrue(); assertThat(result.results().contains(1L)).isFalse(); } diff --git a/paimon-mosaic/pom.xml b/paimon-mosaic/pom.xml new file mode 100644 index 000000000000..6e0d1eae286c --- /dev/null +++ b/paimon-mosaic/pom.xml @@ -0,0 +1,86 @@ + + + + 4.0.0 + + + paimon-parent + org.apache.paimon + 1.5-SNAPSHOT + + + paimon-mosaic + Paimon : Mosaic Format + + + + org.apache.paimon + mosaic + 0.1.0 + + + + org.apache.paimon + paimon-arrow + ${project.version} + + + + org.apache.paimon + paimon-common + ${project.version} + provided + + + + org.apache.paimon + paimon-core + ${project.version} + provided + + + + + + org.apache.paimon + paimon-common + ${project.version} + test-jar + test + + + + org.apache.paimon + paimon-test-utils + ${project.version} + test + + + + org.apache.paimon + paimon-core + ${project.version} + test-jar + test + + + diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormat.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormat.java new file mode 100644 index 000000000000..bff850a4e06c --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormat.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.format.FileFormat; +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatReaderFactory; +import org.apache.paimon.format.FormatWriterFactory; +import org.apache.paimon.format.SimpleStatsExtractor; +import org.apache.paimon.options.ConfigOption; +import org.apache.paimon.options.ConfigOptions; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.statistics.SimpleColStatsCollector; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.BigIntType; +import org.apache.paimon.types.BinaryType; +import org.apache.paimon.types.BlobType; +import org.apache.paimon.types.BooleanType; +import org.apache.paimon.types.CharType; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.DataTypeVisitor; +import org.apache.paimon.types.DateType; +import org.apache.paimon.types.DecimalType; +import org.apache.paimon.types.DoubleType; +import org.apache.paimon.types.FloatType; +import org.apache.paimon.types.IntType; +import org.apache.paimon.types.LocalZonedTimestampType; +import org.apache.paimon.types.MapType; +import org.apache.paimon.types.MultisetType; +import org.apache.paimon.types.RowType; +import org.apache.paimon.types.SmallIntType; +import org.apache.paimon.types.TimeType; +import org.apache.paimon.types.TimestampType; +import org.apache.paimon.types.TinyIntType; +import org.apache.paimon.types.VarBinaryType; +import org.apache.paimon.types.VarCharType; +import org.apache.paimon.types.VariantType; +import org.apache.paimon.types.VectorType; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; + +/** Mosaic {@link FileFormat}. */ +public class MosaicFileFormat extends FileFormat { + + public static final ConfigOption STATS_COLUMNS = + ConfigOptions.key("mosaic.stats-columns") + .stringType() + .defaultValue("") + .withDescription( + "Comma-separated list of column names to collect statistics for. " + + "Empty means no statistics collection."); + + public static final ConfigOption NUM_BUCKETS = + ConfigOptions.key("mosaic.num-buckets") + .intType() + .noDefaultValue() + .withDescription("Number of column buckets for parallel IO."); + + static { + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + } + + private final FileFormatFactory.FormatContext formatContext; + + public MosaicFileFormat(FileFormatFactory.FormatContext formatContext) { + super("mosaic"); + this.formatContext = formatContext; + } + + @Override + public FormatReaderFactory createReaderFactory( + RowType dataSchemaRowType, + RowType projectedRowType, + @Nullable List predicates) { + return new MosaicReaderFactory(dataSchemaRowType, projectedRowType, predicates); + } + + @Override + public FormatWriterFactory createWriterFactory(RowType type) { + return new MosaicWriterFactory(type, formatContext); + } + + @Override + public void validateDataFields(RowType rowType) { + MosaicRowTypeVisitor visitor = new MosaicRowTypeVisitor(); + for (DataType fieldType : rowType.getFieldTypes()) { + fieldType.accept(visitor); + } + } + + @Override + public Optional createStatsExtractor( + RowType type, SimpleColStatsCollector.Factory[] statsCollectors) { + return Optional.of(new MosaicSimpleStatsExtractor(type, statsCollectors)); + } + + static class MosaicRowTypeVisitor implements DataTypeVisitor { + + @Override + public Void visit(CharType charType) { + return null; + } + + @Override + public Void visit(VarCharType varCharType) { + return null; + } + + @Override + public Void visit(BooleanType booleanType) { + return null; + } + + @Override + public Void visit(BinaryType binaryType) { + return null; + } + + @Override + public Void visit(VarBinaryType varBinaryType) { + return null; + } + + @Override + public Void visit(DecimalType decimalType) { + return null; + } + + @Override + public Void visit(TinyIntType tinyIntType) { + return null; + } + + @Override + public Void visit(SmallIntType smallIntType) { + return null; + } + + @Override + public Void visit(IntType intType) { + return null; + } + + @Override + public Void visit(BigIntType bigIntType) { + return null; + } + + @Override + public Void visit(FloatType floatType) { + return null; + } + + @Override + public Void visit(DoubleType doubleType) { + return null; + } + + @Override + public Void visit(DateType dateType) { + return null; + } + + @Override + public Void visit(TimeType timeType) { + return null; + } + + @Override + public Void visit(TimestampType timestampType) { + return null; + } + + @Override + public Void visit(LocalZonedTimestampType localZonedTimestampType) { + return null; + } + + @Override + public Void visit(VariantType variantType) { + throw new UnsupportedOperationException( + "Mosaic file format does not support type VARIANT"); + } + + @Override + public Void visit(BlobType blobType) { + throw new UnsupportedOperationException( + "Mosaic file format does not support type BLOB"); + } + + @Override + public Void visit(ArrayType arrayType) { + throw new UnsupportedOperationException( + "Mosaic file format does not support type ARRAY"); + } + + @Override + public Void visit(VectorType vectorType) { + throw new UnsupportedOperationException( + "Mosaic file format does not support type VECTOR"); + } + + @Override + public Void visit(MultisetType multisetType) { + throw new UnsupportedOperationException( + "Mosaic file format does not support type MULTISET"); + } + + @Override + public Void visit(MapType mapType) { + throw new UnsupportedOperationException("Mosaic file format does not support type MAP"); + } + + @Override + public Void visit(RowType rowType) { + throw new UnsupportedOperationException("Mosaic file format does not support type ROW"); + } + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormatFactory.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormatFactory.java new file mode 100644 index 000000000000..782faba3e8f9 --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicFileFormatFactory.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.format.FileFormat; +import org.apache.paimon.format.FileFormatFactory; + +/** Factory to create {@link MosaicFileFormat}. */ +public class MosaicFileFormatFactory implements FileFormatFactory { + + public static final String IDENTIFIER = "mosaic"; + + @Override + public String identifier() { + return IDENTIFIER; + } + + @Override + public FileFormat create(FormatContext formatContext) { + return new MosaicFileFormat(formatContext); + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicInputFileAdapter.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicInputFileAdapter.java new file mode 100644 index 000000000000..3a307ea0f296 --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicInputFileAdapter.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.SeekableInputStream; +import org.apache.paimon.fs.VectoredReadable; +import org.apache.paimon.mosaic.InputFile; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; + +/** + * Adapts Paimon's {@link FileIO} to Mosaic's {@link InputFile} interface. + * + *

    Maintains a single {@link SeekableInputStream}. If the stream implements {@link + * VectoredReadable}, reads use {@link VectoredReadable#preadFully} which is thread-safe. Otherwise, + * reads are synchronized to protect seek+read sequences. + */ +public class MosaicInputFileAdapter implements InputFile, Closeable { + + private final Path path; + private final SeekableInputStream in; + private final VectoredReadable vectoredReadable; + + public MosaicInputFileAdapter(FileIO fileIO, Path path) throws IOException { + this.path = path; + this.in = fileIO.newInputStream(path); + this.vectoredReadable = in instanceof VectoredReadable ? (VectoredReadable) in : null; + } + + @Override + public void readFully(long position, byte[] buffer, int offset, int length) throws IOException { + if (vectoredReadable != null) { + vectoredReadable.preadFully(position, buffer, offset, length); + } else { + synchronized (in) { + in.seek(position); + int remaining = length; + int off = offset; + while (remaining > 0) { + int read = in.read(buffer, off, remaining); + if (read < 0) { + throw new EOFException( + "Reached end of file while reading " + + path + + " at position " + + position); + } + off += read; + remaining -= read; + } + } + } + } + + @Override + public void close() throws IOException { + in.close(); + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicObjects.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicObjects.java new file mode 100644 index 000000000000..54d15c43c09e --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicObjects.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.DecimalType; +import org.apache.paimon.types.LocalZonedTimestampType; +import org.apache.paimon.types.TimestampType; + +import javax.annotation.Nullable; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; + +/** Converts Mosaic's byte[] statistics to Paimon objects. */ +public class MosaicObjects { + + @Nullable + public static Object convertStatsValue(byte[] bytes, DataType dataType) { + if (bytes == null) { + return null; + } + switch (dataType.getTypeRoot()) { + case CHAR: + case VARCHAR: + return BinaryString.fromBytes(bytes); + case BINARY: + case VARBINARY: + return bytes; + default: + break; + } + if (bytes.length == 0) { + return null; + } + ByteBuffer buf = ByteBuffer.wrap(bytes); + switch (dataType.getTypeRoot()) { + case BOOLEAN: + return bytes[0] != 0; + case TINYINT: + return bytes[0]; + case SMALLINT: + return buf.getShort(); + case INTEGER: + case DATE: + case TIME_WITHOUT_TIME_ZONE: + return buf.getInt(); + case BIGINT: + return buf.getLong(); + case FLOAT: + return buf.getFloat(); + case DOUBLE: + return buf.getDouble(); + case DECIMAL: + DecimalType decimalType = (DecimalType) dataType; + BigInteger unscaled = new BigInteger(bytes); + BigDecimal decimal = new BigDecimal(unscaled, decimalType.getScale()); + return Decimal.fromBigDecimal( + decimal, decimalType.getPrecision(), decimalType.getScale()); + case TIMESTAMP_WITHOUT_TIME_ZONE: + return convertTimestamp(buf, ((TimestampType) dataType).getPrecision()); + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + return convertTimestamp(buf, ((LocalZonedTimestampType) dataType).getPrecision()); + default: + return null; + } + } + + private static Timestamp convertTimestamp(ByteBuffer buf, int precision) { + if (precision <= 3) { + return Timestamp.fromEpochMillis(buf.getLong()); + } else if (precision <= 6) { + return Timestamp.fromMicros(buf.getLong()); + } else { + // precision 7-9: 12 bytes = i64 millis (BE) + i32 nanos_of_milli (BE) + long millis = buf.getLong(); + int nanosOfMilli = buf.getInt(); + return Timestamp.fromEpochMillis(millis, nanosOfMilli); + } + } + + private MosaicObjects() {} +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicReaderFactory.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicReaderFactory.java new file mode 100644 index 000000000000..5b39c867e290 --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicReaderFactory.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.format.FormatReaderFactory; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.reader.FileRecordReader; +import org.apache.paimon.types.RowType; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.List; + +/** A factory to create Mosaic reader. */ +public class MosaicReaderFactory implements FormatReaderFactory { + + private final RowType dataSchemaRowType; + private final RowType projectedRowType; + @Nullable private final List predicates; + + public MosaicReaderFactory( + RowType dataSchemaRowType, + RowType projectedRowType, + @Nullable List predicates) { + this.dataSchemaRowType = dataSchemaRowType; + this.projectedRowType = projectedRowType; + this.predicates = predicates; + } + + @Override + public FileRecordReader createReader(Context context) throws IOException { + MosaicInputFileAdapter inputFile = + new MosaicInputFileAdapter(context.fileIO(), context.filePath()); + return new MosaicRecordsReader( + inputFile, + context.fileSize(), + dataSchemaRowType, + projectedRowType, + predicates, + context.filePath()); + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsReader.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsReader.java new file mode 100644 index 000000000000..24cdcbf05c96 --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsReader.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.arrow.reader.ArrowBatchReader; +import org.apache.paimon.data.GenericArray; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.fs.Path; +import org.apache.paimon.mosaic.ColumnStatistics; +import org.apache.paimon.mosaic.MosaicReader; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.reader.FileRecordIterator; +import org.apache.paimon.reader.FileRecordReader; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.RowType; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.paimon.format.mosaic.MosaicObjects.convertStatsValue; + +/** File reader for Mosaic format. */ +public class MosaicRecordsReader implements FileRecordReader { + + private final MosaicInputFileAdapter inputFileAdapter; + private final MosaicReader reader; + private final ArrowBatchReader arrowBatchReader; + private final Path filePath; + private final BufferAllocator allocator; + private final int numRowGroups; + private final RowType dataSchemaRowType; + @Nullable private final List predicates; + + private int currentRowGroup; + private long returnedPosition = -1; + private VectorSchemaRoot currentVsr; + + public MosaicRecordsReader( + MosaicInputFileAdapter inputFileAdapter, + long fileSize, + RowType dataSchemaRowType, + RowType projectedRowType, + @Nullable List predicates, + Path filePath) { + this.filePath = filePath; + this.inputFileAdapter = inputFileAdapter; + this.dataSchemaRowType = dataSchemaRowType; + this.predicates = predicates; + this.allocator = new RootAllocator(); + + try { + this.reader = MosaicReader.open(inputFileAdapter, fileSize, allocator); + } catch (Exception e) { + allocator.close(); + throw e; + } + + Schema fileSchema = reader.getSchema(); + Set fileColumnNames = new HashSet<>(); + for (Field field : fileSchema.getFields()) { + fileColumnNames.add(field.getName()); + } + List projectedNames = projectedRowType.getFieldNames(); + List existingColumns = new ArrayList<>(); + for (String name : projectedNames) { + if (fileColumnNames.contains(name)) { + existingColumns.add(name); + } + } + if (!existingColumns.isEmpty()) { + reader.project(existingColumns.toArray(new String[0])); + } + + this.numRowGroups = reader.numRowGroups(); + this.currentRowGroup = 0; + this.arrowBatchReader = new ArrowBatchReader(projectedRowType, true); + } + + @Nullable + @Override + public FileRecordIterator readBatch() throws IOException { + while (currentRowGroup < numRowGroups) { + int numRows = reader.rowGroupNumRows(currentRowGroup); + if (!matchesRowGroup(currentRowGroup, numRows)) { + returnedPosition += numRows; + currentRowGroup++; + continue; + } + + releaseCurrentVsr(); + + VectorSchemaRoot vsr = reader.readRowGroup(currentRowGroup, allocator); + currentRowGroup++; + this.currentVsr = vsr; + + Iterator rows = arrowBatchReader.readBatch(vsr).iterator(); + + return new FileRecordIterator() { + @Override + public long returnedPosition() { + return returnedPosition; + } + + @Override + public Path filePath() { + return filePath; + } + + @Nullable + @Override + public InternalRow next() { + if (rows.hasNext()) { + returnedPosition++; + return rows.next(); + } + return null; + } + + @Override + public void releaseBatch() { + releaseCurrentVsr(); + } + }; + } + return null; + } + + private boolean matchesRowGroup(int rowGroupIndex, long rowCount) { + if (predicates == null || predicates.isEmpty()) { + return true; + } + + Map statsMap = reader.getRowGroupStatistics(rowGroupIndex); + if (statsMap.isEmpty()) { + return true; + } + + int fieldCount = dataSchemaRowType.getFieldCount(); + GenericRow minValues = new GenericRow(fieldCount); + GenericRow maxValues = new GenericRow(fieldCount); + long[] nullCounts = new long[fieldCount]; + + List fields = dataSchemaRowType.getFields(); + for (int i = 0; i < fieldCount; i++) { + String colName = fields.get(i).name(); + ColumnStatistics stats = statsMap.get(colName); + if (stats == null) { + continue; + } + + nullCounts[i] = stats.getNullCount(); + if (stats.hasMinMax()) { + DataType dataType = fields.get(i).type(); + Object min = convertStatsValue(stats.getMin(), dataType); + Object max = convertStatsValue(stats.getMax(), dataType); + minValues.setField(i, min); + maxValues.setField(i, max); + } + } + + for (Predicate predicate : predicates) { + if (!predicate.test(rowCount, minValues, maxValues, new GenericArray(nullCounts))) { + return false; + } + } + return true; + } + + private void releaseCurrentVsr() { + if (currentVsr != null) { + currentVsr.close(); + currentVsr = null; + } + } + + @Override + public void close() throws IOException { + releaseCurrentVsr(); + reader.close(); + allocator.close(); + inputFileAdapter.close(); + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsWriter.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsWriter.java new file mode 100644 index 000000000000..fdef0eb3652b --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicRecordsWriter.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.arrow.ArrowBundleRecords; +import org.apache.paimon.arrow.vector.ArrowFormatWriter; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.format.BundleFormatWriter; +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.io.BundleRecords; +import org.apache.paimon.mosaic.ColumnStatistics; +import org.apache.paimon.mosaic.MosaicWriter; +import org.apache.paimon.mosaic.WriterOptions; +import org.apache.paimon.options.MemorySize; +import org.apache.paimon.types.RowType; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** Mosaic records writer. */ +public class MosaicRecordsWriter implements BundleFormatWriter { + + private final ArrowFormatWriter arrowFormatWriter; + private final MosaicWriter nativeWriter; + private final BufferAllocator allocator; + private final List statsColumnNames; + @Nullable private MosaicWriterMetadata metadata; + + public MosaicRecordsWriter( + OutputStream outputStream, + RowType rowType, + FileFormatFactory.FormatContext formatContext, + List statsColumnNames, + @Nullable Integer numBuckets) { + this.statsColumnNames = statsColumnNames; + this.allocator = new RootAllocator(); + + int writeBatchSize = formatContext.writeBatchSize(); + long writeBatchMemory = formatContext.writeBatchMemory().getBytes(); + + this.arrowFormatWriter = + new ArrowFormatWriter(rowType, writeBatchSize, true, allocator, writeBatchMemory); + + WriterOptions options = new WriterOptions().zstdLevel(formatContext.zstdLevel()); + if (numBuckets != null) { + options = options.numBuckets(numBuckets); + } + MemorySize blockSize = formatContext.blockSize(); + if (blockSize != null) { + options = options.rowGroupMaxSize(blockSize.getBytes()); + } + if (!statsColumnNames.isEmpty()) { + options.statsColumns(statsColumnNames.toArray(new String[0])); + } + + Schema arrowSchema = arrowFormatWriter.getVectorSchemaRoot().getSchema(); + this.nativeWriter = new MosaicWriter(outputStream, arrowSchema, options, allocator); + } + + @Override + public void addElement(InternalRow internalRow) { + if (!arrowFormatWriter.write(internalRow)) { + flush(); + if (!arrowFormatWriter.write(internalRow)) { + throw new RuntimeException("Failed to write row to Mosaic file"); + } + } + } + + @Override + public void writeBundle(BundleRecords bundleRecords) { + if (bundleRecords instanceof ArrowBundleRecords) { + flush(); + nativeWriter.write(((ArrowBundleRecords) bundleRecords).getVectorSchemaRoot()); + } else { + for (InternalRow row : bundleRecords) { + addElement(row); + } + } + } + + @Override + public boolean reachTargetSize(boolean suggestedCheck, long targetSize) { + if (!suggestedCheck) { + return false; + } + return nativeWriter.estimatedFileSize() >= targetSize; + } + + @Override + public void close() throws IOException { + Throwable throwable = null; + + try { + flush(); + } catch (Throwable t) { + throwable = t; + } + + try { + nativeWriter.close(); + } catch (Throwable t) { + throwable = addSuppressed(throwable, t); + } + + try { + collectMetadata(); + } catch (Throwable t) { + throwable = addSuppressed(throwable, t); + } + + try { + arrowFormatWriter.close(); + } catch (Throwable t) { + throwable = addSuppressed(throwable, t); + } + + try { + allocator.close(); + } catch (Throwable t) { + throwable = addSuppressed(throwable, t); + } + + if (throwable != null) { + rethrow(throwable); + } + } + + @Nullable + @Override + public Object writerMetadata() { + return metadata; + } + + private void collectMetadata() { + int numRowGroups = nativeWriter.numRowGroups(); + List> allStats = new ArrayList<>(numRowGroups); + for (int i = 0; i < numRowGroups; i++) { + allStats.add(nativeWriter.getRowGroupStatistics(i)); + } + this.metadata = new MosaicWriterMetadata(numRowGroups, allStats, statsColumnNames); + } + + private void flush() { + arrowFormatWriter.flush(); + if (!arrowFormatWriter.empty()) { + VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot(); + nativeWriter.write(vsr); + } + arrowFormatWriter.reset(); + } + + private static Throwable addSuppressed(Throwable throwable, Throwable suppressed) { + if (throwable == null) { + return suppressed; + } + throwable.addSuppressed(suppressed); + return throwable; + } + + private static void rethrow(Throwable throwable) throws IOException { + if (throwable instanceof IOException) { + throw (IOException) throwable; + } + if (throwable instanceof RuntimeException) { + throw (RuntimeException) throwable; + } + if (throwable instanceof Error) { + throw (Error) throwable; + } + throw new IOException(throwable); + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractor.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractor.java new file mode 100644 index 000000000000..f426b27cfa81 --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractor.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.format.SimpleColStats; +import org.apache.paimon.format.SimpleStatsExtractor; +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.fs.Path; +import org.apache.paimon.mosaic.ColumnStatistics; +import org.apache.paimon.mosaic.MosaicReader; +import org.apache.paimon.statistics.SimpleColStatsCollector; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.Pair; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.paimon.format.mosaic.MosaicObjects.convertStatsValue; + +/** Extracts statistics from Mosaic file metadata. */ +public class MosaicSimpleStatsExtractor implements SimpleStatsExtractor { + + private final RowType rowType; + private final SimpleColStatsCollector.Factory[] statsCollectors; + + public MosaicSimpleStatsExtractor( + RowType rowType, SimpleColStatsCollector.Factory[] statsCollectors) { + this.rowType = rowType; + this.statsCollectors = statsCollectors; + } + + @Override + public SimpleColStats[] extract(FileIO fileIO, Path path, long length) { + try (MosaicInputFileAdapter inputFile = new MosaicInputFileAdapter(fileIO, path); + BufferAllocator allocator = new RootAllocator(); + MosaicReader reader = MosaicReader.open(inputFile, length, allocator)) { + return extractFromStats(reader.numRowGroups(), reader::getRowGroupStatistics, null); + } catch (IOException e) { + throw new RuntimeException("Failed to extract stats from " + path, e); + } + } + + @Override + public SimpleColStats[] extract( + FileIO fileIO, Path path, long length, @Nullable Object writerMetadata) { + if (writerMetadata instanceof MosaicWriterMetadata) { + MosaicWriterMetadata meta = (MosaicWriterMetadata) writerMetadata; + Set statsFieldIndices = resolveStatsFieldIndices(meta.statsColumnNames()); + return extractFromStats( + meta.numRowGroups(), meta::getRowGroupStatistics, statsFieldIndices); + } + return extract(fileIO, path, length); + } + + @Override + public Pair extractWithFileInfo( + FileIO fileIO, Path path, long length) { + try (MosaicInputFileAdapter inputFile = new MosaicInputFileAdapter(fileIO, path); + BufferAllocator allocator = new RootAllocator(); + MosaicReader reader = MosaicReader.open(inputFile, length, allocator)) { + int numRowGroups = reader.numRowGroups(); + SimpleColStats[] stats = + extractFromStats(numRowGroups, reader::getRowGroupStatistics, null); + long rowCount = 0; + for (int rg = 0; rg < numRowGroups; rg++) { + rowCount += reader.rowGroupNumRows(rg); + } + return Pair.of(stats, new FileInfo(rowCount)); + } catch (IOException e) { + throw new RuntimeException("Failed to extract stats from " + path, e); + } + } + + @SuppressWarnings("unchecked") + private SimpleColStats[] extractFromStats( + int numRowGroups, + RowGroupStatsProvider statsProvider, + @Nullable Set statsFieldIndices) { + int fieldCount = rowType.getFieldCount(); + List fieldNames = rowType.getFieldNames(); + Object[] minValues = new Object[fieldCount]; + Object[] maxValues = new Object[fieldCount]; + long[] nullCounts = new long[fieldCount]; + Set seenColumns = new HashSet<>(); + + for (int rg = 0; rg < numRowGroups; rg++) { + Map statsMap = statsProvider.getRowGroupStatistics(rg); + for (Map.Entry entry : statsMap.entrySet()) { + int colIdx = fieldNames.indexOf(entry.getKey()); + if (colIdx < 0) { + continue; + } + + ColumnStatistics stat = entry.getValue(); + seenColumns.add(colIdx); + nullCounts[colIdx] += stat.getNullCount(); + + if (stat.hasMinMax()) { + DataType dataType = rowType.getFields().get(colIdx).type(); + Object min = convertStatsValue(stat.getMin(), dataType); + Object max = convertStatsValue(stat.getMax(), dataType); + if (min instanceof Comparable) { + if (minValues[colIdx] == null) { + minValues[colIdx] = min; + } else { + if (((Comparable) min).compareTo(minValues[colIdx]) < 0) { + minValues[colIdx] = min; + } + } + } + if (max instanceof Comparable) { + if (maxValues[colIdx] == null) { + maxValues[colIdx] = max; + } else { + if (((Comparable) max).compareTo(maxValues[colIdx]) > 0) { + maxValues[colIdx] = max; + } + } + } + } + } + } + + Set trackedColumns = statsFieldIndices != null ? statsFieldIndices : seenColumns; + SimpleColStatsCollector[] collectors = SimpleColStatsCollector.create(statsCollectors); + SimpleColStats[] result = new SimpleColStats[fieldCount]; + for (int i = 0; i < fieldCount; i++) { + if (!trackedColumns.contains(i) || !seenColumns.contains(i)) { + result[i] = collectors[i].convert(new SimpleColStats(null, null, null)); + } else { + SimpleColStats fieldStats = + new SimpleColStats(minValues[i], maxValues[i], nullCounts[i]); + result[i] = collectors[i].convert(fieldStats); + } + } + return result; + } + + private Set resolveStatsFieldIndices(List statsColumnNames) { + Set indices = new HashSet<>(); + List fieldNames = rowType.getFieldNames(); + for (String name : statsColumnNames) { + int idx = fieldNames.indexOf(name); + if (idx >= 0) { + indices.add(idx); + } + } + return indices; + } + + @FunctionalInterface + private interface RowGroupStatsProvider { + Map getRowGroupStatistics(int rowGroupIndex); + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterFactory.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterFactory.java new file mode 100644 index 000000000000..ca67647f8cab --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterFactory.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatWriter; +import org.apache.paimon.format.FormatWriterFactory; +import org.apache.paimon.fs.PositionOutputStream; +import org.apache.paimon.types.RowType; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; + +/** A factory to create Mosaic {@link FormatWriter}. */ +public class MosaicWriterFactory implements FormatWriterFactory { + + private final RowType rowType; + private final FileFormatFactory.FormatContext formatContext; + private final List statsColumnNames; + private final @Nullable Integer numBuckets; + + public MosaicWriterFactory(RowType rowType, FileFormatFactory.FormatContext formatContext) { + this.rowType = rowType; + this.formatContext = formatContext; + String statsColumnsValue = formatContext.options().get(MosaicFileFormat.STATS_COLUMNS); + if (statsColumnsValue == null || statsColumnsValue.trim().isEmpty()) { + this.statsColumnNames = new ArrayList<>(); + } else { + this.statsColumnNames = + Arrays.stream(statsColumnsValue.split(",")) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .collect(Collectors.toList()); + } + this.numBuckets = formatContext.options().get(MosaicFileFormat.NUM_BUCKETS); + } + + @Override + public FormatWriter create(PositionOutputStream out, String compression) { + validateCompression(compression); + return new MosaicRecordsWriter(out, rowType, formatContext, statsColumnNames, numBuckets); + } + + private static void validateCompression(String compression) { + if (compression == null) { + return; + } + String normalized = compression.toLowerCase(Locale.ROOT); + if (!normalized.equals("zstd")) { + throw new UnsupportedOperationException( + "Mosaic format only supports zstd compression, but got: " + compression); + } + } +} diff --git a/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterMetadata.java b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterMetadata.java new file mode 100644 index 000000000000..cd3149fd4470 --- /dev/null +++ b/paimon-mosaic/src/main/java/org/apache/paimon/format/mosaic/MosaicWriterMetadata.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.mosaic.ColumnStatistics; + +import java.util.List; +import java.util.Map; + +/** In-memory metadata captured from MosaicWriter after close. */ +public class MosaicWriterMetadata { + + private final int numRowGroups; + private final List> rowGroupStats; + private final List statsColumnNames; + + public MosaicWriterMetadata( + int numRowGroups, + List> rowGroupStats, + List statsColumnNames) { + this.numRowGroups = numRowGroups; + this.rowGroupStats = rowGroupStats; + this.statsColumnNames = statsColumnNames; + } + + public int numRowGroups() { + return numRowGroups; + } + + public Map getRowGroupStatistics(int rowGroupIndex) { + return rowGroupStats.get(rowGroupIndex); + } + + public List statsColumnNames() { + return statsColumnNames; + } +} diff --git a/paimon-mosaic/src/main/resources/META-INF/services/org.apache.paimon.format.FileFormatFactory b/paimon-mosaic/src/main/resources/META-INF/services/org.apache.paimon.format.FileFormatFactory new file mode 100644 index 000000000000..bc955c493506 --- /dev/null +++ b/paimon-mosaic/src/main/resources/META-INF/services/org.apache.paimon.format.FileFormatFactory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.paimon.format.mosaic.MosaicFileFormatFactory diff --git a/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFileFormatTest.java b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFileFormatTest.java new file mode 100644 index 000000000000..8e53164e8627 --- /dev/null +++ b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFileFormatTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatReaderFactory; +import org.apache.paimon.format.FormatWriterFactory; +import org.apache.paimon.options.Options; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Unit tests for {@link MosaicFileFormat} and {@link MosaicFileFormatFactory}. */ +class MosaicFileFormatTest { + + @Test + void testFactoryIdentifier() { + MosaicFileFormatFactory factory = new MosaicFileFormatFactory(); + assertThat(factory.identifier()).isEqualTo("mosaic"); + } + + @Test + void testFactoryCreate() { + MosaicFileFormatFactory factory = new MosaicFileFormatFactory(); + FileFormatFactory.FormatContext context = + new FileFormatFactory.FormatContext(new Options(), 1024, 1024); + assertThat(factory.create(context)).isInstanceOf(MosaicFileFormat.class); + } + + @Test + void testCreateReaderFactory() { + MosaicFileFormat format = createFormat(); + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + FormatReaderFactory readerFactory = + format.createReaderFactory(rowType, rowType, new ArrayList<>()); + assertThat(readerFactory).isInstanceOf(MosaicReaderFactory.class); + } + + @Test + void testCreateWriterFactory() { + MosaicFileFormat format = createFormat(); + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + FormatWriterFactory writerFactory = format.createWriterFactory(rowType); + assertThat(writerFactory).isInstanceOf(MosaicWriterFactory.class); + } + + @Test + void testValidateDataFieldsSupported() { + MosaicFileFormat format = createFormat(); + RowType rowType = + DataTypes.ROW( + DataTypes.INT(), + DataTypes.BIGINT(), + DataTypes.STRING(), + DataTypes.DOUBLE(), + DataTypes.FLOAT(), + DataTypes.BOOLEAN(), + DataTypes.DATE(), + DataTypes.TIMESTAMP(3), + DataTypes.DECIMAL(10, 2), + DataTypes.BYTES()); + format.validateDataFields(rowType); + } + + @Test + void testValidateDataFieldsMapUnsupported() { + MosaicFileFormat format = createFormat(); + RowType rowType = DataTypes.ROW(DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())); + assertThatThrownBy(() -> format.validateDataFields(rowType)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("MAP"); + } + + @Test + void testValidateDataFieldsMultisetUnsupported() { + MosaicFileFormat format = createFormat(); + RowType rowType = DataTypes.ROW(DataTypes.MULTISET(DataTypes.STRING())); + assertThatThrownBy(() -> format.validateDataFields(rowType)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("MULTISET"); + } + + @Test + void testCreateStatsExtractor() { + MosaicFileFormat format = createFormat(); + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + assertThat( + format.createStatsExtractor( + rowType, + new org.apache.paimon.statistics.SimpleColStatsCollector.Factory[] { + org.apache.paimon.statistics.SimpleColStatsCollector.from( + "full"), + org.apache.paimon.statistics.SimpleColStatsCollector.from( + "full") + })) + .isPresent(); + } + + private static MosaicFileFormat createFormat() { + return new MosaicFileFormat(new FileFormatFactory.FormatContext(new Options(), 1024, 1024)); + } +} diff --git a/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFormatReadWriteTest.java b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFormatReadWriteTest.java new file mode 100644 index 000000000000..41f632b3ee3f --- /dev/null +++ b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicFormatReadWriteTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.format.FileFormat; +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatReadWriteTest; +import org.apache.paimon.options.Options; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.api.BeforeAll; + +import java.math.BigDecimal; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** Round-trip read/write tests for Mosaic format. */ +class MosaicFormatReadWriteTest extends FormatReadWriteTest { + + MosaicFormatReadWriteTest() { + super("mosaic"); + } + + @BeforeAll + static void checkNativeLibrary() { + assumeTrue(isNativeAvailable(), "Mosaic native library not available"); + } + + @Override + protected FileFormat fileFormat() { + return new MosaicFileFormat(new FileFormatFactory.FormatContext(new Options(), 1024, 1024)); + } + + @Override + public String compression() { + return "zstd"; + } + + @Override + public boolean supportNestedReadPruning() { + return false; + } + + @Override + protected RowType rowTypeForFullTypesTest() { + return RowType.builder() + .field("f_int", DataTypes.INT().notNull()) + .field("f_string", DataTypes.STRING()) + .field("f_double", DataTypes.DOUBLE().notNull()) + .field("f_boolean", DataTypes.BOOLEAN()) + .field("f_tinyint", DataTypes.TINYINT()) + .field("f_smallint", DataTypes.SMALLINT()) + .field("f_bigint", DataTypes.BIGINT()) + .field("f_float", DataTypes.FLOAT()) + .field("f_binary", DataTypes.BYTES()) + .field("f_date", DataTypes.DATE()) + .field("f_timestamp3", DataTypes.TIMESTAMP(3)) + .field("f_timestamp6", DataTypes.TIMESTAMP(6)) + .field("f_decimal_5_2", DataTypes.DECIMAL(5, 2)) + .field("f_decimal_20_0", DataTypes.DECIMAL(20, 0)) + .build(); + } + + @Override + protected GenericRow expectedRowForFullTypesTest() { + return GenericRow.of( + 42, + BinaryString.fromString("hello mosaic"), + 3.14d, + true, + (byte) 7, + (short) 256, + 9876543210L, + 1.5f, + new byte[] {1, 2, 3}, + 18000, + Timestamp.fromEpochMillis(1700000000000L), + Timestamp.fromMicros(1700000000000000L), + Decimal.fromBigDecimal(new BigDecimal("123.45"), 5, 2), + Decimal.fromBigDecimal(new BigDecimal("12345678901234567890"), 20, 0)); + } + + @Override + protected void validateFullTypesResult(InternalRow actual, InternalRow expected) { + for (int i = 0; i < 14; i++) { + if (expected.isNullAt(i)) { + assertThat(actual.isNullAt(i)).isTrue(); + } + } + assertThat(actual.getInt(0)).isEqualTo(expected.getInt(0)); + assertThat(actual.getString(1)).isEqualTo(expected.getString(1)); + assertThat(actual.getDouble(2)).isEqualTo(expected.getDouble(2)); + assertThat(actual.getBoolean(3)).isEqualTo(expected.getBoolean(3)); + assertThat(actual.getByte(4)).isEqualTo(expected.getByte(4)); + assertThat(actual.getShort(5)).isEqualTo(expected.getShort(5)); + assertThat(actual.getLong(6)).isEqualTo(expected.getLong(6)); + assertThat(actual.getFloat(7)).isEqualTo(expected.getFloat(7)); + assertThat(actual.getBinary(8)).isEqualTo(expected.getBinary(8)); + assertThat(actual.getInt(9)).isEqualTo(expected.getInt(9)); + assertThat(actual.getTimestamp(10, 3)).isEqualTo(expected.getTimestamp(10, 3)); + assertThat(actual.getTimestamp(11, 6)).isEqualTo(expected.getTimestamp(11, 6)); + assertThat(actual.getDecimal(12, 5, 2)).isEqualTo(expected.getDecimal(12, 5, 2)); + assertThat(actual.getDecimal(13, 20, 0)).isEqualTo(expected.getDecimal(13, 20, 0)); + } + + private static boolean isNativeAvailable() { + try { + Class.forName("org.apache.paimon.mosaic.NativeLib"); + return true; + } catch (Throwable t) { + return false; + } + } +} diff --git a/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicObjectsTest.java b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicObjectsTest.java new file mode 100644 index 000000000000..e05ed1709c4a --- /dev/null +++ b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicObjectsTest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.types.DataTypes; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit tests for {@link MosaicObjects}. */ +class MosaicObjectsTest { + + @Test + void testNullBytes() { + assertThat(MosaicObjects.convertStatsValue(null, DataTypes.INT())).isNull(); + } + + @Test + void testEmptyBytes() { + assertThat(MosaicObjects.convertStatsValue(new byte[0], DataTypes.INT())).isNull(); + } + + @Test + void testBoolean() { + assertThat(MosaicObjects.convertStatsValue(new byte[] {1}, DataTypes.BOOLEAN())) + .isEqualTo(true); + assertThat(MosaicObjects.convertStatsValue(new byte[] {0}, DataTypes.BOOLEAN())) + .isEqualTo(false); + } + + @Test + void testTinyInt() { + assertThat(MosaicObjects.convertStatsValue(new byte[] {42}, DataTypes.TINYINT())) + .isEqualTo((byte) 42); + assertThat(MosaicObjects.convertStatsValue(new byte[] {(byte) -1}, DataTypes.TINYINT())) + .isEqualTo((byte) -1); + } + + @Test + void testSmallInt() { + byte[] bytes = ByteBuffer.allocate(2).putShort((short) 1234).array(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.SMALLINT())) + .isEqualTo((short) 1234); + } + + @Test + void testInt() { + byte[] bytes = ByteBuffer.allocate(4).putInt(123456).array(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.INT())).isEqualTo(123456); + } + + @Test + void testIntNegative() { + byte[] bytes = ByteBuffer.allocate(4).putInt(-999).array(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.INT())).isEqualTo(-999); + } + + @Test + void testBigInt() { + byte[] bytes = ByteBuffer.allocate(8).putLong(9876543210L).array(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.BIGINT())) + .isEqualTo(9876543210L); + } + + @Test + void testFloat() { + byte[] bytes = ByteBuffer.allocate(4).putFloat(3.14f).array(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.FLOAT())).isEqualTo(3.14f); + } + + @Test + void testDouble() { + byte[] bytes = ByteBuffer.allocate(8).putDouble(2.718281828).array(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.DOUBLE())) + .isEqualTo(2.718281828); + } + + @Test + void testVarChar() { + byte[] bytes = "hello".getBytes(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.STRING())) + .isEqualTo(BinaryString.fromString("hello")); + } + + @Test + void testBinary() { + byte[] bytes = new byte[] {1, 2, 3, 4, 5}; + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.BYTES())).isEqualTo(bytes); + } + + @Test + void testDate() { + byte[] bytes = ByteBuffer.allocate(4).putInt(18000).array(); + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.DATE())).isEqualTo(18000); + } + + @Test + void testTimestampMillis() { + long millis = 1700000000000L; + byte[] bytes = ByteBuffer.allocate(8).putLong(millis).array(); + Object result = MosaicObjects.convertStatsValue(bytes, DataTypes.TIMESTAMP(3)); + assertThat(result).isEqualTo(Timestamp.fromEpochMillis(millis)); + } + + @Test + void testTimestampMicros() { + long micros = 1700000000000000L; + byte[] bytes = ByteBuffer.allocate(8).putLong(micros).array(); + Object result = MosaicObjects.convertStatsValue(bytes, DataTypes.TIMESTAMP(6)); + assertThat(result).isEqualTo(Timestamp.fromMicros(micros)); + } + + @Test + void testDecimal() { + // 1000 in big-endian two's complement = 0x03E8 + byte[] beBytes = new byte[] {0x03, (byte) 0xE8}; + Object result = MosaicObjects.convertStatsValue(beBytes, DataTypes.DECIMAL(10, 2)); + assertThat(result).isInstanceOf(Decimal.class); + Decimal decimal = (Decimal) result; + assertThat(decimal.toBigDecimal().intValue()).isEqualTo(10); + } + + @Test + void testTimestampNanos() { + long millis = 1700000000123L; + int nanosOfMilli = 456789; + byte[] bytes = ByteBuffer.allocate(12).putLong(millis).putInt(nanosOfMilli).array(); + Object result = MosaicObjects.convertStatsValue(bytes, DataTypes.TIMESTAMP(9)); + assertThat(result).isEqualTo(Timestamp.fromEpochMillis(millis, nanosOfMilli)); + } + + @Test + void testTimestampNanosPrecision7() { + long millis = 1700000000000L; + int nanosOfMilli = 100000; + byte[] bytes = ByteBuffer.allocate(12).putLong(millis).putInt(nanosOfMilli).array(); + Object result = MosaicObjects.convertStatsValue(bytes, DataTypes.TIMESTAMP(7)); + assertThat(result).isEqualTo(Timestamp.fromEpochMillis(millis, nanosOfMilli)); + } + + @Test + void testTimestampWithLocalTimeZoneMillis() { + long millis = 1700000000000L; + byte[] bytes = ByteBuffer.allocate(8).putLong(millis).array(); + Object result = + MosaicObjects.convertStatsValue(bytes, DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)); + assertThat(result).isEqualTo(Timestamp.fromEpochMillis(millis)); + } + + @Test + void testTimestampWithLocalTimeZoneMicros() { + long micros = 1700000000000000L; + byte[] bytes = ByteBuffer.allocate(8).putLong(micros).array(); + Object result = + MosaicObjects.convertStatsValue(bytes, DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(6)); + assertThat(result).isEqualTo(Timestamp.fromMicros(micros)); + } + + @Test + void testTimestampWithLocalTimeZoneNanos() { + long millis = 1700000000123L; + int nanosOfMilli = 456789; + byte[] bytes = ByteBuffer.allocate(12).putLong(millis).putInt(nanosOfMilli).array(); + Object result = + MosaicObjects.convertStatsValue(bytes, DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(9)); + assertThat(result).isEqualTo(Timestamp.fromEpochMillis(millis, nanosOfMilli)); + } + + @Test + void testEmptyStringVarChar() { + Object result = MosaicObjects.convertStatsValue(new byte[0], DataTypes.STRING()); + assertThat(result).isEqualTo(BinaryString.fromString("")); + } + + @Test + void testEmptyBinary() { + Object result = MosaicObjects.convertStatsValue(new byte[0], DataTypes.BYTES()); + assertThat(result).isEqualTo(new byte[0]); + } + + @Test + void testUnsupportedTypeReturnsNull() { + byte[] bytes = new byte[] {1, 2, 3}; + assertThat(MosaicObjects.convertStatsValue(bytes, DataTypes.ARRAY(DataTypes.INT()))) + .isNull(); + } +} diff --git a/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicReaderWriterTest.java b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicReaderWriterTest.java new file mode 100644 index 000000000000..60efceed08e3 --- /dev/null +++ b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicReaderWriterTest.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.serializer.InternalRowSerializer; +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatReaderContext; +import org.apache.paimon.format.FormatReaderFactory; +import org.apache.paimon.format.FormatWriter; +import org.apache.paimon.format.FormatWriterFactory; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.options.Options; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.reader.FileRecordIterator; +import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** Integration tests for Mosaic reader and writer. */ +class MosaicReaderWriterTest { + + @TempDir java.nio.file.Path tempDir; + + @BeforeAll + static void checkNativeLibrary() { + assumeTrue(isNativeAvailable(), "Mosaic native library not available"); + } + + @Test + void testWriteAndRead() throws IOException { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + + writeRows( + rowType, + path, + GenericRow.of(1, BinaryString.fromString("hello")), + GenericRow.of(2, BinaryString.fromString("world"))); + + List result = readAll(rowType, rowType, path, null); + assertThat(result).hasSize(2); + assertThat(result.get(0).getInt(0)).isEqualTo(1); + assertThat(result.get(0).getString(1).toString()).isEqualTo("hello"); + assertThat(result.get(1).getInt(0)).isEqualTo(2); + assertThat(result.get(1).getString(1).toString()).isEqualTo("world"); + } + + @Test + void testNullValues() throws IOException { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + + writeRows( + rowType, + path, + GenericRow.of(1, null), + GenericRow.of(null, BinaryString.fromString("test")), + GenericRow.of(null, null)); + + List result = readAll(rowType, rowType, path, null); + assertThat(result).hasSize(3); + assertThat(result.get(0).isNullAt(1)).isTrue(); + assertThat(result.get(1).isNullAt(0)).isTrue(); + assertThat(result.get(2).isNullAt(0)).isTrue(); + assertThat(result.get(2).isNullAt(1)).isTrue(); + } + + @Test + void testColumnProjection() throws IOException { + RowType writeType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .field("f_double", DataTypes.DOUBLE()) + .build(); + RowType readType = RowType.builder().field("f_string", DataTypes.STRING()).build(); + Path path = newPath(); + + writeRows( + writeType, + path, + GenericRow.of(1, BinaryString.fromString("aaa"), 1.1), + GenericRow.of(2, BinaryString.fromString("bbb"), 2.2)); + + List result = readAll(writeType, readType, path, null); + assertThat(result).hasSize(2); + assertThat(result.get(0).getString(0).toString()).isEqualTo("aaa"); + assertThat(result.get(1).getString(0).toString()).isEqualTo("bbb"); + } + + @Test + void testLargeDataset() throws IOException { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + + int numRows = 10000; + GenericRow[] rows = new GenericRow[numRows]; + for (int i = 0; i < numRows; i++) { + rows[i] = GenericRow.of(i, BinaryString.fromString("row" + i)); + } + writeRows(rowType, path, rows); + + List result = readAll(rowType, rowType, path, null); + assertThat(result).hasSize(numRows); + assertThat(result.get(0).getInt(0)).isEqualTo(0); + assertThat(result.get(numRows - 1).getInt(0)).isEqualTo(numRows - 1); + } + + @Test + void testRowGroupPredicateFiltering() throws IOException { + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .build(); + Path path = newPath(); + + int numRows = 10000; + GenericRow[] rows = new GenericRow[numRows]; + for (int i = 0; i < numRows; i++) { + rows[i] = GenericRow.of(i, BinaryString.fromString("v" + i)); + } + writeRows(rowType, path, "f_int", rows); + + // Predicate that cannot match any row group (all values are 0..9999) + PredicateBuilder builder = new PredicateBuilder(rowType); + Predicate predicate = builder.greaterThan(0, 99999); + List result = + readAll(rowType, rowType, path, Collections.singletonList(predicate)); + assertThat(result).isEmpty(); + + // Predicate that matches the row group (values include range 0..9999) + Predicate matchPredicate = builder.greaterThan(0, 5000); + List matchResult = + readAll(rowType, rowType, path, Collections.singletonList(matchPredicate)); + assertThat(matchResult).hasSize(numRows); + } + + @Test + void testReturnedPosition() throws IOException { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + + writeRows( + rowType, + path, + GenericRow.of(1, BinaryString.fromString("a")), + GenericRow.of(2, BinaryString.fromString("b")), + GenericRow.of(3, BinaryString.fromString("c"))); + + MosaicFileFormat format = createFormat(); + FormatReaderFactory readerFactory = format.createReaderFactory(rowType, rowType, null); + LocalFileIO fileIO = new LocalFileIO(); + RecordReader reader = + readerFactory.createReader( + new FormatReaderContext(fileIO, path, fileIO.getFileSize(path))); + + RecordReader.RecordIterator batch = reader.readBatch(); + assertThat(batch).isNotNull(); + FileRecordIterator fileIter = (FileRecordIterator) batch; + + fileIter.next(); + assertThat(fileIter.returnedPosition()).isEqualTo(0); + fileIter.next(); + assertThat(fileIter.returnedPosition()).isEqualTo(1); + fileIter.next(); + assertThat(fileIter.returnedPosition()).isEqualTo(2); + + reader.close(); + } + + @Test + void testProjectionWithMissingColumns() throws IOException { + RowType writeType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .build(); + // Read type has a column that doesn't exist in the file (schema evolution) + RowType readType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_new_col", DataTypes.BIGINT()) + .field("f_string", DataTypes.STRING()) + .build(); + Path path = newPath(); + + writeRows( + writeType, + path, + GenericRow.of(1, BinaryString.fromString("aaa")), + GenericRow.of(2, BinaryString.fromString("bbb"))); + + List result = readAll(writeType, readType, path, null); + assertThat(result).hasSize(2); + assertThat(result.get(0).getInt(0)).isEqualTo(1); + assertThat(result.get(0).isNullAt(1)).isTrue(); + assertThat(result.get(0).getString(2).toString()).isEqualTo("aaa"); + assertThat(result.get(1).getInt(0)).isEqualTo(2); + assertThat(result.get(1).isNullAt(1)).isTrue(); + assertThat(result.get(1).getString(2).toString()).isEqualTo("bbb"); + } + + @Test + void testProjectionAllColumnsMissing() throws IOException { + RowType writeType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .build(); + // Read type has only columns that don't exist in the file + RowType readType = + RowType.builder() + .field("f_new_a", DataTypes.INT()) + .field("f_new_b", DataTypes.STRING()) + .build(); + Path path = newPath(); + + writeRows( + writeType, + path, + GenericRow.of(1, BinaryString.fromString("x")), + GenericRow.of(2, BinaryString.fromString("y"))); + + List result = readAll(writeType, readType, path, null); + assertThat(result).hasSize(2); + assertThat(result.get(0).isNullAt(0)).isTrue(); + assertThat(result.get(0).isNullAt(1)).isTrue(); + assertThat(result.get(1).isNullAt(0)).isTrue(); + assertThat(result.get(1).isNullAt(1)).isTrue(); + } + + @Test + void testUnsupportedCompressionThrows() { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + MosaicFileFormat format = createFormat(); + FormatWriterFactory writerFactory = format.createWriterFactory(rowType); + LocalFileIO fileIO = new LocalFileIO(); + + assertThatThrownBy(() -> writerFactory.create(fileIO.newOutputStream(path, false), "lz4")) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("lz4"); + } + + @Test + void testReachTargetSize() throws IOException { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + MosaicFileFormat format = createFormat(); + FormatWriterFactory writerFactory = format.createWriterFactory(rowType); + + LocalFileIO fileIO = new LocalFileIO(); + FormatWriter writer = writerFactory.create(fileIO.newOutputStream(path, false), "zstd"); + + boolean reached = false; + for (int i = 0; i < 100000; i++) { + writer.addElement(GenericRow.of(i, BinaryString.fromString("value_" + i + "_padding"))); + if (writer.reachTargetSize(true, 1024)) { + reached = true; + break; + } + } + writer.close(); + assertThat(reached).isTrue(); + } + + private Path newPath() { + return new Path(tempDir.toUri().toString(), UUID.randomUUID() + ".mosaic"); + } + + private void writeRows(RowType rowType, Path path, GenericRow... rows) throws IOException { + writeRows(rowType, path, "", rows); + } + + private void writeRows(RowType rowType, Path path, String statsColumns, GenericRow... rows) + throws IOException { + MosaicFileFormat format = createFormat(statsColumns); + FormatWriterFactory writerFactory = format.createWriterFactory(rowType); + LocalFileIO fileIO = new LocalFileIO(); + FormatWriter writer = writerFactory.create(fileIO.newOutputStream(path, false), "zstd"); + for (GenericRow row : rows) { + writer.addElement(row); + } + writer.close(); + } + + private List readAll( + RowType dataType, RowType readType, Path path, List predicates) + throws IOException { + MosaicFileFormat format = createFormat(); + FormatReaderFactory readerFactory = + format.createReaderFactory(dataType, readType, predicates); + LocalFileIO fileIO = new LocalFileIO(); + RecordReader reader = + readerFactory.createReader( + new FormatReaderContext(fileIO, path, fileIO.getFileSize(path))); + + InternalRowSerializer serializer = new InternalRowSerializer(readType); + List result = new ArrayList<>(); + reader.forEachRemaining(row -> result.add(serializer.copy(row))); + reader.close(); + return result; + } + + private static MosaicFileFormat createFormat() { + return createFormat(""); + } + + private static MosaicFileFormat createFormat(String statsColumns) { + Options options = new Options(); + if (!statsColumns.isEmpty()) { + options.set(MosaicFileFormat.STATS_COLUMNS, statsColumns); + } + return new MosaicFileFormat(new FileFormatFactory.FormatContext(options, 1024, 1024)); + } + + private static boolean isNativeAvailable() { + try { + Class.forName("org.apache.paimon.mosaic.NativeLib"); + return true; + } catch (Throwable t) { + return false; + } + } +} diff --git a/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractorTest.java b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractorTest.java new file mode 100644 index 000000000000..8477c5b06540 --- /dev/null +++ b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicSimpleStatsExtractorTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.format.FileFormat; +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatWriter; +import org.apache.paimon.format.FormatWriterFactory; +import org.apache.paimon.format.SimpleColStats; +import org.apache.paimon.format.SimpleColStatsExtractorTest; +import org.apache.paimon.format.SimpleStatsExtractor; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.options.Options; +import org.apache.paimon.statistics.SimpleColStatsCollector; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.UUID; +import java.util.stream.IntStream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** Tests for {@link MosaicSimpleStatsExtractor}. */ +class MosaicSimpleStatsExtractorTest extends SimpleColStatsExtractorTest { + + @TempDir java.nio.file.Path statsTestTempDir; + + @BeforeAll + static void checkNativeLibrary() { + assumeTrue(isNativeAvailable(), "Mosaic native library not available"); + } + + @Override + protected FileFormat createFormat() { + Options options = new Options(); + options.set( + MosaicFileFormat.STATS_COLUMNS, + "f_boolean,f_tinyint,f_smallint,f_int,f_bigint,f_float," + + "f_double,f_string,f_decimal_5_2,f_date,f_timestamp3,f_timestamp6"); + return new MosaicFileFormat(new FileFormatFactory.FormatContext(options, 1024, 1024)); + } + + @Override + protected RowType rowType() { + return RowType.builder() + .field("f_boolean", DataTypes.BOOLEAN()) + .field("f_tinyint", DataTypes.TINYINT()) + .field("f_smallint", DataTypes.SMALLINT()) + .field("f_int", DataTypes.INT()) + .field("f_bigint", DataTypes.BIGINT()) + .field("f_float", DataTypes.FLOAT()) + .field("f_double", DataTypes.DOUBLE()) + .field("f_string", DataTypes.VARCHAR(100)) + .field("f_decimal_5_2", DataTypes.DECIMAL(5, 2)) + .field("f_date", DataTypes.DATE()) + .field("f_timestamp3", DataTypes.TIMESTAMP(3)) + .field("f_timestamp6", DataTypes.TIMESTAMP(6)) + .build(); + } + + @Override + protected String fileCompression() { + return "zstd"; + } + + @Test + void testUntrackedColumnsReturnNone() throws IOException { + // stats_columns only tracks f_int, but the table has f_int + f_string + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .build(); + Options options = new Options(); + options.set(MosaicFileFormat.STATS_COLUMNS, "f_int"); + MosaicFileFormat format = + new MosaicFileFormat(new FileFormatFactory.FormatContext(options, 1024, 1024)); + + Path path = new Path(statsTestTempDir.toUri().toString(), UUID.randomUUID() + ".mosaic"); + LocalFileIO fileIO = new LocalFileIO(); + FormatWriterFactory writerFactory = format.createWriterFactory(rowType); + FormatWriter writer = writerFactory.create(fileIO.newOutputStream(path, false), "zstd"); + writer.addElement(GenericRow.of(1, BinaryString.fromString("a"))); + writer.addElement(GenericRow.of(2, BinaryString.fromString("b"))); + writer.close(); + + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, rowType.getFieldCount()) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + SimpleColStats[] stats = extractor.extract(fileIO, path, fileIO.getFileSize(path)); + + // f_int is tracked, should have real stats + assertThat(stats[0].min()).isEqualTo(1); + assertThat(stats[0].max()).isEqualTo(2); + assertThat(stats[0].nullCount()).isEqualTo(0L); + // f_string is NOT tracked, should be NONE (null nullCount) + assertThat(stats[1].min()).isNull(); + assertThat(stats[1].max()).isNull(); + assertThat(stats[1].nullCount()).isNull(); + } + + @Test + void testBinaryColumnStatsNoException() throws Exception { + // Binary columns produce byte[] from convertStatsValue, which is not Comparable. + // Verify multi-row-group aggregation doesn't throw ClassCastException. + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_binary", DataTypes.VARBINARY(100)) + .build(); + // Build a fake MosaicWriterMetadata with binary stats across 2 row groups + java.lang.reflect.Constructor ctor = + org.apache.paimon.mosaic.ColumnStatistics.class.getDeclaredConstructor( + long.class, byte[].class, byte[].class); + ctor.setAccessible(true); + + java.util.Map rg0 = + new java.util.HashMap<>(); + rg0.put( + "f_int", + (org.apache.paimon.mosaic.ColumnStatistics) + ctor.newInstance(0L, intBytes(0), intBytes(100))); + rg0.put( + "f_binary", + (org.apache.paimon.mosaic.ColumnStatistics) + ctor.newInstance(0L, new byte[] {1, 2}, new byte[] {3, 4})); + + java.util.Map rg1 = + new java.util.HashMap<>(); + rg1.put( + "f_int", + (org.apache.paimon.mosaic.ColumnStatistics) + ctor.newInstance(0L, intBytes(50), intBytes(200))); + rg1.put( + "f_binary", + (org.apache.paimon.mosaic.ColumnStatistics) + ctor.newInstance(0L, new byte[] {5, 6}, new byte[] {7, 8})); + + java.util.List> allStats = + java.util.Arrays.asList(rg0, rg1); + MosaicWriterMetadata metadata = + new MosaicWriterMetadata(2, allStats, java.util.Arrays.asList("f_int", "f_binary")); + + // Write a minimal file (only f_int in stats_columns since native rejects binary) + Options options = new Options(); + options.set(MosaicFileFormat.STATS_COLUMNS, "f_int"); + MosaicFileFormat format = + new MosaicFileFormat(new FileFormatFactory.FormatContext(options, 1024, 1024)); + Path path = new Path(statsTestTempDir.toUri().toString(), UUID.randomUUID() + ".mosaic"); + LocalFileIO fileIO = new LocalFileIO(); + FormatWriterFactory writerFactory = format.createWriterFactory(rowType); + FormatWriter writer = writerFactory.create(fileIO.newOutputStream(path, false), "zstd"); + writer.addElement(GenericRow.of(1, new byte[] {1})); + writer.close(); + + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, rowType.getFieldCount()) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + // Should not throw ClassCastException + SimpleColStats[] stats = + extractor.extract(fileIO, path, fileIO.getFileSize(path), metadata); + + // f_int aggregated across row groups: min=0, max=200 + assertThat(stats[0].min()).isEqualTo(0); + assertThat(stats[0].max()).isEqualTo(200); + // f_binary min/max should be null (byte[] not Comparable, skipped) + assertThat(stats[1].min()).isNull(); + assertThat(stats[1].max()).isNull(); + } + + private static byte[] intBytes(int value) { + return java.nio.ByteBuffer.allocate(4).putInt(value).array(); + } + + private static boolean isNativeAvailable() { + try { + Class.forName("org.apache.paimon.mosaic.NativeLib"); + return true; + } catch (Throwable t) { + return false; + } + } +} diff --git a/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicWriterMetadataTest.java b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicWriterMetadataTest.java new file mode 100644 index 000000000000..4caf65d66176 --- /dev/null +++ b/paimon-mosaic/src/test/java/org/apache/paimon/format/mosaic/MosaicWriterMetadataTest.java @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.mosaic; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatWriter; +import org.apache.paimon.format.FormatWriterFactory; +import org.apache.paimon.format.SimpleColStats; +import org.apache.paimon.format.SimpleStatsExtractor; +import org.apache.paimon.format.SimpleStatsExtractor.FileInfo; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.options.Options; +import org.apache.paimon.statistics.SimpleColStatsCollector; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.Pair; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.UUID; +import java.util.stream.IntStream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** Tests for writer metadata based stats extraction in Mosaic format. */ +class MosaicWriterMetadataTest { + + @TempDir java.nio.file.Path tempDir; + + @BeforeAll + static void checkNativeLibrary() { + assumeTrue(isNativeAvailable(), "Mosaic native library not available"); + } + + @Test + void testWriterMetadataNotNull() throws IOException { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + + FormatWriter writer = createWriter(rowType, path, "f0,f1"); + writer.addElement(GenericRow.of(1, BinaryString.fromString("hello"))); + writer.addElement(GenericRow.of(2, BinaryString.fromString("world"))); + writer.close(); + + Object metadata = writer.writerMetadata(); + assertThat(metadata).isNotNull(); + assertThat(metadata).isInstanceOf(MosaicWriterMetadata.class); + + MosaicWriterMetadata mosaicMeta = (MosaicWriterMetadata) metadata; + assertThat(mosaicMeta.numRowGroups()).isGreaterThan(0); + } + + @Test + void testStatsFromMetadataMatchesStatsFromFile() throws IOException { + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_bigint", DataTypes.BIGINT()) + .field("f_string", DataTypes.STRING()) + .field("f_double", DataTypes.DOUBLE()) + .build(); + Path path = newPath(); + String statsColumns = "f_int,f_bigint,f_string,f_double"; + + FormatWriter writer = createWriter(rowType, path, statsColumns); + for (int i = 0; i < 1000; i++) { + writer.addElement( + GenericRow.of(i, (long) i * 100, BinaryString.fromString("val_" + i), i * 1.1)); + } + writer.close(); + + Object metadata = writer.writerMetadata(); + assertThat(metadata).isNotNull(); + + MosaicFileFormat format = createFormat(statsColumns); + int fieldCount = rowType.getFieldCount(); + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, fieldCount) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + LocalFileIO fileIO = new LocalFileIO(); + long fileSize = fileIO.getFileSize(path); + + SimpleColStats[] fromFile = extractor.extract(fileIO, path, fileSize); + SimpleColStats[] fromMetadata = extractor.extract(fileIO, path, fileSize, metadata); + + assertThat(fromMetadata).isEqualTo(fromFile); + } + + @Test + void testStatsFromMetadataWithNullValues() throws IOException { + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .build(); + Path path = newPath(); + String statsColumns = "f_int,f_string"; + + FormatWriter writer = createWriter(rowType, path, statsColumns); + writer.addElement(GenericRow.of(1, null)); + writer.addElement(GenericRow.of(null, BinaryString.fromString("a"))); + writer.addElement(GenericRow.of(3, BinaryString.fromString("b"))); + writer.close(); + + Object metadata = writer.writerMetadata(); + MosaicFileFormat format = createFormat(statsColumns); + int fieldCount = rowType.getFieldCount(); + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, fieldCount) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + LocalFileIO fileIO = new LocalFileIO(); + long fileSize = fileIO.getFileSize(path); + + SimpleColStats[] fromMetadata = extractor.extract(fileIO, path, fileSize, metadata); + assertThat(fromMetadata).isNotNull(); + assertThat(fromMetadata[0].nullCount()).isEqualTo(1L); + assertThat(fromMetadata[1].nullCount()).isEqualTo(1L); + } + + @Test + void testExtractWithFileInfoRowCount() throws IOException { + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .build(); + Path path = newPath(); + String statsColumns = "f_int,f_string"; + + int numRows = 500; + FormatWriter writer = createWriter(rowType, path, statsColumns); + for (int i = 0; i < numRows; i++) { + writer.addElement(GenericRow.of(i, BinaryString.fromString("row_" + i))); + } + writer.close(); + + MosaicFileFormat format = createFormat(statsColumns); + int fieldCount = rowType.getFieldCount(); + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, fieldCount) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + LocalFileIO fileIO = new LocalFileIO(); + long fileSize = fileIO.getFileSize(path); + + Pair result = + extractor.extractWithFileInfo(fileIO, path, fileSize); + assertThat(result.getRight().getRowCount()).isEqualTo(numRows); + assertThat(result.getLeft()).isNotNull(); + assertThat(result.getLeft()).hasSize(fieldCount); + } + + @Test + void testPartialStatsColumnsFromMetadata() throws IOException { + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .field("f_double", DataTypes.DOUBLE()) + .build(); + Path path = newPath(); + String statsColumns = "f_int"; + + FormatWriter writer = createWriter(rowType, path, statsColumns); + writer.addElement(GenericRow.of(1, BinaryString.fromString("a"), 1.0)); + writer.addElement(GenericRow.of(null, BinaryString.fromString("b"), 2.0)); + writer.addElement(GenericRow.of(3, null, null)); + writer.close(); + + Object metadata = writer.writerMetadata(); + assertThat(metadata).isInstanceOf(MosaicWriterMetadata.class); + MosaicWriterMetadata mosaicMeta = (MosaicWriterMetadata) metadata; + assertThat(mosaicMeta.statsColumnNames()).containsExactly("f_int"); + + MosaicFileFormat format = createFormat(statsColumns); + int fieldCount = rowType.getFieldCount(); + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, fieldCount) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + LocalFileIO fileIO = new LocalFileIO(); + long fileSize = fileIO.getFileSize(path); + + SimpleColStats[] fromMetadata = extractor.extract(fileIO, path, fileSize, metadata); + + // f_int has stats: min=1, max=3, nullCount=1 + assertThat(fromMetadata[0].min()).isEqualTo(1); + assertThat(fromMetadata[0].max()).isEqualTo(3); + assertThat(fromMetadata[0].nullCount()).isEqualTo(1L); + + // f_string and f_double have no stats (not in statsColumns) + assertThat(fromMetadata[1].min()).isNull(); + assertThat(fromMetadata[1].max()).isNull(); + assertThat(fromMetadata[1].nullCount()).isNull(); + assertThat(fromMetadata[2].min()).isNull(); + assertThat(fromMetadata[2].max()).isNull(); + assertThat(fromMetadata[2].nullCount()).isNull(); + } + + @Test + void testStatsOnMiddleColumn() throws IOException { + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .field("f_double", DataTypes.DOUBLE()) + .build(); + Path path = newPath(); + String statsColumns = "f_string"; + + FormatWriter writer = createWriter(rowType, path, statsColumns); + writer.addElement(GenericRow.of(1, BinaryString.fromString("banana"), 1.0)); + writer.addElement(GenericRow.of(2, BinaryString.fromString("apple"), 2.0)); + writer.addElement(GenericRow.of(3, null, 3.0)); + writer.close(); + + Object metadata = writer.writerMetadata(); + MosaicFileFormat format = createFormat(statsColumns); + int fieldCount = rowType.getFieldCount(); + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, fieldCount) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + LocalFileIO fileIO = new LocalFileIO(); + long fileSize = fileIO.getFileSize(path); + + SimpleColStats[] fromMetadata = extractor.extract(fileIO, path, fileSize, metadata); + + // f_int has no stats + assertThat(fromMetadata[0].min()).isNull(); + assertThat(fromMetadata[0].max()).isNull(); + assertThat(fromMetadata[0].nullCount()).isNull(); + + // f_string has stats: min="apple", max="banana", nullCount=1 + assertThat(fromMetadata[1].min()).isEqualTo(BinaryString.fromString("apple")); + assertThat(fromMetadata[1].max()).isEqualTo(BinaryString.fromString("banana")); + assertThat(fromMetadata[1].nullCount()).isEqualTo(1L); + + // f_double has no stats + assertThat(fromMetadata[2].min()).isNull(); + assertThat(fromMetadata[2].max()).isNull(); + assertThat(fromMetadata[2].nullCount()).isNull(); + } + + @Test + void testPartialStatsColumnsFromFile() throws IOException { + RowType rowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .field("f_double", DataTypes.DOUBLE()) + .build(); + Path path = newPath(); + String statsColumns = "f_string"; + + FormatWriter writer = createWriter(rowType, path, statsColumns); + writer.addElement(GenericRow.of(1, BinaryString.fromString("banana"), 1.0)); + writer.addElement(GenericRow.of(2, BinaryString.fromString("apple"), 2.0)); + writer.addElement(GenericRow.of(3, null, 3.0)); + writer.close(); + + // Extract from file (no writer metadata), simulating fallback path + MosaicFileFormat format = createFormat(statsColumns); + int fieldCount = rowType.getFieldCount(); + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, fieldCount) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + LocalFileIO fileIO = new LocalFileIO(); + long fileSize = fileIO.getFileSize(path); + + SimpleColStats[] fromFile = extractor.extract(fileIO, path, fileSize); + + // f_int has no stats in file + assertThat(fromFile[0].min()).isNull(); + assertThat(fromFile[0].max()).isNull(); + assertThat(fromFile[0].nullCount()).isNull(); + + // f_string has stats + assertThat(fromFile[1].min()).isEqualTo(BinaryString.fromString("apple")); + assertThat(fromFile[1].max()).isEqualTo(BinaryString.fromString("banana")); + assertThat(fromFile[1].nullCount()).isEqualTo(1L); + + // f_double has no stats in file + assertThat(fromFile[2].min()).isNull(); + assertThat(fromFile[2].max()).isNull(); + assertThat(fromFile[2].nullCount()).isNull(); + } + + @Test + void testFallbackToFileWhenMetadataIsNull() throws IOException { + RowType rowType = DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()); + Path path = newPath(); + String statsColumns = "f0,f1"; + + FormatWriter writer = createWriter(rowType, path, statsColumns); + writer.addElement(GenericRow.of(10, BinaryString.fromString("test"))); + writer.close(); + + MosaicFileFormat format = createFormat(statsColumns); + int fieldCount = rowType.getFieldCount(); + SimpleColStatsCollector.Factory[] collectors = + IntStream.range(0, fieldCount) + .mapToObj(i -> SimpleColStatsCollector.from("full")) + .toArray(SimpleColStatsCollector.Factory[]::new); + + SimpleStatsExtractor extractor = format.createStatsExtractor(rowType, collectors).get(); + LocalFileIO fileIO = new LocalFileIO(); + long fileSize = fileIO.getFileSize(path); + + SimpleColStats[] fromFile = extractor.extract(fileIO, path, fileSize); + SimpleColStats[] fromNull = extractor.extract(fileIO, path, fileSize, null); + + assertThat(fromNull).isEqualTo(fromFile); + } + + private Path newPath() { + return new Path(tempDir.toUri().toString(), UUID.randomUUID() + ".mosaic"); + } + + private FormatWriter createWriter(RowType rowType, Path path, String statsColumns) + throws IOException { + MosaicFileFormat format = createFormat(statsColumns); + FormatWriterFactory writerFactory = format.createWriterFactory(rowType); + LocalFileIO fileIO = new LocalFileIO(); + return writerFactory.create(fileIO.newOutputStream(path, false), "zstd"); + } + + private static MosaicFileFormat createFormat() { + return createFormat(""); + } + + private static MosaicFileFormat createFormat(String statsColumns) { + Options options = new Options(); + if (!statsColumns.isEmpty()) { + options.set(MosaicFileFormat.STATS_COLUMNS, statsColumns); + } + return new MosaicFileFormat(new FileFormatFactory.FormatContext(options, 1024, 1024)); + } + + private static boolean isNativeAvailable() { + try { + Class.forName("org.apache.paimon.mosaic.NativeLib"); + return true; + } catch (Throwable t) { + return false; + } + } +} diff --git a/paimon-python/README.md b/paimon-python/README.md index e216dc00197c..5f716c8828f6 100644 --- a/paimon-python/README.md +++ b/paimon-python/README.md @@ -31,3 +31,82 @@ pip3 install dist/*.tar.gz The command will install the package and core dependencies to your local Python environment. +# HDFS without a local Hadoop install + +`pypaimon` supports HDFS through a pure-protocol client based on +[`hdfs-native`](https://github.com/Kimahriman/hdfs-native) (Rust + PyO3). +Use it when you want HDFS access **without** installing Hadoop, a JDK, +`libhdfs`, or wrestling with `CLASSPATH` / `LD_LIBRARY_PATH`. + +Install with the optional extra: + +```commandline +pip install 'pypaimon[hdfs]' +``` + +The native backend requires **Python 3.10+** (and is unavailable on Windows). +On older interpreters the extra is skipped, so `pypaimon` still installs — keep +using the legacy `pyarrow` (`libhdfs`/JVM) backend there via +`hdfs.client.impl=pyarrow`. + +For `hdfs://` and `viewfs://` URIs this backend is now the default. +Switch back to the legacy `libhdfs` (JNI) path with: + +```python +catalog = CatalogFactory.create({ + "warehouse": "hdfs://ns1/warehouse", + "hdfs.client.impl": "pyarrow", # default: "native" +}) +``` + +## Sourcing the cluster wiring + +The client still needs to know about NameNode addresses, HA failover +groups, and `viewfs` mount tables. Three options: + +1. **Local xml** — set `HADOOP_CONF_DIR` (or the `hdfs.conf-dir` option) + to a directory containing `core-site.xml` / `hdfs-site.xml`. Only the + xml is required; no Hadoop binaries or JDK. + +2. **Catalog options (REST-friendly)** — pass the original Hadoop + key/values directly in catalog options. Keys with prefixes `dfs.`, + `fs.`, `hadoop.`, `ipc.`, `io.` are forwarded as-is. A REST catalog + can deliver these in its response, giving a fully zero-file client + experience: + + ```python + CatalogFactory.create({ + "warehouse": "viewfs://cluster/warehouse", + "dfs.nameservices": "ns1", + "dfs.ha.namenodes.ns1": "nn1,nn2", + "dfs.namenode.rpc-address.ns1.nn1": "host-1:8020", + "dfs.namenode.rpc-address.ns1.nn2": "host-2:8020", + "fs.viewfs.mounttable.cluster.link./prod": "hdfs://ns1/prod", + }) + ``` + +3. **Namespaced overrides** — use `hdfs.config.` to forward any + other Hadoop key not covered by the prefix whitelist. + +The three sources can be combined; catalog options take precedence over +xml. + +## Kerberos + +A secured cluster still needs the GSSAPI system library +(`libgssapi-krb5-2` on Debian/Ubuntu, `krb5` via Homebrew on macOS, +`krb5-libs` on RHEL) plus a `krb5.conf`. Provide credentials by either: + +- Running `kinit` yourself and pointing `KRB5CCNAME` at the cache, or +- Setting `security.kerberos.login.principal` and + `security.kerberos.login.keytab` in catalog options — `pypaimon` will + run `kinit` for you. + +## Fallback behaviour + +If the native backend fails to initialise (e.g. wheel missing on an +unsupported platform such as Windows), `pypaimon` automatically falls +back to the `pyarrow` (`libhdfs`/JVM) path and logs a warning. Disable +the fallback with `hdfs.client.fallback-to-pyarrow=false` if you want +hard failures instead. + diff --git a/paimon-python/dev/requirements-dev.txt b/paimon-python/dev/requirements-dev.txt index d4e9a0645b17..c83a2e44b8b6 100644 --- a/paimon-python/dev/requirements-dev.txt +++ b/paimon-python/dev/requirements-dev.txt @@ -21,11 +21,14 @@ duckdb==1.3.2 flake8==4.0.1 pytest~=7.0 -# Ray: 2.48+ has no wheel for Python 3.8; use 2.10.0 on 3.8, 2.48.0 on 3.9+ -ray>=2.10.0 +# merge_into needs Dataset.join (added in Ray 2.50). Python 3.8 has no 2.50 wheel. +ray>=2.10.0; python_version < "3.9" +ray>=2.50.0; python_version >= "3.9" requests parameterized # Vortex 0.71.0 regresses native predicate pushdown on single-row files. vortex-data==0.70.0; python_version >= "3.11" +# merge_into condition expressions (optional, for condition tests) +datafusion>=52; python_version >= "3.10" # Lumina vector search (optional, for lumina index tests) lumina-data>=0.1.0 diff --git a/paimon-python/dev/requirements.txt b/paimon-python/dev/requirements.txt index 9cd250500bd6..49cdfef0340f 100644 --- a/paimon-python/dev/requirements.txt +++ b/paimon-python/dev/requirements.txt @@ -26,13 +26,16 @@ pandas>=1.1,<2; python_version < "3.7" pandas>=1.3,<3; python_version >= "3.7" and python_version < "3.9" pandas>=1.5,<3; python_version >= "3.9" polars>=0.9,<1; python_version<"3.8" -polars>=1,<2; python_version>="3.8" +polars>=1,<2; python_version=="3.8" +polars>=1.32,<2; python_version>="3.9" pyarrow>=6,<7; python_version < "3.8" pyarrow>=16,<20; python_version >= "3.8" pyroaring<=0.3.3; python_version < "3.7" pyroaring<=0.4.5; python_version == "3.7" pyroaring>=1.0.0; python_version >= "3.8" readerwriterlock>=1,<2 +requests>=2.21.0,<3 +urllib3>=1.26,<3 zstandard>=0.19,<1 backports.zstd>=1.0.0,<1.4.0; python_version >= "3.9" and python_version < "3.14" cramjam>=1.3.0,<3; python_version>="3.7" diff --git a/paimon-python/dev/run_mixed_tests.sh b/paimon-python/dev/run_mixed_tests.sh index 7cc32eedeec0..c2b75f476da7 100755 --- a/paimon-python/dev/run_mixed_tests.sh +++ b/paimon-python/dev/run_mixed_tests.sh @@ -373,6 +373,11 @@ run_tantivy_fulltext_test() { return 1 fi cd "$PAIMON_PYTHON_DIR" + echo "Installing Python jieba tokenizer dependency for Tantivy jieba index reads..." + if ! python -m pip install 'jieba>=0.42,<1'; then + echo -e "${RED}✗ Failed to install jieba${NC}" + return 1 + fi echo "Running Python test for JavaPyReadWriteTest.test_read_tantivy_full_text_index..." if python -m pytest java_py_read_write_test.py::JavaPyReadWriteTest::test_read_tantivy_full_text_index -v; then echo -e "${GREEN}✓ Python test completed successfully${NC}" @@ -612,6 +617,48 @@ run_py_variant_write_java_read_test() { fi } +# Function to run ROW format test (Java write, Python read, Python write, Java read) +run_row_format_test() { + echo -e "${YELLOW}=== Running ROW Format Test (Java Write → Python Read, Python Write → Java Read) ===${NC}" + + cd "$PROJECT_ROOT" + + echo "Running Maven test for JavaPyE2ETest.testJavaWriteRowAppendTable..." + if ! mvn test -Dtest=org.apache.paimon.JavaPyE2ETest#testJavaWriteRowAppendTable -pl paimon-core -q -Drun.e2e.tests=true; then + echo -e "${RED}✗ Java ROW write test failed${NC}" + return 1 + fi + echo -e "${GREEN}✓ Java ROW write test completed successfully${NC}" + + cd "$PAIMON_PYTHON_DIR" + echo "Running Python test for JavaPyReadWriteTest.test_read_row_append_table..." + if ! python -m pytest java_py_read_write_test.py::JavaPyReadWriteTest::test_read_row_append_table -v; then + echo -e "${RED}✗ Python ROW read test failed${NC}" + return 1 + fi + echo -e "${GREEN}✓ Python ROW read test completed successfully${NC}" + + echo "" + + echo "Running Python test for JavaPyReadWriteTest.test_py_write_row_append_table..." + if ! python -m pytest java_py_read_write_test.py::JavaPyReadWriteTest::test_py_write_row_append_table -v; then + echo -e "${RED}✗ Python ROW write test failed${NC}" + return 1 + fi + echo -e "${GREEN}✓ Python ROW write test completed successfully${NC}" + + echo "" + + cd "$PROJECT_ROOT" + echo "Running Maven test for JavaPyE2ETest.testReadRowAppendTable..." + if ! mvn test -Dtest=org.apache.paimon.JavaPyE2ETest#testReadRowAppendTable -pl paimon-core -q -Drun.e2e.tests=true; then + echo -e "${RED}✗ Java ROW read test failed${NC}" + return 1 + fi + echo -e "${GREEN}✓ Java ROW read test completed successfully${NC}" + return 0 +} + # Main execution main() { local java_write_result=0 @@ -635,6 +682,7 @@ main() { local multi_vector_dedicated_py_write_result=0 local java_variant_write_py_read_result=0 local py_variant_write_java_read_result=0 + local row_format_result=0 # Detect Python version PYTHON_VERSION=$(python -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')" 2>/dev/null || echo "unknown") @@ -815,6 +863,13 @@ main() { echo "" + # Run ROW format test (Java write + Python read + Python write + Java read) + if ! run_row_format_test; then + row_format_result=1 + fi + + echo "" + echo -e "${YELLOW}=== Test Results Summary ===${NC}" if [[ $java_write_result -eq 0 ]]; then @@ -943,12 +998,18 @@ main() { echo -e "${RED}✗ VARIANT Type Test (Python Write, Java Read): FAILED${NC}" fi + if [[ $row_format_result -eq 0 ]]; then + echo -e "${GREEN}✓ ROW Format Test (Java Write ↔ Python Read/Write): PASSED${NC}" + else + echo -e "${RED}✗ ROW Format Test (Java Write ↔ Python Read/Write): FAILED${NC}" + fi + echo "" # Clean up warehouse directory after all tests cleanup_warehouse - if [[ $java_write_result -eq 0 && $python_read_result -eq 0 && $python_write_result -eq 0 && $java_read_result -eq 0 && $pk_dv_result -eq 0 && $btree_index_result -eq 0 && $compressed_text_result -eq 0 && $tantivy_fulltext_result -eq 0 && $lumina_vector_result -eq 0 && $lumina_vector_btree_result -eq 0 && $compact_conflict_result -eq 0 && $blob_alter_compact_result -eq 0 && $data_evolution_result -eq 0 && $data_evolution_py_write_result -eq 0 && $java_variant_write_py_read_result -eq 0 && $py_variant_write_java_read_result -eq 0 && $vector_append_table_result -eq 0 && $vector_dedicated_java_write_result -eq 0 && $vector_dedicated_py_write_result -eq 0 && $multi_vector_dedicated_java_write_result -eq 0 && $multi_vector_dedicated_py_write_result -eq 0 ]]; then + if [[ $java_write_result -eq 0 && $python_read_result -eq 0 && $python_write_result -eq 0 && $java_read_result -eq 0 && $pk_dv_result -eq 0 && $btree_index_result -eq 0 && $compressed_text_result -eq 0 && $tantivy_fulltext_result -eq 0 && $lumina_vector_result -eq 0 && $lumina_vector_btree_result -eq 0 && $compact_conflict_result -eq 0 && $blob_alter_compact_result -eq 0 && $data_evolution_result -eq 0 && $data_evolution_py_write_result -eq 0 && $java_variant_write_py_read_result -eq 0 && $py_variant_write_java_read_result -eq 0 && $vector_append_table_result -eq 0 && $vector_dedicated_java_write_result -eq 0 && $vector_dedicated_py_write_result -eq 0 && $multi_vector_dedicated_java_write_result -eq 0 && $multi_vector_dedicated_py_write_result -eq 0 && $row_format_result -eq 0 ]]; then echo -e "${GREEN}🎉 All tests passed! Java-Python interoperability verified.${NC}" return 0 else @@ -958,4 +1019,4 @@ main() { } # Run main function -main "$@" \ No newline at end of file +main "$@" diff --git a/paimon-python/pypaimon/benchmark/hdfs_io_bench.py b/paimon-python/pypaimon/benchmark/hdfs_io_bench.py new file mode 100644 index 000000000000..5b6b12dd18de --- /dev/null +++ b/paimon-python/pypaimon/benchmark/hdfs_io_bench.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""HDFS FileIO benchmark: native (hdfs-native) vs pyarrow (libhdfs/JVM). + +Compares throughput of common FileIO operations between the two backends +against the same HDFS cluster. Each backend is exercised via the FileIO +factory by toggling the `hdfs.client.impl` option. + +Usage: + python -m pypaimon.benchmark.hdfs_io_bench \\ + --warehouse hdfs://localhost:8020/bench \\ + [--backend native|pyarrow|both] \\ + [--write-size-mb 256] \\ + [--list-files 1000] \\ + [--read-iters 3] + +Notes: +- `pyarrow` backend requires HADOOP_HOME + HADOOP_CONF_DIR + libhdfs. +- `native` backend requires `pip install pypaimon[hdfs]`. +- The benchmark writes/reads scratch files under /bench_/ + and removes them on exit. +""" + +import argparse +import os +import sys +import time +import uuid +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from pypaimon.common.file_io import FileIO # noqa: E402 +from pypaimon.common.options import Options # noqa: E402 + + +def _build_file_io(warehouse: str, backend: str) -> FileIO: + opts = Options({"hdfs.client.impl": backend}) + return FileIO.get(warehouse, opts) + + +def _human(seconds: float) -> str: + if seconds < 1e-3: + return f"{seconds * 1e6:.0f}us" + if seconds < 1: + return f"{seconds * 1e3:.1f}ms" + return f"{seconds:.2f}s" + + +def _bench_write(file_io: FileIO, root: str, size_mb: int) -> float: + payload = os.urandom(min(size_mb, 16) * 1024 * 1024) + path = f"{root}/write-{uuid.uuid4().hex[:8]}.bin" + t0 = time.perf_counter() + with file_io.new_output_stream(path) as stream: + written = 0 + target = size_mb * 1024 * 1024 + while written < target: + chunk = payload[: min(len(payload), target - written)] + n = stream.write(chunk) + written += n if isinstance(n, int) and n > 0 else len(chunk) + return time.perf_counter() - t0 + + +def _bench_read(file_io: FileIO, path: str) -> float: + t0 = time.perf_counter() + with file_io.new_input_stream(path) as stream: + while True: + chunk = stream.read(8 * 1024 * 1024) + if not chunk: + break + return time.perf_counter() - t0 + + +def _bench_list(file_io: FileIO, root: str, num_files: int) -> float: + scratch = f"{root}/list-{uuid.uuid4().hex[:8]}" + file_io.mkdirs(scratch) + try: + for i in range(num_files): + with file_io.new_output_stream(f"{scratch}/f-{i:06d}.txt") as s: + s.write(b"x") + t0 = time.perf_counter() + results = file_io.list_status(scratch) + _ = list(results) + return time.perf_counter() - t0 + finally: + file_io.delete(scratch, recursive=True) + + +def run_one(backend: str, args) -> None: + print(f"\n=== backend={backend} ===") + try: + file_io = _build_file_io(args.warehouse, backend) + except Exception as e: + print(f" init failed: {e}") + return + + bench_root = f"{args.warehouse.rstrip('/')}/bench_{uuid.uuid4().hex[:8]}" + file_io.mkdirs(bench_root) + try: + # Write + sample_path = f"{bench_root}/write-sample.bin" + with file_io.new_output_stream(sample_path) as stream: + payload = os.urandom(min(args.write_size_mb, 16) * 1024 * 1024) + written = 0 + target = args.write_size_mb * 1024 * 1024 + t0 = time.perf_counter() + while written < target: + chunk = payload[: min(len(payload), target - written)] + n = stream.write(chunk) + written += n if isinstance(n, int) and n > 0 else len(chunk) + write_elapsed = time.perf_counter() - t0 + mb_per_s = args.write_size_mb / write_elapsed if write_elapsed else 0 + print(f" write {args.write_size_mb}MB: " + f"{_human(write_elapsed)} ({mb_per_s:.1f} MB/s)") + + # Read (warm) + read_times = [] + for _ in range(args.read_iters): + read_times.append(_bench_read(file_io, sample_path)) + avg_read = sum(read_times) / len(read_times) + rmb_per_s = args.write_size_mb / avg_read if avg_read else 0 + print(f" read {args.write_size_mb}MB (avg of {args.read_iters}): " + f"{_human(avg_read)} ({rmb_per_s:.1f} MB/s)") + + # List + list_elapsed = _bench_list(file_io, bench_root, args.list_files) + print(f" list {args.list_files} files: {_human(list_elapsed)}") + + finally: + try: + file_io.delete(bench_root, recursive=True) + except Exception as e: + print(f" cleanup failed: {e}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--warehouse", required=True, + help="HDFS URI under which scratch files are created") + parser.add_argument("--backend", default="both", + choices=["native", "pyarrow", "both"]) + parser.add_argument("--write-size-mb", type=int, default=128) + parser.add_argument("--list-files", type=int, default=1000) + parser.add_argument("--read-iters", type=int, default=3) + args = parser.parse_args() + + backends = ["native", "pyarrow"] if args.backend == "both" else [args.backend] + for backend in backends: + run_one(backend, args) + + +if __name__ == "__main__": + main() diff --git a/paimon-python/pypaimon/catalog/catalog_factory.py b/paimon-python/pypaimon/catalog/catalog_factory.py index 741d785d299d..117c10b2e902 100644 --- a/paimon-python/pypaimon/catalog/catalog_factory.py +++ b/paimon-python/pypaimon/catalog/catalog_factory.py @@ -21,6 +21,7 @@ from pypaimon.catalog.catalog import Catalog from pypaimon.catalog.catalog_context import CatalogContext from pypaimon.catalog.filesystem_catalog import FileSystemCatalog +from pypaimon.catalog.jdbc_catalog import JdbcCatalog from pypaimon.catalog.rest.rest_catalog import RESTCatalog from pypaimon.common.options.config import CatalogOptions @@ -29,6 +30,7 @@ class CatalogFactory: CATALOG_REGISTRY = { "filesystem": FileSystemCatalog, + "jdbc": JdbcCatalog, "rest": RESTCatalog, } @@ -39,6 +41,6 @@ def create(catalog_options: Dict) -> Catalog: if catalog_class is None: raise ValueError("Unknown catalog identifier: {}. " "Available types: {}".format(identifier, list(CatalogFactory.CATALOG_REGISTRY.keys()))) - return catalog_class( - CatalogContext.create_from_options(Options(catalog_options))) if identifier == "rest" else catalog_class( - Options(catalog_options)) + if identifier in ("jdbc", "rest"): + return catalog_class(CatalogContext.create_from_options(Options(catalog_options))) + return catalog_class(Options(catalog_options)) diff --git a/paimon-python/pypaimon/catalog/filesystem_catalog.py b/paimon-python/pypaimon/catalog/filesystem_catalog.py index 86e2f775e769..b7356b45955b 100644 --- a/paimon-python/pypaimon/catalog/filesystem_catalog.py +++ b/paimon-python/pypaimon/catalog/filesystem_catalog.py @@ -142,11 +142,11 @@ def _load_data_table(self, identifier: Identifier) -> FileStoreTable: table_schema = self.get_table_schema(identifier) # Create catalog environment for filesystem catalog - # Filesystem catalog doesn't support version management by default + from pypaimon.catalog.filesystem_catalog_loader import FileSystemCatalogLoader catalog_environment = CatalogEnvironment( identifier=identifier, - uuid=None, # Filesystem catalog doesn't track table UUIDs - catalog_loader=None, # No catalog loader for filesystem + uuid=None, + catalog_loader=FileSystemCatalogLoader(self.catalog_context), supports_version_management=False ) diff --git a/paimon-python/pypaimon/catalog/jdbc_catalog.py b/paimon-python/pypaimon/catalog/jdbc_catalog.py new file mode 100644 index 000000000000..ccf5dbfa1a39 --- /dev/null +++ b/paimon-python/pypaimon/catalog/jdbc_catalog.py @@ -0,0 +1,622 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +import sqlite3 +from contextlib import closing, contextmanager +from typing import Dict, List, Optional, Tuple, Union +from urllib.parse import parse_qs, urlparse + +from pypaimon.catalog.catalog import Catalog +from pypaimon.catalog.catalog_context import CatalogContext +from pypaimon.catalog.catalog_environment import CatalogEnvironment +from pypaimon.catalog.catalog_exception import ( + DatabaseAlreadyExistException, + DatabaseNotExistException, + TableAlreadyExistException, + TableNotExistException +) +from pypaimon.catalog.database import Database +from pypaimon.common.file_io import FileIO +from pypaimon.common.identifier import Identifier +from pypaimon.common.options.config import CatalogOptions, JdbcCatalogOptions +from pypaimon.common.options.core_options import CoreOptions +from pypaimon.schema.schema import Schema +from pypaimon.schema.schema_change import SchemaChange +from pypaimon.schema.schema_manager import SchemaManager +from pypaimon.snapshot.snapshot import Snapshot +from pypaimon.snapshot.snapshot_commit import PartitionStatistics +from pypaimon.table.file_store_table import FileStoreTable +from pypaimon.table.table import Table + + +def _convert_qmark_placeholders(sql: str, placeholder: str) -> str: + if placeholder == "?": + return sql + + result = [] + in_single_quote = False + in_double_quote = False + index = 0 + while index < len(sql): + char = sql[index] + if char == "'" and not in_double_quote: + result.append(char) + if in_single_quote and index + 1 < len(sql) and sql[index + 1] == "'": + index += 1 + result.append(sql[index]) + else: + in_single_quote = not in_single_quote + elif char == '"' and not in_single_quote: + result.append(char) + if in_double_quote and index + 1 < len(sql) and sql[index + 1] == '"': + index += 1 + result.append(sql[index]) + else: + in_double_quote = not in_double_quote + elif char == "?" and not in_single_quote and not in_double_quote: + result.append(placeholder) + else: + result.append(char) + index += 1 + return "".join(result) + + +class _DbApiConnection: + def __init__(self, options: Dict[str, str]): + self.options = options + self.uri = options.get(CatalogOptions.URI.key()) + if not self.uri: + raise ValueError(f"Paimon '{CatalogOptions.URI.key()}' must be set for JDBC catalog") + self.protocol, self.placeholder, self.connection = self._connect(self.uri, options) + + def close(self): + self.connection.close() + + def execute(self, sql: str, args: Tuple = ()): + with closing(self.connection.cursor()) as cursor: + cursor.execute(self._sql(sql), args) + return cursor.rowcount + + def executemany(self, sql: str, args): + with closing(self.connection.cursor()) as cursor: + cursor.executemany(self._sql(sql), args) + return cursor.rowcount + + def fetch_all(self, sql: str, args: Tuple = ()): + with closing(self.connection.cursor()) as cursor: + cursor.execute(self._sql(sql), args) + return cursor.fetchall() + + def fetch_one(self, sql: str, args: Tuple = ()): + with closing(self.connection.cursor()) as cursor: + cursor.execute(self._sql(sql), args) + return cursor.fetchone() + + def _sql(self, sql: str) -> str: + return _convert_qmark_placeholders(sql, self.placeholder) + + @contextmanager + def transaction(self): + try: + yield + self.connection.commit() + except Exception: + self.connection.rollback() + raise + + @staticmethod + def _jdbc_properties(options: Dict[str, str]) -> Dict[str, str]: + result = {} + for key, value in options.items(): + if key.startswith("jdbc."): + result[key[len("jdbc."):]] = value + return result + + def _connect(self, uri: str, options: Dict[str, str]): + if uri.startswith("jdbc:sqlite:"): + return self._connect_sqlite(uri) + if uri.startswith("jdbc:mysql:"): + return self._connect_mysql(uri, options) + if uri.startswith("jdbc:postgresql:"): + return self._connect_postgresql(uri, options) + raise ValueError(f"Unsupported JDBC catalog URI for Python DB-API connection: {uri}") + + def _connect_sqlite(self, uri: str): + sqlite_uri = uri[len("jdbc:sqlite:"):] + if sqlite_uri.startswith("file:"): + connection = sqlite3.connect(sqlite_uri, uri=True, check_same_thread=False) + else: + connection = sqlite3.connect(sqlite_uri, check_same_thread=False) + return "sqlite", "?", connection + + def _connect_mysql(self, uri: str, options: Dict[str, str]): + try: + import pymysql + connector = "pymysql" + except ImportError: + try: + import mysql.connector as mysql_connector + connector = "mysql-connector" + except ImportError as e: + raise ImportError( + "PyPaimon JDBC catalog uses Python DB-API drivers and requires " + "pymysql or mysql-connector-python to connect to MySQL." + ) from e + + parsed = urlparse(uri[len("jdbc:"):]) + props = self._jdbc_properties(options) + query = {k: v[0] for k, v in parse_qs(parsed.query).items()} + props.update(query) + user = props.pop("user", props.pop("username", None)) + password = props.pop("password", None) + host = props.pop("host", parsed.hostname) + database = parsed.path.lstrip("/") or props.pop("database", "") + props.pop("database", None) + port_value = props.pop("port", None) + port = parsed.port or int(port_value or 3306) + if connector == "pymysql": + connection = pymysql.connect( + host=host, + port=port, + user=user, + password=password, + database=database, + autocommit=False, + **props + ) + else: + connection = mysql_connector.connect( + host=host, + port=port, + user=user, + password=password, + database=database, + **props + ) + return "mysql", "%s", connection + + def _connect_postgresql(self, uri: str, options: Dict[str, str]): + try: + import psycopg2 + connector = "psycopg2" + except ImportError: + try: + import psycopg + connector = "psycopg" + except ImportError as e: + raise ImportError( + "PyPaimon JDBC catalog uses Python DB-API drivers and requires " + "psycopg2 or psycopg to connect to PostgreSQL." + ) from e + + parsed = urlparse(uri[len("jdbc:"):]) + props = self._jdbc_properties(options) + query = {k: v[0] for k, v in parse_qs(parsed.query).items()} + props.update(query) + user = props.pop("user", props.pop("username", None)) + password = props.pop("password", None) + host = props.pop("host", parsed.hostname) + database = parsed.path.lstrip("/") or props.get("database") or props.get("dbname") or "" + props.pop("database", None) + props.pop("dbname", None) + port_value = props.pop("port", None) + port = parsed.port or int(port_value or 5432) + connect_kwargs = { + "host": host, + "port": port, + "user": user, + "password": password, + "dbname": database, + } + connect_kwargs.update(props) + if connector == "psycopg2": + connection = psycopg2.connect(**connect_kwargs) + else: + connection = psycopg.connect(**connect_kwargs) + return "postgresql", "%s", connection + + +class JdbcCatalog(Catalog): + CATALOG_TABLE_NAME = "paimon_tables" + DATABASE_PROPERTIES_TABLE_NAME = "paimon_database_properties" + TABLE_PROPERTIES_TABLE_NAME = "paimon_table_properties" + CATALOG_KEY = "catalog_key" + TABLE_DATABASE = "database_name" + TABLE_NAME = "table_name" + PROPERTY_KEY = "property_key" + PROPERTY_VALUE = "property_value" + DATABASE_EXISTS_PROPERTY = "exists" + + def __init__(self, context: CatalogContext): + catalog_options = context.options + if not catalog_options.contains(CatalogOptions.WAREHOUSE): + raise ValueError(f"Paimon '{CatalogOptions.WAREHOUSE.key()}' path must be set") + self.context = context + self.catalog_options = catalog_options + self.options = catalog_options.to_map() + self.warehouse = catalog_options.get(CatalogOptions.WAREHOUSE) + self.catalog_key = catalog_options.get(JdbcCatalogOptions.CATALOG_KEY) + self.file_io = FileIO.get(self.warehouse, self.catalog_options) + self.connection = _DbApiConnection(self.options) + self._initialize_catalog_tables() + + def close(self): + self.connection.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + return False + + def _initialize_catalog_tables(self): + with self.connection.transaction(): + self.connection.execute( + "CREATE TABLE IF NOT EXISTS paimon_tables (" + "catalog_key VARCHAR(255) NOT NULL, " + "database_name VARCHAR(255) NOT NULL, " + "table_name VARCHAR(255) NOT NULL, " + "PRIMARY KEY (catalog_key, database_name, table_name))" + ) + self.connection.execute( + "CREATE TABLE IF NOT EXISTS paimon_database_properties (" + "catalog_key VARCHAR(255) NOT NULL, " + "database_name VARCHAR(255) NOT NULL, " + "property_key VARCHAR(255), " + "property_value VARCHAR(1000), " + "PRIMARY KEY (catalog_key, database_name, property_key))" + ) + self.connection.execute( + "CREATE TABLE IF NOT EXISTS paimon_table_properties (" + "catalog_key VARCHAR(255) NOT NULL, " + "database_name VARCHAR(255) NOT NULL, " + "table_name VARCHAR(255) NOT NULL, " + "property_key VARCHAR(255) NOT NULL, " + "property_value VARCHAR(1000), " + "PRIMARY KEY (catalog_key, database_name, table_name, property_key))" + ) + + def list_databases(self) -> List[str]: + table_rows = self.connection.fetch_all( + "SELECT DISTINCT database_name FROM paimon_tables WHERE catalog_key = ?", + (self.catalog_key,) + ) + property_rows = self.connection.fetch_all( + "SELECT DISTINCT database_name FROM paimon_database_properties WHERE catalog_key = ?", + (self.catalog_key,) + ) + databases = {row[0] for row in table_rows} + databases.update(row[0] for row in property_rows) + return sorted(databases) + + def get_database(self, name: str) -> Database: + if not self._database_exists(name): + raise DatabaseNotExistException(name) + properties = self._fetch_database_properties(name) + if Catalog.DB_LOCATION_PROP not in properties: + properties[Catalog.DB_LOCATION_PROP] = self.get_database_path(name) + properties.pop(self.DATABASE_EXISTS_PROPERTY, None) + return Database(name, properties) + + def create_database(self, name: str, ignore_if_exists: bool, properties: Optional[dict] = None): + if self._database_exists(name): + if not ignore_if_exists: + raise DatabaseAlreadyExistException(name) + return + create_props = {self.DATABASE_EXISTS_PROPERTY: "true"} + if properties: + create_props.update(properties) + if Catalog.DB_LOCATION_PROP not in create_props: + create_props[Catalog.DB_LOCATION_PROP] = self.get_database_path(name) + with self.connection.transaction(): + self._insert_database_properties(name, create_props) + + def drop_database(self, name: str, ignore_if_not_exists: bool = False, cascade: bool = False): + if not self._database_exists(name): + if not ignore_if_not_exists: + raise DatabaseNotExistException(name) + return + tables = self.list_tables(name) + if tables and not cascade: + raise ValueError(f"Database {name} is not empty. Use cascade=True to drop all tables first.") + if cascade: + for table in tables: + self.drop_table(Identifier.create(name, table), True) + with self.connection.transaction(): + self.connection.execute( + "DELETE FROM paimon_tables WHERE catalog_key = ? AND database_name = ?", + (self.catalog_key, name) + ) + self.connection.execute( + "DELETE FROM paimon_database_properties WHERE catalog_key = ? AND database_name = ?", + (self.catalog_key, name) + ) + self.connection.execute( + "DELETE FROM paimon_table_properties WHERE catalog_key = ? AND database_name = ?", + (self.catalog_key, name) + ) + + def alter_database(self, name: str, changes: list): + self.get_database(name) + from pypaimon.catalog.rest.property_change import PropertyChange + set_properties, remove_keys = PropertyChange.get_set_properties_to_remove_keys(changes) + current = self._fetch_database_properties(name) + with self.connection.transaction(): + update_args = [ + (value, self.catalog_key, name, key) + for key, value in set_properties.items() + if key in current + ] + insert_properties = { + key: value for key, value in set_properties.items() if key not in current + } + if update_args: + self.connection.executemany( + "UPDATE paimon_database_properties SET property_value = ? " + "WHERE catalog_key = ? AND database_name = ? AND property_key = ?", + update_args + ) + if insert_properties: + self._insert_database_properties(name, insert_properties) + for key in remove_keys: + self.connection.execute( + "DELETE FROM paimon_database_properties " + "WHERE catalog_key = ? AND database_name = ? AND property_key = ?", + (self.catalog_key, name, key) + ) + + def list_tables(self, database_name: str) -> List[str]: + self.get_database(database_name) + rows = self.connection.fetch_all( + "SELECT table_name FROM paimon_tables WHERE catalog_key = ? AND database_name = ?", + (self.catalog_key, database_name) + ) + return sorted(row[0] for row in rows) + + def get_table(self, identifier: Union[str, Identifier]) -> Table: + if not isinstance(identifier, Identifier): + identifier = Identifier.from_string(identifier) + if self.catalog_options.contains(CoreOptions.SCAN_FALLBACK_BRANCH): + raise ValueError(f"Unsupported CoreOption {CoreOptions.SCAN_FALLBACK_BRANCH}") + if not self._table_exists(identifier): + raise TableNotExistException(identifier) + table_path = self.get_table_path(identifier) + table_schema = self.get_table_schema(identifier) + from pypaimon.catalog.jdbc_catalog_loader import JdbcCatalogLoader + catalog_environment = CatalogEnvironment( + identifier=identifier, + uuid=None, + catalog_loader=JdbcCatalogLoader(self.context), + supports_version_management=False + ) + return FileStoreTable(self.file_io, identifier, table_path, table_schema, catalog_environment) + + def create_table(self, identifier: Union[str, Identifier], schema: 'Schema', ignore_if_exists: bool): + if schema.options and schema.options.get(CoreOptions.AUTO_CREATE.key()): + raise ValueError(f"The value of {CoreOptions.AUTO_CREATE.key()} property should be False.") + if not isinstance(identifier, Identifier): + identifier = Identifier.from_string(identifier) + self.get_database(identifier.get_database_name()) + if self._table_exists(identifier): + if not ignore_if_exists: + raise TableAlreadyExistException(identifier) + return + if schema.options and CoreOptions.TYPE.key() in schema.options and schema.options.get( + CoreOptions.TYPE.key()) != "table": + raise ValueError(f"Table Type: {schema.options.get(CoreOptions.TYPE.key())}") + + table_path = self.get_table_path(identifier) + schema_manager = SchemaManager(self.file_io, table_path) + table_schema = schema_manager.create_table(schema) + try: + with self.connection.transaction(): + self.connection.execute( + "INSERT INTO paimon_tables (catalog_key, database_name, table_name) VALUES (?, ?, ?)", + (self.catalog_key, identifier.get_database_name(), identifier.get_table_name()) + ) + if self._sync_all_properties(): + self._insert_table_properties(identifier, self._collect_table_properties(table_schema)) + except Exception: + self.file_io.delete_directory_quietly(table_path) + raise + + def drop_table(self, identifier: Union[str, Identifier], ignore_if_not_exists: bool = False): + if not isinstance(identifier, Identifier): + identifier = Identifier.from_string(identifier) + if not self._table_exists(identifier): + if not ignore_if_not_exists: + raise TableNotExistException(identifier) + return + table_path = self.get_table_path(identifier) + with self.connection.transaction(): + self.connection.execute( + "DELETE FROM paimon_tables WHERE catalog_key = ? AND database_name = ? AND table_name = ?", + (self.catalog_key, identifier.get_database_name(), identifier.get_table_name()) + ) + self.connection.execute( + "DELETE FROM paimon_table_properties WHERE catalog_key = ? AND database_name = ? AND table_name = ?", + (self.catalog_key, identifier.get_database_name(), identifier.get_table_name()) + ) + self.file_io.delete_directory_quietly(table_path) + + def rename_table(self, source_identifier: Union[str, Identifier], target_identifier: Union[str, Identifier]): + if not isinstance(source_identifier, Identifier): + source_identifier = Identifier.from_string(source_identifier) + if not isinstance(target_identifier, Identifier): + target_identifier = Identifier.from_string(target_identifier) + if not self._table_exists(source_identifier): + raise TableNotExistException(source_identifier) + self.get_database(target_identifier.get_database_name()) + if self._table_exists(target_identifier): + raise TableAlreadyExistException(target_identifier) + + source_path = self.get_table_path(source_identifier) + target_path = self.get_table_path(target_identifier) + renamed_path = False + if self.file_io.exists(source_path): + self.file_io.rename(source_path, target_path) + renamed_path = True + try: + with self.connection.transaction(): + self.connection.execute( + "UPDATE paimon_tables SET database_name = ?, table_name = ? " + "WHERE catalog_key = ? AND database_name = ? AND table_name = ?", + ( + target_identifier.get_database_name(), + target_identifier.get_table_name(), + self.catalog_key, + source_identifier.get_database_name(), + source_identifier.get_table_name() + ) + ) + self.connection.execute( + "UPDATE paimon_table_properties SET database_name = ?, table_name = ? " + "WHERE catalog_key = ? AND database_name = ? AND table_name = ?", + ( + target_identifier.get_database_name(), + target_identifier.get_table_name(), + self.catalog_key, + source_identifier.get_database_name(), + source_identifier.get_table_name() + ) + ) + except Exception: + if renamed_path and self.file_io.exists(target_path): + self.file_io.rename(target_path, source_path) + raise + + def alter_table( + self, + identifier: Union[str, Identifier], + changes: List[SchemaChange], + ignore_if_not_exists: bool = False + ): + if not isinstance(identifier, Identifier): + identifier = Identifier.from_string(identifier) + if not self._table_exists(identifier): + if not ignore_if_not_exists: + raise TableNotExistException(identifier) + return + schema_manager = SchemaManager(self.file_io, self.get_table_path(identifier)) + table_schema = schema_manager.commit_changes(changes) + if self._sync_all_properties(): + with self.connection.transaction(): + self.connection.execute( + "DELETE FROM paimon_table_properties " + "WHERE catalog_key = ? AND database_name = ? AND table_name = ?", + (self.catalog_key, identifier.get_database_name(), identifier.get_table_name()) + ) + self._insert_table_properties(identifier, self._collect_table_properties(table_schema)) + + def get_table_schema(self, identifier: Identifier): + table_schema = SchemaManager(self.file_io, self.get_table_path(identifier)).latest() + if table_schema is None: + raise TableNotExistException(identifier) + return table_schema + + def get_database_path(self, name: str) -> str: + warehouse = self.warehouse.rstrip('/') + return f"{warehouse}/{name}{Catalog.DB_SUFFIX}" + + def get_table_path(self, identifier: Identifier) -> str: + db_path = self.get_database_path(identifier.get_database_name()) + return f"{db_path}/{identifier.get_table_name()}" + + def load_snapshot(self, identifier: Identifier): + raise NotImplementedError("JDBC catalog does not support load_snapshot") + + def commit_snapshot( + self, + identifier: Identifier, + table_uuid: Optional[str], + snapshot: Snapshot, + statistics: List[PartitionStatistics] + ) -> bool: + raise NotImplementedError("This catalog does not support commit catalog") + + def _database_exists(self, database_name: str) -> bool: + row = self.connection.fetch_one( + "SELECT database_name FROM paimon_tables " + "WHERE catalog_key = ? AND database_name = ? LIMIT 1", + (self.catalog_key, database_name) + ) + if row is not None: + return True + row = self.connection.fetch_one( + "SELECT database_name FROM paimon_database_properties " + "WHERE catalog_key = ? AND database_name = ? LIMIT 1", + (self.catalog_key, database_name) + ) + return row is not None + + def _table_exists(self, identifier: Identifier) -> bool: + row = self.connection.fetch_one( + "SELECT table_name FROM paimon_tables " + "WHERE catalog_key = ? AND database_name = ? AND table_name = ? LIMIT 1", + (self.catalog_key, identifier.get_database_name(), identifier.get_table_name()) + ) + return row is not None + + def _fetch_database_properties(self, database_name: str) -> Dict[str, str]: + rows = self.connection.fetch_all( + "SELECT property_key, property_value FROM paimon_database_properties " + "WHERE catalog_key = ? AND database_name = ?", + (self.catalog_key, database_name) + ) + return {row[0]: row[1] for row in rows} + + def _insert_database_properties(self, database_name: str, properties: Dict[str, str]): + if properties: + self.connection.executemany( + "INSERT INTO paimon_database_properties " + "(catalog_key, database_name, property_key, property_value) VALUES (?, ?, ?, ?)", + [(self.catalog_key, database_name, key, value) for key, value in properties.items()] + ) + + def _insert_table_properties(self, identifier: Identifier, properties: Dict[str, str]): + if properties: + self.connection.executemany( + "INSERT INTO paimon_table_properties " + "(catalog_key, database_name, table_name, property_key, property_value) " + "VALUES (?, ?, ?, ?, ?)", + [ + ( + self.catalog_key, + identifier.get_database_name(), + identifier.get_table_name(), + key, + value + ) + for key, value in properties.items() + ] + ) + + def _sync_all_properties(self) -> bool: + from pypaimon.common.options.options_utils import OptionsUtils + return OptionsUtils.convert_to_boolean( + self.catalog_options.get(CatalogOptions.SYNC_ALL_PROPERTIES)) + + @staticmethod + def _collect_table_properties(table_schema) -> Dict[str, str]: + properties = dict(table_schema.options or {}) + if table_schema.primary_keys: + properties["primary-key"] = ",".join(table_schema.primary_keys) + if table_schema.partition_keys: + properties["partition"] = ",".join(table_schema.partition_keys) + return properties diff --git a/paimon-python/pypaimon/catalog/jdbc_catalog_loader.py b/paimon-python/pypaimon/catalog/jdbc_catalog_loader.py new file mode 100644 index 000000000000..801ecb1a1c73 --- /dev/null +++ b/paimon-python/pypaimon/catalog/jdbc_catalog_loader.py @@ -0,0 +1,32 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +from pypaimon.catalog.catalog_context import CatalogContext +from pypaimon.catalog.catalog_loader import CatalogLoader +from pypaimon.catalog.jdbc_catalog import JdbcCatalog + + +class JdbcCatalogLoader(CatalogLoader): + def __init__(self, context: CatalogContext): + self._context = context + + def context(self) -> CatalogContext: + return self._context + + def load(self) -> JdbcCatalog: + return JdbcCatalog(self._context) diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py index 777dd6fef145..42dabb268cb2 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py +++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py @@ -191,6 +191,9 @@ def write_lance(self, path: str, data, **kwargs): def write_blob(self, path: str, data, **kwargs): return self.file_io().write_blob(path, data, **kwargs) + def write_row(self, path: str, data, fields=None, zstd_level: int = 1, **kwargs): + return self.file_io().write_row(path, data, fields, zstd_level, **kwargs) + @property def uri_reader_factory(self): if self._uri_reader_factory_cache is None: diff --git a/paimon-python/pypaimon/common/external_path_provider.py b/paimon-python/pypaimon/common/external_path_provider.py index f553dced521e..e039b1d86ed2 100644 --- a/paimon-python/pypaimon/common/external_path_provider.py +++ b/paimon-python/pypaimon/common/external_path_provider.py @@ -15,29 +15,185 @@ # specific language governing permissions and limitations # under the License. +import bisect +import ctypes import random +import struct +from abc import ABC, abstractmethod from typing import List -class ExternalPathProvider: - def __init__(self, external_table_paths: List[str], relative_bucket_path: str): - self.external_table_paths = external_table_paths - self.relative_bucket_path = relative_bucket_path - self.position = random.randint(0, len(external_table_paths) - 1) if external_table_paths else 0 +class ExternalPathProvider(ABC): + """Provider for external data paths.""" + @abstractmethod def get_next_external_data_path(self, file_name: str) -> str: - """ - Get the next external data path using round-robin strategy. - """ - if not self.external_table_paths: - raise ValueError("No external paths available") - - self.position += 1 - if self.position == len(self.external_table_paths): - self.position = 0 - - external_base = self.external_table_paths[self.position] - if self.relative_bucket_path: - return f"{external_base.rstrip('/')}/{self.relative_bucket_path.strip('/')}/{file_name}" + """Get the next external data path for the given file name.""" + + @staticmethod + def create(strategy, external_table_paths, relative_bucket_path="", weights=None): + """Factory method to create the appropriate ExternalPathProvider.""" + from pypaimon.common.options.core_options import ExternalPathStrategy + + if strategy is None: + return None + if isinstance(strategy, str): + strategy = ExternalPathStrategy(strategy) + + if strategy == ExternalPathStrategy.NONE: + return None + elif strategy in (ExternalPathStrategy.ROUND_ROBIN, ExternalPathStrategy.SPECIFIC_FS): + return RoundRobinExternalPathProvider(external_table_paths, relative_bucket_path) + elif strategy == ExternalPathStrategy.ENTROPY_INJECT: + return EntropyInjectExternalPathProvider(external_table_paths, relative_bucket_path) + elif strategy == ExternalPathStrategy.WEIGHTED: + if len(external_table_paths) < 2 or not weights: + return RoundRobinExternalPathProvider(external_table_paths, relative_bucket_path) + return WeightedExternalPathProvider(external_table_paths, relative_bucket_path, weights) + else: + raise ValueError(f"Unsupported external path strategy: {strategy}") + + +class RoundRobinExternalPathProvider(ExternalPathProvider): + """Provider for round-robin external data paths.""" + + def __init__(self, external_table_paths: List[str], relative_bucket_path: str = ""): + if not external_table_paths: + raise ValueError("external_table_paths must not be empty") + self._external_table_paths = external_table_paths + self._relative_bucket_path = relative_bucket_path + self._position = random.randint(0, len(external_table_paths) - 1) + + def get_next_external_data_path(self, file_name: str) -> str: + self._position += 1 + if self._position == len(self._external_table_paths): + self._position = 0 + + external_base = self._external_table_paths[self._position] + if self._relative_bucket_path: + return f"{external_base.rstrip('/')}/{self._relative_bucket_path.strip('/')}/{file_name}" else: return f"{external_base.rstrip('/')}/{file_name}" + + +class EntropyInjectExternalPathProvider(ExternalPathProvider): + """Provider for entropy-injected external data paths. + + Generates hash-based directory structures from filenames using murmur3_32. + Constants: 20-bit hash, depth=3 dirs of 4 bits each, 8-bit remainder. + """ + + _HASH_BINARY_STRING_BITS = 20 + _ENTROPY_DIR_LENGTH = 4 + _ENTROPY_DIR_DEPTH = 3 + + def __init__(self, external_table_paths: List[str], relative_bucket_path: str = ""): + if not external_table_paths: + raise ValueError("external_table_paths must not be empty") + self._external_table_paths = external_table_paths + self._relative_bucket_path = relative_bucket_path + self._position = 0 + + def get_next_external_data_path(self, file_name: str) -> str: + hash_dirs = self._compute_hash(file_name) + if self._relative_bucket_path: + file_path_with_hash = f"{self._relative_bucket_path.strip('/')}/{hash_dirs}/{file_name}" + else: + file_path_with_hash = f"{hash_dirs}/{file_name}" + + self._position += 1 + if self._position == len(self._external_table_paths): + self._position = 0 + + external_base = self._external_table_paths[self._position] + return f"{external_base.rstrip('/')}/{file_path_with_hash}" + + def _compute_hash(self, file_name: str) -> str: + hash_int = _murmur3_32(file_name.encode('utf-8')) + binary_string = format((hash_int & 0xFFFFFFFF) | 0x80000000, '032b') + hash_str = binary_string[32 - self._HASH_BINARY_STRING_BITS:] + + parts = [] + total_prefix = self._ENTROPY_DIR_DEPTH * self._ENTROPY_DIR_LENGTH + for i in range(0, total_prefix, self._ENTROPY_DIR_LENGTH): + end = min(i + self._ENTROPY_DIR_LENGTH, len(hash_str)) + parts.append(hash_str[i:end]) + if len(hash_str) > total_prefix: + parts.append(hash_str[total_prefix:]) + return "/".join(parts) + + +class WeightedExternalPathProvider(ExternalPathProvider): + """Provider for weighted external data paths. + + Uses cumulative weights with binary search for path selection. + """ + + def __init__(self, external_table_paths: List[str], relative_bucket_path: str, weights: List[int]): + if len(external_table_paths) != len(weights): + raise ValueError( + f"The number of external paths and weights should be the same. " + f"Paths: {len(external_table_paths)}, Weights: {len(weights)}" + ) + self._external_table_paths = external_table_paths + self._relative_bucket_path = relative_bucket_path + self._total_weight = sum(weights) + self._cumulative_weights: List[int] = [] + cumulative = 0 + for w in weights: + cumulative += w + self._cumulative_weights.append(cumulative) + + def get_next_external_data_path(self, file_name: str) -> str: + random_value = random.random() * self._total_weight + index = bisect.bisect_right(self._cumulative_weights, random_value) + if index >= len(self._external_table_paths): + index = len(self._external_table_paths) - 1 + selected_base = self._external_table_paths[index] + if self._relative_bucket_path: + return f"{selected_base.rstrip('/')}/{self._relative_bucket_path.strip('/')}/{file_name}" + else: + return f"{selected_base.rstrip('/')}/{file_name}" + + +def _murmur3_32(data: bytes, seed: int = 0) -> int: + """Pure-Python murmur3_32 hash, compatible with Guava Hashing.murmur3_32(). + + Returns a signed 32-bit integer identical to Java's int representation. + """ + c1 = 0xCC9E2D51 + c2 = 0x1B873593 + length = len(data) + h1 = seed & 0xFFFFFFFF + rounded_end = (length & 0xFFFFFFFC) + + for i in range(0, rounded_end, 4): + k1 = struct.unpack_from('> 17)) & 0xFFFFFFFF + k1 = (k1 * c2) & 0xFFFFFFFF + h1 ^= k1 + h1 = ((h1 << 13) | (h1 >> 19)) & 0xFFFFFFFF + h1 = (h1 * 5 + 0xE6546B64) & 0xFFFFFFFF + + k1 = 0 + remaining = length & 3 + if remaining >= 3: + k1 ^= data[rounded_end + 2] << 16 + if remaining >= 2: + k1 ^= data[rounded_end + 1] << 8 + if remaining >= 1: + k1 ^= data[rounded_end] + k1 = (k1 * c1) & 0xFFFFFFFF + k1 = ((k1 << 15) | (k1 >> 17)) & 0xFFFFFFFF + k1 = (k1 * c2) & 0xFFFFFFFF + h1 ^= k1 + + h1 ^= length + h1 ^= h1 >> 16 + h1 = (h1 * 0x85EBCA6B) & 0xFFFFFFFF + h1 ^= h1 >> 13 + h1 = (h1 * 0xC2B2AE35) & 0xFFFFFFFF + h1 ^= h1 >> 16 + + return ctypes.c_int32(h1).value diff --git a/paimon-python/pypaimon/common/file_io.py b/paimon-python/pypaimon/common/file_io.py index 6f9758965b57..fabda004959b 100644 --- a/paimon-python/pypaimon/common/file_io.py +++ b/paimon-python/pypaimon/common/file_io.py @@ -16,6 +16,7 @@ # under the License. import logging +import os import uuid from abc import ABC, abstractmethod from pathlib import Path @@ -27,6 +28,26 @@ from pypaimon.common.options import Options +def supports_pread(stream) -> bool: + """Check if the stream supports position-based reads (thread-safe I/O).""" + if hasattr(stream, 'read_at'): + return True + if hasattr(stream, 'fileno'): + try: + stream.fileno() + return True + except Exception: + pass + return False + + +def pread(stream, length: int, offset: int) -> bytes: + """Position-based read without changing the stream cursor. Thread-safe.""" + if hasattr(stream, 'read_at'): + return stream.read_at(length, offset) + return os.pread(stream.fileno(), length, offset) + + class FileIO(ABC): """ File IO interface to read and write files. @@ -257,6 +278,9 @@ def write_blob(self, path: str, data, **kwargs): def write_vortex(self, path: str, data, **kwargs): raise NotImplementedError("write_vortex must be implemented by FileIO subclasses") + def write_row(self, path: str, data, fields=None, zstd_level: int = 1, **kwargs): + raise NotImplementedError("write_row must be implemented by FileIO subclasses") + def close(self): pass @@ -265,8 +289,11 @@ def get(path: str, catalog_options: Optional[Options] = None) -> 'FileIO': """ Returns a FileIO instance for accessing the file system identified by the given path. - LocalFileIO for local file system (file:// or no scheme) - - PyArrowFileIO for remote file systems (oss://, s3://, hdfs://, etc.) + - HdfsNativeFileIO for HDFS/ViewFS (default; pure protocol client, no Hadoop install) + - PyArrowFileIO for other remote file systems (oss://, s3://, gs://, ...), + and for HDFS when explicitly requested via hdfs.client.impl=pyarrow """ + import os as _os from urllib.parse import urlparse uri = urlparse(path) @@ -276,5 +303,39 @@ def get(path: str, catalog_options: Optional[Options] = None) -> 'FileIO': from pypaimon.filesystem.local_file_io import LocalFileIO return LocalFileIO(path, catalog_options) + opts = catalog_options or Options({}) + + if scheme in ("hdfs", "viewfs"): + from pypaimon.common.options.config import HdfsOptions + impl_source = "hdfs.client.impl option" + # Treat an empty option value the same as "unset" so callers can + # blank it out (common in templated configs) without tripping + # the unsupported-impl branch. + impl_value = opts.to_map().get(HdfsOptions.HDFS_CLIENT_IMPL.key()) + if not impl_value: + impl_value = _os.environ.get("PYPAIMON_HDFS_IMPL") + impl_source = "PYPAIMON_HDFS_IMPL env var" + if not impl_value: + impl_value = HdfsOptions.HDFS_CLIENT_IMPL.default_value() + impl_source = "default" + impl = impl_value.lower() + if impl == "native": + try: + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + return HdfsNativeFileIO(path, opts) + except (ImportError, RuntimeError) as e: + fallback = opts.get(HdfsOptions.HDFS_CLIENT_FALLBACK_TO_PYARROW) + if not fallback: + raise + logging.getLogger(__name__).warning( + "Native HDFS backend init failed, falling back to " + "pyarrow: %s", e, + ) + elif impl != "pyarrow": + raise ValueError( + f"Unsupported hdfs.client.impl '{impl_value}' " + f"(from {impl_source}). Supported: 'native', 'pyarrow'." + ) + from pypaimon.filesystem.pyarrow_file_io import PyArrowFileIO - return PyArrowFileIO(path, catalog_options or Options({})) + return PyArrowFileIO(path, opts) diff --git a/paimon-python/pypaimon/common/merge_engine_dispatch.py b/paimon-python/pypaimon/common/merge_engine_dispatch.py new file mode 100644 index 000000000000..a9096fd49204 --- /dev/null +++ b/paimon-python/pypaimon/common/merge_engine_dispatch.py @@ -0,0 +1,167 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Centralised merge-engine dispatch. + +Both the read path (``MergeFileSplitRead``) and the write path +(``KeyValueDataWriter``'s in-memory merge buffer) need to pick a +``MergeFunction`` based on the table's ``merge-engine`` option. This +module is the single source of truth so the two sides cannot drift. +""" + +from typing import List, Optional + +from pypaimon.common.options.core_options import MergeEngine +from pypaimon.read.reader.deduplicate_merge_function import \ + DeduplicateMergeFunction +from pypaimon.read.reader.first_row_merge_function import \ + FirstRowMergeFunction +from pypaimon.read.reader.partial_update_merge_function import \ + PartialUpdateMergeFunction + + +# Boolean-valued options that, when truthy, opt the table into +# behaviour the pypaimon PartialUpdateMergeFunction does not yet +# implement. Setting any of these forces the dispatch to refuse the +# write instead of running the simple last-non-null merge silently. +_PARTIAL_UPDATE_UNSUPPORTED_BOOLEAN_OPTIONS = ( + "ignore-delete", + "partial-update.ignore-delete", + "first-row.ignore-delete", + "deduplicate.ignore-delete", + "partial-update.remove-record-on-delete", + "partial-update.remove-record-on-sequence-group", +) +_FIELDS_PREFIX = "fields." +_FIELD_SEQUENCE_GROUP_SUFFIX = ".sequence-group" +_FIELD_AGGREGATE_FUNCTION_SUFFIX = ".aggregate-function" +_DEFAULT_AGGREGATE_FUNCTION_KEY = "fields.default-aggregate-function" + +# Mirror ``CoreOptions.ignore_delete()``: any of these keys, if set to +# ``"true"``, opts the engine into silently dropping +# DELETE/UPDATE_BEFORE records. Kept as a raw-option lookup here so the +# dispatch stays table-agnostic. +_IGNORE_DELETE_KEYS = ( + "ignore-delete", + "first-row.ignore-delete", + "deduplicate.ignore-delete", + "partial-update.ignore-delete", +) + + +def build_merge_function( + *, + engine: MergeEngine, + raw_options: dict, + key_arity: int, + value_arity: int, + value_field_nullables: List[bool], + value_field_names: Optional[List[str]] = None, +): + """Pick the MergeFunction for the table's ``merge-engine`` option. + + ``engine`` and ``raw_options`` come from the table's ``CoreOptions`` + (typically ``table.options.merge_engine()`` and + ``table.options.options.to_map()``). ``key_arity`` / ``value_arity`` + / ``value_field_nullables`` describe the value-side schema the + caller wants the merge function to operate on -- for the read path + this is the projected read schema, for the write path it's the full + table schema (minus primary keys). + + ``value_field_names`` is optional and only used by + ``PartialUpdateMergeFunction`` to surface the offending field name + when a NOT NULL constraint is violated; pass ``None`` if the caller + doesn't have names handy. + """ + if engine == MergeEngine.DEDUPLICATE: + return DeduplicateMergeFunction() + if engine == MergeEngine.PARTIAL_UPDATE: + unsupported = partial_update_unsupported_options(raw_options) + if unsupported: + raise NotImplementedError( + "merge-engine 'partial-update' is enabled together with " + "options that pypaimon does not yet implement: {}. The " + "supported subset is per-key last-non-null merge with " + "no sequence-group, no per-field aggregator override, " + "no ignore-delete and no partial-update.remove-record-on-* " + "flags. Open an issue to track Python support.".format( + ", ".join(sorted(unsupported)) + ) + ) + return PartialUpdateMergeFunction( + key_arity=key_arity, + value_arity=value_arity, + nullables=list(value_field_nullables), + value_field_names=( + list(value_field_names) + if value_field_names is not None else None), + ) + if engine == MergeEngine.FIRST_ROW: + return FirstRowMergeFunction( + ignore_delete=_ignore_delete_from_options(raw_options), + ) + raise NotImplementedError( + "merge-engine '{}' is not implemented in pypaimon yet " + "(supported: deduplicate, first-row, partial-update). Open an " + "issue to track support.".format(engine.value) + ) + + +def _ignore_delete_from_options(raw_options: dict) -> bool: + for key in _IGNORE_DELETE_KEYS: + val = raw_options.get(key) + if val is not None: + return _option_is_truthy(val) + return False + + +def partial_update_unsupported_options(raw_options: dict): + """Return the set of option keys this table sets that + ``PartialUpdateMergeFunction`` does not yet support. Empty set + means we can safely run the simple last-non-null merge. + """ + flagged = set() + for key, value in raw_options.items(): + if (key in _PARTIAL_UPDATE_UNSUPPORTED_BOOLEAN_OPTIONS + and _option_is_truthy(value)): + flagged.add(key) + elif key == _DEFAULT_AGGREGATE_FUNCTION_KEY: + flagged.add(key) + elif key.startswith(_FIELDS_PREFIX) and ( + key.endswith(_FIELD_SEQUENCE_GROUP_SUFFIX) + or key.endswith(_FIELD_AGGREGATE_FUNCTION_SUFFIX)): + flagged.add(key) + return flagged + + +def _option_is_truthy(raw): + """Strict ``"true"`` boolean parsing for table-option strings. + + A string is truthy iff it equals ``"true"`` (case-insensitive). + ``"yes"``, ``"on"``, ``"1"`` and similar Python-truthy strings are + treated as falsey, matching the table-option parser used elsewhere + in Paimon so an option string the rest of the toolchain treats as + ``false`` is not silently elevated to ``true`` here. + """ + if raw is None: + return False + if isinstance(raw, bool): + return raw + if isinstance(raw, str): + return raw.strip().lower() == "true" + return False diff --git a/paimon-python/pypaimon/common/options/config.py b/paimon-python/pypaimon/common/options/config.py index 465671725d40..0004cb896a89 100644 --- a/paimon-python/pypaimon/common/options/config.py +++ b/paimon-python/pypaimon/common/options/config.py @@ -59,6 +59,11 @@ class GcsOptions: .with_description("GCP project ID for GCS requests.")) +class JdbcCatalogOptions: + CATALOG_KEY = ConfigOptions.key("catalog-key").string_type().default_value("jdbc").with_description( + "Custom JDBC catalog store key.") + + class PVFSOptions: CACHE_ENABLED = ConfigOptions.key("cache-enabled").boolean_type().default_value("true").with_description( "Enable cache") @@ -101,9 +106,46 @@ class CatalogOptions: PREFIX = ConfigOptions.key("prefix").string_type().no_default_value().with_description("Prefix") HTTP_USER_AGENT_HEADER = ConfigOptions.key( "header.HTTP_USER_AGENT").string_type().no_default_value().with_description("HTTP User Agent header") + SYNC_ALL_PROPERTIES = ConfigOptions.key("sync-all-properties").boolean_type().default_value(True).with_description( + "Sync all table properties to the catalog metastore") BLOB_FILE_IO_DEFAULT_CACHE_SIZE = 2 ** 31 - 1 +class HdfsOptions: + HDFS_CLIENT_IMPL = ( + ConfigOptions.key("hdfs.client.impl") + .string_type() + .default_value("native") + .with_description( + "HDFS FileIO backend. Supported values: 'native' (default, uses " + "hdfs-native protocol client, no Hadoop install required), " + "'pyarrow' (legacy, requires HADOOP_HOME / libhdfs / JVM)." + ) + ) + HDFS_CLIENT_FALLBACK_TO_PYARROW = ( + ConfigOptions.key("hdfs.client.fallback-to-pyarrow") + .boolean_type() + .default_value(True) + .with_description( + "When the native backend fails to initialise (e.g. missing wheel " + "or unsupported platform), fall back to the pyarrow backend " + "instead of raising." + ) + ) + HDFS_CONF_DIR = ( + ConfigOptions.key("hdfs.conf-dir") + .string_type() + .no_default_value() + .with_description( + "Directory containing core-site.xml / hdfs-site.xml that the " + "native client should load. Defaults to $HADOOP_CONF_DIR." + ) + ) + + HDFS_CONFIG_PREFIX = "hdfs.config." + HDFS_NATIVE_CONFIG_KEY_PREFIXES = ("dfs.", "fs.", "hadoop.", "ipc.", "io.") + + class SecurityOptions: KERBEROS_PRINCIPAL = ( ConfigOptions.key("security.kerberos.login.principal") diff --git a/paimon-python/pypaimon/common/options/core_options.py b/paimon-python/pypaimon/common/options/core_options.py index b2f0e019166a..874b888faf42 100644 --- a/paimon-python/pypaimon/common/options/core_options.py +++ b/paimon-python/pypaimon/common/options/core_options.py @@ -16,14 +16,16 @@ # under the License. import sys +import warnings from datetime import timedelta from enum import Enum -from typing import Dict, Optional +from typing import Dict, List, Optional from pypaimon.common.memory_size import MemorySize from pypaimon.common.options import Options from pypaimon.common.options.config_option import ConfigOption from pypaimon.common.options.config_options import ConfigOptions +from pypaimon.common.options.options_utils import OptionsUtils class ExternalPathStrategy(str, Enum): @@ -33,6 +35,8 @@ class ExternalPathStrategy(str, Enum): NONE = "none" ROUND_ROBIN = "round-robin" SPECIFIC_FS = "specific-fs" + ENTROPY_INJECT = "entropy-inject" + WEIGHTED = "weight-robin" class ChangelogProducer(str, Enum): @@ -55,6 +59,37 @@ class MergeEngine(str, Enum): FIRST_ROW = "first-row" +class SortOrder(str, Enum): + """ + Specifies the order of ``sequence.field``. Mirrors Java + ``CoreOptions.SortOrder``. + """ + ASCENDING = "ascending" + DESCENDING = "descending" + + +class StartupMode(str, Enum): + """ + Startup mode for scan operations. + """ + DEFAULT = "default" + LATEST_FULL = "latest-full" + FULL = "full" + LATEST = "latest" + COMPACTED_FULL = "compacted-full" + FROM_TIMESTAMP = "from-timestamp" + FROM_SNAPSHOT = "from-snapshot" + FROM_SNAPSHOT_FULL = "from-snapshot-full" + FROM_CREATION_TIMESTAMP = "from-creation-timestamp" + FROM_FILE_CREATION_TIME = "from-file-creation-time" + INCREMENTAL = "incremental" + + +class GlobalIndexColumnUpdateAction(str, Enum): + THROW_ERROR = "THROW_ERROR" + DROP_PARTITION_INDEX = "DROP_PARTITION_INDEX" + + class CoreOptions: """Core options for Paimon tables.""" # File format constants @@ -64,6 +99,8 @@ class CoreOptions: FILE_FORMAT_BLOB: str = "blob" FILE_FORMAT_LANCE: str = "lance" FILE_FORMAT_VORTEX: str = "vortex" + FILE_FORMAT_ROW: str = "row" + FILE_FORMAT_MOSAIC: str = "mosaic" # Basic options AUTO_CREATE: ConfigOption[bool] = ( @@ -209,6 +246,13 @@ class CoreOptions: .with_description("Whether to return blob values as serialized BlobDescriptor bytes when reading.") ) + BLOB_FIELD: ConfigOption[str] = ( + ConfigOptions.key("blob-field") + .string_type() + .no_default_value() + .with_description("Comma-separated column names that should be stored as blob type.") + ) + BLOB_DESCRIPTOR_FIELD: ConfigOption[str] = ( ConfigOptions.key("blob-descriptor-field") .string_type() @@ -241,6 +285,31 @@ class CoreOptions: ) ) + BLOB_VIEW_FIELD: ConfigOption[str] = ( + ConfigOptions.key("blob-view-field") + .string_type() + .no_default_value() + .with_description("Comma-separated field names to treat as BLOB view fields.") + ) + + BLOB_VIEW_RESOLVE_ENABLED: ConfigOption[bool] = ( + ConfigOptions.key("blob-view.resolve.enabled") + .boolean_type() + .default_value(True) + .with_description( + "Whether to resolve blob-view-field values from upstream tables at " + "read time. Set to false to preserve BlobViewStruct references when " + "forwarding blob view values to another blob-view table." + ) + ) + + VECTOR_FIELD: ConfigOption[str] = ( + ConfigOptions.key("vector-field") + .string_type() + .no_default_value() + .with_description("Comma-separated column names that should be stored as vector type.") + ) + TARGET_FILE_SIZE: ConfigOption[MemorySize] = ( ConfigOptions.key("target-file-size") .memory_type() @@ -276,6 +345,21 @@ class CoreOptions: .with_description("Specify the file name prefix of data files.") ) # Scan options + SCAN_MODE: ConfigOption[StartupMode] = ( + ConfigOptions.key("scan.mode") + .enum_type(StartupMode) + .default_value(StartupMode.DEFAULT) + .with_description( + "Scan startup mode for the table. " + "'default' resolves the actual mode from other scan options. " + "'latest-full' reads the latest snapshot then streams changes. " + "'latest' only streams changes without an initial snapshot. " + "'from-timestamp' reads from a specific timestamp. " + "'from-snapshot' reads from a specific snapshot. " + "'incremental' reads incremental changes between two snapshots/tags." + ) + ) + SCAN_FALLBACK_BRANCH: ConfigOption[str] = ( ConfigOptions.key("scan.fallback-branch") .string_type() @@ -337,6 +421,24 @@ class CoreOptions: ) ) + SCAN_FILE_CREATION_TIME_MILLIS: ConfigOption[int] = ( + ConfigOptions.key("scan.file-creation-time-millis") + .long_type() + .no_default_value() + .with_description( + "After configuring this time, only the data files created after this time will be read." + ) + ) + + SCAN_CREATION_TIME_MILLIS: ConfigOption[int] = ( + ConfigOptions.key("scan.creation-time-millis") + .long_type() + .no_default_value() + .with_description( + "Optional timestamp used in case of 'from-creation-timestamp' scan mode." + ) + ) + SOURCE_SPLIT_TARGET_SIZE: ConfigOption[MemorySize] = ( ConfigOptions.key("source.split.target-size") .memory_type() @@ -375,6 +477,30 @@ class CoreOptions: .with_description("Specify the merge engine for table with primary key. " "Options: deduplicate, partial-update, aggregation, first-row.") ) + + IGNORE_DELETE: ConfigOption[bool] = ( + ConfigOptions.key("ignore-delete") + .boolean_type() + .default_value(False) + .with_description("Whether to ignore delete records.") + ) + + SEQUENCE_FIELD: ConfigOption[str] = ( + ConfigOptions.key("sequence.field") + .string_type() + .no_default_value() + .with_description("The field that generates the sequence number for " + "primary key table, the sequence number determines " + "which data is the most recent.") + ) + + SEQUENCE_FIELD_SORT_ORDER: ConfigOption[SortOrder] = ( + ConfigOptions.key("sequence.field.sort-order") + .enum_type(SortOrder) + .default_value(SortOrder.ASCENDING) + .with_description("Specify the order of sequence.field.") + ) + # Commit options COMMIT_USER_PREFIX: ConfigOption[str] = ( ConfigOptions.key("commit.user-prefix") @@ -436,7 +562,10 @@ class CoreOptions: ConfigOptions.key("data-file.external-paths.strategy") .string_type() .default_value(ExternalPathStrategy.NONE) - .with_description("Strategy for selecting external paths. Options: none, round-robin, specific-fs.") + .with_description( + "Strategy for selecting external paths. " + "Options: none, round-robin, specific-fs, entropy-inject, weight-robin." + ) ) DATA_FILE_EXTERNAL_PATHS_SPECIFIC_FS: ConfigOption[str] = ( @@ -446,6 +575,16 @@ class CoreOptions: .with_description("Specific filesystem for external paths when using specific-fs strategy.") ) + DATA_FILE_EXTERNAL_PATHS_WEIGHTS: ConfigOption[str] = ( + ConfigOptions.key("data-file.external-paths.weights") + .string_type() + .no_default_value() + .with_description( + "Weights for external paths when strategy is weight-robin. " + "Format: comma-separated positive integers corresponding to paths in order." + ) + ) + # Global Index options GLOBAL_INDEX_ENABLED: ConfigOption[bool] = ( ConfigOptions.key("global-index.enabled") @@ -457,13 +596,19 @@ class CoreOptions: GLOBAL_INDEX_THREAD_NUM: ConfigOption[int] = ( ConfigOptions.key("global-index.thread-num") .int_type() - .no_default_value() + .default_value(32) .with_description( - "The maximum number of concurrent scanner for global index. " - "By default is the number of processors available." + "The maximum number of concurrent threads for global index I/O. " + "Defaults to 32 for optimal I/O parallelism." ) ) + GLOBAL_INDEX_COLUMN_UPDATE_ACTION: ConfigOption[GlobalIndexColumnUpdateAction] = ( + ConfigOptions.key("global-index.column-update-action") + .enum_type(GlobalIndexColumnUpdateAction) + .default_value(GlobalIndexColumnUpdateAction.THROW_ERROR) + ) + LOCAL_CACHE_ENABLED: ConfigOption[bool] = ( ConfigOptions.key("local-cache.enabled") .boolean_type() @@ -661,6 +806,21 @@ def variant_shredding_schema(self) -> Optional[str]: def blob_descriptor_fields(self, default=None): value = self.options.get(CoreOptions.BLOB_DESCRIPTOR_FIELD, default) + return CoreOptions._parse_field_set(value) + + def blob_view_fields(self, default=None): + value = self.options.get(CoreOptions.BLOB_VIEW_FIELD, default) + return CoreOptions._parse_field_set(value) + + def blob_field(self, default=None): + value = self.options.get(CoreOptions.BLOB_FIELD, default) + return CoreOptions._parse_field_set(value) + + def blob_view_resolve_enabled(self, default=True): + return self.options.get(CoreOptions.BLOB_VIEW_RESOLVE_ENABLED, default) + + @staticmethod + def _parse_field_set(value): if value is None: return set() if isinstance(value, str): @@ -718,6 +878,42 @@ def vector_target_file_size(self, default=None): def data_file_prefix(self, default=None): return self.options.get(CoreOptions.DATA_FILE_PREFIX, default) + def scan_mode(self, default=None): + return self.options.get(CoreOptions.SCAN_MODE, default) + + def startup_mode(self) -> 'StartupMode': + """Resolve the effective startup mode, matching Java CoreOptions.startupMode(). + + If scan.mode is DEFAULT, auto-detects from other scan options. + Maps deprecated FULL to LATEST_FULL. + """ + mode = self.scan_mode() + if mode == StartupMode.DEFAULT: + if (self.options.contains(CoreOptions.SCAN_TIMESTAMP_MILLIS) + or self.options.contains(CoreOptions.SCAN_TIMESTAMP)): + return StartupMode.FROM_TIMESTAMP + elif (self.options.contains(CoreOptions.SCAN_SNAPSHOT_ID) + or self.options.contains(CoreOptions.SCAN_TAG_NAME) + or self.options.contains(CoreOptions.SCAN_WATERMARK)): + return StartupMode.FROM_SNAPSHOT + elif self.options.contains(CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP): + return StartupMode.INCREMENTAL + elif self.options.contains(CoreOptions.SCAN_FILE_CREATION_TIME_MILLIS): + return StartupMode.FROM_FILE_CREATION_TIME + elif self.options.contains(CoreOptions.SCAN_CREATION_TIME_MILLIS): + return StartupMode.FROM_CREATION_TIMESTAMP + else: + return StartupMode.LATEST_FULL + elif mode == StartupMode.FULL: + warnings.warn( + "scan.mode 'full' is deprecated, use 'latest-full' instead.", + DeprecationWarning, + stacklevel=2, + ) + return StartupMode.LATEST_FULL + else: + return mode + def scan_fallback_branch(self, default=None): return self.options.get(CoreOptions.SCAN_FALLBACK_BRANCH, default) @@ -754,6 +950,9 @@ def row_tracking_enabled(self, default=None): def data_evolution_enabled(self, default=None): return self.options.get(CoreOptions.DATA_EVOLUTION_ENABLED, default) + def global_index_column_update_action(self, default=None): + return self.options.get(CoreOptions.GLOBAL_INDEX_COLUMN_UPDATE_ACTION, default) + def deletion_vectors_enabled(self, default=None): return self.options.get(CoreOptions.DELETION_VECTORS_ENABLED, default) @@ -763,6 +962,46 @@ def changelog_producer(self, default=None): def merge_engine(self, default=None): return self.options.get(CoreOptions.MERGE_ENGINE, default) + def sequence_field(self) -> List[str]: + """User-defined sequence fields, in declaration order. Empty list + when ``sequence.field`` is unset. Mirrors Java + ``CoreOptions.sequenceField()``. + """ + raw = self.options.get(CoreOptions.SEQUENCE_FIELD) + if not raw: + return [] + # Mirror Java ``CoreOptions.sequenceField()`` + # (``Arrays.stream(s.split(',')).map(String::trim)``): Java's + # ``String.split(",")`` drops *trailing* empty segments (so ``'ts,'`` + # yields ``['ts']``) but keeps interior ones, and each segment is + # then trimmed. So an interior empty segment (``'ts,,ts2'``) survives + # as an empty field name that ``check_sequence_field_valid`` rejects, + # while a trailing comma is tolerated. + segments = raw.split(",") + while segments and segments[-1] == "": + segments.pop() + return [name.strip() for name in segments] + + def sequence_field_sort_order_is_ascending(self) -> bool: + """Whether ``sequence.field.sort-order`` is ascending (the default). + Mirrors Java ``CoreOptions.sequenceFieldSortOrderIsAscending()``. + """ + return (self.options.get(CoreOptions.SEQUENCE_FIELD_SORT_ORDER) + == SortOrder.ASCENDING) + + def ignore_delete(self) -> bool: + raw = self.options.to_map() + fallback_keys = ( + "ignore-delete", "first-row.ignore-delete", + "deduplicate.ignore-delete", + "partial-update.ignore-delete", + ) + for key in fallback_keys: + val = raw.get(key) + if val is not None: + return OptionsUtils.convert_to_boolean(val) + return False + def data_file_external_paths(self, default=None): external_paths_str = self.options.get(CoreOptions.DATA_FILE_EXTERNAL_PATHS, default) if not external_paths_str: @@ -775,6 +1014,23 @@ def data_file_external_paths_strategy(self, default=None): def data_file_external_paths_specific_fs(self, default=None): return self.options.get(CoreOptions.DATA_FILE_EXTERNAL_PATHS_SPECIFIC_FS, default) + def data_file_external_paths_weights(self, default=None): + value = self.options.get( + CoreOptions.DATA_FILE_EXTERNAL_PATHS_WEIGHTS, default + ) + if value is None: + return None + parts = value.split(",") + weights = [] + for part in parts: + parsed = int(part.strip()) + if parsed <= 0: + raise ValueError( + f"Weight must be positive, got: {parsed}" + ) + weights.append(parsed) + return weights + def commit_max_retries(self) -> int: return self.options.get(CoreOptions.COMMIT_MAX_RETRIES) diff --git a/paimon-python/pypaimon/common/predicate.py b/paimon-python/pypaimon/common/predicate.py index cb90dd7b1980..540705aa5cdd 100644 --- a/paimon-python/pypaimon/common/predicate.py +++ b/paimon-python/pypaimon/common/predicate.py @@ -299,7 +299,7 @@ def test_by_stats(self, min_v, max_v, literals) -> bool: return not any(min_v == l == max_v for l in literals) def test_by_arrow(self, val, literals) -> bool: - return ~val.isin(literals) + return (~val.isin(literals)) & val.is_valid() class Between(Tester): diff --git a/paimon-python/pypaimon/daft/__init__.py b/paimon-python/pypaimon/daft/__init__.py index e9651dd47744..b854830173d8 100644 --- a/paimon-python/pypaimon/daft/__init__.py +++ b/paimon-python/pypaimon/daft/__init__.py @@ -16,9 +16,9 @@ # limitations under the License. ################################################################################ -from pypaimon.daft.daft_paimon import read_paimon, write_paimon +from pypaimon.daft.daft_paimon import explain_paimon_scan, read_paimon, write_paimon -__all__ = ["read_paimon", "write_paimon", "PaimonCatalog", "PaimonTable"] +__all__ = ["explain_paimon_scan", "read_paimon", "write_paimon", "PaimonCatalog", "PaimonTable"] def __getattr__(name): diff --git a/paimon-python/pypaimon/daft/daft_catalog.py b/paimon-python/pypaimon/daft/daft_catalog.py index aaf3f6879d8a..d59df52dac42 100644 --- a/paimon-python/pypaimon/daft/daft_catalog.py +++ b/paimon-python/pypaimon/daft/daft_catalog.py @@ -115,7 +115,9 @@ def _drop_namespace(self, ident: Identifier) -> None: raise NotFoundError(f"Namespace '{db_name}' not found.") from ex def _drop_table(self, ident: Identifier) -> None: - paimon_ident = _to_paimon_ident(ident) + paimon_ident = _to_paimon_table_ident(ident) + if paimon_ident is None: + raise NotFoundError(f"Table '{ident}' not found.") try: self._inner.drop_table(paimon_ident, ignore_if_not_exists=False) except TableNotExistException as ex: @@ -134,7 +136,9 @@ def _has_namespace(self, ident: Identifier) -> bool: return False def _has_table(self, ident: Identifier) -> bool: - paimon_ident = _to_paimon_ident(ident) + paimon_ident = _to_paimon_table_ident(ident) + if paimon_ident is None: + return False try: self._inner.get_table(paimon_ident) return True @@ -149,7 +153,9 @@ def _get_function(self, ident: Identifier) -> Function: raise NotFoundError(f"Function '{ident}' not found in catalog '{self.name}'") def _get_table(self, ident: Identifier) -> PaimonTable: - paimon_ident = _to_paimon_ident(ident) + paimon_ident = _to_paimon_table_ident(ident) + if paimon_ident is None: + raise NotFoundError(f"Table '{ident}' not found.") try: inner = self._inner.get_table(paimon_ident) return PaimonTable(inner, catalog_options=self._catalog_options) @@ -216,6 +222,29 @@ def read(self, **options: Any) -> DataFrame: Table._validate_options("Paimon read", options, set()) return _read_table(self._inner, catalog_options=self._catalog_options) + def explain_scan( + self, + *, + filters: Any = None, + partition_filters: Any = None, + columns: list[str] | None = None, + limit: int | None = None, + io_config=None, + verbose: bool = False, + ) -> Any: + from pypaimon.daft.daft_paimon import _explain_table + + return _explain_table( + self._inner, + catalog_options=self._catalog_options, + filters=filters, + partition_filters=partition_filters, + columns=columns, + limit=limit, + io_config=io_config, + verbose=verbose, + ) + def append(self, df: DataFrame, **options: Any) -> None: from pypaimon.daft.daft_paimon import _write_table @@ -250,7 +279,7 @@ def truncate_partitions(self, partitions: list[dict[str, str]]) -> None: def _to_paimon_ident(ident: Identifier) -> str: """Convert a Daft identifier to a pypaimon identifier string. - - 1 part (table,) -> 'table' + - 1 part (namespace/table,) -> 'namespace_or_table' - 2 parts (db, table) -> 'db.table' - 3 parts (catalog, db, table) -> 'db.table' (catalog prefix stripped) """ @@ -264,6 +293,18 @@ def _to_paimon_ident(ident: Identifier) -> str: return ident +def _to_paimon_table_ident(ident: Identifier) -> str | None: + """Convert a Daft table identifier to Paimon's required db.table form.""" + if isinstance(ident, Identifier): + parts = tuple(ident) + if len(parts) == 3: + return f"{parts[1]}.{parts[2]}" + if len(parts) == 2: + return f"{parts[0]}.{parts[1]}" + return None + return ident + + def _cast_large_types(arrow_schema: pa.Schema) -> pa.Schema: """Convert PyArrow schema to be compatible with pypaimon. diff --git a/paimon-python/pypaimon/daft/daft_datasink.py b/paimon-python/pypaimon/daft/daft_datasink.py index c019b16a31f3..7e6b871f06dd 100644 --- a/paimon-python/pypaimon/daft/daft_datasink.py +++ b/paimon-python/pypaimon/daft/daft_datasink.py @@ -32,6 +32,93 @@ from pypaimon.table.file_store_table import FileStoreTable +_PaimonIdentifier = tuple[str, str, str | None] + + +def _options_to_dict(options: Any) -> dict[str, Any]: + if options is None: + return {} + if isinstance(options, dict): + return dict(options) + + to_map = getattr(options, "to_map", None) + if callable(to_map): + return dict(to_map()) + + data = getattr(options, "data", None) + if isinstance(data, dict): + return dict(data) + + return {} + + +def _extract_catalog_options(table: FileStoreTable) -> dict[str, Any]: + file_io = getattr(table, "file_io", None) + properties = getattr(file_io, "properties", None) + if properties is None: + properties = getattr(file_io, "catalog_options", None) + return _options_to_dict(properties) + + +def _extract_identifier(table: FileStoreTable) -> _PaimonIdentifier | None: + identifier = getattr(table, "identifier", None) + if identifier is None: + return None + + get_database_name = getattr(identifier, "get_database_name", None) + get_table_name = getattr(identifier, "get_table_name", None) + get_branch_name = getattr(identifier, "get_branch_name", None) + + database_name = ( + get_database_name() + if callable(get_database_name) + else getattr(identifier, "database", None) + ) + table_name = ( + get_table_name() + if callable(get_table_name) + else getattr(identifier, "object", None) + ) + branch_name = ( + get_branch_name() + if callable(get_branch_name) + else getattr(identifier, "branch", None) + ) + if database_name is None or table_name is None: + return None + return database_name, table_name, branch_name + + +def _to_paimon_identifier(identifier: _PaimonIdentifier) -> Any: + database_name, table_name, branch_name = identifier + if branch_name: + from pypaimon.common.identifier import Identifier + + return Identifier(database_name, table_name, branch_name) + return f"{database_name}.{table_name}" + + +def _load_table( + catalog_options: dict[str, Any], + table_identifier: _PaimonIdentifier | None, + table_path: str | None, +) -> FileStoreTable: + if catalog_options and table_identifier is not None: + from pypaimon.catalog.catalog_factory import CatalogFactory + + catalog = CatalogFactory.create(catalog_options) + return catalog.get_table(_to_paimon_identifier(table_identifier)) + + if table_path: + from pypaimon.table.file_store_table import FileStoreTable + + return FileStoreTable.from_path(table_path) + + raise RuntimeError( + "Unable to reconstruct Paimon table while deserializing PaimonDataSink." + ) + + class PaimonDataSink(DataSink[list[Any]]): """DataSink for writing data to an Apache Paimon table. @@ -45,14 +132,51 @@ class PaimonDataSink(DataSink[list[Any]]): def __init__(self, table: FileStoreTable, mode: str = "append") -> None: if mode not in ("append", "overwrite"): raise ValueError(f"Only 'append' or 'overwrite' mode is supported, got: {mode!r}") - self._table = table self._mode = mode + self._catalog_options = _extract_catalog_options(table) + self._table_identifier = _extract_identifier(table) + table_path = getattr(table, "table_path", None) + self._table_path = str(table_path) if table_path is not None else None + self._commit_user: str | None = None + self._init_table(table) + + def __getstate__(self) -> dict[str, Any]: + return { + "_mode": self._mode, + "_catalog_options": self._catalog_options, + "_table_identifier": self._table_identifier, + "_table_path": self._table_path, + "_commit_user": self._commit_user, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self._mode = state["_mode"] + self._catalog_options = state["_catalog_options"] + self._table_identifier = state["_table_identifier"] + self._table_path = state["_table_path"] + self._commit_user = state["_commit_user"] + table = _load_table( + self._catalog_options, + self._table_identifier, + self._table_path, + ) + self._init_table(table) + + def _init_table(self, table: FileStoreTable) -> None: + self._table = table from pypaimon.schema.data_types import PyarrowFieldParser self._target_schema: pa.Schema = PyarrowFieldParser.from_paimon_schema(table.fields) self._write_builder = table.new_batch_write_builder() - if mode == "overwrite": + if ( + self._commit_user is not None + and hasattr(self._write_builder, "commit_user") + ): + self._write_builder.commit_user = self._commit_user + else: + self._commit_user = getattr(self._write_builder, "commit_user", None) + if self._mode == "overwrite": self._write_builder.overwrite({}) def name(self) -> str: diff --git a/paimon-python/pypaimon/daft/daft_datasource.py b/paimon-python/pypaimon/daft/daft_datasource.py index 457fae375c57..5308ff017820 100644 --- a/paimon-python/pypaimon/daft/daft_datasource.py +++ b/paimon-python/pypaimon/daft/daft_datasource.py @@ -18,7 +18,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace import logging from typing import TYPE_CHECKING, Any from urllib.parse import urlparse @@ -32,6 +32,12 @@ from daft.recordbatch import RecordBatch from pypaimon.daft.daft_compat import require_file_range_reads +from pypaimon.daft.daft_explain import ( + PaimonReaderSplitExplain, + PaimonScanExplain, + READER_MODE_NATIVE_PARQUET, + READER_MODE_PYPAIMON_FALLBACK, +) from pypaimon.daft.daft_predicate_visitor import convert_filters_to_paimon if TYPE_CHECKING: @@ -39,7 +45,7 @@ from pypaimon.common.predicate import Predicate from pypaimon.manifest.schema.data_file_meta import DataFileMeta - from pypaimon.read.table_read import TableRead + from pypaimon.read.explain import ExplainSplitInfo from pypaimon.read.split import Split from pypaimon.table.file_store_table import FileStoreTable @@ -52,6 +58,8 @@ PAIMON_FILE_FORMAT_ORC = "orc" PAIMON_FILE_FORMAT_AVRO = "avro" +_PaimonIdentifier = tuple[str, str, str | None] + @dataclass(frozen=True, slots=True) class _ReadPushdownState: @@ -63,6 +71,94 @@ class _ReadPushdownState: source_limit: int | None +@dataclass(frozen=True, slots=True) +class _ReaderRouting: + reader_mode: str + fallback_reason: str | None + + @property + def use_native_reader(self) -> bool: + return self.reader_mode == READER_MODE_NATIVE_PARQUET + + +def _options_to_dict(options: Any) -> dict[str, Any]: + if options is None: + return {} + if isinstance(options, dict): + return dict(options) + return dict(options.to_map()) + + +def _extract_catalog_options(table: FileStoreTable) -> dict[str, Any]: + # Every FileIO exposes catalog properties via ``properties`` (CachingFileIO + # delegates to its wrapped FileIO), so no per-implementation handling needed. + return _options_to_dict(table.file_io.properties) + + +def _extract_identifier(table: FileStoreTable) -> _PaimonIdentifier | None: + identifier = table.identifier + if identifier is None: + return None + + database_name = identifier.get_database_name() + table_name = identifier.get_table_name() + if database_name is None or table_name is None: + return None + return database_name, table_name, identifier.get_branch_name() + + +def _extract_table_options(table: FileStoreTable) -> dict[str, Any]: + return _options_to_dict(table.schema().options) + + +def _to_paimon_identifier(identifier: _PaimonIdentifier) -> Any: + database_name, table_name, branch_name = identifier + if branch_name: + from pypaimon.common.identifier import Identifier + + return Identifier(database_name, table_name, branch_name) + return f"{database_name}.{table_name}" + + +def _load_table( + catalog_options: dict[str, Any], + table_identifier: _PaimonIdentifier | None, + table_path: str | None, + table_options: dict[str, Any], +) -> FileStoreTable: + if catalog_options and table_identifier is not None: + from pypaimon.catalog.catalog_factory import CatalogFactory + + catalog = CatalogFactory.create(catalog_options) + table = catalog.get_table(_to_paimon_identifier(table_identifier)) + elif table_path: + from pypaimon.table.file_store_table import FileStoreTable + + table = FileStoreTable.from_path(table_path) + else: + raise RuntimeError( + "Unable to reconstruct Paimon table while deserializing PaimonDataSource." + ) + + if table_options: + table = table.copy(table_options) + return table + + +def _build_storage_config( + catalog_options: dict[str, Any], + multithreaded_io: bool, +) -> StorageConfig: + from daft import context + from daft.daft import StorageConfig + + from pypaimon.daft.daft_io_config import _convert_paimon_catalog_options_to_io_config + + io_config = _convert_paimon_catalog_options_to_io_config(catalog_options) + io_config = io_config or context.get_context().daft_planning_config.default_io_config + return StorageConfig(multithreaded_io, io_config) + + class _PaimonPKSplitTask(DataSourceTask): """DataSourceTask for PK-table splits that require LSM-tree merge. @@ -73,15 +169,27 @@ class _PaimonPKSplitTask(DataSourceTask): def __init__( self, - table_read: TableRead, + table_catalog_options: dict[str, Any], + table_identifier: _PaimonIdentifier | None, + table_path: str | None, + table_options: dict[str, Any], split: Split, schema: Schema, + read_columns: list[str] | None = None, + limit: int | None = None, + predicate: Predicate | None = None, output_columns: list[str] | None = None, blob_column_names: set[str] | None = None, ) -> None: - self._table_read = table_read + self._table_catalog_options = table_catalog_options + self._table_identifier = table_identifier + self._table_path = table_path + self._table_options = table_options self._split = split self._schema = schema + self._read_columns = read_columns + self._limit = limit + self._predicate = predicate self._output_columns = output_columns self._blob_column_names = blob_column_names or set() @@ -90,7 +198,21 @@ def schema(self) -> Schema: return self._schema async def read(self) -> AsyncIterator[RecordBatch]: - reader = self._table_read.to_arrow_batch_reader([self._split]) + table = _load_table( + self._table_catalog_options, + self._table_identifier, + self._table_path, + self._table_options, + ) + read_builder = table.new_read_builder() + if self._read_columns is not None: + read_builder = read_builder.with_projection(self._read_columns) + if self._limit is not None: + read_builder = read_builder.with_limit(self._limit) + if self._predicate is not None: + read_builder = read_builder.with_filter(self._predicate) + + reader = read_builder.new_read().to_arrow_batch_reader([self._split]) for batch in iter(reader.read_next_batch, None): if self._output_columns is not None: batch = batch.select(self._output_columns) @@ -151,9 +273,58 @@ def __init__( storage_config: StorageConfig, catalog_options: dict[str, str], ) -> None: - self._table = table self._storage_config = storage_config - self._catalog_options = catalog_options + self._catalog_options = dict(catalog_options or {}) + self._table_catalog_options = { + **_extract_catalog_options(table), + **self._catalog_options, + } + self._table_identifier = _extract_identifier(table) + table_path = getattr(table, "table_path", None) + self._table_path = str(table_path) if table_path is not None else None + self._table_options = _extract_table_options(table) + self._pushed_filters: list[PyExpr] | None = None + self._paimon_predicate: Predicate | None = None + self._remaining_filters: list[PyExpr] | None = None + self._init_table(table) + + def __getstate__(self) -> dict[str, Any]: + return { + "_multithreaded_io": self._storage_config.multithreaded_io, + "_catalog_options": self._catalog_options, + "_table_catalog_options": self._table_catalog_options, + "_table_identifier": self._table_identifier, + "_table_path": self._table_path, + "_table_options": self._table_options, + "_pushed_filters": self._pushed_filters, + "_paimon_predicate": self._paimon_predicate, + "_remaining_filters": self._remaining_filters, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self._catalog_options = state["_catalog_options"] + self._table_catalog_options = state["_table_catalog_options"] + self._table_identifier = state["_table_identifier"] + self._table_path = state["_table_path"] + self._table_options = state["_table_options"] + self._pushed_filters = state.get("_pushed_filters") + self._paimon_predicate = state["_paimon_predicate"] + self._remaining_filters = state["_remaining_filters"] + self._storage_config = _build_storage_config( + self._table_catalog_options, + state["_multithreaded_io"], + ) + + table = _load_table( + self._table_catalog_options, + self._table_identifier, + self._table_path, + self._table_options, + ) + self._init_table(table) + + def _init_table(self, table: FileStoreTable) -> None: + self._table = table from pypaimon.schema.data_types import PyarrowFieldParser @@ -177,7 +348,11 @@ def __init__( else: self._schema = Schema.from_pyarrow_schema(pa_schema) - warehouse = catalog_options.get("warehouse", "") + warehouse = ( + self._catalog_options.get("warehouse") + or self._table_catalog_options.get("warehouse") + or "" + ) self._warehouse_scheme = urlparse(warehouse).scheme self._file_format = table.options.file_format().lower() @@ -189,9 +364,6 @@ def __init__( else {} ) - self._paimon_predicate: Predicate | None = None - self._remaining_filters: list[PyExpr] | None = None - @property def name(self) -> str: table_path = getattr(self._table, "table_path", None) @@ -213,6 +385,7 @@ def push_filters(self, filters: list[PyExpr]) -> tuple[list[PyExpr], list[PyExpr """ pushed_filters, remaining_filters, paimon_predicate = convert_filters_to_paimon(self._table, filters) + self._pushed_filters = pushed_filters self._paimon_predicate = paimon_predicate self._remaining_filters = remaining_filters @@ -225,13 +398,17 @@ def push_filters(self, filters: list[PyExpr]) -> tuple[list[PyExpr], list[PyExpr return pushed_filters, remaining_filters - async def get_tasks(self, pushdowns: Pushdowns) -> AsyncIterator[DataSourceTask]: - read_table = self._table + def _read_table_for_scan(self) -> FileStoreTable: if self._has_blob_columns: - read_table = self._table.copy({"blob-as-descriptor": "true"}) + return self._table.copy({"blob-as-descriptor": "true"}) + return self._table - read_builder = read_table.new_read_builder() - read_pushdowns = self._read_pushdown_state(read_table, pushdowns) + def _scan_read_builder( + self, + table: FileStoreTable, + read_pushdowns: _ReadPushdownState, + ) -> Any: + read_builder = table.new_read_builder() if read_pushdowns.requested_columns is not None: read_builder = read_builder.with_projection(read_pushdowns.requested_columns) @@ -246,6 +423,13 @@ async def get_tasks(self, pushdowns: Pushdowns) -> AsyncIterator[DataSourceTask] read_pushdowns.planning_predicate, ) + return read_builder + + async def get_tasks(self, pushdowns: Pushdowns) -> AsyncIterator[DataSourceTask]: + read_table = self._read_table_for_scan() + read_pushdowns = self._read_pushdown_state(read_table, pushdowns) + read_builder = self._scan_read_builder(read_table, read_pushdowns) + if self._table.partition_keys and pushdowns.partition_filters is None: logger.warning( "%s has partition keys %s but no partition filter was specified. " @@ -256,34 +440,21 @@ async def get_tasks(self, pushdowns: Pushdowns) -> AsyncIterator[DataSourceTask] plan = read_builder.new_scan().plan() - pv_cache: dict[tuple[Any, ...], RecordBatch | None] = {} + pv_cache: dict[tuple[tuple[str, Any], ...], RecordBatch | None] = {} for split in plan.splits(): - if self._table.partition_keys and pushdowns.partition_filters is not None: - pv_key = tuple(sorted(split.partition.to_dict().items())) - if pv_key not in pv_cache: - pv_cache[pv_key] = self._build_partition_values(split) - pv = pv_cache[pv_key] - if pv is not None and len(pv.filter(ExpressionsProjection([pushdowns.partition_filters]))) == 0: - continue - - _deletion_files = getattr(split, "data_deletion_files", None) - has_deletion_vectors = _deletion_files is not None and any(df is not None for df in _deletion_files) - - can_use_native_reader = ( - self._is_parquet - and not self._has_blob_columns - and (not self._table.is_primary_key_table or split.raw_convertible) - and not has_deletion_vectors + if self._partition_filter_skips_split(split, pushdowns, pv_cache): + continue + + routing = self._reader_routing( + raw_convertible=split.raw_convertible, + has_deletion_vectors=self._split_has_deletion_vectors(split), ) - if can_use_native_reader: + if routing.use_native_reader: pv = None if self._table.partition_keys: - pv_key = tuple(sorted(split.partition.to_dict().items())) - if pv_key not in pv_cache: - pv_cache[pv_key] = self._build_partition_values(split) - pv = pv_cache[pv_key] + pv = self._partition_values(split, pv_cache) for data_file in split.files: file_uri = self._build_file_uri(self._data_file_path(data_file)) @@ -297,32 +468,187 @@ async def get_tasks(self, pushdowns: Pushdowns) -> AsyncIterator[DataSourceTask] storage_config=self._storage_config, ) else: - if not self._is_parquet: - reason = "non-parquet format" - elif self._has_blob_columns: - reason = "blob columns present" - elif has_deletion_vectors: - reason = "deletion vectors present" - else: - reason = "LSM merge required" logger.debug( "Split with %d files using pypaimon fallback (%s).", len(split.files), - reason, + routing.fallback_reason, ) yield _PaimonPKSplitTask( - self._fallback_read_builder( - read_table, - read_pushdowns.read_columns, - read_pushdowns.source_limit, - read_pushdowns.reader_predicate, - ).new_read(), + self._table_catalog_options, + self._table_identifier, + self._table_path, + _extract_table_options(read_table), split, self._project_schema(read_pushdowns.task_columns), + read_pushdowns.read_columns, + read_pushdowns.source_limit, + read_pushdowns.reader_predicate, read_pushdowns.task_columns, self._blob_column_names, ) + def explain_scan(self, pushdowns: Pushdowns, verbose: bool = False) -> PaimonScanExplain: + read_table = self._read_table_for_scan() + read_pushdowns = self._read_pushdown_state(read_table, pushdowns) + read_builder = self._scan_read_builder(read_table, read_pushdowns) + + paimon_scan = read_builder.explain(verbose=True) + split_details = paimon_scan.splits or [] + + native_split_count = 0 + native_file_count = 0 + fallback_split_count = 0 + fallback_file_count = 0 + fallback_reasons: dict[str, int] = {} + explained_splits: list[PaimonReaderSplitExplain] | None = [] if verbose else None + pv_cache: dict[tuple[tuple[str, Any], ...], RecordBatch | None] = {} + + for split in split_details: + if self._partition_filter_skips_explain_split(split, pushdowns, pv_cache): + continue + + routing = self._reader_routing( + raw_convertible=split.raw_convertible, + has_deletion_vectors=split.has_deletion_vectors, + ) + if routing.use_native_reader: + native_split_count += 1 + native_file_count += split.file_count + else: + fallback_split_count += 1 + fallback_file_count += split.file_count + reason = routing.fallback_reason or "unknown" + fallback_reasons[reason] = fallback_reasons.get(reason, 0) + 1 + + if explained_splits is not None: + explained_splits.append( + PaimonReaderSplitExplain( + partition=split.partition, + bucket=split.bucket, + file_count=split.file_count, + row_count=split.row_count, + file_size=split.file_size, + reader_mode=routing.reader_mode, + fallback_reason=routing.fallback_reason, + file_paths=split.file_paths, + ) + ) + + if not verbose: + paimon_scan = replace(paimon_scan, splits=None) + + pushed_filters, remaining_filters = self._filter_pushdown_explain(pushdowns) + return PaimonScanExplain( + paimon_scan=paimon_scan, + native_parquet_split_count=native_split_count, + native_parquet_file_count=native_file_count, + pypaimon_fallback_split_count=fallback_split_count, + pypaimon_fallback_file_count=fallback_file_count, + fallback_reasons=fallback_reasons, + pushed_filters=pushed_filters, + remaining_filters=remaining_filters, + partition_filters=self._format_partition_filters(pushdowns), + requested_columns=read_pushdowns.requested_columns, + task_columns=read_pushdowns.task_columns, + fallback_read_columns=read_pushdowns.read_columns, + requested_limit=pushdowns.limit, + source_limit=read_pushdowns.source_limit, + limit_pushed=pushdowns.limit is not None and read_pushdowns.source_limit == pushdowns.limit, + splits=explained_splits, + ) + + def _reader_routing( + self, + raw_convertible: bool, + has_deletion_vectors: bool, + ) -> _ReaderRouting: + can_use_native_reader = ( + self._is_parquet + and not self._has_blob_columns + and (not self._table.is_primary_key_table or raw_convertible) + and not has_deletion_vectors + ) + if can_use_native_reader: + return _ReaderRouting(READER_MODE_NATIVE_PARQUET, None) + + if not self._is_parquet: + reason = "non-parquet format" + elif self._has_blob_columns: + reason = "blob columns present" + elif has_deletion_vectors: + reason = "deletion vectors present" + else: + reason = "LSM merge required" + return _ReaderRouting(READER_MODE_PYPAIMON_FALLBACK, reason) + + @staticmethod + def _split_has_deletion_vectors(split: Split) -> bool: + deletion_files = getattr(split, "data_deletion_files", None) + return deletion_files is not None and any(df is not None for df in deletion_files) + + def _partition_filter_skips_split( + self, + split: Split, + pushdowns: Pushdowns, + pv_cache: dict[tuple[tuple[str, Any], ...], RecordBatch | None], + ) -> bool: + if not self._table.partition_keys or pushdowns.partition_filters is None: + return False + pv = self._partition_values(split, pv_cache) + return self._partition_filter_skips_values(pv, pushdowns) + + def _partition_filter_skips_explain_split( + self, + split: ExplainSplitInfo, + pushdowns: Pushdowns, + pv_cache: dict[tuple[tuple[str, Any], ...], RecordBatch | None], + ) -> bool: + if not self._table.partition_keys or pushdowns.partition_filters is None: + return False + pv = self._partition_values_from_dict(split.partition, pv_cache) + return self._partition_filter_skips_values(pv, pushdowns) + + @staticmethod + def _partition_filter_skips_values( + partition_values: RecordBatch | None, + pushdowns: Pushdowns, + ) -> bool: + return ( + partition_values is not None + and len(partition_values.filter(ExpressionsProjection([pushdowns.partition_filters]))) == 0 + ) + + def _format_partition_filters(self, pushdowns: Pushdowns) -> list[str]: + if pushdowns.partition_filters is None: + return [] + return self._format_pyexprs([getattr(pushdowns.partition_filters, "_expr", pushdowns.partition_filters)]) + + def _filter_pushdown_explain(self, pushdowns: Pushdowns) -> tuple[list[str], list[str]]: + if self._remaining_filters is not None: + return ( + self._format_pyexprs(self._pushed_filters or []), + self._format_pyexprs(self._remaining_filters), + ) + + if pushdowns.filters is None: + return [], [] + + py_expr = getattr(pushdowns.filters, "_expr", pushdowns.filters) + pushed_filters, remaining_filters, _ = convert_filters_to_paimon(self._table, [py_expr]) + return self._format_pyexprs(pushed_filters), self._format_pyexprs(remaining_filters) + + @staticmethod + def _format_pyexprs(py_exprs: list[PyExpr]) -> list[str]: + from daft.expressions import Expression + + result = [] + for py_expr in py_exprs: + try: + result.append(str(Expression._from_pyexpr(py_expr))) + except Exception: + result.append(str(py_expr)) + return result + def _build_file_uri(self, file_path: str) -> str: """Reconstruct a full URI from a (potentially scheme-stripped) file_path.""" if urlparse(file_path).scheme: @@ -337,10 +663,29 @@ def _data_file_path(data_file: DataFileMeta) -> str: def _build_partition_values(self, split: Split) -> daft.recordbatch.RecordBatch | None: """Build a single-row RecordBatch encoding the partition values for a split.""" + return self._build_partition_values_from_dict(split.partition.to_dict()) + + def _partition_values( + self, + split: Split, + pv_cache: dict[tuple[tuple[str, Any], ...], RecordBatch | None], + ) -> RecordBatch | None: + return self._partition_values_from_dict(split.partition.to_dict(), pv_cache) + + def _partition_values_from_dict( + self, + partition_dict: dict[str, Any], + pv_cache: dict[tuple[tuple[str, Any], ...], RecordBatch | None], + ) -> RecordBatch | None: + pv_key = tuple(sorted(partition_dict.items())) + if pv_key not in pv_cache: + pv_cache[pv_key] = self._build_partition_values_from_dict(partition_dict) + return pv_cache[pv_key] + + def _build_partition_values_from_dict(self, partition_dict: dict[str, Any]) -> daft.recordbatch.RecordBatch | None: if not self._table.partition_keys: return None - partition_dict = split.partition.to_dict() arrays: dict[str, daft.Series] = {} for pfield in self._table.partition_keys_fields: value = partition_dict.get(pfield.name) @@ -412,22 +757,6 @@ def _project_schema(self, columns: list[str] | None) -> Schema: [(name, field_map[name].dtype) for name in columns if name in field_map] ) - def _fallback_read_builder( - self, - table: FileStoreTable, - read_columns: list[str] | None, - limit: int | None, - predicate: Predicate | None, - ) -> Any: - read_builder = table.new_read_builder() - if read_columns is not None: - read_builder = read_builder.with_projection(read_columns) - if limit is not None: - read_builder = read_builder.with_limit(limit) - if predicate is not None: - read_builder = read_builder.with_filter(predicate) - return read_builder - def _read_pushdown_state( self, table: FileStoreTable, diff --git a/paimon-python/pypaimon/daft/daft_explain.py b/paimon-python/pypaimon/daft/daft_explain.py new file mode 100644 index 000000000000..6c97f393aea0 --- /dev/null +++ b/paimon-python/pypaimon/daft/daft_explain.py @@ -0,0 +1,160 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Structured explain result for Daft Paimon scans.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pypaimon.read.explain import ExplainResult + + +READER_MODE_NATIVE_PARQUET = "native_parquet" +READER_MODE_PYPAIMON_FALLBACK = "pypaimon_fallback" + + +@dataclass(frozen=True, slots=True) +class PaimonReaderSplitExplain: + partition: dict[str, Any] + bucket: int + file_count: int + row_count: int + file_size: int + reader_mode: str + fallback_reason: str | None + file_paths: list[str] + + +@dataclass(frozen=True, slots=True) +class PaimonScanExplain: + paimon_scan: ExplainResult + + native_parquet_split_count: int = 0 + native_parquet_file_count: int = 0 + pypaimon_fallback_split_count: int = 0 + pypaimon_fallback_file_count: int = 0 + fallback_reasons: dict[str, int] = field(default_factory=dict) + + pushed_filters: list[str] = field(default_factory=list) + remaining_filters: list[str] = field(default_factory=list) + partition_filters: list[str] = field(default_factory=list) + + requested_columns: list[str] | None = None + task_columns: list[str] | None = None + fallback_read_columns: list[str] | None = None + + requested_limit: int | None = None + source_limit: int | None = None + limit_pushed: bool = False + + splits: list[PaimonReaderSplitExplain] | None = None + + @property + def total_split_count(self) -> int: + return self.native_parquet_split_count + self.pypaimon_fallback_split_count + + @property + def total_file_count(self) -> int: + return self.native_parquet_file_count + self.pypaimon_fallback_file_count + + def __str__(self) -> str: + return render_daft_paimon_explain(self) + + +def render_daft_paimon_explain(result: PaimonScanExplain) -> str: + out = [] + out.append("== Daft Paimon Scan ==") + _line(out, "Native Parquet splits", _count_files( + result.native_parquet_split_count, + result.native_parquet_file_count, + )) + _line(out, "pypaimon fallback splits", _count_files( + result.pypaimon_fallback_split_count, + result.pypaimon_fallback_file_count, + )) + _line(out, "Fallback reasons", _format_reason_counts(result.fallback_reasons)) + _line(out, "Pushed filters", _format_list(result.pushed_filters)) + _line(out, "Remaining filters", _format_list(result.remaining_filters)) + _line(out, "Partition filters", _format_list(result.partition_filters)) + _line(out, "Requested columns", _format_optional_list(result.requested_columns, "")) + _line(out, "Task columns", _format_optional_list(result.task_columns, "")) + _line(out, "Fallback read columns", _format_optional_list( + result.fallback_read_columns, + "", + )) + _line(out, "Limit", _format_limit(result)) + + if result.splits is not None: + out.append("") + out.append("Splits:") + for index, split in enumerate(result.splits): + suffix = "" if split.fallback_reason is None else " ({})".format(split.fallback_reason) + out.append( + " #{} bucket={} files={} rows={} size={} mode={}{}".format( + index, + split.bucket, + split.file_count, + split.row_count, + split.file_size, + split.reader_mode, + suffix, + ) + ) + + out.append("") + out.append(str(result.paimon_scan).rstrip()) + return "\n".join(out) + + +def _line(out: list[str], key: str, value: str) -> None: + out.append("{:<28} {}".format(key + ":", value)) + + +def _count_files(split_count: int, file_count: int) -> str: + return "{} ({} files)".format(split_count, file_count) + + +def _format_reason_counts(reasons: dict[str, int]) -> str: + if not reasons: + return "" + return ", ".join("{}: {}".format(reason, count) for reason, count in sorted(reasons.items())) + + +def _format_list(values: list[str]) -> str: + if not values: + return "" + return ", ".join(values) + + +def _format_optional_list(values: list[str] | None, empty: str) -> str: + if values is None: + return empty + if not values: + return "[]" + return "[{}]".format(", ".join(values)) + + +def _format_limit(result: PaimonScanExplain) -> str: + if result.requested_limit is None: + return "" + pushed = "pushed" if result.limit_pushed else "not pushed" + source = "" if result.source_limit is None else str(result.source_limit) + return "requested {}, source {} ({})".format(result.requested_limit, source, pushed) diff --git a/paimon-python/pypaimon/daft/daft_paimon.py b/paimon-python/pypaimon/daft/daft_paimon.py index cde1e23d808e..29825fbc11a9 100644 --- a/paimon-python/pypaimon/daft/daft_paimon.py +++ b/paimon-python/pypaimon/daft/daft_paimon.py @@ -20,20 +20,22 @@ Usage:: - from pypaimon.daft import read_paimon, write_paimon + from pypaimon.daft import explain_paimon_scan, read_paimon, write_paimon df = read_paimon("db.table", catalog_options={"warehouse": "/path"}) + explain = explain_paimon_scan("db.table", catalog_options={"warehouse": "/path"}) write_paimon(df, "db.table", catalog_options={"warehouse": "/path"}) """ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from urllib.parse import urlparse if TYPE_CHECKING: import daft + from pypaimon.daft.daft_explain import PaimonScanExplain from pypaimon.table.file_store_table import FileStoreTable @@ -57,19 +59,31 @@ def _enrich_options_with_rest_token( return enriched -def _read_table( +def _time_travel_table( table: FileStoreTable, - catalog_options: Dict[str, str] | None = None, - io_config=None, snapshot_id: int | None = None, tag_name: str | None = None, -) -> "daft.DataFrame": - """Read a Paimon table object into a lazy Daft DataFrame.""" +) -> FileStoreTable: if snapshot_id is not None and tag_name is not None: raise ValueError( "snapshot_id and tag_name cannot be set at the same time" ) + travel_options: dict[str, str] = {} + if snapshot_id is not None: + travel_options["scan.snapshot-id"] = str(snapshot_id) + if tag_name is not None: + travel_options["scan.tag-name"] = tag_name + if travel_options: + return table.copy(travel_options) + return table + + +def _source_for_table( + table: FileStoreTable, + catalog_options: Dict[str, str] | None = None, + io_config=None, +): from daft import context, runners from daft.daft import StorageConfig @@ -78,14 +92,6 @@ def _read_table( _convert_paimon_catalog_options_to_io_config, ) - travel_options: dict[str, str] = {} - if snapshot_id is not None: - travel_options["scan.snapshot-id"] = str(snapshot_id) - if tag_name is not None: - travel_options["scan.tag-name"] = tag_name - if travel_options: - table = table.copy(travel_options) - if catalog_options is None: catalog_options = {} @@ -97,10 +103,71 @@ def _read_table( multithreaded_io = runners.get_or_create_runner().name != "ray" storage_config = StorageConfig(multithreaded_io, io_config) - source = PaimonDataSource( + return PaimonDataSource( table, storage_config=storage_config, catalog_options=catalog_options ) - return source.read() + + +def _read_table( + table: FileStoreTable, + catalog_options: Dict[str, str] | None = None, + io_config=None, + snapshot_id: int | None = None, + tag_name: str | None = None, +) -> "daft.DataFrame": + """Read a Paimon table object into a lazy Daft DataFrame.""" + table = _time_travel_table(table, snapshot_id=snapshot_id, tag_name=tag_name) + return _source_for_table(table, catalog_options=catalog_options, io_config=io_config).read() + + +def _normalize_explain_filters(filters: Any) -> tuple[Any, list[Any]]: + if filters is None: + return None, [] + + if isinstance(filters, (list, tuple)): + if not filters: + return None, [] + filter_exprs = list(filters) + combined = filter_exprs[0] + for filter_expr in filter_exprs[1:]: + combined = combined & filter_expr + else: + filter_exprs = [filters] + combined = filters + + return combined, [getattr(filter_expr, "_expr", filter_expr) for filter_expr in filter_exprs] + + +def _explain_table( + table: FileStoreTable, + catalog_options: Dict[str, str] | None = None, + io_config=None, + snapshot_id: int | None = None, + tag_name: str | None = None, + filters: Any = None, + partition_filters: Any = None, + columns: list[str] | None = None, + limit: int | None = None, + verbose: bool = False, +) -> "PaimonScanExplain": + """Explain a Paimon table object using Daft's datasource pushdown model.""" + from daft.io.pushdowns import Pushdowns + + table = _time_travel_table(table, snapshot_id=snapshot_id, tag_name=tag_name) + source = _source_for_table(table, catalog_options=catalog_options, io_config=io_config) + filter_expr, filter_pyexprs = _normalize_explain_filters(filters) + partition_filter_expr, _ = _normalize_explain_filters(partition_filters) + if filter_pyexprs: + source.push_filters(filter_pyexprs) + return source.explain_scan( + Pushdowns( + filters=filter_expr, + partition_filters=partition_filter_expr, + columns=columns, + limit=limit, + ), + verbose=verbose, + ) def _write_table( @@ -143,6 +210,11 @@ def read_paimon( Returns: A lazy ``daft.DataFrame`` backed by this Paimon table. """ + if snapshot_id is not None and tag_name is not None: + raise ValueError( + "snapshot_id and tag_name cannot be set at the same time" + ) + from pypaimon.catalog.catalog_factory import CatalogFactory catalog = CatalogFactory.create(catalog_options) @@ -154,6 +226,44 @@ def read_paimon( ) +def explain_paimon_scan( + table_identifier: str, + catalog_options: Dict[str, str], + *, + filters: Any = None, + partition_filters: Any = None, + columns: list[str] | None = None, + limit: int | None = None, + snapshot_id: Optional[int] = None, + tag_name: Optional[str] = None, + io_config=None, + verbose: bool = False, +) -> "PaimonScanExplain": + """Explain a Paimon scan through Daft's reader-routing layer. + + The optional ``filters`` argument accepts a Daft expression or a list of + Daft expressions. Lists are treated as conjunctions, matching how multiple + pushed filters reach Daft datasources. + """ + from pypaimon.catalog.catalog_factory import CatalogFactory + + catalog = CatalogFactory.create(catalog_options) + table = catalog.get_table(table_identifier) + + return _explain_table( + table, + catalog_options=catalog_options, + io_config=io_config, + snapshot_id=snapshot_id, + tag_name=tag_name, + filters=filters, + partition_filters=partition_filters, + columns=columns, + limit=limit, + verbose=verbose, + ) + + def write_paimon( df: "daft.DataFrame", table_identifier: str, diff --git a/paimon-python/pypaimon/daft/daft_predicate_visitor.py b/paimon-python/pypaimon/daft/daft_predicate_visitor.py index 5dfe4bca3d0d..897672f09623 100644 --- a/paimon-python/pypaimon/daft/daft_predicate_visitor.py +++ b/paimon-python/pypaimon/daft/daft_predicate_visitor.py @@ -28,6 +28,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from daft.expressions import Expression from daft.expressions.visitor import PredicateVisitor if TYPE_CHECKING: @@ -36,7 +37,6 @@ from pypaimon.table.file_store_table import FileStoreTable from daft.daft import PyExpr - from daft.expressions import Expression logger = logging.getLogger(__name__) @@ -63,6 +63,43 @@ class _Unsupported: _UNSUPPORTED = _Unsupported() +class _AndSplitter(PredicateVisitor[tuple[Expression, Expression] | None]): + """Returns the direct children only when the expression root is AND.""" + + def visit_and(self, left: Expression, right: Expression) -> tuple[Expression, Expression]: + return left, right + + def _not_and(self, *args: Any) -> None: + return None + + visit_or = _not_and + visit_not = _not_and + visit_equal = _not_and + visit_not_equal = _not_and + visit_less_than = _not_and + visit_less_than_or_equal = _not_and + visit_greater_than = _not_and + visit_greater_than_or_equal = _not_and + visit_between = _not_and + visit_is_in = _not_and + visit_is_null = _not_and + visit_not_null = _not_and + visit_col = _not_and + visit_lit = _not_and + visit_alias = _not_and + visit_cast = _not_and + visit_coalesce = _not_and + visit_function = _not_and + + +def _split_conjuncts(expr: Expression) -> list[Expression]: + children = _AndSplitter().visit(expr) + if children is None: + return [expr] + left, right = children + return _split_conjuncts(left) + _split_conjuncts(right) + + class PaimonPredicateVisitor(PredicateVisitor[Any]): """Tree fold visitor that converts Daft expressions to Paimon predicates. @@ -119,7 +156,22 @@ def visit_or(self, left: Expression, right: Expression) -> Predicate | _Unsuppor return self._builder.or_predicates([left_pred, right_pred]) return _UNSUPPORTED - def visit_not(self, expr: Expression) -> _Unsupported: + def visit_not(self, expr: Expression) -> Predicate | _Unsupported: + predicate = self.visit(expr) + if not self._is_predicate(predicate): + return _UNSUPPORTED + + if predicate.method == "equal": + return self._builder.not_equal(predicate.field, predicate.literals[0]) + if predicate.method == "in": + return self._builder.is_not_in(predicate.field, predicate.literals) + if predicate.method == "between": + return self._builder.not_between(predicate.field, predicate.literals[0], predicate.literals[1]) + if predicate.method == "isNull": + return self._builder.is_not_null(predicate.field) + if predicate.method == "isNotNull": + return self._builder.is_null(predicate.field) + return _UNSUPPORTED # -- Comparison operators -- @@ -272,14 +324,16 @@ def convert_filters_to_paimon( for py_expr in py_filters: expr = Expression._from_pyexpr(py_expr) - predicate = converter.visit(expr) - - if isinstance(predicate, Predicate): - pushed_filters.append(py_expr) - predicates.append(predicate) - else: - remaining_filters.append(py_expr) - logger.debug("Filter %s cannot be pushed down to Paimon", expr) + + for conjunct in _split_conjuncts(expr): + predicate = converter.visit(conjunct) + + if isinstance(predicate, Predicate): + pushed_filters.append(conjunct._expr) + predicates.append(predicate) + else: + remaining_filters.append(conjunct._expr) + logger.debug("Filter %s cannot be pushed down to Paimon", conjunct) combined_predicate: Predicate | None = None if predicates: diff --git a/paimon-python/pypaimon/filesystem/_kerberos.py b/paimon-python/pypaimon/filesystem/_kerberos.py new file mode 100644 index 000000000000..00a31ba975eb --- /dev/null +++ b/paimon-python/pypaimon/filesystem/_kerberos.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Shared Kerberos helpers used by HDFS FileIO backends.""" + +import os +import subprocess +from typing import Optional + + +def kerberos_login_from_keytab(principal: str, keytab: str) -> None: + if not os.path.isfile(keytab): + raise FileNotFoundError(f"Kerberos keytab file not found: {keytab}") + if not os.access(keytab, os.R_OK): + raise PermissionError(f"Kerberos keytab file is not readable: {keytab}") + subprocess.run( + ['kinit', '-kt', keytab, principal], + check=True, capture_output=True, text=True, + ) + + +def get_ticket_cache_path() -> Optional[str]: + cc = os.environ.get('KRB5CCNAME') + if cc: + if cc.startswith('FILE:'): + return cc[5:] + return cc + default_path = f'/tmp/krb5cc_{os.getuid()}' + if os.path.exists(default_path): + return default_path + return None diff --git a/paimon-python/pypaimon/filesystem/caching_file_io.py b/paimon-python/pypaimon/filesystem/caching_file_io.py index 778614c56c2c..8aa64e227f19 100644 --- a/paimon-python/pypaimon/filesystem/caching_file_io.py +++ b/paimon-python/pypaimon/filesystem/caching_file_io.py @@ -31,7 +31,7 @@ from collections import OrderedDict from typing import Optional -from pypaimon.common.file_io import FileIO +from pypaimon.common.file_io import FileIO, supports_pread, pread from pypaimon.utils.file_type import FileType @@ -202,6 +202,8 @@ def __init__(self, file_io, file_path: str, cache): self._file_size = -1 self._cache = cache self._pos = 0 + self._io_lock = threading.Lock() + self._remote_supports_pread = None def _get_file_size(self) -> int: if self._file_size == -1: @@ -249,6 +251,28 @@ def read(self, size=-1) -> bytes: self._pos = end return bytes(result) + def read_at(self, nbytes: int, offset: int) -> bytes: + """Position-based read. Does not change the cursor. Thread-safe.""" + if nbytes <= 0 or offset >= self._get_file_size(): + return b'' + + end = min(offset + nbytes, self._get_file_size()) + block_size = self._cache.block_size + + first_block = offset // block_size + last_block = (end - 1) // block_size + + result = bytearray() + for bi in range(first_block, last_block + 1): + block_data = self._read_block(bi) + + block_start = bi * block_size + start_in_block = max(offset - block_start, 0) + end_in_block = min(end - block_start, len(block_data)) + result.extend(block_data[start_in_block:end_in_block]) + + return bytes(result) + def _read_block(self, block_index: int) -> bytes: cached = self._cache.get_block(self._file_path, block_index) if cached is not None: @@ -258,13 +282,20 @@ def _read_block(self, block_index: int) -> bytes: offset = block_index * block_size read_size = min(block_size, self._get_file_size() - offset) - stream = self._get_remote_stream() - stream.seek(offset) - data = self._read_fully(stream, read_size) - + data = self._read_remote(offset, read_size) self._cache.put_block(self._file_path, block_index, data) return data + def _read_remote(self, offset: int, size: int) -> bytes: + stream = self._get_remote_stream() + if self._remote_supports_pread is None: + self._remote_supports_pread = supports_pread(stream) + if self._remote_supports_pread: + return pread(stream, size, offset) + with self._io_lock: + stream.seek(offset) + return self._read_fully(stream, size) + def _read_fully(self, stream, size: int) -> bytes: buf = bytearray() remaining = size @@ -356,6 +387,10 @@ def wrap_with_caching_if_needed(file_io, options, cache=None): return file_io return CachingFileIO(file_io, cache, whitelist) + @property + def properties(self): + return self._delegate.properties + def new_input_stream(self, path: str): file_type = FileType.classify(path) if self._cache is None or file_type not in self._whitelist or FileType.is_mutable(path): @@ -411,9 +446,15 @@ def write_lance(self, *args, **kwargs): def write_blob(self, *args, **kwargs): return self._delegate.write_blob(*args, **kwargs) + def write_mosaic(self, *args, **kwargs): + return self._delegate.write_mosaic(*args, **kwargs) + def write_vortex(self, *args, **kwargs): return self._delegate.write_vortex(*args, **kwargs) + def write_row(self, *args, **kwargs): + return self._delegate.write_row(*args, **kwargs) + def __getattr__(self, name): return getattr(self._delegate, name) diff --git a/paimon-python/pypaimon/filesystem/hdfs_native_file_io.py b/paimon-python/pypaimon/filesystem/hdfs_native_file_io.py new file mode 100644 index 000000000000..aa68e3b747b6 --- /dev/null +++ b/paimon-python/pypaimon/filesystem/hdfs_native_file_io.py @@ -0,0 +1,698 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""HDFS FileIO backed by the hdfs-native protocol client (no JVM, no libhdfs).""" + +import logging +import os +import xml.etree.ElementTree as ET +from datetime import datetime, timezone +from pathlib import PurePosixPath +from typing import Dict, Optional +from urllib.parse import urlparse + +import pyarrow +import pyarrow.fs as pafs + +from pypaimon.common.file_io import FileIO +from pypaimon.common.options import Options +from pypaimon.common.options.config import HdfsOptions, SecurityOptions +from pypaimon.common.uri_reader import UriReaderFactory +from pypaimon.filesystem import _kerberos +from pypaimon.schema.data_types import AtomicType, DataField, PyarrowFieldParser +from pypaimon.write.blob_format_writer import BlobFormatWriter + + +class _HdfsFileInfo: + """pafs.FileInfo-shaped adapter built from hdfs_native.FileStatus.""" + __slots__ = ('path', 'size', 'type', 'mtime', 'base_name') + + def __init__(self, path: str, size: Optional[int], file_type, mtime): + self.path = path + self.size = size + self.type = file_type + self.mtime = mtime + self.base_name = path.rsplit('/', 1)[-1] if path else '' + + +class _HdfsWriterAdapter: + """File-like wrapper over hdfs_native.FileWriter.""" + + def __init__(self, fw): + self._fw = fw + self._pos = 0 + self._closed = False + + def write(self, buf) -> int: + n = self._fw.write(buf) + if n is None: + n = len(buf) if hasattr(buf, '__len__') else 0 + self._pos += n + return n + + def tell(self) -> int: + return self._pos + + def flush(self): + pass + + def close(self): + if not self._closed: + try: + self._fw.close() + finally: + self._closed = True + + @property + def closed(self) -> bool: + return self._closed + + def writable(self) -> bool: + return True + + def readable(self) -> bool: + return False + + def seekable(self) -> bool: + return False + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.close() + return False + + +class _HdfsReaderAdapter: + """File-like wrapper over hdfs_native.FileReader. + + Delegates read/seek/tell straight to the underlying reader (which is an + io.RawIOBase subclass with full seek/tell support). The wrapper only + exists so that exiting a `with` block guarantees the underlying handle + is closed — hdfs-native's own FileReader.__exit__ is a no-op. + """ + + def __init__(self, fr): + self._fr = fr + self._closed = False + + def read(self, size: int = -1) -> bytes: + return self._fr.read(-1 if size is None else size) + + def read1(self, size: int = -1) -> bytes: + return self.read(size) + + def seek(self, pos: int, whence: int = 0) -> int: + self._fr.seek(pos, whence) + return self._fr.tell() + + def tell(self) -> int: + return self._fr.tell() + + def close(self): + if self._closed: + return + try: + close = getattr(self._fr, 'close', None) + if close is not None: + close() + finally: + self._closed = True + + @property + def closed(self) -> bool: + return self._closed + + def readable(self) -> bool: + return True + + def writable(self) -> bool: + return False + + def seekable(self) -> bool: + return True + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.close() + return False + + +class HdfsNativeFileIO(FileIO): + """HDFS FileIO that speaks the HDFS RPC protocol directly. + + No JVM, no libhdfs, no Hadoop install required. Hadoop xml is still + consumed if present (HADOOP_CONF_DIR or `hdfs.conf-dir` option) for + viewfs mount tables and HA NameNode lists; alternatively the same + key/values can be delivered via the catalog options channel (a REST + catalog can therefore push the cluster wiring with the response). + """ + + NATIVE_KEY_PREFIXES = HdfsOptions.HDFS_NATIVE_CONFIG_KEY_PREFIXES + NS_PREFIX = HdfsOptions.HDFS_CONFIG_PREFIX + + def __init__(self, path: str, catalog_options: Options): + self.properties = catalog_options or Options({}) + self.logger = logging.getLogger(__name__) + self.uri_reader_factory = UriReaderFactory(self.properties) + + scheme, netloc, _ = self.parse_location(path) + if scheme not in {"hdfs", "viewfs"}: + raise ValueError( + f"HdfsNativeFileIO does not support scheme '{scheme}'" + ) + self._scheme = scheme + self._netloc = netloc + + try: + from hdfs_native import Client, WriteOptions + except ImportError as e: + raise ImportError( + "hdfs-native is not installed. " + "Install with: pip install 'pypaimon[hdfs]'" + ) from e + self._WriteOptions = WriteOptions + + self._setup_kerberos() + + config_dir = ( + self.properties.get(HdfsOptions.HDFS_CONF_DIR) + or os.environ.get("HADOOP_CONF_DIR") + ) + hadoop_xml = self._load_hadoop_xml(config_dir) + + config = self._build_config_dict() + self._maybe_inject_viewfs_fallback(scheme, netloc, config, hadoop_xml) + + # Stash for the lazy `filesystem` property (the fsspec/pyarrow facade + # is only built if a caller asks for it). + self._config = config + self._hadoop_xml = hadoop_xml + self._config_dir = config_dir + self._filesystem = None + + client_kwargs = {} + url = self._build_url(scheme, netloc) + if url: + client_kwargs["url"] = url + if config: + client_kwargs["config"] = config + if config_dir: + client_kwargs["config_dir"] = config_dir + + self._client = Client(**client_kwargs) + + def __reduce__(self): + """Pickle support for Ray / multiprocessing. + + hdfs_native.Client is a Rust binding that can't be pickled; rather + than try to serialise live handles, we serialise the constructor + inputs and let workers re-init their own Client. Same pattern + pyarrow.fs.HadoopFileSystem uses. + + Pin the resolved config_dir into the carried options. If the + driver resolved it from $HADOOP_CONF_DIR, a worker on a host with + a different env var would otherwise pick up the worker's value + and silently talk to a different cluster. + """ + netloc = self._netloc or "" + path = f"{self._scheme}://{netloc}" + props_map = dict(self.properties.to_map()) + if self._config_dir and not props_map.get( + HdfsOptions.HDFS_CONF_DIR.key() + ): + props_map[HdfsOptions.HDFS_CONF_DIR.key()] = self._config_dir + return (type(self), (path, Options(props_map))) + + @property + def filesystem(self): + """pyarrow.fs.FileSystem facade backed by hdfs_native.fsspec. + + Lazily constructed: FileIO-only call paths + (exists/list_status/new_input_stream/...) never pay the fsspec init + cost; only ds.dataset / open_input_file callers do. + """ + if self._filesystem is None: + import pyarrow.fs as pafs + try: + from hdfs_native.fsspec import ( + HdfsFileSystem, + ViewfsFileSystem, + ) + except ImportError as e: + raise RuntimeError( + "hdfs-native fsspec adapter is required to bridge " + "HdfsNativeFileIO to a pyarrow.fs filesystem; upgrade " + "hdfs-native (>=0.13)." + ) from e + cls = (ViewfsFileSystem if self._scheme == "viewfs" + else HdfsFileSystem) + # Merge xml + overrides so the fsspec instance can connect + # without relying on HADOOP_CONF_DIR (BaseFileSystem.__init__ + # only forwards storage_options to Client, not config_dir). + merged_config = {**self._hadoop_xml, **self._config} + fsspec_fs = cls(host=self._netloc, **merged_config) + self._filesystem = pafs.PyFileSystem( + pafs.FSSpecHandler(fsspec_fs)) + return self._filesystem + + @staticmethod + def parse_location(location: str): + uri = urlparse(location) + if not uri.scheme: + return "file", uri.netloc, os.path.abspath(location) + return uri.scheme, uri.netloc, uri.path + + @staticmethod + def _build_url(scheme: str, netloc: Optional[str]) -> Optional[str]: + if not netloc: + return None + return f"{scheme}://{netloc}" + + @staticmethod + def _load_hadoop_xml(config_dir: Optional[str]) -> Dict[str, str]: + """Parse core-site.xml + hdfs-site.xml from a Hadoop config dir into a + flat {name: value} dict. Returns empty dict if the dir is missing or + unreadable. + + Used only to discover viewfs mount-table state so we can polyfill the + linkFallback mount that hdfs-native requires but libhdfs tolerates. + The final config dir is still handed to hdfs-native for its own + (more complete) xml parsing. + """ + result: Dict[str, str] = {} + if not config_dir or not os.path.isdir(config_dir): + return result + for fname in ("core-site.xml", "hdfs-site.xml"): + path = os.path.join(config_dir, fname) + if not os.path.isfile(path): + continue + try: + tree = ET.parse(path) + except (ET.ParseError, OSError): + continue + for prop in tree.getroot().findall("property"): + name_el = prop.find("name") + value_el = prop.find("value") + if name_el is None or name_el.text is None: + continue + value = ( + value_el.text.strip() + if value_el is not None and value_el.text + else "" + ) + result[name_el.text.strip()] = value + return result + + @staticmethod + def _maybe_inject_viewfs_fallback( + scheme: str, + netloc: Optional[str], + overrides: Dict[str, str], + hadoop_xml: Dict[str, str], + ) -> None: + """If we're opening a viewfs URI and no linkFallback is configured for + the cluster, pick a usable nameservice URI from existing link.* + targets or dfs.nameservices and inject one into `overrides`. + + hdfs-native rejects viewfs init without a fallback mount; libhdfs + tolerates it. This bridges the gap without touching cluster xml. + + The mount-table state is read from the merged view of hadoop xml and + catalog-option overrides, so a zero-file viewfs setup (link.* / + dfs.nameservices pushed purely through catalog options) gets a + fallback too; the injected key is still only written back to + `overrides`. + """ + if scheme != "viewfs" or not netloc: + return + cluster = netloc + fallback_key = f"fs.viewfs.mounttable.{cluster}.linkFallback" + if fallback_key in overrides or fallback_key in hadoop_xml: + return + + merged = {**hadoop_xml, **overrides} + + link_prefix = f"fs.viewfs.mounttable.{cluster}.link." + for key, value in merged.items(): + if key.startswith(link_prefix) and value: + parsed = urlparse(value) + if parsed.scheme == "hdfs" and parsed.netloc: + overrides[fallback_key] = f"hdfs://{parsed.netloc}/" + return + + nameservices = [ + ns.strip() + for ns in merged.get("dfs.nameservices", "").split(",") + if ns.strip() + ] + if nameservices: + overrides[fallback_key] = f"hdfs://{nameservices[0]}/" + + def _setup_kerberos(self): + principal = ( + self.properties.get(SecurityOptions.KERBEROS_PRINCIPAL) + or self.properties.to_map().get("security.principal") + ) + keytab = ( + self.properties.get(SecurityOptions.KERBEROS_KEYTAB) + or self.properties.to_map().get("security.keytab") + ) + if bool(principal) != bool(keytab): + raise ValueError( + "security.kerberos.login.principal and " + "security.kerberos.login.keytab " + "must be both set or both unset" + ) + if principal and keytab: + _kerberos.kerberos_login_from_keytab(principal, keytab) + cache_path = _kerberos.get_ticket_cache_path() + if not cache_path: + raise RuntimeError( + "kinit succeeded but no ticket cache path could be " + "determined. Set the KRB5CCNAME environment variable " + "to specify the cache location." + ) + # hdfs-native's GSSAPI layer reads KRB5CCNAME from the process + # env, which is global state. If a different cache was already + # configured (typically because another HdfsNativeFileIO with + # a different principal lives in the same process), warn — the + # last writer wins and earlier instances will start using the + # new ticket, which is almost certainly not what the caller + # wanted. + existing = os.environ.get("KRB5CCNAME") + existing_stripped = ( + existing[5:] if existing and existing.startswith("FILE:") + else existing + ) + if existing_stripped and existing_stripped != cache_path: + self.logger.warning( + "Overwriting process-global KRB5CCNAME from %r to %r; " + "concurrent HdfsNativeFileIO instances with different " + "Kerberos principals share env state and will clobber " + "each other's ticket caches.", + existing, cache_path, + ) + # Preserve the `FILE:` qualifier if the existing value carried + # it — some GSSAPI tooling distinguishes cache types by prefix. + os.environ["KRB5CCNAME"] = ( + f"FILE:{cache_path}" + if existing and existing.startswith("FILE:") + else cache_path + ) + + def _build_config_dict(self) -> Dict[str, str]: + config: Dict[str, str] = {} + for key, value in self.properties.to_map().items(): + if value is None: + continue + if any(key.startswith(p) for p in self.NATIVE_KEY_PREFIXES): + config[key] = str(value) + elif key.startswith(self.NS_PREFIX): + config[key[len(self.NS_PREFIX):]] = str(value) + return config + + def to_filesystem_path(self, path: str) -> str: + # hdfs-native expects an absolute path within the cluster the Client is + # bound to; passing a full URI makes its Rust-side MountTable::resolve + # treat the string as a relative path (since it doesn't start with '/') + # and prepend the user's home dir, producing nonsense like + # `/user/foo/viewfs://cluster/...`. Strip the matching scheme+authority + # so a plain absolute path reaches the client. + parsed = urlparse(path) + if parsed.scheme in ("hdfs", "viewfs"): + if parsed.scheme == self._scheme and ( + not parsed.netloc or parsed.netloc == self._netloc + ): + return parsed.path or "/" + return path + + def _adapt_status(self, status, fallback_path: str = '') -> _HdfsFileInfo: + path = getattr(status, 'path', None) or fallback_path + is_dir = bool(getattr(status, 'isdir', False)) + length = getattr(status, 'length', 0) + mtime_ms = getattr(status, 'modification_time', None) + mtime = ( + datetime.fromtimestamp(mtime_ms / 1000.0, tz=timezone.utc) + if mtime_ms else None + ) + size = None if is_dir else int(length or 0) + ftype = pafs.FileType.Directory if is_dir else pafs.FileType.File + return _HdfsFileInfo(path, size, ftype, mtime) + + def new_input_stream(self, path: str): + path_str = self.to_filesystem_path(path) + reader = self._client.read(path_str) + return _HdfsReaderAdapter(reader) + + def new_output_stream(self, path: str): + path_str = self.to_filesystem_path(path) + writer = self._client.create( + path_str, + self._WriteOptions(create_parent=True, overwrite=True), + ) + return _HdfsWriterAdapter(writer) + + def get_file_status(self, path: str): + path_str = self.to_filesystem_path(path) + try: + status = self._client.get_file_info(path_str) + except FileNotFoundError: + raise FileNotFoundError(f"File {path} does not exist") + return self._adapt_status(status, path_str) + + def list_status(self, path: str): + path_str = self.to_filesystem_path(path) + return [self._adapt_status(s) for s in self._client.list_status(path_str)] + + def exists(self, path: str) -> bool: + path_str = self.to_filesystem_path(path) + try: + self._client.get_file_info(path_str) + return True + except FileNotFoundError: + return False + + def delete(self, path: str, recursive: bool = False) -> bool: + path_str = self.to_filesystem_path(path) + try: + status = self._client.get_file_info(path_str) + except FileNotFoundError: + return False + if bool(getattr(status, 'isdir', False)) and not recursive: + if next(iter(self._client.list_status(path_str)), None) is not None: + raise OSError(f"Directory {path} is not empty") + return bool(self._client.delete(path_str, recursive)) + + def mkdirs(self, path: str) -> bool: + path_str = self.to_filesystem_path(path) + try: + status = self._client.get_file_info(path_str) + except FileNotFoundError: + self._client.mkdirs(path_str, create_parent=True) + return True + if bool(getattr(status, 'isdir', False)): + return True + raise FileExistsError(f"Path exists but is not a directory: {path}") + + def rename(self, src: str, dst: str) -> bool: + src_str = self.to_filesystem_path(src) + dst_str = self.to_filesystem_path(dst) + dst_parent = str(PurePosixPath(dst_str).parent) + if dst_parent and dst_parent != '.': + try: + self._client.get_file_info(dst_parent) + except FileNotFoundError: + self._client.mkdirs(dst_parent, create_parent=True) + try: + dst_status = self._client.get_file_info(dst_str) + if not getattr(dst_status, 'isdir', False): + return False + src_name = PurePosixPath(src_str).name + dst_str = str(PurePosixPath(dst_str) / src_name) + try: + self._client.get_file_info(dst_str) + return False + except FileNotFoundError: + pass + except FileNotFoundError: + pass + try: + self._client.rename(src_str, dst_str) + return True + except FileNotFoundError: + return False + except (PermissionError, OSError): + return False + + def write_parquet(self, path: str, data: pyarrow.Table, + compression: str = 'zstd', zstd_level: int = 1, **kwargs): + try: + import pyarrow.parquet as pq + if compression.lower() == 'zstd': + kwargs['compression_level'] = zstd_level + with self.new_output_stream(path) as raw_stream: + stream = pyarrow.PythonFile(raw_stream, mode='wb') + try: + pq.write_table( + data, stream, compression=compression, **kwargs) + finally: + stream.close() + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write Parquet file {path}: {e}") from e + + def write_orc(self, path: str, data: pyarrow.Table, + compression: str = 'zstd', zstd_level: int = 1, **kwargs): + try: + import sys + import pyarrow.orc as orc + data = self._cast_time_columns_for_orc(data) + with self.new_output_stream(path) as raw_stream: + stream = pyarrow.PythonFile(raw_stream, mode='wb') + try: + if sys.version_info[:2] == (3, 6): + orc.write_table(data, stream, **kwargs) + else: + orc.write_table( + data, stream, compression=compression, **kwargs) + finally: + stream.close() + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write ORC file {path}: {e}") from e + + def write_avro(self, path: str, data: pyarrow.Table, + avro_schema=None, compression: str = 'zstd', + zstd_level: int = 1, **kwargs): + import fastavro + if avro_schema is None: + avro_schema = PyarrowFieldParser.to_avro_schema(data.schema) + + records_dict = data.to_pydict() + + def record_generator(): + num_rows = len(list(records_dict.values())[0]) + for i in range(num_rows): + record = {} + for col in records_dict.keys(): + value = records_dict[col][i] + if isinstance(value, datetime) and value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + record[col] = value + yield record + + codec_map = { + 'null': 'null', 'deflate': 'deflate', 'snappy': 'snappy', + 'bzip2': 'bzip2', 'xz': 'xz', 'zstandard': 'zstandard', + 'zstd': 'zstandard', + } + codec = codec_map.get(compression.lower()) + if codec is None: + raise ValueError( + f"Unsupported compression '{compression}' for Avro format. " + f"Supported compressions: {', '.join(sorted(codec_map.keys()))}." + ) + if codec == 'zstandard': + kwargs['codec_compression_level'] = zstd_level + with self.new_output_stream(path) as output_stream: + fastavro.writer(output_stream, avro_schema, + record_generator(), codec=codec, **kwargs) + + def write_blob(self, path: str, data: pyarrow.Table, **kwargs): + try: + if data.num_columns != 1: + raise RuntimeError( + f"Blob format only supports a single column, " + f"got {data.num_columns} columns") + field = data.schema[0] + if pyarrow.types.is_large_binary(field.type): + fields = [DataField(0, field.name, AtomicType("BLOB"))] + else: + paimon_type = PyarrowFieldParser.to_paimon_type( + field.type, field.nullable) + fields = [DataField(0, field.name, paimon_type)] + records_dict = data.to_pydict() + num_rows = data.num_rows + field_name = fields[0].name + with self.new_output_stream(path) as output_stream: + writer = BlobFormatWriter(output_stream) + for i in range(num_rows): + writer.write_value(records_dict[field_name][i], + fields, self.uri_reader_factory) + writer.close() + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write blob file {path}: {e}") from e + + def write_lance(self, path: str, data: pyarrow.Table, **kwargs): + # Mirror the remote-scheme writer: lance/vortex talk to the backend + # through their own object_store, so we hand them the URI plus any + # storage options the FileIO exposes rather than routing through the + # native client. Without these two methods, an HDFS table configured + # with file.format=lance/vortex would hit FileIO's NotImplementedError + # now that this class is the default hdfs:// backend. + try: + import lance + + from pypaimon.read.reader.lance_utils import to_lance_specified + file_path_for_lance, storage_options = to_lance_specified(self, path) + + writer = lance.file.LanceFileWriter( + file_path_for_lance, data.schema, + storage_options=storage_options, **kwargs) + try: + for batch in data.to_batches(): + writer.write_batch(batch) + finally: + writer.close() + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write Lance file {path}: {e}") from e + + def write_mosaic(self, path: str, data: pyarrow.Table, **kwargs): + try: + import mosaic + with self.new_output_stream(path) as output_stream: + mosaic.write_table(data, output_stream) + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write Mosaic file {path}: {e}") from e + + def write_vortex(self, path: str, data: pyarrow.Table, **kwargs): + try: + import vortex + from vortex import store + + from pypaimon.read.reader.vortex_utils import to_vortex_specified + file_path_for_vortex, store_kwargs = to_vortex_specified(self, path) + + if store_kwargs: + vortex_store = store.from_url(file_path_for_vortex, **store_kwargs) + vortex_store.write(vortex.array(data)) + else: + from vortex._lib.io import write as vortex_write + vortex_write(vortex.array(data), file_path_for_vortex) + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write Vortex file {path}: {e}") from e + + def close(self): + self._client = None diff --git a/paimon-python/pypaimon/filesystem/local_file_io.py b/paimon-python/pypaimon/filesystem/local_file_io.py index 120f2c6b3aae..d3f5f81f4f4e 100644 --- a/paimon-python/pypaimon/filesystem/local_file_io.py +++ b/paimon-python/pypaimon/filesystem/local_file_io.py @@ -393,6 +393,16 @@ def write_lance(self, path: str, data: pyarrow.Table, **kwargs): self.delete_quietly(path) raise RuntimeError(f"Failed to write Lance file {path}: {e}") from e + def write_mosaic(self, path: str, data: pyarrow.Table, **kwargs): + try: + import mosaic + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'wb') as f: + mosaic.write_table(data, f) + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write Mosaic file {path}: {e}") from e + def write_vortex(self, path: str, data: pyarrow.Table, **kwargs): try: import vortex @@ -412,6 +422,27 @@ def write_vortex(self, path: str, data: pyarrow.Table, **kwargs): self.delete_quietly(path) raise RuntimeError(f"Failed to write Vortex file {path}: {e}") from e + def write_row(self, path: str, data: pyarrow.Table, fields=None, zstd_level: int = 1, **kwargs): + try: + from pypaimon.write.writer.format_row_writer import FormatRowWriter + from pypaimon.schema.data_types import PyarrowFieldParser + + if fields is None: + fields = PyarrowFieldParser.to_paimon_schema(data.schema) + + file_path = self._to_file(path) + parent = file_path.parent + if parent and not parent.exists(): + parent.mkdir(parents=True, exist_ok=True) + + with open(file_path, 'wb') as output_stream: + writer = FormatRowWriter(output_stream, fields, zstd_level=zstd_level) + writer.write_table(data) + writer.close() + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write row file {path}: {e}") from e + def write_blob(self, path: str, data: pyarrow.Table, **kwargs): try: if data.num_columns != 1: diff --git a/paimon-python/pypaimon/filesystem/pvfs.py b/paimon-python/pypaimon/filesystem/pvfs.py index 236c00481c83..a56128ca4025 100644 --- a/paimon-python/pypaimon/filesystem/pvfs.py +++ b/paimon-python/pypaimon/filesystem/pvfs.py @@ -18,6 +18,7 @@ import datetime import importlib import logging +import posixpath import time from abc import ABC from dataclasses import dataclass @@ -100,7 +101,12 @@ def __eq__(self, __value: Any) -> bool: def get_actual_path(self, storage_location: str): if self.sub_path: - return '{}/{}'.format(storage_location.rstrip("/"), self.sub_path.lstrip("/")) + normalized_sub = posixpath.normpath(self.sub_path) + if normalized_sub == ".." or normalized_sub.startswith("../") or normalized_sub.startswith("/"): + raise ValueError( + "Path traversal detected: resolved path escapes table storage boundary" + ) + return '{}/{}'.format(storage_location.rstrip("/"), normalized_sub) return storage_location def get_virtual_location(self): @@ -733,6 +739,11 @@ def _extract_pvfs_identifier(self, path: str) -> Optional['PVFSIdentifier']: return None components = [component for component in path_without_protocol.rstrip('/').split('/') if component] + for component in components: + if component == '..' or '\x00' in component: + raise ValueError( + "Invalid path: path traversal components are not allowed" + ) catalog: str = None endpoint: str = self.options.get(CatalogOptions.URI) if len(components) > 0: diff --git a/paimon-python/pypaimon/filesystem/pyarrow_file_io.py b/paimon-python/pypaimon/filesystem/pyarrow_file_io.py index e7315f3effe9..12d3e91b5c74 100644 --- a/paimon-python/pypaimon/filesystem/pyarrow_file_io.py +++ b/paimon-python/pypaimon/filesystem/pyarrow_file_io.py @@ -329,26 +329,13 @@ def _initialize_gcs_fs(self) -> FileSystem: @staticmethod def _kerberos_login_from_keytab(principal: str, keytab: str): - if not os.path.isfile(keytab): - raise FileNotFoundError(f"Kerberos keytab file not found: {keytab}") - if not os.access(keytab, os.R_OK): - raise PermissionError(f"Kerberos keytab file is not readable: {keytab}") - subprocess.run( - ['kinit', '-kt', keytab, principal], - check=True, capture_output=True, text=True - ) + from pypaimon.filesystem import _kerberos + _kerberos.kerberos_login_from_keytab(principal, keytab) @staticmethod def _get_ticket_cache_path() -> Optional[str]: - cc = os.environ.get('KRB5CCNAME') - if cc: - if cc.startswith('FILE:'): - return cc[5:] - return cc - default_path = f'/tmp/krb5cc_{os.getuid()}' - if os.path.exists(default_path): - return default_path - return None + from pypaimon.filesystem import _kerberos + return _kerberos.get_ticket_cache_path() def new_input_stream(self, path: str): path_str = self.to_filesystem_path(path) @@ -655,6 +642,15 @@ def write_lance(self, path: str, data: pyarrow.Table, **kwargs): self.delete_quietly(path) raise RuntimeError(f"Failed to write Lance file {path}: {e}") from e + def write_mosaic(self, path: str, data: pyarrow.Table, **kwargs): + try: + import mosaic + with self.new_output_stream(path) as output_stream: + mosaic.write_table(data, output_stream) + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write Mosaic file {path}: {e}") from e + def write_vortex(self, path: str, data: pyarrow.Table, **kwargs): try: import vortex @@ -673,6 +669,21 @@ def write_vortex(self, path: str, data: pyarrow.Table, **kwargs): self.delete_quietly(path) raise RuntimeError(f"Failed to write Vortex file {path}: {e}") from e + def write_row(self, path: str, data: pyarrow.Table, fields=None, zstd_level: int = 1, **kwargs): + try: + from pypaimon.write.writer.format_row_writer import FormatRowWriter + + if fields is None: + fields = PyarrowFieldParser.to_paimon_schema(data.schema) + + with self.new_output_stream(path) as output_stream: + writer = FormatRowWriter(output_stream, fields, zstd_level=zstd_level) + writer.write_table(data) + writer.close() + except Exception as e: + self.delete_quietly(path) + raise RuntimeError(f"Failed to write row file {path}: {e}") from e + def write_blob(self, path: str, data: pyarrow.Table, **kwargs): try: if data.num_columns != 1: diff --git a/paimon-python/pypaimon/globalindex/btree/__init__.py b/paimon-python/pypaimon/globalindex/btree/__init__.py index d01b38f31834..17a6a30cbc4f 100644 --- a/paimon-python/pypaimon/globalindex/btree/__init__.py +++ b/paimon-python/pypaimon/globalindex/btree/__init__.py @@ -19,6 +19,14 @@ from pypaimon.globalindex.btree.btree_index_reader import BTreeIndexReader from pypaimon.globalindex.btree.btree_index_meta import BTreeIndexMeta +from pypaimon.globalindex.btree.btree_file_meta_selector import BTreeFileMetaSelector from pypaimon.globalindex.btree.key_serializer import KeySerializer +from pypaimon.globalindex.btree.lazy_filtered_btree_reader import LazyFilteredBTreeReader -__all__ = ['BTreeIndexReader', 'BTreeIndexMeta', 'KeySerializer'] +__all__ = [ + 'BTreeIndexReader', + 'BTreeIndexMeta', + 'BTreeFileMetaSelector', + 'KeySerializer', + 'LazyFilteredBTreeReader', +] diff --git a/paimon-python/pypaimon/globalindex/btree/btree_file_meta_selector.py b/paimon-python/pypaimon/globalindex/btree/btree_file_meta_selector.py new file mode 100644 index 000000000000..3770d842aab0 --- /dev/null +++ b/paimon-python/pypaimon/globalindex/btree/btree_file_meta_selector.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prunes BTree index files by comparing predicate literals against file min/max keys.""" + +from typing import List, Optional, Tuple + +from pypaimon.globalindex.btree.btree_index_meta import BTreeIndexMeta +from pypaimon.globalindex.btree.key_serializer import KeySerializer +from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta + + +class BTreeFileMetaSelector: + """Selects which BTree index files may contain rows matching a predicate. + + Each select_* method returns: + - None: cannot evaluate (caller must visit all files) + - []: all files pruned (result is empty) + - [...]: only these files need to be visited + """ + + def __init__(self, files: List[Tuple[GlobalIndexIOMeta, BTreeIndexMeta]], + key_serializer: KeySerializer): + self._files = files + self._key_serializer = key_serializer + self._comparator = key_serializer.create_comparator() + + def _filter(self, predicate) -> List[GlobalIndexIOMeta]: + return [io_meta for io_meta, idx_meta in self._files if predicate(idx_meta)] + + def _all_non_null_files(self) -> List[GlobalIndexIOMeta]: + return self._filter(lambda m: not m.only_nulls()) + + def select_is_null(self) -> List[GlobalIndexIOMeta]: + return self._filter(lambda m: m.has_nulls) + + def select_is_not_null(self) -> List[GlobalIndexIOMeta]: + return self._filter(lambda m: not m.only_nulls()) + + def select_equal(self, literal) -> List[GlobalIndexIOMeta]: + def pred(meta): + if meta.only_nulls(): + return False + first = self._key_serializer.deserialize(meta.first_key) + last = self._key_serializer.deserialize(meta.last_key) + return self._comparator(literal, first) >= 0 and self._comparator(literal, last) <= 0 + return self._filter(pred) + + def select_not_equal(self, literal) -> Optional[List[GlobalIndexIOMeta]]: + return None + + def select_less_than(self, literal) -> List[GlobalIndexIOMeta]: + def pred(meta): + if meta.only_nulls(): + return False + first = self._key_serializer.deserialize(meta.first_key) + return self._comparator(first, literal) < 0 + return self._filter(pred) + + def select_less_or_equal(self, literal) -> List[GlobalIndexIOMeta]: + def pred(meta): + if meta.only_nulls(): + return False + first = self._key_serializer.deserialize(meta.first_key) + return self._comparator(first, literal) <= 0 + return self._filter(pred) + + def select_greater_than(self, literal) -> List[GlobalIndexIOMeta]: + def pred(meta): + if meta.only_nulls(): + return False + last = self._key_serializer.deserialize(meta.last_key) + return self._comparator(last, literal) > 0 + return self._filter(pred) + + def select_greater_or_equal(self, literal) -> List[GlobalIndexIOMeta]: + def pred(meta): + if meta.only_nulls(): + return False + last = self._key_serializer.deserialize(meta.last_key) + return self._comparator(last, literal) >= 0 + return self._filter(pred) + + def select_in(self, literals) -> List[GlobalIndexIOMeta]: + def pred(meta): + if meta.only_nulls(): + return False + first = self._key_serializer.deserialize(meta.first_key) + last = self._key_serializer.deserialize(meta.last_key) + for lit in literals: + if self._comparator(lit, first) >= 0 and self._comparator(lit, last) <= 0: + return True + return False + return self._filter(pred) + + def select_not_in(self, literals) -> Optional[List[GlobalIndexIOMeta]]: + return None + + def select_between(self, from_v, to_v) -> List[GlobalIndexIOMeta]: + def pred(meta): + if meta.only_nulls(): + return False + first = self._key_serializer.deserialize(meta.first_key) + last = self._key_serializer.deserialize(meta.last_key) + return self._comparator(from_v, last) <= 0 and self._comparator(to_v, first) >= 0 + return self._filter(pred) + + def select_starts_with(self, literal) -> Optional[List[GlobalIndexIOMeta]]: + return None + + def select_ends_with(self, literal) -> Optional[List[GlobalIndexIOMeta]]: + return None + + def select_contains(self, literal) -> Optional[List[GlobalIndexIOMeta]]: + return None + + def select_like(self, literal) -> Optional[List[GlobalIndexIOMeta]]: + return None diff --git a/paimon-python/pypaimon/globalindex/btree/btree_index_reader.py b/paimon-python/pypaimon/globalindex/btree/btree_index_reader.py index 8dda2ef5c187..5a27a24ecc91 100644 --- a/paimon-python/pypaimon/globalindex/btree/btree_index_reader.py +++ b/paimon-python/pypaimon/globalindex/btree/btree_index_reader.py @@ -15,22 +15,21 @@ # specific language governing permissions and limitations # under the License. -""" -The BTreeIndexReader implementation for btree index. +"""The BTreeIndexReader implementation for btree index. -This reader provides efficient querying capabilities for B-tree based global indexes, -supporting various predicate operations like equality, range, and null checks. +Synchronous index reader for a single BTree index file. Parallelism across +multiple files is handled by LazyFilteredBTreeReader. """ import struct +import threading import zlib from typing import List, Optional -from pypaimon.common.file_io import FileIO +from pypaimon.common.file_io import FileIO, supports_pread, pread from pypaimon.globalindex.btree.btree_index_meta import BTreeIndexMeta from pypaimon.globalindex.btree.key_serializer import KeySerializer from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta -from pypaimon.globalindex.global_index_reader import FieldRef, GlobalIndexReader from pypaimon.globalindex.global_index_result import GlobalIndexResult from pypaimon.utils.roaring_bitmap import RoaringBitmap64 from pypaimon.globalindex.btree.btree_file_footer import BTreeFileFooter @@ -52,12 +51,10 @@ def _deserialize_row_ids(data: bytes) -> List[int]: return row_ids -class BTreeIndexReader(GlobalIndexReader): - """ - The GlobalIndexReader implementation for btree index. - - This reader provides efficient querying capabilities for B-tree based global indexes, - supporting various predicate operations like equality, range, and null checks. +class BTreeIndexReader: + """Synchronous index reader for a single BTree index file. + + Parallelism across multiple files is handled by LazyFilteredBTreeReader. """ FOOTER_ENCODED_LENGTH = 52 @@ -67,51 +64,52 @@ def __init__( key_serializer: KeySerializer, file_io: FileIO, index_path: str, - io_meta: GlobalIndexIOMeta + io_meta: GlobalIndexIOMeta, ): self.key_serializer = key_serializer self.comparator = key_serializer.create_comparator() self.io_meta = io_meta - + # Deserialize index metadata index_meta = BTreeIndexMeta.deserialize(io_meta.metadata) - + if index_meta.first_key is not None: self.min_key = key_serializer.deserialize(index_meta.first_key) self.max_key = key_serializer.deserialize(index_meta.last_key) else: - # This is possible if this btree index file only stores nulls self.min_key = None self.max_key = None - + self.has_nulls = index_meta.has_nulls file_path = (io_meta.external_path if io_meta.external_path else index_path + "/" + io_meta.file_name) self.input_stream = file_io.new_input_stream(file_path) - + self._supports_pread = supports_pread(self.input_stream) + self._io_lock = threading.Lock() + # Lazy-loaded null bitmap self._null_bitmap: Optional[RoaringBitmap64] = None - + self._null_bitmap_lock = threading.Lock() + # Read footer to get index and bloom filter handles self.footer = self._read_footer() - - # Initialize SST file reader (simplified version) + + # Initialize SST file reader self.reader = self._create_sst_reader() - def _read_footer(self) -> BTreeFileFooter: - """ - Read the file footer to get metadata handles. + def _read_from(self, offset: int, length: int) -> bytes: + if self._supports_pread: + return pread(self.input_stream, length, offset) + with self._io_lock: + self.input_stream.seek(offset) + return self.input_stream.read(length) - Returns: - BTreeFileFooter containing index_block_handle and bloom_filter_handle - """ + def _read_footer(self) -> BTreeFileFooter: file_size = self.io_meta.file_size - # Seek to footer position - self.input_stream.seek(file_size - BTreeFileFooter.ENCODED_LENGTH) - footer_data = self.input_stream.read(BTreeFileFooter.ENCODED_LENGTH) - - # Parse footer + footer_data = self._read_from( + file_size - BTreeFileFooter.ENCODED_LENGTH, + BTreeFileFooter.ENCODED_LENGTH) return BTreeFileFooter.read_footer(footer_data) def _create_sst_reader(self) -> SstFileReader: @@ -120,51 +118,39 @@ def comparator(a: bytes, b: bytes) -> int: o2 = self.key_serializer.deserialize(b) return self.comparator(o1, o2) - return SstFileReader(self.input_stream, comparator, self.footer.index_block_handle) + return SstFileReader( + self.input_stream, comparator, self.footer.index_block_handle, + use_pread=self._supports_pread, io_lock=self._io_lock) def _read_null_bitmap(self) -> RoaringBitmap64: - """ - Read the null bitmap from the index file. - - Returns: - RoaringBitmap64 containing null row IDs - """ if self._null_bitmap is not None: return self._null_bitmap - - bitmap = RoaringBitmap64() - - # Read from the null bitmap block handle if available - if self.footer.null_bitmap_handle is not None: - self.input_stream.seek(self.footer.null_bitmap_handle.offset) - data = self.input_stream.read(self.footer.null_bitmap_handle.size + 4) - # Read bitmap data (excluding CRC32) - bitmap_length = len(data) - 4 - bitmap_bytes = data[:bitmap_length] - crc32_value = struct.unpack(' RoaringBitmap64: - """ - Get all non-null row IDs. - - This traverses all data to avoid returning null values, which is very - advantageous in situations where there are many null values. - - Returns: - RoaringBitmap64 containing all non-null row IDs - """ if self.min_key is None: return RoaringBitmap64() - return self._range_query(self.min_key, self.max_key, True, True) def _range_query( @@ -174,31 +160,16 @@ def _range_query( from_inclusive: bool, to_inclusive: bool ) -> RoaringBitmap64: - """ - Range query on underlying SST File. - - Args: - from_key: Lower bound key - to_key: Upper bound key - from_inclusive: Whether to include lower bound - to_inclusive: Whether to include upper bound - - Returns: - RoaringBitmap64 containing all qualified row IDs - """ result = RoaringBitmap64() - # Create iterator and seek to start key file_iter = self.reader.create_iterator() file_iter.seek_to(self.key_serializer.serialize(from_key)) - # Iterate through data blocks while True: data_iter = file_iter.read_batch() if data_iter is None: break - # Process entries in current block while data_iter.has_next(): entry = data_iter.__next__() if entry is None: @@ -209,217 +180,81 @@ def _range_query( key = self.key_serializer.deserialize(key_bytes) - # Skip if key equals from_key and from_inclusive is False if not from_inclusive and self.comparator(key, from_key) == 0: continue - # Check if key is beyond the range difference = self.comparator(key, to_key) if difference > 0 or (not to_inclusive and difference == 0): return result - # Add all row IDs for this key row_ids = _deserialize_row_ids(value_bytes) for row_id in row_ids: result.add(row_id) return result - def _is_in_range( - self, - key: object, - from_key: object, - to_key: object, - from_inclusive: bool, - to_inclusive: bool - ) -> bool: - """ - Check if a key falls within the specified range. - - Args: - key: The key to check - from_key: Lower bound - to_key: Upper bound - from_inclusive: Whether lower bound is inclusive - to_inclusive: Whether upper bound is inclusive - - Returns: - True if key is in range, False otherwise - """ - if not from_inclusive and self.comparator(key, from_key) == 0: - return False - - cmp_to = self.comparator(key, to_key) - if cmp_to > 0: - return False - if not to_inclusive and cmp_to == 0: - return False - - return True - - def visit_is_not_null(self, field_ref: FieldRef) -> Optional[GlobalIndexResult]: - """ - Visit an is-not-null predicate. - - Nulls are stored separately in null bitmap. - """ - def supplier(): - try: - return self._all_non_null_rows() - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_is_null(self, field_ref: FieldRef) -> Optional[GlobalIndexResult]: - """ - Visit an is-null predicate. - - Nulls are stored separately in null bitmap. - """ - return GlobalIndexResult.create(self._read_null_bitmap) - - def visit_less_than(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit a less-than predicate.""" - def supplier(): - try: - return self._range_query(self.min_key, literal, True, False) - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_greater_or_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit a greater-or-equal predicate.""" - def supplier(): - try: - return self._range_query(literal, self.max_key, True, True) - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_not_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit a not-equal predicate.""" - def supplier(): - try: - result = self._all_non_null_rows() - equal_result = self._range_query(literal, literal, True, True) - return RoaringBitmap64.remove_all(result, equal_result) - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_less_or_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit a less-or-equal predicate.""" - def supplier(): - try: - return self._range_query(self.min_key, literal, True, True) - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit an equality predicate.""" - def supplier(): - return self._range_query(literal, literal, True, True) - - return GlobalIndexResult.create(supplier) - - def visit_greater_than(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit a greater-than predicate.""" - def supplier(): - try: - return self._range_query(literal, self.max_key, False, True) - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_in(self, field_ref: FieldRef, literals: List[object]) -> Optional[GlobalIndexResult]: - """Visit an in predicate.""" - def supplier(): - try: - result = RoaringBitmap64() - for literal in literals: - range_result = self._range_query(literal, literal, True, True) - result = RoaringBitmap64.or_(result, range_result) - return result - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_not_in(self, field_ref: FieldRef, literals: List[object]) -> Optional[GlobalIndexResult]: - """Visit a not-in predicate.""" - def supplier(): - try: - result = self._all_non_null_rows() - in_result = self.visit_in(field_ref, literals).results() - return RoaringBitmap64.remove_all(result, in_result) - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_starts_with(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """ - Visit a starts-with predicate. - - Note: `startsWith` can also be covered by btree index in the future. - """ - def supplier(): - try: - return self._all_non_null_rows() - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_ends_with(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit an ends-with predicate.""" - def supplier(): - try: - return self._all_non_null_rows() - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_contains(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit a contains predicate.""" - def supplier(): - try: - return self._all_non_null_rows() - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_like(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - """Visit a like predicate.""" - def supplier(): - try: - return self._all_non_null_rows() - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) - - def visit_between(self, field_ref: FieldRef, min_v: object, max_v: object) -> Optional[GlobalIndexResult]: - """Visit a between predicate.""" - - def supplier(): - try: - return self._range_query(min_v, max_v, True, True) - except Exception as e: - raise RuntimeError("fail to read btree index file.", e) - - return GlobalIndexResult.create(supplier) + def visit_is_not_null(self) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create(self._all_non_null_rows()) + + def visit_is_null(self) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create(self._read_null_bitmap()) + + def visit_less_than(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create( + self._range_query(self.min_key, literal, True, False)) + + def visit_greater_or_equal(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create( + self._range_query(literal, self.max_key, True, True)) + + def visit_not_equal(self, literal: object) -> Optional[GlobalIndexResult]: + result = self._all_non_null_rows() + equal_result = self._range_query(literal, literal, True, True) + return GlobalIndexResult.create( + RoaringBitmap64.remove_all(result, equal_result)) + + def visit_less_or_equal(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create( + self._range_query(self.min_key, literal, True, True)) + + def visit_equal(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create( + self._range_query(literal, literal, True, True)) + + def visit_greater_than(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create( + self._range_query(literal, self.max_key, False, True)) + + def visit_in(self, literals: List[object]) -> Optional[GlobalIndexResult]: + result = RoaringBitmap64() + for literal in literals: + range_result = self._range_query(literal, literal, True, True) + result = RoaringBitmap64.or_(result, range_result) + return GlobalIndexResult.create(result) + + def visit_not_in(self, literals: List[object]) -> Optional[GlobalIndexResult]: + result = self._all_non_null_rows() + for literal in literals: + range_result = self._range_query(literal, literal, True, True) + result = RoaringBitmap64.remove_all(result, range_result) + return GlobalIndexResult.create(result) + + def visit_starts_with(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create(self._all_non_null_rows()) + + def visit_ends_with(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create(self._all_non_null_rows()) + + def visit_contains(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create(self._all_non_null_rows()) + + def visit_like(self, literal: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create(self._all_non_null_rows()) + + def visit_between(self, min_v: object, max_v: object) -> Optional[GlobalIndexResult]: + return GlobalIndexResult.create( + self._range_query(min_v, max_v, True, True)) def close(self) -> None: - """Close the reader and release resources.""" if self.input_stream is not None: self.input_stream.close() diff --git a/paimon-python/pypaimon/globalindex/btree/lazy_filtered_btree_reader.py b/paimon-python/pypaimon/globalindex/btree/lazy_filtered_btree_reader.py new file mode 100644 index 000000000000..8dc19e7151ab --- /dev/null +++ b/paimon-python/pypaimon/globalindex/btree/lazy_filtered_btree_reader.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""A GlobalIndexReader that manages multiple BTree files with lazy reader creation and file pruning.""" + +import threading +from concurrent.futures import Executor, Future +from typing import Callable, Dict, List, Optional + +from pypaimon.globalindex.btree.btree_file_meta_selector import BTreeFileMetaSelector +from pypaimon.globalindex.btree.btree_index_meta import BTreeIndexMeta +from pypaimon.globalindex.btree.btree_index_reader import BTreeIndexReader +from pypaimon.globalindex.btree.key_serializer import KeySerializer +from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta +from pypaimon.globalindex.global_index_reader import FieldRef, GlobalIndexReader, _completed_future +from pypaimon.globalindex.global_index_result import GlobalIndexResult + + +class LazyFilteredBTreeReader(GlobalIndexReader): + """Manages multiple BTree index files for one range. + + Uses BTreeFileMetaSelector to prune files before I/O, and lazily + creates/caches BTreeIndexReader instances only for files that are + actually visited. + """ + + def __init__( + self, + key_serializer: KeySerializer, + file_io, + index_path: str, + io_metas: List[GlobalIndexIOMeta], + executor: Executor, + ): + self._key_serializer = key_serializer + self._file_io = file_io + self._index_path = index_path + self._executor = executor + + files = [] + for io_meta in io_metas: + idx_meta = BTreeIndexMeta.deserialize(io_meta.metadata) + files.append((io_meta, idx_meta)) + + self._selector = BTreeFileMetaSelector(files, key_serializer) + self._reader_cache: Dict[str, BTreeIndexReader] = {} + self._cache_lock = threading.Lock() + + def _get_or_create_reader(self, meta: GlobalIndexIOMeta) -> BTreeIndexReader: + key = meta.external_path or meta.file_name + with self._cache_lock: + reader = self._reader_cache.get(key) + if reader is not None: + return reader + reader = BTreeIndexReader( + key_serializer=self._key_serializer, + file_io=self._file_io, + index_path=self._index_path, + io_meta=meta, + ) + self._reader_cache[key] = reader + return reader + + def _visit_parallel( + self, + selector_fn: Callable[[], Optional[List[GlobalIndexIOMeta]]], + visitor_fn: Callable[[BTreeIndexReader], Optional[GlobalIndexResult]], + ) -> 'Future[Optional[GlobalIndexResult]]': + selected = selector_fn() + if selected is None: + selected = [io_meta for io_meta, _ in self._selector._files] + if not selected: + return _completed_future(GlobalIndexResult.create_empty()) + + # Single-level submit: reader creation + query in one task. + # This avoids the nested submission deadlock when using a + # semaphore-limited executor. + task_futures: List[Future] = [] + for meta in selected: + task_futures.append(self._executor.submit( + lambda m=meta: visitor_fn(self._get_or_create_reader(m)))) + + # Union all results once every task future is done + all_done: Future = Future() + remaining = [len(task_futures)] + lock = threading.Lock() + + def on_done(_): + with lock: + remaining[0] -= 1 + if remaining[0] == 0: + try: + result: Optional[GlobalIndexResult] = None + for f in task_futures: + current = f.result() + if current is None: + continue + if result is None: + result = current + else: + result = result.or_(current) + all_done.set_result(result if result is not None + else GlobalIndexResult.create_empty()) + except Exception as e: + all_done.set_exception(e) + + for f in task_futures: + f.add_done_callback(on_done) + + return all_done + + # ---- visit methods ------------------------------------------------------- + + def visit_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_equal(literal), + lambda r: r.visit_equal(literal)) + + def visit_not_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_not_equal(literal), + lambda r: r.visit_not_equal(literal)) + + def visit_less_than(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_less_than(literal), + lambda r: r.visit_less_than(literal)) + + def visit_less_or_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_less_or_equal(literal), + lambda r: r.visit_less_or_equal(literal)) + + def visit_greater_than(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_greater_than(literal), + lambda r: r.visit_greater_than(literal)) + + def visit_greater_or_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_greater_or_equal(literal), + lambda r: r.visit_greater_or_equal(literal)) + + def visit_is_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_is_null(), + lambda r: r.visit_is_null()) + + def visit_is_not_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_is_not_null(), + lambda r: r.visit_is_not_null()) + + def visit_in(self, field_ref: FieldRef, literals) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_in(literals), + lambda r: r.visit_in(literals)) + + def visit_not_in(self, field_ref: FieldRef, literals) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_not_in(literals), + lambda r: r.visit_not_in(literals)) + + def visit_starts_with(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_starts_with(literal), + lambda r: r.visit_starts_with(literal)) + + def visit_ends_with(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_ends_with(literal), + lambda r: r.visit_ends_with(literal)) + + def visit_contains(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_contains(literal), + lambda r: r.visit_contains(literal)) + + def visit_like(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_like(literal), + lambda r: r.visit_like(literal)) + + def visit_between(self, field_ref: FieldRef, from_v, to_v) -> 'Future[Optional[GlobalIndexResult]]': + return self._visit_parallel( + lambda: self._selector.select_between(from_v, to_v), + lambda r: r.visit_between(from_v, to_v)) + + def close(self) -> None: + with self._cache_lock: + for reader in self._reader_cache.values(): + try: + reader.close() + except Exception: + pass + self._reader_cache.clear() diff --git a/paimon-python/pypaimon/globalindex/btree/sst_file_reader.py b/paimon-python/pypaimon/globalindex/btree/sst_file_reader.py index 535561c060f3..b2e2c87775ec 100644 --- a/paimon-python/pypaimon/globalindex/btree/sst_file_reader.py +++ b/paimon-python/pypaimon/globalindex/btree/sst_file_reader.py @@ -21,14 +21,17 @@ Users can call createIterator to create a file iterator and then use seek and read methods to do range queries. -Note that this class is NOT thread-safe. +Thread-safe when the underlying stream supports position-based reads +(PyArrow NativeFile.read_at or os.pread). """ import struct +import threading import zlib from typing import Optional, Callable from typing import BinaryIO +from pypaimon.common.file_io import pread from pypaimon.globalindex.btree.block_handle import BlockHandle from pypaimon.globalindex.btree.block_entry import BlockEntry from pypaimon.globalindex.btree.block_reader import BlockReader, BlockIterator @@ -103,27 +106,38 @@ def read_batch(self) -> Optional[BlockIterator]: class SstFileReader: """ An SST File Reader which serves point queries and range queries. - + Users can call createIterator to create a file iterator and then use seek and read methods to do range queries. - - Note that this class is NOT thread-safe. + + Thread-safe when the underlying stream supports pread, or when an + io_lock is provided for seek+read fallback. """ - + def __init__( self, input_stream: BinaryIO, comparator: Callable[[bytes, bytes], int], - index_block_handle: BlockHandle + index_block_handle: BlockHandle, + use_pread: bool = False, + io_lock: Optional[threading.Lock] = None, ): self.comparator = comparator self.input_stream = input_stream + self._supports_pread = use_pread + self._lock = io_lock or threading.Lock() self.index_block = self._read_block(index_block_handle) + def _read_from(self, offset: int, length: int) -> bytes: + if self._supports_pread: + return pread(self.input_stream, length, offset) + with self._lock: + self.input_stream.seek(offset) + return self.input_stream.read(length) + def _read_block(self, block_handle: BlockHandle) -> BlockReader: - self.input_stream.seek(block_handle.offset) # Read block data + 5 bytes trailer (1 byte compression type + 4 bytes CRC32) - block_data = self.input_stream.read(block_handle.size + 5) + block_data = self._read_from(block_handle.offset, block_handle.size + 5) # Parse block trailer (last 5 bytes: 1 byte compression type + 4 bytes CRC32) if len(block_data) < 5: raise ValueError("Block data too short to contain trailer") diff --git a/paimon-python/pypaimon/globalindex/full_text_search.py b/paimon-python/pypaimon/globalindex/full_text_search.py index f80c8a539aa5..2ec41818d061 100644 --- a/paimon-python/pypaimon/globalindex/full_text_search.py +++ b/paimon-python/pypaimon/globalindex/full_text_search.py @@ -17,6 +17,7 @@ """FullTextSearch for performing full-text search on a text column.""" +from concurrent.futures import Future from dataclasses import dataclass from typing import Optional @@ -35,6 +36,7 @@ class FullTextSearch: query_text: str limit: int field_name: str + query_operator: str = "or" def __post_init__(self): if not self.query_text: @@ -43,10 +45,21 @@ def __post_init__(self): raise ValueError(f"Limit must be positive, got: {self.limit}") if not self.field_name: raise ValueError("Field name cannot be null or empty") + query_operator = ( + "or" if self.query_operator is None else self.query_operator.strip().lower() + ) + if query_operator not in ("or", "and"): + raise ValueError( + "Query operator must be 'or' or 'and', got: %s" % self.query_operator + ) + self.query_operator = query_operator - def visit(self, visitor: 'GlobalIndexReader') -> Optional['ScoredGlobalIndexResult']: + def visit(self, visitor: 'GlobalIndexReader') -> 'Future[Optional[ScoredGlobalIndexResult]]': """Visit the global index reader with this full-text search.""" return visitor.visit_full_text_search(self) def __repr__(self) -> str: - return f"FullTextSearch(field={self.field_name}, query='{self.query_text}', limit={self.limit})" + return ( + f"FullTextSearch(field={self.field_name}, query='{self.query_text}', " + f"limit={self.limit}, operator={self.query_operator})" + ) diff --git a/paimon-python/pypaimon/globalindex/global_index_evaluator.py b/paimon-python/pypaimon/globalindex/global_index_evaluator.py index f399b84de6cb..ebc05d479114 100644 --- a/paimon-python/pypaimon/globalindex/global_index_evaluator.py +++ b/paimon-python/pypaimon/globalindex/global_index_evaluator.py @@ -19,7 +19,7 @@ import threading from collections import deque -from concurrent.futures import Executor, Future +from concurrent.futures import Future from typing import Callable, Collection, Dict, List, Optional from pypaimon.globalindex.global_index_reader import GlobalIndexReader, FieldRef @@ -28,38 +28,22 @@ from pypaimon.schema.data_types import DataField -class _DirectExecutor(Executor): - """Executor that runs callables in the calling thread.""" - - def submit(self, fn, *args, **kwargs): - f = Future() - try: - result = fn(*args, **kwargs) - f.set_result(result) - except Exception as e: - f.set_exception(e) - return f - - def shutdown(self, wait=True): - pass - - class GlobalIndexEvaluator: - """Predicate evaluator for filtering data using global indexes.""" + """Predicate evaluator for filtering data using global indexes. + + Reader visit methods return Future internally — the evaluator no longer + dispatches to an executor. + """ def __init__( self, fields: List[DataField], readers_function: Callable[[DataField], Collection[GlobalIndexReader]], - executor: Optional[Executor] = None, ): self._fields = fields self._field_by_name = {f.name: f for f in fields} self._readers_function = readers_function self._index_readers_cache: Dict[int, Collection[GlobalIndexReader]] = {} - self._reader_locks: Dict[int, threading.Lock] = {} - self._locks_lock = threading.Lock() - self._executor = executor if executor is not None else _DirectExecutor() def evaluate( self, @@ -92,11 +76,8 @@ def _visit_leaf_async(self, predicate: Predicate) -> Future: reader_futures = [] for reader in readers: - lock = self._get_reader_lock(id(reader)) reader_futures.append( - self._executor.submit( - self._visit_reader, reader, predicate, field_ref, lock - ) + self._visit_function(reader, predicate, field_ref) ) all_done = Future() @@ -123,13 +104,6 @@ def on_done(_): return all_done - def _visit_reader(self, reader, predicate, field_ref, lock): - with lock: - result = self._visit_function(reader, predicate, field_ref) - if result is not None: - result.results() - return result - def _combine_reader_results( self, reader_futures: List[Future] ) -> Optional[GlobalIndexResult]: @@ -209,20 +183,12 @@ def _flatten_children(self, method: str, children) -> list: result.append(child) return result - def _get_reader_lock(self, reader_id: int) -> threading.Lock: - with self._locks_lock: - lock = self._reader_locks.get(reader_id) - if lock is None: - lock = threading.Lock() - self._reader_locks[reader_id] = lock - return lock - def _visit_function( self, reader: GlobalIndexReader, predicate: Predicate, field_ref: FieldRef - ) -> Optional[GlobalIndexResult]: + ) -> 'Future[Optional[GlobalIndexResult]]': method = predicate.method literals = predicate.literals @@ -257,7 +223,8 @@ def _visit_function( elif method == 'between': return reader.visit_between(field_ref, literals[0], literals[1]) - return None + from pypaimon.globalindex.global_index_reader import _completed_future + return _completed_future(None) def close(self) -> None: for readers in self._index_readers_cache.values(): diff --git a/paimon-python/pypaimon/globalindex/global_index_reader.py b/paimon-python/pypaimon/globalindex/global_index_reader.py index ada3d2437b16..84825c705607 100644 --- a/paimon-python/pypaimon/globalindex/global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/global_index_reader.py @@ -18,6 +18,7 @@ """Global index reader interface.""" from abc import ABC, abstractmethod +from concurrent.futures import Future from typing import List, Optional @@ -30,82 +31,81 @@ def __init__(self, index: int, name: str, data_type: str): self.data_type = data_type +def _completed_future(value): + """Create a Future that is already completed with the given value.""" + f = Future() + f.set_result(value) + return f + + +def _map_future(source, transform): + """Create a new Future whose result is transform(source.result()).""" + result = Future() + + def on_done(f): + try: + result.set_result(transform(f.result())) + except Exception as e: + result.set_exception(e) + + source.add_done_callback(on_done) + return result + + class GlobalIndexReader(ABC): - """ - Index reader for global index, returns GlobalIndexResult. - - This is the base interface for all global index readers. - """ - - def visit_vector_search(self, vector_search: 'VectorSearch') -> Optional['GlobalIndexResult']: - """Visit a vector search query.""" + """Index reader for global index. All visit methods return Future[Optional[GlobalIndexResult]].""" + + def visit_vector_search(self, vector_search: 'VectorSearch') -> 'Future[Optional[GlobalIndexResult]]': raise NotImplementedError("Vector search not supported by this reader") - def visit_full_text_search(self, full_text_search: 'FullTextSearch') -> Optional['GlobalIndexResult']: - """Visit a full-text search query.""" + def visit_full_text_search(self, full_text_search: 'FullTextSearch') -> 'Future[Optional[GlobalIndexResult]]': raise NotImplementedError("Full-text search not supported by this reader") - def visit_equal(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit an equality predicate.""" - return None + def visit_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_not_equal(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a not-equal predicate.""" - return None + def visit_not_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_less_than(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a less-than predicate.""" - return None + def visit_less_than(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_less_or_equal(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a less-or-equal predicate.""" - return None + def visit_less_or_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_greater_than(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a greater-than predicate.""" - return None + def visit_greater_than(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_greater_or_equal(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a greater-or-equal predicate.""" - return None + def visit_greater_or_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_is_null(self, field_ref: FieldRef) -> Optional['GlobalIndexResult']: - """Visit an is-null predicate.""" - return None + def visit_is_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_is_not_null(self, field_ref: FieldRef) -> Optional['GlobalIndexResult']: - """Visit an is-not-null predicate.""" - return None + def visit_is_not_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_in(self, field_ref: FieldRef, literals: List[object]) -> Optional['GlobalIndexResult']: - """Visit an in predicate.""" - return None + def visit_in(self, field_ref: FieldRef, literals: List[object]) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_not_in(self, field_ref: FieldRef, literals: List[object]) -> Optional['GlobalIndexResult']: - """Visit a not-in predicate.""" - return None + def visit_not_in(self, field_ref: FieldRef, literals: List[object]) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_starts_with(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a starts-with predicate.""" - return None + def visit_starts_with(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_ends_with(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit an ends-with predicate.""" - return None + def visit_ends_with(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_contains(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a contains predicate.""" - return None + def visit_contains(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_like(self, field_ref: FieldRef, literal: object) -> Optional['GlobalIndexResult']: - """Visit a like predicate.""" - return None + def visit_like(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) - def visit_between(self, field_ref: FieldRef, min_v: object, max_v: object) -> Optional['GlobalIndexResult']: - """Visit a between predicate.""" - return None + def visit_between(self, field_ref: FieldRef, min_v: object, max_v: object) -> 'Future[Optional[GlobalIndexResult]]': + return _completed_future(None) @abstractmethod def close(self) -> None: - """Close the reader and release resources.""" pass diff --git a/paimon-python/pypaimon/globalindex/global_index_result.py b/paimon-python/pypaimon/globalindex/global_index_result.py index 1332f677cbd9..709e2d43934b 100644 --- a/paimon-python/pypaimon/globalindex/global_index_result.py +++ b/paimon-python/pypaimon/globalindex/global_index_result.py @@ -16,7 +16,7 @@ # under the License. from abc import ABC, abstractmethod -from typing import Callable, List, Optional +from typing import List from pypaimon.utils.roaring_bitmap import RoaringBitmap64 from pypaimon.utils.range import Range @@ -38,7 +38,7 @@ def offset(self, start_offset: int) -> 'GlobalIndexResult': offset_bitmap = RoaringBitmap64() for row_id in bitmap: offset_bitmap.add(row_id + start_offset) - return LazyGlobalIndexResult(lambda: offset_bitmap) + return SimpleGlobalIndexResult(offset_bitmap) def and_(self, other: 'GlobalIndexResult') -> 'GlobalIndexResult': """Returns the intersection of this result and the other result.""" @@ -59,31 +59,27 @@ def is_empty(self) -> bool: @staticmethod def create_empty() -> 'GlobalIndexResult': """Returns an empty GlobalIndexResult.""" - return LazyGlobalIndexResult(lambda: RoaringBitmap64()) + return SimpleGlobalIndexResult(RoaringBitmap64()) @staticmethod - def create(supplier: Callable[[], RoaringBitmap64]) -> 'GlobalIndexResult': - """Returns a new GlobalIndexResult from supplier.""" - return LazyGlobalIndexResult(supplier) + def create(bitmap: RoaringBitmap64) -> 'GlobalIndexResult': + """Returns a new GlobalIndexResult wrapping the given bitmap.""" + return SimpleGlobalIndexResult(bitmap) @staticmethod def from_range(range_: Range) -> 'GlobalIndexResult': """Returns a new GlobalIndexResult from Range.""" - def create_bitmap(): - result = RoaringBitmap64() - result.add_range(range_.from_, range_.to) - return result - return LazyGlobalIndexResult(create_bitmap) + result = RoaringBitmap64() + result.add_range(range_.from_, range_.to) + return SimpleGlobalIndexResult(result) @staticmethod def from_ranges(ranges: List[Range]) -> 'GlobalIndexResult': """Returns a new GlobalIndexResult from multiple Ranges.""" - def create_bitmap(): - result = RoaringBitmap64() - for r in ranges: - result.add_range(r.from_, r.to) - return result - return LazyGlobalIndexResult(create_bitmap) + result = RoaringBitmap64() + for r in ranges: + result.add_range(r.from_, r.to) + return SimpleGlobalIndexResult(result) class SimpleGlobalIndexResult(GlobalIndexResult): @@ -93,16 +89,3 @@ def __init__(self, result: RoaringBitmap64): def results(self) -> RoaringBitmap64: return self._result - - -class LazyGlobalIndexResult(GlobalIndexResult): - """Lazy implementation of GlobalIndexResult that delays computation.""" - - def __init__(self, supplier: Callable[[], RoaringBitmap64]): - self._supplier = supplier - self._cached: Optional[RoaringBitmap64] = None - - def results(self) -> RoaringBitmap64: - if self._cached is None: - self._cached = self._supplier() - return self._cached diff --git a/paimon-python/pypaimon/globalindex/global_index_scanner.py b/paimon-python/pypaimon/globalindex/global_index_scanner.py index e8e34f641aa9..924d30608e27 100644 --- a/paimon-python/pypaimon/globalindex/global_index_scanner.py +++ b/paimon-python/pypaimon/globalindex/global_index_scanner.py @@ -17,7 +17,6 @@ """Scanner for shard-based global indexes.""" -import os from concurrent.futures import ThreadPoolExecutor from typing import Collection, Optional @@ -43,7 +42,7 @@ def __init__( thread_num: Optional[int] = None, ): self._executor = ThreadPoolExecutor( - max_workers=thread_num or os.cpu_count() or 4 + max_workers=thread_num or 32 ) self._evaluator = self._create_evaluator( fields, file_io, index_path, index_files @@ -76,18 +75,24 @@ def _create_evaluator(self, fields, file_io, index_path, index_files): ) index_metas[field_id][index_type][range_key].append(io_meta) + executor = self._executor + def readers_function(field: DataField) -> Collection[GlobalIndexReader]: - return _create_readers(file_io, index_path, index_metas.get(field.id), field) + return _create_readers(file_io, index_path, index_metas.get(field.id), field, executor) - return GlobalIndexEvaluator(fields, readers_function, self._executor) + return GlobalIndexEvaluator(fields, readers_function) @staticmethod - def create(table, index_files=None, partition_filter=None, predicate=None) -> Optional['GlobalIndexScanner']: + def create(table, index_files=None, partition_filter=None, predicate=None, + snapshot=None) -> Optional['GlobalIndexScanner']: """Create a GlobalIndexScanner. Can be called in two ways: 1. create(table, index_files) - with explicit index files - 2. create(table, partition_filter=..., predicate=...) - scan index files from snapshot + 2. create(table, partition_filter=..., predicate=..., snapshot=...) - + scan index files from snapshot. ``snapshot`` may be passed in by the + caller to avoid a duplicate ``get_latest_snapshot`` REST round-trip + (the caller usually already fetched it for manifest scanning). """ from pypaimon.index.index_file_handler import IndexFileHandler @@ -119,7 +124,8 @@ def index_file_filter(entry): return False return global_index_meta.index_field_id in filter_field_ids - snapshot = table.snapshot_manager().get_latest_snapshot() + if snapshot is None: + snapshot = table.snapshot_manager().get_latest_snapshot() index_file_handler = IndexFileHandler(table=table) entries = index_file_handler.scan(snapshot, index_file_filter) scanned_index_files = [entry.index_file for entry in entries] @@ -150,7 +156,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() -def _create_readers(file_io, index_path, index_type_metas, field): +def _create_readers(file_io, index_path, index_type_metas, field, executor=None): """Create readers for a specific field, dispatched by index_type. Unknown indexTypes raise — a silent skip would make @@ -172,7 +178,7 @@ def _create_readers(file_io, index_path, index_type_metas, field): offset_readers = [] for range_key, io_metas in range_metas.items(): inner_readers = _create_inner_readers( - index_type, file_io, index_path, field, io_metas) + index_type, file_io, index_path, field, io_metas, executor) for inner in inner_readers: offset_readers.append( OffsetGlobalIndexReader( @@ -182,28 +188,25 @@ def _create_readers(file_io, index_path, index_type_metas, field): return readers -def _create_inner_readers(index_type, file_io, index_path, field, io_metas): +def _create_inner_readers(index_type, file_io, index_path, field, io_metas, executor=None): """Build per-file (or per-shard) readers for a single indexType/range.""" if index_type == 'btree': - from pypaimon.globalindex.btree import BTreeIndexReader + from pypaimon.globalindex.btree.lazy_filtered_btree_reader import LazyFilteredBTreeReader from pypaimon.globalindex.btree.key_serializer import create_serializer key_serializer = create_serializer(field.type) - return [ - BTreeIndexReader( - key_serializer=key_serializer, - file_io=file_io, - index_path=index_path, - io_meta=io_meta, - ) - for io_meta in io_metas - ] + return [LazyFilteredBTreeReader( + key_serializer=key_serializer, + file_io=file_io, + index_path=index_path, + io_metas=io_metas, + executor=executor, + )] from pypaimon.globalindex.tantivy import ( TANTIVY_FULLTEXT_IDENTIFIER, TantivyFullTextGlobalIndexReader, ) if index_type == TANTIVY_FULLTEXT_IDENTIFIER: - # Tantivy expects one file per shard; create one reader per io_meta. return [ TantivyFullTextGlobalIndexReader(file_io, index_path, [io_meta]) for io_meta in io_metas diff --git a/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py b/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py index e8e1187ea4e1..fd425b3a130a 100644 --- a/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py @@ -22,9 +22,11 @@ """ import os +import threading + import numpy as np -from pypaimon.globalindex.global_index_reader import GlobalIndexReader +from pypaimon.globalindex.global_index_reader import GlobalIndexReader, _completed_future from pypaimon.globalindex.vector_search_result import DictBasedScoredIndexResult LUMINA_IDENTIFIER = "lumina" @@ -54,6 +56,7 @@ def __init__(self, file_io, index_path, io_metas, options=None): self._index_meta = None self._search_options = None self._stream = None + self._load_lock = threading.Lock() def visit_vector_search(self, vector_search): self._ensure_loaded() @@ -72,14 +75,14 @@ def visit_vector_search(self, vector_search): count = self._searcher.get_count() effective_k = min(limit, count) if effective_k <= 0: - return None + return _completed_future(None) include_row_ids = vector_search.include_row_ids if include_row_ids is not None: filter_id_list = list(include_row_ids) if len(filter_id_list) == 0: - return None + return _completed_future(None) effective_k = min(effective_k, len(filter_id_list)) search_opts = dict(self._search_options) search_opts["search.thread_safe_filter"] = "true" @@ -103,36 +106,38 @@ def visit_vector_search(self, vector_search): float(distances[i]), index_metric) id_to_scores[int(row_id)] = score - return DictBasedScoredIndexResult(id_to_scores) + return _completed_future(DictBasedScoredIndexResult(id_to_scores)) def _ensure_loaded(self): if self._searcher is not None: return - from lumina_data import LuminaSearcher - from pypaimon.globalindex.lumina.lumina_index_meta import LuminaIndexMeta - from pypaimon.globalindex.lumina.lumina_vector_index_options import ( - strip_lumina_options, - ) - - self._index_meta = LuminaIndexMeta.deserialize(self._io_meta.metadata) - # Merge paimon table options (prefix-stripped) with index metadata options; - # index metadata takes precedence as it reflects the actual built index. - searcher_options = strip_lumina_options(self._options) - searcher_options.update(self._index_meta.options) - self._search_options = searcher_options - - file_path = (self._io_meta.external_path - if self._io_meta.external_path - else os.path.join(self._index_path, self._io_meta.file_name)) - stream = self._file_io.new_input_stream(file_path) - try: - self._searcher = LuminaSearcher(searcher_options) - self._searcher.open_stream(stream, self._io_meta.file_size) - self._stream = stream - except Exception: - stream.close() - raise + with self._load_lock: + if self._searcher is not None: + return + + from lumina_data import LuminaSearcher + from pypaimon.globalindex.lumina.lumina_index_meta import LuminaIndexMeta + from pypaimon.globalindex.lumina.lumina_vector_index_options import ( + strip_lumina_options, + ) + + self._index_meta = LuminaIndexMeta.deserialize(self._io_meta.metadata) + searcher_options = strip_lumina_options(self._options) + searcher_options.update(self._index_meta.options) + self._search_options = searcher_options + + file_path = (self._io_meta.external_path + if self._io_meta.external_path + else os.path.join(self._index_path, self._io_meta.file_name)) + stream = self._file_io.new_input_stream(file_path) + try: + self._searcher = LuminaSearcher(searcher_options) + self._searcher.open_stream(stream, self._io_meta.file_size) + self._stream = stream + except Exception: + stream.close() + raise def __enter__(self): return self diff --git a/paimon-python/pypaimon/globalindex/offset_global_index_reader.py b/paimon-python/pypaimon/globalindex/offset_global_index_reader.py index b2c3a74394c4..8d2c5bc38bb4 100644 --- a/paimon-python/pypaimon/globalindex/offset_global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/offset_global_index_reader.py @@ -17,16 +17,16 @@ """A GlobalIndexReader that wraps another reader and applies an offset to all row IDs.""" +from concurrent.futures import Future from typing import List, Optional -from pypaimon.globalindex.global_index_reader import GlobalIndexReader, FieldRef +from pypaimon.globalindex.global_index_reader import FieldRef, GlobalIndexReader, _map_future from pypaimon.globalindex.global_index_result import GlobalIndexResult class OffsetGlobalIndexReader(GlobalIndexReader): - """ - A GlobalIndexReader that wraps another reader and applies an offset - to all row IDs in the results. + """A GlobalIndexReader that wraps another reader and applies an offset + to all row IDs in the results. All visit methods return Future. """ def __init__(self, wrapped: GlobalIndexReader, offset: int, to: int): @@ -34,68 +34,68 @@ def __init__(self, wrapped: GlobalIndexReader, offset: int, to: int): self._offset = offset self._to = to - def visit_vector_search(self, vector_search) -> Optional[GlobalIndexResult]: - result = self._wrapped.visit_vector_search( - vector_search.offset_range(self._offset, self._to)) - if result is not None: - return result.offset(self._offset) - return None + def _apply_offset_future( + self, source: 'Future[Optional[GlobalIndexResult]]' + ) -> 'Future[Optional[GlobalIndexResult]]': + def transform(result): + if result is not None: + return result.offset(self._offset) + return None + return _map_future(source, transform) - def visit_full_text_search(self, full_text_search) -> Optional[GlobalIndexResult]: - result = self._wrapped.visit_full_text_search(full_text_search) - if result is not None: - return result.offset(self._offset) - return None + def visit_vector_search(self, vector_search) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future( + self._wrapped.visit_vector_search( + vector_search.offset_range(self._offset, self._to))) - def visit_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_equal(field_ref, literal)) + def visit_full_text_search(self, full_text_search) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future( + self._wrapped.visit_full_text_search(full_text_search)) - def visit_not_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_not_equal(field_ref, literal)) + def visit_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_equal(field_ref, literal)) - def visit_less_than(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_less_than(field_ref, literal)) + def visit_not_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_not_equal(field_ref, literal)) - def visit_less_or_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_less_or_equal(field_ref, literal)) + def visit_less_than(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_less_than(field_ref, literal)) - def visit_greater_than(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_greater_than(field_ref, literal)) + def visit_less_or_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_less_or_equal(field_ref, literal)) - def visit_greater_or_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_greater_or_equal(field_ref, literal)) + def visit_greater_than(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_greater_than(field_ref, literal)) - def visit_is_null(self, field_ref: FieldRef) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_is_null(field_ref)) + def visit_greater_or_equal(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_greater_or_equal(field_ref, literal)) - def visit_is_not_null(self, field_ref: FieldRef) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_is_not_null(field_ref)) + def visit_is_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_is_null(field_ref)) - def visit_in(self, field_ref: FieldRef, literals: List[object]) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_in(field_ref, literals)) + def visit_is_not_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_is_not_null(field_ref)) - def visit_not_in(self, field_ref: FieldRef, literals: List[object]) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_not_in(field_ref, literals)) + def visit_in(self, field_ref: FieldRef, literals: List[object]) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_in(field_ref, literals)) - def visit_starts_with(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_starts_with(field_ref, literal)) + def visit_not_in(self, field_ref: FieldRef, literals: List[object]) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_not_in(field_ref, literals)) - def visit_ends_with(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_ends_with(field_ref, literal)) + def visit_starts_with(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_starts_with(field_ref, literal)) - def visit_contains(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_contains(field_ref, literal)) + def visit_ends_with(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_ends_with(field_ref, literal)) - def visit_like(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_like(field_ref, literal)) + def visit_contains(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_contains(field_ref, literal)) - def visit_between(self, field_ref: FieldRef, min_v: object, max_v: object) -> Optional[GlobalIndexResult]: - return self._apply_offset(self._wrapped.visit_between(field_ref, min_v, max_v)) + def visit_like(self, field_ref: FieldRef, literal: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_like(field_ref, literal)) - def _apply_offset(self, result: Optional[GlobalIndexResult]) -> Optional[GlobalIndexResult]: - if result is not None: - return result.offset(self._offset) - return None + def visit_between(self, field_ref: FieldRef, min_v: object, max_v: object) -> 'Future[Optional[GlobalIndexResult]]': + return self._apply_offset_future(self._wrapped.visit_between(field_ref, min_v, max_v)) def close(self) -> None: self._wrapped.close() diff --git a/paimon-python/pypaimon/globalindex/tantivy/tantivy_full_text_global_index_reader.py b/paimon-python/pypaimon/globalindex/tantivy/tantivy_full_text_global_index_reader.py index e9d2150cf7d2..5ed0b9a91c25 100644 --- a/paimon-python/pypaimon/globalindex/tantivy/tantivy_full_text_global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/tantivy/tantivy_full_text_global_index_reader.py @@ -21,20 +21,143 @@ backed by a stream-based Directory. No temp files are created on disk. """ +import json +import logging import os import struct import threading -from typing import Dict, List, Optional +from dataclasses import dataclass +from typing import Dict, List -from pypaimon.globalindex.global_index_reader import GlobalIndexReader, FieldRef -from pypaimon.globalindex.global_index_result import GlobalIndexResult +from pypaimon.globalindex.global_index_reader import GlobalIndexReader, FieldRef, _completed_future from pypaimon.globalindex.vector_search_result import ( - ScoredGlobalIndexResult, DictBasedScoredIndexResult, ) from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta +logger = logging.getLogger(__name__) + TANTIVY_FULLTEXT_IDENTIFIER = "tantivy-fulltext" +TANTIVY_NGRAM_TOKENIZER = "paimon_ngram" +TANTIVY_JIEBA_TOKENIZER = "paimon_jieba" +TANTIVY_CUSTOM_TOKENIZER = "paimon_custom" +_SUPPORTED_LANGUAGES = { + "arabic", + "danish", + "dutch", + "english", + "finnish", + "french", + "german", + "greek", + "hungarian", + "italian", + "norwegian", + "portuguese", + "romanian", + "russian", + "spanish", + "swedish", + "tamil", + "turkish", +} + + +@dataclass(frozen=True) +class TantivyFullTextIndexOptions: + """Tokenizer options serialized by the Java Tantivy full-text index.""" + + tokenizer: str = "default" + ngram_min_gram: int = 2 + ngram_max_gram: int = 2 + ngram_prefix_only: bool = False + lower_case: bool = True + max_token_length: int = 40 + ascii_folding: bool = False + stem: bool = False + language: str = "english" + remove_stop_words: bool = False + stop_words: str = "" + with_position: bool = True + + @staticmethod + def deserialize(data): + if not data: + return TantivyFullTextIndexOptions() + return TantivyFullTextIndexOptions._deserialize_json(data) + + @staticmethod + def _deserialize_json(data): + config = json.loads(data.decode("utf-8")) + stop_words = config.get("stop-words", []) + if isinstance(stop_words, list): + stop_words = ";".join( + word for word in stop_words if word is not None) + + return TantivyFullTextIndexOptions( + tokenizer=config.get("tokenizer", "default"), + ngram_min_gram=config.get("ngram.min-gram", 2), + ngram_max_gram=config.get("ngram.max-gram", 2), + ngram_prefix_only=config.get("ngram.prefix-only", False), + lower_case=config.get("lower-case", True), + max_token_length=config.get("max-token-length", 40), + ascii_folding=config.get("ascii-folding", False), + stem=config.get("stem", False), + language=config.get("language", "english"), + remove_stop_words=config.get("remove-stop-words", False), + stop_words=stop_words, + with_position=config.get("with-position", True)) + + def __post_init__(self): + tokenizer = "" if self.tokenizer is None else self.tokenizer.strip().lower() + object.__setattr__(self, "tokenizer", tokenizer) + language = "" if self.language is None else self.language.strip().lower() + object.__setattr__(self, "language", language) + object.__setattr__(self, "stop_words", _normalize_stop_words(self.stop_words)) + + supported_tokenizers = ("default", "simple", "whitespace", "raw", "ngram", "jieba") + if tokenizer not in supported_tokenizers: + raise ValueError("Unsupported Tantivy tokenizer: %s" % tokenizer) + if self.ngram_min_gram <= 0: + raise ValueError("ngram min gram must be positive.") + if self.ngram_max_gram <= 0: + raise ValueError("ngram max gram must be positive.") + if self.ngram_min_gram > self.ngram_max_gram: + raise ValueError( + "ngram min gram must not be greater than max gram.") + if self.max_token_length <= 0: + raise ValueError("max token length must be positive.") + if self.language not in _SUPPORTED_LANGUAGES: + raise ValueError("Unsupported Tantivy language: %s" % self.language) + + def tokenizer_name(self): + if self.tokenizer == "ngram": + return TANTIVY_NGRAM_TOKENIZER + if self.tokenizer == "jieba": + return TANTIVY_JIEBA_TOKENIZER + if self._needs_custom_tokenizer(): + return TANTIVY_CUSTOM_TOKENIZER + return self.tokenizer + + def stop_word_list(self): + if not self.stop_words: + return [] + return [ + word.strip() + for word in self.stop_words.split(";") + if word.strip() + ] + + def _needs_custom_tokenizer(self): + return ( + self.tokenizer in ("simple", "whitespace", "raw") + or self.max_token_length != 40 + or not self.lower_case + or self.ascii_folding + or self.stem + or self.remove_stop_words + or bool(self.stop_word_list()) + ) class StreamDirectory: @@ -139,22 +262,27 @@ def __init__(self, file_io, index_path: str, io_metas: List[GlobalIndexIOMeta]): self._file_io = file_io self._index_path = index_path self._io_meta = io_metas[0] + self._index_options = TantivyFullTextIndexOptions.deserialize( + self._io_meta.metadata) self._searcher = None self._index = None + self._schema = None self._stream = None + self._load_lock = threading.Lock() - def visit_full_text_search(self, full_text_search) -> Optional[ScoredGlobalIndexResult]: + def visit_full_text_search(self, full_text_search): self._ensure_loaded() - query_text = full_text_search.query_text limit = full_text_search.limit searcher = self._searcher - query = self._index.parse_query(query_text, ["text"]) + import tantivy + + query = self._parse_query(tantivy, full_text_search) results = searcher.search(query, limit) if not results.hits: - return DictBasedScoredIndexResult({}) + return _completed_future(DictBasedScoredIndexResult({})) doc_addresses = [addr for score, addr in results.hits] scores = [score for score, addr in results.hits] @@ -164,37 +292,226 @@ def visit_full_text_search(self, full_text_search) -> Optional[ScoredGlobalIndex for row_id, score in zip(row_ids, scores): id_to_scores[row_id] = score - return DictBasedScoredIndexResult(id_to_scores) + return _completed_future(DictBasedScoredIndexResult(id_to_scores)) def _ensure_loaded(self): if self._searcher is not None: return - import tantivy + with self._load_lock: + if self._searcher is not None: + return + + import tantivy + + self._verify_tantivy_tokenizer_api(tantivy) + file_path = (self._io_meta.external_path + if self._io_meta.external_path + else os.path.join(self._index_path, self._io_meta.file_name)) + stream = self._file_io.new_input_stream(file_path) + try: + file_names, file_offsets, file_lengths = self._parse_archive_header(stream) + directory = StreamDirectory(stream, file_names, file_offsets, file_lengths) + + schema = self._build_schema(tantivy) + try: + self._index = tantivy.Index( + schema, directory=directory, + ) + except ValueError as e: + if "schema does not match" not in str(e): + raise + logger.warning( + "Schema mismatch, retrying with " + "row_id stored=true" + ) + schema = self._build_schema( + tantivy, row_id_stored=True, + ) + self._index = tantivy.Index( + schema, directory=directory, + ) + self._schema = schema + self._register_tokenizer(tantivy, self._index) + self._index.reload() + self._searcher = self._index.searcher() + self._stream = stream + except Exception: + stream.close() + raise + + def _build_schema(self, tantivy, row_id_stored=False): + schema_builder = tantivy.SchemaBuilder() + schema_builder.add_unsigned_field( + "row_id", stored=row_id_stored, indexed=True, fast=True, + ) + tokenizer_name = self._index_options.tokenizer_name() + field_kwargs = {} + if not self._index_options.with_position: + field_kwargs["index_option"] = "freq" + if tokenizer_name == "default": + schema_builder.add_text_field( + "text", stored=False, **field_kwargs, + ) + else: + schema_builder.add_text_field( + "text", stored=False, + tokenizer_name=tokenizer_name, **field_kwargs, + ) + return schema_builder.build() + + def _register_tokenizer(self, tantivy, index): + if (self._index_options.tokenizer == "default" + and self._index_options.tokenizer_name() == "default"): + return + + if self._index_options.tokenizer == "ngram": + tokenizer = tantivy.Tokenizer.ngram( + min_gram=self._index_options.ngram_min_gram, + max_gram=self._index_options.ngram_max_gram, + prefix_only=self._index_options.ngram_prefix_only) + elif self._index_options.tokenizer in ("default", "simple"): + tokenizer = tantivy.Tokenizer.simple() + elif self._index_options.tokenizer == "whitespace": + tokenizer = tantivy.Tokenizer.whitespace() + elif self._index_options.tokenizer == "raw": + tokenizer = tantivy.Tokenizer.raw() + else: + return - # Open the archive stream (prefer external_path if the manifest set it). - file_path = (self._io_meta.external_path - if self._io_meta.external_path - else os.path.join(self._index_path, self._io_meta.file_name)) - stream = self._file_io.new_input_stream(file_path) + analyzer_builder = tantivy.TextAnalyzerBuilder(tokenizer) + if self._index_options.max_token_length != 40: + analyzer_builder = analyzer_builder.filter( + tantivy.Filter.remove_long(self._index_options.max_token_length)) + if self._index_options.lower_case: + analyzer_builder = analyzer_builder.filter(tantivy.Filter.lowercase()) + if self._index_options.ascii_folding: + analyzer_builder = analyzer_builder.filter(tantivy.Filter.ascii_fold()) + if self._index_options.stem: + analyzer_builder = analyzer_builder.filter( + tantivy.Filter.stemmer(self._index_options.language)) + if self._index_options.remove_stop_words: + analyzer_builder = analyzer_builder.filter( + tantivy.Filter.stopword(self._index_options.language)) + stop_words = self._index_options.stop_word_list() + if stop_words: + analyzer_builder = analyzer_builder.filter( + tantivy.Filter.custom_stopword(stop_words)) + analyzer = analyzer_builder.build() + index.register_tokenizer(self._index_options.tokenizer_name(), analyzer) + + def _parse_query(self, tantivy, full_text_search): + query_text = full_text_search.query_text + conjunction_by_default = full_text_search.query_operator == "and" + if self._index_options.tokenizer != "jieba": + if conjunction_by_default: + return self._index.parse_query( + query_text, ["text"], conjunction_by_default=True) + return self._index.parse_query(query_text, ["text"]) + + tokens = self._jieba_query_tokens(query_text) + if not tokens: + return tantivy.Query.empty_query() + + term_queries = [ + tantivy.Query.term_query(self._schema, "text", token) + for token in tokens + ] + if len(term_queries) == 1: + return term_queries[0] + occur = tantivy.Occur.Must if conjunction_by_default else tantivy.Occur.Should + return tantivy.Query.boolean_query([ + (occur, query) + for query in term_queries + ]) + + def _jieba_query_tokens(self, query_text): try: - # Parse archive header to get file layout - file_names, file_offsets, file_lengths = self._parse_archive_header(stream) - directory = StreamDirectory(stream, file_names, file_offsets, file_lengths) - - # Open tantivy index from stream-backed directory - schema_builder = tantivy.SchemaBuilder() - schema_builder.add_unsigned_field("row_id", stored=False, indexed=True, fast=True) - schema_builder.add_text_field("text", stored=False) - schema = schema_builder.build() - - self._index = tantivy.Index(schema, directory=directory) - self._index.reload() - self._searcher = self._index.searcher() - self._stream = stream - except Exception: - stream.close() - raise + import jieba + except ImportError as e: + raise RuntimeError( + "PyPaimon Tantivy full-text search requires Python package " + "'jieba' to query jieba tokenizer indexes. Install it with " + "`pip install jieba`.") from e + + seen = set() + tokens = [] + for word, _, _ in jieba.tokenize(query_text, mode="search", HMM=True): + token = word.strip() + if self._index_options.lower_case: + token = token.lower() + if token and token not in seen: + seen.add(token) + tokens.append(token) + return tokens + + def _verify_tantivy_tokenizer_api(self, tantivy): + if (self._index_options.tokenizer == "default" + and self._index_options.tokenizer_name() == "default"): + return + + missing = [] + if self._index_options.tokenizer != "jieba": + required_classes = ["TextAnalyzerBuilder", "Tokenizer"] + if (self._index_options.lower_case + or self._index_options.max_token_length != 40 + or self._index_options.ascii_folding + or self._index_options.stem + or self._index_options.remove_stop_words + or self._index_options.stop_word_list()): + required_classes.append("Filter") + else: + required_classes = ["Query", "Occur"] + for name in required_classes: + if not hasattr(tantivy, name): + missing.append(name) + + tokenizer = getattr(tantivy, "Tokenizer", None) + filter_ = getattr(tantivy, "Filter", None) + query = getattr(tantivy, "Query", None) + occur = getattr(tantivy, "Occur", None) + tokenizer_apis = { + "default": "simple", + "ngram": "ngram", + "simple": "simple", + "whitespace": "whitespace", + "raw": "raw", + } + tokenizer_api = tokenizer_apis.get(self._index_options.tokenizer) + if (tokenizer_api is not None and tokenizer is not None + and not hasattr(tokenizer, tokenizer_api)): + missing.append("Tokenizer.%s" % tokenizer_api) + if self._index_options.tokenizer != "jieba" and filter_ is not None: + filter_checks = [] + if self._index_options.max_token_length != 40: + filter_checks.append(("remove_long", "Filter.remove_long")) + if self._index_options.lower_case: + filter_checks.append(("lowercase", "Filter.lowercase")) + if self._index_options.ascii_folding: + filter_checks.append(("ascii_fold", "Filter.ascii_fold")) + if self._index_options.stem: + filter_checks.append(("stemmer", "Filter.stemmer")) + if self._index_options.remove_stop_words: + filter_checks.append(("stopword", "Filter.stopword")) + if self._index_options.stop_word_list(): + filter_checks.append(("custom_stopword", "Filter.custom_stopword")) + for attr, api_name in filter_checks: + if not hasattr(filter_, attr): + missing.append(api_name) + if self._index_options.tokenizer == "jieba" and query is not None: + for name in ("empty_query", "term_query", "boolean_query"): + if not hasattr(query, name): + missing.append("Query.%s" % name) + if self._index_options.tokenizer == "jieba" and occur is not None: + for name in ("Should", "Must"): + if not hasattr(occur, name): + missing.append("Occur.%s" % name) + if missing: + tokenizer_name = self._index_options.tokenizer + raise RuntimeError( + "PyPaimon Tantivy full-text search requires a tantivy-py " + "version with %s tokenizer support. Missing API(s): %s" + % (tokenizer_name, ", ".join(missing))) @staticmethod def _parse_archive_header(stream): @@ -225,54 +542,55 @@ def _parse_archive_header(stream): # =================== unsupported ===================== - def visit_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_equal(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_not_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_not_equal(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_less_than(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_less_than(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_less_or_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_less_or_equal(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_greater_than(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_greater_than(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_greater_or_equal(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_greater_or_equal(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_is_null(self, field_ref: FieldRef) -> Optional[GlobalIndexResult]: - return None + def visit_is_null(self, field_ref: FieldRef): + return _completed_future(None) - def visit_is_not_null(self, field_ref: FieldRef) -> Optional[GlobalIndexResult]: - return None + def visit_is_not_null(self, field_ref: FieldRef): + return _completed_future(None) - def visit_in(self, field_ref: FieldRef, literals: List[object]) -> Optional[GlobalIndexResult]: - return None + def visit_in(self, field_ref: FieldRef, literals: List[object]): + return _completed_future(None) - def visit_not_in(self, field_ref: FieldRef, literals: List[object]) -> Optional[GlobalIndexResult]: - return None + def visit_not_in(self, field_ref: FieldRef, literals: List[object]): + return _completed_future(None) - def visit_starts_with(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_starts_with(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_ends_with(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_ends_with(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_contains(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_contains(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_like(self, field_ref: FieldRef, literal: object) -> Optional[GlobalIndexResult]: - return None + def visit_like(self, field_ref: FieldRef, literal: object): + return _completed_future(None) - def visit_between(self, field_ref: FieldRef, min_v: object, max_v: object) -> Optional[GlobalIndexResult]: - return None + def visit_between(self, field_ref: FieldRef, min_v: object, max_v: object): + return _completed_future(None) def close(self) -> None: self._searcher = None self._index = None + self._schema = None if self._stream is not None: self._stream.close() self._stream = None @@ -298,3 +616,19 @@ def _read_fully(stream, length: int) -> bytes: buf.extend(chunk) remaining -= len(chunk) return bytes(buf) + + +def _normalize_stop_words(stop_words): + if stop_words is None: + return "" + if isinstance(stop_words, list): + return ";".join( + word.strip() + for word in stop_words + if word is not None and word.strip() + ) + return ";".join( + word.strip() + for word in stop_words.split(";") + if word.strip() + ) diff --git a/paimon-python/pypaimon/globalindex/union_global_index_reader.py b/paimon-python/pypaimon/globalindex/union_global_index_reader.py index e9ec0dc6871f..6e2e2a9f9cc4 100644 --- a/paimon-python/pypaimon/globalindex/union_global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/union_global_index_reader.py @@ -17,15 +17,15 @@ """A GlobalIndexReader that unions results from multiple underlying readers. -Each visit_* call is dispatched to every underlying reader and the results are -OR-combined into a single ``GlobalIndexResult``. Readers returning ``None`` -("cannot answer") are skipped; an empty bitmap DOES contribute to the union -and is NOT a short-circuit signal. +Each visit_* call dispatches to every underlying reader (which returns Future +internally) and combines the results via OR once all futures complete. """ +import threading +from concurrent.futures import Future from typing import Callable, List, Optional -from pypaimon.globalindex.global_index_reader import FieldRef, GlobalIndexReader +from pypaimon.globalindex.global_index_reader import FieldRef, GlobalIndexReader, _completed_future from pypaimon.globalindex.global_index_result import GlobalIndexResult @@ -34,82 +34,94 @@ class UnionGlobalIndexReader(GlobalIndexReader): def __init__(self, readers: List[GlobalIndexReader]): self._readers = readers - def _union(self, visitor: Callable[[GlobalIndexReader], Optional[GlobalIndexResult]] - ) -> Optional[GlobalIndexResult]: - result: Optional[GlobalIndexResult] = None - for reader in self._readers: - current = visitor(reader) - if current is None: - continue - if result is None: - result = current - else: - result = result.or_(current) - return result + def _union_futures(self, visitor: Callable[[GlobalIndexReader], 'Future[Optional[GlobalIndexResult]]'] + ) -> 'Future[Optional[GlobalIndexResult]]': + futures = [visitor(reader) for reader in self._readers] + + if not futures: + return _completed_future(None) + + all_done = Future() + remaining = [len(futures)] + lock = threading.Lock() + + def on_done(_): + with lock: + remaining[0] -= 1 + if remaining[0] == 0: + try: + result: Optional[GlobalIndexResult] = None + for f in futures: + current = f.result() + if current is None: + continue + if result is None: + result = current + else: + result = result.or_(current) + all_done.set_result(result) + except Exception as e: + all_done.set_exception(e) + + for f in futures: + f.add_done_callback(on_done) + + return all_done # ---- vector / full-text search ---------------------------------------- - def visit_vector_search(self, vector_search) -> Optional[GlobalIndexResult]: - from pypaimon.globalindex.vector_search_result import ( - ScoredGlobalIndexResult, - ) - result: Optional[ScoredGlobalIndexResult] = None - for reader in self._readers: - current = reader.visit_vector_search(vector_search) - if current is None: - continue - result = current if result is None else result.or_(current) - return result + def visit_vector_search(self, vector_search) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_vector_search(vector_search)) - def visit_full_text_search(self, full_text_search) -> Optional[GlobalIndexResult]: - return self._union(lambda r: r.visit_full_text_search(full_text_search)) + def visit_full_text_search(self, full_text_search) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_full_text_search(full_text_search)) # ---- scalar predicates (every reader sees the visit) ------------------ - def visit_equal(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_equal(field_ref, literal)) + def visit_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_equal(field_ref, literal)) - def visit_not_equal(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_not_equal(field_ref, literal)) + def visit_not_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_not_equal(field_ref, literal)) - def visit_less_than(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_less_than(field_ref, literal)) + def visit_less_than(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_less_than(field_ref, literal)) - def visit_less_or_equal(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_less_or_equal(field_ref, literal)) + def visit_less_or_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_less_or_equal(field_ref, literal)) - def visit_greater_than(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_greater_than(field_ref, literal)) + def visit_greater_than(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_greater_than(field_ref, literal)) - def visit_greater_or_equal(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_greater_or_equal(field_ref, literal)) + def visit_greater_or_equal(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_greater_or_equal(field_ref, literal)) - def visit_is_null(self, field_ref: FieldRef): - return self._union(lambda r: r.visit_is_null(field_ref)) + def visit_is_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_is_null(field_ref)) - def visit_is_not_null(self, field_ref: FieldRef): - return self._union(lambda r: r.visit_is_not_null(field_ref)) + def visit_is_not_null(self, field_ref: FieldRef) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_is_not_null(field_ref)) - def visit_in(self, field_ref: FieldRef, literals): - return self._union(lambda r: r.visit_in(field_ref, literals)) + def visit_in(self, field_ref: FieldRef, literals) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_in(field_ref, literals)) - def visit_not_in(self, field_ref: FieldRef, literals): - return self._union(lambda r: r.visit_not_in(field_ref, literals)) + def visit_not_in(self, field_ref: FieldRef, literals) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_not_in(field_ref, literals)) - def visit_starts_with(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_starts_with(field_ref, literal)) + def visit_starts_with(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_starts_with(field_ref, literal)) - def visit_ends_with(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_ends_with(field_ref, literal)) + def visit_ends_with(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_ends_with(field_ref, literal)) - def visit_contains(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_contains(field_ref, literal)) + def visit_contains(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_contains(field_ref, literal)) - def visit_like(self, field_ref: FieldRef, literal): - return self._union(lambda r: r.visit_like(field_ref, literal)) + def visit_like(self, field_ref: FieldRef, literal) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_like(field_ref, literal)) - def visit_between(self, field_ref: FieldRef, from_v, to_v): - return self._union(lambda r: r.visit_between(field_ref, from_v, to_v)) + def visit_between(self, field_ref: FieldRef, from_v, to_v) -> 'Future[Optional[GlobalIndexResult]]': + return self._union_futures(lambda r: r.visit_between(field_ref, from_v, to_v)) def close(self) -> None: for reader in self._readers: diff --git a/paimon-python/pypaimon/globalindex/vector_search.py b/paimon-python/pypaimon/globalindex/vector_search.py index dd46ed60baf4..3cd8d2c88a7c 100644 --- a/paimon-python/pypaimon/globalindex/vector_search.py +++ b/paimon-python/pypaimon/globalindex/vector_search.py @@ -17,6 +17,7 @@ """VectorSearch for performing vector similarity search.""" +from concurrent.futures import Future from dataclasses import dataclass, field from typing import List, Optional, Union import numpy as np @@ -83,7 +84,7 @@ def offset_range(self, from_: int, to: int) -> 'VectorSearch': ) return self - def visit(self, visitor: 'GlobalIndexReader') -> Optional['GlobalIndexResult']: + def visit(self, visitor: 'GlobalIndexReader') -> 'Future[Optional[GlobalIndexResult]]': """Visit the global index reader with this vector search.""" return visitor.visit_vector_search(self) diff --git a/paimon-python/pypaimon/globalindex/vector_search_result.py b/paimon-python/pypaimon/globalindex/vector_search_result.py index 5953ea29ca29..fe2b974550d8 100644 --- a/paimon-python/pypaimon/globalindex/vector_search_result.py +++ b/paimon-python/pypaimon/globalindex/vector_search_result.py @@ -111,15 +111,15 @@ def create_empty() -> 'ScoredGlobalIndexResult': @staticmethod def create( - supplier: Callable[[], RoaringBitmap64], + bitmap: RoaringBitmap64, score_getter: ScoreGetter ) -> 'ScoredGlobalIndexResult': - """Creates a new VectorSearchGlobalIndexResult from supplier.""" - return LazyScoredGlobalIndexResult(supplier, score_getter) + """Creates a new ScoredGlobalIndexResult wrapping the given bitmap.""" + return SimpleScoredGlobalIndexResult(bitmap, score_getter) class SimpleScoredGlobalIndexResult(ScoredGlobalIndexResult): - """Simple implementation of VectorSearchGlobalIndexResult.""" + """Simple implementation of ScoredGlobalIndexResult.""" def __init__(self, bitmap: RoaringBitmap64, score_getter_fn: ScoreGetter): self._bitmap = bitmap @@ -132,23 +132,6 @@ def score_getter(self) -> ScoreGetter: return self._score_getter_fn -class LazyScoredGlobalIndexResult(ScoredGlobalIndexResult): - """Lazy implementation of VectorSearchGlobalIndexResult.""" - - def __init__(self, supplier: Callable[[], RoaringBitmap64], score_getter_fn: ScoreGetter): - self._supplier = supplier - self._score_getter_fn = score_getter_fn - self._cached: Optional[RoaringBitmap64] = None - - def results(self) -> RoaringBitmap64: - if self._cached is None: - self._cached = self._supplier() - return self._cached - - def score_getter(self) -> ScoreGetter: - return self._score_getter_fn - - class DictBasedScoredIndexResult(ScoredGlobalIndexResult): """Vector search result backed by a dictionary of row_id -> score.""" diff --git a/paimon-python/pypaimon/manifest/index_manifest_file.py b/paimon-python/pypaimon/manifest/index_manifest_file.py index 4e65e95e0cb1..a7d82d4b12ba 100644 --- a/paimon-python/pypaimon/manifest/index_manifest_file.py +++ b/paimon-python/pypaimon/manifest/index_manifest_file.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import uuid from io import BytesIO from typing import List, Optional @@ -24,11 +25,61 @@ from pypaimon.index.deletion_vector_meta import DeletionVectorMeta from pypaimon.index.index_file_meta import IndexFileMeta from pypaimon.manifest.index_manifest_entry import IndexManifestEntry -from pypaimon.table.row.generic_row import GenericRowDeserializer +from pypaimon.table.row.generic_row import (GenericRowDeserializer, + GenericRowSerializer) +from pypaimon.utils.file_store_path_factory import FileStorePathFactory + +# DV and global-index sub-schemas required by INDEX_MANIFEST_ENTRY_SCHEMA for +# Avro compatibility with Java; values are always null in data-evolution tables. +_DELETION_VECTOR_META_SCHEMA = { + "type": "record", + "name": "DeletionVectorMeta", + "fields": [ + {"name": "f0", "type": "string"}, + {"name": "f1", "type": "long"}, + {"name": "f2", "type": "int"}, + {"name": "_CARDINALITY", "type": ["null", "long"], "default": None}, + ], +} + +_GLOBAL_INDEX_META_SCHEMA = { + "type": "record", + "name": "GlobalIndexMeta", + "fields": [ + {"name": "_ROW_RANGE_START", "type": "long"}, + {"name": "_ROW_RANGE_END", "type": "long"}, + {"name": "_INDEX_FIELD_ID", "type": "int"}, + {"name": "_EXTRA_FIELD_IDS", + "type": ["null", {"type": "array", "items": "int"}], "default": None}, + {"name": "_INDEX_META", "type": ["null", "bytes"], "default": None}, + ], +} + +INDEX_MANIFEST_ENTRY_SCHEMA = { + "type": "record", + "name": "IndexManifestEntry", + "fields": [ + {"name": "_VERSION", "type": "int"}, + {"name": "_KIND", "type": "int"}, + {"name": "_PARTITION", "type": "bytes"}, + {"name": "_BUCKET", "type": "int"}, + {"name": "_INDEX_TYPE", "type": "string"}, + {"name": "_FILE_NAME", "type": "string"}, + {"name": "_FILE_SIZE", "type": "long"}, + {"name": "_ROW_COUNT", "type": "long"}, + {"name": "_DELETIONS_VECTORS_RANGES", + "type": ["null", {"type": "array", "items": _DELETION_VECTOR_META_SCHEMA}], + "default": None}, + {"name": "_EXTERNAL_PATH", "type": ["null", "string"], "default": None}, + {"name": "_GLOBAL_INDEX", + "type": ["null", _GLOBAL_INDEX_META_SCHEMA], "default": None}, + ], +} + +_INDEX_ENTRY_VERSION = 1 class IndexManifestFile: - """Index manifest file reader for reading index manifest entries.""" DELETION_VECTORS_INDEX = "DELETION_VECTORS" @@ -172,5 +223,69 @@ def _parse_global_index_meta(self, global_index_record) -> Optional[GlobalIndexM row_range_start=global_index_record.get('_ROW_RANGE_START', 0), row_range_end=global_index_record.get('_ROW_RANGE_END', 0), index_field_id=global_index_record.get('_INDEX_FIELD_ID', 0), + extra_field_ids=global_index_record.get('_EXTRA_FIELD_IDS'), index_meta=global_index_record.get('_INDEX_META') ) + + def combine_deletes( + self, + previous_name: Optional[str], + deletes: List[IndexManifestEntry], + ) -> Optional[str]: + if not deletes: + return previous_name + previous = self.read(previous_name) if previous_name else [] + delete_names = {e.index_file.file_name for e in deletes} + survivors = [e for e in previous if e.index_file.file_name not in delete_names] + if not survivors: + return None + return self.write(survivors) + + def write(self, entries: List[IndexManifestEntry]) -> str: + file_name = f"{FileStorePathFactory.INDEX_MANIFEST_PREFIX}{uuid.uuid4()}" + path = f"{self.manifest_path}/{file_name}" + records = [self._to_avro_record(e) for e in entries] + try: + buffer = BytesIO() + fastavro.writer(buffer, INDEX_MANIFEST_ENTRY_SCHEMA, records) + with self.file_io.new_output_stream(path) as output_stream: + output_stream.write(buffer.getvalue()) + except Exception as e: + self.file_io.delete_quietly(path) + raise RuntimeError( + f"Exception occurs when writing records to {path}. Clean up." + ) from e + return file_name + + def _to_avro_record(self, entry: IndexManifestEntry) -> dict: + index_file = entry.index_file + dv_ranges = None + if index_file.dv_ranges: + dv_ranges = [ + {"f0": dv.data_file_name, "f1": dv.offset, "f2": dv.length, + "_CARDINALITY": dv.cardinality} + for dv in index_file.dv_ranges.values() + ] + global_index = None + if index_file.global_index_meta is not None: + gim = index_file.global_index_meta + global_index = { + "_ROW_RANGE_START": gim.row_range_start, + "_ROW_RANGE_END": gim.row_range_end, + "_INDEX_FIELD_ID": gim.index_field_id, + "_EXTRA_FIELD_IDS": gim.extra_field_ids, + "_INDEX_META": gim.index_meta, + } + return { + "_VERSION": _INDEX_ENTRY_VERSION, + "_KIND": entry.kind, + "_PARTITION": GenericRowSerializer.to_bytes(entry.partition), + "_BUCKET": entry.bucket, + "_INDEX_TYPE": index_file.index_type, + "_FILE_NAME": index_file.file_name, + "_FILE_SIZE": index_file.file_size, + "_ROW_COUNT": index_file.row_count, + "_DELETIONS_VECTORS_RANGES": dv_ranges, + "_EXTERNAL_PATH": index_file.external_path, + "_GLOBAL_INDEX": global_index, + } diff --git a/paimon-python/pypaimon/ray/__init__.py b/paimon-python/pypaimon/ray/__init__.py index f36eb0253dd8..4280187956e3 100644 --- a/paimon-python/pypaimon/ray/__init__.py +++ b/paimon-python/pypaimon/ray/__init__.py @@ -16,5 +16,24 @@ # under the License. from pypaimon.ray.ray_paimon import read_paimon, write_paimon +from pypaimon.ray.data_evolution_merge_into import ( + WhenMatched, + WhenNotMatched, + merge_into, +) +from pypaimon.ray.data_evolution_merge_transform import ( + source_col, + target_col, + lit, +) -__all__ = ["read_paimon", "write_paimon"] +__all__ = [ + "read_paimon", + "write_paimon", + "merge_into", + "WhenMatched", + "WhenNotMatched", + "source_col", + "target_col", + "lit", +] diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py new file mode 100644 index 000000000000..fa824b44a2f9 --- /dev/null +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -0,0 +1,574 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""MERGE INTO ... USING ... for Paimon data-evolution tables via Ray Datasets.""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple + +import pyarrow as pa + +from pypaimon.ray.data_evolution_merge_join import ( + build_matched_update_ds, + build_not_matched_insert_ds, + distributed_update_apply, + distributed_write_collect_msgs, +) +from pypaimon.ray.data_evolution_merge_transform import ( + LiteralValue, + OnSpec, + SetSpec, + SourceColumnRef, + TargetColumnRef, + WhenMatched, + WhenNotMatched, + _NormalizedClause, +) + +__all__ = ["merge_into", "WhenMatched", "WhenNotMatched"] + + +@dataclass(frozen=True) +class _PrepareCtx: + """Bag of values _prepare hands to _build_datasets.""" + target_on_cols: List[str] + source_on_cols: List[str] + settable_field_names: List[str] + full_target_field_names: List[str] + update_pa_schema: pa.Schema + full_pa_schema: pa.Schema + catalog_options: Dict[str, str] + + +def merge_into( + target: str, + source: Any, + catalog_options: Dict[str, str], + *, + on: OnSpec, + when_matched: Sequence[WhenMatched] = (), + when_not_matched: Sequence[WhenNotMatched] = (), + num_partitions: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, +) -> Dict[str, int]: + _require_ray_join() + num_partitions = _resolve_num_partitions(num_partitions) + + table, source_ds, matched_specs, not_matched_specs, ctx = _prepare( + target, source, catalog_options, + list(when_matched), list(when_not_matched), on, + ) + base_snapshot = table.snapshot_manager().get_latest_snapshot() + + update_ds, insert_ds, update_cols_union = _build_datasets( + target, source_ds, matched_specs, not_matched_specs, + ctx, base_snapshot, num_partitions, ray_remote_args, + ) + + return _execute_and_commit( + table, update_ds, insert_ds, update_cols_union, + base_snapshot, num_partitions, + ray_remote_args, concurrency, + ) + + +def _prepare(target, source, catalog_options, when_matched, when_not_matched, on): + if not when_matched and not when_not_matched: + raise ValueError( + "At least one of when_matched or when_not_matched must be non-empty." + ) + if len(when_matched) > 1 or len(when_not_matched) > 1: + raise NotImplementedError( + "merge_into currently supports a single WhenMatched and a single " + "WhenNotMatched clause; multi-clause fall-through will be added " + "in a follow-up PR." + ) + target_on_cols, source_on_cols = _normalize_on(on) + + from pypaimon.catalog.catalog_factory import CatalogFactory + + catalog = CatalogFactory.create(catalog_options) + table = catalog.get_table(target) + if not table.options.data_evolution_enabled(): + raise ValueError( + f"merge_into requires 'data-evolution.enabled' = 'true' on '{target}'." + ) + if not table.options.row_tracking_enabled(): + raise ValueError( + f"merge_into requires 'row-tracking.enabled' = 'true' on '{target}'." + ) + + blob_cols = _blob_col_names(table) + full_target_field_names = list(table.field_names) + # SET specs only cover non-blob columns: update can't rewrite blob files + # (data evolution puts them in dedicated .blob files), and insert leaves + # blob columns null since the source can't carry them through SET="*". + settable_field_names = [ + c for c in full_target_field_names if c not in blob_cols + ] + on_map = dict(zip(target_on_cols, source_on_cols)) + if when_matched and table.partition_keys: + raise ValueError( + "merge_into does not support matched clauses on partitioned " + "tables; cross-partition row movement is not implemented." + ) + matched_specs = [ + _NormalizedClause( + spec=_normalize_set_spec( + c.update, settable_field_names, on_map, + ), + condition=c.condition, + ) + for c in when_matched + ] + has_condition = any( + c.condition is not None + for c in list(when_matched) + list(when_not_matched) + ) + if has_condition: + from pypaimon.ray.merge_condition import ( + _require_datafusion, extract_target_columns, + ) + _require_datafusion() + for c in when_not_matched: + if c.condition is not None: + t_refs = extract_target_columns(c.condition) + if t_refs: + raise ValueError( + f"WhenNotMatched condition must not reference " + f"target columns (t.*), but found: {sorted(t_refs)}" + ) + for c in list(when_matched) + list(when_not_matched): + if c.condition is not None: + blob_refs = extract_target_columns(c.condition) & blob_cols + if blob_refs: + raise ValueError( + f"condition must not reference blob columns, " + f"but found: {sorted(blob_refs)}" + ) + not_matched_specs = [] + for c in when_not_matched: + spec = _normalize_set_spec( + c.insert, settable_field_names, on_map, + allow_target_refs=False, + ) + for tk, sk in on_map.items(): + if tk in settable_field_names and tk not in spec: + spec[tk] = SourceColumnRef(sk) + not_matched_specs.append( + _NormalizedClause(spec=spec, condition=c.condition) + ) + + source_snapshot_id = None + if isinstance(source, str): + source_snapshot = ( + catalog.get_table(source) + .snapshot_manager() + .get_latest_snapshot() + ) + if source_snapshot is not None: + source_snapshot_id = source_snapshot.id + + source_ds = _normalize_source( + source, catalog_options, source_snapshot_id=source_snapshot_id, + ) + _validate_source_on_cols(source_ds, source_on_cols) + _validate_source_has_target_cols( + source_ds, matched_specs + not_matched_specs, + ) + + if has_condition: + from pypaimon.ray.merge_condition import extract_columns + source_names = set(_source_schema_or_raise(source_ds).names) + target_names = set(full_target_field_names) + for c in list(when_matched) + list(when_not_matched): + if c.condition is not None: + for ref in extract_columns(c.condition): + prefix, col = ref.split(".", 1) + if prefix == "s" and col not in source_names: + raise ValueError( + f"condition references unknown source " + f"column '{col}'" + ) + if prefix == "t" and col not in target_names: + raise ValueError( + f"condition references unknown target " + f"column '{col}'" + ) + + from pypaimon.schema.data_types import PyarrowFieldParser + full_pa_schema = PyarrowFieldParser.from_paimon_schema( + table.table_schema.fields + ) + # update_pa_schema strips blob (only non-blob cols are written by the + # update path); insert_pa_schema is the full table schema so the writer + # gets every column (blob columns end up null). + update_pa_schema = pa.schema( + [full_pa_schema.field(c) for c in settable_field_names] + ) + ctx = _PrepareCtx( + target_on_cols=target_on_cols, + source_on_cols=source_on_cols, + settable_field_names=settable_field_names, + full_target_field_names=full_target_field_names, + update_pa_schema=update_pa_schema, + full_pa_schema=full_pa_schema, + catalog_options=catalog_options, + ) + return table, source_ds, matched_specs, not_matched_specs, ctx + + +def _build_datasets( + target, source_ds, matched_specs, not_matched_specs, + ctx: "_PrepareCtx", base_snapshot, num_partitions, ray_remote_args, +): + # Pin every target read to base_snapshot so all branches see the same + # snapshot the caller observed; otherwise concurrent commits in between + # would mix data from different snapshots. + base_snapshot_id = base_snapshot.id if base_snapshot is not None else None + + update_ds = None + insert_ds = None + update_cols_union: List[str] = [] + + # Mirror Spark: matched/not-matched run as two independent joins + # (inner / left_anti). One unified left_outer join would force + # joined.materialize() to feed both branches, which can OOM on large merges. + if matched_specs and base_snapshot is not None: + update_cols_union = _union_update_cols(matched_specs) + update_ds = build_matched_update_ds( + target_identifier=target, + source_ds=source_ds, + target_on=ctx.target_on_cols, + source_on=ctx.source_on_cols, + clauses=matched_specs, + target_field_names=ctx.settable_field_names, + target_pa_schema=ctx.update_pa_schema, + update_cols=update_cols_union, + catalog_options=ctx.catalog_options, + num_partitions=num_partitions, + resolve_target_projection=_resolve_target_projection, + snapshot_id=base_snapshot_id, + ray_remote_args=ray_remote_args, + ) + + if not_matched_specs: + # Insert writes the full target schema; SET spec only covers + # settable cols, so blob columns fall through to null. + insert_ds = build_not_matched_insert_ds( + target_identifier=target, + source_ds=source_ds, + target_on=ctx.target_on_cols, + source_on=ctx.source_on_cols, + clauses=not_matched_specs, + target_field_names=ctx.full_target_field_names, + target_pa_schema=ctx.full_pa_schema, + catalog_options=ctx.catalog_options, + num_partitions=num_partitions, + snapshot_id=base_snapshot_id, + target_empty=base_snapshot is None, + ray_remote_args=ray_remote_args, + ) + + return update_ds, insert_ds, update_cols_union + + +def _execute_and_commit( + table, update_ds, insert_ds, update_cols_union, + base_snapshot, num_partitions, + ray_remote_args, concurrency, +): + update_msgs: list = [] + num_updated = 0 + if update_ds is not None: + try: + update_msgs, num_updated = distributed_update_apply( + update_ds, table, update_cols_union, + num_partitions=num_partitions, + ray_remote_args=ray_remote_args, + base_snapshot_id=( + base_snapshot.id + if base_snapshot is not None else None + ), + ) + except Exception as e: + _reraise_inner(e) + + all_msgs: list = list(update_msgs) + num_inserted = 0 + if insert_ds is not None: + try: + insert_msgs = distributed_write_collect_msgs( + insert_ds, table, + ray_remote_args=ray_remote_args, concurrency=concurrency, + ) + except Exception as e: + _reraise_inner(e) + num_inserted = sum( + f.row_count for m in insert_msgs for f in m.new_files + ) + all_msgs.extend(insert_msgs) + if all_msgs: + wb = table.new_batch_write_builder() + tc = wb.new_commit() + tc.commit(all_msgs) + tc.close() + + # num_matched = rows that passed the condition and were updated + return { + "num_matched": num_updated, + "num_inserted": num_inserted, + "num_unchanged": 0, + } + + +def _normalize_on(on: OnSpec) -> Tuple[List[str], List[str]]: + if isinstance(on, Mapping): + target_cols = list(on.keys()) + source_cols = list(on.values()) + else: + target_cols = list(on) + source_cols = list(on) + if not target_cols: + raise ValueError("'on' must be non-empty.") + return target_cols, source_cols + + +def _resolve_num_partitions(num_partitions: Optional[int]) -> int: + if num_partitions is not None: + return num_partitions + try: + import ray + + cpus = int(ray.cluster_resources().get("CPU", 4)) + return max(1, cpus * 2) + except Exception: + return 4 + + +def _require_ray_join() -> None: + import ray + from packaging.version import parse + + if parse(ray.__version__) < parse("2.50.0"): + raise RuntimeError( + f"merge_into requires ray>=2.50; " + f"installed ray is {ray.__version__}." + ) + + +def _blob_col_names(table) -> set: + return { + f.name + for f in table.table_schema.fields + if getattr(f.type, "type", None) == "BLOB" + } + + +def _reraise_inner(err: BaseException) -> None: + """Unwrap Ray's RayTaskError so callers see the worker-side exception.""" + inner = err + cause = getattr(err, "cause", None) or getattr(err, "__cause__", None) + while cause is not None: + inner = cause + cause = getattr(inner, "cause", None) or getattr(inner, "__cause__", None) + if inner is err: + raise err + raise inner from err + + +def _union_update_cols(clauses: List[_NormalizedClause]) -> List[str]: + seen: List[str] = [] + seen_set: set = set() + for clause in clauses: + for col in clause.spec.keys(): + if col not in seen_set: + seen.append(col) + seen_set.add(col) + return seen + + +def _needed_target_cols( + clauses: List[_NormalizedClause], + on: Sequence[str], + update_cols: Sequence[str], + all_target_cols: Sequence[str], +) -> list: + # Target needs only: join keys, t.col refs, and cols that may fall back + # (not set by every clause). Cols all clauses set from source aren't read. + needed = set(on) + set_by_all = set(update_cols) + for clause in clauses: + for value in clause.spec.values(): + if isinstance(value, TargetColumnRef): + needed.add(value.column) + set_by_all &= set(clause.spec.keys()) + needed |= set(update_cols) - set_by_all + return [c for c in all_target_cols if c in needed] + + +def _resolve_target_projection( + clauses: List[_NormalizedClause], + target_on: Sequence[str], + update_cols: Sequence[str], + target_field_names: Sequence[str], +) -> list: + needed = set(_needed_target_cols( + clauses, target_on, update_cols, target_field_names, + )) + if any(c.condition is not None for c in clauses): + from pypaimon.ray.merge_condition import extract_target_columns + target_set = set(target_field_names) + for clause in clauses: + if clause.condition is not None: + needed |= extract_target_columns(clause.condition) & target_set + return [c for c in target_field_names if c in needed] + + +def _normalize_set_spec( + spec: SetSpec, + target_field_names: Sequence[str], + on_map: Optional[Mapping[str, str]] = None, + allow_target_refs: bool = True, +) -> Dict[str, Any]: + on_map = on_map or {} + if spec == "*": + return { + col: SourceColumnRef(on_map.get(col, col)) + for col in target_field_names + } + if not isinstance(spec, Mapping): + raise TypeError( + f"SET spec must be '*' or a mapping, got {type(spec).__name__}" + ) + if not spec: + raise ValueError("SET spec must not be empty") + target_set = set(target_field_names) + for key in spec: + if key not in target_set: + raise ValueError( + f"SET spec references unknown target column '{key}'" + ) + result: Dict[str, Any] = {} + for key, val in spec.items(): + if callable(val) and not isinstance(val, type): + raise TypeError( + "SET values must be source_col(), target_col(), " + "lit(), or literals, not callables" + ) + if isinstance(val, SourceColumnRef): + result[key] = val + elif isinstance(val, TargetColumnRef): + if not allow_target_refs: + raise ValueError( + "INSERT spec must not reference target columns " + f"(t.*), but found: 't.{val.column}'" + ) + if val.column not in target_set: + raise ValueError( + f"SET spec references unknown target column " + f"'{val.column}'" + ) + result[key] = val + elif isinstance(val, LiteralValue): + result[key] = val + elif isinstance(val, str) and val.startswith("s."): + result[key] = SourceColumnRef(val[2:]) + elif isinstance(val, str) and val.startswith("t."): + if not allow_target_refs: + raise ValueError( + "INSERT spec must not reference target columns " + f"(t.*), but found: '{val}'" + ) + ref = val[2:] + if ref not in target_set: + raise ValueError( + f"SET spec references unknown target column '{ref}'" + ) + result[key] = TargetColumnRef(ref) + else: + result[key] = LiteralValue(val) + return result + + +def _normalize_source( + source: Any, + catalog_options: Dict[str, str], + source_snapshot_id: Optional[int] = None, +): + import ray.data + + if isinstance(source, ray.data.Dataset): + return source + if isinstance(source, str): + from pypaimon.ray.ray_paimon import read_paimon + read_kwargs = {} + if source_snapshot_id is not None: + read_kwargs["snapshot_id"] = source_snapshot_id + return read_paimon(source, catalog_options, **read_kwargs) + if isinstance(source, pa.Table): + return ray.data.from_arrow(source) + try: + import pandas as pd + except ImportError: + pd = None + if pd is not None and isinstance(source, pd.DataFrame): + return ray.data.from_pandas(source) + raise TypeError( + "source must be a ray.data.Dataset, a Paimon table identifier string, " + f"a pyarrow.Table, or a pandas.DataFrame; got {type(source).__name__}." + ) + + +def _source_schema_or_raise(source_ds): + """Get source schema; refuse to proceed if Ray can't tell us the columns.""" + schema = source_ds.schema() + if schema is None: + raise ValueError( + "merge_into could not infer the source schema; pass a " + "ray.data.Dataset that has been materialized (e.g. via " + ".materialize()) or constructed from pyarrow/pandas." + ) + return schema + + +def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None: + names = set(_source_schema_or_raise(source_ds).names) + missing = [c for c in on if c not in names] + if missing: + raise ValueError( + f"'on' columns {missing} missing from source schema {list(names)}." + ) + + +def _validate_source_has_target_cols( + source_ds, + specs: List[_NormalizedClause], +) -> None: + names = set(_source_schema_or_raise(source_ds).names) + needed = set() + for clause in specs: + for val in clause.spec.values(): + if isinstance(val, SourceColumnRef): + needed.add(val.column) + missing = sorted(needed - names) + if missing: + raise ValueError( + f"source is missing columns {missing} referenced by SET spec" + ) diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py b/paimon-python/pypaimon/ray/data_evolution_merge_join.py new file mode 100644 index 000000000000..f01f9b59aab8 --- /dev/null +++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py @@ -0,0 +1,388 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import pyarrow as pa + +from pypaimon.ray.data_evolution_merge_transform import ( + _NormalizedClause, + build_update_schema, + vectorized_insert_transform, + vectorized_matched_transform, +) + + +def _map_kwargs( + ray_remote_args: Optional[Dict[str, Any]], +) -> Dict[str, Any]: + """Build kwargs for map_batches/map_groups; spread ray_remote_args because + those APIs take remote options as **kwargs, not under a 'ray_remote_args' + key.""" + kwargs: Dict[str, Any] = {"batch_format": "pyarrow"} + if ray_remote_args: + kwargs.update(ray_remote_args) + return kwargs + + +def build_matched_update_ds( + *, + target_identifier: str, + source_ds, + target_on: Sequence[str], + source_on: Sequence[str], + clauses: List[_NormalizedClause], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, + update_cols: Sequence[str], + catalog_options: Dict[str, str], + num_partitions: int, + resolve_target_projection, + snapshot_id: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +) -> Tuple: + from pypaimon.ray.ray_paimon import read_paimon + from pypaimon.table.special_fields import SpecialFields + + row_id_name = SpecialFields.ROW_ID.name + needed_cols = resolve_target_projection( + clauses, target_on, update_cols, target_field_names, + ) + projection = [row_id_name] + [c for c in needed_cols if c != row_id_name] + + target_ds = read_paimon( + target_identifier, catalog_options, + projection=projection, snapshot_id=snapshot_id, + ) + update_schema = build_update_schema(target_pa_schema, update_cols, row_id_name) + + target_renamed = target_ds.rename_columns( + {c: f"t.{c}" for c in target_ds.schema().names} + ) + source_cols = list(source_ds.schema().names) + source_renamed = source_ds.rename_columns( + {c: f"s.{c}" for c in source_cols} + ) + + joined = target_renamed.join( + source_renamed, + join_type="inner", + num_partitions=num_partitions, + on=tuple(f"t.{c}" for c in target_on), + right_on=tuple(f"s.{c}" for c in source_on), + ) + + # MVP supports a single matched clause; future fan-out (conditions, multi- + # clause fall-through) must thread every clause's spec through the + # transform — guard so silent first-only behaviour can't sneak in. + assert len(clauses) == 1, ( + f"build_matched_update_ds expected 1 clause, got {len(clauses)}" + ) + spec = clauses[0].spec + condition = clauses[0].condition + captured_update_cols = list(update_cols) + captured_row_id_name = row_id_name + captured_on_pairs = list(zip(source_on, target_on)) + captured_schema = update_schema + + captured_apply = None + captured_rewritten = None + if condition is not None: + from pypaimon.ray.merge_condition import ( + apply_condition, remap_source_on_keys, rewrite_condition, + ) + on_map = dict(zip(source_on, target_on)) + captured_rewritten = remap_source_on_keys( + rewrite_condition(condition), on_map, + ) + captured_apply = apply_condition + + def _transform(batch: pa.Table) -> pa.Table: + if captured_apply is not None: + batch = captured_apply( + batch, captured_rewritten, captured_schema, + ) + if batch.num_rows == 0: + return batch + return vectorized_matched_transform( + batch, spec, captured_on_pairs, + captured_update_cols, captured_row_id_name, + captured_schema, + ) + + return joined.map_batches(_transform, **_map_kwargs(ray_remote_args)) + + +def distributed_update_apply( + update_ds, + table, + write_update_cols: Sequence[str], + *, + num_partitions: int, + ray_remote_args: Optional[Dict[str, Any]] = None, + base_snapshot_id: Optional[int] = None, +) -> Tuple[list, int]: + import numpy as np + import pickle + import uuid + + import pyarrow.compute as pc + import ray + + from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER + from pypaimon.table.special_fields import SpecialFields + from pypaimon.write.table_update_by_row_id import TableUpdateByRowId + + row_id_name = SpecialFields.ROW_ID.name + cols = list(write_update_cols) + + for col in cols: + if col not in table.field_names: + raise ValueError( + f"Column '{col}' is not in target table schema." + ) + + planner = TableUpdateByRowId( + table, + "_merge_into_planner_" + uuid.uuid4().hex[:8], + BATCH_COMMIT_IDENTIFIER, + ) + sorted_first_row_ids = list(planner.first_row_ids) + if not sorted_first_row_ids: + return [], 0 + + # Pin commit-time conflict check to the snapshot the join was built on, + # so concurrent commits between read and planner are detected. + check_from_snapshot = ( + base_snapshot_id if base_snapshot_id is not None + else planner.snapshot_id + ) + + # Put file metadata into Ray's object store and pass a single ref to + # workers. Avoids per-task manifest re-scans (Jingsong review #6) and + # avoids serializing the metadata into every task's closure. Override + # snapshot_id with the join's base snapshot so commit-time conflict + # detection covers the read→planner window. + from dataclasses import replace + files_info = replace( + planner._snapshot_files_info(), + snapshot_id=check_from_snapshot, + ) + precomputed_info_ref = ray.put(files_info) + + frid_col = "_FIRST_ROW_ID" + captured_sorted = sorted_first_row_ids + captured_sorted_arr = np.asarray(captured_sorted, dtype=np.int64) + valid_ranges = planner.valid_row_id_ranges + range_starts = np.asarray([r.from_ for r in valid_ranges], dtype=np.int64) + range_ends = np.asarray([r.to for r in valid_ranges], dtype=np.int64) + + def _assign_frid(batch: pa.Table) -> pa.Table: + if batch.num_rows == 0: + return batch.append_column( + frid_col, pa.array([], type=pa.int64()) + ) + rid_col = batch.column(row_id_name) + if rid_col.null_count: + raise ValueError( + "_ROW_ID is null; planner snapshot is stale " + "or matched rows come from a different table." + ) + rids = rid_col.to_numpy(zero_copy_only=False) + # Check each row_id belongs to a valid range (vectorized). + in_range = np.zeros(len(rids), dtype=bool) + for s, e in zip(range_starts, range_ends): + in_range |= (rids >= s) & (rids <= e) + if not in_range.all(): + bad = rids[~in_range][0] + raise ValueError( + f"_ROW_ID {bad} does not belong to any valid range " + f"{[f'[{r.from_}, {r.to}]' for r in valid_ranges]}; " + f"planner snapshot is stale or matched rows come " + f"from a different table." + ) + idx = np.searchsorted( + captured_sorted_arr, rids, side="right" + ) - 1 + frids = captured_sorted_arr[idx] + return batch.append_column( + frid_col, pa.array(frids, type=pa.int64()) + ) + + map_kwargs = _map_kwargs(ray_remote_args) + with_frid = update_ds.map_batches(_assign_frid, **map_kwargs) + + captured_table = table + captured_cols = cols + + def _apply_group(group: pa.Table) -> pa.Table: + if group.num_rows == 0: + return pa.Table.from_pydict({ + "msgs_blob": pa.array([], type=pa.binary()), + "n_updated": pa.array([], type=pa.int64()), + }) + + if ( + pc.count_distinct(group.column(row_id_name)).as_py() + != group.num_rows + ): + raise ValueError( + "MERGE matched multiple source rows to the same " + "target _ROW_ID. Deduplicate the source before " + "merging." + ) + + for_update = group.drop_columns([frid_col]) + worker = TableUpdateByRowId( + captured_table, + "_merge_into_shard_" + uuid.uuid4().hex[:8], + BATCH_COMMIT_IDENTIFIER, + _precomputed_files_info=ray.get(precomputed_info_ref), + ) + msgs = worker.update_columns(for_update, list(captured_cols)) + return pa.Table.from_pydict({ + "msgs_blob": [pickle.dumps(msgs)], + "n_updated": pa.array( + [for_update.num_rows], type=pa.int64() + ), + }) + + # One group per target data file; bounded by file count and num_partitions. + group_partitions = max( + 1, min(len(captured_sorted), num_partitions) + ) + msgs_ds = with_frid.groupby( + frid_col, num_partitions=group_partitions + ).map_groups(_apply_group, **map_kwargs) + + all_msgs: list = [] + num_updated = 0 + for batch in msgs_ds.iter_batches(batch_format="pyarrow"): + for blob in batch.column("msgs_blob").to_pylist(): + all_msgs.extend(pickle.loads(blob)) + for n in batch.column("n_updated").to_pylist(): + num_updated += n + return all_msgs, num_updated + + +def build_not_matched_insert_ds( + *, + target_identifier: str, + source_ds, + target_on: Sequence[str], + source_on: Sequence[str], + clauses: List[_NormalizedClause], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, + catalog_options: Dict[str, str], + num_partitions: int, + target_empty: bool = False, + snapshot_id: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +): + from pypaimon.ray.ray_paimon import read_paimon + from pypaimon.ray.shuffle import _coerce_large_string_types + + captured_field_names = list(target_field_names) + out_schema = target_pa_schema + + source_cols = list(source_ds.schema().names) + source_renamed = source_ds.rename_columns( + {c: f"s.{c}" for c in source_cols} + ) + + if target_empty: + unmatched = source_renamed + else: + target_ds = read_paimon( + target_identifier, catalog_options, + projection=list(target_on), snapshot_id=snapshot_id, + ) + target_renamed = target_ds.rename_columns( + {c: f"t.{c}" for c in target_on} + ) + unmatched = source_renamed.join( + target_renamed, + join_type="left_anti", + num_partitions=num_partitions, + on=tuple(f"s.{c}" for c in source_on), + right_on=tuple(f"t.{c}" for c in target_on), + ) + + # MVP supports a single not-matched clause; see build_matched_update_ds + # for why we assert instead of silently dropping the rest. + assert len(clauses) == 1, ( + f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}" + ) + spec = clauses[0].spec + condition = clauses[0].condition + captured_apply = None + captured_rewritten = None + if condition is not None: + from pypaimon.ray.merge_condition import apply_condition, rewrite_condition + captured_rewritten = rewrite_condition(condition) + captured_apply = apply_condition + + def _transform(batch: pa.Table) -> pa.Table: + if captured_apply is not None: + batch = captured_apply( + batch, captured_rewritten, out_schema, + ) + if batch.num_rows == 0: + return _coerce_large_string_types(batch) + return _coerce_large_string_types( + vectorized_insert_transform( + batch, spec, captured_field_names, out_schema + ) + ) + + return unmatched.map_batches( + _transform, **_map_kwargs(ray_remote_args) + ) + + +def distributed_write_collect_msgs( + insert_ds, + table, + *, + ray_remote_args: Optional[Dict[str, Any]], + concurrency: Optional[int], +) -> list: + from pypaimon.write.ray_datasink import PaimonDatasink + + class _CollectingDatasink(PaimonDatasink): + def __init__(self, t): + super().__init__(t, overwrite=False) + self.collected: list = [] + + def on_write_complete(self, write_result): + self.collected = [ + m + for batch in self._extract_write_returns(write_result) + for m in batch + if not m.is_empty() + ] + + sink = _CollectingDatasink(table) + write_kwargs: Dict[str, Any] = {} + if ray_remote_args is not None: + write_kwargs["ray_remote_args"] = ray_remote_args + if concurrency is not None: + write_kwargs["concurrency"] = concurrency + insert_ds.write_datasink(sink, **write_kwargs) + return sink.collected diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py new file mode 100644 index 000000000000..003977f3e7f2 --- /dev/null +++ b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py @@ -0,0 +1,150 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from dataclasses import dataclass +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union + +import pyarrow as pa + +SetSpec = Union[str, Mapping[str, Any]] +OnSpec = Union[Sequence[str], Mapping[str, str]] + + +@dataclass(frozen=True) +class SourceColumnRef: + column: str + + +@dataclass(frozen=True) +class TargetColumnRef: + column: str + + +@dataclass(frozen=True) +class LiteralValue: + value: Any + + +def source_col(name: str) -> SourceColumnRef: + return SourceColumnRef(name) + + +def target_col(name: str) -> TargetColumnRef: + return TargetColumnRef(name) + + +def lit(value: Any) -> LiteralValue: + return LiteralValue(value) + + +@dataclass +class WhenMatched: + update: SetSpec + condition: Optional[str] = None + + +@dataclass +class WhenNotMatched: + insert: SetSpec + condition: Optional[str] = None + + +@dataclass +class _NormalizedClause: + spec: Dict[str, Any] + condition: Optional[str] = None + + +def vectorized_matched_transform( + batch: pa.Table, + spec: Dict[str, Any], + on_pairs: Sequence[Tuple[str, str]], + update_cols: Sequence[str], + row_id_name: str, + update_schema: pa.Schema, +) -> pa.Table: + available = set(batch.schema.names) + arrays: list = [batch.column(f"t.{row_id_name}")] + for col in update_cols: + out_type = update_schema.field(col).type + if col in spec: + arrays.append( + _resolve_spec_array( + spec[col], batch, available, on_pairs, out_type + ) + ) + else: + arrays.append(batch.column(f"t.{col}")) + return pa.Table.from_arrays(arrays, schema=update_schema) + + +def vectorized_insert_transform( + batch: pa.Table, + spec: Dict[str, Any], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, +) -> pa.Table: + available = set(batch.schema.names) + arrays: list = [] + for col in target_field_names: + out_type = target_pa_schema.field(col).type + if col in spec: + arrays.append( + _resolve_spec_array( + spec[col], batch, available, (), out_type + ) + ) + else: + arrays.append(pa.nulls(batch.num_rows, type=out_type)) + return pa.Table.from_arrays(arrays, schema=target_pa_schema) + + +def build_update_schema( + target_pa_schema: pa.Schema, + update_cols: Sequence[str], + row_id_name: str, +) -> pa.Schema: + return pa.schema( + [pa.field(row_id_name, pa.int64(), nullable=False)] + + [target_pa_schema.field(col) for col in update_cols] + ) + + +def _resolve_spec_array( + val: Any, + batch: pa.Table, + available: set, + on_pairs: Sequence[Tuple[str, str]], + out_type: pa.DataType, +): + if isinstance(val, LiteralValue): + return pa.array([val.value] * batch.num_rows, type=out_type) + if isinstance(val, SourceColumnRef): + ref = val.column + if f"s.{ref}" in available: + return batch.column(f"s.{ref}") + for sk, tk in on_pairs: + if sk == ref and f"t.{tk}" in available: + return batch.column(f"t.{tk}") + return pa.nulls(batch.num_rows, type=out_type) + if isinstance(val, TargetColumnRef): + col_name = f"t.{val.column}" + return batch.column(col_name) if col_name in available else pa.nulls( + batch.num_rows, type=out_type + ) + raise TypeError(f"unexpected spec value type: {type(val).__name__}") diff --git a/paimon-python/pypaimon/ray/merge_condition.py b/paimon-python/pypaimon/ray/merge_condition.py new file mode 100644 index 000000000000..5497406c5cd2 --- /dev/null +++ b/paimon-python/pypaimon/ray/merge_condition.py @@ -0,0 +1,104 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import re +from typing import Mapping, Set + +import pyarrow as pa + + +_COL_REF_PATTERN = re.compile(r'\b([st])\.(\w+)\b') + + +def _require_datafusion(): + try: + import datafusion + return datafusion + except ImportError: + raise ImportError( + "merge_into condition expressions require the 'datafusion' " + "package. Install it with: pip install pypaimon[sql]" + ) + + +_STRING_LITERAL = re.compile(r"'(?:[^']|'')*'") + + +def _strip_string_literals(condition: str) -> str: + return _STRING_LITERAL.sub('', condition) + + +def rewrite_condition(condition: str) -> str: + parts, last = [], 0 + for m in _STRING_LITERAL.finditer(condition): + parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:m.start()])) + parts.append(m.group()) + last = m.end() + parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:])) + return ''.join(parts) + + +def remap_source_on_keys( + rewritten: str, on_map: Mapping[str, str], +) -> str: + for s_col, t_col in on_map.items(): + old, new = f'"s.{s_col}"', f'"t.{t_col}"' + parts, last = [], 0 + for m in _STRING_LITERAL.finditer(rewritten): + parts.append(rewritten[last:m.start()].replace(old, new)) + parts.append(m.group()) + last = m.end() + parts.append(rewritten[last:].replace(old, new)) + rewritten = ''.join(parts) + return rewritten + + +def filter_batch( + batch: pa.Table, condition: str, _pre_rewritten: bool = False, +) -> pa.Table: + if batch.num_rows == 0: + return batch + datafusion = _require_datafusion() + rewritten = condition if _pre_rewritten else rewrite_condition(condition) + ctx = datafusion.SessionContext() + ctx.register_record_batches("_batch", [batch.to_batches()]) + result = ctx.sql( + f'SELECT * FROM _batch WHERE {rewritten}' + ) + return result.to_arrow_table() + + +def apply_condition( + batch: pa.Table, rewritten: str, empty_schema: pa.Schema, +) -> pa.Table: + batch = filter_batch(batch, rewritten, _pre_rewritten=True) + if batch.num_rows == 0: + return empty_schema.empty_table() + return batch + + +def extract_columns(condition: str) -> Set[str]: + stripped = _strip_string_literals(condition) + return {f"{m.group(1)}.{m.group(2)}" + for m in _COL_REF_PATTERN.finditer(stripped)} + + +def extract_target_columns(condition: str) -> Set[str]: + stripped = _strip_string_literals(condition) + return {m.group(2) for m in _COL_REF_PATTERN.finditer(stripped) + if m.group(1) == "t"} diff --git a/paimon-python/pypaimon/ray/ray_paimon.py b/paimon-python/pypaimon/ray/ray_paimon.py index 86505097d8a1..0e796dd7c223 100644 --- a/paimon-python/pypaimon/ray/ray_paimon.py +++ b/paimon-python/pypaimon/ray/ray_paimon.py @@ -26,12 +26,26 @@ write_paimon(ds, "db.table", catalog_options={"warehouse": "/path"}) """ -from typing import Any, Dict, List, Optional - -import ray.data +import importlib +from typing import Any, Dict, List, Optional, TYPE_CHECKING from pypaimon.common.predicate import Predicate +if TYPE_CHECKING: + import ray.data + + +def _require_ray_data(): + try: + return importlib.import_module("ray.data") + except ModuleNotFoundError as e: + if e.name not in ("ray", "ray.data"): + raise + raise ImportError( + "PyPaimon Ray APIs require the 'ray' package. " + "Install it with: pip install pypaimon[ray]" + ) from e + def read_paimon( table_identifier: str, @@ -46,7 +60,7 @@ def read_paimon( concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, **read_args, -) -> ray.data.Dataset: +) -> "ray.data.Dataset": """Read a Paimon table into a Ray Dataset. Args: @@ -68,8 +82,11 @@ def read_paimon( Returns: A ``ray.data.Dataset`` containing the table data. """ + ray_data = _require_ray_data() + from pypaimon.read.datasource.ray_datasource import RayDatasource from pypaimon.read.datasource.split_provider import CatalogSplitProvider + from pypaimon.schema.data_types import PyarrowFieldParser if snapshot_id is not None and tag_name is not None: raise ValueError( @@ -81,18 +98,29 @@ def read_paimon( "override_num_blocks must be at least 1, got {}".format(override_num_blocks) ) - datasource = RayDatasource( - CatalogSplitProvider( - table_identifier=table_identifier, - catalog_options=catalog_options, - predicate=filter, - projection=projection, - limit=limit, - snapshot_id=snapshot_id, - tag_name=tag_name, - ) + split_provider = CatalogSplitProvider( + table_identifier=table_identifier, + catalog_options=catalog_options, + predicate=filter, + projection=projection, + limit=limit, + snapshot_id=snapshot_id, + tag_name=tag_name, ) - ds = ray.data.read_datasource( + + if not split_provider.splits(): + schema = PyarrowFieldParser.from_paimon_schema( + split_provider.read_type() + ) + import pyarrow + empty_table = pyarrow.Table.from_arrays( + [pyarrow.array([], type=field.type) for field in schema], + schema=schema, + ) + return ray_data.from_arrow(empty_table) + + datasource = RayDatasource(split_provider) + ds = ray_data.read_datasource( datasource, ray_remote_args=ray_remote_args, concurrency=concurrency, @@ -107,21 +135,24 @@ def read_paimon( def write_paimon( - dataset: ray.data.Dataset, + dataset: "ray.data.Dataset", table_identifier: str, catalog_options: Dict[str, str], *, overwrite: bool = False, concurrency: Optional[int] = None, ray_remote_args: Optional[Dict[str, Any]] = None, + hash_fixed_precluster: str = "auto", ) -> None: """Write a Ray Dataset to a Paimon table. - For HASH_FIXED tables, rows are automatically clustered by - ``(partition_keys..., bucket)`` before writing so that each - (partition, bucket) lands in a single Ray task. This avoids the - small-file storm that Ray's default round-robin distribution would - otherwise produce. No user configuration is required. + HASH_FIXED rows are assigned to the correct bucket by the Paimon + writer. Optional pre-clustering is only a file-count optimization. + The legacy ``map_groups`` pre-clustering mode materializes each + ``(partition_keys..., bucket)`` group on one Ray node and should + only be used when every group fits in memory. HASH_DYNAMIC and + CROSS_PARTITION primary-key Ray writes are rejected because Ray + write tasks create independent Paimon writers. Args: dataset: The Ray Dataset to write. @@ -130,7 +161,14 @@ def write_paimon( overwrite: If ``True``, overwrite existing data in the table. concurrency: Optional max number of Ray write tasks to run concurrently. ray_remote_args: Optional kwargs passed to ``ray.remote`` in write tasks. + hash_fixed_precluster: HASH_FIXED pre-clustering mode. ``"auto"`` + and ``"off"`` write append-only HASH_FIXED tables directly + and reject HASH_FIXED primary-key tables. ``"map_groups"`` + preserves the legacy small-file optimization and its single + group memory bound for HASH_FIXED primary-key tables. """ + _require_ray_data() + from pypaimon.catalog.catalog_factory import CatalogFactory from pypaimon.ray.shuffle import maybe_apply_repartition from pypaimon.write.ray_datasink import PaimonDatasink @@ -138,7 +176,7 @@ def write_paimon( catalog = CatalogFactory.create(catalog_options) table = catalog.get_table(table_identifier) - dataset = maybe_apply_repartition(dataset, table) + dataset = maybe_apply_repartition(dataset, table, hash_fixed_precluster) datasink = PaimonDatasink(table, overwrite=overwrite) diff --git a/paimon-python/pypaimon/ray/shuffle.py b/paimon-python/pypaimon/ray/shuffle.py index b17f7a7ab1c4..5079bf621582 100644 --- a/paimon-python/pypaimon/ray/shuffle.py +++ b/paimon-python/pypaimon/ray/shuffle.py @@ -16,23 +16,16 @@ # limitations under the License. ################################################################################ -"""Pre-repartition a Ray Dataset by (partition, bucket) before writing -to a Paimon table. - -Without this, Ray's default round-robin block distribution scatters rows -that share the same (partition, bucket) across many Ray tasks. Each -task then opens its own writer and emits its own data file, producing -``partitions × buckets × ray_tasks`` files instead of the -``partitions × buckets`` the writer would naturally produce. - -For HASH_FIXED tables we group rows by ``(partition_keys..., bucket)`` -so every distinct group lands in a single Ray task. ``bucket`` is -computed using the same ``FixedBucketRowKeyExtractor`` the writer -uses, so the bucket assignment seen by the groupby is byte-equivalent -to the writer's. HASH_FIXED writes are always pre-clustered; no user -opt-in is required. - -For any other bucket mode the dataset is returned unchanged. +"""Optional pre-clustering and write guards for Ray writes. + +The legacy ``map_groups`` strategy groups rows by +``(partition_keys..., bucket)`` so every distinct group lands in a +single Ray task. This can reduce file count, but Ray requires each +``map_groups`` group to fit in memory on one node. Keep that strategy +behind an explicit opt-in. + +For append-only tables in any other bucket mode the dataset is returned +unchanged. """ import uuid @@ -51,6 +44,14 @@ # runtime by ``_pick_bucket_col_name`` so user tables that happen to # contain a column with this name still work correctly. BUCKET_KEY_COL = "__paimon_bucket__" +HASH_FIXED_PRECLUSTER_AUTO = "auto" +HASH_FIXED_PRECLUSTER_OFF = "off" +HASH_FIXED_PRECLUSTER_MAP_GROUPS = "map_groups" +HASH_FIXED_PRECLUSTER_MODES = frozenset([ + HASH_FIXED_PRECLUSTER_AUTO, + HASH_FIXED_PRECLUSTER_OFF, + HASH_FIXED_PRECLUSTER_MAP_GROUPS, +]) def _pick_bucket_col_name(existing_names) -> str: @@ -67,14 +68,53 @@ def _pick_bucket_col_name(existing_names) -> str: def maybe_apply_repartition( dataset: "ray.data.Dataset", table: "Table", + hash_fixed_precluster: str = HASH_FIXED_PRECLUSTER_AUTO, ) -> "ray.data.Dataset": - """Cluster rows by ``(partition_keys..., bucket)`` for HASH_FIXED tables. - - For any other bucket mode the dataset is returned unchanged. - HASH_FIXED writes are always pre-clustered, with no user opt-in - required. + """Optionally cluster rows for HASH_FIXED tables. + + ``auto`` currently behaves like ``off`` for append-only tables + because the old ``map_groups`` strategy materializes each + ``(partition, bucket)`` group on one Ray node. For primary-key + tables, unsafe Ray write plans are rejected because multiple Ray + tasks create independent Paimon writers and can assign overlapping + sequence numbers. """ - if table.bucket_mode() != BucketMode.HASH_FIXED: + if hash_fixed_precluster not in HASH_FIXED_PRECLUSTER_MODES: + raise ValueError( + "hash_fixed_precluster must be one of {}, got {!r}".format( + sorted(HASH_FIXED_PRECLUSTER_MODES), + hash_fixed_precluster, + ) + ) + + bucket_mode = table.bucket_mode() + is_primary_key_table = getattr(table, "is_primary_key_table", False) + + if bucket_mode != BucketMode.HASH_FIXED: + if is_primary_key_table and bucket_mode in ( + BucketMode.HASH_DYNAMIC, + BucketMode.CROSS_PARTITION, + ): + raise ValueError( + "{} primary-key Ray writes are not supported. Multiple " + "Ray tasks create independent Paimon writers, which can " + "assign overlapping buckets or sequence numbers.".format( + bucket_mode.name + ) + ) + return dataset + + if hash_fixed_precluster in ( + HASH_FIXED_PRECLUSTER_AUTO, + HASH_FIXED_PRECLUSTER_OFF, + ): + if is_primary_key_table: + raise ValueError( + "HASH_FIXED primary-key Ray writes require " + "hash_fixed_precluster='map_groups'. Direct writes can " + "create overlapping sequence numbers when multiple Ray " + "tasks write the same bucket." + ) return dataset partition_keys = list(table.table_schema.partition_keys or []) diff --git a/paimon-python/pypaimon/read/merge_engine_support.py b/paimon-python/pypaimon/read/merge_engine_support.py index d54cd8b0e4c5..cb5d34ab8721 100644 --- a/paimon-python/pypaimon/read/merge_engine_support.py +++ b/paimon-python/pypaimon/read/merge_engine_support.py @@ -45,23 +45,170 @@ "partial-update.remove-record-on-delete", "partial-update.remove-record-on-sequence-group", ) +# Boolean-valued options that, when truthy, opt the table into the +# retract / delete-removal behaviour the Python +# ``AggregateMergeFunction`` does not implement. +_AGGREGATION_UNSUPPORTED_BOOLEAN_OPTIONS = ( + "aggregation.remove-record-on-delete", +) +# Aggregator identifiers the ``aggregation`` engine knows how to +# build. Duplicated from the registration site in +# ``aggregate/aggregators.py`` so this guard has no import-time +# dependency on the read-pipeline modules; keep both sides in sync +# when adding new aggregators. +_AGGREGATION_SUPPORTED_AGG_FUNCS = frozenset([ + "primary_key", + "last_value", "last_non_null_value", + "first_value", "first_non_null_value", + "sum", "max", "min", + "bool_or", "bool_and", +]) _FIELDS_PREFIX = "fields." _FIELD_SEQUENCE_GROUP_SUFFIX = ".sequence-group" _FIELD_AGGREGATE_FUNCTION_SUFFIX = ".aggregate-function" +_FIELD_IGNORE_RETRACT_SUFFIX = ".ignore-retract" +_FIELD_NESTED_SEQUENCE_SUFFIX = ".nested-sequence-field" _DEFAULT_AGGREGATE_FUNCTION_KEY = "fields.default-aggregate-function" +def _nested_sequence_field_options(table) -> Set[str]: + """Option keys configuring ``nested-sequence-field`` (a per-field + nested sequence ordering distinct from the top-level + ``sequence.field``). pypaimon implements top-level ``sequence.field`` + but not nested sequence fields, so reject them on every PK engine + rather than silently ignoring them. + """ + flagged: Set[str] = set() + raw = table.options.options.to_map() + for key in raw: + if key.startswith(_FIELDS_PREFIX) and key.endswith( + _FIELD_NESTED_SEQUENCE_SUFFIX): + flagged.add(key) + return flagged + + +def _unsupported_sequence_fields(table) -> Set[str]: + """Configured ``sequence.field`` names whose type pypaimon cannot order. + Java's ``UserDefinedSeqComparator`` delegates to ``RecordComparator`` + and supports ARRAY / VECTOR / MAP / MULTISET / ROW, but pypaimon's + ``builtin_seq_comparator`` only compares orderable atomic types. This + flags both complex (non-atomic) types and the atomic-but-unorderable + VARIANT, so a raw-convertible split (which skips the merge reader) can't + silently bypass the limitation. + """ + from pypaimon.read.reader.sort_merge_reader import is_comparable_seq_field + flagged: Set[str] = set() + for field in table.options.sequence_field(): + data_field = table.field_dict.get(field) + if data_field is not None and not is_comparable_seq_field(data_field): + flagged.add(field) + return flagged + + +def check_sequence_field_valid(table) -> None: + """Reject ``sequence.field`` configurations Java forbids at schema + validation (``SchemaValidation.validateSequenceField``), raising + ``ValueError`` to mirror Java's ``IllegalArgumentException``. + + These are invalid configurations, not deferred features, so they are + rejected on every merge engine regardless of pypaimon's read-path + coverage. Mirrors all of Java's checks: + + 1. Every sequence field must exist in the table schema. + 2. No sequence field may be declared more than once. + 3. ``fields..aggregate-function`` on a sequence column: Java + forbids aggregating the sequence column outright. pypaimon's + aggregation engine otherwise silently overrides it with + ``last_value``, hiding the misconfiguration. + 4. ``sequence.field`` together with ``merge-engine=first-row``: + first-row keeps the earliest-written row and never honors a + sequence ordering. + 5. ``sequence.field`` together with cross-partition update (the PK + does not include all partition fields). + """ + sequence_fields = table.options.sequence_field() + if not sequence_fields: + return + + field_names = set(table.field_names) + seen: Set[str] = set() + options_map = table.options.options.to_map() + for field in sequence_fields: + if field not in field_names: + raise ValueError( + "Sequence field: '{}' can not be found in table " + "schema.".format(field) + ) + if field in seen: + raise ValueError( + "Sequence field '{}' is defined repeatedly.".format(field) + ) + seen.add(field) + agg_key = "fields.{}.aggregate-function".format(field) + if options_map.get(agg_key) is not None: + raise ValueError( + "Should not define aggregation on sequence field: '{}' " + "({}).".format(field, agg_key) + ) + + if table.options.merge_engine() == MergeEngine.FIRST_ROW: + raise ValueError( + "Do not support use sequence.field on FIRST_ROW merge engine." + ) + + if table.cross_partition_update: + raise ValueError( + "You can not use sequence.field in cross partition update case " + "(primary keys {} do not include all partition fields " + "{}).".format(table.primary_keys, table.partition_keys) + ) + + def check_supported(table) -> None: """Raise ``NotImplementedError`` if the table's merge-engine - configuration is outside what pypaimon's read path implements. + configuration is outside what pypaimon's read path implements, or + ``ValueError`` if it is an outright-invalid configuration that Java + rejects at schema validation. Non-PK tables are always fine (no merge function involved). """ if not table.is_primary_key_table: return + # ``nested-sequence-field`` is unimplemented on every engine; reject it + # before per-engine dispatch so it can't be silently ignored by the + # top-level ``sequence.field`` comparator. + nested_seq = _nested_sequence_field_options(table) + if nested_seq: + raise NotImplementedError( + "nested-sequence-field is not implemented in pypaimon yet: {}. " + "Top-level 'sequence.field' is supported; open an issue to track " + "nested sequence field support.".format(", ".join(sorted(nested_seq))) + ) + # ``sequence.field`` validity is engine-independent in Java + # (SchemaValidation.validateSequenceField). pypaimon has no + # schema-creation validation, so enforce the same invariants here on + # the read path, before per-engine dispatch. + check_sequence_field_valid(table) + # ``sequence.field`` validity (above) is Java-aligned and engine + # independent. Some field *types* are valid in Java but unimplemented in + # pypaimon's orderable-atomic-only comparator (complex types, plus the + # atomic-but-unorderable VARIANT), so reject them as NotImplementedError + # here -- before per-engine dispatch, so a raw-convertible split can't + # bypass the merge reader and skip the check. + unsupported_seq = _unsupported_sequence_fields(table) + if unsupported_seq: + raise NotImplementedError( + "sequence.field with unsupported type is not implemented in " + "pypaimon yet: {}. pypaimon only supports orderable atomic " + "sequence-field types; complex types (ARRAY / MAP / ROW etc., " + "handled by Java via RecordComparator) and VARIANT are not " + "supported. Open an issue to track support.".format( + ", ".join(sorted(unsupported_seq)))) engine = table.options.merge_engine() if engine == MergeEngine.DEDUPLICATE: return + if engine == MergeEngine.FIRST_ROW: + return if engine == MergeEngine.PARTIAL_UPDATE: unsupported = partial_update_unsupported_options(table) if unsupported: @@ -71,16 +218,35 @@ def check_supported(table) -> None: "supported subset is per-key last-non-null merge with " "no sequence-group, no per-field aggregator override, " "no ignore-delete and no partial-update.remove-record-" - "on-* flags. Use the Java client for the full feature " - "set, or open an issue to track Python support.".format( + "on-* flags. These options are not yet supported; open " + "an issue to track support.".format( ", ".join(sorted(unsupported)) ) ) return + if engine == MergeEngine.AGGREGATE: + unsupported = aggregation_unsupported_options(table) + if unsupported: + raise NotImplementedError( + "merge-engine 'aggregation' is enabled together with " + "options that pypaimon does not yet implement: {}. The " + "supported subset is per-key field aggregation with the " + "built-in aggregators ({}); retract opt-ins " + "(aggregation.remove-record-on-delete, " + "fields..ignore-retract) " + "and other aggregators (product / listagg / collect / " + "merge_map* / nested_update* / theta_sketch / " + "hll_sketch / roaring_bitmap_*) are not yet supported. " + "Open an issue to track support.".format( + ", ".join(sorted(unsupported)), + ", ".join(sorted(_AGGREGATION_SUPPORTED_AGG_FUNCS)), + ) + ) + return raise NotImplementedError( "merge-engine '{}' is not implemented in pypaimon yet " - "(supported: deduplicate, partial-update). Use the Java " - "client or open an issue to track support.".format(engine.value) + "(supported: deduplicate, first-row, partial-update, aggregation). " + "Open an issue to track support.".format(engine.value) ) @@ -104,6 +270,46 @@ def partial_update_unsupported_options(table) -> Set[str]: return flagged +def aggregation_unsupported_options(table) -> Set[str]: + """Return the set of option keys configured on this table that the + ``AggregateMergeFunction`` does not yet support. Empty set means + the configuration is safe to run. + + Three families of options are rejected: + + 1. Retract opt-ins: ``aggregation.remove-record-on-delete`` and + ``fields..ignore-retract`` only make sense in conjunction + with DELETE / UPDATE_BEFORE handling, which the engine does not + implement. + 2. Sequence-group configuration: ``fields..sequence-group`` is not + supported (top-level ``sequence.field`` is honored, see + ``builtin_seq_comparator``). + 3. Out-of-scope aggregator selections: ``fields..aggregate- + function`` and ``fields.default-aggregate-function`` set to an + identifier this engine doesn't support yet (e.g. ``collect``, + ``nested_update``). + """ + flagged: Set[str] = set() + raw = table.options.options.to_map() + for key, value in raw.items(): + if (key in _AGGREGATION_UNSUPPORTED_BOOLEAN_OPTIONS + and _option_is_truthy(value)): + flagged.add(key) + elif key == _DEFAULT_AGGREGATE_FUNCTION_KEY: + if value not in _AGGREGATION_SUPPORTED_AGG_FUNCS: + flagged.add(key) + elif key.startswith(_FIELDS_PREFIX): + if key.endswith(_FIELD_IGNORE_RETRACT_SUFFIX): + if _option_is_truthy(value): + flagged.add(key) + elif key.endswith(_FIELD_SEQUENCE_GROUP_SUFFIX): + flagged.add(key) + elif key.endswith(_FIELD_AGGREGATE_FUNCTION_SUFFIX): + if value not in _AGGREGATION_SUPPORTED_AGG_FUNCS: + flagged.add(key) + return flagged + + def _option_is_truthy(raw) -> bool: if raw is None: return False diff --git a/paimon-python/pypaimon/read/reader/aggregate/__init__.py b/paimon-python/pypaimon/read/reader/aggregate/__init__.py new file mode 100644 index 000000000000..8d5fe75d4816 --- /dev/null +++ b/paimon-python/pypaimon/read/reader/aggregate/__init__.py @@ -0,0 +1,84 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""FieldAggregator registry and factory entry point. + +Looks up the registered factory for an aggregator identifier (``"sum"``, +``"last_value"``, ...) read from table options and builds an instance +for it. Concrete aggregators register themselves at import time via +:func:`register_aggregator`; importing this package eagerly imports the +built-in aggregator module so the registrations always happen, +regardless of which call site triggers the first lookup. +""" + +from typing import Callable, Dict, TYPE_CHECKING + +from pypaimon.read.reader.aggregate.field_aggregator import FieldAggregator +from pypaimon.schema.data_types import DataType + +if TYPE_CHECKING: + from pypaimon.common.options.core_options import CoreOptions + + +# Module-global registry keyed by aggregator identifier +# (``"sum"``, ``"last_value"`` ...). +_FACTORIES: Dict[str, Callable[[DataType, str, "CoreOptions"], FieldAggregator]] = {} + + +def register_aggregator( + identifier: str, + factory: Callable[[DataType, str, "CoreOptions"], FieldAggregator], +) -> None: + """Register ``factory`` under ``identifier``. + + Re-registering an identifier replaces the existing factory. The + built-in aggregators register themselves at module-import time from + :mod:`aggregators`. + """ + _FACTORIES[identifier] = factory + + +def create_field_aggregator( + field_type: DataType, + field_name: str, + agg_func_name: str, + options: "CoreOptions", +) -> FieldAggregator: + """Build a ``FieldAggregator`` for ``agg_func_name``. + + Raises ``ValueError`` if the identifier was never registered, so + typos or out-of-scope aggregators surface at merge-function + construction time rather than at the first row. + """ + factory = _FACTORIES.get(agg_func_name) + if factory is None: + raise ValueError( + "Use unsupported aggregation '{}' or spell aggregate function " + "incorrectly! Supported aggregators in pypaimon: {}".format( + agg_func_name, sorted(_FACTORIES.keys()) + ) + ) + return factory(field_type, field_name, options) + + +# Eager-import the built-in aggregator module so its top-level +# ``register_aggregator(...)`` calls populate ``_FACTORIES`` before any +# caller looks anything up. Placed at the bottom of the module so the +# names ``register_aggregator`` / ``FieldAggregator`` aggregators +# imports back from here are already defined when its import runs. +from pypaimon.read.reader.aggregate import aggregators # noqa: E402, F401 diff --git a/paimon-python/pypaimon/read/reader/aggregate/aggregators.py b/paimon-python/pypaimon/read/reader/aggregate/aggregators.py new file mode 100644 index 000000000000..961cd100912c --- /dev/null +++ b/paimon-python/pypaimon/read/reader/aggregate/aggregators.py @@ -0,0 +1,283 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Built-in :class:`FieldAggregator` implementations. + +Each class registers itself with the global registry at import time +via :func:`register_aggregator`, so importing +``pypaimon.read.reader.aggregate`` makes all of them discoverable. + +This module ships 10 aggregators — the primary-key placeholder plus +the 9 most commonly-used value aggregators: ``primary_key`` / +``last_value`` / ``last_non_null_value`` / ``first_value`` / +``first_non_null_value`` / ``sum`` / ``max`` / ``min`` / ``bool_or`` +/ ``bool_and``. Other aggregators (``product`` / ``listagg`` / +``collect`` / ``merge_map`` / ``nested_update`` / ``theta_sketch`` / +``hll_sketch`` / ``roaring_bitmap_*``) are intentionally deferred — +the registry will report them as unsupported so users see a clear +error rather than a silent fallback. +""" + +from typing import Any + +from pypaimon.read.reader.aggregate import register_aggregator +from pypaimon.read.reader.aggregate.field_aggregator import FieldAggregator +from pypaimon.schema.data_types import AtomicType, DataType + + +# Aggregator identifiers exposed via ``fields..aggregate-function`` +# and ``fields.default-aggregate-function``. +NAME_PRIMARY_KEY = "primary_key" +NAME_LAST_VALUE = "last_value" +NAME_LAST_NON_NULL_VALUE = "last_non_null_value" +NAME_FIRST_VALUE = "first_value" +NAME_FIRST_NON_NULL_VALUE = "first_non_null_value" +NAME_SUM = "sum" +NAME_MAX = "max" +NAME_MIN = "min" +NAME_BOOL_OR = "bool_or" +NAME_BOOL_AND = "bool_and" + + +# Base SQL type names treated as numeric for sum/product-style +# aggregators. NUMERIC / DEC are SQL synonyms accepted by the parser; +# treat them the same as DECIMAL. +_NUMERIC_BASE_TYPES = frozenset([ + "TINYINT", "SMALLINT", "INT", "INTEGER", "BIGINT", + "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC", "DEC", +]) + + +def _atomic_base_name(field_type: DataType): + """Extract the bare SQL type name from an :class:`AtomicType`, + stripping precision arguments (``DECIMAL(10,2)``) and trailing + ``NOT NULL``. Returns ``None`` for non-atomic types so callers can + raise a uniform "unsupported type" error. + """ + if not isinstance(field_type, AtomicType): + return None + raw = field_type.type + head = raw.split('(', 1)[0].split(' ', 1)[0] + return head.upper() + + +def _check_numeric(name: str, field_type: DataType) -> None: + base = _atomic_base_name(field_type) + if base not in _NUMERIC_BASE_TYPES: + raise ValueError( + "Data type for '{}' column must be a numeric type but was " + "'{}'.".format(name, field_type) + ) + + +def _check_boolean(name: str, field_type: DataType) -> None: + base = _atomic_base_name(field_type) + if base != "BOOLEAN": + raise ValueError( + "Data type for '{}' column must be 'BOOLEAN' but was " + "'{}'.".format(name, field_type) + ) + + +# --------------------------------------------------------------------------- +# Aggregator classes +# --------------------------------------------------------------------------- + + +class FieldPrimaryKeyAgg(FieldAggregator): + """Carries the primary-key column through merge unchanged.""" + + def agg(self, accumulator: Any, input_field: Any) -> Any: + return input_field + + +class FieldLastValueAgg(FieldAggregator): + """Latest value wins, including ``None``.""" + + def agg(self, accumulator: Any, input_field: Any) -> Any: + return input_field + + +class FieldLastNonNullValueAgg(FieldAggregator): + """Latest non-null value; ``None`` inputs are absorbed. + + This is the system-wide default aggregator when no per-field + override and no ``fields.default-aggregate-function`` are set. + """ + + def agg(self, accumulator: Any, input_field: Any) -> Any: + return accumulator if input_field is None else input_field + + +class FieldFirstValueAgg(FieldAggregator): + """First value (including ``None``) wins; locks after the first + :meth:`agg` call until the next :meth:`reset`. + """ + + def __init__(self, name: str, field_type: DataType): + super().__init__(name, field_type) + self._initialized = False + + def agg(self, accumulator: Any, input_field: Any) -> Any: + if not self._initialized: + self._initialized = True + return input_field + return accumulator + + def reset(self) -> None: + self._initialized = False + + +class FieldFirstNonNullValueAgg(FieldAggregator): + """First non-null value; locks after the first non-null + :meth:`agg` call until the next :meth:`reset`. + """ + + def __init__(self, name: str, field_type: DataType): + super().__init__(name, field_type) + self._initialized = False + + def agg(self, accumulator: Any, input_field: Any) -> Any: + if not self._initialized and input_field is not None: + self._initialized = True + return input_field + return accumulator + + def reset(self) -> None: + self._initialized = False + + +class FieldSumAgg(FieldAggregator): + """Numeric sum. ``None`` on either side returns the non-null + operand. Python's native ``+`` works uniformly for int / float / + Decimal — the values produced by the pyarrow read path already + arrive as the right Python primitive for the column's SQL type, so + no per-type branching is needed. + """ + + def agg(self, accumulator: Any, input_field: Any) -> Any: + if accumulator is None or input_field is None: + return accumulator if input_field is None else input_field + return accumulator + input_field + + +class FieldMaxAgg(FieldAggregator): + """Maximum value. ``None`` on either side returns the non-null + operand. Uses Python's native ``<`` so any orderable type + (numeric, string, date, datetime, Decimal) works. + """ + + def agg(self, accumulator: Any, input_field: Any) -> Any: + if accumulator is None or input_field is None: + return accumulator if input_field is None else input_field + return input_field if accumulator < input_field else accumulator + + +class FieldMinAgg(FieldAggregator): + """Minimum value. ``None`` on either side returns the non-null + operand. Uses Python's native ``<``. + """ + + def agg(self, accumulator: Any, input_field: Any) -> Any: + if accumulator is None or input_field is None: + return accumulator if input_field is None else input_field + return accumulator if accumulator < input_field else input_field + + +class FieldBoolOrAgg(FieldAggregator): + """Logical OR. ``None`` on either side returns the non-null + operand. + """ + + def agg(self, accumulator: Any, input_field: Any) -> Any: + if accumulator is None or input_field is None: + return accumulator if input_field is None else input_field + return bool(accumulator) or bool(input_field) + + +class FieldBoolAndAgg(FieldAggregator): + """Logical AND. ``None`` on either side returns the non-null + operand. + """ + + def agg(self, accumulator: Any, input_field: Any) -> Any: + if accumulator is None or input_field is None: + return accumulator if input_field is None else input_field + return bool(accumulator) and bool(input_field) + + +# --------------------------------------------------------------------------- +# Registration. Each builder binds an identifier to a factory that +# optionally validates the column DataType before constructing the +# aggregator instance. +# --------------------------------------------------------------------------- + + +def _build_no_type_check(cls, identifier: str): + """Build a factory that accepts any DataType. Used by + ``primary_key`` / ``last_value`` / ``first_value`` variants and by + ``max`` / ``min``, all of which work on any orderable DataType. + """ + def _factory(field_type, field_name, options): + return cls(identifier, field_type) + return _factory + + +def _build_numeric(cls, identifier: str): + def _factory(field_type, field_name, options): + _check_numeric(identifier, field_type) + return cls(identifier, field_type) + return _factory + + +def _build_boolean(cls, identifier: str): + def _factory(field_type, field_name, options): + _check_boolean(identifier, field_type) + return cls(identifier, field_type) + return _factory + + +register_aggregator( + NAME_PRIMARY_KEY, + _build_no_type_check(FieldPrimaryKeyAgg, NAME_PRIMARY_KEY), +) +register_aggregator( + NAME_LAST_VALUE, + _build_no_type_check(FieldLastValueAgg, NAME_LAST_VALUE), +) +register_aggregator( + NAME_LAST_NON_NULL_VALUE, + _build_no_type_check(FieldLastNonNullValueAgg, NAME_LAST_NON_NULL_VALUE), +) +register_aggregator( + NAME_FIRST_VALUE, + _build_no_type_check(FieldFirstValueAgg, NAME_FIRST_VALUE), +) +register_aggregator( + NAME_FIRST_NON_NULL_VALUE, + _build_no_type_check(FieldFirstNonNullValueAgg, NAME_FIRST_NON_NULL_VALUE), +) +register_aggregator(NAME_SUM, _build_numeric(FieldSumAgg, NAME_SUM)) +register_aggregator(NAME_MAX, _build_no_type_check(FieldMaxAgg, NAME_MAX)) +register_aggregator(NAME_MIN, _build_no_type_check(FieldMinAgg, NAME_MIN)) +register_aggregator( + NAME_BOOL_OR, _build_boolean(FieldBoolOrAgg, NAME_BOOL_OR) +) +register_aggregator( + NAME_BOOL_AND, _build_boolean(FieldBoolAndAgg, NAME_BOOL_AND) +) diff --git a/paimon-python/pypaimon/read/reader/aggregate/field_aggregator.py b/paimon-python/pypaimon/read/reader/aggregate/field_aggregator.py new file mode 100644 index 000000000000..b306a4c3f75e --- /dev/null +++ b/paimon-python/pypaimon/read/reader/aggregate/field_aggregator.py @@ -0,0 +1,81 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Per-field aggregator abstraction used by the ``aggregation`` merge +engine. + +Each non-PK field is reduced across rows sharing the same primary key +by one ``FieldAggregator`` instance picked per field. The +``AggregateMergeFunction`` drives the lifecycle: ``reset()`` at the +start of each key group, ``agg()`` per input value, and the final +accumulator is read out to build the merged row. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from pypaimon.schema.data_types import DataType + + +class FieldAggregator(ABC): + """Per-field aggregator base class. + + Concrete subclasses implement :meth:`agg` and may override + :meth:`reset`. :meth:`retract` is intentionally left as the default + "refuse" implementation: pypaimon's ``AggregateMergeFunction`` + rejects ``DELETE`` / ``UPDATE_BEFORE`` rows up-front, so no + aggregator's ``retract`` is reachable from the read path. The hook + is kept so a future PR can add retract semantics without changing + every subclass. + """ + + def __init__(self, name: str, field_type: DataType): + self.name = name + self.field_type = field_type + + @abstractmethod + def agg(self, accumulator: Any, input_field: Any) -> Any: + """Combine ``accumulator`` with ``input_field`` and return the + new accumulator. Called once per row in the key group, in + arrival order (sequence-number ascending). ``accumulator`` is + ``None`` before the first add. + """ + + def reset(self) -> None: + """Reset internal state at the start of a new key group. + + Default is a no-op. Aggregators that carry per-group bookkeeping + beyond the externally-passed accumulator (e.g. ``first_value``'s + "have we seen any row yet?" flag) must override this. + """ + + def retract(self, accumulator: Any, retract_field: Any) -> Any: + """Refuse the retract operation by default. + + ``AggregateMergeFunction`` rejects retract rows at :meth:`add` + time, so this path is currently unreachable from the read + pipeline. The hook is kept for forward-compatibility: a future + PR that wires retract through the merge function can override + this on the aggregators that actually support it (sum, product, + last_value, ...). + """ + raise NotImplementedError( + "Aggregator '{}' does not support retract; the aggregation " + "merge engine does not implement DELETE / UPDATE_BEFORE " + "handling.".format(self.name) + ) diff --git a/paimon-python/pypaimon/read/reader/aggregation_merge_function.py b/paimon-python/pypaimon/read/reader/aggregation_merge_function.py new file mode 100644 index 000000000000..ae666734032a --- /dev/null +++ b/paimon-python/pypaimon/read/reader/aggregation_merge_function.py @@ -0,0 +1,203 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Merge function for the ``aggregation`` merge engine. + +Rows sharing a primary key are folded across each non-PK field by the +per-field :class:`FieldAggregator` configured in table options. +``DeduplicateMergeFunction`` keeps only the latest row; +``PartialUpdateMergeFunction`` lets later writes "fill in" fields the +earlier writes left null; ``AggregateMergeFunction`` runs an actual +aggregation (sum / max / min / last_value / ...) per column. + +This is the **core merge semantics only**. Retract on DELETE / +UPDATE_BEFORE rows (with ``aggregation.remove-record-on-delete`` and +``fields..ignore-retract`` opt-ins) and ~14 additional +aggregators (``product`` / ``listagg`` / ``collect`` / ``merge_map`` / +``nested_update`` / ``theta_sketch`` / ``hll_sketch`` / +``roaring_bitmap_*``) are intentionally deferred. Non-INSERT row +kinds raise ``NotImplementedError`` at :meth:`add` time so we never +silently corrupt data with a half-implemented contract, and +out-of-scope aggregator identifiers / options are rejected up-front in +:mod:`pypaimon.read.merge_engine_support`. +""" + +from typing import Any, List, Optional + +from pypaimon.read.reader.aggregate import create_field_aggregator +from pypaimon.read.reader.aggregate.aggregators import ( + NAME_LAST_NON_NULL_VALUE, + NAME_LAST_VALUE, + NAME_PRIMARY_KEY, +) +from pypaimon.read.reader.aggregate.field_aggregator import FieldAggregator +from pypaimon.schema.data_types import DataField +from pypaimon.table.row.key_value import KeyValue +from pypaimon.table.row.row_kind import RowKind + + +# --------------------------------------------------------------------------- +# Aggregator-list construction helpers. Live in this module (rather than +# in split_read.py) so they can be exercised directly by unit tests. +# --------------------------------------------------------------------------- + + +def resolve_agg_func_name(field_name, primary_keys, options_map, + sequence_fields=()): + """Pick the aggregator identifier for ``field_name`` using the same + precedence as Java ``AggregateMergeFunction.getAggFuncName``: + + 1. Sequence fields use ``last_value`` (no aggregation -- the + sequence column just carries the latest-by-sequence value). + 2. Primary-key columns use ``primary_key`` (identity). + 3. Otherwise, field-level ``fields..aggregate-function`` + overrides everything. + 4. Otherwise, the table-wide ``fields.default-aggregate-function``. + 5. Otherwise, the system default ``last_non_null_value``. + + Sequence fields take precedence over the table-wide + ``fields.default-aggregate-function``, matching Java: the value of a + ``sequence.field`` column must not be aggregated. An *explicit* + ``fields..aggregate-function`` on a sequence column is rejected + up-front (see ``merge_engine_support.check_sequence_field_valid``), so + it never reaches this precedence. + """ + if field_name in sequence_fields: + return NAME_LAST_VALUE + if field_name in primary_keys: + return NAME_PRIMARY_KEY + return ( + options_map.get("fields.{}.aggregate-function".format(field_name)) + or options_map.get("fields.default-aggregate-function") + or NAME_LAST_NON_NULL_VALUE + ) + + +def build_field_aggregators( + value_fields: List[DataField], + primary_keys: List[str], + core_options, +) -> List[FieldAggregator]: + """Build the per-column aggregator list parallel to ``value_fields``. + + Resolves the identifier for each field via :func:`resolve_agg_func_name` + and instantiates the aggregator through the registry. Type validation + for aggregators that care (``sum`` requires numeric, ``bool_or`` / + ``bool_and`` require boolean) runs inside the registered factory, so + misconfigured tables fail here rather than at first row. + """ + options_map = core_options.options.to_map() + pk_set = set(primary_keys) + sequence_fields = set(core_options.sequence_field()) + aggregators = [] + for field in value_fields: + agg_name = resolve_agg_func_name( + field.name, pk_set, options_map, sequence_fields) + aggregators.append( + create_field_aggregator( + field.type, field.name, agg_name, core_options + ) + ) + return aggregators + + +# --------------------------------------------------------------------------- +# Merge function +# --------------------------------------------------------------------------- + + +class AggregateMergeFunction: + """A MergeFunction where the key is the primary key (unique) and + each non-PK column is reduced across the rows for that key by its + configured :class:`FieldAggregator`. + + Follows the same ``MergeFunction`` protocol used by + :class:`SortMergeReaderWithMinHeap`: :meth:`reset` between groups + of same-key rows, :meth:`add` one row at a time (oldest to + newest), :meth:`get_result` after the group is exhausted. + """ + + def __init__(self, + key_arity: int, + value_arity: int, + field_aggregators: List[FieldAggregator]): + if len(field_aggregators) != value_arity: + raise ValueError( + "field_aggregators length {} does not match value_arity " + "{}".format(len(field_aggregators), value_arity) + ) + self._key_arity = key_arity + self._value_arity = value_arity + self._field_aggregators = field_aggregators + # Parallel to value indices. Reset at the start of every key + # group; updated in-place as ``add()`` calls feed rows in. + self._accumulators: List[Any] = [None] * value_arity + # Reference to the most recently added kv. Used only to + # propagate the key + sequence_number into the result row; we + # snapshot those values into a fresh tuple in ``get_result()`` + # so the result is not aliased to upstream's reused KeyValue. + self._latest_kv: Optional[KeyValue] = None + + def reset(self) -> None: + self._accumulators = [None] * self._value_arity + for agg in self._field_aggregators: + agg.reset() + self._latest_kv = None + + def add(self, kv: KeyValue) -> None: + row_kind_byte = kv.value_row_kind_byte + if not RowKind.is_add_byte(row_kind_byte): + # DELETE / UPDATE_BEFORE rows are not supported by this + # merge engine. Refuse them rather than silently swallow + # rows, which would let aggregations diverge from the + # underlying data. + raise NotImplementedError( + "AggregateMergeFunction received a {} row; the " + "aggregation merge engine does not yet implement " + "retract (DELETE / UPDATE_BEFORE) handling. Tables " + "producing such rows are not yet supported." + .format(RowKind(row_kind_byte).to_string()) + ) + + for i, agg in enumerate(self._field_aggregators): + input_val = kv.value.get_field(i) + self._accumulators[i] = agg.agg(self._accumulators[i], input_val) + self._latest_kv = kv + + def get_result(self) -> Optional[KeyValue]: + if self._latest_kv is None: + return None + + kv = self._latest_kv + # Snapshot the key as a fresh tuple — we cannot keep a reference + # to ``kv`` because upstream readers (e.g. KeyValueWrapReader) + # reuse a single KeyValue instance and mutate its underlying + # row_tuple between calls. Building a fresh tuple here means + # the result we return is decoupled from any subsequent + # iteration. + key_values = tuple( + kv.key.get_field(i) for i in range(self._key_arity) + ) + result_row = key_values + ( + kv.sequence_number, + RowKind.INSERT.value, + ) + tuple(self._accumulators) + + result = KeyValue(self._key_arity, self._value_arity) + result.replace(result_row) + return result diff --git a/paimon-python/pypaimon/read/reader/blob_descriptor_convert_reader.py b/paimon-python/pypaimon/read/reader/blob_descriptor_convert_reader.py index 35fe046a03ce..12e975bae5ee 100644 --- a/paimon-python/pypaimon/read/reader/blob_descriptor_convert_reader.py +++ b/paimon-python/pypaimon/read/reader/blob_descriptor_convert_reader.py @@ -15,68 +15,189 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional +from typing import Callable, Optional, Set +import pyarrow from pyarrow import RecordBatch from pypaimon.common.options.core_options import CoreOptions from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader +from pypaimon.table.row.blob import Blob, BlobViewStruct -class BlobDescriptorConvertReader(RecordBatchReader): - def __init__(self, inner: RecordBatchReader, table): +class BlobInlineConvertReader(RecordBatchReader): + """Resolves BlobView and BlobDescriptor fields in record batches. + + Processing is split into two clear stages: + Stage 1 (BlobView resolution): If view fields exist, use a lightweight + prescan reader (only projecting view columns) to collect + BlobViewStructs, bulk-preload their descriptors, then read + full data from the main reader and replace view field values + with the corresponding BlobDescriptor serialized bytes. + Stage 2 (BlobData resolution): Controlled by blob-as-descriptor option. + If false, resolve all BlobDescriptor bytes (from both descriptor + fields and view fields) into real blob data bytes. + If true, return as-is. + """ + + def __init__(self, inner: RecordBatchReader, table, + prescan_reader_factory: Optional[Callable[[Set[str]], RecordBatchReader]] = None): + """ + Args: + inner: The main data reader (reads all columns). + table: The table instance. + prescan_reader_factory: Optional factory that creates a lightweight + reader projecting only the specified field names. Used for + prescan to collect BlobViewStructs without reading all columns. + Signature: (field_names: Set[str]) -> RecordBatchReader + """ self._inner = inner self._table = table - self._descriptor_fields = CoreOptions.blob_descriptor_fields(table.options) + self._prescan_reader_factory = prescan_reader_factory self.file_io = inner.file_io self.blob_field_indices = inner.blob_field_indices + # Preserve original BlobViewStruct bytes when resolve disabled: skip both + # view resolution (Stage 1) and descriptor-to-data resolution (Stage 2). + resolve_enabled = CoreOptions.blob_view_resolve_enabled( + table.options) and self._table.catalog_environment.catalog_loader is not None + self._view_fields = CoreOptions.blob_view_fields(table.options) if resolve_enabled else set() + self._descriptor_fields = CoreOptions.blob_descriptor_fields(table.options) + self._blob_as_descriptor = CoreOptions.blob_as_descriptor(table.options) + self._prescan_done = False + self._blob_view_lookup = None def read_arrow_batch(self) -> Optional[RecordBatch]: - import pyarrow + # Align with Java: only enter blob view resolution when catalog_loader is available + # If catalog_loader is None, skip both Stage 1 (view resolution) and Stage 2 (descriptor resolution) + # This matches Java's behavior in DataEvolutionTableRead.createReader where blob view reader + # is only created when catalogContext != null + if self._view_fields and not self._prescan_done: + self._prescan_view_structs() + batch = self._inner.read_arrow_batch() if batch is None: return None - return self._convert_batch(batch, pyarrow) + # Resolve view fields using the preloaded lookup + if self._view_fields and self._blob_view_lookup is not None: + batch = self._resolve_view_fields(batch, self._blob_view_lookup) + # Resolve BlobDescriptor -> real bytes (if blob-as-descriptor=false) + return self._resolve_descriptor_fields(batch) + + # ------------------------------------------------------------------ + # Stage 1: BlobView prescan (lightweight, only reads view columns) + # ------------------------------------------------------------------ + + def _prescan_view_structs(self): + """Use a lightweight prescan reader (projecting only view columns) to + collect all BlobViewStructs and bulk-preload their descriptors.""" + from pypaimon.table.row.blob import BlobViewStruct + from pypaimon.utils.blob_view_lookup import BlobViewLookup - def _convert_batch(self, batch, pyarrow): - from pypaimon.table.row.blob import Blob, BlobDescriptor + all_view_structs = [] - result = batch - for field_name in self._descriptor_fields: - if field_name not in result.schema.names: + prescan_reader = self._prescan_reader_factory(self._view_fields) + try: + while True: + batch = prescan_reader.read_arrow_batch() + if batch is None: + break + for field_name in self._view_fields: + if field_name not in batch.schema.names: + continue + for value in batch.column(field_name).to_pylist(): + value = self._normalize_blob_to_bytes(value) + if value is None: + continue + if isinstance(value, bytes) and BlobViewStruct.is_blob_view_struct(value): + all_view_structs.append(BlobViewStruct.deserialize(value)) + else: + raise ValueError( + f"Expected BlobViewStruct bytes in view field '{field_name}', " + f"but got non-BlobViewStruct bytes." + ) + finally: + prescan_reader.close() + + # Bulk-preload BlobViewStruct -> BlobDescriptor mapping + if all_view_structs: + self._blob_view_lookup = BlobViewLookup(self._table) + self._blob_view_lookup.preload(all_view_structs) + self._prescan_done = True + + def _resolve_view_fields(self, batch, blob_view_lookup): + """Replace BlobViewStruct bytes in view fields with the corresponding + BlobDescriptor serialized bytes.""" + for field_name in self._view_fields: + if field_name not in batch.schema.names: continue - values = result.column(field_name).to_pylist() + values = [self._normalize_blob_to_bytes(v) for v in batch.column(field_name).to_pylist()] converted_values = [] for value in values: if value is None: converted_values.append(None) continue - if hasattr(value, 'as_py'): - value = value.as_py() - if isinstance(value, str): - value = value.encode('utf-8') - if isinstance(value, bytearray): - value = bytes(value) if not isinstance(value, bytes): converted_values.append(value) continue - try: - descriptor = BlobDescriptor.deserialize(value) - if descriptor.serialize() != value: - converted_values.append(value) - continue - uri_reader = self._table.file_io.uri_reader_factory.create(descriptor.uri) - converted_values.append(Blob.from_descriptor(uri_reader, descriptor).to_data()) - except Exception: + if not BlobViewStruct.is_blob_view_struct(value): converted_values.append(value) + continue + view_struct = BlobViewStruct.deserialize(value) + if blob_view_lookup.resolve_to_null(view_struct): + converted_values.append(None) + else: + descriptor = blob_view_lookup.resolve_descriptor(view_struct) + converted_values.append(descriptor.serialize()) - column_idx = result.schema.names.index(field_name) - result = result.set_column( + column_idx = batch.schema.names.index(field_name) + batch = batch.set_column( column_idx, pyarrow.field(field_name, pyarrow.large_binary(), nullable=True), pyarrow.array(converted_values, type=pyarrow.large_binary()), ) - return result + return batch + + # ------------------------------------------------------------------ + # Stage 2: BlobData resolution (unified exit) + # ------------------------------------------------------------------ + + def _resolve_descriptor_fields(self, batch): + if self._blob_as_descriptor: + return batch + + all_inline_blob_fields = self._descriptor_fields | self._view_fields + for field_name in all_inline_blob_fields: + if field_name not in batch.schema.names: + continue + values = [self._normalize_blob_to_bytes(v) for v in batch.column(field_name).to_pylist()] + converted_values = [] + for value in values: + blob = Blob.from_bytes(value, self._table.file_io) + converted_values.append(blob.to_data() if blob else None) + + column_idx = batch.schema.names.index(field_name) + batch = batch.set_column( + column_idx, + pyarrow.field(field_name, pyarrow.large_binary(), nullable=True), + pyarrow.array(converted_values, type=pyarrow.large_binary()), + ) + return batch + + # ------------------------------------------------------------------ + # Utilities + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_blob_to_bytes(value): + if value is None: + return None + if hasattr(value, 'as_py'): + value = value.as_py() + if isinstance(value, str): + value = value.encode('utf-8') + if isinstance(value, bytearray): + value = bytes(value) + return value def close(self): self._inner.close() diff --git a/paimon-python/pypaimon/read/reader/concat_batch_reader.py b/paimon-python/pypaimon/read/reader/concat_batch_reader.py index 378868716202..67d1a7c40ace 100644 --- a/paimon-python/pypaimon/read/reader/concat_batch_reader.py +++ b/paimon-python/pypaimon/read/reader/concat_batch_reader.py @@ -16,13 +16,17 @@ # under the License. import collections -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import pyarrow as pa import pyarrow.dataset as ds from pyarrow import RecordBatch +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.read.reader.format_blob_reader import BlobRecordIterator from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader +from pypaimon.table.row.blob import Blob +from pypaimon.utils.range import Range _MIN_BATCH_SIZE_TO_REFILL = 1024 @@ -229,3 +233,127 @@ def close(self) -> None: reader.close() except Exception as e: raise IOError("Failed to close inner readers") from e + + +class BlobFallbackBatchReader(RecordBatchReader): + """Resolve blob placeholders by falling back through older blob versions.""" + + def __init__(self, file_reader_suppliers: List[Tuple[DataFileMeta, Callable]], + field_name: str, output_type, row_ranges: Optional[List[Range]] = None, + blob_as_descriptor: bool = False): + self._file_reader_suppliers = file_reader_suppliers + self._field_name = field_name + self._output_type = output_type + self._row_ranges = Range.sort_and_merge_overlap(row_ranges) if row_ranges else None + self._blob_as_descriptor = blob_as_descriptor + self._returned = False + self._readers: List[RecordBatchReader] = [] + + def read_arrow_batch(self) -> Optional[RecordBatch]: + if self._returned: + return None + self._returned = True + + groups: Dict[int, Dict[int, Tuple[object, bool]]] = {} + target_row_ids = self._target_row_ids() + + for file, supplier in self._file_reader_suppliers: + row_ids = self._selected_row_ids(file) + blob_values = self._read_blob_values(file, supplier) + if len(blob_values) != len(row_ids): + raise ValueError( + "Blob fallback reader returned an unexpected row count " + f"for {file.file_name}: expect {len(row_ids)}, got {len(blob_values)}." + ) + if not row_ids: + continue + group = groups.setdefault(file.max_sequence_number, {}) + for row_id, blob in zip(row_ids, blob_values): + if row_id in group: + raise ValueError( + "Blob files within the same max sequence should not overlap." + ) + if blob is None: + group[row_id] = (None, False) + elif blob is Blob.PLACE_HOLDER: + group[row_id] = (None, True) + else: + if self._blob_as_descriptor: + group[row_id] = (blob.to_descriptor().serialize(), False) + else: + group[row_id] = (blob.to_data(), False) + + if not groups: + return None + + result = [] + for row_id in target_row_ids: + found = False + for max_sequence_number in sorted(groups.keys(), reverse=True): + candidate = groups[max_sequence_number].get(row_id) + if candidate is None: + continue + value, is_placeholder = candidate + if not is_placeholder: + result.append(value) + found = True + break + if not found: + raise ValueError("All blob files at the same row id store a placeholder.") + + return pa.RecordBatch.from_arrays( + [pa.array(result, type=self._output_type)], + names=[self._field_name], + ) + + def _target_row_ids(self) -> List[int]: + file_ranges = [ + file.row_id_range() + for file, _ in self._file_reader_suppliers + ] + ranges = [ + Range( + min(row_range.from_ for row_range in file_ranges), + max(row_range.to for row_range in file_ranges), + ) + ] + if self._row_ranges is not None: + ranges = Range.and_(ranges, self._row_ranges) + return self._expand_ranges(ranges) + + def _selected_row_ids(self, file: DataFileMeta) -> List[int]: + ranges = [file.row_id_range()] + if self._row_ranges is not None: + ranges = Range.and_(ranges, self._row_ranges) + return self._expand_ranges(ranges) + + @staticmethod + def _expand_ranges(ranges: List[Range]) -> List[int]: + return [ + row_id + for row_range in ranges + for row_id in range(row_range.from_, row_range.to + 1) + ] + + def _read_blob_values(self, file: DataFileMeta, supplier: Callable) -> List[object]: + reader = supplier() + if reader is None: + return [] + self._readers.append(reader) + try: + iterator = BlobRecordIterator( + reader._file_io, + reader.file_path, + reader.blob_lengths, + reader.blob_offsets, + self._field_name, + reader._input_stream, + ) + return [row.values[0] for row in iterator] + except AttributeError as e: + raise TypeError("Blob fallback reader expects FormatBlobReader suppliers.") from e + + def close(self) -> None: + for reader in self._readers: + reader.close() + self._readers = [] diff --git a/paimon-python/pypaimon/read/reader/data_file_batch_reader.py b/paimon-python/pypaimon/read/reader/data_file_batch_reader.py index 64da0cc8400e..21d1b2a911df 100644 --- a/paimon-python/pypaimon/read/reader/data_file_batch_reader.py +++ b/paimon-python/pypaimon/read/reader/data_file_batch_reader.py @@ -25,7 +25,6 @@ from pypaimon.read.reader.format_blob_reader import FormatBlobReader from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader from pypaimon.schema.data_types import DataField, PyarrowFieldParser -from pypaimon.table.row.blob import Blob from pypaimon.table.special_fields import SpecialFields @@ -40,8 +39,6 @@ def __init__(self, format_reader: RecordBatchReader, index_mapping: List[int], p first_row_id: int, row_tracking_enabled: bool, system_fields: dict, - blob_as_descriptor: bool = False, - blob_descriptor_fields: Optional[set] = None, file_io: Optional[FileIO] = None, row_id_offsets: Optional[List[int]] = None): self.format_reader = format_reader @@ -55,19 +52,7 @@ def __init__(self, format_reader: RecordBatchReader, index_mapping: List[int], p self._row_id_cursor = 0 self.max_sequence_number = max_sequence_number self.system_fields = system_fields - self.blob_as_descriptor = blob_as_descriptor - self.blob_descriptor_fields = blob_descriptor_fields or set() self.file_io = file_io - self.blob_field_names = { - field.name - for field in fields - if hasattr(field.type, 'type') and field.type.type == 'BLOB' - } - self.descriptor_blob_fields = { - field_name - for field_name in self.blob_descriptor_fields - if field_name in self.blob_field_names - } def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch]: if isinstance(self.format_reader, FormatBlobReader): @@ -78,6 +63,16 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch return None if self.partition_info is None and self.index_mapping is None: + # A file written under an older schema (e.g. before an INT -> BIGINT + # promotion or a DECIMAL precision change) yields columns in the + # data file's original types. Without reordering or partition padding + # to rebuild the batch, those old types would otherwise leak through + # here -- returning a type that depends on whether this read happens + # to span newer-schema files, and failing to concatenate when it + # does. Align them to the current read schema, mirroring the rebuild + # path below. + record_batch = self._align_batch_to_read_schema( + record_batch.schema.names, record_batch.columns) if self.row_tracking_enabled and self.system_fields: record_batch = self._assign_row_tracking(record_batch) return record_batch @@ -122,68 +117,41 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch inter_arrays = mapped_arrays inter_names = mapped_names - # to contains 'not null' property - final_fields = [] - for i, name in enumerate(inter_names): - array = inter_arrays[i] - target_field = self.schema_map.get(name) - if not target_field: - target_field = pa.field(name, array.type) - final_fields.append(target_field) - final_schema = pa.schema(final_fields) - record_batch = pa.RecordBatch.from_arrays(inter_arrays, schema=final_schema) + # Rebuild the batch typed by the read schema (carries 'not null' and + # aligns old-schema column types). + record_batch = self._align_batch_to_read_schema(inter_names, inter_arrays) # Handle row tracking fields if self.row_tracking_enabled and self.system_fields: record_batch = self._assign_row_tracking(record_batch) - record_batch = self._convert_descriptor_stored_blob_columns(record_batch) - return record_batch - def _convert_descriptor_stored_blob_columns(self, record_batch: RecordBatch) -> RecordBatch: - if isinstance(self.format_reader, FormatBlobReader): - return record_batch - if not self.descriptor_blob_fields: - return record_batch - - schema_names = set(record_batch.schema.names) - target_fields = [f for f in self.descriptor_blob_fields if f in schema_names] - if not target_fields: - return record_batch - - arrays = list(record_batch.columns) - for field_name in target_fields: - field_idx = record_batch.schema.get_field_index(field_name) - values = record_batch.column(field_idx).to_pylist() - - if self.blob_as_descriptor: - converted = [self._normalize_blob_cell(v) for v in values] - else: - converted = [self._blob_cell_to_data(v) for v in values] - arrays[field_idx] = pa.array(converted, type=pa.large_binary()) - - return pa.RecordBatch.from_arrays(arrays, schema=record_batch.schema) - - @staticmethod - def _normalize_blob_cell(value): - if value is None: - return None - if hasattr(value, 'as_py'): - value = value.as_py() - if isinstance(value, str): - value = value.encode('utf-8') - if isinstance(value, bytearray): - value = bytes(value) - return value - - def _blob_cell_to_data(self, value): - value = self._normalize_blob_cell(value) - if value is None: - return None - if not isinstance(value, bytes): - return value - return Blob.from_bytes(value, self.file_io).to_data() + def _align_batch_to_read_schema(self, names: List[str], arrays: list) -> RecordBatch: + """Build a record batch for ``names``/``arrays`` typed by the read schema. + + Each known field is cast to the current read schema's type, which also + carries the 'not null' property; unknown columns keep the array's own + type. Columns whose type already matches are reused as-is, keeping the + common (non-evolution) path zero-copy. + + Casts use ``safe=False`` to match Java ``CastExecutors`` semantics for + the read-time conversions a user-approved schema evolution implies + (e.g. DECIMAL scale-down or DOUBLE -> INT truncate rather than raise). + Evolution legality is the writer's concern (``DataTypeCasts``); the read + path only materializes the result. + """ + out_arrays = [] + out_fields = [] + for name, array in zip(names, arrays): + target_field = self.schema_map.get(name) + if target_field is None: + target_field = pa.field(name, array.type) + elif array.type != target_field.type: + array = array.cast(target_field.type, safe=False) + out_arrays.append(array) + out_fields.append(target_field) + return pa.RecordBatch.from_arrays(out_arrays, schema=pa.schema(out_fields)) def _assign_row_tracking(self, record_batch: RecordBatch) -> RecordBatch: """Assign row tracking meta fields (_ROW_ID and _SEQUENCE_NUMBER).""" diff --git a/paimon-python/pypaimon/read/reader/deduplicate_merge_function.py b/paimon-python/pypaimon/read/reader/deduplicate_merge_function.py new file mode 100644 index 000000000000..5b669aae5564 --- /dev/null +++ b/paimon-python/pypaimon/read/reader/deduplicate_merge_function.py @@ -0,0 +1,50 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Default merge function for primary-key tables. + +Mirrors Java ``DeduplicateMergeFunction`` -- for a run of KVs sharing +the same primary key, keep only the one with the highest sequence +number (by virtue of ``add`` being called in sequence-number order). +""" + +from typing import Optional + +from pypaimon.table.row.key_value import KeyValue + + +class DeduplicateMergeFunction: + """Keep only the latest KV per primary key. + + Used by both the read path (``SortMergeReaderWithMinHeap``) and the + write path (``KeyValueDataWriter`` in-memory merge buffer) -- the + latter is what enforces the LSM "PK unique within a file" + invariant on flush. + """ + + def __init__(self): + self.latest_kv: Optional[KeyValue] = None + + def reset(self) -> None: + self.latest_kv = None + + def add(self, kv: KeyValue) -> None: + self.latest_kv = kv + + def get_result(self) -> Optional[KeyValue]: + return self.latest_kv diff --git a/paimon-python/pypaimon/read/reader/field_bunch.py b/paimon-python/pypaimon/read/reader/field_bunch.py index 74162cbc8391..f2e0ae756cea 100644 --- a/paimon-python/pypaimon/read/reader/field_bunch.py +++ b/paimon-python/pypaimon/read/reader/field_bunch.py @@ -23,7 +23,9 @@ """ from abc import ABC from typing import List + from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.utils.range import Range class FieldBunch(ABC): @@ -75,6 +77,12 @@ def add(self, file: DataFileMeta) -> None: f"Only {self._file_type_label()} file can be added to " f"a {self._file_type_label()} bunch.") + if self._files and file.write_cols != self._files[0].write_cols: + raise ValueError( + f"All files in a {self._file_type_label()} bunch should " + f"have the same write columns." + ) + if file.first_row_id == self.latest_first_row_id: if file.max_sequence_number >= self.latest_max_sequence_number: raise ValueError( @@ -136,6 +144,50 @@ def files(self) -> List[DataFileMeta]: class BlobBunch(_SpecialFieldBunch): """Files for partial field (blob files).""" + def add(self, file: DataFileMeta) -> None: + if not self._is_special_file(file.file_name): + raise ValueError("Only blob file can be added to a blob bunch.") + if self._files and file.write_cols != self._files[0].write_cols: + raise ValueError("All files in a blob bunch should have the same write columns.") + + self._files.append(file) + merged = Range.sort_and_merge_overlap( + [blob_file.row_id_range() for blob_file in self._files], + True, + True, + ) + self._row_count = sum(row_range.count() for row_range in merged) + if self.expected_row_count >= 0 and self._row_count > self.expected_row_count: + raise ValueError( + f"Blob files row count exceed the expect {self.expected_row_count}" + ) + + def row_count(self) -> int: + merged = Range.sort_and_merge_overlap( + [blob_file.row_id_range() for blob_file in self._files], + True, + True, + ) + row_count = sum(row_range.count() for row_range in merged) + if not self.row_id_push_down: + if len(merged) != 1: + raise ValueError("Blob file bunch should always contain a contiguous row range.") + if self.expected_row_count >= 0 and row_count != self.expected_row_count: + raise ValueError( + "The merged row count of blob file bunch should be aligned " + f"with normal files, expect {self.expected_row_count}, got {row_count}." + ) + return row_count + + def sequential_read_optimize(self) -> bool: + if not self._files: + raise ValueError("Blob bunch should not be empty.") + max_sequence_number = self._files[0].max_sequence_number + return all( + file.max_sequence_number == max_sequence_number + for file in self._files + ) + def _is_special_file(self, file_name: str) -> bool: return DataFileMeta.is_blob_file(file_name) diff --git a/paimon-python/pypaimon/read/reader/first_row_merge_function.py b/paimon-python/pypaimon/read/reader/first_row_merge_function.py new file mode 100644 index 000000000000..f3a91aadcf22 --- /dev/null +++ b/paimon-python/pypaimon/read/reader/first_row_merge_function.py @@ -0,0 +1,55 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Optional + +from pypaimon.table.row.key_value import KeyValue + + +class FirstRowMergeFunction: + """A MergeFunction where key is primary key (unique) and value is the + full record, only keep the first one.""" + + def __init__(self, ignore_delete: bool = False): + self.ignore_delete = ignore_delete + self.first: Optional[KeyValue] = None + + def reset(self) -> None: + self.first = None + + def add(self, kv: KeyValue) -> None: + if not kv.is_add(): + if self.ignore_delete: + return + raise ValueError( + "By default, First row merge engine can not accept " + "DELETE/UPDATE_BEFORE records.\n" + "You can config 'ignore-delete' to ignore the " + "DELETE/UPDATE_BEFORE records." + ) + + if self.first is None: + # Snapshot, don't keep the reference: the caller may pool/reuse + # a single KeyValue and replace() it for the next row (the write + # path's fold does exactly this). Holding the live reference + # would make get_result return the LAST row instead of the + # first, silently turning first-row into last-row. + self.first = kv.copy() + + def get_result(self) -> Optional[KeyValue]: + return self.first diff --git a/paimon-python/pypaimon/read/reader/format_blob_reader.py b/paimon-python/pypaimon/read/reader/format_blob_reader.py index 355fb36dc41d..52f197097f64 100644 --- a/paimon-python/pypaimon/read/reader/format_blob_reader.py +++ b/paimon-python/pypaimon/read/reader/format_blob_reader.py @@ -16,7 +16,7 @@ # under the License. import struct -from typing import List, Optional, Any, Iterator +from typing import List, Optional, Any, Iterator, BinaryIO import pyarrow as pa import pyarrow.dataset as ds @@ -37,33 +37,40 @@ class FormatBlobReader(RecordBatchReader): def __init__(self, file_io: FileIO, file_path: str, read_fields: List[str], full_fields: List[DataField], push_down_predicate: Any, blob_as_descriptor: bool, - batch_size: int = 1024): + batch_size: int = 1024, row_indices: Optional[Any] = None): self._file_io = file_io self._file_path = file_path self._push_down_predicate = push_down_predicate self._blob_as_descriptor = blob_as_descriptor self._batch_size = batch_size - # Get file size - self._file_size = file_io.get_file_size(file_path) # Initialize the low-level blob format reader self.file_path = file_path self.blob_lengths: List[int] = [] self.blob_offsets: List[int] = [] self.returned = False - self._read_index() - - # Set up fields and schema - if len(read_fields) > 1: - raise RuntimeError("Blob reader only supports one field.") - self._fields = read_fields - full_fields_map = {field.name: field for field in full_fields} - projected_data_fields = [full_fields_map[name] for name in read_fields] - self._schema = PyarrowFieldParser.from_paimon_schema(projected_data_fields) - - # Initialize iterator + self._input_stream = None self._blob_iterator = None self._current_batch = None + try: + self._file_size = file_io.get_file_size(file_path) + self._input_stream = file_io.new_input_stream(file_path) + self._read_index() + self._apply_row_indices(row_indices) + if self._blob_as_descriptor: + self._input_stream.close() + self._input_stream = None + + # Set up fields and schema + if len(read_fields) > 1: + raise RuntimeError("Blob reader only supports one field.") + self._fields = read_fields + full_fields_map = {field.name: field for field in full_fields} + projected_data_fields = [full_fields_map[name] for name in read_fields] + self._schema = PyarrowFieldParser.from_paimon_schema(projected_data_fields) + except Exception: + self.close() + raise def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch]: """ @@ -76,7 +83,7 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch self.returned = True batch_iterator = BlobRecordIterator( self._file_io, self.file_path, self.blob_lengths, - self.blob_offsets, self._fields[0] + self.blob_offsets, self._fields[0], self._input_stream ) self._blob_iterator = iter(batch_iterator) read_size = self._batch_size @@ -140,42 +147,66 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch def close(self): self._blob_iterator = None + if self._input_stream is not None: + self._input_stream.close() + self._input_stream = None def _read_index(self) -> None: - with self._file_io.new_input_stream(self.file_path) as f: - # Seek to header: last 5 bytes - f.seek(self._file_size - 5) - header = f.read(5) - - if len(header) != 5: - raise IOError("Invalid blob file: cannot read header") - - # Parse header - index_length = struct.unpack(' None: + if row_indices is None: + return + + selected_lengths = [] + selected_offsets = [] + record_count = len(self.blob_lengths) + for row_index in row_indices: + row_index = int(row_index) + if row_index < 0 or row_index >= record_count: + raise IndexError( + f"Blob row index {row_index} is out of range for file " + f"{self.file_path}, record count: {record_count}." + ) + selected_lengths.append(self.blob_lengths[row_index]) + selected_offsets.append(self.blob_offsets[row_index]) + + self.blob_lengths = selected_lengths + self.blob_offsets = selected_offsets class BlobRecordIterator: @@ -185,9 +216,11 @@ class BlobRecordIterator: PLACE_HOLDER_LENGTH = -2 def __init__(self, file_io: FileIO, file_path: str, blob_lengths: List[int], - blob_offsets: List[int], field_name: str): + blob_offsets: List[int], field_name: str, + input_stream: Optional[BinaryIO] = None): self.file_io = file_io self.file_path = file_path + self.input_stream = input_stream self.field_name = field_name self.blob_lengths = blob_lengths self.blob_offsets = blob_offsets @@ -211,9 +244,28 @@ def __next__(self) -> GenericRow: # Skip magic number (4 bytes) and exclude length (8 bytes) + CRC (4 bytes) = 12 bytes blob_offset = self.blob_offsets[self.current_position] + self.MAGIC_NUMBER_SIZE # Skip magic number blob_length = length - self.METADATA_OVERHEAD - blob = Blob.from_file(self.file_io, self.file_path, blob_offset, blob_length) + if self.input_stream is not None: + blob = Blob.from_data(self._read_inline_blob(blob_offset, blob_length)) + else: + blob = Blob.from_file(self.file_io, self.file_path, blob_offset, blob_length) self.current_position += 1 return GenericRow([blob], fields, RowKind.INSERT) def returned_position(self) -> int: return self.current_position + + def _read_inline_blob(self, position: int, length: int) -> bytes: + self.input_stream.seek(position) + data = self._read_fully(length) + if len(data) != length: + raise IOError("Invalid blob file: cannot read blob data") + return data + + def _read_fully(self, length: int) -> bytes: + data = bytearray() + while len(data) < length: + chunk = self.input_stream.read(length - len(data)) + if not chunk: + break + data.extend(chunk) + return bytes(data) diff --git a/paimon-python/pypaimon/read/reader/format_mosaic_reader.py b/paimon-python/pypaimon/read/reader/format_mosaic_reader.py new file mode 100644 index 000000000000..4ab7d3b69818 --- /dev/null +++ b/paimon-python/pypaimon/read/reader/format_mosaic_reader.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, List, Optional + +import pyarrow as pa +import pyarrow.dataset as ds +from pyarrow import RecordBatch + +from pypaimon.common.file_io import FileIO, supports_pread, pread +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader +from pypaimon.schema.data_types import DataField, PyarrowFieldParser +from pypaimon.table.special_fields import SpecialFields + + +class FormatMosaicReader(RecordBatchReader): + + def __init__(self, file_io: FileIO, file_path: str, read_fields: List[DataField], + push_down_predicate: Any, batch_size: int = 1024): + from mosaic import MosaicReader + + self._read_field_names = [f.name for f in read_fields] + self._batch_size = batch_size + + stream = file_io.new_input_stream(file_path) + file_length = file_io.get_file_size(file_path) + + if supports_pread(stream): + self._stream = stream + + def read_at(offset, length): + return pread(stream, length, offset) + else: + self._stream = None + file_data = stream.read() + stream.close() + file_length = len(file_data) + + def read_at(offset, length): + return file_data[offset:offset + length] + + self._reader = MosaicReader.from_input_file(read_at, file_length) + + file_schema_names = set(f.name for f in self._reader.schema) + self.existing_fields = [f.name for f in read_fields if f.name in file_schema_names] + self.missing_fields = [f.name for f in read_fields if f.name not in file_schema_names] + + if self.existing_fields: + self._reader.project(self.existing_fields) + + if self.missing_fields: + output_schema = PyarrowFieldParser.from_paimon_schema(read_fields) + self._missing_out_fields = [] + for name in self.missing_fields: + idx = output_schema.get_field_index(name) + col_type = output_schema.field(idx).type if idx >= 0 else pa.null() + nullable = not SpecialFields.is_system_field(name) + self._missing_out_fields.append(pa.field(name, col_type, nullable=nullable)) + + self._current_rg = 0 + self._num_row_groups = self._reader.num_row_groups + self._current_batches = None + + if push_down_predicate is not None: + self._predicate = push_down_predicate + else: + self._predicate = None + + def _next_row_group_batches(self): + while self._current_rg < self._num_row_groups: + batch = self._reader.read_row_group(self._current_rg) + self._current_rg += 1 + + if batch.num_rows == 0: + continue + + batch = self._fill_missing_fields(batch) + + if self._predicate is not None: + dataset = ds.InMemoryDataset(pa.Table.from_batches([batch])) + scanner = dataset.scanner(filter=self._predicate, batch_size=self._batch_size) + return scanner.to_reader() + else: + return iter(pa.Table.from_batches([batch]).to_batches( + max_chunksize=self._batch_size)) + return None + + def _fill_missing_fields(self, batch: RecordBatch) -> RecordBatch: + if not self.missing_fields: + return batch + + all_columns = [] + out_fields = [] + for field_name in self._read_field_names: + if field_name in self.existing_fields: + col_idx = self.existing_fields.index(field_name) + all_columns.append(batch.column(col_idx)) + out_fields.append(batch.schema.field(col_idx)) + else: + miss_idx = self.missing_fields.index(field_name) + out_field = self._missing_out_fields[miss_idx] + all_columns.append(pa.nulls(batch.num_rows, type=out_field.type)) + out_fields.append(out_field) + return pa.RecordBatch.from_arrays(all_columns, schema=pa.schema(out_fields)) + + def read_arrow_batch(self) -> Optional[RecordBatch]: + while True: + if self._current_batches is not None: + try: + if hasattr(self._current_batches, 'read_next_batch'): + return self._current_batches.read_next_batch() + else: + return next(self._current_batches) + except StopIteration: + self._current_batches = None + + self._current_batches = self._next_row_group_batches() + if self._current_batches is None: + return None + + def close(self): + if self._stream is not None: + self._stream.close() + self._stream = None + self._reader = None + self._current_batches = None diff --git a/paimon-python/pypaimon/read/reader/format_row_reader.py b/paimon-python/pypaimon/read/reader/format_row_reader.py new file mode 100644 index 000000000000..34a9c663a9a4 --- /dev/null +++ b/paimon-python/pypaimon/read/reader/format_row_reader.py @@ -0,0 +1,469 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import struct +from decimal import Decimal +from typing import Any, List, Optional + +import pyarrow as pa +import pyarrow.dataset as ds +from pyarrow import RecordBatch + +from pypaimon.common.delta_varint_compressor import DeltaVarintCompressor +from pypaimon.common.file_io import FileIO +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader +from pypaimon.schema.data_types import ( + ArrayType, DataField, MapType, MultisetType, PyarrowFieldParser, RowType, VectorType, AtomicType +) + +FOOTER_SIZE = 32 +MAGIC = 0x524F5753 # "ROWS" +VERSION = 1 + + +class FormatRowReader(RecordBatchReader): + + def __init__(self, file_io: FileIO, file_path: str, read_fields: List[str], + full_fields: List[DataField], push_down_predicate: Any, + batch_size: int = 1024, row_indices: Optional[List[int]] = None): + self._file_io = file_io + self._file_path = file_path + self._push_down_predicate = push_down_predicate + self._batch_size = batch_size + + self._file_size = file_io.get_file_size(file_path) + + full_fields_map = {field.name: field for field in full_fields} + self._projected_fields = [full_fields_map[name] for name in read_fields] + self._all_fields = full_fields + self._schema = PyarrowFieldParser.from_paimon_schema(self._projected_fields) + + self._block_compressed_sizes: List[int] = [] + self._block_uncompressed_sizes: List[int] = [] + self._block_row_starts: List[int] = [] + self._block_offsets: List[int] = [] + self._total_row_count = 0 + self._block_count = 0 + self._current_block_idx = 0 + + self._row_indices = row_indices + self._row_indices_pos = 0 + + self._read_metadata() + + if self._row_indices is not None: + self._blocks_to_read = self._compute_blocks_for_indices() + self._blocks_to_read_pos = 0 + else: + self._blocks_to_read = None + + def _compute_blocks_for_indices(self) -> List[tuple]: + """Group row_indices by block. Returns list of (block_idx, [local_row_offsets]).""" + import bisect + result = [] + row_starts = self._block_row_starts + for idx in self._row_indices: + block_idx = bisect.bisect_right(row_starts, idx) - 1 + local_row = idx - row_starts[block_idx] + if result and result[-1][0] == block_idx: + result[-1][1].append(local_row) + else: + result.append((block_idx, [local_row])) + return result + + def read_arrow_batch(self) -> Optional[RecordBatch]: + if self._row_indices is not None: + return self._read_batch_indexed() + return self._read_batch_sequential() + + def _read_batch_sequential(self) -> Optional[RecordBatch]: + if self._current_block_idx >= self._block_count: + return None + + block_data = self._read_and_decompress_block(self._current_block_idx) + self._current_block_idx += 1 + + columns = self._decode_block(block_data) + + if not columns or all(len(col) == 0 for col in columns): + return None + + pydict = {field.name: columns[i] for i, field in enumerate(self._projected_fields)} + table = pa.Table.from_pydict(pydict, self._schema) + + if self._push_down_predicate is not None: + dataset = ds.InMemoryDataset(table) + scanner = dataset.scanner(filter=self._push_down_predicate) + table = scanner.to_table().combine_chunks() + + if table.num_rows == 0: + return self._read_batch_sequential() + + return table.to_batches()[0] + + def _read_batch_indexed(self) -> Optional[RecordBatch]: + if self._blocks_to_read_pos >= len(self._blocks_to_read): + return None + + block_idx, local_rows = self._blocks_to_read[self._blocks_to_read_pos] + self._blocks_to_read_pos += 1 + + block_data = self._read_and_decompress_block(block_idx) + columns = self._decode_block(block_data, row_filter=local_rows) + + if not columns or all(len(col) == 0 for col in columns): + return self._read_batch_indexed() + + pydict = {field.name: columns[i] for i, field in enumerate(self._projected_fields)} + table = pa.Table.from_pydict(pydict, self._schema) + + if self._push_down_predicate is not None: + dataset = ds.InMemoryDataset(table) + scanner = dataset.scanner(filter=self._push_down_predicate) + table = scanner.to_table().combine_chunks() + + if table.num_rows == 0: + return self._read_batch_indexed() + + return table.to_batches()[0] + + def close(self): + pass + + def _read_metadata(self): + with self._file_io.new_input_stream(self._file_path) as f: + f.seek(self._file_size - FOOTER_SIZE) + footer_bytes = f.read(FOOTER_SIZE) + + if len(footer_bytes) != FOOTER_SIZE: + raise IOError("Invalid row file: cannot read footer") + + magic = struct.unpack_from(' bytes: + import zstandard as zstd + + offset = self._block_offsets[block_idx] + compressed_size = self._block_compressed_sizes[block_idx] + + with self._file_io.new_input_stream(self._file_path) as f: + f.seek(offset) + compressed_data = f.read(compressed_size) + + decompressor = zstd.ZstdDecompressor() + uncompressed_size = self._block_uncompressed_sizes[block_idx] + return decompressor.decompress(compressed_data, max_output_size=uncompressed_size) + + def _decode_block(self, block_data: bytes, + row_filter: Optional[List[int]] = None) -> List[List]: + data_len = len(block_data) + row_count = struct.unpack_from(' bool: + v = self.data[self.pos] != 0 + self.pos += 1 + return v + + def read_byte(self) -> int: + v = struct.unpack_from(' int: + v = struct.unpack_from(' int: + v = struct.unpack_from(' int: + v = struct.unpack_from(' float: + v = struct.unpack_from(' float: + v = struct.unpack_from(' int: + result = 0 + shift = 0 + while True: + b = self.data[self.pos] + self.pos += 1 + result |= (b & 0x7F) << shift + if (b & 0x80) == 0: + return result + shift += 7 + + def read_string(self) -> str: + length = self.read_var_int() + s = self.data[self.pos:self.pos + length].decode('utf-8') + self.pos += length + return s + + def read_bytes(self) -> bytes: + length = self.read_var_int() + b = self.data[self.pos:self.pos + length] + self.pos += length + return bytes(b) + + +def _decode_var_int(data: bytes, offset: int) -> tuple: + result = 0 + shift = 0 + pos = offset + while True: + b = data[pos] + pos += 1 + result |= (b & 0x7F) << shift + if (b & 0x80) == 0: + return result, pos - offset + shift += 7 + + +def _read_field(decoder: _RowDecoder, data_type) -> Any: + if isinstance(data_type, AtomicType): + type_name = data_type.type.upper() + if type_name == 'BOOLEAN': + return decoder.read_boolean() + elif type_name == 'TINYINT': + return decoder.read_byte() + elif type_name == 'SMALLINT': + return decoder.read_short() + elif type_name in ('INT', 'INTEGER', 'DATE', 'TIME'): + return decoder.read_int() + elif type_name.startswith('TIME') and not type_name.startswith('TIMESTAMP'): + return decoder.read_int() + elif type_name == 'BIGINT': + return decoder.read_long() + elif type_name == 'FLOAT': + return decoder.read_float() + elif type_name == 'DOUBLE': + return decoder.read_double() + elif type_name == 'STRING' or type_name.startswith('CHAR') or type_name.startswith('VARCHAR'): + return decoder.read_string() + elif type_name == 'BYTES' or type_name.startswith('BINARY') or type_name.startswith('VARBINARY'): + return decoder.read_bytes() + elif type_name == 'BLOB': + return decoder.read_bytes() + elif type_name.startswith('DECIMAL'): + precision, scale = _parse_decimal_params(type_name) + if precision <= 18: + unscaled = decoder.read_long() + return Decimal(unscaled) / Decimal(10 ** scale) + else: + raw = decoder.read_bytes() + unscaled = int.from_bytes(raw, byteorder='big', signed=True) + return Decimal(unscaled) / Decimal(10 ** scale) + elif type_name.startswith('TIMESTAMP'): + precision = _parse_timestamp_precision(type_name) + millis = decoder.read_long() + if precision <= 3: + return millis + else: + nano_of_milli = decoder.read_var_int() + micros = millis * 1000 + nano_of_milli // 1000 + return micros + elif type_name == 'VARIANT': + value_bytes = decoder.read_bytes() + metadata_bytes = decoder.read_bytes() + return {'value': value_bytes, 'metadata': metadata_bytes} + else: + raise ValueError(f"Unsupported atomic type: {type_name}") + + elif isinstance(data_type, ArrayType): + return _read_array(decoder, data_type.element) + + elif isinstance(data_type, VectorType): + return _read_vector(decoder, data_type.element) + + elif isinstance(data_type, MapType): + keys = _read_array_elements(decoder, data_type.key) + values = _read_array_elements(decoder, data_type.value) + return list(zip(keys, values)) + + elif isinstance(data_type, MultisetType): + keys = _read_array_elements(decoder, data_type.element) + counts = _read_array_elements(decoder, AtomicType("INT")) + return list(zip(keys, counts)) + + elif isinstance(data_type, RowType): + return _read_nested_row(decoder, data_type) + + else: + raise ValueError(f"Unsupported data type: {data_type}") + + +def _read_array(decoder: _RowDecoder, element_type) -> list: + return _read_array_elements(decoder, element_type) + + +def _read_array_elements(decoder: _RowDecoder, element_type) -> list: + size = decoder.read_var_int() + null_bitmap_bytes = (size + 7) // 8 + null_bitmap = decoder.data[decoder.pos:decoder.pos + null_bitmap_bytes] + decoder.pos += null_bitmap_bytes + + elements = [] + for i in range(size): + is_null = (null_bitmap[i // 8] & (1 << (i % 8))) != 0 + if is_null: + elements.append(None) + else: + elements.append(_read_field(decoder, element_type)) + return elements + + +def _read_vector(decoder: _RowDecoder, element_type) -> list: + size = decoder.read_var_int() + elements = [] + for _ in range(size): + elements.append(_read_field(decoder, element_type)) + return elements + + +def _read_nested_row(decoder: _RowDecoder, row_type: RowType) -> dict: + fields = row_type.fields + arity = len(fields) + header_size = (arity + 7) // 8 + null_bitmap = decoder.data[decoder.pos:decoder.pos + header_size] + decoder.pos += header_size + + result = {} + for i, field in enumerate(fields): + is_null = (null_bitmap[i // 8] & (1 << (i % 8))) != 0 + if is_null: + result[field.name] = None + else: + result[field.name] = _read_field(decoder, field.type) + return result + + +def _parse_decimal_params(type_name: str) -> tuple: + import re + match = re.fullmatch(r'DECIMAL\((\d+),\s*(\d+)\)', type_name) + if match: + return int(match.group(1)), int(match.group(2)) + match = re.fullmatch(r'DECIMAL\((\d+)\)', type_name) + if match: + return int(match.group(1)), 0 + return 10, 0 + + +def _parse_timestamp_precision(type_name: str) -> int: + import re + match = re.search(r'\((\d+)\)', type_name) + if match: + return int(match.group(1)) + return 6 diff --git a/paimon-python/pypaimon/read/reader/format_vortex_reader.py b/paimon-python/pypaimon/read/reader/format_vortex_reader.py index ac48ca1abe91..fcd50a24f005 100644 --- a/paimon-python/pypaimon/read/reader/format_vortex_reader.py +++ b/paimon-python/pypaimon/read/reader/format_vortex_reader.py @@ -86,29 +86,23 @@ def __init__(self, file_io: FileIO, file_path: str, read_fields: List[DataField] PyarrowFieldParser.from_paimon_schema(read_fields) if read_fields else None ) - # Collect predicate-referenced fields for targeted view type casting - self._cast_fields = predicate_fields if predicate_fields and vortex_expr is not None else set() - @staticmethod - def _cast_view_types(batch: RecordBatch, target_fields: Set[str]) -> RecordBatch: - """Cast string_view/binary_view columns to string/binary, only for target fields.""" - if not target_fields: - return batch + def _cast_view_types(batch: RecordBatch) -> RecordBatch: + """Cast all string_view/binary_view columns to string/binary.""" columns = [] fields = [] changed = False for i in range(batch.num_columns): col = batch.column(i) field = batch.schema.field(i) - if field.name in target_fields: - if col.type == pa.string_view(): - col = col.cast(pa.utf8()) - field = field.with_type(pa.utf8()) - changed = True - elif col.type == pa.binary_view(): - col = col.cast(pa.binary()) - field = field.with_type(pa.binary()) - changed = True + if col.type == pa.string_view(): + col = col.cast(pa.utf8()) + field = field.with_type(pa.utf8()) + changed = True + elif col.type == pa.binary_view(): + col = col.cast(pa.binary()) + field = field.with_type(pa.binary()) + changed = True columns.append(col) fields.append(field) if changed: @@ -118,7 +112,7 @@ def _cast_view_types(batch: RecordBatch, target_fields: Set[str]) -> RecordBatch def read_arrow_batch(self) -> Optional[RecordBatch]: try: batch = next(self.record_batch_reader) - batch = self._cast_view_types(batch, self._cast_fields) + batch = self._cast_view_types(batch) if not self.missing_fields: return batch diff --git a/paimon-python/pypaimon/read/reader/limited_record_reader.py b/paimon-python/pypaimon/read/reader/limited_record_reader.py index 74f2612ebdc0..f78221d3f87b 100644 --- a/paimon-python/pypaimon/read/reader/limited_record_reader.py +++ b/paimon-python/pypaimon/read/reader/limited_record_reader.py @@ -26,6 +26,9 @@ from typing import Optional +from pyarrow import RecordBatch + +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader from pypaimon.read.reader.iface.record_iterator import RecordIterator from pypaimon.read.reader.iface.record_reader import RecordReader @@ -68,3 +71,38 @@ def next(self): return None self._limiter.count += 1 return row + + +class LimitedRecordBatchReader(RecordBatchReader): + """Stop emitting rows once ``limit`` rows have been delivered. + + Unlike ``LimitedRecordReader`` (which inherits ``RecordReader``), + this class inherits ``RecordBatchReader`` so that the + ``isinstance(..., RecordBatchReader)`` gate in TableRead picks the + arrow-batch code path. + """ + + def __init__(self, inner: RecordBatchReader, limit: int): + if limit < 0: + raise ValueError("limit must be non-negative, got %d" % limit) + self._inner = inner + self._limit = limit + self.count = 0 + self.file_io = inner.file_io + self.blob_field_indices = inner.blob_field_indices + self.vector_field_indices = inner.vector_field_indices + + def read_arrow_batch(self) -> Optional[RecordBatch]: + if self.count >= self._limit: + return None + batch = self._inner.read_arrow_batch() + if batch is None: + return None + remaining = self._limit - self.count + if batch.num_rows > remaining: + batch = batch.slice(0, remaining) + self.count += batch.num_rows + return batch + + def close(self) -> None: + self._inner.close() diff --git a/paimon-python/pypaimon/read/reader/partial_update_merge_function.py b/paimon-python/pypaimon/read/reader/partial_update_merge_function.py index 978b48011c54..6fc2d453112b 100644 --- a/paimon-python/pypaimon/read/reader/partial_update_merge_function.py +++ b/paimon-python/pypaimon/read/reader/partial_update_merge_function.py @@ -16,26 +16,22 @@ # limitations under the License. ################################################################################ -""" -Python port of Java's ``PartialUpdateMergeFunction`` -(``paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ -PartialUpdateMergeFunction.java``). +"""Merge function for the ``partial-update`` merge engine on PK tables. -The merge function used by the ``partial-update`` merge engine on PK -tables: rows sharing a primary key are merged left-to-right, taking the -latest non-null value per non-PK field. ``DeduplicateMergeFunction`` -keeps only the latest row; ``PartialUpdateMergeFunction`` instead lets -later writes "fill in" fields the earlier writes left null, so users -can write the same logical record across multiple commits with -different sets of non-null columns. +Rows sharing a primary key are merged left-to-right, taking the latest +non-null value per non-PK field. ``DeduplicateMergeFunction`` keeps +only the latest row; ``PartialUpdateMergeFunction`` instead lets later +writes "fill in" fields the earlier writes left null, so users can +write the same logical record across multiple commits with different +sets of non-null columns. -This is the **core merge semantics only**. The Java implementation also +This is the **core merge semantics only**. The upstream engine also supports per-field aggregator overrides (``fields..aggregate- function``), sequence groups (``fields..sequence-group``), ``ignore-delete``, and ``partial-update.remove-record-on-*`` options. -None of those are implemented yet; non-INSERT row kinds raise -``NotImplementedError`` at ``add`` time so we never silently corrupt -data with a half-implemented contract. +None of those are implemented in pypaimon yet; non-INSERT row kinds +raise ``NotImplementedError`` at ``add`` time so we never silently +corrupt data with a half-implemented contract. """ from typing import Any, List, Optional @@ -56,19 +52,31 @@ class PartialUpdateMergeFunction: """ def __init__(self, key_arity: int, value_arity: int, - nullables: Optional[List[bool]] = None): + nullables: Optional[List[bool]] = None, + value_field_names: Optional[List[str]] = None): self._key_arity = key_arity self._value_arity = value_arity # Per-value-field nullable flags, parallel to value indices. When # ``None``, no nullability check runs (preserves the contract for # direct callers that don't have schema info handy). When given, - # mirrors Java's ``updateNonNullFields`` check: a null input on a - # NOT NULL field raises rather than being silently absorbed. + # the schema's NOT NULL declaration is enforced on every add(): + # a null input on a NOT NULL field raises rather than being + # silently absorbed. if nullables is not None and len(nullables) != value_arity: raise ValueError( "nullables length {} does not match value_arity {}".format( len(nullables), value_arity)) self._nullables = nullables + # Optional value-field names, parallel to value indices. When + # given, the NOT-NULL error message uses the field name instead + # of a bare position to make the failure actionable. + if value_field_names is not None \ + and len(value_field_names) != value_arity: + raise ValueError( + "value_field_names length {} does not match " + "value_arity {}".format( + len(value_field_names), value_arity)) + self._value_field_names = value_field_names # Lazily allocated on first add(); ``None`` means "no rows yet". self._accumulator: Optional[List[Any]] = None # Reference to the most recently added kv. We use it only to @@ -84,23 +92,24 @@ def reset(self) -> None: def add(self, kv: KeyValue) -> None: row_kind_byte = kv.value_row_kind_byte if not RowKind.is_add_byte(row_kind_byte): - # DELETE / UPDATE_BEFORE need ignore-delete or - # partial-update.remove-record-on-delete to be set in Java; - # neither option is wired up in pypaimon yet, so refuse the - # row rather than silently swallow it. + # DELETE / UPDATE_BEFORE require ignore-delete or + # partial-update.remove-record-on-delete to be enabled, + # and neither option is implemented in pypaimon yet, so + # refuse the row rather than silently swallow it. raise NotImplementedError( - "PartialUpdateMergeFunction received a {} row; this " - "Python port does not yet implement the ignore-delete / " - "partial-update.remove-record-on-delete options. Use the " - "Java client for tables that produce DELETE / " - "UPDATE_BEFORE rows.".format(RowKind(row_kind_byte).to_string()) + "PartialUpdateMergeFunction received a {} row; the " + "ignore-delete / partial-update.remove-record-on-delete " + "options needed to handle it are not yet implemented in " + "pypaimon. Tables that produce DELETE / UPDATE_BEFORE " + "rows are not supported here.".format( + RowKind(row_kind_byte).to_string()) ) - # Mirror Java's reset() + updateNonNullFields(): the accumulator - # starts as all-null (equivalent to ``new GenericRow(arity)``) and - # each add() writes non-null inputs; null inputs are absorbed — - # except when the schema marks the field NOT NULL, in which case - # we raise to match Java's IllegalArgumentException check. + # The accumulator starts as all-null and each add() writes + # non-null inputs; null inputs are absorbed -- except when the + # schema marks the field NOT NULL, in which case we raise so + # the violation surfaces at write time instead of producing a + # row that breaks the schema invariant. if self._accumulator is None: self._accumulator = [None] * self._value_arity for i in range(self._value_arity): @@ -108,7 +117,15 @@ def add(self, kv: KeyValue) -> None: if v is not None: self._accumulator[i] = v elif self._nullables is not None and not self._nullables[i]: - raise ValueError("Field {} can not be null".format(i)) + if self._value_field_names is not None: + field_ref = "'{}'".format(self._value_field_names[i]) + else: + field_ref = "at index {}".format(i) + raise ValueError( + "Partial-update received NULL for non-nullable field " + "{}. Declare the field nullable in the table schema " + "if writes can leave it unset, or supply a value." + .format(field_ref)) self._latest_kv = kv def get_result(self) -> Optional[KeyValue]: diff --git a/paimon-python/pypaimon/read/reader/sort_merge_reader.py b/paimon-python/pypaimon/read/reader/sort_merge_reader.py index 56f42b6f3ca6..c525a7592c07 100644 --- a/paimon-python/pypaimon/read/reader/sort_merge_reader.py +++ b/paimon-python/pypaimon/read/reader/sort_merge_reader.py @@ -18,9 +18,11 @@ import heapq from typing import Any, Callable, List, Optional +from pypaimon.read.reader.deduplicate_merge_function import \ + DeduplicateMergeFunction from pypaimon.read.reader.iface.record_iterator import RecordIterator from pypaimon.read.reader.iface.record_reader import RecordReader -from pypaimon.schema.data_types import DataField, Keyword +from pypaimon.schema.data_types import AtomicType, DataField, Keyword from pypaimon.schema.table_schema import TableSchema from pypaimon.table.row.internal_row import InternalRow from pypaimon.table.row.key_value import KeyValue @@ -30,7 +32,8 @@ class SortMergeReaderWithMinHeap(RecordReader): """SortMergeReader implemented with min-heap.""" def __init__(self, readers: List[RecordReader[KeyValue]], schema: TableSchema, - merge_function: Optional[Any] = None): + merge_function: Optional[Any] = None, + seq_comparator: Optional[Callable[[Any, Any], int]] = None): self.next_batch_readers = list(readers) # Default to dedupe so callers that don't pass a merge_function # keep their old behaviour. The merge engine dispatch lives in @@ -38,6 +41,12 @@ def __init__(self, readers: List[RecordReader[KeyValue]], schema: TableSchema, # path; tests or other ad-hoc callers can pass a different # implementation here. self.merge_function = merge_function if merge_function is not None else DeduplicateMergeFunction() + # Optional user-defined sequence comparator (``sequence.field``). + # When set, it breaks key-ties on the value row before the + # file-level sequence number, mirroring Java's + # ``SortMergeReaderWithMinHeap`` + ``UserDefinedSeqComparator``. + # Built by the caller, which knows the value-side schema. + self.seq_comparator = seq_comparator if schema.partition_keys: trimmed_primary_keys = [pk for pk in schema.primary_keys if pk not in schema.partition_keys] @@ -63,7 +72,8 @@ def read_batch(self) -> Optional[RecordIterator]: kv = iterator.next() if kv is not None: element = Element(kv, iterator, reader) - entry = HeapEntry(kv.key, element, self.key_comparator) + entry = HeapEntry(kv.key, element, self.key_comparator, + self.seq_comparator) heapq.heappush(self.min_heap, entry) break @@ -78,6 +88,7 @@ def read_batch(self) -> Optional[RecordIterator]: self.min_heap, self.merge_function, self.key_comparator, + self.seq_comparator, ) def close(self): @@ -93,12 +104,13 @@ def close(self): class SortMergeIterator(RecordIterator): def __init__(self, reader, polled: List['Element'], min_heap, merge_function, - key_comparator): + key_comparator, seq_comparator=None): self.reader = reader self.polled = polled self.min_heap = min_heap self.merge_function = merge_function self.key_comparator = key_comparator + self.seq_comparator = seq_comparator self.released = False def next(self): @@ -112,7 +124,8 @@ def next(self): def _next_impl(self): for element in self.polled: if element.update(): - entry = HeapEntry(element.kv.key, element, self.key_comparator) + entry = HeapEntry(element.kv.key, element, self.key_comparator, + self.seq_comparator) heapq.heappush(self.min_heap, entry) self.polled.clear() @@ -129,22 +142,6 @@ def _next_impl(self): return True -class DeduplicateMergeFunction: - """A MergeFunction where key is primary key (unique) and value is the full record, only keep the latest one.""" - - def __init__(self): - self.latest_kv = None - - def reset(self) -> None: - self.latest_kv = None - - def add(self, kv: KeyValue): - self.latest_kv = kv - - def get_result(self) -> Optional[KeyValue]: - return self.latest_kv - - class Element: def __init__(self, kv: KeyValue, iterator: RecordIterator[KeyValue], reader: RecordReader[KeyValue]): self.kv = kv @@ -168,36 +165,82 @@ def update(self) -> bool: class HeapEntry: - def __init__(self, key: InternalRow, element: Element, key_comparator): + def __init__(self, key: InternalRow, element: Element, key_comparator, + seq_comparator=None): self.key = key self.element = element self.key_comparator = key_comparator + self.seq_comparator = seq_comparator def __lt__(self, other): + # Heap order mirrors Java ``SortMergeReaderWithMinHeap``: user key + # -> user-defined sequence comparator (``sequence.field``) on the + # value row -> file-level sequence number. result = self.key_comparator(self.key, other.key) - if result < 0: - return True - elif result > 0: - return False - - return self.element.kv.sequence_number < other.element.kv.sequence_number - - -def builtin_key_comparator(key_schema: List[DataField]) -> Callable[[Any, Any], int]: - # Precompute comparability flags to avoid repeated type checks - comparable_types = {member.value for member in Keyword if member is not Keyword.VARIANT} - comparable_flags = [field.type.type.split(' ')[0] in comparable_types for field in key_schema] - - def comparator(key1: InternalRow, key2: InternalRow) -> int: - if key1 is None and key2 is None: + if result == 0 and self.seq_comparator is not None: + result = self.seq_comparator( + self.element.kv.value, other.element.kv.value) + if result == 0: + result = self.element.kv.sequence_number - other.element.kv.sequence_number + return result < 0 + + +def _base_type_name(field: DataField) -> str: + """Base type keyword of a field, stripping any ``(precision[, scale])`` + parameters and the ``NOT NULL`` suffix. E.g. ``DECIMAL(10, 2)`` and + ``TIMESTAMP(6)`` map to ``DECIMAL`` / ``TIMESTAMP``. + """ + return field.type.type.split('(')[0].split(' ')[0] + + +# Atomic type keywords pypaimon can order with Python's native comparison +# operators. VARIANT is atomic but has no ordering, so it is excluded -- +# matching Java, which has no VARIANT sequence-field support. +_COMPARABLE_TYPE_NAMES = frozenset( + member.value for member in Keyword if member is not Keyword.VARIANT) + + +def is_comparable_seq_field(field: DataField) -> bool: + """Whether ``field`` can serve as a ``sequence.field`` for pypaimon's + atomic comparator: it must be an ``AtomicType`` whose base type name is + orderable. Complex types (ARRAY / MAP / ROW / ...) and the atomic-but- + unorderable VARIANT both return ``False``. Used by the read-builder + guard to reject unsupported sequence fields up front. + """ + return (isinstance(field.type, AtomicType) + and _base_type_name(field) in _COMPARABLE_TYPE_NAMES) + + +def _row_field_comparator( + fields: List[DataField], + indices: List[int], + ascending: bool = True) -> Callable[[Any, Any], int]: + """Build a comparator over two rows on the given ``indices`` (positions + in ``fields`` / the row's ``get_field``), compared left-to-right. + + Shared by :func:`builtin_key_comparator` (all key fields, ascending) and + :func:`builtin_seq_comparator` (the configured sequence fields, with + sort-order). Comparability is precomputed once. ``None`` rows/values + always sort first, independent of ``ascending`` -- only the comparison + of two non-null values is reversed when ``ascending=False``. This + mirrors Java ``GenerateUtils.generateRowCompare`` built with + ``nullIsLast=false`` (see ``CodeGeneratorImpl#getSortSpec``), where + descending order flips only the non-null value comparison and leaves + nulls sorting first. + """ + comparable_flags = [_base_type_name(fields[idx]) in _COMPARABLE_TYPE_NAMES for idx in indices] + sign = 1 if ascending else -1 + + def comparator(row1: InternalRow, row2: InternalRow) -> int: + if row1 is None and row2 is None: return 0 - if key1 is None: + if row1 is None: return -1 - if key2 is None: + if row2 is None: return 1 - for i, comparable in enumerate(comparable_flags): - val1 = key1.get_field(i) - val2 = key2.get_field(i) + for pos, idx in enumerate(indices): + val1 = row1.get_field(idx) + val2 = row2.get_field(idx) if val1 is None and val2 is None: continue @@ -206,13 +249,73 @@ def comparator(key1: InternalRow, key2: InternalRow) -> int: if val2 is None: return 1 - if not comparable: - raise ValueError(f"Unsupported {key_schema[i].type} comparison") + if not comparable_flags[pos]: + raise ValueError(f"Unsupported {fields[idx].type} comparison") if val1 < val2: - return -1 + return -sign elif val1 > val2: - return 1 + return sign return 0 return comparator + + +def builtin_key_comparator(key_schema: List[DataField]) -> Callable[[Any, Any], int]: + return _row_field_comparator(key_schema, list(range(len(key_schema)))) + + +def builtin_seq_comparator( + value_fields: List[DataField], + sequence_field_names: List[str], + ascending: bool) -> Optional[Callable[[Any, Any], int]]: + """Build a comparator for the user-defined ``sequence.field`` option. + + Compares two *value* rows (the value side of a ``KeyValue``) on the + configured sequence fields, in declaration order, returning a negative + / zero / positive int. Mirrors Java ``UserDefinedSeqComparator``: + + - ``sequence_field_names`` empty -> ``None`` (no comparator; the caller + falls back to the file-level sequence number). + - field names resolve to indices within the value row + (``value_fields`` is the value-side schema, == ``read_type``); + ``get_field(idx)`` indexes the value ``OffsetRow``. + - multiple fields compared left-to-right. + - ``ascending=False`` reverses only the non-null value comparison for + each field; null ordering stays nulls-first regardless of sort order + (mirroring Java's ``nullIsLast=false``). The value rows here carry a + homogeneous sort order, so reversing the final non-null comparison is + equivalent to Java reversing each field. + + A name that does not resolve raises ``ValueError`` -- the read path + injects missing sequence fields into the projection before this runs, + so a miss indicates a wiring bug rather than user error. + + A sequence field whose type pypaimon cannot order raises + ``NotImplementedError``: complex types (ARRAY / VECTOR / MAP / MULTISET / + ROW), which Java handles via ``RecordComparator``, and the atomic-but- + unorderable VARIANT. pypaimon only implements atomic-type comparison + here, so reject these explicitly rather than failing later with an + obscure error. + """ + if not sequence_field_names: + return None + + name_to_index = {field.name: i for i, field in enumerate(value_fields)} + indices = [] + for name in sequence_field_names: + if name not in name_to_index: + raise ValueError( + f"sequence.field '{name}' not found in value fields " + f"{[f.name for f in value_fields]}") + idx = name_to_index[name] + if not is_comparable_seq_field(value_fields[idx]): + raise NotImplementedError( + f"sequence.field '{name}' has unsupported type " + f"{value_fields[idx].type}; pypaimon only supports orderable " + f"atomic sequence-field types. Complex types (ARRAY / MAP / " + f"ROW etc., handled by Java via RecordComparator) and VARIANT " + f"are not supported -- open an issue to track support.") + indices.append(idx) + + return _row_field_comparator(value_fields, indices, ascending) diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index 1d2831c194e3..5424d2d51910 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -285,8 +285,14 @@ def scan(self) -> Plan: def _create_data_evolution_split_generator(self): row_ranges = None score_getter = None + # Fetch snapshot once and share with global index evaluation to avoid + # a duplicate /snapshot REST round-trip (#7513). + manifest_files, snapshot = self.manifest_scanner() + self._scanned_snapshot = snapshot + self._scanned_snapshot_id = snapshot.id if snapshot else None + global_index_result = self._global_index_result if self._global_index_result is not None \ - else self._eval_global_index() + else self._eval_global_index(snapshot) if global_index_result is not None: row_ranges = global_index_result.results().to_range_list() if isinstance(global_index_result, ScoredGlobalIndexResult): @@ -294,10 +300,6 @@ def _create_data_evolution_split_generator(self): if row_ranges is None and self.predicate is not None: row_ranges = _row_ranges_from_predicate(self.predicate) - manifest_files, snapshot = self.manifest_scanner() - self._scanned_snapshot = snapshot - self._scanned_snapshot_id = snapshot.id if snapshot else None - # Filter manifest files by row ranges if available if row_ranges is not None: manifest_files = _filter_manifest_files_by_row_ranges(manifest_files, row_ranges) @@ -324,7 +326,7 @@ def plan_files(self) -> List[ManifestEntry]: return [] return self.read_manifest_entries(manifest_files) - def _eval_global_index(self): + def _eval_global_index(self, snapshot=None): # No filter - nothing to evaluate if self.predicate is None: return None @@ -339,7 +341,8 @@ def _eval_global_index(self): scanner = GlobalIndexScanner.create( self.table, partition_filter=self.partition_key_predicate, - predicate=self.predicate + predicate=self.predicate, + snapshot=snapshot, ) if scanner is None: return None diff --git a/paimon-python/pypaimon/read/split_read.py b/paimon-python/pypaimon/read/split_read.py index e432eca6b122..fe7652743167 100644 --- a/paimon-python/pypaimon/read/split_read.py +++ b/paimon-python/pypaimon/read/split_read.py @@ -20,6 +20,7 @@ from functools import partial from typing import Callable, Dict, List, Optional, Tuple +from pypaimon.common.merge_engine_dispatch import build_merge_function from pypaimon.common.options.core_options import CoreOptions, MergeEngine from pypaimon.common.predicate import Predicate from pypaimon.deletionvectors import ApplyDeletionVectorReader @@ -29,8 +30,9 @@ from pypaimon.read.interval_partition import IntervalPartition, SortedRun from pypaimon.read.partition_info import PartitionInfo from pypaimon.read.push_down_utils import rewrite_predicate_indices, trim_predicate_by_fields -from pypaimon.read.reader.concat_batch_reader import (ConcatBatchReader, - MergeAllBatchReader, DataEvolutionMergeReader) +from pypaimon.read.reader.concat_batch_reader import ( + BlobFallbackBatchReader, ConcatBatchReader, + MergeAllBatchReader, DataEvolutionMergeReader) from pypaimon.read.reader.concat_record_reader import ConcatRecordReader from pypaimon.read.reader.data_file_batch_reader import DataFileBatchReader @@ -39,12 +41,15 @@ from pypaimon.read.reader.field_bunch import BlobBunch, DataBunch, FieldBunch, VectorBunch from pypaimon.read.reader.filter_record_reader import FilterRecordReader from pypaimon.read.reader.format_avro_reader import FormatAvroReader -from pypaimon.read.reader.blob_descriptor_convert_reader import BlobDescriptorConvertReader +from pypaimon.read.reader.blob_descriptor_convert_reader import BlobInlineConvertReader from pypaimon.read.reader.filter_record_batch_reader import FilterRecordBatchReader +from pypaimon.read.reader.limited_record_reader import LimitedRecordBatchReader, LimitedRecordReader from pypaimon.read.reader.row_range_filter_record_reader import RowIdFilterRecordBatchReader from pypaimon.read.reader.format_blob_reader import FormatBlobReader from pypaimon.read.reader.format_lance_reader import FormatLanceReader from pypaimon.read.reader.format_pyarrow_reader import FormatPyArrowReader +from pypaimon.read.reader.format_row_reader import FormatRowReader +from pypaimon.read.reader.format_mosaic_reader import FormatMosaicReader from pypaimon.read.reader.format_vortex_reader import FormatVortexReader from pypaimon.read.reader.iface.record_batch_reader import (RecordBatchReader, RowPositionReader, EmptyRecordBatchReader) @@ -53,10 +58,10 @@ KeyValueUnwrapRecordReader from pypaimon.read.reader.key_value_wrap_reader import KeyValueWrapReader from pypaimon.read.reader.shard_batch_reader import ShardBatchReader -from pypaimon.read.reader.partial_update_merge_function import \ - PartialUpdateMergeFunction -from pypaimon.read.reader.sort_merge_reader import (DeduplicateMergeFunction, - SortMergeReaderWithMinHeap) +from pypaimon.read.reader.aggregation_merge_function import ( + AggregateMergeFunction, build_field_aggregators) +from pypaimon.read.reader.sort_merge_reader import (SortMergeReaderWithMinHeap, + builtin_seq_comparator) from pypaimon.read.push_down_utils import _get_all_fields from pypaimon.read.split import Split from pypaimon.read.sliced_split import SlicedSplit @@ -102,7 +107,8 @@ def __init__( read_type: List[DataField], split: Split, row_tracking_enabled: bool, - nested_name_paths: Optional[List[List[str]]] = None): + nested_name_paths: Optional[List[List[str]]] = None, + limit: Optional[int] = None): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table @@ -112,9 +118,10 @@ def __init__( self.row_tracking_enabled = row_tracking_enabled self.value_arity = len(read_type) self.nested_name_paths = nested_name_paths + self.limit = limit # Snapshot the raw value-side schema before _create_key_value_fields # wraps it, so MergeFileSplitRead can hand per-value-field nullable - # flags to merge functions that mirror Java's NOT-NULL check. + # flags to merge functions that enforce NOT-NULL on every add(). self.value_fields = list(read_type) self.trimmed_primary_key = self.table.trimmed_primary_keys @@ -137,8 +144,8 @@ def __init__( # the space FilterRecordReader actually evaluates against. read_type_names = {f.name for f in read_type} if ( - self.predicate is not None - and _get_all_fields(self.predicate).issubset(read_type_names) + self.predicate is not None + and _get_all_fields(self.predicate).issubset(read_type_names) ): self.predicate_for_reader = rewrite_predicate_indices( self.predicate, read_type @@ -189,13 +196,17 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, batch_size = self.table.options.read_batch_size() - # Convert global row_ranges (IndexedSplit) to local row_indices for Vortex/Lance native pushdown + # Convert global row_ranges (IndexedSplit) to local row_indices for native pushdown. row_indices = None if row_ranges is not None: effective_row_ranges = Range.and_(row_ranges, [file.row_id_range()]) if len(effective_row_ranges) == 0: return EmptyRecordBatchReader() - if file_format in (CoreOptions.FILE_FORMAT_VORTEX, CoreOptions.FILE_FORMAT_LANCE): + row_index_formats = (CoreOptions.FILE_FORMAT_BLOB, + CoreOptions.FILE_FORMAT_VORTEX, + CoreOptions.FILE_FORMAT_LANCE, + CoreOptions.FILE_FORMAT_ROW) + if file_format in row_index_formats: row_indices = [] for r in effective_row_ranges: start = r.from_ - file.first_row_id @@ -236,7 +247,8 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, blob_as_descriptor = CoreOptions.blob_as_descriptor(self.table.options) format_reader = FormatBlobReader(self.table.file_io, file_path, read_file_fields, self.read_fields, read_arrow_predicate, blob_as_descriptor, - batch_size=batch_size) + batch_size=batch_size, + row_indices=row_indices) elif file_format == CoreOptions.FILE_FORMAT_LANCE: if has_nested: raise NotImplementedError( @@ -257,6 +269,13 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, row_indices=row_indices, shard_range=shard_range, predicate_fields=predicate_fields) + elif file_format == CoreOptions.FILE_FORMAT_MOSAIC: + if has_nested: + raise NotImplementedError( + "Nested-field projection is not supported on Mosaic files") + ordered_read_fields = [name_to_field[n] for n in read_file_fields if n in name_to_field] + format_reader = FormatMosaicReader(self.table.file_io, file_path, ordered_read_fields, + read_arrow_predicate, batch_size=batch_size) elif file_format == CoreOptions.FILE_FORMAT_PARQUET or file_format == CoreOptions.FILE_FORMAT_ORC: ordered_read_fields = [name_to_field[n] for n in read_file_fields if n in name_to_field] ordered_nested_paths = ( @@ -268,16 +287,33 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, ordered_read_fields, read_arrow_predicate, batch_size=batch_size, options=self.table.options, nested_name_paths=ordered_nested_paths) + elif file_format == CoreOptions.FILE_FORMAT_ROW: + if has_nested: + raise NotImplementedError( + "Nested-field projection is not supported on ROW files") + file_schema = self.table.schema_manager.get_schema( + file.schema_id) + if file.write_cols: + field_map = {f.name: f for f in file_schema.fields} + row_full_fields = [field_map[n] for n in file.write_cols + if n in field_map] + elif self.table.is_primary_key_table: + row_full_fields = self._create_key_value_fields( + file_schema.fields) + else: + row_full_fields = file_schema.fields + format_reader = FormatRowReader( + self.table.file_io, file_path, read_file_fields, + row_full_fields, + read_arrow_predicate, batch_size=batch_size, + row_indices=row_indices) elif file_format in ('json', 'csv'): raise NotImplementedError( f"Reading '{file_format}' format is not yet supported in Python SDK. " - f"Supported formats: parquet, orc, avro, lance, blob.") + f"Supported formats: parquet, orc, avro, lance, vortex, mosaic, blob, row.") else: raise ValueError(f"Unexpected file format: {file_format}") - blob_as_descriptor = CoreOptions.blob_as_descriptor(self.table.options) - blob_descriptor_fields = CoreOptions.blob_descriptor_fields(self.table.options) - index_mapping = self.create_index_mapping() partition_info = self._create_partition_info() system_fields = SpecialFields.find_system_fields(self.read_fields) @@ -305,8 +341,6 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, effective_first_row_id, row_tracking_enabled, system_fields, - blob_as_descriptor=blob_as_descriptor, - blob_descriptor_fields=blob_descriptor_fields, file_io=self.table.file_io, row_id_offsets=row_indices) else: @@ -320,8 +354,6 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, effective_first_row_id, row_tracking_enabled, system_fields, - blob_as_descriptor=blob_as_descriptor, - blob_descriptor_fields=blob_descriptor_fields, file_io=self.table.file_io, row_id_offsets=row_indices) @@ -588,9 +620,14 @@ def create_reader(self) -> RecordReader: vector_field_indices=_vector_field_indices(self.read_fields)) # if the table is appendonly table, we don't need extra filter, all predicates has pushed down if self.table.is_primary_key_table and self.predicate_for_reader: - return FilterRecordReader(concat_reader, self.predicate_for_reader) + reader = FilterRecordReader(concat_reader, self.predicate_for_reader) + if self.limit is not None: + reader = LimitedRecordReader(reader, self.limit) else: - return concat_reader + reader = concat_reader + if self.limit is not None: + reader = LimitedRecordBatchReader(reader, self.limit) + return reader def _get_all_data_fields(self): if self.row_tracking_enabled: @@ -618,9 +655,18 @@ def __init__( split=split, row_tracking_enabled=row_tracking_enabled, nested_name_paths=None, + limit=limit, ) self.outer_extract_name_paths = outer_extract_name_paths - self.limit = limit + # Built once per split-read (value_fields and options are constant + # for the object's life), not per section. ``None`` when + # ``sequence.field`` is unset, in which case the heap falls back to + # the file-level sequence number. + self.seq_comparator = builtin_seq_comparator( + self.value_fields, + self.table.options.sequence_field(), + self.table.options.sequence_field_sort_order_is_ascending(), + ) def kv_reader_supplier(self, file: DataFileMeta, dv_factory: Optional[Callable] = None) -> RecordReader: file_batch_reader = self.file_reader_supplier(file, True, self._get_final_read_data_fields(), False) @@ -642,31 +688,48 @@ def section_reader_supplier(self, section: List[SortedRun]) -> RecordReader: readers.append(ConcatRecordReader(data_readers)) merge_function = self._build_merge_function() return SortMergeReaderWithMinHeap( - readers, self.table.table_schema, merge_function=merge_function) + readers, self.table.table_schema, merge_function=merge_function, + seq_comparator=self.seq_comparator) def _build_merge_function(self): - """Pick the right MergeFunction implementation for the table's - ``merge-engine`` option. - - The pre-flight checks that reject unsupported engines or option - combinations live in - :func:`pypaimon.read.merge_engine_support.check_supported` and - run at ``TableRead.__init__`` time, so by the point this method - executes only the supported engines are reachable. + """Pick the MergeFunction for the table's ``merge-engine`` option. + + Delegates to the shared dispatch in + ``pypaimon.common.merge_engine_dispatch`` so the read path and + the in-memory merge buffer on the write path cannot drift. + ``AGGREGATE`` is special-cased here because building the per- + field aggregators needs the full ``DataField`` objects, the + full primary-key list and the parsed ``CoreOptions`` -- which + sit outside the dispatch's raw-options contract. The writer- + side merge buffer falls back to dedupe for aggregation anyway + (see :meth:`FileStoreWrite._build_pk_merge_function`), so the + two sides only need to share the simple engines. """ engine = self.table.options.merge_engine() - if engine == MergeEngine.DEDUPLICATE: - return DeduplicateMergeFunction() - if engine == MergeEngine.PARTIAL_UPDATE: - return PartialUpdateMergeFunction( + if engine == MergeEngine.AGGREGATE: + # Use the full primary-key list, not ``trimmed_primary_key``: + # ``value_fields`` still carries partition columns, so any PK + # column that is also a partition column must be recognised + # as PK here. Otherwise a table with + # ``fields.default-aggregate-function`` would apply the + # default aggregator to that partition-PK column. + field_aggregators = build_field_aggregators( + self.value_fields, + self.table.primary_keys, + self.table.options, + ) + return AggregateMergeFunction( key_arity=len(self.trimmed_primary_key), value_arity=self.value_arity, - nullables=[f.type.nullable for f in self.value_fields], + field_aggregators=field_aggregators, ) - # check_supported() rejects everything else at TableRead.__init__. - raise AssertionError( - "unreachable: merge-engine '{}' should have been rejected by " - "merge_engine_support.check_supported".format(engine.value) + return build_merge_function( + engine=engine, + raw_options=self.table.options.options.to_map(), + key_arity=len(self.trimmed_primary_key), + value_arity=self.value_arity, + value_field_nullables=[f.type.nullable for f in self.value_fields], + value_field_names=[f.name for f in self.value_fields], ) def create_reader(self) -> RecordReader: @@ -694,8 +757,6 @@ def create_reader(self) -> RecordReader: blob_field_indices=_blob_field_indices(inner_value_fields), vector_field_indices=_vector_field_indices(inner_value_fields)) if self.limit is not None: - from pypaimon.read.reader.limited_record_reader import \ - LimitedRecordReader reader = LimitedRecordReader(reader, self.limit) return reader @@ -712,7 +773,8 @@ def __init__( read_type: List[DataField], split: Split, row_tracking_enabled: bool, - nested_name_paths: Optional[List[List[str]]] = None): + nested_name_paths: Optional[List[List[str]]] = None, + limit: Optional[int] = None): self.row_ranges = None actual_split = split if isinstance(split, IndexedSplit): @@ -721,6 +783,7 @@ def __init__( super().__init__( table, predicate, read_type, actual_split, row_tracking_enabled, nested_name_paths=nested_name_paths, + limit=limit, ) def _push_down_predicate(self) -> Optional[Predicate]: @@ -729,6 +792,20 @@ def _push_down_predicate(self) -> Optional[Predicate]: return None def create_reader(self) -> RecordReader: + reader = self._create_raw_reader() + + if ((CoreOptions.blob_view_fields(self.table.options) and CoreOptions.blob_view_resolve_enabled( + self.table.options)) + or (not CoreOptions.blob_as_descriptor(self.table.options) + and CoreOptions.blob_descriptor_fields(self.table.options))): + reader = BlobInlineConvertReader( + reader, self.table, + prescan_reader_factory=lambda names: self._create_prescan_reader(names)) + + return reader + + def _create_raw_reader(self) -> RecordReader: + """Core read logic: split_by_row_id -> suppliers -> ConcatBatchReader -> filter.""" files = self.split.files suppliers = [] @@ -760,12 +837,39 @@ def create_reader(self) -> RecordReader: else: reader = merge_reader - if (not CoreOptions.blob_as_descriptor(self.table.options) - and CoreOptions.blob_descriptor_fields(self.table.options)): - reader = BlobDescriptorConvertReader(reader, self.table) + if self.limit is not None: + reader = LimitedRecordBatchReader(reader, self.limit) return reader + def _create_prescan_reader(self, field_names): + """Create a prescan reader by constructing a new DataEvolutionSplitRead + instance that only projects the specified field names. + + Align with Java's configureBlobViewPrescanRead: pass limit to prescan reader + to avoid scanning entire split when there's a LIMIT clause. + """ + from pypaimon.read.reader.iface.record_batch_reader import EmptyRecordBatchReader + + prescan_fields = [f for f in self.read_fields if f.name in field_names] + if not prescan_fields: + return EmptyRecordBatchReader() + + # When there's a normal field predicate, don't push down limit to prescan reader + # because the outer reader will apply predicate+limit filtering, + # while prescan reader would only apply limit without normal field predicate + # TODO support limit+predicate push down + prescan_read = DataEvolutionSplitRead( + table=self.table, + predicate=self.predicate, + read_type=prescan_fields, + split=self.split, + row_tracking_enabled=False, + limit=None if self.predicate else self.limit, + ) + prescan_read.row_ranges = self.row_ranges + return prescan_read._create_raw_reader() + def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta]]: """Split files by firstRowId for data evolution.""" @@ -888,6 +992,27 @@ def _create_union_reader(self, need_merge_files: List[DataFileMeta]) -> RecordRe bunch.files()[0], read_field_names ): r] file_record_readers[i] = MergeAllBatchReader(suppliers, batch_size=batch_size) + elif DataFileMeta.is_blob_file(first_file.file_name): + file_reader_suppliers = [ + ( + file, + partial( + self._create_raw_blob_file_reader, + file=file, + read_fields=read_field_names, + ), + ) + for file in bunch.files() + ] + file_record_readers[i] = BlobFallbackBatchReader( + file_reader_suppliers, + read_fields[0].name, + PyarrowFieldParser.from_paimon_schema( + [read_fields[0]] + ).field(0).type, + self.row_ranges, + CoreOptions.blob_as_descriptor(self.table.options), + ) else: # Create concatenated reader for multiple files suppliers = [ @@ -915,6 +1040,30 @@ def _create_file_reader(self, file: DataFileMeta, read_fields: [str]) -> Optiona row_tracking_enabled=True, row_ranges=self.row_ranges) + def _create_raw_blob_file_reader( + self, file: DataFileMeta, read_fields: [str]) -> Optional[FormatBlobReader]: + row_indices = None + if self.row_ranges is not None: + row_indices = [ + row_id - file.first_row_id + for row_range in Range.and_([file.row_id_range()], self.row_ranges) + for row_id in range(row_range.from_, row_range.to + 1) + ] + if not row_indices: + return None + + file_path = file.external_path if file.external_path else file.file_path + return FormatBlobReader( + self.table.file_io, + file_path, + read_fields, + self.read_fields, + None, + CoreOptions.blob_as_descriptor(self.table.options), + batch_size=self.table.options.read_batch_size(), + row_indices=row_indices, + ) + def _split_field_bunches(self, need_merge_files: List[DataFileMeta]) -> List[FieldBunch]: """Split files into field bunches.""" diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 52a4eaaa7f1a..826b2b4024a1 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -532,6 +532,32 @@ def _create_split_read(self, split: Split) -> SplitRead: # the requested sub-paths back to the user's flat schema. inner_read_type = self._widen_to_top_level_for_merge() outer_extract_name_paths = self.nested_name_paths + + # When the user's projection drops a ``sequence.field``, the merge + # heap can't compare it. Inject the missing sequence field(s) into + # the value row so the comparator resolves, then project them back + # out after merging (mirrors Java MergeFileSplitRead.withReadType + + # projectOuter). Reuses the OuterProjectionRecordReader machinery. + seq_fields = self.table.options.sequence_field() + if seq_fields: + present = {f.name for f in inner_read_type} + missing = [name for name in seq_fields if name not in present] + if missing: + table_fields_by_name = {f.name: f for f in self.table.fields} + extra = [] + for name in missing: + field = table_fields_by_name.get(name) + if field is None: + raise ValueError( + "sequence.field %r not found in table schema" + % (name,)) + extra.append(field) + inner_read_type = list(inner_read_type) + extra + if outer_extract_name_paths is None: + # Drop the injected seq columns: project back to the + # user's requested (flat) columns in order. + outer_extract_name_paths = [ + [f.name] for f in self.read_type] return MergeFileSplitRead( table=self.table, predicate=self.predicate, @@ -554,6 +580,7 @@ def _create_split_read(self, split: Split) -> SplitRead: split=split, row_tracking_enabled=True, nested_name_paths=self.nested_name_paths, + limit=self.limit, ) else: return RawFileSplitRead( @@ -563,6 +590,7 @@ def _create_split_read(self, split: Split) -> SplitRead: split=split, row_tracking_enabled=self.table.options.row_tracking_enabled(), nested_name_paths=self.nested_name_paths, + limit=self.limit, ) def _widen_to_top_level_for_merge(self) -> List[DataField]: diff --git a/paimon-python/pypaimon/read/table_scan.py b/paimon-python/pypaimon/read/table_scan.py index 623261803503..03a1c8b06297 100755 --- a/paimon-python/pypaimon/read/table_scan.py +++ b/paimon-python/pypaimon/read/table_scan.py @@ -58,6 +58,8 @@ def _create_file_scanner(self) -> FileScanner: snapshot_manager = self.table.snapshot_manager() manifest_list_manager = ManifestListManager(self.table) + self._validate_scan_mode() + from pypaimon.snapshot.time_travel_util import TimeTravelUtil, SCAN_KEYS has_time_travel = any(options.contains_key(key) for key in SCAN_KEYS) has_incremental = options.contains(CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP) @@ -158,3 +160,104 @@ def with_slice(self, start_pos, end_pos) -> 'TableScan': def with_global_index_result(self, result) -> 'TableScan': self.file_scanner.with_global_index_result(result) return self + + def _validate_scan_mode(self): + """Validate scan.mode against companion options using a whitelist approach. + + Each StartupMode declares exactly which scan keys are allowed. Any + scan key present but not in the whitelist for the resolved mode is + rejected. This matches Java's SchemaValidation mutual-exclusion matrix. + """ + from pypaimon.common.options.core_options import StartupMode + + core_options = self.table.options + mode = core_options.startup_mode() + options = core_options.options + + has_snapshot_id = options.contains(CoreOptions.SCAN_SNAPSHOT_ID) + has_tag_name = options.contains(CoreOptions.SCAN_TAG_NAME) + has_watermark = options.contains(CoreOptions.SCAN_WATERMARK) + has_timestamp_millis = options.contains(CoreOptions.SCAN_TIMESTAMP_MILLIS) + has_timestamp = options.contains(CoreOptions.SCAN_TIMESTAMP) + has_incremental = options.contains(CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP) + has_file_creation_time = options.contains(CoreOptions.SCAN_FILE_CREATION_TIME_MILLIS) + has_creation_time = options.contains(CoreOptions.SCAN_CREATION_TIME_MILLIS) + + present_keys = [] + if has_snapshot_id: + present_keys.append(CoreOptions.SCAN_SNAPSHOT_ID.key()) + if has_tag_name: + present_keys.append(CoreOptions.SCAN_TAG_NAME.key()) + if has_watermark: + present_keys.append(CoreOptions.SCAN_WATERMARK.key()) + if has_timestamp_millis: + present_keys.append(CoreOptions.SCAN_TIMESTAMP_MILLIS.key()) + if has_timestamp: + present_keys.append(CoreOptions.SCAN_TIMESTAMP.key()) + if has_incremental: + present_keys.append(CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP.key()) + if has_file_creation_time: + present_keys.append(CoreOptions.SCAN_FILE_CREATION_TIME_MILLIS.key()) + if has_creation_time: + present_keys.append(CoreOptions.SCAN_CREATION_TIME_MILLIS.key()) + + # scan.timestamp-millis and scan.timestamp are mutually exclusive + if has_timestamp_millis and has_timestamp: + raise ValueError( + "scan.timestamp-millis and scan.timestamp cannot both be set." + ) + + # Define allowed companion keys per mode + if mode == StartupMode.FROM_TIMESTAMP: + allowed = { + CoreOptions.SCAN_TIMESTAMP_MILLIS.key(), + CoreOptions.SCAN_TIMESTAMP.key(), + } + if not (has_timestamp_millis or has_timestamp): + raise ValueError( + "scan.mode is 'from-timestamp' but neither " + "scan.timestamp-millis nor scan.timestamp is set." + ) + elif mode == StartupMode.FROM_SNAPSHOT_FULL: + allowed = {CoreOptions.SCAN_SNAPSHOT_ID.key()} + if not has_snapshot_id: + raise ValueError( + "scan.mode is 'from-snapshot-full' but scan.snapshot-id is not set." + ) + elif mode == StartupMode.FROM_SNAPSHOT: + allowed = { + CoreOptions.SCAN_SNAPSHOT_ID.key(), + CoreOptions.SCAN_TAG_NAME.key(), + CoreOptions.SCAN_WATERMARK.key(), + } + if not (has_snapshot_id or has_tag_name or has_watermark): + raise ValueError( + "scan.mode is 'from-snapshot' but none of " + "scan.snapshot-id, scan.tag-name, or scan.watermark is set." + ) + elif mode == StartupMode.INCREMENTAL: + allowed = {CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP.key()} + if not has_incremental: + raise ValueError( + "scan.mode is 'incremental' but " + "incremental-between-timestamp is not set." + ) + elif mode in (StartupMode.LATEST_FULL, StartupMode.LATEST): + allowed = set() + elif mode in (StartupMode.COMPACTED_FULL, + StartupMode.FROM_CREATION_TIMESTAMP, + StartupMode.FROM_FILE_CREATION_TIME): + raise ValueError( + f"scan.mode '{mode.value}' is not yet supported in pypaimon." + ) + else: + allowed = set() + + # Reject any scan key that's not in the whitelist for this mode + disallowed = [k for k in present_keys if k not in allowed] + if disallowed: + raise ValueError( + f"scan.mode '{mode.value}' conflicts with: {disallowed}. " + f"Only {sorted(allowed) if allowed else 'no scan keys'} " + f"are allowed for this mode." + ) diff --git a/paimon-python/pypaimon/schema/column_directive_utils.py b/paimon-python/pypaimon/schema/column_directive_utils.py new file mode 100644 index 000000000000..b7127ae86683 --- /dev/null +++ b/paimon-python/pypaimon/schema/column_directive_utils.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Column comment directive utilities for BLOB / VECTOR type conversion. + +Mirrors Java's ColumnDirectiveUtils. Supported directives: + __BLOB_FIELD -> blob-field + __BLOB_DESCRIPTOR_FIELD -> blob-descriptor-field + __BLOB_VIEW_FIELD -> blob-view-field + __BLOB_EXTERNAL_STORAGE_FIELD -> blob-external-storage-field + blob-descriptor-field + __VECTOR_FIELD;dim -> vector-field +""" + +from dataclasses import dataclass +from typing import Dict, List, Optional + +from pypaimon.common.options.core_options import CoreOptions +from pypaimon.schema.data_types import (ArrayType, AtomicType, DataField, + DataType, VectorType) + +BLOB_FIELD_DIRECTIVE = "__BLOB_FIELD" +BLOB_DESCRIPTOR_FIELD_DIRECTIVE = "__BLOB_DESCRIPTOR_FIELD" +BLOB_VIEW_FIELD_DIRECTIVE = "__BLOB_VIEW_FIELD" +BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE = "__BLOB_EXTERNAL_STORAGE_FIELD" +VECTOR_FIELD_DIRECTIVE = "__VECTOR_FIELD" + +_BLOB_DIRECTIVES = [ + (BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE, CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key()), + (BLOB_VIEW_FIELD_DIRECTIVE, CoreOptions.BLOB_VIEW_FIELD.key()), + (BLOB_DESCRIPTOR_FIELD_DIRECTIVE, CoreOptions.BLOB_DESCRIPTOR_FIELD.key()), + (BLOB_FIELD_DIRECTIVE, CoreOptions.BLOB_FIELD.key()), +] + +_BLOB_OPTIONS = [ + CoreOptions.BLOB_FIELD, + CoreOptions.BLOB_DESCRIPTOR_FIELD, + CoreOptions.BLOB_VIEW_FIELD, + CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD, +] + +_VECTOR_OPTIONS = [ + CoreOptions.VECTOR_FIELD, +] + +# Legacy fallback keys that map to canonical option keys. +# Java's blob-descriptor-field has fallback key "blob.stored-descriptor-fields". +_FALLBACK_KEYS = { + CoreOptions.BLOB_DESCRIPTOR_FIELD.key(): ["blob.stored-descriptor-fields"], +} + + +@dataclass +class ParsedDirective: + option_key: str + real_comment: Optional[str] + is_vector: bool + vector_dim: int + + +@dataclass +class ConvertedColumn: + type: DataType + comment: Optional[str] + + +def parse_add_column_comment(comment: Optional[str]) -> Optional[ParsedDirective]: + if comment is None: + return None + comment = comment.strip() + if not comment: + return None + + if comment.startswith("__VECTOR"): + return _parse_vector_directive(comment) + if not comment.startswith("__BLOB"): + return None + + for marker, option_key in _BLOB_DIRECTIVES: + if not comment.startswith(marker): + continue + if len(comment) == len(marker): + real_comment = None + elif comment[len(marker)] == ';': + real_comment = comment[len(marker) + 1:].strip() or None + else: + continue + return ParsedDirective(option_key, real_comment, False, 0) + + raise ValueError( + f"Unsupported BLOB directive in column comment: '{comment}'. " + f"Supported directives are '{BLOB_FIELD_DIRECTIVE}', " + f"'{BLOB_DESCRIPTOR_FIELD_DIRECTIVE}', '{BLOB_VIEW_FIELD_DIRECTIVE}' " + f"and '{BLOB_EXTERNAL_STORAGE_FIELD_DIRECTIVE}'." + ) + + +def _parse_vector_directive(comment: str) -> ParsedDirective: + marker = VECTOR_FIELD_DIRECTIVE + if not comment.startswith(marker): + raise ValueError( + f"Unsupported VECTOR directive in column comment: '{comment}'. " + f"Supported directive is '{VECTOR_FIELD_DIRECTIVE}'." + ) + if len(comment) == len(marker) or comment[len(marker)] != ';': + raise ValueError( + f"VECTOR directive '{comment}' requires a dimension, " + f"e.g. '{VECTOR_FIELD_DIRECTIVE};128' or " + f"'{VECTOR_FIELD_DIRECTIVE};128; my comment'." + ) + + rest = comment[len(marker) + 1:] + semi_pos = rest.find(';') + if semi_pos < 0: + dim_str = rest.strip() + real_comment = None + else: + dim_str = rest[:semi_pos].strip() + real_comment = rest[semi_pos + 1:].strip() or None + + try: + dim = int(dim_str) + except ValueError: + raise ValueError( + f"Expected an integer dimension after '{VECTOR_FIELD_DIRECTIVE};', " + f"but got: '{dim_str}'." + ) + if dim < 1: + raise ValueError(f"Vector dimension must be >= 1, but got: {dim}.") + + option_key = CoreOptions.VECTOR_FIELD.key() + return ParsedDirective(option_key, real_comment, True, dim) + + +def _convert_type(directive: ParsedDirective, field_name: str, source_type: DataType) -> DataType: + if directive.is_vector: + if not isinstance(source_type, ArrayType): + raise ValueError( + f"Column {field_name} declared with a VECTOR directive " + f"must be of ARRAY type, but was {source_type}." + ) + return VectorType(source_type.nullable, source_type.element, directive.vector_dim) + else: + type_name = getattr(source_type, 'type', None) if isinstance(source_type, AtomicType) else None + if type_name not in ('VARBINARY', 'BINARY', 'BYTES', 'BLOB'): + raise ValueError( + f"Column {field_name} declared with a BLOB directive " + f"must be of BYTES, BINARY or BLOB type, but was {source_type}." + ) + return AtomicType('BLOB', source_type.nullable) + + +def _modify_field_options(option_key: str, field_name: str, options: Dict[str, str]): + existing = options.get(option_key) + if not existing: + for fk in _FALLBACK_KEYS.get(option_key, []): + fallback_value = options.pop(fk, None) + if fallback_value: + existing = fallback_value + break + if existing: + options[option_key] = existing + "," + field_name + else: + options[option_key] = field_name + + +def apply_add_column_directive( + comment: Optional[str], + field_name: str, + source_type: DataType, + options: Dict[str, str], +) -> Optional[ConvertedColumn]: + directive = parse_add_column_comment(comment) + if directive is None: + return None + new_type = _convert_type(directive, field_name, source_type) + _modify_field_options(directive.option_key, field_name, options) + if directive.option_key == CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key(): + _modify_field_options(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), field_name, options) + return ConvertedColumn(new_type, directive.real_comment) + + +def apply_directives(fields: List[DataField], options: Dict[str, str]): + """Process comment directives on fields for CREATE TABLE. + + Modifies fields list and options dict in place. + Returns True if any directive was applied. + """ + changed = False + for i, field in enumerate(list(fields)): + converted = apply_add_column_directive( + field.description, field.name, field.type, options + ) + if converted is not None: + changed = True + fields[i] = DataField(field.id, field.name, converted.type, converted.comment) + return changed + + +def _remove_from_csv_option(key: str, field_name: str, options: Dict[str, str]): + existing = options.get(key) + if not existing: + return + parts = [v.strip() for v in existing.split(",") if v.strip() and v.strip() != field_name] + if parts: + options[key] = ",".join(parts) + else: + options.pop(key, None) + + +def remove_dropped_directive_options( + field_name: str, type_root: str, options: Dict[str, str] +): + """Remove directive-managed options when a BLOB or VECTOR column is dropped.""" + if type_root == 'BLOB': + for opt in _BLOB_OPTIONS: + _remove_from_csv_option(opt.key(), field_name, options) + for fk in _FALLBACK_KEYS.get(opt.key(), []): + _remove_from_csv_option(fk, field_name, options) + elif type_root == 'VECTOR': + for opt in _VECTOR_OPTIONS: + _remove_from_csv_option(opt.key(), field_name, options) + options.pop(f"field.{field_name}.vector-dim", None) diff --git a/paimon-python/pypaimon/schema/schema.py b/paimon-python/pypaimon/schema/schema.py index 912966732660..f3a63c88e14c 100644 --- a/paimon-python/pypaimon/schema/schema.py +++ b/paimon-python/pypaimon/schema/schema.py @@ -62,40 +62,6 @@ def from_pyarrow_schema(pa_schema: pa.Schema, partition_keys: Optional[List[str] if field.name in pk_set: field.type.nullable = False - # Check if Blob type exists in the schema - blob_names = [ - field.name for field in fields - if 'blob' in str(field.type).lower() - ] - - if blob_names: - if options is None: - options = {} - - if len(fields) <= len(blob_names): - raise ValueError( - "Table with BLOB type column must have other normal columns." - ) - - required_options = { - CoreOptions.ROW_TRACKING_ENABLED.key(): 'true', - CoreOptions.DATA_EVOLUTION_ENABLED.key(): 'true' - } - - missing_options = [] - for key, expected_value in required_options.items(): - if key not in options or options[key] != expected_value: - missing_options.append(f"{key}='{expected_value}'") - - if missing_options: - raise ValueError( - f"Schema contains Blob type but is missing required options: {', '.join(missing_options)}. " - f"Please add these options to the schema." - ) - - if primary_keys is not None: - raise ValueError("Blob type is not supported with primary key.") - # Check if Vector type with dedicated file format vector_names = [ field.name for field in fields diff --git a/paimon-python/pypaimon/schema/schema_manager.py b/paimon-python/pypaimon/schema/schema_manager.py index 1a01b31295c2..d01549c71bc5 100644 --- a/paimon-python/pypaimon/schema/schema_manager.py +++ b/paimon-python/pypaimon/schema/schema_manager.py @@ -23,6 +23,9 @@ from pypaimon.common.identifier import DEFAULT_MAIN_BRANCH from pypaimon.common.json_util import JSON from pypaimon.common.options import CoreOptions, Options +from pypaimon.schema.column_directive_utils import ( + apply_add_column_directive, apply_directives, + remove_dropped_directive_options) from pypaimon.schema.data_types import AtomicInteger, DataField from pypaimon.schema.schema import Schema from pypaimon.schema.schema_change import (AddColumn, DropColumn, RemoveOption, @@ -50,7 +53,7 @@ def _get_rename_mappings(changes: List[SchemaChange]) -> dict: def _handle_update_column_comment( - change: UpdateColumnComment, new_fields: List[DataField] + change: UpdateColumnComment, new_fields: List[DataField] ): field_name = change.field_names[-1] field_index = _find_field_index(new_fields, field_name) @@ -63,7 +66,7 @@ def _handle_update_column_comment( def _handle_update_column_nullability( - change: UpdateColumnNullability, new_fields: List[DataField] + change: UpdateColumnNullability, new_fields: List[DataField] ): field_name = change.field_names[-1] field_index = _find_field_index(new_fields, field_name) @@ -80,7 +83,7 @@ def _handle_update_column_nullability( def _handle_update_column_type( - change: UpdateColumnType, new_fields: List[DataField] + change: UpdateColumnType, new_fields: List[DataField] ): field_name = change.field_names[-1] field_index = _find_field_index(new_fields, field_name) @@ -107,16 +110,30 @@ def _drop_column_validation(schema: 'TableSchema', change: DropColumn): ) -def _handle_drop_column(change: DropColumn, new_fields: List[DataField]): +def _handle_drop_column(change: DropColumn, new_fields: List[DataField], + new_options: dict): field_name = change.field_names[-1] field_index = _find_field_index(new_fields, field_name) if field_index is None: raise ColumnNotExistException(field_name) + if len(change.field_names) == 1: + field = new_fields[field_index] + type_root = _get_type_root(field.type) + remove_dropped_directive_options(field_name, type_root, new_options) new_fields.pop(field_index) if not new_fields: raise ValueError("Cannot drop all fields in table") +def _get_type_root(data_type) -> str: + from pypaimon.schema.data_types import AtomicType, VectorType + if isinstance(data_type, VectorType): + return 'VECTOR' + if isinstance(data_type, AtomicType) and data_type.type == 'BLOB': + return 'BLOB' + return getattr(data_type, 'type', '') + + def _assert_not_updating_partition_keys( schema: 'TableSchema', field_names: List[str], operation: str): if len(field_names) > 1: @@ -149,6 +166,71 @@ def _assert_not_renaming_blob_column( ) +def _validate_blob_fields(fields: List[DataField], options: dict, primary_keys: List[str]): + """Validate blob field configurations in the schema.""" + if options is None: + options = {} + + blob_field_names = { + field.name for field in fields + if getattr(field.type, 'type', None) == 'BLOB' + } + + if len(fields) <= len(blob_field_names): + raise ValueError( + "Table with BLOB type column must have other normal columns." + ) + + core_options = CoreOptions(Options(options)) + + configured_blob_fields = core_options.blob_field() + for field in configured_blob_fields: + if field not in blob_field_names: + raise ValueError( + "Field '{}' in '{}' must be a BLOB field in table schema.".format( + field, CoreOptions.BLOB_FIELD.key() + ) + ) + + descriptor_fields = core_options.blob_descriptor_fields() + view_fields = core_options.blob_view_fields() + + all_inline_fields = descriptor_fields.union(view_fields) + non_blob_inline_fields = all_inline_fields.difference(blob_field_names) + if non_blob_inline_fields: + raise ValueError( + "Fields in 'blob-descriptor-field' or 'blob-view-field' must be blob fields " + "in schema. Non-BLOB fields: {}".format(sorted(non_blob_inline_fields)) + ) + + overlapping_inline_fields = descriptor_fields.intersection(view_fields) + if overlapping_inline_fields: + raise ValueError( + "Fields in 'blob-descriptor-field' and 'blob-view-field' must not overlap. " + "Overlapping fields: {}".format(sorted(overlapping_inline_fields)) + ) + + if blob_field_names: + required_options = { + CoreOptions.ROW_TRACKING_ENABLED.key(): 'true', + CoreOptions.DATA_EVOLUTION_ENABLED.key(): 'true' + } + + missing_options = [] + for key, expected_value in required_options.items(): + if key not in options or options[key] != expected_value: + missing_options.append(f"{key}='{expected_value}'") + + if missing_options: + raise ValueError( + f"Schema contains Blob type but is missing required options: {', '.join(missing_options)}. " + f"Please add these options to the schema." + ) + + if primary_keys: + raise ValueError("Blob type is not supported with primary key.") + + def _validate_blob_external_storage_fields(fields: List[DataField], options: dict): """Validate blob-external-storage-field configuration. @@ -242,7 +324,8 @@ def _handle_add_column( new_fields: List[DataField], highest_field_id: AtomicInteger, partition_keys: List[str], - add_column_before_partition: bool + add_column_before_partition: bool, + new_options: dict ): if not change.data_type.nullable: raise ValueError( @@ -252,7 +335,20 @@ def _handle_add_column( field_name = change.field_names[-1] if _find_field_index(new_fields, field_name) is not None: raise ColumnAlreadyExistException(field_name) - new_field = DataField(field_id, field_name, change.data_type, change.comment) + + data_type = change.data_type + comment = change.comment + converted = apply_add_column_directive(comment, field_name, data_type, new_options) + if converted is not None: + if len(change.field_names) > 1: + raise ValueError( + f"Comment directive cannot be used on a nested column " + f"{'.'.join(change.field_names)}." + ) + data_type = converted.type + comment = converted.comment + + new_field = DataField(field_id, field_name, data_type, comment) if change.move: _apply_move(new_fields, new_field, change.move) elif ( @@ -322,6 +418,18 @@ def create_table(self, schema: Schema) -> TableSchema: if latest is not None: raise RuntimeError("Schema in filesystem exists, creation is not allowed.") + fields = list(schema.fields) + options = dict(schema.options) + apply_directives(fields, options) + schema = Schema( + fields=fields, + partition_keys=schema.partition_keys, + primary_keys=schema.primary_keys, + options=options, + comment=schema.comment, + ) + + _validate_blob_fields(schema.fields, schema.options, schema.primary_keys) _validate_blob_external_storage_fields(schema.fields, schema.options) table_schema = TableSchema.from_schema(schema_id=0, schema=schema) success = self.commit(table_schema) @@ -329,6 +437,7 @@ def create_table(self, schema: Schema) -> TableSchema: return table_schema def commit(self, new_schema: TableSchema) -> bool: + _validate_blob_fields(new_schema.fields, new_schema.options, new_schema.primary_keys) schema_path = self._to_schema_path(new_schema.id) try: result = self.file_io.try_to_write_atomic(schema_path, JSON.to_json(new_schema, indent=2)) @@ -419,7 +528,8 @@ def _generate_table_schema( elif isinstance(change, AddColumn): _handle_add_column( change, new_fields, highest_field_id, - partition_keys, add_column_before_partition + partition_keys, add_column_before_partition, + new_options ) elif isinstance(change, RenameColumn): _assert_not_updating_partition_keys( @@ -429,7 +539,7 @@ def _generate_table_schema( _handle_rename_column(change, new_fields) elif isinstance(change, DropColumn): _drop_column_validation(old_table_schema, change) - _handle_drop_column(change, new_fields) + _handle_drop_column(change, new_fields, new_options) elif isinstance(change, UpdateColumnType): _assert_not_updating_partition_keys( old_table_schema, change.field_names, "update" diff --git a/paimon-python/pypaimon/table/file_store_table.py b/paimon-python/pypaimon/table/file_store_table.py index 67be2587b7e8..9fd61abed49b 100644 --- a/paimon-python/pypaimon/table/file_store_table.py +++ b/paimon-python/pypaimon/table/file_store_table.py @@ -81,6 +81,10 @@ def from_path(cls, table_path: str) -> 'FileStoreTable': return cls(file_io, identifier, table_path, table_schema) + def schema(self) -> TableSchema: + """Get the table schema.""" + return self.table_schema + def current_branch(self) -> str: """Get the current branch name from the identifier.""" return self.identifier.get_branch_name_or_default() @@ -113,14 +117,15 @@ def branch_manager(self): """Get the branch manager for this table.""" # If catalog environment has a catalog loader, use CatalogBranchManager catalog_loader = self.catalog_environment.catalog_loader - if catalog_loader is not None: + if catalog_loader is not None and self.catalog_environment.supports_version_management: from pypaimon.branch.catalog_branch_manager import CatalogBranchManager return CatalogBranchManager( catalog_loader, self.identifier ) # Otherwise, use FileSystemBranchManager - from pypaimon.branch.filesystem_branch_manager import FileSystemBranchManager + from pypaimon.branch.filesystem_branch_manager import \ + FileSystemBranchManager current_branch = self.current_branch() or "main" return FileSystemBranchManager( self.file_io, @@ -376,6 +381,8 @@ def path_factory(self) -> 'FileStorePathFactory': file_compression=file_compression, data_file_path_directory=None, external_paths=external_paths, + external_path_strategy=self.options.data_file_external_paths_strategy(), + external_path_weights=self.options.data_file_external_paths_weights(), index_file_in_data_file_dir=False, ) @@ -413,11 +420,13 @@ def new_stream_write_builder(self) -> StreamWriteBuilder: return StreamWriteBuilder(self) def new_full_text_search_builder(self) -> 'FullTextSearchBuilder': - from pypaimon.table.source.full_text_search_builder import FullTextSearchBuilderImpl + from pypaimon.table.source.full_text_search_builder import \ + FullTextSearchBuilderImpl return FullTextSearchBuilderImpl(self) def new_vector_search_builder(self) -> 'VectorSearchBuilder': - from pypaimon.table.source.vector_search_builder import VectorSearchBuilderImpl + from pypaimon.table.source.vector_search_builder import \ + VectorSearchBuilderImpl return VectorSearchBuilderImpl(self) def create_row_key_extractor(self) -> RowKeyExtractor: @@ -492,6 +501,7 @@ def _try_time_travel(self, options: Options) -> Optional[TableSchema]: def _create_external_paths(self) -> List[str]: from urllib.parse import urlparse + from pypaimon.common.options.core_options import ExternalPathStrategy external_paths_str = self.options.data_file_external_paths() diff --git a/paimon-python/pypaimon/table/row/blob.py b/paimon-python/pypaimon/table/row/blob.py index 43391775bd8d..eb2f00b76471 100644 --- a/paimon-python/pypaimon/table/row/blob.py +++ b/paimon-python/pypaimon/table/row/blob.py @@ -18,9 +18,10 @@ import io import struct from abc import ABC, abstractmethod -from typing import BinaryIO, Optional, Union +from typing import BinaryIO, Callable, Optional, Union from urllib.parse import urlparse +from pypaimon.common.identifier import Identifier from pypaimon.common.uri_reader import UriReader, FileUriReader @@ -162,6 +163,115 @@ def __repr__(self) -> str: return self.__str__() +class BlobViewStruct: + CURRENT_VERSION = 1 + MAGIC = 0x424C4F4256494557 # "BLOBVIEW" + + def __init__(self, identifier: Union[Identifier, str], field_id: int, row_id: int): + if isinstance(identifier, str): + identifier = Identifier.from_string(identifier) + if not isinstance(identifier, Identifier): + raise TypeError("BlobViewStruct identifier must be Identifier or str.") + self._identifier = identifier + self._field_id = field_id + self._row_id = row_id + + @property + def identifier(self) -> Identifier: + return self._identifier + + @property + def field_id(self) -> int: + return self._field_id + + @property + def row_id(self) -> int: + return self._row_id + + def serialize(self) -> bytes: + identifier_bytes = self._identifier.get_full_name().encode('utf-8') + data = struct.pack(' 'BlobViewStruct': + if len(data) < 25: + raise ValueError("Invalid BlobViewStruct data: too short") + + offset = 0 + version = struct.unpack(' len(data): + raise ValueError("Invalid BlobViewStruct data: identifier length exceeds data size") + + identifier = data[offset:offset + identifier_length].decode('utf-8') + offset += identifier_length + field_id = struct.unpack(' bool: + if not isinstance(data, (bytes, bytearray)): + return False + raw = bytes(data) + if len(raw) < 9: + return False + version = raw[0] + if version != cls.CURRENT_VERSION: + return False + try: + magic = struct.unpack(' bool: + if not isinstance(other, BlobViewStruct): + return False + return (self._identifier == other._identifier + and self._field_id == other._field_id + and self._row_id == other._row_id) + + def __hash__(self) -> int: + return hash((self._identifier.get_full_name(), self._field_id, self._row_id)) + + def __str__(self) -> str: + return ( + f"BlobViewStruct(identifier={self._identifier.get_full_name()}, " + f"field_id={self._field_id}, row_id={self._row_id})" + ) + + def __repr__(self) -> str: + return self.__str__() + + class OffsetInputStream(io.RawIOBase): def __init__(self, wrapped, offset: int, length: int): @@ -276,6 +386,10 @@ def from_file(file_io, file_path: str, offset: int, length: int) -> 'Blob': def from_descriptor(uri_reader: UriReader, descriptor: BlobDescriptor) -> 'Blob': return BlobRef(uri_reader, descriptor) + @staticmethod + def from_view(view_struct: BlobViewStruct) -> 'BlobView': + return BlobView(view_struct) + @staticmethod def from_bytes(data: Optional[bytes], file_io=None, allow_blob_data: bool = True) -> Optional['Blob']: if data is None: @@ -283,6 +397,8 @@ def from_bytes(data: Optional[bytes], file_io=None, allow_blob_data: bool = True if not isinstance(data, (bytes, bytearray)): raise TypeError(f"Blob.from_bytes expects bytes, got {type(data)}") data = bytes(data) + if BlobViewStruct.is_blob_view_struct(data): + return Blob.from_view(BlobViewStruct.deserialize(data)) is_descriptor = BlobDescriptor.is_blob_descriptor(data) if not allow_blob_data and not is_descriptor: raise ValueError( @@ -382,3 +498,45 @@ def __eq__(self, other) -> bool: def __hash__(self) -> int: return hash(self._descriptor) + + +BlobConsumer = Callable[[str, Optional[BlobDescriptor]], bool] + + +class BlobView(Blob): + + def __init__(self, view_struct: BlobViewStruct): + self._view_struct: BlobViewStruct = view_struct + self._resolved_blob: Optional[BlobRef] = None + + @property + def view_struct(self) -> BlobViewStruct: + return self._view_struct + + def is_resolved(self) -> bool: + return self._resolved_blob is not None + + def resolve(self, uri_reader: UriReader, descriptor: BlobDescriptor): + self._resolved_blob = BlobRef(uri_reader, descriptor) + + def to_data(self) -> bytes: + return self._resolved().to_data() + + def to_descriptor(self) -> BlobDescriptor: + return self._resolved().to_descriptor() + + def new_input_stream(self) -> BinaryIO: + return self._resolved().new_input_stream() + + def _resolved(self) -> BlobRef: + if self._resolved_blob is None: + raise RuntimeError("BlobView is not resolved.") + return self._resolved_blob + + def __eq__(self, other) -> bool: + if not isinstance(other, BlobView): + return False + return self._view_struct == other._view_struct + + def __hash__(self) -> int: + return hash(self._view_struct) diff --git a/paimon-python/pypaimon/table/row/key_value.py b/paimon-python/pypaimon/table/row/key_value.py index 845f77fadc5d..8fc89ea8b859 100644 --- a/paimon-python/pypaimon/table/row/key_value.py +++ b/paimon-python/pypaimon/table/row/key_value.py @@ -36,6 +36,20 @@ def replace(self, row_tuple: tuple): self._reused_value.replace(row_tuple) return self + def copy(self) -> 'KeyValue': + """Return an independent KeyValue carrying the current row tuple. + + ``replace`` swaps the tuple reference rather than mutating it, so a + copy stays valid even if a pooled/reused source KeyValue is later + replaced again. Callers that need to hold onto a kv past the next + ``replace`` (e.g. FirstRowMergeFunction keeping the first row while + the writer folds with a single pooled KeyValue) use this. + """ + new = KeyValue(self.key_arity, self.value_arity) + if self._row_tuple is not None: + new.replace(self._row_tuple) + return new + def is_add(self) -> bool: return RowKind.is_add_byte(self.value_row_kind_byte) diff --git a/paimon-python/pypaimon/table/source/full_text_read.py b/paimon-python/pypaimon/table/source/full_text_read.py index ca61d528ccf0..3257525ad16f 100644 --- a/paimon-python/pypaimon/table/source/full_text_read.py +++ b/paimon-python/pypaimon/table/source/full_text_read.py @@ -18,7 +18,8 @@ """Full-text read to read index files.""" from abc import ABC, abstractmethod -from typing import List, Optional +from concurrent.futures import wait +from typing import List from pypaimon.globalindex.full_text_search import FullTextSearch from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta @@ -48,23 +49,32 @@ def __init__( table: 'FileStoreTable', limit: int, text_column: 'DataField', - query_text: str + query_text: str, + query_operator: str = "or" ): self._table = table self._limit = limit self._text_column = text_column self._query_text = query_text + self._query_operator = query_operator def read(self, splits: List[FullTextSearchSplit]) -> GlobalIndexResult: if not splits: return GlobalIndexResult.create_empty() - merged_scores = {} - for split in splits: - split_result = self._eval( + futures = [ + self._eval( split.row_range_start, split.row_range_end, split.full_text_index_files ) + for split in splits + ] + + wait(futures) + + merged_scores = {} + for future in futures: + split_result = future.result() if split_result is not None: score_getter = split_result.score_getter() for row_id in split_result.results(): @@ -73,8 +83,7 @@ def read(self, splits: List[FullTextSearchSplit]) -> GlobalIndexResult: return DictBasedScoredIndexResult(merged_scores).top_k(self._limit) - def _eval(self, row_range_start, row_range_end, full_text_index_files - ) -> Optional[GlobalIndexResult]: + def _eval(self, row_range_start, row_range_end, full_text_index_files): index_io_meta_list = [] for index_file in full_text_index_files: meta = index_file.global_index_meta @@ -83,7 +92,8 @@ def _eval(self, row_range_start, row_range_end, full_text_index_files GlobalIndexIOMeta( file_name=index_file.file_name, file_size=index_file.file_size, - metadata=meta.index_meta + metadata=meta.index_meta, + external_path=index_file.external_path, ) ) @@ -99,14 +109,14 @@ def _eval(self, row_range_start, row_range_end, full_text_index_files full_text_search = FullTextSearch( query_text=self._query_text, limit=self._limit, - field_name=self._text_column.name + field_name=self._text_column.name, + query_operator=self._query_operator ) - try: - offset_reader = OffsetGlobalIndexReader(reader, row_range_start, row_range_end) - return offset_reader.visit_full_text_search(full_text_search) - finally: - reader.close() + offset_reader = OffsetGlobalIndexReader(reader, row_range_start, row_range_end) + future = offset_reader.visit_full_text_search(full_text_search) + future.add_done_callback(lambda _: reader.close()) + return future def _create_full_text_reader(index_type, file_io, index_path, index_io_meta_list): diff --git a/paimon-python/pypaimon/table/source/full_text_search_builder.py b/paimon-python/pypaimon/table/source/full_text_search_builder.py index 6d88e956eff0..b7d92c9934c0 100644 --- a/paimon-python/pypaimon/table/source/full_text_search_builder.py +++ b/paimon-python/pypaimon/table/source/full_text_search_builder.py @@ -43,6 +43,11 @@ def with_query_text(self, query_text: str) -> 'FullTextSearchBuilder': """The query text to search.""" pass + @abstractmethod + def with_query_operator(self, query_operator: str) -> 'FullTextSearchBuilder': + """The default operator for query terms. Supported values are 'or' and 'and'.""" + pass + @abstractmethod def new_full_text_scan(self) -> FullTextScan: """Create full-text scan to scan index files.""" @@ -66,6 +71,7 @@ def __init__(self, table: 'FileStoreTable'): self._limit: int = 0 self._text_column: Optional['DataField'] = None self._query_text: Optional[str] = None + self._query_operator: str = "or" def with_limit(self, limit: int) -> 'FullTextSearchBuilder': self._limit = limit @@ -82,6 +88,10 @@ def with_query_text(self, query_text: str) -> 'FullTextSearchBuilder': self._query_text = query_text return self + def with_query_operator(self, query_operator: str) -> 'FullTextSearchBuilder': + self._query_operator = query_operator + return self + def new_full_text_scan(self) -> FullTextScan: if self._text_column is None: raise ValueError("Text column must be set via with_text_column()") @@ -95,5 +105,6 @@ def new_full_text_read(self) -> FullTextRead: if self._query_text is None: raise ValueError("Query text must be set via with_query_text()") return FullTextReadImpl( - self._table, self._limit, self._text_column, self._query_text + self._table, self._limit, self._text_column, + self._query_text, self._query_operator ) diff --git a/paimon-python/pypaimon/table/source/vector_search_read.py b/paimon-python/pypaimon/table/source/vector_search_read.py index aa89cba7444a..e6839ebe1058 100644 --- a/paimon-python/pypaimon/table/source/vector_search_read.py +++ b/paimon-python/pypaimon/table/source/vector_search_read.py @@ -18,6 +18,7 @@ """Vector search read to read index files.""" from abc import ABC, abstractmethod +from concurrent.futures import wait from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta from pypaimon.globalindex.global_index_result import GlobalIndexResult @@ -56,12 +57,19 @@ def read(self, splits): pre_filter = self._pre_filter(splits) - merged_scores = {} - for split in splits: - split_result = self._eval( + futures = [ + self._eval( split.row_range_start, split.row_range_end, split.vector_index_files, pre_filter ) + for split in splits + ] + + wait(futures) + + merged_scores = {} + for future in futures: + split_result = future.result() if split_result is not None: score_getter = split_result.score_getter() for row_id in split_result.results(): @@ -103,9 +111,10 @@ def _pre_filter(self, splits): def _eval(self, row_range_start, row_range_end, vector_index_files, include_row_ids): - # type: (int, int, list, Optional[RoaringBitmap64]) -> Optional[ScoredGlobalIndexResult] + from pypaimon.globalindex.global_index_reader import _completed_future + if not vector_index_files: - return None + return _completed_future(None) index_io_meta_list = [] for index_file in vector_index_files: meta = index_file.global_index_meta @@ -132,12 +141,14 @@ def _eval(self, row_range_start, row_range_end, vector_index_files, if include_row_ids is not None: vector_search = vector_search.with_include_row_ids(include_row_ids) - with _create_vector_reader( + reader = _create_vector_reader( index_type, file_io, index_path, index_io_meta_list, options - ) as reader: - offset_reader = OffsetGlobalIndexReader(reader, row_range_start, row_range_end) - return offset_reader.visit_vector_search(vector_search) + ) + offset_reader = OffsetGlobalIndexReader(reader, row_range_start, row_range_end) + future = offset_reader.visit_vector_search(vector_search) + future.add_done_callback(lambda _: reader.close()) + return future def _create_vector_reader(index_type, file_io, index_path, index_io_meta_list, options=None): diff --git a/paimon-python/pypaimon/table/special_fields.py b/paimon-python/pypaimon/table/special_fields.py index 5c578ec85f07..64d2429bef7d 100644 --- a/paimon-python/pypaimon/table/special_fields.py +++ b/paimon-python/pypaimon/table/special_fields.py @@ -81,3 +81,22 @@ def row_type_with_row_tracking(table_fields: List[DataField], fields_with_row_tracking.append(SpecialFields.SEQUENCE_NUMBER) return fields_with_row_tracking + + @staticmethod + def row_type_with_row_id(table_fields: List[DataField]) -> List[DataField]: + """Add ROW_ID field to the given fields list. + + Args: + table_fields: The original table fields + """ + fields_with_row_id = list(table_fields) + + for field in fields_with_row_id: + if SpecialFields.ROW_ID.name == field.name: + raise ValueError( + "Row tracking field name '{}' conflicts with existing field names." + .format(field.name) + ) + + fields_with_row_id.append(SpecialFields.ROW_ID) + return fields_with_row_id diff --git a/paimon-python/pypaimon/table/system/buckets_table.py b/paimon-python/pypaimon/table/system/buckets_table.py new file mode 100644 index 000000000000..ddd5e8c6c6f0 --- /dev/null +++ b/paimon-python/pypaimon/table/system/buckets_table.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The ``$buckets`` system table — per-bucket aggregated stats.""" + +from typing import List, Optional + +import pyarrow + +from pypaimon.manifest.manifest_file_manager import ManifestFileManager +from pypaimon.manifest.manifest_list_manager import ManifestListManager +from pypaimon.schema.data_types import AtomicType, DataField, RowType +from pypaimon.table.system.system_table import SystemTable + + +TABLE_TYPE = RowType(False, [ + DataField(0, "partition", AtomicType("STRING", nullable=True)), + DataField(1, "bucket", AtomicType("INT", nullable=False)), + DataField(2, "record_count", AtomicType("BIGINT", nullable=False)), + DataField(3, "file_size_in_bytes", AtomicType("BIGINT", nullable=False)), + DataField(4, "file_count", AtomicType("BIGINT", nullable=False)), + DataField(5, "last_update_time", AtomicType("TIMESTAMP(3)", nullable=True)), +]) + + +_TIMESTAMP_TYPE = pyarrow.timestamp("ms") + + +class BucketsTable(SystemTable): + """The ``$buckets`` system table. + + Aggregates manifest entries by (partition, bucket) to show per-bucket + record counts, file sizes, file counts and last update times. + """ + + def system_table_name(self) -> str: + return "buckets" + + def row_type(self) -> RowType: + return TABLE_TYPE + + def primary_keys(self) -> List[str]: + return ["partition", "bucket"] + + def _build_arrow_table(self) -> pyarrow.Table: + snapshot = self.base_table.snapshot_manager().get_latest_snapshot() + if snapshot is None: + return self._empty_table() + + manifest_list_manager = ManifestListManager(self.base_table) + manifest_files = manifest_list_manager.read_all(snapshot) + manifest_file_manager = ManifestFileManager(self.base_table) + entries = manifest_file_manager.read_entries_parallel( + manifest_files, drop_stats=True) + + _NULL = object() + + bucket_map: dict = {} + for entry in entries: + raw_key = tuple( + (field.name, _NULL if value is None else value) + for field, value in zip( + entry.partition.fields, entry.partition.values)) + bucket_id = int(entry.bucket) + key = (raw_key, bucket_id) + + stats = bucket_map.get(key) + if stats is None: + render_items = tuple( + (name, str(val) if val is not _NULL else None) + for name, val in raw_key) + stats = { + "render_items": render_items, + "bucket": bucket_id, + "record_count": 0, + "file_size_in_bytes": 0, + "file_count": 0, + "last_update_time": None, + } + bucket_map[key] = stats + + stats["record_count"] += int(entry.file.row_count) + stats["file_size_in_bytes"] += int(entry.file.file_size) + stats["file_count"] += 1 + ct_ms = entry.file.creation_time_epoch_millis() + if ct_ms is not None: + if (stats["last_update_time"] is None + or ct_ms > stats["last_update_time"]): + stats["last_update_time"] = ct_ms + + sorted_keys = sorted( + bucket_map.keys(), + key=lambda k: ( + _render_partition(bucket_map[k]["render_items"]) or "", + k[1])) + + partition_strings: List[Optional[str]] = [] + buckets: List[int] = [] + record_counts: List[int] = [] + file_sizes: List[int] = [] + file_counts: List[int] = [] + last_update_times: List[Optional[int]] = [] + + for key in sorted_keys: + stats = bucket_map[key] + partition_strings.append(_render_partition(stats["render_items"])) + buckets.append(stats["bucket"]) + record_counts.append(stats["record_count"]) + file_sizes.append(stats["file_size_in_bytes"]) + file_counts.append(stats["file_count"]) + last_update_times.append(stats["last_update_time"]) + + return pyarrow.table({ + "partition": pyarrow.array( + partition_strings, type=pyarrow.string()), + "bucket": pyarrow.array(buckets, type=pyarrow.int32()), + "record_count": pyarrow.array( + record_counts, type=pyarrow.int64()), + "file_size_in_bytes": pyarrow.array( + file_sizes, type=pyarrow.int64()), + "file_count": pyarrow.array(file_counts, type=pyarrow.int64()), + "last_update_time": pyarrow.array( + last_update_times, type=_TIMESTAMP_TYPE), + }) + + @staticmethod + def _empty_table() -> pyarrow.Table: + return pyarrow.table({ + "partition": pyarrow.array([], type=pyarrow.string()), + "bucket": pyarrow.array([], type=pyarrow.int32()), + "record_count": pyarrow.array([], type=pyarrow.int64()), + "file_size_in_bytes": pyarrow.array([], type=pyarrow.int64()), + "file_count": pyarrow.array([], type=pyarrow.int64()), + "last_update_time": pyarrow.array([], type=_TIMESTAMP_TYPE), + }) + + +def _render_partition(spec_items) -> Optional[str]: + """Render a partition spec as ``pt=v/pt2=v2`` or None when empty. + + Null partition values are rendered as ``__NULL__`` to distinguish them + from the literal string ``"None"``. A partition whose value is + literally ``"__NULL__"`` will produce the same rendered string — + aggregation keys are still distinct, but the displayed partition + column will collide. This is a display-only limitation. + """ + if not spec_items: + return None + return "/".join( + "{}={}".format(name, "__NULL__" if value is None else value) + for name, value in spec_items) diff --git a/paimon-python/pypaimon/table/system/system_table_loader.py b/paimon-python/pypaimon/table/system/system_table_loader.py index 9d0576aec1c3..72b758947d2e 100644 --- a/paimon-python/pypaimon/table/system/system_table_loader.py +++ b/paimon-python/pypaimon/table/system/system_table_loader.py @@ -24,7 +24,7 @@ The following short names are intentionally not registered here yet: audit_log, binlog, read_optimized, consumers, statistics, - aggregation_fields, buckets, file_key_ranges, table_indexes, + aggregation_fields, file_key_ranges, table_indexes, row_tracking, all_tables, all_partitions, all_table_options, catalog_options """ @@ -44,6 +44,7 @@ "manifests", "files", "partitions", + "buckets", "tags", "branches", ) @@ -66,6 +67,7 @@ def factory(base_table: "FileStoreTable") -> "SystemTable": "manifests": _lazy("pypaimon.table.system.manifests_table", "ManifestsTable"), "files": _lazy("pypaimon.table.system.files_table", "FilesTable"), "partitions": _lazy("pypaimon.table.system.partitions_table", "PartitionsTable"), + "buckets": _lazy("pypaimon.table.system.buckets_table", "BucketsTable"), "tags": _lazy("pypaimon.table.system.tags_table", "TagsTable"), "branches": _lazy("pypaimon.table.system.branches_table", "BranchesTable"), } diff --git a/paimon-python/pypaimon/tests/blob_table_test.py b/paimon-python/pypaimon/tests/blob_table_test.py index c4e5a4d1bd3f..3d33594c4a92 100755 --- a/paimon-python/pypaimon/tests/blob_table_test.py +++ b/paimon-python/pypaimon/tests/blob_table_test.py @@ -29,8 +29,8 @@ from pypaimon.write.commit_message import CommitMessage -class DataBlobWriterTest(unittest.TestCase): - """Tests for DataBlobWriter functionality with paimon table operations.""" +class DedicatedFormatWriterTest(unittest.TestCase): + """Tests for DedicatedFormatWriter functionality with paimon table operations.""" @classmethod def setUpClass(cls): @@ -51,8 +51,8 @@ def tearDownClass(cls): except OSError: pass - def test_data_blob_writer_basic_functionality(self): - """Test basic DataBlobWriter functionality with paimon table.""" + def test_dedicated_format_writer_basic_functionality(self): + """Test basic DedicatedFormatWriter functionality with paimon table.""" from pypaimon import Schema # Create schema with normal and blob columns @@ -82,7 +82,7 @@ def test_data_blob_writer_basic_functionality(self): 'blob_data': [b'blob_data_1', b'blob_data_2', b'blob_data_3'] }, schema=pa_schema) - # Test DataBlobWriter initialization using proper table API + # Test DedicatedFormatWriter initialization using proper table API # Use proper table API to create writer write_builder = table.new_batch_write_builder() blob_writer = write_builder.new_write() @@ -108,8 +108,8 @@ def test_data_blob_writer_basic_functionality(self): blob_writer.close() - def test_data_blob_writer_schema_detection(self): - """Test that DataBlobWriter correctly detects blob columns from schema.""" + def test_dedicated_format_writer_schema_detection(self): + """Test that DedicatedFormatWriter correctly detects blob columns from schema.""" from pypaimon import Schema # Test schema with blob column @@ -132,7 +132,7 @@ def test_data_blob_writer_schema_detection(self): write_builder = table.new_batch_write_builder() blob_writer = write_builder.new_write() - # Test that DataBlobWriter was created internally + # Test that DedicatedFormatWriter was created internally # We can verify this by checking the internal data writers test_data = pa.Table.from_pydict({ 'id': [1, 2, 3], @@ -142,19 +142,19 @@ def test_data_blob_writer_schema_detection(self): # Write data to trigger writer creation blob_writer.write_arrow(test_data) - # Verify that a DataBlobWriter was created internally + # Verify that a DedicatedFormatWriter was created internally data_writers = blob_writer.file_store_write.data_writers self.assertGreater(len(data_writers), 0) - # Check that the writer is a DataBlobWriter + # Check that the writer is a DedicatedFormatWriter for writer in data_writers.values(): - from pypaimon.write.writer.data_blob_writer import DataBlobWriter - self.assertIsInstance(writer, DataBlobWriter) + from pypaimon.write.writer.dedicated_format_writer import DedicatedFormatWriter + self.assertIsInstance(writer, DedicatedFormatWriter) blob_writer.close() - def test_data_blob_writer_no_blob_column(self): - """Test that DataBlobWriter raises error when no blob column is found.""" + def test_dedicated_format_writer_no_blob_column(self): + """Test that DedicatedFormatWriter raises error when no blob column is found.""" from pypaimon import Schema # Test schema without blob column @@ -177,7 +177,7 @@ def test_data_blob_writer_no_blob_column(self): write_builder = table.new_batch_write_builder() writer = write_builder.new_write() - # Test that a regular writer (not DataBlobWriter) was created + # Test that a regular writer (not DedicatedFormatWriter) was created test_data = pa.Table.from_pydict({ 'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie'] @@ -186,19 +186,19 @@ def test_data_blob_writer_no_blob_column(self): # Write data to trigger writer creation writer.write_arrow(test_data) - # Verify that a regular writer was created (not DataBlobWriter) + # Verify that a regular writer was created (not DedicatedFormatWriter) data_writers = writer.file_store_write.data_writers self.assertGreater(len(data_writers), 0) - # Check that the writer is NOT a DataBlobWriter + # Check that the writer is NOT a DedicatedFormatWriter for writer_instance in data_writers.values(): - from pypaimon.write.writer.data_blob_writer import DataBlobWriter - self.assertNotIsInstance(writer_instance, DataBlobWriter) + from pypaimon.write.writer.dedicated_format_writer import DedicatedFormatWriter + self.assertNotIsInstance(writer_instance, DedicatedFormatWriter) writer.close() - def test_data_blob_writer_multiple_blob_columns(self): - """Test that DataBlobWriter supports multiple blob columns.""" + def test_dedicated_format_writer_multiple_blob_columns(self): + """Test that DedicatedFormatWriter supports multiple blob columns.""" from pypaimon import Schema # Test schema with multiple blob columns @@ -242,7 +242,7 @@ def test_data_blob_writer_multiple_blob_columns(self): result = table.new_read_builder().new_read().to_arrow(table.new_read_builder().new_scan().plan().splits()) self.assertEqual(result.num_rows, 3) - def test_data_blob_writer_partial_write_with_write_type(self): + def test_dedicated_format_writer_partial_write_with_write_type(self): """Partial write (normal + blob subset) via with_write_type: split must match batch columns.""" from pypaimon import Schema @@ -292,7 +292,7 @@ def test_data_blob_writer_partial_write_with_write_type(self): self.assertEqual(out.column('blob_data').to_pylist(), [b'a', b'b']) self.assertEqual(out.column('name').to_pylist(), [None, None]) - def test_data_blob_writer_partial_write_normal_only_with_write_type(self): + def test_dedicated_format_writer_partial_write_normal_only_with_write_type(self): """Partial write without blob columns in write_cols must not touch blob split paths.""" from pypaimon import Schema @@ -333,7 +333,7 @@ def test_data_blob_writer_partial_write_normal_only_with_write_type(self): self.assertEqual(out.column('name').to_pylist(), ['n']) self.assertEqual(out.column('blob_data').to_pylist(), [None]) - def test_data_blob_writer_partial_write_single_blob_of_two_with_write_type(self): + def test_dedicated_format_writer_partial_write_single_blob_of_two_with_write_type(self): """with_write_type lists only one blob column: only that column gets .blob files.""" from pypaimon import Schema @@ -369,8 +369,8 @@ def test_data_blob_writer_partial_write_single_blob_of_two_with_write_type(self) write_builder.new_commit().commit(commit_messages) writer.close() - def test_data_blob_writer_write_operations(self): - """Test DataBlobWriter write operations with real data.""" + def test_dedicated_format_writer_write_operations(self): + """Test DedicatedFormatWriter write operations with real data.""" from pypaimon import Schema # Create schema with blob column @@ -411,8 +411,8 @@ def test_data_blob_writer_write_operations(self): blob_writer.close() - def test_data_blob_writer_write_large_blob(self): - """Test DataBlobWriter with very large blob data (50MB per item) in 10 batches.""" + def test_dedicated_format_writer_write_large_blob(self): + """Test DedicatedFormatWriter with very large blob data (50MB per item) in 10 batches.""" from pypaimon import Schema # Create schema with blob column @@ -436,28 +436,27 @@ def test_data_blob_writer_write_large_blob(self): write_builder = table.new_batch_write_builder() blob_writer = write_builder.new_write() - # Create 50MB blob data per item - # Using a pattern to make the data more realistic and compressible - target_size = 50 * 1024 * 1024 # 50MB in bytes + # Create 5MB blob data per item + target_size = 5 * 1024 * 1024 # 5MB in bytes blob_pattern = b'LARGE_BLOB_DATA_PATTERN_' + b'X' * 1024 # ~1KB pattern pattern_size = len(blob_pattern) repetitions = target_size // pattern_size large_blob_data = blob_pattern * repetitions - # Verify the blob size is approximately 50MB + # Verify the blob size is approximately 5MB blob_size_mb = len(large_blob_data) / (1024 * 1024) - self.assertGreater(blob_size_mb, 49) # Should be at least 49MB - self.assertLess(blob_size_mb, 51) # Should be less than 51MB + self.assertGreater(blob_size_mb, 4) # Should be at least 4MB + self.assertLess(blob_size_mb, 6) # Should be less than 6MB total_rows = 0 # Write 10 batches, each with 5 rows (50 rows total) - # Total data volume: 50 rows * 50MB = 2.5GB of blob data + # Total data volume: 50 rows * 5MB = 250MB of blob data for batch_num in range(10): batch_data = pa.Table.from_pydict({ 'id': [batch_num * 5 + i for i in range(5)], 'description': [f'Large blob batch {batch_num}, row {i}' for i in range(5)], - 'large_blob': [large_blob_data] * 5 # 5 rows per batch, each with 50MB blob + 'large_blob': [large_blob_data] * 5 # 5 rows per batch, each with 5MB blob }, schema=pa_schema) # Write each batch @@ -468,7 +467,7 @@ def test_data_blob_writer_write_large_blob(self): # Log progress for large data processing print(f"Completed batch {batch_num + 1}/10 with {batch.num_rows} rows") - # Record count is tracked internally by DataBlobWriter + # Record count is tracked internally by DedicatedFormatWriter # Test prepare commit commit_messages: CommitMessage = blob_writer.prepare_commit() @@ -502,9 +501,9 @@ def test_data_blob_writer_write_large_blob(self): # Verify total data written (50 rows of normal data + 50 rows of blob data = 100 total) self.assertEqual(total_row_count, 50) - # Verify total file size is substantial (should be much larger than 2.5GB due to overhead) + # Verify total file size is substantial (should be at least 200MB) total_size_mb = total_file_size / (1024 * 1024) - self.assertGreater(total_size_mb, 2000) # Should be at least 2GB due to overhead + self.assertGreater(total_size_mb, 200) total_files = sum(len(commit_msg.new_files) for commit_msg in commit_messages) print(f"Total data written: {total_size_mb:.2f}MB across {total_files} files") @@ -512,8 +511,8 @@ def test_data_blob_writer_write_large_blob(self): blob_writer.close() - def test_data_blob_writer_abort_functionality(self): - """Test DataBlobWriter abort functionality.""" + def test_dedicated_format_writer_abort_functionality(self): + """Test DedicatedFormatWriter abort functionality.""" from pypaimon import Schema # Create schema with blob column @@ -547,12 +546,12 @@ def test_data_blob_writer_abort_functionality(self): blob_writer.write_arrow_batch(batch) # Test abort - BatchTableWrite doesn't have abort method - # The abort functionality is handled internally by DataBlobWriter + # The abort functionality is handled internally by DedicatedFormatWriter blob_writer.close() - def test_data_blob_writer_multiple_batches(self): - """Test DataBlobWriter with multiple batches and verify results.""" + def test_dedicated_format_writer_multiple_batches(self): + """Test DedicatedFormatWriter with multiple batches and verify results.""" from pypaimon import Schema # Create schema with blob column @@ -609,7 +608,7 @@ def test_data_blob_writer_multiple_batches(self): blob_writer.write_arrow_batch(batch) total_rows += batch.num_rows - # Record count is tracked internally by DataBlobWriter + # Record count is tracked internally by DedicatedFormatWriter # Test prepare commit commit_messages = blob_writer.prepare_commit() @@ -620,8 +619,8 @@ def test_data_blob_writer_multiple_batches(self): blob_writer.close() - def test_data_blob_writer_large_batches(self): - """Test DataBlobWriter with large batches to test rolling behavior.""" + def test_dedicated_format_writer_large_batches(self): + """Test DedicatedFormatWriter with large batches to test rolling behavior.""" from pypaimon import Schema # Create schema with blob column @@ -672,7 +671,7 @@ def test_data_blob_writer_large_batches(self): blob_writer.write_arrow_batch(batch) total_rows += batch.num_rows - # Record count is tracked internally by DataBlobWriter + # Record count is tracked internally by DedicatedFormatWriter # Test prepare commit commit_messages = blob_writer.prepare_commit() @@ -683,8 +682,8 @@ def test_data_blob_writer_large_batches(self): blob_writer.close() - def test_data_blob_writer_mixed_data_types(self): - """Test DataBlobWriter with mixed data types in blob column.""" + def test_dedicated_format_writer_mixed_data_types(self): + """Test DedicatedFormatWriter with mixed data types in blob column.""" from pypaimon import Schema # Create schema with blob column @@ -727,7 +726,7 @@ def test_data_blob_writer_mixed_data_types(self): blob_writer.write_arrow_batch(batch) total_rows += batch.num_rows - # Record count is tracked internally by DataBlobWriter + # Record count is tracked internally by DedicatedFormatWriter # Test prepare commit commit_messages = blob_writer.prepare_commit() @@ -787,8 +786,8 @@ def test_data_blob_writer_mixed_data_types(self): self.assertEqual(result_type, original_type, f"Row {i + 1}: Type should match") self.assertEqual(result_data, original_data, f"Row {i + 1}: Blob data should match") - def test_data_blob_writer_empty_batches(self): - """Test DataBlobWriter with empty batches.""" + def test_dedicated_format_writer_empty_batches(self): + """Test DedicatedFormatWriter with empty batches.""" from pypaimon import Schema # Create schema with blob column @@ -843,8 +842,8 @@ def test_data_blob_writer_empty_batches(self): total_rows += batch.num_rows # Verify record count (empty batch should not affect count) - # Record count is tracked internally by DataBlobWriter - # Record count is tracked internally by DataBlobWriter + # Record count is tracked internally by DedicatedFormatWriter + # Record count is tracked internally by DedicatedFormatWriter # Test prepare commit commit_messages = blob_writer.prepare_commit() @@ -852,8 +851,8 @@ def test_data_blob_writer_empty_batches(self): blob_writer.close() - def test_data_blob_writer_rolling_behavior(self): - """Test DataBlobWriter rolling behavior with multiple commits.""" + def test_dedicated_format_writer_rolling_behavior(self): + """Test DedicatedFormatWriter rolling behavior with multiple commits.""" from pypaimon import Schema # Create schema with blob column @@ -893,7 +892,7 @@ def test_data_blob_writer_rolling_behavior(self): blob_writer.write_arrow_batch(batch) # Verify total record count - # Record count is tracked internally by DataBlobWriter + # Record count is tracked internally by DedicatedFormatWriter # Test prepare commit commit_messages = blob_writer.prepare_commit() @@ -1063,6 +1062,82 @@ def test_null_blob(self): [b'first_blob', None, b'third_blob', None, b'fifth_blob'], ) + def test_update_blob_column(self): + from pypaimon import Schema + from pypaimon.read.reader.format_blob_reader import FormatBlobReader + from pypaimon.write.blob_format_writer import BlobFormatWriter + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('blob_data', pa.large_binary()), + ]) + + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true' + } + ) + self.catalog.create_table('test_db.blob_update_column', schema, False) + table = self.catalog.get_table('test_db.blob_update_column') + + initial = pa.Table.from_pydict({ + 'id': [1, 2, 3], + 'name': ['a', 'b', 'c'], + 'blob_data': [b'blob-1', b'blob-2', b'blob-3'], + }, schema=pa_schema) + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(initial) + write_builder.new_commit().commit(writer.prepare_commit()) + writer.close() + + update_builder = table.new_batch_write_builder() + table_update = update_builder.new_update().with_update_type(['blob_data']) + update_data = pa.Table.from_pydict({ + '_ROW_ID': pa.array([1], type=pa.int64()), + 'blob_data': pa.array([b'updated-blob-2'], type=pa.large_binary()), + }) + update_messages = table_update.update_by_arrow_with_row_id(update_data) + update_builder.new_commit().commit(update_messages) + + update_files = [f for msg in update_messages for f in msg.new_files] + update_blob_files = [f for f in update_files if f.file_name.endswith('.blob')] + self.assertGreater(len(update_blob_files), 0) + self.assertTrue(all(f.write_cols == ['blob_data'] for f in update_files)) + update_blob_lengths = [] + blob_fields = [field for field in table.fields if field.name == 'blob_data'] + for blob_file in update_blob_files: + blob_reader = FormatBlobReader( + file_io=table.file_io, + file_path=blob_file.file_path, + read_fields=['blob_data'], + full_fields=blob_fields, + push_down_predicate=None, + blob_as_descriptor=False, + ) + update_blob_lengths.extend(blob_reader.blob_lengths) + blob_reader.close() + self.assertEqual( + update_blob_lengths.count(BlobFormatWriter.PLACE_HOLDER_LENGTH), + 2, + ) + + read_builder = table.new_read_builder() + result = read_builder.new_read().to_arrow(read_builder.new_scan().plan().splits()) + by_id = { + row['id']: row['blob_data'] + for row in result.select(['id', 'blob_data']).to_pylist() + } + self.assertEqual(by_id, { + 1: b'blob-1', + 2: b'updated-blob-2', + 3: b'blob-3', + }) + def test_blob_write_read_partition(self): """Test complete end-to-end blob functionality: write blob data and read it back to verify correctness.""" from pypaimon import Schema @@ -1315,6 +1390,483 @@ def test_blob_descriptor_fields_mixed_mode(self): self.assertEqual(result.column('pic1').to_pylist()[0], pic1_data) self.assertEqual(result.column('pic2').to_pylist()[0], pic2_data) + def test_blob_view_fields_resolve_upstream_blob(self): + from pypaimon import Schema + from pypaimon.common.options.core_options import CoreOptions + from pypaimon.table.row.blob import BlobViewStruct + + source_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + source = Schema.from_pyarrow_schema( + source_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + ) + self.catalog.create_table('test_db.blob_view_source', source, False) + source_table = self.catalog.get_table('test_db.blob_view_source') + payloads = [b'view-source-0', b'view-source-1'] + + write_builder = source_table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(pa.Table.from_pydict({ + 'id': [1, 2], + 'picture': payloads, + }, schema=source_schema)) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + picture_field_id = next( + field.id for field in source_table.table_schema.fields if field.name == 'picture' + ) + view_values = [ + BlobViewStruct('test_db.blob_view_source', picture_field_id, 0).serialize(), + BlobViewStruct('test_db.blob_view_source', picture_field_id, 1).serialize(), + ] + + target_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + target = Schema.from_pyarrow_schema( + target_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob-view-field': 'picture', + } + ) + self.catalog.create_table('test_db.blob_view_target', target, False) + target_table = self.catalog.get_table('test_db.blob_view_target') + + target_write_builder = target_table.new_batch_write_builder() + target_writer = target_write_builder.new_write() + target_writer.write_arrow(pa.Table.from_pydict({ + 'id': [10, 11], + 'picture': view_values, + }, schema=target_schema)) + target_commit_messages = target_writer.prepare_commit() + target_write_builder.new_commit().commit(target_commit_messages) + target_writer.close() + + all_target_files = [f for msg in target_commit_messages for f in msg.new_files] + self.assertFalse( + any(f.file_name.endswith('.blob') for f in all_target_files), + "Blob view fields should be stored inline without writing new blob files", + ) + + result = target_table.new_read_builder().new_read().to_arrow( + target_table.new_read_builder().new_scan().plan().splits() + ).sort_by('id') + self.assertEqual(result.column('picture').to_pylist(), payloads) + + descriptor_table = target_table.copy({CoreOptions.BLOB_AS_DESCRIPTOR.key(): 'true'}) + descriptor_result = descriptor_table.new_read_builder().new_read().to_arrow( + descriptor_table.new_read_builder().new_scan().plan().splits() + ).sort_by('id') + # With blob-as-descriptor=true, view fields return BlobDescriptor bytes + from pypaimon.table.row.blob import BlobDescriptor + for value in descriptor_result.column('picture').to_pylist(): + self.assertTrue( + BlobDescriptor.is_blob_descriptor(value), + "Expected BlobDescriptor bytes when blob-as-descriptor=true" + ) + + def test_blob_view_resolve_disabled_preserves_references(self): + from pypaimon import Schema + from pypaimon.common.options.core_options import CoreOptions + from pypaimon.table.row.blob import BlobViewStruct + + source_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + source = Schema.from_pyarrow_schema( + source_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + ) + self.catalog.create_table('test_db.blob_view_resolve_source', source, False) + source_table = self.catalog.get_table('test_db.blob_view_resolve_source') + payloads = [b'resolve-source-0', b'resolve-source-1'] + + write_builder = source_table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(pa.Table.from_pydict({ + 'id': [1, 2], + 'picture': payloads, + }, schema=source_schema)) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + picture_field_id = next( + field.id for field in source_table.table_schema.fields if field.name == 'picture' + ) + view_values = [ + BlobViewStruct('test_db.blob_view_resolve_source', picture_field_id, 0).serialize(), + BlobViewStruct('test_db.blob_view_resolve_source', picture_field_id, 1).serialize(), + ] + + target_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + target = Schema.from_pyarrow_schema( + target_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob-view-field': 'picture', + } + ) + self.catalog.create_table('test_db.blob_view_resolve_target', target, False) + target_table = self.catalog.get_table('test_db.blob_view_resolve_target') + + target_write_builder = target_table.new_batch_write_builder() + target_writer = target_write_builder.new_write() + target_writer.write_arrow(pa.Table.from_pydict({ + 'id': [10, 11], + 'picture': view_values, + }, schema=target_schema)) + target_commit_messages = target_writer.prepare_commit() + target_write_builder.new_commit().commit(target_commit_messages) + target_writer.close() + + # Default (resolve enabled): view fields are resolved to real blob data. + resolved_result = target_table.new_read_builder().new_read().to_arrow( + target_table.new_read_builder().new_scan().plan().splits() + ).sort_by('id') + self.assertEqual(resolved_result.column('picture').to_pylist(), payloads) + + # resolve disabled: view fields keep the original BlobViewStruct bytes. + preserve_table = target_table.copy( + {CoreOptions.BLOB_VIEW_RESOLVE_ENABLED.key(): 'false'} + ) + preserve_result = preserve_table.new_read_builder().new_read().to_arrow( + preserve_table.new_read_builder().new_scan().plan().splits() + ).sort_by('id') + preserved_values = preserve_result.column('picture').to_pylist() + self.assertEqual(preserved_values, view_values) + for value in preserved_values: + self.assertTrue( + BlobViewStruct.is_blob_view_struct(value), + "Expected original BlobViewStruct bytes when resolve disabled" + ) + + def test_blob_view_resolves_null_upstream_value(self): + from pypaimon import Schema + from pypaimon.table.row.blob import BlobViewStruct + + source_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + source = Schema.from_pyarrow_schema( + source_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + ) + self.catalog.create_table('test_db.blob_view_null_source', source, False) + source_table = self.catalog.get_table('test_db.blob_view_null_source') + # Row 0 has a real blob value, row 1 has a null blob value. + payloads = [b'null-source-0', None] + + write_builder = source_table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(pa.Table.from_pydict({ + 'id': [1, 2], + 'picture': payloads, + }, schema=source_schema)) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + picture_field_id = next( + field.id for field in source_table.table_schema.fields if field.name == 'picture' + ) + view_values = [ + BlobViewStruct('test_db.blob_view_null_source', picture_field_id, 0).serialize(), + BlobViewStruct('test_db.blob_view_null_source', picture_field_id, 1).serialize(), + ] + + target_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + target = Schema.from_pyarrow_schema( + target_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob-view-field': 'picture', + } + ) + self.catalog.create_table('test_db.blob_view_null_target', target, False) + target_table = self.catalog.get_table('test_db.blob_view_null_target') + + target_write_builder = target_table.new_batch_write_builder() + target_writer = target_write_builder.new_write() + target_writer.write_arrow(pa.Table.from_pydict({ + 'id': [10, 11], + 'picture': view_values, + }, schema=target_schema)) + target_commit_messages = target_writer.prepare_commit() + target_write_builder.new_commit().commit(target_commit_messages) + target_writer.close() + + # View referencing a real upstream value resolves to data; view + # referencing a null upstream value resolves to None (not an error). + result = target_table.new_read_builder().new_read().to_arrow( + target_table.new_read_builder().new_scan().plan().splits() + ).sort_by('id') + self.assertEqual(result.column('picture').to_pylist(), [b'null-source-0', None]) + + def test_blob_view_fields_rejects_non_view_input(self): + from pypaimon import Schema + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob-view-field': 'picture', + } + ) + self.catalog.create_table('test_db.blob_view_reject_test', schema, False) + table = self.catalog.get_table('test_db.blob_view_reject_test') + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + bad_data = pa.Table.from_pydict({ + 'id': [1], + 'picture': [b'not-a-view-struct'], + }, schema=pa_schema) + + with self.assertRaises(ValueError) as context: + writer.write_arrow(bad_data) + self.assertIn("blob-view-field", str(context.exception)) + + def test_blob_inline_fields_reject_overlap_and_unknown_fields(self): + from pypaimon import Schema + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + base_options = { + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + + overlap_options = dict(base_options) + overlap_options.update({ + 'blob-descriptor-field': 'picture', + 'blob-view-field': 'picture', + }) + overlap_schema = Schema.from_pyarrow_schema(pa_schema, options=overlap_options) + with self.assertRaises(ValueError) as overlap_context: + self.catalog.create_table( + 'test_db.blob_overlap_reject', overlap_schema, False) + self.assertIn("must not overlap", str(overlap_context.exception)) + + unknown_options = dict(base_options) + unknown_options.update({'blob-view-field': 'missing_picture'}) + unknown_schema = Schema.from_pyarrow_schema(pa_schema, options=unknown_options) + with self.assertRaises(ValueError) as unknown_context: + self.catalog.create_table( + 'test_db.blob_unknown_reject', unknown_schema, False) + self.assertIn("must be blob fields", str(unknown_context.exception)) + + def test_blob_view_prescan_with_limit(self): + """Test that limit is correctly pushed down to prescan reader. + + Regression test for: prescan should only scan up to limit rows, + not the entire split. + """ + from pypaimon import Schema + from pypaimon.table.row.blob import BlobViewStruct + + # Create source table with multiple rows + source_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + source = Schema.from_pyarrow_schema( + source_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + ) + self.catalog.create_table('test_db.blob_view_limit_source', source, False) + source_table = self.catalog.get_table('test_db.blob_view_limit_source') + + # Write 10 rows + num_rows = 10 + payloads = [f'payload-{i}'.encode() for i in range(num_rows)] + write_builder = source_table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(pa.Table.from_pydict({ + 'id': list(range(num_rows)), + 'picture': payloads, + }, schema=source_schema)) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + picture_field_id = next( + field.id for field in source_table.table_schema.fields if field.name == 'picture' + ) + view_values = [ + BlobViewStruct('test_db.blob_view_limit_source', picture_field_id, i).serialize() + for i in range(num_rows) + ] + + # Create target table with blob-view-field + target_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + target = Schema.from_pyarrow_schema( + target_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob-view-field': 'picture', + } + ) + self.catalog.create_table('test_db.blob_view_limit_target', target, False) + target_table = self.catalog.get_table('test_db.blob_view_limit_target') + + target_write_builder = target_table.new_batch_write_builder() + target_writer = target_write_builder.new_write() + target_writer.write_arrow(pa.Table.from_pydict({ + 'id': list(range(num_rows)), + 'picture': view_values, + }, schema=target_schema)) + target_commit_messages = target_writer.prepare_commit() + target_write_builder.new_commit().commit(target_commit_messages) + target_writer.close() + + # Test with limit: should only return first 3 rows + read_builder = target_table.new_read_builder() + read_builder.with_limit(3) + result = read_builder.new_read().to_arrow( + read_builder.new_scan().plan().splits() + ) + self.assertEqual(result.num_rows, 3, "LIMIT should be respected in blob view prescan") + self.assertEqual(result.column('id').to_pylist(), [0, 1, 2]) + + def test_blob_view_prescan_only_collects_limited_view_structs(self): + """Verify that the prescan stage only collects as many BlobViewStructs as + the limit allows, instead of scanning the entire split. + + Unlike test_blob_view_prescan_with_limit (which only checks the final + output), this test patches BlobViewLookup.preload to capture the exact + list of view structs collected during prescan and asserts its length + equals the limit. + """ + from unittest import mock + + from pypaimon import Schema + from pypaimon.table.row.blob import BlobViewStruct + from pypaimon.utils.blob_view_lookup import BlobViewLookup + + source_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + source = Schema.from_pyarrow_schema( + source_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + ) + self.catalog.create_table('test_db.blob_view_prescan_count_source', source, False) + source_table = self.catalog.get_table('test_db.blob_view_prescan_count_source') + + num_rows = 10 + payloads = [f'payload-{i}'.encode() for i in range(num_rows)] + write_builder = source_table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(pa.Table.from_pydict({ + 'id': list(range(num_rows)), + 'picture': payloads, + }, schema=source_schema)) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + picture_field_id = next( + field.id for field in source_table.table_schema.fields if field.name == 'picture' + ) + view_values = [ + BlobViewStruct('test_db.blob_view_prescan_count_source', picture_field_id, i).serialize() + for i in range(num_rows) + ] + + target_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + target = Schema.from_pyarrow_schema( + target_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob-view-field': 'picture', + } + ) + self.catalog.create_table('test_db.blob_view_prescan_count_target', target, False) + target_table = self.catalog.get_table('test_db.blob_view_prescan_count_target') + + target_write_builder = target_table.new_batch_write_builder() + target_writer = target_write_builder.new_write() + target_writer.write_arrow(pa.Table.from_pydict({ + 'id': list(range(num_rows)), + 'picture': view_values, + }, schema=target_schema)) + target_commit_messages = target_writer.prepare_commit() + target_write_builder.new_commit().commit(target_commit_messages) + target_writer.close() + + captured_view_structs = [] + original_preload = BlobViewLookup.preload + + def capturing_preload(lookup_self, view_structs): + captured_view_structs.append(list(view_structs)) + return original_preload(lookup_self, view_structs) + + limit = 3 + read_builder = target_table.new_read_builder() + read_builder.with_limit(limit) + with mock.patch.object(BlobViewLookup, 'preload', autospec=True, + side_effect=capturing_preload): + result = read_builder.new_read().to_arrow( + read_builder.new_scan().plan().splits() + ) + + self.assertEqual(result.num_rows, limit) + self.assertEqual(len(captured_view_structs), 1, + "preload should be invoked exactly once during prescan") + self.assertEqual( + len(captured_view_structs[0]), limit, + "prescan should only collect as many view structs as the limit allows") + def test_to_arrow_batch_reader(self): import random from pypaimon import Schema @@ -2715,8 +3267,8 @@ def test_blob_large_data_volume_with_shard(self): actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('id') self.assertEqual(actual, expected) - def test_data_blob_writer_with_slice(self): - """Test DataBlobWriter with mixed data types in blob column.""" + def test_dedicated_format_writer_with_slice(self): + """Test DedicatedFormatWriter with mixed data types in blob column.""" # Create schema with blob column pa_schema = pa.schema([ @@ -2777,8 +3329,8 @@ def test_data_blob_writer_with_slice(self): self.assertEqual(result.num_columns, 3, "Should have 3 columns") self.assertEqual(result["id"].unique().to_pylist(), [2, 3], "Get incorrect column ID") - def test_data_blob_writer_with_shard(self): - """Test DataBlobWriter with mixed data types in blob column.""" + def test_dedicated_format_writer_with_shard(self): + """Test DedicatedFormatWriter with mixed data types in blob column.""" # Create schema with blob column pa_schema = pa.schema([ @@ -3129,7 +3681,7 @@ def test_blob_data_with_ray(self): total_split_row_count = sum([s.row_count for s in splits]) self.assertEqual(total_split_row_count, num_rows * 2, f"Total split row count should be {num_rows}, got {total_split_row_count}") - + total_merged_count = 0 for split in splits: merged_count = split.merged_row_count() @@ -3138,7 +3690,7 @@ def test_blob_data_with_ray(self): self.assertLessEqual( merged_count, split.row_count, f"merged_row_count ({merged_count}) should be <= row_count ({split.row_count})") - + if total_merged_count > 0: self.assertEqual( total_merged_count, num_rows, @@ -3178,14 +3730,14 @@ def test_blob_with_row_id_equal(self): self.catalog.create_table('test_db.blob_rowid_equal', schema, False) table = self.catalog.get_table('test_db.blob_rowid_equal') - blob_bytes = os.urandom(248 * 1024) + blob_values = [bytes([i]) * (248 * 1024) for i in range(20)] write_builder = table.new_batch_write_builder() tw = write_builder.new_write() tc = write_builder.new_commit() batch = pa.Table.from_pydict({ 'id': list(range(20)), 'name': [f'item_{i}' for i in range(20)], - 'data': [blob_bytes] * 20, + 'data': blob_values, }, schema=pa_schema) tw.write_arrow(batch) tc.commit(tw.prepare_commit()) @@ -3196,6 +3748,7 @@ def test_blob_with_row_id_equal(self): from pypaimon.table.special_fields import SpecialFields rb = table.new_read_builder() + rb.with_projection(['id', 'name', 'data', SpecialFields.ROW_ID.name]) fields = list(table.fields) fields.append(SpecialFields.ROW_ID) pb = PredicateBuilder(fields) @@ -3207,6 +3760,10 @@ def test_blob_with_row_id_equal(self): read = rb.new_read() result = read.to_arrow(splits) self.assertEqual(result.num_rows, 1) + self.assertEqual(result.column('id').to_pylist(), [5]) + self.assertEqual(result.column('name').to_pylist(), ['item_5']) + self.assertEqual(result.column('data').to_pylist(), [blob_values[5]]) + self.assertEqual(result.column(SpecialFields.ROW_ID.name).to_pylist(), [5]) def test_rename_blob_column_should_fail(self): pa_schema = pa.schema([ @@ -3232,6 +3789,24 @@ def test_rename_blob_column_should_fail(self): ) self.assertIn('Cannot rename BLOB column', str(ctx.exception)) + def test_nested_field_named_blob_not_treated_as_blob(self): + """Regression: a ROW field with a nested column whose name contains + 'blob' must NOT be treated as a top-level BLOB column. Previously + the substring match would falsely classify such fields, causing + create_table to require row-tracking and data-evolution options.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('payload', pa.struct([ + ('blob_name', pa.string()), + ('value', pa.int64()), + ])), + ]) + schema = Schema.from_pyarrow_schema(pa_schema) + self.catalog.create_table( + 'test_db.nested_blob_name_no_error', schema, False) + table = self.catalog.get_table('test_db.nested_blob_name_no_error') + self.assertIsNotNone(table) + class GetBlobTest(unittest.TestCase): @@ -3289,6 +3864,22 @@ def test_get_blob_access(self): self.assertEqual(results[1], (2, b'img_data_2')) self.assertEqual(results[2], (3, b'img_data_3')) + def test_get_blob_access_with_limit(self): + read_builder = self.table.new_read_builder().with_limit(2) + splits = read_builder.new_scan().plan().splits() + read = read_builder.new_read() + + results = [] + for row in read.to_iterator(splits): + blob = row.get_blob(2) + self.assertIsNotNone(blob) + results.append((row.get_field(0), blob.to_data())) + + self.assertEqual(len(results), 2) + for row_id, data in results: + self.assertIn(row_id, (1, 2, 3)) + self.assertIn(data, (b'img_data_1', b'img_data_2', b'img_data_3')) + def test_get_blob_streaming(self): read_builder = self.table.new_read_builder() splits = read_builder.new_scan().plan().splits() @@ -3488,5 +4079,241 @@ def test_get_blob_on_non_blob_column_with_magic_bytes_raises(self): mock_create.assert_not_called() +class BlobConsumerTest(unittest.TestCase): + """Tests for BlobConsumer callback functionality.""" + + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.temp_dir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('test_db', False) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + def test_blob_consumer_basic(self): + """Consumer receives one BlobDescriptor per blob written, None for nulls.""" + from pypaimon.table.row.blob import Blob, BlobDescriptor + from pypaimon.common.uri_reader import FileUriReader + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_basic', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_basic') + + blob_bytes = b'hello_blob_consumer' + received = [] + + def my_consumer(field_name, descriptor): + received.append((field_name, descriptor)) + return True + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.with_blob_consumer(my_consumer) + + test_data = pa.Table.from_pydict({ + 'id': [1, 2, 3], + 'name': ['a', 'b', 'c'], + 'blob_data': [blob_bytes, blob_bytes, None], + }, schema=pa_schema) + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + self.assertEqual(len(received), 3) + + for field_name, desc in received[:2]: + self.assertEqual(field_name, 'blob_data') + self.assertIsInstance(desc, BlobDescriptor) + uri_reader = FileUriReader(table.file_io) + blob = Blob.from_descriptor(uri_reader, desc) + self.assertEqual(blob.to_data(), blob_bytes) + + self.assertEqual(received[2][0], 'blob_data') + self.assertIsNone(received[2][1]) + + def test_blob_consumer_flush_behavior(self): + """Consumer return value controls flush; verify flush count.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_flush', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_flush') + + blob_bytes = b'flush_test_blob' + descriptors = [] + flush_count = [0] + + def my_consumer(field_name, descriptor): + descriptors.append(descriptor) + should_flush = len(descriptors) % 2 == 0 + if should_flush: + flush_count[0] += 1 + return should_flush + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.with_blob_consumer(my_consumer) + + test_data = pa.Table.from_pydict({ + 'id': list(range(5)), + 'name': [f'row{i}' for i in range(5)], + 'blob_data': [blob_bytes] * 5, + }, schema=pa_schema) + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + self.assertEqual(len(descriptors), 5) + self.assertEqual(flush_count[0], 2) + + from pypaimon.table.row.blob import Blob + from pypaimon.common.uri_reader import FileUriReader + uri_reader = FileUriReader(table.file_io) + for desc in descriptors: + self.assertIsNotNone(desc) + blob = Blob.from_descriptor(uri_reader, desc) + self.assertEqual(blob.to_data(), blob_bytes) + + def test_blob_consumer_no_consumer_set(self): + """Without consumer, writing still works normally.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_no_consumer', schema, False) + table = self.catalog.get_table('test_db.blob_no_consumer') + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + + test_data = pa.Table.from_pydict({ + 'id': [1, 2], + 'blob_data': [b'data1', b'data2'], + }, schema=pa_schema) + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + result = table.new_read_builder().new_read().to_arrow( + table.new_read_builder().new_scan().plan().splits()) + self.assertEqual(result.column('blob_data').to_pylist(), [b'data1', b'data2']) + + def test_blob_consumer_chain_call(self): + """with_blob_consumer returns self for chaining.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_chain', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_chain') + + write_builder = table.new_batch_write_builder() + result = write_builder.new_write().with_blob_consumer(lambda f, d: False) + self.assertIsNotNone(result) + result.close() + + def test_blob_consumer_abort_preserves_files(self): + """Abort with consumer must not delete blob files that descriptors point to.""" + from pypaimon.table.row.blob import Blob + from pypaimon.common.uri_reader import FileUriReader + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob.target-file-size': '1KB', + }) + self.catalog.create_table('test_db.blob_consumer_abort', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_abort') + + blob_bytes = b'X' * 2048 + received = [] + + def my_consumer(field_name, descriptor): + received.append(descriptor) + return False + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.with_blob_consumer(my_consumer) + + test_data = pa.Table.from_pydict({ + 'id': list(range(5)), + 'blob_data': [blob_bytes] * 5, + }, schema=pa_schema) + writer.write_arrow(test_data) + + self.assertGreater(len(received), 0) + + # Capture data writers before close() clears them, then abort each one. + data_writers = list(writer.file_store_write.data_writers.values()) + self.assertGreater(len(data_writers), 0) + for dw in data_writers: + dw.abort() + + # Every descriptor returned to the consumer must still be readable. + uri_reader = FileUriReader(table.file_io) + for desc in received: + self.assertIsNotNone(desc) + data = Blob.from_descriptor(uri_reader, desc).to_data() + self.assertEqual(data, blob_bytes) + + def test_blob_consumer_after_write_raises(self): + """Setting consumer after data has been written must raise.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_late', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_late') + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + + writer.write_arrow(pa.Table.from_pydict({ + 'id': [1], + 'blob_data': [b'data'], + }, schema=pa_schema)) + + with self.assertRaises(RuntimeError): + writer.with_blob_consumer(lambda f, d: False) + writer.close() + + if __name__ == '__main__': unittest.main() diff --git a/paimon-python/pypaimon/tests/blob_test.py b/paimon-python/pypaimon/tests/blob_test.py index e6b856432b50..37217f8b7cfe 100644 --- a/paimon-python/pypaimon/tests/blob_test.py +++ b/paimon-python/pypaimon/tests/blob_test.py @@ -31,7 +31,7 @@ from pypaimon.common.options import Options from pypaimon.read.reader.format_blob_reader import BlobRecordIterator, FormatBlobReader from pypaimon.schema.data_types import AtomicType, DataField -from pypaimon.table.row.blob import Blob, BlobData, BlobRef, BlobDescriptor +from pypaimon.table.row.blob import Blob, BlobData, BlobRef, BlobDescriptor, BlobViewStruct, BlobView from pypaimon.table.row.generic_row import GenericRowDeserializer, GenericRowSerializer, GenericRow from pypaimon.table.row.row_kind import RowKind @@ -166,6 +166,25 @@ def test_from_bytes_invalid_type_raises(self): with self.assertRaises(TypeError): Blob.from_bytes(12345) + def test_blob_view_struct_roundtrip(self): + """Test BlobViewStruct serialization compatibility.""" + view_struct = BlobViewStruct("test_db.source_table", 7, 42) + serialized = view_struct.serialize() + + self.assertTrue(BlobViewStruct.is_blob_view_struct(serialized)) + self.assertFalse(BlobDescriptor.is_blob_descriptor(serialized)) + + restored = BlobViewStruct.deserialize(serialized) + self.assertEqual(restored, view_struct) + self.assertEqual(restored.identifier.get_full_name(), "test_db.source_table") + self.assertEqual(restored.field_id, 7) + self.assertEqual(restored.row_id, 42) + + blob = Blob.from_bytes(view_struct.serialize()) + self.assertIsInstance(blob, BlobView) + self.assertFalse(blob.is_resolved()) + self.assertEqual(blob.view_struct, view_struct) + def test_blob_data_interface_compliance(self): """Test that BlobData properly implements Blob interface.""" test_data = b"interface test data" @@ -684,6 +703,106 @@ def test_blob_end_to_end(self): reader.close() + def test_blob_read_inline_bytes_reuses_reader_stream(self): + class CountingFileIO: + + def __init__(self, delegate): + self._delegate = delegate + self.input_stream_count = 0 + + def __getattr__(self, name): + return getattr(self._delegate, name) + + def new_input_stream(self, path): + self.input_stream_count += 1 + return self._delegate.new_input_stream(path) + + file_io = LocalFileIO(self.temp_dir, Options({})) + blob_field_name = "blob_field" + blob_data = [b"hello", b"world"] + schema = pa.schema([pa.field(blob_field_name, pa.large_binary())]) + table = pa.table([blob_data], schema=schema) + blob_file_path = Path(self.temp_dir) / (blob_field_name + "_inline.blob") + blob_file_url = _to_url(blob_file_path) + file_io.write_blob(blob_file_url, table) + + counting_file_io = CountingFileIO(file_io) + read_fields = [DataField(0, blob_field_name, AtomicType("BLOB"))] + reader = FormatBlobReader( + file_io=counting_file_io, + file_path=str(blob_file_path), + read_fields=[blob_field_name], + full_fields=read_fields, + push_down_predicate=None, + blob_as_descriptor=False + ) + + batch = reader.read_arrow_batch() + self.assertIsNotNone(batch) + self.assertEqual(batch.num_rows, 2) + self.assertEqual(batch.column(0)[0].as_py(), b"hello") + self.assertEqual(batch.column(0)[1].as_py(), b"world") + self.assertEqual(counting_file_io.input_stream_count, 1) + reader.close() + + def test_blob_reader_row_indices_pushdown(self): + file_io = LocalFileIO(self.temp_dir, Options({})) + blob_field_name = "blob_field" + blob_data = [f"value_{i}".encode("utf-8") for i in range(6)] + schema = pa.schema([pa.field(blob_field_name, pa.large_binary())]) + table = pa.table([blob_data], schema=schema) + blob_file_path = Path(self.temp_dir) / "row_indices.blob" + blob_file_url = _to_url(blob_file_path) + file_io.write_blob(blob_file_url, table) + + read_fields = [DataField(0, blob_field_name, AtomicType("BLOB"))] + reader = FormatBlobReader( + file_io=file_io, + file_path=str(blob_file_path), + read_fields=[blob_field_name], + full_fields=read_fields, + push_down_predicate=None, + blob_as_descriptor=False, + batch_size=2, + row_indices=[1, 3, 4], + ) + try: + batch = reader.read_arrow_batch() + self.assertIsNotNone(batch) + self.assertEqual(batch.column(0).to_pylist(), [blob_data[1], blob_data[3]]) + + batch = reader.read_arrow_batch() + self.assertIsNotNone(batch) + self.assertEqual(batch.column(0).to_pylist(), [blob_data[4]]) + + self.assertIsNone(reader.read_arrow_batch()) + finally: + reader.close() + + def test_blob_reader_row_indices_out_of_range(self): + file_io = LocalFileIO(self.temp_dir, Options({})) + blob_field_name = "blob_field" + blob_data = [b"value_0", b"value_1"] + schema = pa.schema([pa.field(blob_field_name, pa.large_binary())]) + table = pa.table([blob_data], schema=schema) + blob_file_path = Path(self.temp_dir) / "row_indices_out_of_range.blob" + blob_file_url = _to_url(blob_file_path) + file_io.write_blob(blob_file_url, table) + + read_fields = [DataField(0, blob_field_name, AtomicType("BLOB"))] + with self.assertRaises(IndexError) as context: + FormatBlobReader( + file_io=file_io, + file_path=str(blob_file_path), + read_fields=[blob_field_name], + full_fields=read_fields, + push_down_predicate=None, + blob_as_descriptor=False, + row_indices=[0, 2], + ) + + self.assertIn("Blob row index 2 is out of range", str(context.exception)) + def test_blob_complex_types_throw_exception(self): """Test that complex types containing BLOB elements throw exceptions during read/write operations.""" from pypaimon.schema.data_types import DataField, AtomicType, ArrayType, MultisetType, MapType diff --git a/paimon-python/pypaimon/tests/btree_thread_safety_test.py b/paimon-python/pypaimon/tests/btree_thread_safety_test.py new file mode 100644 index 000000000000..b328b512b5b0 --- /dev/null +++ b/paimon-python/pypaimon/tests/btree_thread_safety_test.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Thread-safety tests for BTree global index readers.""" + +import struct +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor + +from pypaimon.globalindex.btree.btree_file_meta_selector import BTreeFileMetaSelector +from pypaimon.globalindex.btree.btree_index_meta import BTreeIndexMeta +from pypaimon.globalindex.btree.lazy_filtered_btree_reader import LazyFilteredBTreeReader +from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta +from pypaimon.globalindex.global_index_reader import FieldRef +from pypaimon.globalindex.global_index_result import GlobalIndexResult +from pypaimon.utils.roaring_bitmap import RoaringBitmap64 + + +def _int_meta(first: int, last: int) -> bytes: + """Create serialized BTreeIndexMeta for INT key range.""" + first_bytes = struct.pack(' GlobalIndexIOMeta: + return GlobalIndexIOMeta( + file_name=file_name, + file_size=1024, + metadata=_int_meta(first, last), + ) + + +class _SlowBTreeReader: + """Simulates a BTreeIndexReader with configurable delay.""" + + def __init__(self, key_serializer, file_io, index_path, io_meta): + self._file_name = io_meta.file_name + self._creation_time = time.monotonic() + + def visit_equal(self, literal): + time.sleep(0.01) + bm = RoaringBitmap64() + bm.add(literal) + return GlobalIndexResult.create(bm) + + def visit_is_not_null(self): + time.sleep(0.01) + bm = RoaringBitmap64() + bm.add_range(0, 10) + return GlobalIndexResult.create(bm) + + def close(self): + pass + + +class LazyFilteredBTreeReaderThreadTest(unittest.TestCase): + """Thread-safety tests for LazyFilteredBTreeReader.""" + + def _create_reader(self, num_files, pool_size, reader_cls=None): + from pypaimon.globalindex.btree.key_serializer import create_serializer + from pypaimon.schema.data_types import AtomicType + + if reader_cls is None: + reader_cls = _SlowBTreeReader + key_serializer = create_serializer(AtomicType("INT")) + io_metas = [_io_meta(f"file-{i}.index", i * 10, (i + 1) * 10 - 1) + for i in range(num_files)] + executor = ThreadPoolExecutor(max_workers=pool_size) + + import pypaimon.globalindex.btree.lazy_filtered_btree_reader as mod + original = mod.BTreeIndexReader + mod.BTreeIndexReader = reader_cls + reader = LazyFilteredBTreeReader( + key_serializer=key_serializer, + file_io=None, + index_path="/unused", + io_metas=io_metas, + executor=executor, + ) + # Keep the patch active — reader_cls stays in module + return reader, executor, (mod, original) + + def _cleanup(self, reader, executor, patch_info): + reader.close() + executor.shutdown(wait=False) + mod, original = patch_info + mod.BTreeIndexReader = original + + def test_no_deadlock_more_files_than_threads(self): + """With 8 files and only 2 threads, callback-based chaining must not deadlock.""" + reader, executor, patch = self._create_reader(num_files=8, pool_size=2) + field_ref = FieldRef(0, "id", "INT") + try: + future = reader.visit_is_not_null(field_ref) + result = future.result(timeout=5.0) + self.assertIsNotNone(result) + self.assertFalse(result.is_empty()) + finally: + self._cleanup(reader, executor, patch) + + def test_concurrent_visits_return_correct_results(self): + """Multiple concurrent visit_equal calls return correct disjoint results.""" + reader, executor, patch = self._create_reader(num_files=4, pool_size=8) + field_ref = FieldRef(0, "id", "INT") + try: + futures = [] + for i in range(20): + val = (i % 4) * 10 + 5 + futures.append((val, reader.visit_equal(field_ref, val))) + + for expected_val, future in futures: + result = future.result(timeout=5.0) + self.assertIsNotNone(result) + hits = list(result.results()) + self.assertIn(expected_val, hits) + finally: + self._cleanup(reader, executor, patch) + + def test_lazy_creation_only_once_per_file(self): + """_get_or_create_reader must create each reader exactly once under concurrency.""" + creation_counts = {} + creation_lock = threading.Lock() + + class _CountingReader: + def __init__(self_inner, key_serializer, file_io, index_path, io_meta): + with creation_lock: + name = io_meta.file_name + creation_counts[name] = creation_counts.get(name, 0) + 1 + + def visit_equal(self_inner, literal): + bm = RoaringBitmap64() + bm.add(literal) + return GlobalIndexResult.create(bm) + + def close(self_inner): + pass + + reader, executor, patch = self._create_reader( + num_files=4, pool_size=8, reader_cls=_CountingReader) + field_ref = FieldRef(0, "id", "INT") + try: + barrier = threading.Barrier(16) + + def do_query(): + barrier.wait() + return reader.visit_equal(field_ref, 5).result(timeout=5.0) + + with ThreadPoolExecutor(max_workers=16) as query_pool: + query_futures = [query_pool.submit(do_query) for _ in range(16)] + for f in query_futures: + f.result(timeout=10.0) + + for name, count in creation_counts.items(): + self.assertEqual(1, count, + f"File {name} was created {count} times (expected 1)") + finally: + self._cleanup(reader, executor, patch) + + def test_selector_prunes_files_correctly(self): + """BTreeFileMetaSelector correctly prunes files that cannot match.""" + from pypaimon.globalindex.btree.key_serializer import create_serializer + from pypaimon.schema.data_types import AtomicType + + key_serializer = create_serializer(AtomicType("INT")) + io_metas = [ + _io_meta("file-0.index", 0, 9), + _io_meta("file-1.index", 10, 19), + _io_meta("file-2.index", 20, 29), + ] + files = [(m, BTreeIndexMeta.deserialize(m.metadata)) for m in io_metas] + selector = BTreeFileMetaSelector(files, key_serializer) + + # equal(5) should only match file-0 + result = selector.select_equal(5) + self.assertEqual(1, len(result)) + self.assertEqual("file-0.index", result[0].file_name) + + # equal(15) should only match file-1 + result = selector.select_equal(15) + self.assertEqual(1, len(result)) + self.assertEqual("file-1.index", result[0].file_name) + + # less_than(15) should match file-0 and file-1 + result = selector.select_less_than(15) + self.assertEqual(2, len(result)) + + # greater_than(25) should match file-2 + result = selector.select_greater_than(25) + self.assertEqual(1, len(result)) + self.assertEqual("file-2.index", result[0].file_name) + + # between(5, 15) should match file-0 and file-1 + result = selector.select_between(5, 15) + self.assertEqual(2, len(result)) + + # equal(100) should match nothing + result = selector.select_equal(100) + self.assertEqual(0, len(result)) + + # in([5, 25]) should match file-0 and file-2 + result = selector.select_in([5, 25]) + self.assertEqual(2, len(result)) + names = {m.file_name for m in result} + self.assertEqual({"file-0.index", "file-2.index"}, names) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/column_directive_utils_test.py b/paimon-python/pypaimon/tests/column_directive_utils_test.py new file mode 100644 index 000000000000..4abe030336bf --- /dev/null +++ b/paimon-python/pypaimon/tests/column_directive_utils_test.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from pypaimon.common.options.core_options import CoreOptions +from pypaimon.schema.column_directive_utils import ( + apply_add_column_directive, + apply_directives, + parse_add_column_comment, + remove_dropped_directive_options, +) +from pypaimon.schema.data_types import ( + ArrayType, AtomicType, DataField, VectorType, +) + + +class TestParseAddColumnComment(unittest.TestCase): + + def test_none_and_empty(self): + self.assertIsNone(parse_add_column_comment(None)) + self.assertIsNone(parse_add_column_comment("")) + self.assertIsNone(parse_add_column_comment("normal comment")) + + def test_blob_field(self): + d = parse_add_column_comment("__BLOB_FIELD; picture") + self.assertEqual(d.option_key, CoreOptions.BLOB_FIELD.key()) + self.assertEqual(d.real_comment, "picture") + self.assertFalse(d.is_vector) + + def test_blob_field_bare(self): + d = parse_add_column_comment("__BLOB_FIELD") + self.assertIsNone(d.real_comment) + + def test_blob_descriptor_field(self): + d = parse_add_column_comment("__BLOB_DESCRIPTOR_FIELD; desc") + self.assertEqual(d.option_key, CoreOptions.BLOB_DESCRIPTOR_FIELD.key()) + + def test_blob_view_field(self): + d = parse_add_column_comment("__BLOB_VIEW_FIELD; view") + self.assertEqual(d.option_key, CoreOptions.BLOB_VIEW_FIELD.key()) + + def test_blob_external_storage_field(self): + d = parse_add_column_comment("__BLOB_EXTERNAL_STORAGE_FIELD; ext") + self.assertEqual(d.option_key, CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key()) + + def test_vector_field(self): + d = parse_add_column_comment("__VECTOR_FIELD;128; embedding") + self.assertEqual(d.option_key, CoreOptions.VECTOR_FIELD.key()) + self.assertTrue(d.is_vector) + self.assertEqual(d.vector_dim, 128) + self.assertEqual(d.real_comment, "embedding") + + def test_vector_field_no_comment(self): + d = parse_add_column_comment("__VECTOR_FIELD;64") + self.assertEqual(d.vector_dim, 64) + self.assertIsNone(d.real_comment) + + def test_unknown_blob_directive_rejected(self): + with self.assertRaises(ValueError): + parse_add_column_comment("__BLOB_UNKNOWN") + + def test_vector_without_dim_rejected(self): + with self.assertRaises(ValueError): + parse_add_column_comment("__VECTOR_FIELD") + + def test_vector_non_integer_dim_rejected(self): + with self.assertRaises(ValueError): + parse_add_column_comment("__VECTOR_FIELD;abc") + + +class TestApplyAddColumnDirective(unittest.TestCase): + + def test_non_directive_returns_none(self): + opts = {} + result = apply_add_column_directive( + "normal", "col", AtomicType("BYTES"), opts + ) + self.assertIsNone(result) + self.assertEqual(opts, {}) + + def test_blob_field(self): + opts = {} + result = apply_add_column_directive( + "__BLOB_FIELD; pic", "pic", AtomicType("BYTES"), opts + ) + self.assertIsNotNone(result) + self.assertEqual(result.type.type, "BLOB") + self.assertEqual(result.comment, "pic") + self.assertEqual(opts[CoreOptions.BLOB_FIELD.key()], "pic") + + def test_vector_field(self): + opts = {} + result = apply_add_column_directive( + "__VECTOR_FIELD;128; emb", + "emb", + ArrayType(True, AtomicType("FLOAT")), + opts + ) + self.assertIsNotNone(result) + self.assertIsInstance(result.type, VectorType) + self.assertEqual(result.type.length, 128) + self.assertEqual(result.comment, "emb") + self.assertEqual(opts[CoreOptions.VECTOR_FIELD.key()], "emb") + + def test_external_storage_registers_both(self): + opts = {} + apply_add_column_directive( + "__BLOB_EXTERNAL_STORAGE_FIELD", "vid", AtomicType("BYTES"), opts + ) + self.assertEqual(opts[CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key()], "vid") + self.assertEqual(opts[CoreOptions.BLOB_DESCRIPTOR_FIELD.key()], "vid") + + def test_blob_rejects_non_binary(self): + with self.assertRaises(ValueError): + apply_add_column_directive( + "__BLOB_FIELD", "col", AtomicType("INT"), {} + ) + + def test_vector_rejects_non_array(self): + with self.assertRaises(ValueError): + apply_add_column_directive( + "__VECTOR_FIELD;128", "col", AtomicType("INT"), {} + ) + + def test_appends_to_existing(self): + opts = {CoreOptions.BLOB_FIELD.key(): "a"} + apply_add_column_directive( + "__BLOB_FIELD", "b", AtomicType("BYTES"), opts + ) + self.assertEqual(opts[CoreOptions.BLOB_FIELD.key()], "a,b") + + def test_migrates_legacy_fallback_key(self): + opts = {"blob.stored-descriptor-fields": "legacy_col"} + apply_add_column_directive( + "__BLOB_DESCRIPTOR_FIELD", "new_col", AtomicType("BYTES"), opts + ) + self.assertEqual( + opts[CoreOptions.BLOB_DESCRIPTOR_FIELD.key()], "legacy_col,new_col" + ) + self.assertNotIn("blob.stored-descriptor-fields", opts) + + +class TestApplyDirectives(unittest.TestCase): + + def test_no_directives(self): + fields = [DataField(0, "k", AtomicType("INT"))] + opts = {} + changed = apply_directives(fields, opts) + self.assertFalse(changed) + self.assertEqual(fields[0].type.type, "INT") + + def test_mixed_fields(self): + fields = [ + DataField(0, "k", AtomicType("INT")), + DataField(1, "pic", AtomicType("BYTES"), "__BLOB_FIELD; picture"), + DataField( + 2, "emb", ArrayType(True, AtomicType("FLOAT")), + "__VECTOR_FIELD;64; my emb" + ), + ] + opts = {} + changed = apply_directives(fields, opts) + self.assertTrue(changed) + self.assertEqual(fields[1].type.type, "BLOB") + self.assertEqual(fields[1].description, "picture") + self.assertIsInstance(fields[2].type, VectorType) + self.assertEqual(fields[2].type.length, 64) + self.assertEqual(fields[2].description, "my emb") + self.assertEqual(opts[CoreOptions.BLOB_FIELD.key()], "pic") + self.assertEqual(opts[CoreOptions.VECTOR_FIELD.key()], "emb") + + +class TestRemoveDroppedDirectiveOptions(unittest.TestCase): + + def test_drop_blob(self): + opts = { + CoreOptions.BLOB_FIELD.key(): "a,b", + CoreOptions.BLOB_DESCRIPTOR_FIELD.key(): "b", + CoreOptions.VECTOR_FIELD.key(): "v", + } + remove_dropped_directive_options("b", "BLOB", opts) + self.assertEqual(opts[CoreOptions.BLOB_FIELD.key()], "a") + self.assertNotIn(CoreOptions.BLOB_DESCRIPTOR_FIELD.key(), opts) + self.assertEqual(opts[CoreOptions.VECTOR_FIELD.key()], "v") + + def test_drop_vector(self): + opts = { + CoreOptions.VECTOR_FIELD.key(): "emb,emb2", + "field.emb.vector-dim": "128", + CoreOptions.BLOB_FIELD.key(): "a", + } + remove_dropped_directive_options("emb", "VECTOR", opts) + self.assertEqual(opts[CoreOptions.VECTOR_FIELD.key()], "emb2") + self.assertNotIn("field.emb.vector-dim", opts) + self.assertEqual(opts[CoreOptions.BLOB_FIELD.key()], "a") + + def test_drop_blob_cleans_fallback_keys(self): + opts = { + CoreOptions.BLOB_DESCRIPTOR_FIELD.key(): "b,c", + "blob.stored-descriptor-fields": "b,legacy", + } + remove_dropped_directive_options("b", "BLOB", opts) + self.assertEqual(opts[CoreOptions.BLOB_DESCRIPTOR_FIELD.key()], "c") + self.assertEqual(opts["blob.stored-descriptor-fields"], "legacy") + + def test_drop_non_directive_is_noop(self): + opts = {CoreOptions.BLOB_FIELD.key(): "a"} + remove_dropped_directive_options("x", "INT", opts) + self.assertEqual(opts[CoreOptions.BLOB_FIELD.key()], "a") + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/daft/daft_catalog_test.py b/paimon-python/pypaimon/tests/daft/daft_catalog_test.py index 0b34452ed40f..95519aa3be55 100644 --- a/paimon-python/pypaimon/tests/daft/daft_catalog_test.py +++ b/paimon-python/pypaimon/tests/daft/daft_catalog_test.py @@ -156,6 +156,7 @@ def test_catalog_has_table(paimon_catalog): assert daft_catalog.has_table("test_db.test_table") assert not daft_catalog.has_table("test_db.nonexistent_table") assert not daft_catalog.has_table("nonexistent_db.test_table") + assert not daft_catalog.has_table("missing_table") def test_catalog_list_tables(paimon_catalog): @@ -188,6 +189,14 @@ def test_catalog_get_table_not_found(paimon_catalog): daft_catalog.get_table("test_db.nonexistent_table") +def test_catalog_get_table_single_part_not_found(paimon_catalog): + from daft.catalog import NotFoundError + + daft_catalog, _, _ = paimon_catalog + with pytest.raises(NotFoundError): + daft_catalog.get_table("missing_table") + + def test_catalog_drop_table(paimon_catalog): daft_catalog, _, _ = paimon_catalog assert daft_catalog.has_table("test_db.test_table") @@ -195,6 +204,14 @@ def test_catalog_drop_table(paimon_catalog): assert not daft_catalog.has_table("test_db.test_table") +def test_catalog_drop_table_single_part_not_found(paimon_catalog): + from daft.catalog import NotFoundError + + daft_catalog, _, _ = paimon_catalog + with pytest.raises(NotFoundError): + daft_catalog.drop_table("missing_table") + + def test_catalog_create_table(tmp_path): inner = pypaimon.CatalogFactory.create({"warehouse": str(tmp_path)}) inner.create_database("mydb", ignore_if_exists=True) diff --git a/paimon-python/pypaimon/tests/daft/daft_data_test.py b/paimon-python/pypaimon/tests/daft/daft_data_test.py index 9d7795cb9797..5520b3710072 100644 --- a/paimon-python/pypaimon/tests/daft/daft_data_test.py +++ b/paimon-python/pypaimon/tests/daft/daft_data_test.py @@ -39,11 +39,41 @@ from pypaimon.daft.daft_predicate_visitor import convert_filters_to_paimon +def _contains_expr(py_expr): + from daft.expressions import Expression + + expr_text = str(Expression._from_pyexpr(py_expr)) + return "contains" in expr_text + + +def _predicate_leaves(predicate): + if predicate is None: + return [] + if predicate.method in ("and", "or"): + result = [] + for child in predicate.literals: + result.extend(_predicate_leaves(child)) + return result + return [(predicate.method, predicate.field, tuple(predicate.literals or []))] + + # --------------------------------------------------------------------------- # Helper # --------------------------------------------------------------------------- +class _UnserializableFileIoMarker: + def __reduce__(self): + raise TypeError("file io marker should not be serialized") + + +class _UnserializableStorageConfig: + multithreaded_io = False + + def __reduce__(self): + raise TypeError("storage config marker should not be serialized") + + def _write_to_paimon(table, arrow_table, mode="append", overwrite_partition=None): write_builder = table.new_batch_write_builder() if mode == "overwrite": @@ -59,6 +89,18 @@ def _write_to_paimon(table, arrow_table, mode="append", overwrite_partition=None table_commit.close() +async def _collect_paimon_source_batches(source, pushdowns): + batches = [] + fallback_task_count = 0 + async for task in source.get_tasks(pushdowns): + if type(task).__name__ == "_PaimonPKSplitTask": + fallback_task_count += 1 + async for batch in task.read(): + batches.append(batch.to_pydict()) + assert fallback_task_count > 0 + return batches + + async def _read_paimon_source_batches( table, filter_expr=None, @@ -81,16 +123,8 @@ async def _read_paimon_source_batches( assert pushed_filters assert not remaining_filters - batches = [] - fallback_task_count = 0 pushdowns = Pushdowns(filters=filter_expr, columns=columns, limit=limit) - async for task in source.get_tasks(pushdowns): - if type(task).__name__ == "_PaimonPKSplitTask": - fallback_task_count += 1 - async for batch in task.read(): - batches.append(batch.to_pydict()) - assert fallback_task_count > 0 - return batches + return await _collect_paimon_source_batches(source, pushdowns) # --------------------------------------------------------------------------- @@ -191,6 +225,148 @@ def test_read_paimon_schema_matches(append_only_table): assert "dt" in schema.column_names() +def test_read_paimon_source_is_serializable(append_only_table): + """The Daft source must not serialize live table/file_io/storage objects.""" + from daft.pickle import dumps, loads + + from pypaimon.daft.daft_datasource import PaimonDataSource + + table, _ = append_only_table + table.file_io._unserializable_marker = _UnserializableFileIoMarker() + + source = PaimonDataSource( + table, + storage_config=_UnserializableStorageConfig(), + catalog_options={}, + ) + + restored = loads(dumps(source)) + + assert restored is not source + assert restored.schema.column_names() == source.schema.column_names() + assert restored._table is not table + assert restored._table.identifier.get_full_name() == table.identifier.get_full_name() + assert restored._storage_config.multithreaded_io is False + + +def test_read_paimon_source_serialization_preserves_pushed_filter_for_fallback(local_paimon_catalog): + """A serialized source must keep filters accepted by SupportsPushdownFilters.""" + from daft import context, runners + from daft.daft import StorageConfig + from daft.io.pushdowns import Pushdowns + from daft.pickle import dumps, loads + + from pypaimon.daft.daft_datasource import PaimonDataSource + + catalog, _ = local_paimon_catalog + schema = pypaimon.Schema.from_pyarrow_schema( + pa.schema([ + pa.field("id", pa.int64()), + pa.field("name", pa.string()), + ]), + options={ + "file.format": "avro", + "source.split.target-size": "800b", + "source.split.open-file-cost": "600b", + }, + ) + catalog.create_table("test_db.avro_serialized_pushdown_filter", schema, ignore_if_exists=False) + table = catalog.get_table("test_db.avro_serialized_pushdown_filter") + _write_to_paimon(table, pa.table({"id": [1], "name": ["first"]})) + _write_to_paimon(table, pa.table({"id": [999], "name": ["match"]})) + + io_config = context.get_context().daft_planning_config.default_io_config + storage_config = StorageConfig(runners.get_or_create_runner().name != "ray", io_config) + source = PaimonDataSource(table, storage_config=storage_config, catalog_options={}) + pushed_filters, remaining_filters = source.push_filters([(col("id") == 999)._expr]) + assert pushed_filters + assert not remaining_filters + + restored = loads(dumps(source)) + batches = asyncio.run( + _collect_paimon_source_batches( + restored, + Pushdowns(filters=None, limit=1), + ) + ) + + assert batches == [{"id": [999], "name": ["match"]}] + + +def test_read_paimon_remote_ray_task_is_serializable(pk_table, monkeypatch): + """A fallback PK split task must reopen the table from metadata on Ray workers. + + Splits that need an LSM merge (here, overlapping primary-key writes) are read + by the pypaimon reader task. Under the Ray runner that task is pickled to + remote workers, so it must serialize only rebuildable metadata -- never the + live table / file_io / storage objects. + """ + from daft import runners + from daft.io.pushdowns import Pushdowns + from daft.pickle import dumps, loads + + from pypaimon.daft.daft_datasource import PaimonDataSource + + class _RayRunner: + name = "ray" + + table, _ = pk_table + # Two overlapping writes on id=1 create non-raw-convertible splits that + # require the pypaimon merge reader (the fallback _PaimonPKSplitTask). + _write_to_paimon( + table, + pa.table( + { + "id": pa.array([1, 2], pa.int64()), + "name": pa.array(["old_a", "old_b"], pa.string()), + "dt": pa.array(["2024-01-01", "2024-01-01"], pa.string()), + } + ), + ) + _write_to_paimon( + table, + pa.table( + { + "id": pa.array([1], pa.int64()), + "name": pa.array(["new_a"], pa.string()), + "dt": pa.array(["2024-01-01"], pa.string()), + } + ), + ) + table.file_io._unserializable_marker = _UnserializableFileIoMarker() + + source = PaimonDataSource( + table, + storage_config=_UnserializableStorageConfig(), + catalog_options={}, + ) + monkeypatch.setattr(runners, "get_or_create_runner", lambda: _RayRunner()) + + async def first_task(): + async for task in source.get_tasks(Pushdowns()): + return task + raise AssertionError("Expected at least one task") + + async def read_task(task): + rows = [] + async for batch in task.read(): + rows.append(batch.to_pydict()) + return rows + + task = asyncio.run(first_task()) + assert type(task).__name__ == "_PaimonPKSplitTask" + + restored_task = loads(dumps(task)) + batches = asyncio.run(read_task(restored_task)) + + merged = { + _id: name + for batch in batches + for _id, name in zip(batch["id"], batch["name"]) + } + assert merged == {1: "new_a", 2: "old_b"} + + # --------------------------------------------------------------------------- # Multi-partition reads # --------------------------------------------------------------------------- @@ -440,6 +616,34 @@ def test_read_paimon_fallback_plans_pushdown_filter_without_push_filters(local_p assert batches == [{"id": [999], "name": ["match"]}] +def test_read_paimon_fallback_not_in_filter_excludes_nulls_before_limit(local_paimon_catalog): + """Fallback datasource tasks must satisfy pushed NOT IN filters before limit.""" + catalog, _ = local_paimon_catalog + pa_schema = pa.schema([ + pa.field("id", pa.int64()), + pa.field("name", pa.string()), + ]) + schema = pypaimon.Schema.from_pyarrow_schema( + pa_schema, + options={"file.format": "row"}, + ) + catalog.create_table("test_db.row_not_in_filter_limit", schema, ignore_if_exists=False) + table = catalog.get_table("test_db.row_not_in_filter_limit") + _write_to_paimon(table, pa.table({"id": [None], "name": ["null-row"]}, schema=pa_schema)) + _write_to_paimon(table, pa.table({"id": [3], "name": ["match"]}, schema=pa_schema)) + + batches = asyncio.run( + _read_paimon_source_batches( + table, + filter_expr=~col("id").is_in([1, 2]), + limit=1, + call_push_filters=False, + ) + ) + + assert batches == [{"id": [3], "name": ["match"]}] + + def test_read_paimon_fallback_keeps_limit_above_remaining_filter(local_paimon_catalog): """Fallback reads must not apply limit before Daft evaluates remaining filters.""" catalog, _ = local_paimon_catalog @@ -674,6 +878,108 @@ def test_filter_pushdown_combined(self, filter_table): result = df.to_pydict() assert result["id"] == [2, 3, 4] + def test_filter_pushdown_splits_supported_conjuncts(self, filter_table): + unsupported = col("value").contains(col("id")) + cases = [ + ((col("id") == 1) & unsupported, [("equal", "id", (1,))]), + (unsupported & (col("id") == 1), [("equal", "id", (1,))]), + ( + (col("id") == 1) & (unsupported & (col("value") == "a")), + [("equal", "id", (1,)), ("equal", "value", ("a",))], + ), + ] + + for expr, expected_leaves in cases: + pushed_filters, remaining_filters, predicate = convert_filters_to_paimon(filter_table, expr._expr) + + assert len(pushed_filters) == len(expected_leaves) + assert len(remaining_filters) == 1 + assert _contains_expr(remaining_filters[0]) + assert _predicate_leaves(predicate) == expected_leaves + + def test_filter_pushdown_does_not_split_or_with_unsupported_branch(self, filter_table): + expr = (col("id") == 1) | col("value").contains(col("id")) + + pushed_filters, remaining_filters, predicate = convert_filters_to_paimon(filter_table, expr._expr) + + assert pushed_filters == [] + assert remaining_filters == [expr._expr] + assert predicate is None + + def test_filter_pushdown_pushes_supported_or_conjunct(self, filter_table): + supported_or = (col("id") == 1) | (col("id") == 2) + expr = supported_or & col("value").contains(col("id")) + + pushed_filters, remaining_filters, predicate = convert_filters_to_paimon(filter_table, expr._expr) + + assert len(pushed_filters) == 1 + assert len(remaining_filters) == 1 + assert _contains_expr(remaining_filters[0]) + assert predicate is not None + assert predicate.method == "or" + assert _predicate_leaves(predicate) == [("equal", "id", (1,)), ("equal", "id", (2,))] + + def test_filter_pushdown_rewrites_supported_not_predicates(self, filter_table): + cases = [ + (~(col("id") == 1), [("notEqual", "id", (1,))]), + (~col("id").is_in([1, 2]), [("notIn", "id", (1, 2))]), + (~col("id").between(1, 3), [("notBetween", "id", (1, 3))]), + (~col("value").is_null(), [("isNotNull", "value", ())]), + (~col("value").not_null(), [("isNull", "value", ())]), + ] + + for expr, expected_leaves in cases: + pushed_filters, remaining_filters, predicate = convert_filters_to_paimon(filter_table, expr._expr) + + assert len(pushed_filters) == 1 + assert remaining_filters == [] + assert _predicate_leaves(predicate) == expected_leaves + + def test_filter_pushdown_supported_not_predicates_read_path(self, local_paimon_catalog): + catalog, tmp_path = local_paimon_catalog + pa_schema = pa.schema([ + ("id", pa.int64()), + ("value", pa.string()), + ]) + paimon_schema = pypaimon.Schema.from_pyarrow_schema(pa_schema) + catalog.create_table("test_db.filter_not_read", paimon_schema, ignore_if_exists=True) + table = catalog.get_table("test_db.filter_not_read") + + data = pa.table( + { + "id": [1, 2, 3, 4, 5], + "value": ["a", "b", None, "d", None], + }, + schema=pa_schema, + ) + _write_to_paimon(table, data) + + cases = [ + (~(col("id") == 1), [2, 3, 4, 5]), + (~col("id").is_in([1, 3]), [2, 4, 5]), + (~col("id").between(2, 4), [1, 5]), + (~col("value").is_null(), [1, 2, 4]), + (~col("value").not_null(), [3, 5]), + ] + + for expr, expected_ids in cases: + result = _read_table(table).where(expr).select("id").sort("id").to_pydict() + + assert result["id"] == expected_ids + + def test_filter_pushdown_does_not_demorgan_not_compound_predicates(self, filter_table): + expressions = [ + ~((col("id") == 1) & (col("value") == "a")), + ~((col("id") == 1) | (col("value") == "a")), + ] + + for expr in expressions: + pushed_filters, remaining_filters, predicate = convert_filters_to_paimon(filter_table, expr._expr) + + assert pushed_filters == [] + assert remaining_filters == [expr._expr] + assert predicate is None + def test_unsupported_expression_remains_in_daft(self, filter_table): expressions = [ col("id") == lit(1).cast("int64"), @@ -718,6 +1024,32 @@ def test_mixed_string_expression_is_filtered_by_daft(self, local_paimon_catalog) assert result["id"] == [1, 3] + def test_mixed_conjunctive_expression_is_filtered_by_daft(self, local_paimon_catalog): + catalog, tmp_path = local_paimon_catalog + pa_schema = pa.schema([ + ("id", pa.int64()), + ("value", pa.string()), + ("pattern", pa.string()), + ]) + paimon_schema = pypaimon.Schema.from_pyarrow_schema(pa_schema) + catalog.create_table("test_db.filter_mixed_conjunctive", paimon_schema, ignore_if_exists=True) + table = catalog.get_table("test_db.filter_mixed_conjunctive") + + data = pa.table( + { + "id": [1, 1, 2, 3], + "value": ["alpha", "bravo", "alps", "charlie"], + "pattern": ["lp", "zz", "lp", "lie"], + } + ) + _write_to_paimon(table, data) + + df = _read_table(table).where((col("id") == 1) & col("value").contains(col("pattern"))) + result = df.sort("value").to_pydict() + + assert result["id"] == [1] + assert result["value"] == ["alpha"] + # --------------------------------------------------------------------------- # Advanced data types diff --git a/paimon-python/pypaimon/tests/daft/daft_explain_test.py b/paimon-python/pypaimon/tests/daft/daft_explain_test.py new file mode 100644 index 000000000000..0846f591214c --- /dev/null +++ b/paimon-python/pypaimon/tests/daft/daft_explain_test.py @@ -0,0 +1,420 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Tests for Daft-side Paimon scan explain diagnostics.""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +pypaimon = pytest.importorskip("pypaimon") +daft = pytest.importorskip("daft") + +from daft import col + +from pypaimon.daft import explain_paimon_scan +from pypaimon.daft.daft_catalog import PaimonTable +from pypaimon.daft.daft_explain import ( + READER_MODE_NATIVE_PARQUET, + READER_MODE_PYPAIMON_FALLBACK, +) +from pypaimon.daft.daft_compat import has_file_range_reads +from pypaimon.daft.daft_datasource import PaimonDataSource +from pypaimon.daft.daft_paimon import _explain_table +from pypaimon.read.explain import ExplainResult, ExplainSplitInfo + + +requires_blob = pytest.mark.skipif(not has_file_range_reads(), reason="BLOB support requires daft >= 0.7.11") + + +@pytest.fixture +def catalog_options(tmp_path): + options = {"warehouse": str(tmp_path)} + catalog = pypaimon.CatalogFactory.create(options) + catalog.create_database("test_db", ignore_if_exists=True) + return options + + +def _create_table( + catalog_options, + table_name: str, + pa_schema: pa.Schema, + *, + partition_keys: list[str] | None = None, + primary_keys: list[str] | None = None, + options: dict[str, str] | None = None, +): + identifier = f"test_db.{table_name}" + catalog = pypaimon.CatalogFactory.create(catalog_options) + schema = pypaimon.Schema.from_pyarrow_schema( + pa_schema, + partition_keys=partition_keys, + primary_keys=primary_keys, + options=options, + ) + catalog.create_table(identifier, schema, ignore_if_exists=False) + return identifier, catalog.get_table(identifier) + + +def _write_arrow(table, data: pa.Table) -> None: + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + try: + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + finally: + table_write.close() + table_commit.close() + + +def _single_split_explain( + *, + table_identifier: str, + raw_convertible: bool, + has_deletion_vectors: bool, +) -> ExplainResult: + split = ExplainSplitInfo( + partition={}, + bucket=0, + file_count=1, + row_count=4, + merged_row_count=None, + file_size=128, + raw_convertible=raw_convertible, + has_deletion_vectors=has_deletion_vectors, + level_histogram={0: 1}, + deletion_file_count=1 if has_deletion_vectors else 0, + file_paths=["/tmp/fake.parquet"], + ) + return ExplainResult( + table_identifier=table_identifier, + is_primary_key_table=False, + bucket_mode="unaware", + deletion_vectors_enabled=has_deletion_vectors, + data_evolution_enabled=False, + snapshot_id=1, + schema_id=0, + file_count=1, + total_file_size=split.file_size, + estimated_row_count=split.row_count, + deletion_file_count=split.deletion_file_count, + level_histogram=split.level_histogram, + split_count=1, + splits_raw_convertible=1 if raw_convertible else 0, + splits_with_deletion_vectors=1 if has_deletion_vectors else 0, + files_per_split_min=1, + files_per_split_max=1, + files_per_split_avg=1.0, + split_size_min=split.file_size, + split_size_max=split.file_size, + split_size_avg=float(split.file_size), + split_size_p50=split.file_size, + split_size_p95=split.file_size, + splits=[split], + ) + + +def test_explain_paimon_scan_reports_native_parquet_routing(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + identifier, table = _create_table( + catalog_options, + "explain_native", + pa_schema, + options={"bucket": "-1", "file.format": "parquet"}, + ) + _write_arrow(table, pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}, schema=pa_schema)) + + result = explain_paimon_scan( + identifier, + catalog_options, + filters=col("id") == 2, + columns=["name"], + limit=1, + verbose=True, + ) + + assert result.native_parquet_split_count == result.paimon_scan.split_count + assert result.native_parquet_split_count > 0 + assert result.pypaimon_fallback_split_count == 0 + assert result.fallback_reasons == {} + assert result.requested_columns == ["name"] + assert result.requested_limit == 1 + assert result.source_limit == 1 + assert result.limit_pushed is True + assert any("id" in pushed for pushed in result.pushed_filters) + assert result.remaining_filters == [] + assert result.splits is not None + assert all(split.reader_mode == READER_MODE_NATIVE_PARQUET for split in result.splits) + assert "Daft Paimon Scan" in str(result) + assert "PyPaimon Scan Plan" in str(result) + + +def test_explain_scan_keeps_limit_above_remaining_filters(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + identifier, table = _create_table( + catalog_options, + "explain_remaining_filter", + pa_schema, + options={"bucket": "-1", "file.format": "parquet"}, + ) + _write_arrow(table, pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=pa_schema)) + + result = PaimonTable(table, catalog_options=catalog_options).explain_scan( + filters=~((col("id") == 1) & (col("name") == "a")), + limit=1, + ) + + assert result.native_parquet_split_count == result.paimon_scan.split_count + assert result.pypaimon_fallback_split_count == 0 + assert result.pushed_filters == [] + assert any("id" in remaining for remaining in result.remaining_filters) + assert result.source_limit is None + assert result.limit_pushed is False + assert result.splits is None + assert result.paimon_scan.splits is None + + +def test_explain_scan_pushes_supported_not_and_limit(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + identifier, table = _create_table( + catalog_options, + "explain_not_filter", + pa_schema, + options={"bucket": "-1", "file.format": "parquet"}, + ) + _write_arrow(table, pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=pa_schema)) + + result = PaimonTable(table, catalog_options=catalog_options).explain_scan( + filters=~(col("id") == 1), + limit=1, + ) + + assert result.native_parquet_split_count == result.paimon_scan.split_count + assert result.pypaimon_fallback_split_count == 0 + assert any("!=" in pushed or "not" in pushed for pushed in result.pushed_filters) + assert result.remaining_filters == [] + assert result.source_limit == 1 + assert result.limit_pushed is True + assert result.splits is None + assert result.paimon_scan.splits is None + + +def test_explain_scan_partially_pushes_conjuncts_and_keeps_limit(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + identifier, table = _create_table( + catalog_options, + "explain_partial_conjunct", + pa_schema, + options={"bucket": "-1", "file.format": "parquet"}, + ) + _write_arrow(table, pa.table({"id": [1, 2], "name": ["alpha", "bravo"]}, schema=pa_schema)) + + result = PaimonTable(table, catalog_options=catalog_options).explain_scan( + filters=(col("id") == 1) & col("name").contains(col("id")), + limit=1, + ) + + assert any("id" in pushed for pushed in result.pushed_filters) + assert any("contains" in remaining for remaining in result.remaining_filters) + assert result.source_limit is None + assert result.limit_pushed is False + + +def test_explain_scan_applies_partition_filters_to_reader_counts(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ("dt", pa.string()), + ]) + identifier, table = _create_table( + catalog_options, + "explain_partition_filter", + pa_schema, + partition_keys=["dt"], + options={"bucket": "1", "file.format": "parquet"}, + ) + _write_arrow( + table, + pa.table({"id": [1], "name": ["a"], "dt": ["2024-01-01"]}, schema=pa_schema), + ) + _write_arrow( + table, + pa.table({"id": [2], "name": ["b"], "dt": ["2024-01-02"]}, schema=pa_schema), + ) + + result = explain_paimon_scan( + identifier, + catalog_options, + partition_filters=col("dt") == "2024-01-02", + verbose=True, + ) + + assert result.paimon_scan.split_count == 2 + assert result.native_parquet_split_count == 1 + assert result.pypaimon_fallback_split_count == 0 + assert any("dt" in partition_filter for partition_filter in result.partition_filters) + assert result.splits is not None + assert len(result.splits) == 1 + assert result.splits[0].partition == {"dt": "2024-01-02"} + + +def test_explain_scan_reports_pk_lsm_fallback(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ("dt", pa.string()), + ]) + _, table = _create_table( + catalog_options, + "explain_pk_fallback", + pa_schema, + partition_keys=["dt"], + primary_keys=["id", "dt"], + options={"bucket": "1", "file.format": "parquet"}, + ) + _write_arrow( + table, + pa.table({"id": [1, 2], "name": ["old-a", "old-b"], "dt": ["2024-01-01", "2024-01-01"]}, schema=pa_schema), + ) + _write_arrow( + table, + pa.table({"id": [1], "name": ["new-a"], "dt": ["2024-01-01"]}, schema=pa_schema), + ) + + result = _explain_table( + table, + catalog_options=catalog_options, + filters=col("id") == 1, + columns=["name"], + limit=1, + verbose=True, + ) + + assert result.pypaimon_fallback_split_count > 0 + assert result.native_parquet_split_count == 0 + assert result.fallback_reasons["LSM merge required"] == result.pypaimon_fallback_split_count + assert result.fallback_read_columns is not None + assert "name" in result.fallback_read_columns + assert "id" in result.fallback_read_columns + assert result.splits is not None + assert all(split.reader_mode == READER_MODE_PYPAIMON_FALLBACK for split in result.splits) + assert all(split.fallback_reason == "LSM merge required" for split in result.splits) + + +def test_explain_scan_reports_non_parquet_fallback(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + _, table = _create_table( + catalog_options, + "explain_avro_fallback", + pa_schema, + options={"bucket": "-1", "file.format": "avro"}, + ) + _write_arrow(table, pa.table({"id": [1], "name": ["a"]}, schema=pa_schema)) + + result = _explain_table(table, catalog_options=catalog_options, verbose=True) + + assert result.pypaimon_fallback_split_count == result.paimon_scan.split_count + assert result.pypaimon_fallback_split_count > 0 + assert result.native_parquet_split_count == 0 + assert result.fallback_reasons["non-parquet format"] == result.pypaimon_fallback_split_count + assert result.splits is not None + assert all(split.fallback_reason == "non-parquet format" for split in result.splits) + + +@requires_blob +def test_explain_scan_reports_blob_fallback(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("payload", pa.large_binary()), + ]) + _, table = _create_table( + catalog_options, + "explain_blob_fallback", + pa_schema, + options={ + "bucket": "-1", + "file.format": "parquet", + "row-tracking.enabled": "true", + "data-evolution.enabled": "true", + }, + ) + _write_arrow(table, pa.table({"id": [1], "payload": [b"hello"]}, schema=pa_schema)) + + result = _explain_table(table, catalog_options=catalog_options, verbose=True) + + assert result.pypaimon_fallback_split_count == result.paimon_scan.split_count + assert result.pypaimon_fallback_split_count > 0 + assert result.native_parquet_split_count == 0 + assert result.fallback_reasons["blob columns present"] == result.pypaimon_fallback_split_count + assert result.splits is not None + assert all(split.reader_mode == READER_MODE_PYPAIMON_FALLBACK for split in result.splits) + assert all(split.fallback_reason == "blob columns present" for split in result.splits) + + +def test_explain_scan_reports_deletion_vector_fallback(catalog_options, monkeypatch): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + _, table = _create_table( + catalog_options, + "explain_deletion_vector_fallback", + pa_schema, + options={"bucket": "-1", "file.format": "parquet"}, + ) + + class FakeReadBuilder: + def explain(self, verbose: bool = False) -> ExplainResult: + assert verbose is True + return _single_split_explain( + table_identifier="test_db.explain_deletion_vector_fallback", + raw_convertible=True, + has_deletion_vectors=True, + ) + + def fake_scan_read_builder(self, table, read_pushdowns): + return FakeReadBuilder() + + monkeypatch.setattr(PaimonDataSource, "_scan_read_builder", fake_scan_read_builder) + + result = _explain_table(table, catalog_options=catalog_options, verbose=True) + + assert result.pypaimon_fallback_split_count == 1 + assert result.native_parquet_split_count == 0 + assert result.fallback_reasons == {"deletion vectors present": 1} + assert result.splits is not None + assert len(result.splits) == 1 + assert result.splits[0].reader_mode == READER_MODE_PYPAIMON_FALLBACK + assert result.splits[0].fallback_reason == "deletion vectors present" diff --git a/paimon-python/pypaimon/tests/daft/daft_integration_test.py b/paimon-python/pypaimon/tests/daft/daft_integration_test.py new file mode 100644 index 000000000000..3c5709739d83 --- /dev/null +++ b/paimon-python/pypaimon/tests/daft/daft_integration_test.py @@ -0,0 +1,236 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Tests for the public pypaimon.daft read_paimon() / write_paimon() API.""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +pypaimon = pytest.importorskip("pypaimon") +daft = pytest.importorskip("daft") + +from daft import col + +from pypaimon.daft import read_paimon, write_paimon + + +@pytest.fixture +def catalog_options(tmp_path): + options = {"warehouse": str(tmp_path)} + catalog = pypaimon.CatalogFactory.create(options) + catalog.create_database("test_db", ignore_if_exists=True) + return options + + +def _create_table( + catalog_options, + table_name: str, + pa_schema: pa.Schema, + *, + partition_keys: list[str] | None = None, + options: dict[str, str] | None = None, +): + identifier = f"test_db.{table_name}" + catalog = pypaimon.CatalogFactory.create(catalog_options) + schema = pypaimon.Schema.from_pyarrow_schema( + pa_schema, + partition_keys=partition_keys, + options=options, + ) + catalog.create_table(identifier, schema, ignore_if_exists=False) + return identifier, catalog.get_table(identifier) + + +def _write_arrow(table, data: pa.Table) -> None: + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + try: + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + finally: + table_write.close() + table_commit.close() + + +def _create_and_populate_table( + catalog_options, + table_name: str, + data: pa.Table, + *, + partition_keys: list[str] | None = None, + options: dict[str, str] | None = None, +) -> str: + identifier, table = _create_table( + catalog_options, + table_name, + data.schema, + partition_keys=partition_keys, + options=options, + ) + _write_arrow(table, data) + return identifier + + +def test_read_paimon_basic(catalog_options): + data = pa.table( + { + "id": pa.array([1, 2, 3], pa.int64()), + "name": pa.array(["alice", "bob", "carol"], pa.string()), + "value": pa.array([10, 20, 30], pa.int64()), + } + ) + identifier = _create_and_populate_table(catalog_options, "read_basic", data) + + result = read_paimon(identifier, catalog_options).sort("id").to_pydict() + + assert result == { + "id": [1, 2, 3], + "name": ["alice", "bob", "carol"], + "value": [10, 20, 30], + } + + +def test_read_paimon_projection(catalog_options): + data = pa.table( + { + "id": pa.array([1, 2], pa.int64()), + "name": pa.array(["alice", "bob"], pa.string()), + "value": pa.array([10, 20], pa.int64()), + } + ) + identifier = _create_and_populate_table(catalog_options, "read_projection", data) + + result = read_paimon(identifier, catalog_options).select("id", "name").sort("id").to_pydict() + + assert result == { + "id": [1, 2], + "name": ["alice", "bob"], + } + + +def test_read_paimon_filter(catalog_options): + data = pa.table( + { + "id": pa.array([1, 2, 3, 4], pa.int64()), + "category": pa.array(["A", "B", "A", "C"], pa.string()), + "amount": pa.array([100, 200, 150, 300], pa.int64()), + } + ) + identifier = _create_and_populate_table(catalog_options, "read_filter", data) + + result = ( + read_paimon(identifier, catalog_options) + .where((col("category") == "A") & (col("amount") >= 120)) + .sort("id") + .to_pydict() + ) + + assert result == { + "id": [3], + "category": ["A"], + "amount": [150], + } + + +def test_read_paimon_limit(catalog_options): + data = pa.table( + { + "id": pa.array(list(range(10)), pa.int64()), + "name": pa.array([f"name-{i}" for i in range(10)], pa.string()), + } + ) + identifier = _create_and_populate_table(catalog_options, "read_limit", data) + + result = read_paimon(identifier, catalog_options).limit(3).to_pydict() + + assert len(result["id"]) == 3 + + +def test_read_paimon_with_snapshot_id(catalog_options): + pa_schema = pa.schema([("id", pa.int64()), ("name", pa.string())]) + identifier, table = _create_table(catalog_options, "read_snapshot_id", pa_schema) + _write_arrow(table, pa.table({"id": [1], "name": ["first"]}, schema=pa_schema)) + _write_arrow(table, pa.table({"id": [2], "name": ["second"]}, schema=pa_schema)) + + latest = read_paimon(identifier, catalog_options).sort("id").to_pydict() + snap1 = read_paimon(identifier, catalog_options, snapshot_id=1).to_pydict() + + assert latest["id"] == [1, 2] + assert snap1 == {"id": [1], "name": ["first"]} + + +def test_read_paimon_with_tag_name(catalog_options): + pa_schema = pa.schema([("id", pa.int64()), ("name", pa.string())]) + identifier, table = _create_table(catalog_options, "read_tag_name", pa_schema) + _write_arrow(table, pa.table({"id": [1], "name": ["tagged"]}, schema=pa_schema)) + table.create_tag("v1") + _write_arrow(table, pa.table({"id": [2], "name": ["latest"]}, schema=pa_schema)) + + result = read_paimon(identifier, catalog_options, tag_name="v1").to_pydict() + + assert result == {"id": [1], "name": ["tagged"]} + + +def test_read_paimon_rejects_snapshot_id_and_tag_name_together(catalog_options): + with pytest.raises(ValueError, match="snapshot_id and tag_name cannot be set at the same time"): + read_paimon( + "test_db.dummy", + catalog_options, + snapshot_id=1, + tag_name="v1", + ) + + +def test_write_paimon_append(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + identifier, _ = _create_table(catalog_options, "write_append", pa_schema) + df = daft.from_pydict({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + + write_paimon(df, identifier, catalog_options) + + result = read_paimon(identifier, catalog_options).sort("id").to_pydict() + assert result == {"id": [1, 2, 3], "name": ["a", "b", "c"]} + + +def test_write_paimon_overwrite(catalog_options): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("name", pa.string()), + ]) + identifier, _ = _create_table(catalog_options, "write_overwrite", pa_schema) + write_paimon( + daft.from_pydict({"id": [1, 2], "name": ["old-a", "old-b"]}), + identifier, + catalog_options, + ) + + write_paimon( + daft.from_pydict({"id": [3], "name": ["new"]}), + identifier, + catalog_options, + mode="overwrite", + ) + + result = read_paimon(identifier, catalog_options).to_pydict() + assert result == {"id": [3], "name": ["new"]} diff --git a/paimon-python/pypaimon/tests/daft/daft_sink_test.py b/paimon-python/pypaimon/tests/daft/daft_sink_test.py index 3cca5487d97f..ad963de4447e 100644 --- a/paimon-python/pypaimon/tests/daft/daft_sink_test.py +++ b/paimon-python/pypaimon/tests/daft/daft_sink_test.py @@ -303,6 +303,28 @@ def test_write_paimon_invalid_mode(append_only_table): _write_table(df, table, mode="upsert") +def test_write_paimon_sink_serializes_without_file_io(append_only_table): + """PaimonDataSink should not pickle table FileIO objects.""" + from daft.pickle import dumps, loads + + class Unpicklable: + def __reduce__(self): + raise TypeError("file io marker should not be serialized") + + table, _ = append_only_table + table.file_io._unpicklable_marker = Unpicklable() + sink = PaimonDataSink(table, mode="overwrite") + commit_user = sink._write_builder.commit_user + + restored = loads(dumps(sink)) + + assert restored.name() == sink.name() + assert restored._mode == "overwrite" + assert restored._write_builder.commit_user == commit_user + assert restored._write_builder.static_partition == {} + assert restored._table.identifier.get_full_name() == table.identifier.get_full_name() + + def test_write_paimon_rejects_extra_columns(local_paimon_catalog): """Extra input columns should fail instead of being silently dropped.""" catalog, _ = local_paimon_catalog diff --git a/paimon-python/pypaimon/tests/data_evolution_formats_test.py b/paimon-python/pypaimon/tests/data_evolution_formats_test.py new file mode 100644 index 000000000000..f89f3290f199 --- /dev/null +++ b/paimon-python/pypaimon/tests/data_evolution_formats_test.py @@ -0,0 +1,1063 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Data evolution tests covering parquet + blob + vector (vortex) formats. + +Each test writes data using different file format combinations and reads it +back, verifying correctness of the data evolution merge path across formats. +""" + +import os +import shutil +import sys +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.manifest.schema.data_file_meta import DataFileMeta + + +class DataEvolutionFormatsTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', False) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + # ------------------------------------------------------------------ + # Parquet-format data evolution + # ------------------------------------------------------------------ + + def test_parquet_column_subset_write_and_merge_read(self): + """Write disjoint column subsets as parquet, merge-read via data evolution.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('score', pa.float64()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'parquet', + }) + self.catalog.create_table('default.fmt_parquet_subset', schema, False) + table = self.catalog.get_table('default.fmt_parquet_subset') + wb = table.new_batch_write_builder() + + # commit 1: write id + name + w0 = wb.new_write().with_write_type(['id', 'name']) + w1 = wb.new_write().with_write_type(['score']) + c = wb.new_commit() + w0.write_arrow(pa.Table.from_pydict( + {'id': [1, 2, 3], 'name': ['a', 'b', 'c']}, + schema=pa.schema([('id', pa.int32()), ('name', pa.string())]))) + w1.write_arrow(pa.Table.from_pydict( + {'score': [1.1, 2.2, 3.3]}, + schema=pa.schema([('score', pa.float64())]))) + cmts = w0.prepare_commit() + w1.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + c.commit(cmts) + w0.close() + w1.close() + c.close() + + # verify file format + all_files = [nf for m in cmts for nf in m.new_files] + for f in all_files: + self.assertTrue(f.file_name.endswith('.parquet'), + f"Expected parquet file, got {f.file_name}") + + # read back + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + expect = pa.Table.from_pydict( + {'id': [1, 2, 3], 'name': ['a', 'b', 'c'], 'score': [1.1, 2.2, 3.3]}, + schema=pa_schema) + self.assertEqual(actual, expect) + + def test_parquet_overwrite_column(self): + """Write all columns, then overwrite one column via a second commit.""" + pa_schema = pa.schema([ + ('k', pa.int64()), + ('v', pa.string()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'parquet', + }) + self.catalog.create_table('default.fmt_parquet_overwrite', schema, False) + table = self.catalog.get_table('default.fmt_parquet_overwrite') + wb = table.new_batch_write_builder() + + # commit 1: full row + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'k': [10, 20], 'v': ['old1', 'old2']}, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: overwrite v only (first_row_id=0) + tw = wb.new_write().with_write_type(['v']) + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'v': ['new1', 'new2']}, schema=pa.schema([('v', pa.string())]))) + cmts = tw.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + tc.commit(cmts) + tw.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + expect = pa.Table.from_pydict( + {'k': [10, 20], 'v': ['new1', 'new2']}, schema=pa_schema) + self.assertEqual(actual, expect) + + def test_parquet_append_new_rows(self): + """Append new rows (new first_row_id) with column subsets, merge-read all.""" + pa_schema = pa.schema([ + ('a', pa.int32()), + ('b', pa.string()), + ('c', pa.float32()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'parquet', + }) + self.catalog.create_table('default.fmt_parquet_append', schema, False) + table = self.catalog.get_table('default.fmt_parquet_append') + wb = table.new_batch_write_builder() + + # commit 1: 2 full rows + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'a': [1, 2], 'b': ['x', 'y'], 'c': [0.1, 0.2]}, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: append 2 new rows with column subsets, first_row_id=2 + w_ab = wb.new_write().with_write_type(['a', 'b']) + w_c = wb.new_write().with_write_type(['c']) + tc = wb.new_commit() + w_ab.write_arrow(pa.Table.from_pydict( + {'a': [3, 4], 'b': ['z', 'w']}, + schema=pa.schema([('a', pa.int32()), ('b', pa.string())]))) + w_c.write_arrow(pa.Table.from_pydict( + {'c': [0.3, 0.4]}, + schema=pa.schema([('c', pa.float32())]))) + cmts = w_ab.prepare_commit() + w_c.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 2 + tc.commit(cmts) + w_ab.close() + w_c.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 4) + expect = pa.Table.from_pydict( + {'a': [1, 2, 3, 4], 'b': ['x', 'y', 'z', 'w'], + 'c': [0.1, 0.2, 0.3, 0.4]}, + schema=pa_schema) + self.assertEqual(actual, expect) + + # ------------------------------------------------------------------ + # Blob-format data evolution + # ------------------------------------------------------------------ + + def test_blob_write_and_read(self): + """Write a table with normal + blob columns, read back and verify.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('payload', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.fmt_blob_basic', schema, False) + table = self.catalog.get_table('default.fmt_blob_basic') + wb = table.new_batch_write_builder() + + blobs = [b'hello world', b'\x00\x01\x02\xff', b'paimon blob'] + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': [1, 2, 3], 'payload': blobs}, schema=pa_schema)) + cmts = tw.prepare_commit() + tc.commit(cmts) + tw.close() + tc.close() + + # verify we produced both parquet and blob files + all_files = [nf for m in cmts for nf in m.new_files] + parquet_files = [f for f in all_files if f.file_name.endswith('.parquet')] + blob_files = [f for f in all_files if f.file_name.endswith('.blob')] + self.assertGreater(len(parquet_files), 0) + self.assertGreater(len(blob_files), 0) + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 3) + self.assertEqual(actual.column('id').to_pylist(), [1, 2, 3]) + self.assertEqual(actual.column('payload').to_pylist(), blobs) + + def test_blob_column_subset_evolution(self): + """Write normal+blob cols in one commit, overwrite normal col in another, merge-read.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('doc', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.fmt_blob_evolution', schema, False) + table = self.catalog.get_table('default.fmt_blob_evolution') + wb = table.new_batch_write_builder() + + # commit 1: write id + doc (normal + blob together) + tw = wb.new_write().with_write_type(['id', 'doc']) + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': [1, 2], 'doc': [b'doc_alice', b'doc_bob']}, + schema=pa.schema([('id', pa.int32()), ('doc', pa.large_binary())]))) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: write name for the same rows (first_row_id=0) + tw = wb.new_write().with_write_type(['name']) + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'name': ['Alice', 'Bob']}, + schema=pa.schema([('name', pa.string())]))) + cmts = tw.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + tc.commit(cmts) + tw.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 2) + self.assertEqual(actual.column('id').to_pylist(), [1, 2]) + self.assertEqual(actual.column('name').to_pylist(), ['Alice', 'Bob']) + self.assertEqual(actual.column('doc').to_pylist(), [b'doc_alice', b'doc_bob']) + + def test_blob_append_with_subset_evolution(self): + """Write normal+blob subset in first commit, add remaining col via evolution.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('tag', pa.string()), + ('picture', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.fmt_blob_append_evo', schema, False) + table = self.catalog.get_table('default.fmt_blob_append_evo') + wb = table.new_batch_write_builder() + + # commit 1: id + picture (normal + blob) + tw = wb.new_write().with_write_type(['id', 'picture']) + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': [1, 2], 'picture': [b'pic1', b'pic2']}, + schema=pa.schema([('id', pa.int32()), ('picture', pa.large_binary())]))) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: add tag for the same rows + tw = wb.new_write().with_write_type(['tag']) + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'tag': ['t1', 't2']}, + schema=pa.schema([('tag', pa.string())]))) + cmts = tw.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + tc.commit(cmts) + tw.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 2) + self.assertEqual(actual.column('id').to_pylist(), [1, 2]) + self.assertEqual(actual.column('tag').to_pylist(), ['t1', 't2']) + self.assertEqual(actual.column('picture').to_pylist(), [b'pic1', b'pic2']) + + def test_blob_multiple_blob_columns(self): + """Table with two blob columns, write and read both.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('audio', pa.large_binary()), + ('video', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.fmt_blob_multi', schema, False) + table = self.catalog.get_table('default.fmt_blob_multi') + wb = table.new_batch_write_builder() + + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict({ + 'id': [1, 2], + 'audio': [b'audio_1', b'audio_2'], + 'video': [b'video_1', b'video_2'], + }, schema=pa_schema)) + cmts = tw.prepare_commit() + tc.commit(cmts) + tw.close() + tc.close() + + # verify blob files were produced + all_files = [nf for m in cmts for nf in m.new_files] + blob_files = [f for f in all_files if f.file_name.endswith('.blob')] + self.assertGreaterEqual(len(blob_files), 2) + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 2) + self.assertEqual(actual.column('audio').to_pylist(), [b'audio_1', b'audio_2']) + self.assertEqual(actual.column('video').to_pylist(), [b'video_1', b'video_2']) + + # ------------------------------------------------------------------ + # Vortex-format data evolution + # ------------------------------------------------------------------ + + @unittest.skipIf(sys.version_info < (3, 11), "vortex-data requires Python >= 3.11") + @unittest.skipUnless( + __import__('importlib').util.find_spec('vortex') is not None, + "vortex not installed") + def test_vortex_column_subset_write_and_merge_read(self): + """Write disjoint column subsets as vortex, merge-read via data evolution.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('tag', pa.string()), + ('val', pa.float64()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'vortex', + }) + self.catalog.create_table('default.fmt_vortex_subset', schema, False) + table = self.catalog.get_table('default.fmt_vortex_subset') + wb = table.new_batch_write_builder() + + w0 = wb.new_write().with_write_type(['id', 'tag']) + w1 = wb.new_write().with_write_type(['val']) + c = wb.new_commit() + w0.write_arrow(pa.Table.from_pydict( + {'id': [10, 20, 30], 'tag': ['p', 'q', 'r']}, + schema=pa.schema([('id', pa.int32()), ('tag', pa.string())]))) + w1.write_arrow(pa.Table.from_pydict( + {'val': [1.5, 2.5, 3.5]}, + schema=pa.schema([('val', pa.float64())]))) + cmts = w0.prepare_commit() + w1.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + c.commit(cmts) + w0.close() + w1.close() + c.close() + + # verify vortex files + all_files = [nf for m in cmts for nf in m.new_files] + for f in all_files: + self.assertTrue(f.file_name.endswith('.vortex'), + f"Expected vortex file, got {f.file_name}") + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + expect = pa.Table.from_pydict( + {'id': [10, 20, 30], 'tag': ['p', 'q', 'r'], 'val': [1.5, 2.5, 3.5]}, + schema=pa_schema) + self.assertEqual(actual, expect) + + @unittest.skipIf(sys.version_info < (3, 11), "vortex-data requires Python >= 3.11") + @unittest.skipUnless( + __import__('importlib').util.find_spec('vortex') is not None, + "vortex not installed") + def test_vortex_overwrite_column(self): + """Full row write then overwrite one column, all in vortex format.""" + pa_schema = pa.schema([ + ('k', pa.int64()), + ('v', pa.string()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'vortex', + }) + self.catalog.create_table('default.fmt_vortex_overwrite', schema, False) + table = self.catalog.get_table('default.fmt_vortex_overwrite') + wb = table.new_batch_write_builder() + + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'k': [100, 200], 'v': ['old', 'old']}, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + tw = wb.new_write().with_write_type(['v']) + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'v': ['new', 'new']}, schema=pa.schema([('v', pa.string())]))) + cmts = tw.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + tc.commit(cmts) + tw.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual, pa.Table.from_pydict( + {'k': [100, 200], 'v': ['new', 'new']}, schema=pa_schema)) + + @unittest.skipIf(sys.version_info < (3, 11), "vortex-data requires Python >= 3.11") + @unittest.skipUnless( + __import__('importlib').util.find_spec('vortex') is not None, + "vortex not installed") + def test_vortex_append_new_rows(self): + """Append new rows with column subsets in vortex format.""" + pa_schema = pa.schema([ + ('x', pa.int32()), + ('y', pa.string()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'vortex', + }) + self.catalog.create_table('default.fmt_vortex_append', schema, False) + table = self.catalog.get_table('default.fmt_vortex_append') + wb = table.new_batch_write_builder() + + # commit 1 + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'x': [1, 2], 'y': ['a', 'b']}, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: append with subsets, first_row_id=2 + w_x = wb.new_write().with_write_type(['x']) + w_y = wb.new_write().with_write_type(['y']) + tc = wb.new_commit() + w_x.write_arrow(pa.Table.from_pydict( + {'x': [3]}, schema=pa.schema([('x', pa.int32())]))) + w_y.write_arrow(pa.Table.from_pydict( + {'y': ['c']}, schema=pa.schema([('y', pa.string())]))) + cmts = w_x.prepare_commit() + w_y.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 2 + tc.commit(cmts) + w_x.close() + w_y.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 3) + expect = pa.Table.from_pydict( + {'x': [1, 2, 3], 'y': ['a', 'b', 'c']}, schema=pa_schema) + self.assertEqual(actual, expect) + + @unittest.skipIf(sys.version_info < (3, 11), "vortex-data requires Python >= 3.11") + @unittest.skipUnless( + __import__('importlib').util.find_spec('vortex') is not None, + "vortex not installed") + def test_vortex_with_row_id_and_filter(self): + """Write vortex data, read with _ROW_ID projection and filter.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('val', pa.string()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'vortex', + }) + self.catalog.create_table('default.fmt_vortex_rowid_filter', schema, False) + table = self.catalog.get_table('default.fmt_vortex_rowid_filter') + wb = table.new_batch_write_builder() + + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': list(range(10)), 'val': [f'v{i}' for i in range(10)]}, + schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # full read + rb = table.new_read_builder() + full = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(full.num_rows, 10) + + # filter by _ROW_ID + rb_rid = table.new_read_builder().with_projection(['id', 'val', '_ROW_ID']) + pb = rb_rid.new_predicate_builder() + rb_f = table.new_read_builder().with_filter(pb.equal('_ROW_ID', 5)) + actual = rb_f.new_read().to_arrow(rb_f.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 1) + self.assertEqual(actual.column('id')[0].as_py(), 5) + self.assertEqual(actual.column('val')[0].as_py(), 'v5') + + # ------------------------------------------------------------------ + # Vector (vortex) file format for embedding columns + # ------------------------------------------------------------------ + + @unittest.skipIf(sys.version_info < (3, 11), "vortex-data requires Python >= 3.11") + @unittest.skipUnless( + __import__('importlib').util.find_spec('vortex') is not None, + "vortex not installed") + def test_vector_vortex_write_and_read(self): + """Write table with normal + vector columns using vortex vector format.""" + pa_schema = pa.schema([ + ('id', pa.int64()), + ('embed', pa.list_(pa.float32(), 4)), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'vortex', + 'vector.file.format': 'vortex', + }) + self.catalog.create_table('default.fmt_vec_vortex', schema, False) + table = self.catalog.get_table('default.fmt_vec_vortex') + + embeddings = [1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0] + test_data = pa.table({ + 'id': pa.array([1, 2, 3], type=pa.int64()), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array(embeddings, type=pa.float32()), 4), + }) + + wb = table.new_batch_write_builder() + tw = wb.new_write() + tw.write_arrow(test_data) + cmts = tw.prepare_commit() + + # should produce both normal and vector files + all_files = [nf for m in cmts for nf in m.new_files] + normal_files = [f for f in all_files if not DataFileMeta.is_vector_file(f.file_name)] + vector_files = [f for f in all_files if DataFileMeta.is_vector_file(f.file_name)] + self.assertGreater(len(normal_files), 0) + self.assertGreater(len(vector_files), 0) + for vf in vector_files: + self.assertIn('.vector.vortex', vf.file_name) + + wb.new_commit().commit(cmts) + tw.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 3) + self.assertEqual(actual.column('id').to_pylist(), [1, 2, 3]) + embed_col = actual.column('embed') + self.assertEqual(embed_col[0].as_py(), [1.0, 0.0, 0.0, 0.0]) + self.assertEqual(embed_col[1].as_py(), [0.0, 1.0, 0.0, 0.0]) + self.assertEqual(embed_col[2].as_py(), [0.0, 0.0, 1.0, 0.0]) + + @unittest.skipIf(sys.version_info < (3, 11), "vortex-data requires Python >= 3.11") + @unittest.skipUnless( + __import__('importlib').util.find_spec('vortex') is not None, + "vortex not installed") + def test_vector_vortex_multiple_appends(self): + """Append multiple batches of normal+vector data and read all back.""" + pa_schema = pa.schema([ + ('id', pa.int64()), + ('label', pa.string()), + ('embed', pa.list_(pa.float32(), 3)), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'vortex', + 'vector.file.format': 'vortex', + }) + self.catalog.create_table('default.fmt_vec_vortex_append', schema, False) + table = self.catalog.get_table('default.fmt_vec_vortex_append') + wb = table.new_batch_write_builder() + + # commit 1 + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'id': pa.array([1, 2], type=pa.int64()), + 'label': pa.array(['cat', 'dog']), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], type=pa.float32()), 3), + })) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: append + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'id': pa.array([3], type=pa.int64()), + 'label': pa.array(['bird']), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([0.7, 0.8, 0.9], type=pa.float32()), 3), + })) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 3) + self.assertEqual(actual.column('id').to_pylist(), [1, 2, 3]) + self.assertEqual(actual.column('label').to_pylist(), ['cat', 'dog', 'bird']) + embed_col = actual.column('embed') + self.assertAlmostEqual(embed_col[0].as_py()[0], 0.1, places=5) + self.assertAlmostEqual(embed_col[2].as_py()[2], 0.9, places=5) + + # ------------------------------------------------------------------ + # Mixed formats: parquet + blob + vector in one table + # ------------------------------------------------------------------ + + def test_parquet_and_blob_mixed_append(self): + """Table with normal parquet cols + blob col, append new rows.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('image', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.fmt_mixed_parquet_blob', schema, False) + table = self.catalog.get_table('default.fmt_mixed_parquet_blob') + wb = table.new_batch_write_builder() + + # commit 1: first batch + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict({ + 'id': [1, 2], + 'name': ['a', 'b'], + 'image': [b'img1', b'img2'], + }, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: append more rows + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict({ + 'id': [3, 4], + 'name': ['c', 'd'], + 'image': [b'img3', b'img4'], + }, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 4) + self.assertEqual(actual.column('id').to_pylist(), [1, 2, 3, 4]) + self.assertEqual(actual.column('name').to_pylist(), ['a', 'b', 'c', 'd']) + self.assertEqual(actual.column('image').to_pylist(), + [b'img1', b'img2', b'img3', b'img4']) + + @unittest.skipIf(sys.version_info < (3, 11), "vortex-data requires Python >= 3.11") + @unittest.skipUnless( + __import__('importlib').util.find_spec('vortex') is not None, + "vortex not installed") + def test_vortex_and_vector_vortex_mixed(self): + """Table with normal (vortex) + vector (vortex) columns, write and read. + + Verifies that the writer produces separate .vortex and .vector.vortex files, + and the data evolution merge reader stitches them back together. + """ + pa_schema = pa.schema([ + ('id', pa.int64()), + ('name', pa.string()), + ('embed', pa.list_(pa.float32(), 3)), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'vortex', + 'vector.file.format': 'vortex', + }) + self.catalog.create_table('default.fmt_vortex_vector', schema, False) + table = self.catalog.get_table('default.fmt_vortex_vector') + wb = table.new_batch_write_builder() + + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'id': pa.array([1, 2, 3], type=pa.int64()), + 'name': pa.array(['cat', 'dog', 'bird']), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + type=pa.float32()), 3), + })) + cmts = tw.prepare_commit() + tc.commit(cmts) + tw.close() + tc.close() + + # verify two file types: .vortex + .vector.vortex + all_files = [nf for m in cmts for nf in m.new_files] + normal_files = [f for f in all_files if not DataFileMeta.is_vector_file(f.file_name)] + vector_files = [f for f in all_files if DataFileMeta.is_vector_file(f.file_name)] + self.assertGreater(len(normal_files), 0, "should produce normal vortex files") + self.assertGreater(len(vector_files), 0, "should produce vector files") + for nf in normal_files: + self.assertTrue(nf.file_name.endswith('.vortex')) + for vf in vector_files: + self.assertIn('.vector.vortex', vf.file_name) + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 3) + self.assertEqual(actual.column('id').to_pylist(), [1, 2, 3]) + self.assertEqual(actual.column('name').to_pylist(), ['cat', 'dog', 'bird']) + embed = actual.column('embed') + self.assertAlmostEqual(embed[0].as_py()[0], 0.1, places=5) + self.assertAlmostEqual(embed[2].as_py()[2], 0.9, places=5) + + def test_blob_and_vector_inline_mixed(self): + """Table with normal + blob + vector(inline) columns, write and read. + + When blob columns are present, vector columns are stored inline in the + parquet file (not as separate .vector files). This test verifies the + blob+inline-vector path works correctly. + """ + pa_schema = pa.schema([ + ('id', pa.int64()), + ('doc', pa.large_binary()), + ('embed', pa.list_(pa.float32(), 3)), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.fmt_blob_vector_inline', schema, False) + table = self.catalog.get_table('default.fmt_blob_vector_inline') + wb = table.new_batch_write_builder() + + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'id': pa.array([1, 2], type=pa.int64()), + 'doc': pa.array([b'doc1', b'doc2'], type=pa.large_binary()), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], type=pa.float32()), 3), + })) + cmts = tw.prepare_commit() + tc.commit(cmts) + tw.close() + tc.close() + + # verify parquet + blob files + all_files = [nf for m in cmts for nf in m.new_files] + parquet_files = [f for f in all_files if f.file_name.endswith('.parquet')] + blob_files = [f for f in all_files if f.file_name.endswith('.blob')] + self.assertGreater(len(parquet_files), 0, "should produce parquet files") + self.assertGreater(len(blob_files), 0, "should produce blob files") + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 2) + self.assertEqual(actual.column('id').to_pylist(), [1, 2]) + self.assertEqual(actual.column('doc').to_pylist(), [b'doc1', b'doc2']) + embed = actual.column('embed') + self.assertAlmostEqual(embed[0].as_py()[0], 0.1, places=5) + self.assertAlmostEqual(embed[1].as_py()[2], 0.6, places=5) + + def test_blob_and_vector_with_vector_file_format(self): + """Table with blob + vector columns and explicit vector.file.format. + + DedicatedFormatWriter splits data three ways: normal columns to .parquet, + blob columns to .blob, and vector columns to .vector. files. + """ + pa_schema = pa.schema([ + ('id', pa.int64()), + ('doc', pa.large_binary()), + ('embed', pa.list_(pa.float32(), 3)), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'vector.file.format': 'parquet', + }) + self.catalog.create_table('default.fmt_blob_vec_format', schema, False) + table = self.catalog.get_table('default.fmt_blob_vec_format') + wb = table.new_batch_write_builder() + + # commit 1: write all columns + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'id': pa.array([1, 2, 3], type=pa.int64()), + 'doc': pa.array([b'aaa', b'bbb', b'ccc'], type=pa.large_binary()), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0], type=pa.float32()), 3), + })) + cmts = tw.prepare_commit() + tc.commit(cmts) + tw.close() + tc.close() + + # DedicatedFormatWriter produces parquet + blob + vector files + all_files = [nf for m in cmts for nf in m.new_files] + parquet_files = [f for f in all_files + if f.file_name.endswith('.parquet') + and not DataFileMeta.is_vector_file(f.file_name)] + blob_files = [f for f in all_files if f.file_name.endswith('.blob')] + vector_files = [f for f in all_files if DataFileMeta.is_vector_file(f.file_name)] + self.assertGreater(len(parquet_files), 0, "should produce normal parquet files") + self.assertGreater(len(blob_files), 0, "should produce blob files") + self.assertGreater(len(vector_files), 0, "should produce vector files") + for vf in vector_files: + self.assertIn('.vector.parquet', vf.file_name) + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 3) + self.assertEqual(actual.column('id').to_pylist(), [1, 2, 3]) + self.assertEqual(actual.column('doc').to_pylist(), [b'aaa', b'bbb', b'ccc']) + self.assertEqual(actual.column('embed')[0].as_py(), [1.0, 0.0, 0.0]) + self.assertEqual(actual.column('embed')[1].as_py(), [0.0, 1.0, 0.0]) + self.assertEqual(actual.column('embed')[2].as_py(), [0.0, 0.0, 1.0]) + + # commit 2: append more rows + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'id': pa.array([4, 5], type=pa.int64()), + 'doc': pa.array([b'ddd', b'eee'], type=pa.large_binary()), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([0.5, 0.5, 0.0, + 0.0, 0.5, 0.5], type=pa.float32()), 3), + })) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder() + actual2 = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual2.num_rows, 5) + self.assertEqual(actual2.column('id').to_pylist(), [1, 2, 3, 4, 5]) + self.assertEqual(actual2.column('doc').to_pylist(), + [b'aaa', b'bbb', b'ccc', b'ddd', b'eee']) + + def test_blob_vector_partial_write_vector_only(self): + """Blob+vector table with with_write_type(['embed']) — vector-only partial write. + + When normal_column_names is empty, the writer must still flush vector + metadata without crashing on an empty normal data path. + """ + pa_schema = pa.schema([ + ('id', pa.int64()), + ('doc', pa.large_binary()), + ('embed', pa.list_(pa.float32(), 3)), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'vector.file.format': 'parquet', + }) + self.catalog.create_table('default.fmt_blob_vec_partial', schema, False) + table = self.catalog.get_table('default.fmt_blob_vec_partial') + wb = table.new_batch_write_builder() + + # commit 1: write all columns + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'id': pa.array([1, 2, 3], type=pa.int64()), + 'doc': pa.array([b'aaa', b'bbb', b'ccc'], type=pa.large_binary()), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0], type=pa.float32()), 3), + })) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # commit 2: write only vector column — no normal columns + tw = wb.new_write().with_write_type(['embed']) + tc = wb.new_commit() + tw.write_arrow(pa.table({ + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([0.5, 0.5, 0.0, + 0.0, 0.5, 0.5, + 0.5, 0.0, 0.5], type=pa.float32()), 3), + })) + cmts = tw.prepare_commit() + + # should produce only vector files, no normal or blob files + all_files = [nf for m in cmts for nf in m.new_files] + self.assertGreater(len(all_files), 0, "should produce vector files") + for f in all_files: + self.assertTrue(DataFileMeta.is_vector_file(f.file_name), + f"Expected vector file, got {f.file_name}") + + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + tc.commit(cmts) + tw.close() + tc.close() + + # read back and verify the vector column was updated + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 3) + self.assertEqual(actual.column('id').to_pylist(), [1, 2, 3]) + self.assertEqual(actual.column('doc').to_pylist(), [b'aaa', b'bbb', b'ccc']) + embed = actual.column('embed') + self.assertEqual(embed[0].as_py(), [0.5, 0.5, 0.0]) + self.assertEqual(embed[1].as_py(), [0.0, 0.5, 0.5]) + self.assertEqual(embed[2].as_py(), [0.5, 0.0, 0.5]) + + # ------------------------------------------------------------------ + # Projection and _ROW_ID across formats + # ------------------------------------------------------------------ + + def test_blob_with_row_id_projection(self): + """Read blob table with _ROW_ID projection.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.fmt_blob_rowid', schema, False) + table = self.catalog.get_table('default.fmt_blob_rowid') + wb = table.new_batch_write_builder() + + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': [10, 20], 'data': [b'aa', b'bb']}, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder() + rb.with_projection(['id', 'data', '_ROW_ID']) + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 2) + self.assertEqual(actual.column('_ROW_ID').to_pylist(), [0, 1]) + self.assertEqual(actual.column('id').to_pylist(), [10, 20]) + self.assertEqual(actual.column('data').to_pylist(), [b'aa', b'bb']) + + def test_parquet_large_data_evolution(self): + """Larger dataset: 1000 rows, column-subset write+merge.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('col_a', pa.string()), + ('col_b', pa.float64()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'file.format': 'parquet', + }) + self.catalog.create_table('default.fmt_parquet_large', schema, False) + table = self.catalog.get_table('default.fmt_parquet_large') + wb = table.new_batch_write_builder() + + n = 1000 + w0 = wb.new_write().with_write_type(['id', 'col_a']) + w1 = wb.new_write().with_write_type(['col_b']) + c = wb.new_commit() + w0.write_arrow(pa.Table.from_pydict( + {'id': list(range(n)), 'col_a': [f's{i}' for i in range(n)]}, + schema=pa.schema([('id', pa.int32()), ('col_a', pa.string())]))) + w1.write_arrow(pa.Table.from_pydict( + {'col_b': [float(i) for i in range(n)]}, + schema=pa.schema([('col_b', pa.float64())]))) + cmts = w0.prepare_commit() + w1.prepare_commit() + for m in cmts: + for nf in m.new_files: + nf.first_row_id = 0 + c.commit(cmts) + w0.close() + w1.close() + c.close() + + rb = table.new_read_builder() + actual = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, n) + self.assertEqual(actual.column('id').to_pylist(), list(range(n))) + self.assertEqual(actual.column('col_a').to_pylist(), [f's{i}' for i in range(n)]) + self.assertEqual(actual.column('col_b').to_pylist(), [float(i) for i in range(n)]) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/e2e/hdfs/README.md b/paimon-python/pypaimon/tests/e2e/hdfs/README.md new file mode 100644 index 000000000000..47f63701ed12 --- /dev/null +++ b/paimon-python/pypaimon/tests/e2e/hdfs/README.md @@ -0,0 +1,75 @@ + + +# HDFS End-to-End Tests + +Verifies the native HDFS FileIO backend (`HdfsNativeFileIO`) against a live HDFS +cluster. No local Hadoop install or JVM required on the client side. + +## Quick start (Docker) + +```sh +# 1. Bring up a single-NameNode + single-DataNode cluster. +docker compose -f pypaimon/tests/e2e/hdfs/docker-compose.yml up -d + +# Wait ~30s for the cluster to become healthy; check with: +docker compose -f pypaimon/tests/e2e/hdfs/docker-compose.yml ps + +# 2. Install the package with the hdfs extra. +pip install -e '.[hdfs]' + +# 3. Run the tests. +PYPAIMON_HDFS_E2E_URL=hdfs://localhost:8020 \ + python -m pytest pypaimon/tests/e2e/hdfs/ -v + +# 4. Teardown. +docker compose -f pypaimon/tests/e2e/hdfs/docker-compose.yml down -v +``` + +## REST-catalog config delivery mode (no local xml) + +The native backend accepts Hadoop key/values directly via catalog options. +Skip the `core-site.xml` / `hdfs-site.xml` dance entirely by configuring the +cluster wiring as options — exactly what a REST catalog would push to the +client in its response. Example: + +```python +catalog = CatalogFactory.create({ + "warehouse": "viewfs://cluster/warehouse", + "hdfs.client.impl": "native", + # Forwarded as-is to the underlying client: + "dfs.nameservices": "ns1,ns2", + "dfs.ha.namenodes.ns1": "nn1,nn2", + "dfs.namenode.rpc-address.ns1.nn1": "host-1:8020", + "dfs.namenode.rpc-address.ns1.nn2": "host-2:8020", + "fs.viewfs.mounttable.cluster.link./prod": "hdfs://ns1/prod", +}) +``` + +Keys matching the prefixes `dfs.` / `fs.` / `hadoop.` / `ipc.` / `io.` are +forwarded automatically. Use the `hdfs.config.` namespace for any other +key you want passed through. + +## Kerberos + +The cluster in `docker-compose.yml` runs without security to keep the +smoke test simple. For a Kerberized e2e: provision a krb5 + HDFS compose +separately, install `libgssapi-krb5-2` (or platform equivalent) on the +client, set `KRB5_CONFIG` and `KRB5CCNAME`, then either run `kinit` +yourself or pass `security.kerberos.login.principal` + `.keytab` as +catalog options (pypaimon will run `kinit` for you). diff --git a/paimon-python/pypaimon/tests/e2e/hdfs/__init__.py b/paimon-python/pypaimon/tests/e2e/hdfs/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/paimon-python/pypaimon/tests/e2e/hdfs/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/paimon-python/pypaimon/tests/e2e/hdfs/docker-compose.yml b/paimon-python/pypaimon/tests/e2e/hdfs/docker-compose.yml new file mode 100644 index 000000000000..145708c3e629 --- /dev/null +++ b/paimon-python/pypaimon/tests/e2e/hdfs/docker-compose.yml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Single-NameNode HDFS cluster for integration testing the native HDFS backend. +# Brings up one NameNode (RPC 8020) and one DataNode (data 9866). +# +# Run: +# docker compose -f pypaimon/tests/e2e/hdfs/docker-compose.yml up -d +# PYPAIMON_HDFS_E2E_URL=hdfs://localhost:8020 \ +# python -m pytest pypaimon/tests/e2e/hdfs/hdfs_native_e2e_test.py -v +# docker compose -f pypaimon/tests/e2e/hdfs/docker-compose.yml down -v + +services: + namenode: + image: apache/hadoop:3.3.6 + hostname: namenode + user: root + command: ["/bin/bash", "-c", "hdfs namenode -format -force -nonInteractive || true; hdfs namenode"] + environment: + HADOOP_HOME: /opt/hadoop + CORE-SITE.XML_fs.defaultFS: hdfs://namenode:8020 + HDFS-SITE.XML_dfs.namenode.rpc-address: namenode:8020 + HDFS-SITE.XML_dfs.replication: "1" + HDFS-SITE.XML_dfs.permissions.enabled: "false" + ports: + - "8020:8020" + - "9870:9870" + + datanode: + image: apache/hadoop:3.3.6 + user: root + command: ["hdfs", "datanode"] + environment: + HADOOP_HOME: /opt/hadoop + CORE-SITE.XML_fs.defaultFS: hdfs://namenode:8020 + HDFS-SITE.XML_dfs.datanode.use.datanode.hostname: "false" + HDFS-SITE.XML_dfs.permissions.enabled: "false" + depends_on: + - namenode + ports: + - "9866:9866" diff --git a/paimon-python/pypaimon/tests/e2e/hdfs/hdfs_native_e2e_test.py b/paimon-python/pypaimon/tests/e2e/hdfs/hdfs_native_e2e_test.py new file mode 100644 index 000000000000..5989b815faf7 --- /dev/null +++ b/paimon-python/pypaimon/tests/e2e/hdfs/hdfs_native_e2e_test.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""End-to-end tests for the native HDFS FileIO backend. + +Disabled by default. To run: + 1. Start an HDFS cluster — see docker-compose.yml in this directory. + 2. Install pypaimon with the hdfs extra: + pip install -e '.[hdfs]' + 3. Point the tests at the cluster and run: + PYPAIMON_HDFS_E2E_URL=hdfs://localhost:8020 \\ + python -m pytest pypaimon/tests/e2e/hdfs/hdfs_native_e2e_test.py -v + +To exercise the REST-catalog config-delivery path (no local xml), put the +Hadoop config k/v in catalog options under the `dfs.*` / `fs.*` namespaces +or via `hdfs.config.*` — both are forwarded to the underlying client. +""" + +import os +import unittest +import uuid + +import pandas as pd +import pyarrow as pa + +E2E_URL = os.environ.get("PYPAIMON_HDFS_E2E_URL") +SKIP_REASON = ("PYPAIMON_HDFS_E2E_URL not set; skipping HDFS e2e. " + "See docker-compose.yml in this directory.") + + +@unittest.skipIf(not E2E_URL, SKIP_REASON) +class HdfsNativeE2ETest(unittest.TestCase): + """Smoke-test the native HDFS backend end-to-end against a live cluster.""" + + @classmethod + def setUpClass(cls): + try: + import hdfs_native # noqa: F401 + except ImportError as e: + raise unittest.SkipTest( + "hdfs-native not installed. pip install 'pypaimon[hdfs]'" + ) from e + + from pypaimon.catalog.catalog_factory import CatalogFactory + cls.warehouse = ( + f"{E2E_URL}/pypaimon-e2e/warehouse-{uuid.uuid4().hex[:8]}" + ) + cls.catalog = CatalogFactory.create({ + "warehouse": cls.warehouse, + "hdfs.client.impl": "native", + }) + cls.catalog.create_database("default", True) + + def _make_table(self, name, schema): + from pypaimon.schema.schema import Schema + fqn = f"default.{name}" + s = Schema.from_pyarrow_schema( + schema, + options={"file.format": "parquet"}, + ) + self.catalog.create_table(fqn, s, False) + return self.catalog.get_table(fqn) + + def test_write_then_read_back(self): + pa_schema = pa.schema([ + ("id", pa.int64()), + ("payload", pa.string()), + ]) + table = self._make_table(f"t_{uuid.uuid4().hex[:8]}", pa_schema) + + data = pd.DataFrame({ + "id": list(range(100)), + "payload": [f"row-{i}" for i in range(100)], + }) + + writer = table.new_batch_write_builder().new_write() + writer.write_pandas(data) + commit_msgs = writer.prepare_commit() + committer = table.new_batch_write_builder().new_commit() + committer.commit(commit_msgs) + writer.close() + committer.close() + + scan = table.new_read_builder().new_scan() + reader = table.new_read_builder().new_read() + splits = scan.plan().splits() + result = reader.to_arrow(splits) + + self.assertEqual(result.num_rows, 100) + + +if __name__ == "__main__": + unittest.main() diff --git a/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py b/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py index 475d2e74f7c2..0c84a828bca1 100644 --- a/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py +++ b/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py @@ -277,6 +277,55 @@ def test_read_pk_table(self, file_format): # which explicitly reads KeyValue objects and checks valueKind print(f"Format: {file_format}, Python read completed. ValueKind verification should be done in Java test.") + def test_py_write_row_append_table(self): + """Python writes a ROW-format append-only table for Java to read.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('value', pa.float64()), + ]) + + schema = Schema.from_pyarrow_schema( + pa_schema, + options={'file.format': 'row', 'bucket': '-1'} + ) + + table_name = 'default.mixed_test_append_tablep_row' + self.catalog.create_table(table_name, schema, False) + table = self.catalog.get_table(table_name) + + data = pa.table({ + 'id': pa.array([1, 2, 3, 4, 5, 6], type=pa.int32()), + 'name': pa.array(['Apple', 'Banana', 'Carrot', 'Broccoli', 'Chicken', 'Beef']), + 'value': pa.array([1.5, 0.8, 0.6, 1.2, 5.0, 8.0]), + }) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Verify Python can read it back + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 6) + expected_names = {'Apple', 'Banana', 'Carrot', 'Broccoli', 'Chicken', 'Beef'} + self.assertEqual(set(result.column('name').to_pylist()), expected_names) + + def test_read_row_append_table(self): + """Python reads a ROW-format append-only table written by Java.""" + table = self.catalog.get_table('default.mixed_test_append_tablej_row') + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 6) + expected_names = {'Apple', 'Banana', 'Carrot', 'Broccoli', 'Chicken', 'Beef'} + self.assertEqual(set(result.column('name').to_pylist()), expected_names) + def test_pk_dv_read(self): pa_schema = pa.schema([ pa.field('pt', pa.int32(), nullable=False), @@ -394,7 +443,9 @@ def test_read_btree_index_table(self): self._test_read_btree_index_generic("test_btree_index_bigint", 2000, pa.int64()) self._test_read_btree_index_large() self._test_read_btree_index_null() - self._test_index_manifest_inherited_after_write() + self._test_partial_append_does_not_trigger_index_action() + if sys.version_info[:2] >= (3, 7): + self._test_index_manifest_inherited_after_write() def _test_read_btree_index_generic(self, table_name: str, k, k_type): table = self.catalog.get_table('default.' + table_name) @@ -516,6 +567,72 @@ def _test_index_manifest_inherited_after_write(self): "index_manifest lost after Python data write - indexes become invisible" ) + read_builder = table.new_read_builder() + predicate_builder = read_builder.new_predicate_builder() + read_builder.with_filter(predicate_builder.equal('k', 'k2')) + read_builder.with_projection(['k', '_ROW_ID']) + splits = read_builder.new_scan().plan().splits() + row_ids = read_builder.new_read().to_arrow(splits)['_ROW_ID'].to_pylist() + self.assertTrue(len(row_ids) > 0, "k2 should exist before update") + + wb = table.new_batch_write_builder() + tu = wb.new_update().with_update_type(['k']) + update_data = pa.table({ + '_ROW_ID': pa.array(row_ids, type=pa.int64()), + 'k': ['k_updated'] * len(row_ids), + }) + msgs = tu.update_by_arrow_with_row_id(update_data) + with self.assertRaises(RuntimeError) as cm: + wb.new_commit().commit(msgs) + self.assertIn("'k'", str(cm.exception)) + self.assertIn("Conflicted columns", str(cm.exception)) + + table_drop = table.copy( + {'global-index.column-update-action': 'DROP_PARTITION_INDEX'} + ) + wb_drop = table_drop.new_batch_write_builder() + tu_drop = wb_drop.new_update().with_update_type(['k']) + wb_drop.new_commit().commit(tu_drop.update_by_arrow_with_row_id(update_data)) + + table_after = self.catalog.get_table('default.test_btree_index_string') + rb = table_after.new_read_builder() + rb.with_filter(rb.new_predicate_builder().equal('k', 'k_updated')) + rows_new = rb.new_read().to_arrow(rb.new_scan().plan().splits()) + self.assertGreater(len(rows_new), 0, + "after DROP_PARTITION_INDEX, new value should read") + + from pypaimon.manifest.index_manifest_file import IndexManifestFile + snap = table_after.snapshot_manager().get_latest_snapshot() + entries = (IndexManifestFile(table_after).read(snap.index_manifest) + if snap.index_manifest else []) + field_by_id = {f.id: f.name for f in table_after.fields} + remaining = [e for e in entries + if e.index_file.global_index_meta is not None + and field_by_id.get( + e.index_file.global_index_meta.index_field_id) == 'k'] + self.assertEqual(remaining, [], + "btree index entries for 'k' should be dropped") + + def _test_partial_append_does_not_trigger_index_action(self): + table = self.catalog.get_table('default.test_btree_index_string') + snap_before = table.snapshot_manager().get_latest_snapshot() + + wb = table.new_batch_write_builder() + tw = wb.new_write() + tw.with_write_type(['k']) + tw.write_arrow(pa.table({'k': ['k_new']})) + tc = wb.new_commit() + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + snap_after = table.snapshot_manager().get_latest_snapshot() + self.assertGreater(snap_after.id, snap_before.id) + self.assertIsNotNone( + snap_after.index_manifest, + "partial append should not drop index manifest" + ) + @parameterized.expand([('json',), ('csv',)]) def test_read_compressed_text_append_table(self, file_format): table = self.catalog.get_table( @@ -832,6 +949,103 @@ def test_read_tantivy_full_text_index(self): self.assertIn(1, ids3) self.assertIn(3, ids3) + ngram_table = self.catalog.get_table('default.test_tantivy_fulltext_ngram') + + # Search for Chinese fragments using the ngram tokenizer metadata written by Java. + ngram_builder = ngram_table.new_full_text_search_builder() + ngram_builder.with_text_column('content') + ngram_builder.with_query_text('中文') + ngram_builder.with_limit(10) + + ngram_result = ngram_builder.execute_local() + ngram_row_ids = sorted(list(ngram_result.results())) + print(f"Tantivy ngram search for '中文': row_ids={ngram_row_ids}") + self.assertEqual(ngram_row_ids, [0, 4]) + + ngram_read_builder = ngram_table.new_read_builder() + ngram_scan = ngram_read_builder.new_scan().with_global_index_result(ngram_result) + ngram_pa_table = ngram_read_builder.new_read().to_arrow(ngram_scan.plan().splits()) + ngram_pa_table = table_sort_by(ngram_pa_table, 'id') + self.assertEqual(ngram_pa_table.column('id').to_pylist(), [0, 4]) + self.assertEqual( + ngram_pa_table.column('content').to_pylist(), + ['Apache Paimon 支持中文全文检索', '中文索引支持片段查询']) + + fragment_builder = ngram_table.new_full_text_search_builder() + fragment_builder.with_text_column('content') + fragment_builder.with_query_text('片段') + fragment_builder.with_limit(10) + + fragment_result = fragment_builder.execute_local() + fragment_row_ids = sorted(list(fragment_result.results())) + print(f"Tantivy ngram search for '片段': row_ids={fragment_row_ids}") + self.assertEqual(fragment_row_ids, [4]) + + ngram_and_builder = ngram_table.new_full_text_search_builder() + ngram_and_builder.with_text_column('content') + ngram_and_builder.with_query_text('中文 片段') + ngram_and_builder.with_query_operator('and') + ngram_and_builder.with_limit(10) + + ngram_and_result = ngram_and_builder.execute_local() + ngram_and_row_ids = sorted(list(ngram_and_result.results())) + print(f"Tantivy ngram AND search for '中文 片段': row_ids={ngram_and_row_ids}") + self.assertEqual(ngram_and_row_ids, [4]) + + simple_table = self.catalog.get_table('default.test_tantivy_fulltext_simple') + simple_builder = simple_table.new_full_text_search_builder() + simple_builder.with_text_column('content') + simple_builder.with_query_text('running') + simple_builder.with_limit(10) + + simple_result = simple_builder.execute_local() + simple_row_ids = sorted(list(simple_result.results())) + print(f"Tantivy simple stemmed search for 'running': row_ids={simple_row_ids}") + self.assertEqual(simple_row_ids, [0, 1, 2]) + + jieba_table = self.catalog.get_table('default.test_tantivy_fulltext_jieba') + + # Search for Chinese words using the jieba tokenizer metadata written by Java. + jieba_builder = jieba_table.new_full_text_search_builder() + jieba_builder.with_text_column('content') + jieba_builder.with_query_text('售货员') + jieba_builder.with_limit(10) + + jieba_result = jieba_builder.execute_local() + jieba_row_ids = sorted(list(jieba_result.results())) + print(f"Tantivy jieba search for '售货员': row_ids={jieba_row_ids}") + self.assertEqual(jieba_row_ids, [0]) + + jieba_read_builder = jieba_table.new_read_builder() + jieba_scan = jieba_read_builder.new_scan().with_global_index_result(jieba_result) + jieba_pa_table = jieba_read_builder.new_read().to_arrow(jieba_scan.plan().splits()) + jieba_pa_table = table_sort_by(jieba_pa_table, 'id') + self.assertEqual(jieba_pa_table.column('id').to_pylist(), [0]) + self.assertEqual( + jieba_pa_table.column('content').to_pylist(), + ['张华在百货公司当售货员']) + + jieba_phrase_builder = jieba_table.new_full_text_search_builder() + jieba_phrase_builder.with_text_column('content') + jieba_phrase_builder.with_query_text('自然') + jieba_phrase_builder.with_limit(10) + + jieba_phrase_result = jieba_phrase_builder.execute_local() + jieba_phrase_row_ids = sorted(list(jieba_phrase_result.results())) + print(f"Tantivy jieba search for '自然': row_ids={jieba_phrase_row_ids}") + self.assertEqual(jieba_phrase_row_ids, [3]) + + jieba_and_builder = jieba_table.new_full_text_search_builder() + jieba_and_builder.with_text_column('content') + jieba_and_builder.with_query_text('中文 自然') + jieba_and_builder.with_query_operator('and') + jieba_and_builder.with_limit(10) + + jieba_and_result = jieba_and_builder.execute_local() + jieba_and_row_ids = sorted(list(jieba_and_result.results())) + print(f"Tantivy jieba AND search for '中文 自然': row_ids={jieba_and_row_ids}") + self.assertEqual(jieba_and_row_ids, [3]) + def test_read_lumina_vector_index(self): """Test reading a Lumina vector index built by Java (orc and lance formats).""" test_cases = [('default.test_lumina_vector', 'orc')] diff --git a/paimon-python/pypaimon/tests/external_paths_test.py b/paimon-python/pypaimon/tests/external_paths_test.py index ba2c07534fe9..aab4fa38a599 100644 --- a/paimon-python/pypaimon/tests/external_paths_test.py +++ b/paimon-python/pypaimon/tests/external_paths_test.py @@ -24,8 +24,11 @@ from pypaimon import CatalogFactory, Schema from pypaimon.catalog.catalog import Identifier -from pypaimon.common.options.core_options import CoreOptions, ExternalPathStrategy -from pypaimon.common.external_path_provider import ExternalPathProvider +from pypaimon.common.external_path_provider import ( + EntropyInjectExternalPathProvider, ExternalPathProvider, + RoundRobinExternalPathProvider, WeightedExternalPathProvider, _murmur3_32) +from pypaimon.common.options.core_options import (CoreOptions, + ExternalPathStrategy) class ExternalPathProviderTest(unittest.TestCase): @@ -40,7 +43,7 @@ def test_path_selection_and_structure(self): "oss://bucket3/external", ] relative_path = "partition=value/bucket-0" - provider = ExternalPathProvider(external_paths, relative_path) + provider = RoundRobinExternalPathProvider(external_paths, relative_path) paths = [provider.get_next_external_data_path("file.parquet") for _ in range(6)] @@ -56,18 +59,247 @@ def test_path_selection_and_structure(self): self.assertIn("file.parquet", paths[0]) # Test single path - single_provider = ExternalPathProvider(["oss://bucket/external"], "bucket-0") + single_provider = RoundRobinExternalPathProvider(["oss://bucket/external"], "bucket-0") single_path = single_provider.get_next_external_data_path("data.parquet") self.assertIn("bucket/external", single_path) self.assertIn("bucket-0", single_path) self.assertIn("data.parquet", single_path) # Test empty relative path - empty_provider = ExternalPathProvider(["oss://bucket/external"], "") + empty_provider = RoundRobinExternalPathProvider(["oss://bucket/external"], "") empty_path = empty_provider.get_next_external_data_path("file.parquet") self.assertIn("bucket/external", empty_path) self.assertIn("file.parquet", empty_path) + def test_factory_create_round_robin(self): + """Test ExternalPathProvider.create() with round-robin strategy.""" + provider = ExternalPathProvider.create( + "round-robin", ["oss://a/path", "oss://b/path"], "bucket-0" + ) + self.assertIsInstance(provider, RoundRobinExternalPathProvider) + paths = [provider.get_next_external_data_path("f.parquet") for _ in range(4)] + schemes_used = {p.split("://")[1].split("/")[0] for p in paths} + self.assertEqual(len(schemes_used), 2) + + def test_factory_create_specific_fs(self): + """Test ExternalPathProvider.create() with specific-fs (falls through to round-robin).""" + provider = ExternalPathProvider.create( + "specific-fs", ["oss://bucket/path"] + ) + self.assertIsInstance(provider, RoundRobinExternalPathProvider) + + def test_factory_create_none(self): + """Test ExternalPathProvider.create() with none strategy returns None.""" + provider = ExternalPathProvider.create("none", ["oss://bucket/path"]) + self.assertIsNone(provider) + + def test_factory_create_entropy_inject(self): + """Test ExternalPathProvider.create() with entropy-inject strategy.""" + provider = ExternalPathProvider.create( + "entropy-inject", ["oss://a/path", "oss://b/path"], "bucket-0" + ) + self.assertIsInstance(provider, EntropyInjectExternalPathProvider) + + def test_factory_create_weighted(self): + """Test ExternalPathProvider.create() with weight-robin strategy.""" + provider = ExternalPathProvider.create( + "weight-robin", ["oss://a/path", "oss://b/path"], "bucket-0", [10, 5] + ) + self.assertIsInstance(provider, WeightedExternalPathProvider) + + def test_factory_create_weighted_fallback(self): + """Test weight-robin falls back to round-robin when paths < 2 or no weights.""" + provider = ExternalPathProvider.create( + "weight-robin", ["oss://a/path"], "bucket-0", [10] + ) + self.assertIsInstance(provider, RoundRobinExternalPathProvider) + + provider2 = ExternalPathProvider.create( + "weight-robin", ["oss://a/path", "oss://b/path"], "bucket-0", None + ) + self.assertIsInstance(provider2, RoundRobinExternalPathProvider) + + +class Murmur3HashTest(unittest.TestCase): + """Test murmur3_32 hash implementation for Java Guava compatibility.""" + + def test_empty_string(self): + """Empty string should produce a deterministic hash.""" + result = _murmur3_32(b'') + # Guava: Hashing.murmur3_32().hashString("", UTF_8).asInt() == 0 + self.assertEqual(result, 0) + + def test_deterministic(self): + """Same input always produces same output.""" + for s in [b'test', b'hello world', b'data-0001.blob']: + self.assertEqual(_murmur3_32(s), _murmur3_32(s)) + + def test_known_values(self): + """Verify against Guava Hashing.murmur3_32().hashString(s, UTF_8).asInt(). + + Values confirmed by running Java Guava 32.0.0. + """ + self.assertEqual(_murmur3_32(b''), 0) + self.assertEqual(_murmur3_32(b'a'), 1009084850) + self.assertEqual(_murmur3_32(b'hello'), 613153351) + self.assertEqual(_murmur3_32(b'world'), -74040069) + self.assertEqual(_murmur3_32(b'test'), -1167338989) + self.assertEqual(_murmur3_32(b'data-abc.blob'), 894520562) + self.assertEqual(_murmur3_32(b'data-xyz.blob'), -822867934) + + def test_signed_32bit_range(self): + """Result should be in signed 32-bit integer range.""" + for s in [b'a', b'ab', b'abc', b'abcd', b'abcde']: + result = _murmur3_32(s) + self.assertGreaterEqual(result, -(2 ** 31)) + self.assertLessEqual(result, 2 ** 31 - 1) + + +class EntropyInjectExternalPathProviderTest(unittest.TestCase): + """Test EntropyInjectExternalPathProvider functionality.""" + + def test_hash_directory_structure(self): + """Hash directories should have depth=3 with 4-bit segments + remainder.""" + provider = EntropyInjectExternalPathProvider(["oss://bucket/ext"], "bucket-0") + hash_dirs = provider._compute_hash("test-file.blob") + parts = hash_dirs.split("/") + self.assertEqual(len(parts), 4) + self.assertEqual(len(parts[0]), 4) + self.assertEqual(len(parts[1]), 4) + self.assertEqual(len(parts[2]), 4) + self.assertEqual(len(parts[3]), 8) + for part in parts: + self.assertTrue(all(c in ('0', '1') for c in part)) + + def test_deterministic_path(self): + """Same filename always produces same hash directories.""" + provider = EntropyInjectExternalPathProvider( + ["oss://bucket/ext"], "dt=20240101/bucket-0" + ) + hash1 = provider._compute_hash("data-001.parquet") + hash2 = provider._compute_hash("data-001.parquet") + self.assertEqual(hash1, hash2) + + def test_path_format(self): + """Full path should include base/relative/hashDirs/fileName.""" + provider = EntropyInjectExternalPathProvider( + ["oss://bucket/ext"], "dt=20240101/bucket-0" + ) + path = provider.get_next_external_data_path("data-001.parquet") + self.assertIn("oss://bucket/ext", path) + self.assertIn("dt=20240101/bucket-0", path) + self.assertIn("data-001.parquet", path) + # Should have hash dirs between relative path and filename + parts_between = path.split("bucket-0/")[1].split("/data-001.parquet")[0] + self.assertEqual(len(parts_between.split("/")), 4) + + def test_multi_path_rotation(self): + """Paths should rotate across external paths.""" + provider = EntropyInjectExternalPathProvider( + ["oss://a/ext", "oss://b/ext", "oss://c/ext"], "" + ) + paths = [provider.get_next_external_data_path(f"file-{i}.parquet") for i in range(6)] + bases = [p.split("://")[1][0] for p in paths] + self.assertEqual(set(bases), {'a', 'b', 'c'}) + + +class WeightedExternalPathProviderTest(unittest.TestCase): + """Test WeightedExternalPathProvider functionality.""" + + def test_weight_distribution(self): + """Paths should be selected roughly proportional to weights.""" + import random as _random + _random.seed(42) + provider = WeightedExternalPathProvider( + ["oss://a/path", "oss://b/path"], "bucket-0", [90, 10] + ) + counts = {"a": 0, "b": 0} + for i in range(10000): + path = provider.get_next_external_data_path(f"file-{i}.parquet") + if "://a/" in path: + counts["a"] += 1 + else: + counts["b"] += 1 + + # With 90:10 weights, "a" should get ~90% (allow 5% tolerance) + ratio_a = counts["a"] / 10000 + self.assertGreater(ratio_a, 0.85) + self.assertLess(ratio_a, 0.95) + + def test_equal_weights(self): + """Equal weights should distribute roughly evenly.""" + import random as _random + _random.seed(42) + provider = WeightedExternalPathProvider( + ["oss://a/path", "oss://b/path", "oss://c/path"], "bucket-0", [1, 1, 1] + ) + counts = {"a": 0, "b": 0, "c": 0} + for i in range(9000): + path = provider.get_next_external_data_path(f"file-{i}.parquet") + for key in counts: + if f"://{key}/" in path: + counts[key] += 1 + + for key in counts: + ratio = counts[key] / 9000 + self.assertGreater(ratio, 0.28) + self.assertLess(ratio, 0.39) + + def test_path_format(self): + """Path should include base/relative/fileName.""" + provider = WeightedExternalPathProvider( + ["oss://a/path", "oss://b/path"], "dt=20240101/bucket-0", [5, 5] + ) + path = provider.get_next_external_data_path("data.parquet") + self.assertIn("dt=20240101/bucket-0", path) + self.assertIn("data.parquet", path) + + def test_mismatched_lengths_raises(self): + """Should raise ValueError if paths and weights have different lengths.""" + with self.assertRaises(ValueError): + WeightedExternalPathProvider( + ["oss://a/path", "oss://b/path"], "bucket-0", [10] + ) + + +class WeightsParsingTest(unittest.TestCase): + """Test CoreOptions.data_file_external_paths_weights() parsing and validation.""" + + def test_valid_weights(self): + """Normal comma-separated positive integers.""" + from pypaimon.common.options.core_options import CoreOptions + opts = CoreOptions.from_dict({"data-file.external-paths.weights": "10,5,15"}) + self.assertEqual(opts.data_file_external_paths_weights(), [10, 5, 15]) + + def test_none_when_not_configured(self): + """Returns None when option is not set.""" + from pypaimon.common.options.core_options import CoreOptions + opts = CoreOptions.from_dict({}) + self.assertIsNone(opts.data_file_external_paths_weights()) + + def test_zero_weight_raises(self): + """Zero weight should raise ValueError.""" + from pypaimon.common.options.core_options import CoreOptions + opts = CoreOptions.from_dict({"data-file.external-paths.weights": "10,0,5"}) + with self.assertRaises(ValueError) as ctx: + opts.data_file_external_paths_weights() + self.assertIn("positive", str(ctx.exception)) + + def test_negative_weight_raises(self): + """Negative weight should raise ValueError.""" + from pypaimon.common.options.core_options import CoreOptions + opts = CoreOptions.from_dict({"data-file.external-paths.weights": "10,-5"}) + with self.assertRaises(ValueError) as ctx: + opts.data_file_external_paths_weights() + self.assertIn("positive", str(ctx.exception)) + + def test_empty_element_raises(self): + """Empty element like '10,,5' should raise ValueError (align with Java NumberFormatException).""" + from pypaimon.common.options.core_options import CoreOptions + opts = CoreOptions.from_dict({"data-file.external-paths.weights": "10,,5"}) + with self.assertRaises(ValueError): + opts.data_file_external_paths_weights() + class ExternalPathsConfigTest(unittest.TestCase): """Test external paths configuration parsing through FileStoreTable._create_external_paths().""" @@ -188,6 +420,7 @@ def test_create_external_path_provider(self): # Test with external paths configured provider = path_factory.create_external_path_provider(("value1",), 0) self.assertIsNotNone(provider) + self.assertIsInstance(provider, RoundRobinExternalPathProvider) path = provider.get_next_external_data_path("file.parquet") self.assertTrue("bucket1" in str(path) or "bucket2" in str(path)) self.assertIn("dt=value1", str(path)) @@ -234,6 +467,55 @@ def test_create_external_path_provider(self): provider3 = table3.path_factory().create_external_path_provider((), 0) self.assertIsNone(provider3) + def test_create_entropy_inject_provider(self): + """Test creating EntropyInject provider from path factory.""" + table_name = "test_db.entropy_test" + # Manually delete table directory if it exists + try: + table_path = self.catalog.get_table_path(Identifier.from_string(table_name)) + if self.catalog.file_io.exists(table_path): + self.catalog.file_io.delete(table_path, recursive=True) + except Exception: + pass # Table may not exist, ignore + options = { + CoreOptions.DATA_FILE_EXTERNAL_PATHS.key(): "oss://bucket1/path1,oss://bucket2/path2", + CoreOptions.DATA_FILE_EXTERNAL_PATHS_STRATEGY.key(): ExternalPathStrategy.ENTROPY_INJECT, + } + pa_schema = pa.schema([("id", pa.int32()), ("name", pa.string())]) + schema = Schema.from_pyarrow_schema(pa_schema, options=options) + self.catalog.create_table(table_name, schema, True) + table = self.catalog.get_table(table_name) + path_factory = table.path_factory() + + provider = path_factory.create_external_path_provider((), 0) + self.assertIsNotNone(provider) + self.assertIsInstance(provider, EntropyInjectExternalPathProvider) + + def test_create_weighted_provider(self): + """Test creating Weighted provider from path factory.""" + table_name = "test_db.weighted_test" + # Manually delete table directory if it exists + try: + table_path = self.catalog.get_table_path(Identifier.from_string(table_name)) + if self.catalog.file_io.exists(table_path): + self.catalog.file_io.delete(table_path, recursive=True) + except Exception: + pass # Table may not exist, ignore + options = { + CoreOptions.DATA_FILE_EXTERNAL_PATHS.key(): "oss://bucket1/path1,oss://bucket2/path2", + CoreOptions.DATA_FILE_EXTERNAL_PATHS_STRATEGY.key(): ExternalPathStrategy.WEIGHTED, + CoreOptions.DATA_FILE_EXTERNAL_PATHS_WEIGHTS.key(): "10,5", + } + pa_schema = pa.schema([("id", pa.int32()), ("name", pa.string())]) + schema = Schema.from_pyarrow_schema(pa_schema, options=options) + self.catalog.create_table(table_name, schema, True) + table = self.catalog.get_table(table_name) + path_factory = table.path_factory() + + provider = path_factory.create_external_path_provider((), 0) + self.assertIsNotNone(provider) + self.assertIsInstance(provider, WeightedExternalPathProvider) + class ExternalPathsIntegrationTest(unittest.TestCase): """Integration tests for external paths feature.""" diff --git a/paimon-python/pypaimon/tests/external_storage_blob_test.py b/paimon-python/pypaimon/tests/external_storage_blob_test.py index 505e4407a9e6..e7ad8273d9a1 100644 --- a/paimon-python/pypaimon/tests/external_storage_blob_test.py +++ b/paimon-python/pypaimon/tests/external_storage_blob_test.py @@ -94,7 +94,7 @@ def test_validation_field_not_blob_type(self): }) with self.assertRaises(ValueError) as ctx: self.catalog.create_table('test_db.not_blob_type_test', schema, False) - self.assertIn('must be a BLOB type field', str(ctx.exception)) + self.assertIn('must be blob fields', str(ctx.exception)) def test_validation_blob_not_null_field_passes(self): """BLOB NOT NULL fields should pass validation (not be rejected by str comparison).""" diff --git a/paimon-python/pypaimon/tests/global_index_evaluator_test.py b/paimon-python/pypaimon/tests/global_index_evaluator_test.py index 3017835d78d4..7f3b49388a3a 100644 --- a/paimon-python/pypaimon/tests/global_index_evaluator_test.py +++ b/paimon-python/pypaimon/tests/global_index_evaluator_test.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. +import threading +import time import unittest from concurrent.futures import ThreadPoolExecutor from pypaimon.common.predicate import Predicate from pypaimon.globalindex.global_index_evaluator import GlobalIndexEvaluator -from pypaimon.globalindex.global_index_reader import GlobalIndexReader +from pypaimon.globalindex.global_index_reader import GlobalIndexReader, _completed_future from pypaimon.globalindex.global_index_result import GlobalIndexResult from pypaimon.schema.data_types import DataField, AtomicType from pypaimon.utils.range import Range @@ -33,7 +35,21 @@ def __init__(self, result): self._result = result def visit_equal(self, field_ref, literal): - return self._result + return _completed_future(self._result) + + def close(self): + pass + + +class AsyncStubReader(GlobalIndexReader): + """A test reader that dispatches to an executor (simulating real async readers).""" + + def __init__(self, result, executor): + self._result = result + self._executor = executor + + def visit_equal(self, field_ref, literal): + return self._executor.submit(lambda: self._result) def close(self): pass @@ -79,8 +95,7 @@ def test_and_parallel_multiple_fields(self): executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, - lambda field: [StubGlobalIndexReader(field_results[field.id])], - executor, + lambda field: [AsyncStubReader(field_results[field.id], executor)], ) predicate = Predicate( @@ -112,8 +127,7 @@ def test_or_parallel_multiple_fields(self): executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, - lambda field: [StubGlobalIndexReader(field_results[field.id])], - executor, + lambda field: [AsyncStubReader(field_results[field.id], executor)], ) predicate = Predicate( @@ -144,8 +158,7 @@ def readers_fn(field): return [StubGlobalIndexReader(result_a)] return [] - executor = ThreadPoolExecutor(max_workers=2) - evaluator = GlobalIndexEvaluator(fields, readers_fn, executor) + evaluator = GlobalIndexEvaluator(fields, readers_fn) predicate = Predicate( method='or', index=None, field=None, @@ -159,7 +172,6 @@ def readers_fn(field): self.assertIsNone(result) evaluator.close() - executor.shutdown(wait=False) def test_and_with_disjoint_results(self): fields = _make_fields() @@ -171,8 +183,7 @@ def test_and_with_disjoint_results(self): executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, - lambda field: [StubGlobalIndexReader(field_results[field.id])], - executor, + lambda field: [AsyncStubReader(field_results[field.id], executor)], ) predicate = Predicate( @@ -229,8 +240,7 @@ def test_nested_and_does_not_deadlock_with_small_pool(self): executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, - lambda field: [StubGlobalIndexReader(field_results[field.id])], - executor, + lambda field: [AsyncStubReader(field_results[field.id], executor)], ) # Nested binary tree: and(and(a, b), c) @@ -270,8 +280,7 @@ def test_nested_or_does_not_deadlock_with_small_pool(self): executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, - lambda field: [StubGlobalIndexReader(field_results[field.id])], - executor, + lambda field: [AsyncStubReader(field_results[field.id], executor)], ) # Nested binary tree: or(or(a, b), c) @@ -311,8 +320,7 @@ def test_mixed_nested_does_not_deadlock_with_small_pool(self): executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, - lambda field: [StubGlobalIndexReader(field_results[field.id])], - executor, + lambda field: [AsyncStubReader(field_results[field.id], executor)], ) # AND(OR(a, b), OR(a, c)) — mixed nesting @@ -360,8 +368,7 @@ def test_deep_mixed_nested_does_not_deadlock_with_small_pool(self): executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, - lambda field: [StubGlobalIndexReader(field_results[field.id])], - executor, + lambda field: [AsyncStubReader(field_results[field.id], executor)], ) # AND(OR(AND(a, b), c), OR(AND(a, c), b)) — deep mixed nesting @@ -410,28 +417,29 @@ def test_deep_mixed_nested_does_not_deadlock_with_small_pool(self): evaluator.close() executor.shutdown(wait=False) - def test_same_field_predicates_not_accessed_concurrently(self): - import threading - import time - + def test_same_field_predicates_accessed_concurrently(self): fields = _make_fields() concurrency = [0] max_concurrency = [0] lock = threading.Lock() + executor = ThreadPoolExecutor(max_workers=4) + class ConcurrencyDetectingReader(GlobalIndexReader): def __init__(self, result): self._result = result def visit_equal(self, field_ref, literal): - with lock: - concurrency[0] += 1 - max_concurrency[0] = max(max_concurrency[0], concurrency[0]) - time.sleep(0.05) - with lock: - concurrency[0] -= 1 - return self._result + def _work(): + with lock: + concurrency[0] += 1 + max_concurrency[0] = max(max_concurrency[0], concurrency[0]) + time.sleep(0.05) + with lock: + concurrency[0] -= 1 + return self._result + return executor.submit(_work) def close(self): pass @@ -439,14 +447,12 @@ def close(self): result_a = GlobalIndexResult.from_range(Range(1, 5)) reader = ConcurrencyDetectingReader(result_a) - executor = ThreadPoolExecutor(max_workers=4) evaluator = GlobalIndexEvaluator( fields, lambda field: [reader], - executor, ) - # AND(a=1, a=2, a=3) — all same field, must not run concurrently + # AND(a=1, a=2, a=3) — evaluator dispatches concurrently, readers own their thread-safety predicate = Predicate( method='and', index=None, field=None, literals=[ @@ -458,32 +464,33 @@ def close(self): evaluator.evaluate(predicate) - self.assertEqual(max_concurrency[0], 1) + self.assertGreater(max_concurrency[0], 1) evaluator.close() executor.shutdown(wait=False) - def test_mixed_nested_same_field_not_accessed_concurrently(self): - import threading - import time - + def test_mixed_nested_same_field_accessed_concurrently(self): fields = _make_fields() concurrency_a = [0] max_concurrency_a = [0] lock = threading.Lock() + executor = ThreadPoolExecutor(max_workers=4) + class ConcurrencyDetectingReader(GlobalIndexReader): def __init__(self, result): self._result = result def visit_equal(self, field_ref, literal): - with lock: - concurrency_a[0] += 1 - max_concurrency_a[0] = max(max_concurrency_a[0], concurrency_a[0]) - time.sleep(0.05) - with lock: - concurrency_a[0] -= 1 - return self._result + def _work(): + with lock: + concurrency_a[0] += 1 + max_concurrency_a[0] = max(max_concurrency_a[0], concurrency_a[0]) + time.sleep(0.05) + with lock: + concurrency_a[0] -= 1 + return self._result + return executor.submit(_work) def close(self): pass @@ -497,10 +504,10 @@ def readers_fn(field): return [detecting_reader] return [normal_reader] - executor = ThreadPoolExecutor(max_workers=4) - evaluator = GlobalIndexEvaluator(fields, readers_fn, executor) + evaluator = GlobalIndexEvaluator(fields, readers_fn) # AND(OR(a=1, b=2), OR(a=3, c=4)) — field a in both OR subtrees + # evaluator dispatches concurrently, readers own their thread-safety predicate = Predicate( method='and', index=None, field=None, literals=[ @@ -523,48 +530,49 @@ def readers_fn(field): evaluator.evaluate(predicate) - self.assertEqual(max_concurrency_a[0], 1) + self.assertGreater(max_concurrency_a[0], 1) evaluator.close() executor.shutdown(wait=False) - def test_lazy_result_not_materialized_concurrently(self): - import threading - import time - + def test_internally_locked_reader_serializes_access(self): fields = _make_fields() concurrency = [0] max_concurrency = [0] - lock = threading.Lock() + count_lock = threading.Lock() - class LazyIOReader(GlobalIndexReader): - def visit_equal(self, field_ref, literal): - def supplier(): - with lock: - concurrency[0] += 1 - max_concurrency[0] = max(max_concurrency[0], concurrency[0]) - time.sleep(0.05) - with lock: - concurrency[0] -= 1 - return GlobalIndexResult.from_range(Range(1, 3)).results() + executor = ThreadPoolExecutor(max_workers=4) - return GlobalIndexResult.create(supplier) + class InternallyLockedReader(GlobalIndexReader): + def __init__(self): + self._lock = threading.Lock() + + def visit_equal(self, field_ref, literal): + def _work(): + with self._lock: + with count_lock: + concurrency[0] += 1 + max_concurrency[0] = max(max_concurrency[0], concurrency[0]) + time.sleep(0.05) + with count_lock: + concurrency[0] -= 1 + return GlobalIndexResult.from_range(Range(1, 3)) + return executor.submit(_work) def close(self): pass - lazy_reader = LazyIOReader() + locked_reader = InternallyLockedReader() def readers_fn(field): if field.id == 0: - return [lazy_reader] + return [locked_reader] return [StubGlobalIndexReader(GlobalIndexResult.from_range(Range(1, 5)))] - executor = ThreadPoolExecutor(max_workers=4) - evaluator = GlobalIndexEvaluator(fields, readers_fn, executor) + evaluator = GlobalIndexEvaluator(fields, readers_fn) # AND(OR(a=1, b=2), OR(a=3, c=4)) — field a in both OR subtrees - # lazy results for field a must not be materialized concurrently + # reader with internal lock serializes access predicate = Predicate( method='and', index=None, field=None, literals=[ @@ -600,10 +608,9 @@ def test_multiple_readers_per_field_combined_with_and(self): evaluator = GlobalIndexEvaluator( fields, lambda field: [ - StubGlobalIndexReader(reader_result1), - StubGlobalIndexReader(reader_result2), + AsyncStubReader(reader_result1, executor), + AsyncStubReader(reader_result2, executor), ], - executor, ) predicate = Predicate(method='equal', index=0, field='a', literals=[42]) @@ -622,11 +629,9 @@ def test_non_field_leaf_predicate_does_not_throw(self): fields = _make_fields() result_a = GlobalIndexResult.from_range(Range(1, 3)) - executor = ThreadPoolExecutor(max_workers=2) evaluator = GlobalIndexEvaluator( fields, lambda field: [StubGlobalIndexReader(result_a)], - executor, ) # AND(non-field leaf, a=1) — non-field leaf has field=None @@ -644,7 +649,6 @@ def test_non_field_leaf_predicate_does_not_throw(self): self.assertIsNotNone(result) self.assertEqual(result.results().cardinality(), 3) evaluator.close() - executor.shutdown(wait=False) def test_null_predicate(self): fields = _make_fields() diff --git a/paimon-python/pypaimon/tests/global_index_test.py b/paimon-python/pypaimon/tests/global_index_test.py index 691c081c10e4..873e93358455 100644 --- a/paimon-python/pypaimon/tests/global_index_test.py +++ b/paimon-python/pypaimon/tests/global_index_test.py @@ -16,8 +16,17 @@ # under the License. import unittest +from unittest.mock import patch + +import pyarrow as pa from pypaimon.globalindex.global_index_result import GlobalIndexResult +from pypaimon.index.index_file_handler import IndexFileHandler +from pypaimon.snapshot.snapshot_manager import SnapshotManager +from pypaimon.tests.data_evolution_test_helpers import ( + BatchModeMixin, + DataEvolutionTestBase, +) from pypaimon.utils.range import Range @@ -38,3 +47,78 @@ def test_chained_and(self): result = result.and_(other) self.assertEqual(result.results().cardinality(), 10001) + + +class PlanSnapshotFetchRegressionTest( + BatchModeMixin, DataEvolutionTestBase, unittest.TestCase): + + table_options = { + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'global-index.enabled': 'true', + 'bucket': '-1', + } + + def test_plan_fetches_latest_snapshot_only_once(self): + table = self._create_table() + self._write_arrow(table, pa.table( + {'id': [1, 2, 3], 'name': ['a', 'b', 'c'], + 'age': [10, 20, 30], 'city': ['x', 'y', 'z']}, + schema=self.pa_schema)) + + fresh_table = self.catalog.get_table(table.identifier.get_full_name()) + rb = fresh_table.new_read_builder() + rb = rb.with_filter(rb.new_predicate_builder().equal('id', 1)) + + orig_get_latest = SnapshotManager.get_latest_snapshot + call_count = [0] + + def counting(self_sm, *args, **kwargs): + call_count[0] += 1 + return orig_get_latest(self_sm, *args, **kwargs) + + with patch.object(SnapshotManager, 'get_latest_snapshot', counting): + rb.new_scan().plan().splits() + + self.assertEqual( + 1, call_count[0], + msg=f"Plan fetched latest snapshot {call_count[0]} times — " + "duplicate from #7513: manifest_scanner + " + "GlobalIndexScanner.create both fetch independently.") + + def test_time_travel_plan(self): + table = self._create_table() + self._write_arrow(table, pa.table( + {'id': [1], 'name': ['a'], 'age': [10], 'city': ['x']}, + schema=self.pa_schema)) + snapshot_1_id = table.snapshot_manager().get_latest_snapshot().id + self._write_arrow(table, pa.table( + {'id': [2], 'name': ['b'], 'age': [20], 'city': ['y']}, + schema=self.pa_schema)) + + travel_table = self.catalog.get_table( + table.identifier.get_full_name() + ).copy({'scan.snapshot-id': str(snapshot_1_id)}) + rb = travel_table.new_read_builder() + rb = rb.with_filter(rb.new_predicate_builder().equal('id', 1)) + + orig_scan = IndexFileHandler.scan + seen_snapshot_ids = [] + + def spy_scan(self_h, snapshot, entry_filter=None): + seen_snapshot_ids.append(snapshot.id if snapshot else None) + return orig_scan(self_h, snapshot, entry_filter) + + with patch.object(IndexFileHandler, 'scan', spy_scan): + rb.new_scan().plan().splits() + + self.assertTrue(seen_snapshot_ids, + "IndexFileHandler.scan was never called") + self.assertEqual( + snapshot_1_id, seen_snapshot_ids[0], + msg=f"Global index evaluated against snapshot " + f"{seen_snapshot_ids[0]}, expected time-travel snapshot " + f"{snapshot_1_id}. Before #7513 was fixed, " + "GlobalIndexScanner.create self-fetched latest snapshot, " + "so global index used latest while manifest used the " + "time-travel snapshot — silent correctness bug.") diff --git a/paimon-python/pypaimon/tests/hdfs_native_test.py b/paimon-python/pypaimon/tests/hdfs_native_test.py new file mode 100644 index 000000000000..957ee8f123dd --- /dev/null +++ b/paimon-python/pypaimon/tests/hdfs_native_test.py @@ -0,0 +1,917 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import tempfile +import types +import unittest +from unittest.mock import MagicMock, patch + +import pyarrow.fs as pafs + +from pypaimon.common.file_io import FileIO +from pypaimon.common.options import Options +from pypaimon.common.options.config import HdfsOptions + + +def _install_fake_hdfs_native(): + """Install a fake hdfs_native module (with .fsspec submodule) into + sys.modules. + + Returns (fake_module, mock_client_cls, mock_write_options_cls). + """ + fake = types.ModuleType("hdfs_native") + fake.Client = MagicMock(name="Client") + fake.WriteOptions = MagicMock(name="WriteOptions") + fsspec_mod = types.ModuleType("hdfs_native.fsspec") + fsspec_mod.HdfsFileSystem = MagicMock(name="HdfsFileSystem") + fsspec_mod.ViewfsFileSystem = MagicMock(name="ViewfsFileSystem") + fake.fsspec = fsspec_mod + sys.modules["hdfs_native"] = fake + sys.modules["hdfs_native.fsspec"] = fsspec_mod + return fake, fake.Client, fake.WriteOptions + + +def _uninstall_fake_hdfs_native(): + sys.modules.pop("hdfs_native", None) + sys.modules.pop("hdfs_native.fsspec", None) + # Also drop the cached HdfsNativeFileIO so a re-import sees the new fake + sys.modules.pop( + "pypaimon.filesystem.hdfs_native_file_io", None) + + +class HdfsOptionsTest(unittest.TestCase): + + def test_defaults(self): + opts = Options({}) + self.assertEqual(opts.get(HdfsOptions.HDFS_CLIENT_IMPL), "native") + self.assertTrue(opts.get(HdfsOptions.HDFS_CLIENT_FALLBACK_TO_PYARROW)) + self.assertIsNone(opts.get(HdfsOptions.HDFS_CONF_DIR)) + + def test_explicit_pyarrow(self): + opts = Options({"hdfs.client.impl": "pyarrow"}) + self.assertEqual(opts.get(HdfsOptions.HDFS_CLIENT_IMPL), "pyarrow") + + def test_explicit_fallback_false(self): + opts = Options({"hdfs.client.fallback-to-pyarrow": "false"}) + self.assertFalse(opts.get(HdfsOptions.HDFS_CLIENT_FALLBACK_TO_PYARROW)) + + +class HdfsNativeFileIORoutingTest(unittest.TestCase): + + def setUp(self): + _uninstall_fake_hdfs_native() + + def tearDown(self): + _uninstall_fake_hdfs_native() + + def test_local_paths_unaffected(self): + fio = FileIO.get("file:///tmp/foo") + self.assertEqual(type(fio).__name__, "LocalFileIO") + + def test_default_hdfs_routes_to_native(self): + _install_fake_hdfs_native() + fio = FileIO.get("hdfs://ns/foo", Options({})) + self.assertEqual(type(fio).__name__, "HdfsNativeFileIO") + + def test_explicit_pyarrow_routes_to_pyarrow(self): + # No hdfs-native module needed; should go straight to pyarrow. + with patch( + "pypaimon.filesystem.pyarrow_file_io.PyArrowFileIO.__init__", + return_value=None, + ): + fio = FileIO.get( + "hdfs://ns/foo", + Options({"hdfs.client.impl": "pyarrow"}), + ) + self.assertEqual(type(fio).__name__, "PyArrowFileIO") + + def test_native_init_failure_falls_back_to_pyarrow(self): + # hdfs_native not installed; default fallback enabled. + _uninstall_fake_hdfs_native() + with patch( + "pypaimon.filesystem.pyarrow_file_io.PyArrowFileIO.__init__", + return_value=None, + ): + fio = FileIO.get("hdfs://ns/foo", Options({})) + self.assertEqual(type(fio).__name__, "PyArrowFileIO") + + def test_native_init_failure_no_fallback_raises(self): + _uninstall_fake_hdfs_native() + with self.assertRaises(ImportError): + FileIO.get( + "hdfs://ns/foo", + Options({"hdfs.client.fallback-to-pyarrow": "false"}), + ) + + def test_unsupported_impl_raises(self): + with self.assertRaises(ValueError) as ctx: + FileIO.get( + "hdfs://ns/foo", + Options({"hdfs.client.impl": "bogus"}), + ) + self.assertIn("Unsupported hdfs.client.impl", str(ctx.exception)) + + def test_env_var_override_when_option_absent(self): + _install_fake_hdfs_native() + with patch( + "pypaimon.filesystem.pyarrow_file_io.PyArrowFileIO.__init__", + return_value=None, + ): + with patch.dict(os.environ, {"PYPAIMON_HDFS_IMPL": "pyarrow"}): + fio = FileIO.get("hdfs://ns/foo", Options({})) + self.assertEqual(type(fio).__name__, "PyArrowFileIO") + + def test_option_wins_over_env_var(self): + _install_fake_hdfs_native() + with patch.dict(os.environ, {"PYPAIMON_HDFS_IMPL": "pyarrow"}): + fio = FileIO.get( + "hdfs://ns/foo", + Options({"hdfs.client.impl": "native"}), + ) + self.assertEqual(type(fio).__name__, "HdfsNativeFileIO") + + def test_viewfs_scheme_routes_to_native(self): + _install_fake_hdfs_native() + fio = FileIO.get("viewfs://cluster/foo", Options({})) + self.assertEqual(type(fio).__name__, "HdfsNativeFileIO") + + def test_empty_impl_option_treated_as_unset(self): + # Templated configs sometimes blank the option to opt out — that + # should fall through to the default ("native"), not raise + # "Unsupported hdfs.client.impl ''". + _install_fake_hdfs_native() + fio = FileIO.get("hdfs://ns/foo", Options({"hdfs.client.impl": ""})) + self.assertEqual(type(fio).__name__, "HdfsNativeFileIO") + + +class HdfsNativeFileIOInitTest(unittest.TestCase): + + def setUp(self): + self._fake, self._client_cls, self._wo_cls = _install_fake_hdfs_native() + + def tearDown(self): + _uninstall_fake_hdfs_native() + + def _make(self, path, props=None): + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + return HdfsNativeFileIO(path, Options(props or {})) + + def test_constructs_client_with_url(self): + self._make("hdfs://ns1/warehouse") + self._client_cls.assert_called_once() + _, kwargs = self._client_cls.call_args + self.assertEqual(kwargs.get("url"), "hdfs://ns1") + self.assertNotIn("config", kwargs) + + def test_viewfs_scheme_passes_through(self): + self._make("viewfs://cluster1/") + _, kwargs = self._client_cls.call_args + self.assertEqual(kwargs.get("url"), "viewfs://cluster1") + + def test_native_hadoop_keys_forwarded_as_config(self): + self._make("hdfs://ns1/foo", { + "dfs.nameservices": "ns1", + "dfs.ha.namenodes.ns1": "nn1,nn2", + "dfs.namenode.rpc-address.ns1.nn1": "host1:8020", + "fs.viewfs.mounttable.cluster.link./prod": "hdfs://ns1/prod", + "warehouse": "hdfs://ns1/warehouse", # should NOT be forwarded + }) + _, kwargs = self._client_cls.call_args + config = kwargs.get("config", {}) + self.assertEqual(config.get("dfs.nameservices"), "ns1") + self.assertEqual(config.get("dfs.ha.namenodes.ns1"), "nn1,nn2") + self.assertEqual( + config.get("dfs.namenode.rpc-address.ns1.nn1"), "host1:8020") + self.assertEqual( + config.get("fs.viewfs.mounttable.cluster.link./prod"), + "hdfs://ns1/prod") + self.assertNotIn("warehouse", config) + + def test_namespaced_overrides_forwarded(self): + self._make("hdfs://ns1/foo", { + "hdfs.config.dfs.client.read.shortcircuit": "true", + }) + _, kwargs = self._client_cls.call_args + config = kwargs.get("config", {}) + self.assertEqual( + config.get("dfs.client.read.shortcircuit"), "true") + + def test_conf_dir_from_option(self): + self._make("hdfs://ns1/foo", { + "hdfs.conf-dir": "/tmp/conf", + }) + _, kwargs = self._client_cls.call_args + self.assertEqual(kwargs.get("config_dir"), "/tmp/conf") + + def test_conf_dir_from_env(self): + env = dict(os.environ) + env["HADOOP_CONF_DIR"] = "/env/conf" + with patch.dict(os.environ, env, clear=True): + self._make("hdfs://ns1/foo") + _, kwargs = self._client_cls.call_args + self.assertEqual(kwargs.get("config_dir"), "/env/conf") + + def test_option_conf_dir_overrides_env(self): + env = dict(os.environ) + env["HADOOP_CONF_DIR"] = "/env/conf" + with patch.dict(os.environ, env, clear=True): + self._make("hdfs://ns1/foo", {"hdfs.conf-dir": "/opt/conf"}) + _, kwargs = self._client_cls.call_args + self.assertEqual(kwargs.get("config_dir"), "/opt/conf") + + @patch("pypaimon.filesystem._kerberos.subprocess.run") + def test_kerberos_principal_keytab_triggers_kinit(self, mock_kinit): + mock_kinit.return_value = MagicMock() + with tempfile.NamedTemporaryFile(suffix=".keytab") as keytab_file: + with patch.dict(os.environ, {"KRB5CCNAME": "/tmp/kc_test"}): + self._make("hdfs://ns1/foo", { + "security.kerberos.login.principal": "user@REALM", + "security.kerberos.login.keytab": keytab_file.name, + }) + kinit_calls = [ + c for c in mock_kinit.call_args_list + if c[0][0][0] == "kinit" + ] + self.assertEqual(len(kinit_calls), 1) + self.assertEqual( + kinit_calls[0][0][0], + ["kinit", "-kt", keytab_file.name, "user@REALM"], + ) + + def test_principal_without_keytab_raises(self): + with self.assertRaises(ValueError) as ctx: + self._make("hdfs://ns1/foo", { + "security.kerberos.login.principal": "user@REALM", + }) + self.assertIn("must be both set or both unset", str(ctx.exception)) + + @patch("pypaimon.filesystem._kerberos.subprocess.run") + def test_kerberos_preserves_FILE_prefix_on_krb5ccname(self, mock_kinit): + # If KRB5CCNAME had a `FILE:` qualifier, the rewrite after kinit + # must keep it so GSSAPI cache-type detection isn't perturbed. + mock_kinit.return_value = MagicMock() + with tempfile.NamedTemporaryFile(suffix=".keytab") as keytab_file: + with patch.dict(os.environ, + {"KRB5CCNAME": "FILE:/tmp/kc_test"}, + clear=True): + self._make("hdfs://ns1/foo", { + "security.kerberos.login.principal": "user@REALM", + "security.kerberos.login.keytab": keytab_file.name, + }) + self.assertEqual( + os.environ["KRB5CCNAME"], "FILE:/tmp/kc_test") + + @patch("pypaimon.filesystem._kerberos.get_ticket_cache_path", + return_value="/tmp/freshly_kinit_cache") + @patch("pypaimon.filesystem._kerberos.subprocess.run") + def test_kerberos_warns_when_overwriting_different_cache( + self, mock_kinit, _mock_cache, + ): + # Multi-principal in the same process clobbers KRB5CCNAME; we warn + # so the operator sees the race instead of silently mis-routing + # earlier instances' RPCs. + mock_kinit.return_value = MagicMock() + with tempfile.NamedTemporaryFile(suffix=".keytab") as keytab_file: + with patch.dict(os.environ, + {"KRB5CCNAME": "/tmp/some_other_cache"}, + clear=True): + with self.assertLogs( + "pypaimon.filesystem.hdfs_native_file_io", + level="WARNING", + ) as log_ctx: + self._make("hdfs://ns1/foo", { + "security.kerberos.login.principal": "user@REALM", + "security.kerberos.login.keytab": keytab_file.name, + }) + self.assertTrue( + any("Overwriting process-global KRB5CCNAME" in m + for m in log_ctx.output), + log_ctx.output, + ) + + @patch("pypaimon.filesystem._kerberos.get_ticket_cache_path", + return_value="/tmp/kc_test") + @patch("pypaimon.filesystem._kerberos.subprocess.run") + def test_kerberos_no_warning_when_cache_unchanged( + self, mock_kinit, _mock_cache, + ): + import logging as _logging + mock_kinit.return_value = MagicMock() + with tempfile.NamedTemporaryFile(suffix=".keytab") as keytab_file: + # Pre-existing value matches the kinit-resolved cache → no warn. + with patch.dict(os.environ, + {"KRB5CCNAME": "/tmp/kc_test"}, + clear=True): + # assertNoLogs is 3.10+; patch warning explicitly so older + # interpreters keep working too. + logger = _logging.getLogger( + "pypaimon.filesystem.hdfs_native_file_io") + with patch.object(logger, "warning") as warn: + self._make("hdfs://ns1/foo", { + "security.kerberos.login.principal": "user@REALM", + "security.kerberos.login.keytab": keytab_file.name, + }) + warn.assert_not_called() + + def test_unsupported_scheme_raises(self): + with self.assertRaises(ValueError): + self._make("s3://bucket/key") + + +class HdfsNativeFileIOOpsTest(unittest.TestCase): + + def setUp(self): + self._fake, self._client_cls, self._wo_cls = _install_fake_hdfs_native() + self._mock_client = MagicMock(name="ClientInstance") + self._client_cls.return_value = self._mock_client + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + self.fio = HdfsNativeFileIO("hdfs://ns/", Options({})) + + def tearDown(self): + _uninstall_fake_hdfs_native() + + def _file_status(self, path, isdir=False, length=0, mtime=0): + s = MagicMock() + s.path = path + s.isdir = isdir + s.length = length + s.modification_time = mtime + return s + + def test_exists_true(self): + self._mock_client.get_file_info.return_value = self._file_status("/x") + self.assertTrue(self.fio.exists("/x")) + + def test_exists_false(self): + self._mock_client.get_file_info.side_effect = FileNotFoundError("nope") + self.assertFalse(self.fio.exists("/missing")) + + def test_get_file_status_adapts_to_pafs_filetype(self): + self._mock_client.get_file_info.return_value = self._file_status( + "/x", isdir=False, length=42, mtime=1700000000000, + ) + info = self.fio.get_file_status("/x") + self.assertEqual(info.path, "/x") + self.assertEqual(info.size, 42) + self.assertEqual(info.type, pafs.FileType.File) + self.assertIsNotNone(info.mtime) + + def test_list_status(self): + self._mock_client.list_status.return_value = iter([ + self._file_status("/x/a", isdir=False, length=10), + self._file_status("/x/b", isdir=True), + ]) + infos = self.fio.list_status("/x") + self.assertEqual(len(infos), 2) + self.assertEqual(infos[0].type, pafs.FileType.File) + self.assertEqual(infos[1].type, pafs.FileType.Directory) + self.assertIsNone(infos[1].size) + + def test_delete_missing_returns_false(self): + self._mock_client.get_file_info.side_effect = FileNotFoundError("nope") + self.assertFalse(self.fio.delete("/missing")) + self._mock_client.delete.assert_not_called() + + def test_delete_file(self): + self._mock_client.get_file_info.return_value = self._file_status("/x") + self._mock_client.delete.return_value = True + self.assertTrue(self.fio.delete("/x")) + self._mock_client.delete.assert_called_once_with("/x", False) + + def test_delete_nonempty_dir_without_recursive_raises(self): + self._mock_client.get_file_info.return_value = self._file_status( + "/x", isdir=True) + self._mock_client.list_status.return_value = iter([ + self._file_status("/x/a")]) + with self.assertRaises(OSError): + self.fio.delete("/x", recursive=False) + + def test_mkdirs_creates_when_missing(self): + self._mock_client.get_file_info.side_effect = FileNotFoundError("nope") + self.assertTrue(self.fio.mkdirs("/new")) + self._mock_client.mkdirs.assert_called_once_with( + "/new", create_parent=True) + + def test_mkdirs_idempotent_on_existing_dir(self): + self._mock_client.get_file_info.return_value = self._file_status( + "/x", isdir=True) + self.assertTrue(self.fio.mkdirs("/x")) + self._mock_client.mkdirs.assert_not_called() + + def test_mkdirs_existing_file_raises(self): + self._mock_client.get_file_info.return_value = self._file_status( + "/x", isdir=False) + with self.assertRaises(FileExistsError): + self.fio.mkdirs("/x") + + +class HdfsNativeAdaptersTest(unittest.TestCase): + + def setUp(self): + _install_fake_hdfs_native() + + def tearDown(self): + _uninstall_fake_hdfs_native() + + def test_writer_adapter_tracks_position_and_closes_once(self): + from pypaimon.filesystem.hdfs_native_file_io import _HdfsWriterAdapter + fw = MagicMock() + fw.write.side_effect = lambda buf: len(buf) + adapter = _HdfsWriterAdapter(fw) + adapter.write(b"abc") + adapter.write(b"defg") + self.assertEqual(adapter.tell(), 7) + adapter.close() + adapter.close() # idempotent + fw.close.assert_called_once() + + def test_reader_adapter_seek_and_read(self): + from pypaimon.filesystem.hdfs_native_file_io import _HdfsReaderAdapter + fr = MagicMock() + fr.tell.side_effect = [20, 30] + fr.read.return_value = b"x" * 10 + adapter = _HdfsReaderAdapter(fr) + self.assertEqual(adapter.seek(20), 20) + fr.seek.assert_called_once_with(20, 0) + data = adapter.read(10) + self.assertEqual(data, b"x" * 10) + fr.read.assert_called_once_with(10) + self.assertEqual(adapter.tell(), 30) + + def test_reader_adapter_read_negative_reads_all(self): + from pypaimon.filesystem.hdfs_native_file_io import _HdfsReaderAdapter + fr = MagicMock() + fr.read.return_value = b"all-content" + adapter = _HdfsReaderAdapter(fr) + self.assertEqual(adapter.read(), b"all-content") + fr.read.assert_called_once_with(-1) + + def test_reader_adapter_close_releases_underlying(self): + from pypaimon.filesystem.hdfs_native_file_io import _HdfsReaderAdapter + fr = MagicMock() + adapter = _HdfsReaderAdapter(fr) + adapter.close() + adapter.close() # idempotent + fr.close.assert_called_once() + self.assertTrue(adapter.closed) + + +def _write_hadoop_xml(path, entries): + """Write a minimal Hadoop-style xml file at `path`.""" + body = ['', ""] + for name, value in entries.items(): + body.append( + f" {name}{value}" + ) + body.append("") + with open(path, "w") as f: + f.write("\n".join(body)) + + +class ViewFsFallbackTest(unittest.TestCase): + """Cover _load_hadoop_xml + _maybe_inject_viewfs_fallback polyfill.""" + + def setUp(self): + _install_fake_hdfs_native() + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + self.Fio = HdfsNativeFileIO + + def tearDown(self): + _uninstall_fake_hdfs_native() + + def test_load_hadoop_xml_merges_two_files(self): + with tempfile.TemporaryDirectory() as d: + _write_hadoop_xml(os.path.join(d, "core-site.xml"), + {"fs.defaultFS": "viewfs://c1"}) + _write_hadoop_xml(os.path.join(d, "hdfs-site.xml"), + {"dfs.nameservices": "ns1"}) + cfg = self.Fio._load_hadoop_xml(d) + self.assertEqual(cfg.get("fs.defaultFS"), "viewfs://c1") + self.assertEqual(cfg.get("dfs.nameservices"), "ns1") + + def test_load_hadoop_xml_missing_dir_returns_empty(self): + self.assertEqual(self.Fio._load_hadoop_xml(None), {}) + self.assertEqual(self.Fio._load_hadoop_xml("/no/such/dir/xyz"), {}) + + def test_load_hadoop_xml_malformed_file_skipped(self): + with tempfile.TemporaryDirectory() as d: + with open(os.path.join(d, "core-site.xml"), "w") as f: + f.write(" absolute-path normalisation for hdfs-native.""" + + def setUp(self): + _install_fake_hdfs_native() + + def tearDown(self): + _uninstall_fake_hdfs_native() + + def _make(self, root): + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + return HdfsNativeFileIO(root, Options({})) + + def test_viewfs_uri_same_cluster_returns_path(self): + fio = self._make("viewfs://cluster1/") + self.assertEqual( + fio.to_filesystem_path("viewfs://cluster1/home/hudi/x"), + "/home/hudi/x", + ) + + def test_viewfs_uri_no_path_returns_root(self): + fio = self._make("viewfs://cluster1/") + self.assertEqual(fio.to_filesystem_path("viewfs://cluster1"), "/") + + def test_viewfs_absolute_path_unchanged(self): + fio = self._make("viewfs://cluster1/") + self.assertEqual(fio.to_filesystem_path("/foo/bar"), "/foo/bar") + + def test_hdfs_uri_same_ns_returns_path(self): + fio = self._make("hdfs://ns1/") + self.assertEqual( + fio.to_filesystem_path("hdfs://ns1/foo/bar"), + "/foo/bar", + ) + + def test_hdfs_uri_different_ns_unchanged(self): + fio = self._make("hdfs://ns1/") + self.assertEqual( + fio.to_filesystem_path("hdfs://nsX/foo"), + "hdfs://nsX/foo", + ) + + def test_hdfs_client_with_viewfs_uri_unchanged(self): + fio = self._make("hdfs://ns1/") + # Different scheme; let hdfs-native error rather than silently rewrite. + self.assertEqual( + fio.to_filesystem_path("viewfs://cluster1/foo"), + "viewfs://cluster1/foo", + ) + + def test_exists_passes_path_only_to_client(self): + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + fio = HdfsNativeFileIO("viewfs://cluster1/", Options({})) + client = sys.modules["hdfs_native"].Client.return_value + client.get_file_info.return_value = MagicMock( + path="/home/hudi/x", isdir=False, length=0, modification_time=0) + fio.exists("viewfs://cluster1/home/hudi/x") + client.get_file_info.assert_called_once_with("/home/hudi/x") + + +class FilesystemPropertyTest(unittest.TestCase): + """Cover the lazy pyarrow.fs facade backed by hdfs_native.fsspec.""" + + def setUp(self): + _install_fake_hdfs_native() + self._patcher = patch("pyarrow.fs.PyFileSystem") + self._handler_patcher = patch("pyarrow.fs.FSSpecHandler") + self.MockPyFs = self._patcher.start() + self.MockHandler = self._handler_patcher.start() + + def tearDown(self): + self._patcher.stop() + self._handler_patcher.stop() + _uninstall_fake_hdfs_native() + + def _make(self, root, props=None, xml_entries=None): + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + if xml_entries: + d = tempfile.mkdtemp() + self.addCleanup(lambda: __import__("shutil").rmtree(d, ignore_errors=True)) + _write_hadoop_xml(os.path.join(d, "hdfs-site.xml"), xml_entries) + base_props = {"hdfs.conf-dir": d} + base_props.update(props or {}) + props = base_props + return HdfsNativeFileIO(root, Options(props or {})) + + def test_viewfs_uses_viewfs_fsspec_class(self): + fio = self._make("viewfs://cluster1/") + fs_instance = fio.filesystem # trigger lazy + VFs = sys.modules["hdfs_native.fsspec"].ViewfsFileSystem + HFs = sys.modules["hdfs_native.fsspec"].HdfsFileSystem + VFs.assert_called_once() + HFs.assert_not_called() + _, kwargs = VFs.call_args + self.assertEqual(kwargs.get("host"), "cluster1") + self.assertIs(fs_instance, self.MockPyFs.return_value) + + def test_hdfs_uses_hdfs_fsspec_class(self): + self._make("hdfs://ns1/").filesystem + HFs = sys.modules["hdfs_native.fsspec"].HdfsFileSystem + VFs = sys.modules["hdfs_native.fsspec"].ViewfsFileSystem + HFs.assert_called_once() + VFs.assert_not_called() + _, kwargs = HFs.call_args + self.assertEqual(kwargs.get("host"), "ns1") + + def test_lazy_caches_after_first_access(self): + fio = self._make("hdfs://ns1/") + first = fio.filesystem + second = fio.filesystem + self.assertIs(first, second) + HFs = sys.modules["hdfs_native.fsspec"].HdfsFileSystem + self.assertEqual(HFs.call_count, 1) + + def test_xml_and_catalog_options_merged_into_fsspec_storage_options(self): + fio = self._make( + "hdfs://ns1/", + props={"dfs.client.read.shortcircuit": "true"}, + xml_entries={"dfs.nameservices": "ns1"}, + ) + fio.filesystem # trigger + HFs = sys.modules["hdfs_native.fsspec"].HdfsFileSystem + _, kwargs = HFs.call_args + # Both xml and option keys should land in the fsspec kwargs. + self.assertEqual(kwargs.get("dfs.nameservices"), "ns1") + self.assertEqual(kwargs.get("dfs.client.read.shortcircuit"), "true") + + def test_catalog_option_overrides_xml(self): + fio = self._make( + "hdfs://ns1/", + props={"dfs.foo": "v_user"}, + xml_entries={"dfs.foo": "v_xml"}, + ) + fio.filesystem + HFs = sys.modules["hdfs_native.fsspec"].HdfsFileSystem + _, kwargs = HFs.call_args + self.assertEqual(kwargs.get("dfs.foo"), "v_user") + + def test_missing_fsspec_raises_clear_error(self): + fio = self._make("hdfs://ns1/") + # Remove the fsspec submodule but keep hdfs_native itself, to + # simulate an old/partial install. + sys.modules.pop("hdfs_native.fsspec", None) + sys.modules["hdfs_native"].fsspec = None + with self.assertRaises(RuntimeError) as ctx: + fio.filesystem + self.assertIn("hdfs-native fsspec adapter", str(ctx.exception)) + + +class PickleTest(unittest.TestCase): + """Cover __reduce__ so Ray / multiprocessing can ship FileIO.""" + + def setUp(self): + _install_fake_hdfs_native() + # Isolate from any HADOOP_CONF_DIR on the host so __reduce__'s + # env-derived config_dir pinning is deterministic across machines. + self._env_patcher = patch.dict(os.environ, {}, clear=True) + self._env_patcher.start() + + def tearDown(self): + self._env_patcher.stop() + _uninstall_fake_hdfs_native() + + def _make(self, path, props=None): + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + return HdfsNativeFileIO(path, Options(props or {})) + + def test_reduce_returns_class_and_args(self): + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + fio = self._make("viewfs://cluster1/some/sub/path", + {"dfs.nameservices": "ns1"}) + cls, args = fio.__reduce__() + self.assertIs(cls, HdfsNativeFileIO) + path, options = args + # Path is rebuilt from scheme+netloc (path segment dropped) — that + # is intentional because __init__ ignores path beyond scheme+netloc. + self.assertEqual(path, "viewfs://cluster1") + self.assertEqual(options.to_map(), {"dfs.nameservices": "ns1"}) + + def test_reduce_for_empty_netloc(self): + fio = self._make("hdfs://") + _, (path, _) = fio.__reduce__() + self.assertEqual(path, "hdfs://") + + def test_reduce_pins_env_resolved_config_dir_into_options(self): + # config_dir resolved from $HADOOP_CONF_DIR should be carried into + # the pickled options so a worker on a host with a different env + # value still uses the driver's resolved directory. + with tempfile.TemporaryDirectory() as d: + with patch.dict(os.environ, {"HADOOP_CONF_DIR": d}, clear=True): + fio = self._make("hdfs://ns1/foo") + _, (_, options) = fio.__reduce__() + self.assertEqual(options.to_map().get("hdfs.conf-dir"), d) + + def test_reduce_does_not_override_explicit_conf_dir_option(self): + with tempfile.TemporaryDirectory() as opt_dir: + with patch.dict(os.environ, + {"HADOOP_CONF_DIR": "/env/dir"}, clear=True): + fio = self._make("hdfs://ns1/foo", + {"hdfs.conf-dir": opt_dir}) + _, (_, options) = fio.__reduce__() + self.assertEqual(options.to_map().get("hdfs.conf-dir"), opt_dir) + + def test_pickle_roundtrip_preserves_type_and_options(self): + import pickle + fio = self._make("hdfs://ns1/foo", + {"dfs.foo": "bar", "fs.viewfs.x": "y"}) + client_cls = sys.modules["hdfs_native"].Client + client_cls.reset_mock() + # Roundtrip via the highest pickle protocol. + blob = pickle.dumps(fio, protocol=pickle.HIGHEST_PROTOCOL) + restored = pickle.loads(blob) + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + self.assertIsInstance(restored, HdfsNativeFileIO) + self.assertEqual(restored.properties.to_map(), + {"dfs.foo": "bar", "fs.viewfs.x": "y"}) + # The original __init__ ran once; the unpickle ran __init__ again. + self.assertEqual(client_cls.call_count, 1) + + def test_pickle_with_viewfs_scheme(self): + import pickle + fio = self._make("viewfs://cluster1/") + restored = pickle.loads(pickle.dumps(fio)) + self.assertEqual(restored._scheme, "viewfs") + self.assertEqual(restored._netloc, "cluster1") + + def test_pickle_does_not_serialise_live_client(self): + # If the live _client were pickled, the call would fail (MagicMocks + # are picklable but the real RawClient would not be). This test + # documents the contract: __reduce__ MUST sidestep _client. + import pickle + fio = self._make("hdfs://ns1/") + blob = pickle.dumps(fio) + # The pickled blob should reference the constructor inputs only; + # specifically it should not embed the literal mock _client. + self.assertNotIn(b"_client", blob) + + +class HdfsNativeWriteFormatTest(unittest.TestCase): + """HdfsNativeFileIO is the default hdfs:// backend, so it must keep the + same write surface as the pyarrow backend it replaces — including + lance/vortex, which otherwise fall through to FileIO's NotImplementedError. + """ + + def setUp(self): + self._fake, self._client_cls, _ = _install_fake_hdfs_native() + self._client_cls.return_value = MagicMock(name="ClientInstance") + from pypaimon.filesystem.hdfs_native_file_io import HdfsNativeFileIO + self.fio = HdfsNativeFileIO("hdfs://ns/", Options({})) + + def tearDown(self): + _uninstall_fake_hdfs_native() + + def test_lance_and_vortex_are_overridden(self): + # The regression this guards: with these unimplemented, an HDFS table + # using file.format=lance/vortex would hit FileIO's NotImplementedError. + from pypaimon.common.file_io import FileIO + self.assertIsNot( + type(self.fio).write_lance, FileIO.write_lance) + self.assertIsNot( + type(self.fio).write_vortex, FileIO.write_vortex) + + def test_write_lance_delegates_to_lance_specified(self): + import pyarrow + table = pyarrow.table({"a": [1, 2]}) + writer = MagicMock(name="LanceFileWriter") + fake_lance = types.ModuleType("lance") + fake_lance.file = types.SimpleNamespace( + LanceFileWriter=MagicMock(return_value=writer)) + with patch.dict(sys.modules, {"lance": fake_lance}), \ + patch("pypaimon.read.reader.lance_utils.to_lance_specified", + return_value=("hdfs://ns/x.lance", {"opt": "v"})) as spec: + self.fio.write_lance("hdfs://ns/x.lance", table) + spec.assert_called_once() + _, kwargs = fake_lance.file.LanceFileWriter.call_args + self.assertEqual(kwargs.get("storage_options"), {"opt": "v"}) + writer.close.assert_called_once() + + def test_write_vortex_delegates_to_vortex_specified(self): + import pyarrow + table = pyarrow.table({"a": [1, 2]}) + fake_vortex = types.ModuleType("vortex") + fake_vortex.array = MagicMock(return_value="varr") + fake_vortex.store = types.SimpleNamespace(from_url=MagicMock()) + fake_io = types.ModuleType("vortex._lib.io") + fake_io.write = MagicMock() + fake_lib = types.ModuleType("vortex._lib") + fake_lib.io = fake_io + fake_modules = { + "vortex": fake_vortex, + "vortex._lib": fake_lib, + "vortex._lib.io": fake_io, + } + with patch.dict(sys.modules, fake_modules), \ + patch("pypaimon.read.reader.vortex_utils.to_vortex_specified", + return_value=("hdfs://ns/x.vortex", None)): + self.fio.write_vortex("hdfs://ns/x.vortex", table) + fake_io.write.assert_called_once_with("varr", "hdfs://ns/x.vortex") + + +if __name__ == "__main__": + unittest.main() diff --git a/paimon-python/pypaimon/tests/index_manifest_write_test.py b/paimon-python/pypaimon/tests/index_manifest_write_test.py new file mode 100644 index 000000000000..ff61750407e0 --- /dev/null +++ b/paimon-python/pypaimon/tests/index_manifest_write_test.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import shutil +import tempfile +import unittest +import uuid + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.globalindex.global_index_meta import GlobalIndexMeta +from pypaimon.index.index_file_meta import IndexFileMeta +from pypaimon.manifest.index_manifest_entry import IndexManifestEntry +from pypaimon.manifest.index_manifest_file import IndexManifestFile +from pypaimon.table.row.generic_row import GenericRow + + +class IndexManifestWriteTest(unittest.TestCase): + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('vec', pa.string()), + ]) + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _table(self): + name = f'default.idx_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema(self.pa_schema) + self.catalog.create_table(name, s, False) + return self.catalog.get_table(name) + + def _entry(self, file_name, field_id, meta=b'm'): + partition = GenericRow([], []) + index_file = IndexFileMeta( + index_type='BTREE', + file_name=file_name, + file_size=123, + row_count=10, + global_index_meta=GlobalIndexMeta( + row_range_start=0, + row_range_end=10, + index_field_id=field_id, + extra_field_ids=[field_id + 1], + index_meta=meta, + ), + ) + return IndexManifestEntry(kind=0, partition=partition, bucket=0, index_file=index_file) + + def test_write_read_roundtrip(self): + imf = IndexManifestFile(self._table()) + name = imf.write([self._entry('idx-a', 1), self._entry('idx-b', 2)]) + out = imf.read(name) + self.assertEqual(2, len(out)) + by_name = {e.index_file.file_name: e for e in out} + a = by_name['idx-a'] + self.assertEqual('BTREE', a.index_file.index_type) + self.assertEqual(123, a.index_file.file_size) + self.assertEqual(10, a.index_file.row_count) + self.assertEqual(0, a.kind) + gim = a.index_file.global_index_meta + self.assertEqual(1, gim.index_field_id) + self.assertEqual(0, gim.row_range_start) + self.assertEqual(10, gim.row_range_end) + self.assertEqual([2], gim.extra_field_ids) + self.assertEqual(b'm', bytes(gim.index_meta)) + + def test_combine_drops_named_files(self): + imf = IndexManifestFile(self._table()) + previous = imf.write([self._entry('idx-a', 1), self._entry('idx-b', 2)]) + deletes = [self._entry('idx-a', 1)] + new_name = imf.combine_deletes(previous, deletes) + self.assertNotEqual(previous, new_name) + survivors = {e.index_file.file_name for e in imf.read(new_name)} + self.assertEqual({'idx-b'}, survivors) + + def test_combine_unknown_delete_is_noop_on_content(self): + imf = IndexManifestFile(self._table()) + previous = imf.write([self._entry('idx-a', 1)]) + new_name = imf.combine_deletes(previous, [self._entry('idx-zzz', 9)]) + survivors = {e.index_file.file_name for e in imf.read(new_name)} + self.assertEqual({'idx-a'}, survivors) + + def test_combine_empty_deletes_returns_previous(self): + imf = IndexManifestFile(self._table()) + previous = imf.write([self._entry('idx-a', 1)]) + self.assertEqual(previous, imf.combine_deletes(previous, [])) + + def test_combine_all_deleted_returns_none(self): + imf = IndexManifestFile(self._table()) + previous = imf.write([self._entry('idx-a', 1)]) + self.assertIsNone(imf.combine_deletes(previous, [self._entry('idx-a', 1)])) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/jdbc_catalog_test.py b/paimon-python/pypaimon/tests/jdbc_catalog_test.py new file mode 100644 index 000000000000..04ca55e250f3 --- /dev/null +++ b/paimon-python/pypaimon/tests/jdbc_catalog_test.py @@ -0,0 +1,223 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +import os +import shutil +import sqlite3 +import tempfile +import unittest + +from pypaimon import CatalogFactory, Schema +from pypaimon.catalog.catalog_exception import ( + DatabaseAlreadyExistException, + DatabaseNotExistException, + TableAlreadyExistException, + TableNotExistException +) +from pypaimon.catalog.jdbc_catalog import JdbcCatalog, _convert_qmark_placeholders +from pypaimon.catalog.rest.property_change import PropertyChange +from pypaimon.schema.data_types import AtomicType, DataField +from pypaimon.schema.schema_change import SchemaChange + + +class JdbcCatalogTest(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp(prefix="unittest_") + self.warehouse = os.path.join(self.temp_dir, "warehouse") + self.jdbc_path = os.path.join(self.temp_dir, "catalog.db") + self.options = { + "metastore": "jdbc", + "warehouse": self.warehouse, + "uri": "jdbc:sqlite:" + self.jdbc_path, + "catalog-key": "test-jdbc-catalog", + } + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_create_jdbc_catalog(self): + catalog = CatalogFactory.create(self.options) + self.assertTrue(isinstance(catalog, JdbcCatalog)) + + with sqlite3.connect(self.jdbc_path) as conn: + tables = { + row[0] + for row in conn.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ) + } + self.assertIn("paimon_tables", tables) + self.assertIn("paimon_database_properties", tables) + self.assertIn("paimon_table_properties", tables) + + def test_jdbc_catalog_context_manager_closes_connection(self): + with CatalogFactory.create(self.options) as catalog: + self.assertTrue(isinstance(catalog, JdbcCatalog)) + + with self.assertRaises(sqlite3.ProgrammingError): + catalog.list_databases() + + def test_placeholder_conversion_skips_string_literals(self): + sql = "SELECT '?' AS q, \"?\" AS quoted, col FROM tbl WHERE a = ? AND b = '?'" + self.assertEqual( + _convert_qmark_placeholders(sql, "%s"), + "SELECT '?' AS q, \"?\" AS quoted, col FROM tbl WHERE a = %s AND b = '?'" + ) + + def test_database(self): + catalog = CatalogFactory.create(self.options) + catalog.create_database("test_db", False, {"owner": "owner1"}) + + with self.assertRaises(DatabaseAlreadyExistException): + catalog.create_database("test_db", False) + + self.assertEqual(catalog.list_databases(), ["test_db"]) + database = catalog.get_database("test_db") + self.assertEqual(database.name, "test_db") + self.assertEqual(database.options["owner"], "owner1") + self.assertEqual( + database.options["location"], + os.path.join(self.warehouse, "test_db.db") + ) + + reloaded = CatalogFactory.create(self.options) + self.assertEqual(reloaded.list_databases(), ["test_db"]) + reloaded.alter_database( + "test_db", + [ + PropertyChange.set_property("comment", "new comment"), + PropertyChange.remove_property("owner"), + ] + ) + updated = reloaded.get_database("test_db") + self.assertEqual(updated.options["comment"], "new comment") + self.assertNotIn("owner", updated.options) + + reloaded.drop_database("test_db") + with self.assertRaises(DatabaseNotExistException): + reloaded.get_database("test_db") + + def test_table(self): + fields = [ + DataField.from_dict({"id": 1, "name": "f0", "type": "INT"}), + DataField.from_dict({"id": 2, "name": "f1", "type": "STRING"}), + ] + catalog = CatalogFactory.create(self.options) + catalog.create_database("test_db", False) + catalog.create_table( + "test_db.test_table", + Schema(fields=fields, partition_keys=["f1"], options={"bucket": "1"}), + False + ) + + with self.assertRaises(TableAlreadyExistException): + catalog.create_table("test_db.test_table", Schema(fields=fields), False) + + self.assertEqual(catalog.list_tables("test_db"), ["test_table"]) + self.assertTrue( + os.path.exists( + os.path.join(self.warehouse, "test_db.db", "test_table", "schema", "schema-0") + ) + ) + + reloaded = CatalogFactory.create(self.options) + table = reloaded.get_table("test_db.test_table") + self.assertEqual(table.fields[0].name, "f0") + self.assertTrue(isinstance(table.fields[0].type, AtomicType)) + self.assertEqual(table.fields[0].type.type, "INT") + + with sqlite3.connect(self.jdbc_path) as conn: + properties = dict( + conn.execute( + "SELECT property_key, property_value FROM paimon_table_properties " + "WHERE catalog_key = ? AND database_name = ? AND table_name = ?", + ("test-jdbc-catalog", "test_db", "test_table") + ).fetchall() + ) + self.assertEqual(properties["bucket"], "1") + self.assertEqual(properties["partition"], "f1") + + reloaded.alter_table( + "test_db.test_table", + [SchemaChange.add_column("f2", AtomicType("BIGINT"))] + ) + self.assertEqual(len(reloaded.get_table("test_db.test_table").fields), 3) + + reloaded.rename_table("test_db.test_table", "test_db.renamed_table") + self.assertEqual(reloaded.list_tables("test_db"), ["renamed_table"]) + with self.assertRaises(TableNotExistException): + reloaded.get_table("test_db.test_table") + + reloaded.drop_table("test_db.renamed_table") + self.assertEqual(reloaded.list_tables("test_db"), []) + with self.assertRaises(TableNotExistException): + reloaded.get_table("test_db.renamed_table") + + def test_create_table_rolls_back_metadata_on_failure(self): + fields = [DataField.from_dict({"id": 1, "name": "f0", "type": "INT"})] + catalog = CatalogFactory.create(self.options) + catalog.create_database("test_db", False) + + def fail_insert_table_properties(identifier, properties): + raise RuntimeError("injected failure") + + catalog._insert_table_properties = fail_insert_table_properties + with self.assertRaises(RuntimeError): + catalog.create_table("test_db.test_table", Schema(fields=fields), False) + + with sqlite3.connect(self.jdbc_path) as conn: + table_count = conn.execute( + "SELECT COUNT(*) FROM paimon_tables " + "WHERE catalog_key = ? AND database_name = ? AND table_name = ?", + ("test-jdbc-catalog", "test_db", "test_table") + ).fetchone()[0] + self.assertEqual(table_count, 0) + self.assertFalse(os.path.exists(os.path.join(self.warehouse, "test_db.db", "test_table"))) + + def test_rename_table_keeps_metadata_when_file_move_fails(self): + fields = [DataField.from_dict({"id": 1, "name": "f0", "type": "INT"})] + catalog = CatalogFactory.create(self.options) + catalog.create_database("test_db", False) + catalog.create_table("test_db.test_table", Schema(fields=fields), False) + + def fail_rename(source, target): + raise OSError("injected failure") + + catalog.file_io.rename = fail_rename + with self.assertRaises(OSError): + catalog.rename_table("test_db.test_table", "test_db.renamed_table") + + self.assertEqual(catalog.list_tables("test_db"), ["test_table"]) + self.assertTrue(os.path.exists(os.path.join(self.warehouse, "test_db.db", "test_table"))) + + def test_drop_database_requires_cascade_for_non_empty_database(self): + fields = [DataField.from_dict({"id": 1, "name": "f0", "type": "INT"})] + catalog = CatalogFactory.create(self.options) + catalog.create_database("test_db", False) + catalog.create_table("test_db.test_table", Schema(fields=fields), False) + + with self.assertRaises(ValueError): + catalog.drop_database("test_db") + + catalog.drop_database("test_db", cascade=True) + self.assertEqual(catalog.list_databases(), []) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/kerberos_test.py b/paimon-python/pypaimon/tests/kerberos_test.py index 06d98d90d6ad..722043d77899 100644 --- a/paimon-python/pypaimon/tests/kerberos_test.py +++ b/paimon-python/pypaimon/tests/kerberos_test.py @@ -235,7 +235,7 @@ def test_get_ticket_cache_no_cache(self): @patch("pypaimon.filesystem.pyarrow_file_io.subprocess.run") @patch("pypaimon.filesystem.pyarrow_file_io.pafs.HadoopFileSystem") def test_hdfs_with_fallback_keys(self, mock_hdfs_fs, mock_subprocess_run): - """Verify that Java-compatible fallback keys security.principal / security.keytab work.""" + """Verify that the secondary fallback keys security.principal / security.keytab work.""" mock_subprocess_run.return_value = MagicMock(stdout="/some/classpath") with tempfile.NamedTemporaryFile(suffix=".keytab") as keytab_file: diff --git a/paimon-python/pypaimon/tests/lumina_vector_index_test.py b/paimon-python/pypaimon/tests/lumina_vector_index_test.py index ca25ef691721..5a27af20c80a 100644 --- a/paimon-python/pypaimon/tests/lumina_vector_index_test.py +++ b/paimon-python/pypaimon/tests/lumina_vector_index_test.py @@ -98,7 +98,7 @@ def test_build_and_read(self): options=paimon_options, ) as reader: vs = VectorSearch(vector=raw[:dim], limit=5, field_name="embedding") - result = reader.visit_vector_search(vs) + result = reader.visit_vector_search(vs).result() self.assertIsNotNone(result) row_ids = result.results() @@ -157,7 +157,7 @@ def test_filtered_search(self): vector=raw[:dim], limit=3, field_name="embedding", include_row_ids=include_ids, ) - result = reader.visit_vector_search(vs) + result = reader.visit_vector_search(vs).result() self.assertIsNotNone(result) for row_id in result.results(): diff --git a/paimon-python/pypaimon/tests/nested_type_read_write_test.py b/paimon-python/pypaimon/tests/nested_type_read_write_test.py new file mode 100644 index 000000000000..0b2b9da403fe --- /dev/null +++ b/paimon-python/pypaimon/tests/nested_type_read_write_test.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema + + +class NestedTypeReadWriteTest(unittest.TestCase): + """Read/write of nested map types on primary-key tables. + + Primary-key tables read through a row-based merge path that converts each + arrow batch with polars before merging, so reading a map nested inside a + struct exercises that conversion. This guards against regressions there + (e.g. polars releases that cannot decode a struct-nested arrow Map). + """ + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_pk_table(self, name, pa_schema): + self.catalog.create_table( + 'default.' + name, + Schema.from_pyarrow_schema( + pa_schema, primary_keys=['id'], + options={'bucket': '1', 'file.format': 'parquet'}), + False) + return self.catalog.get_table('default.' + name) + + @staticmethod + def _write(table, pa_schema, ids, values): + data = pa.table( + {'id': pa.array(ids, pa_schema.field('id').type), + pa_schema.field(1).name: pa.array( + values, type=pa_schema.field(1).type)}, + schema=pa_schema) + write_builder = table.new_batch_write_builder() + write = write_builder.new_write() + commit = write_builder.new_commit() + write.write_arrow(data) + commit.commit(write.prepare_commit()) + write.close() + commit.close() + + @staticmethod + def _read_sorted(table): + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + rows = read_builder.new_read().to_arrow(splits).to_pylist() + rows.sort(key=lambda r: r['id']) + return rows + + def test_pk_merge_read_map_nested_in_struct(self): + # row, + # all_versioned_values map>> + inner = pa.struct([ + pa.field('audio_vae_version', pa.string()), + pa.field('audio_vae_result_path', pa.string()), + pa.field('audio_vae_latent_shape', pa.string()), + ]) + top = pa.struct([ + pa.field('latest_version', pa.string()), + pa.field('latest_value', inner), + pa.field('all_versioned_values', pa.map_(pa.string(), inner)), + ]) + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + pa.field('info', top), + ]) + + def value(ver, path, shape): + return {'audio_vae_version': ver, + 'audio_vae_result_path': path, + 'audio_vae_latent_shape': shape} + + table = self._create_pk_table('nested_map_in_struct', pa_schema) + + # first version + self._write(table, pa_schema, [1], [{ + 'latest_version': 'v2', + 'latest_value': value('v2', '/p/v2', '[1,2,3]'), + 'all_versioned_values': [ + ('v1', value('v1', '/p/v1', '[1,2]')), + ('v2', value('v2', '/p/v2', '[1,2,3]')), + ], + }]) + # same pk again -> forces a merge read (the polars conversion path) + self._write(table, pa_schema, [1], [{ + 'latest_version': 'v3', + 'latest_value': value('v3', '/p/v3', '[1,2,3,4]'), + 'all_versioned_values': [ + ('v1', value('v1', '/p/v1', '[1,2]')), + ('v2', value('v2', '/p/v2', '[1,2,3]')), + ('v3', value('v3', '/p/v3', '[1,2,3,4]')), + ], + }]) + + rows = self._read_sorted(table) + self.assertEqual(1, len(rows)) + info = rows[0]['info'] + self.assertEqual('v3', info['latest_version']) + self.assertEqual(value('v3', '/p/v3', '[1,2,3,4]'), + info['latest_value']) + self.assertEqual( + [('v1', value('v1', '/p/v1', '[1,2]')), + ('v2', value('v2', '/p/v2', '[1,2,3]')), + ('v3', value('v3', '/p/v3', '[1,2,3,4]'))], + info['all_versioned_values']) + + def test_pk_merge_read_top_level_map(self): + # map> as a top-level value column + row_ab = pa.struct([ + pa.field('a', pa.string()), + pa.field('b', pa.int32()), + ]) + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + pa.field('v', pa.map_(pa.string(), row_ab)), + ]) + + table = self._create_pk_table('top_level_map', pa_schema) + + self._write(table, pa_schema, [1, 2], [ + [('k1', {'a': 'OLD', 'b': 1})], + [('k2', {'a': 'keep', 'b': 2})], + ]) + # overwrite id=1, leave id=2 untouched + self._write(table, pa_schema, [1], [ + [('k1', {'a': 'NEW', 'b': 100}), ('k1b', {'a': 'extra', 'b': 101})], + ]) + + rows = self._read_sorted(table) + self.assertEqual(2, len(rows)) + self.assertEqual( + [('k1', {'a': 'NEW', 'b': 100}), ('k1b', {'a': 'extra', 'b': 101})], + rows[0]['v']) + self.assertEqual([('k2', {'a': 'keep', 'b': 2})], rows[1]['v']) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/predicates_test.py b/paimon-python/pypaimon/tests/predicates_test.py index 629f0235a00b..509efc79ca58 100644 --- a/paimon-python/pypaimon/tests/predicates_test.py +++ b/paimon-python/pypaimon/tests/predicates_test.py @@ -23,6 +23,7 @@ import pandas as pd import pyarrow as pa +import pyarrow.dataset as ds from pypaimon import CatalogFactory, Schema from pypaimon.common.predicate import Predicate @@ -463,6 +464,13 @@ def test_not_between_value(self): self.assertFalse(predicate.test(OffsetRow([3], 0, 1))) self.assertFalse(predicate.test(OffsetRow([None], 0, 1))) + def test_not_in_arrow_filter_excludes_nulls(self): + predicate = Predicate(method='notIn', index=0, field='val', literals=[1, 2]) + table = pa.table({"val": [None, 1, 3]}) + scanner = ds.InMemoryDataset(table).scanner(filter=predicate.to_arrow()) + + self.assertEqual(scanner.to_table().to_pydict(), {"val": [3]}) + def test_pk_reader_with_filter(self): pa_schema = pa.schema([ pa.field('key1', pa.int32(), nullable=False), diff --git a/paimon-python/pypaimon/tests/pvfs_test.py b/paimon-python/pypaimon/tests/pvfs_test.py index fd3c1785355b..813368c3fd6c 100644 --- a/paimon-python/pypaimon/tests/pvfs_test.py +++ b/paimon-python/pypaimon/tests/pvfs_test.py @@ -199,3 +199,47 @@ def test_api(self): self.assertTrue(self.pvfs.created(table_data_new_virtual_path) is not None) self.assertTrue(self.pvfs.modified(table_data_new_virtual_path) is not None) self.assertEqual('Hello World', self.pvfs.cat_file(date_file_new_virtual_path).decode('utf-8')) + + def test_path_traversal_rejected_in_extract(self): + """Paths containing '..' components must be rejected at parse time.""" + traversal_paths = [ + f'pvfs://{self.catalog}/{self.database}/{self.table}/../other_table/secret.parquet', + f'pvfs://{self.catalog}/{self.database}/{self.table}/../../other_db/t/data', + f'pvfs://{self.catalog}/{self.database}/{self.table}/../../../etc/passwd', + f'pvfs://{self.catalog}/../{self.database}/{self.table}', + ] + for path in traversal_paths: + with self.assertRaises(ValueError, msg=f"Should reject: {path}"): + self.pvfs._extract_pvfs_identifier(path) + + def test_path_traversal_rejected_in_get_actual_path(self): + """Even if '..' reaches get_actual_path, boundary check must block it.""" + from pypaimon.filesystem.pvfs import PVFSTableIdentifier + identifier = PVFSTableIdentifier( + endpoint="http://localhost", + catalog="cat", + database="db", + table="tbl", + sub_path="../../other_table/secret.parquet" + ) + with self.assertRaises(ValueError): + identifier.get_actual_path("/warehouse/cat/db/tbl") + + def test_null_byte_rejected(self): + """Null bytes in path components must be rejected.""" + path = f'pvfs://{self.catalog}/{self.database}/{self.table}/file\x00.parquet' + with self.assertRaises(ValueError): + self.pvfs._extract_pvfs_identifier(path) + + def test_legitimate_subpaths_allowed(self): + """Normal sub-paths without traversal must still work.""" + from pypaimon.filesystem.pvfs import PVFSTableIdentifier + identifier = PVFSTableIdentifier( + endpoint="http://localhost", + catalog="cat", + database="db", + table="tbl", + sub_path="partition=1/bucket-0/data.parquet" + ) + result = identifier.get_actual_path("/warehouse/cat/db/tbl") + self.assertEqual(result, "/warehouse/cat/db/tbl/partition=1/bucket-0/data.parquet") diff --git a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py index a0a3ad37e85d..a5aa88d7ca28 100644 --- a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py +++ b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py @@ -241,7 +241,12 @@ def test_full_data_types(self): table_scan = read_builder.new_scan() table_read = read_builder.new_read() actual_data = table_read.to_arrow(table_scan.plan().splits()) - self.assertEqual(actual_data, expect_data) + # BINARY(N) maps to variable-length binary on read (see #7518), so the + # fixed-size f9 column normalizes to binary; reflect that in the expected. + f9_index = expect_data.schema.get_field_index('f9') + expected_data = expect_data.set_column( + f9_index, 'f9', expect_data.column('f9').cast(pa.binary())) + self.assertEqual(actual_data, expected_data) # to test GenericRow ability latest_snapshot = table.snapshot_manager().get_latest_snapshot() diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py new file mode 100644 index 000000000000..ca06d43e5341 --- /dev/null +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -0,0 +1,1420 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os +import shutil +import tempfile +import unittest +import uuid +from unittest.mock import Mock, patch + +import pyarrow as pa +import ray + +from pypaimon import CatalogFactory, Schema +from pypaimon.ray import ( + WhenMatched, WhenNotMatched, merge_into, + source_col, target_col, lit, +) + +try: + import datafusion # noqa: F401 + _HAS_DATAFUSION = True +except ImportError: + _HAS_DATAFUSION = False + +_SKIP_CONDITION = not _HAS_DATAFUSION +_SKIP_REASON = "datafusion not installed" + +_TEST_NUM_PARTITIONS = 2 + + +class RayDataEvolutionMergeIntoTest(unittest.TestCase): + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + + de_options = { + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog_options = {'warehouse': cls.warehouse} + cls.catalog = CatalogFactory.create(cls.catalog_options) + cls.catalog.create_database('default', True) + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, num_cpus=2) + + @classmethod + def tearDownClass(cls): + try: + if ray.is_initialized(): + ray.shutdown() + except Exception: + pass + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_table(self, options=None): + opts = options if options is not None else self.de_options + name = f'default.tbl_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema(self.pa_schema, options=opts) + self.catalog.create_table(name, s, False) + return name + + def _source(self, ids=(1,)): + return pa.Table.from_pydict( + { + 'id': pa.array(list(ids), type=pa.int32()), + 'name': ['x'] * len(ids), + 'age': [10] * len(ids), + }, + schema=self.pa_schema, + ) + + def _write(self, target, data): + table = self.catalog.get_table(target) + wb = table.new_batch_write_builder() + writer = wb.new_write() + writer.write_arrow(data) + wb.new_commit().commit(writer.prepare_commit()) + writer.close() + + def _read_sorted(self, target): + table = self.catalog.get_table(target) + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + return rb.new_read().to_arrow(splits).sort_by('id').to_pydict() + + def _snapshot_id(self, target): + table = self.catalog.get_table(target) + snap = table.snapshot_manager().get_latest_snapshot() + return snap.id if snap is not None else None + + def test_paimon_source_table_pins_snapshot(self): + from pypaimon.ray import data_evolution_merge_into as m + + target = self._create_table() + source = self._create_table() + self._write(source, self._source(ids=(1,))) + expected_snapshot_id = self._snapshot_id(source) + + fake_ds = Mock() + fake_ds.schema.return_value = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + + with patch( + 'pypaimon.ray.ray_paimon.read_paimon', + return_value=fake_ds, + ) as mock_read_paimon: + m._prepare( + target, source, self.catalog_options, + [WhenMatched(update='*')], [], ['id'], + ) + + mock_read_paimon.assert_called_once_with( + source, + self.catalog_options, + snapshot_id=expected_snapshot_id, + ) + + def test_no_clause_raises(self): + target = self._create_table() + with self.assertRaises(ValueError): + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + def test_non_de_table_rejected(self): + target = self._create_table(options={'row-tracking.enabled': 'true'}) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('data-evolution.enabled', str(ctx.exception)) + + def test_no_row_tracking_rejected(self): + target = self._create_table(options={'data-evolution.enabled': 'true'}) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('row-tracking.enabled', str(ctx.exception)) + + def test_source_missing_on_col_raises(self): + target = self._create_table() + bad_source = pa.Table.from_pydict( + {'name': ['x'], 'age': [10]}, + schema=pa.schema([('name', pa.string()), ('age', pa.int32())]), + ) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=bad_source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn("'id'", str(ctx.exception)) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_not_matched_condition_rejects_target_refs(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert='*', condition='t.age > 10') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('t.', str(ctx.exception)) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_condition_unknown_source_col_rejected(self): + target = self._create_table() + self._write(target, self._source()) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.nonexistent > 0') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nonexistent', str(ctx.exception)) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_condition_unknown_target_col_rejected(self): + target = self._create_table() + self._write(target, self._source()) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > t.nonexistent') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nonexistent', str(ctx.exception)) + + def test_matched_update_star(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3, 4], type=pa.int32()), + 'name': ['b2', 'c2', 'd'], + 'age': pa.array([22, 33, 40], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c2']) + self.assertEqual(out['age'], [10, 22, 33]) + + def test_not_matched_insert_appends_unmatched(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3, 4], type=pa.int32()), + 'name': ['b2', 'c2', 'd'], + 'age': pa.array([22, 33, 40], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[WhenNotMatched(insert='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3, 4]) + self.assertEqual(out['name'], ['a', 'b', 'c', 'd']) + self.assertEqual(out['age'], [10, 20, 30, 40]) + + def test_combined_update_and_insert(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3], type=pa.int32()), + 'name': ['b2', 'c'], + 'age': pa.array([22, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + metrics = merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + when_not_matched=[WhenNotMatched(insert='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c']) + self.assertEqual(out['age'], [10, 22, 30]) + self.assertEqual(metrics, { + 'num_matched': 1, 'num_inserted': 1, 'num_unchanged': 0, + }) + + def test_on_with_renamed_columns_star(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source_schema = pa.schema([ + ('uid', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([2, 3], type=pa.int32()), + 'name': ['b2', 'c'], + 'age': pa.array([22, 30], type=pa.int32()), + }, + schema=source_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_matched=[WhenMatched(update='*')], + when_not_matched=[WhenNotMatched(insert='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c']) + self.assertEqual(out['age'], [10, 22, 30]) + + def test_insert_into_empty_target(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[WhenNotMatched(insert='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b', 'c']) + self.assertEqual(out['age'], [10, 20, 30]) + + def test_multi_source_match_raises_by_default(self): + # One target row matched by several source rows: the winning value is + # undefined (Spark DE's checkCardinality=false), so we refuse by default. + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 1], type=pa.int32()), + 'name': ['x', 'y'], + 'age': pa.array([100, 200], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + with self.assertRaises(Exception) as ctx: + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn("multiple source rows", str(ctx.exception)) + + def test_blob_columns_excluded(self): + import types + + from pypaimon.ray.data_evolution_merge_into import _blob_col_names + from pypaimon.schema.data_types import AtomicType, DataField + + fake_table = types.SimpleNamespace( + table_schema=types.SimpleNamespace( + fields=[ + DataField(0, 'id', AtomicType('INT')), + DataField(1, 'payload', AtomicType('BLOB')), + ] + ) + ) + self.assertEqual({'payload'}, _blob_col_names(fake_table)) + + def test_combined_writes_single_snapshot(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + before = self._snapshot_id(target) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3], type=pa.int32()), + 'name': ['b2', 'c'], + 'age': pa.array([22, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + when_not_matched=[WhenNotMatched(insert='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + after = self._snapshot_id(target) + self.assertEqual(after, before + 1) + + def test_empty_target_matched_update_is_noop(self): + target = self._create_table() + before = self._snapshot_id(target) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + self.assertEqual(self._snapshot_id(target), before) + + def test_partitioned_matched_update_rejected(self): + pt_schema = pa.schema([ + ('pt', pa.string()), + ('id', pa.int32()), + ('name', pa.string()), + ]) + name = f'default.tbl_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema( + pt_schema, partition_keys=['pt'], options=self.de_options, + ) + self.catalog.create_table(name, s, False) + + source = pa.Table.from_pydict( + { + 'pt': ['a'], + 'id': pa.array([1], type=pa.int32()), + 'name': ['x'], + }, + schema=pt_schema, + ) + + with self.assertRaises(ValueError) as ctx: + merge_into( + target=name, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('partitioned', str(ctx.exception)) + + def test_partitioned_insert_allowed(self): + pt_schema = pa.schema([ + ('pt', pa.string()), + ('id', pa.int32()), + ('name', pa.string()), + ]) + name = f'default.tbl_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema( + pt_schema, partition_keys=['pt'], options=self.de_options, + ) + self.catalog.create_table(name, s, False) + + source = pa.Table.from_pydict( + { + 'pt': ['a', 'b'], + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['x', 'y'], + }, + schema=pt_schema, + ) + + merge_into( + target=name, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[WhenNotMatched(insert='*')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + table = self.catalog.get_table(name) + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + out = rb.new_read().to_arrow(splits).sort_by('id').to_pydict() + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['pt'], ['a', 'b']) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_matched_update_with_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a2', 'b2', 'c2'], + 'age': pa.array([15, 25, 45], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.age > t.age + 10')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b', 'c2']) + self.assertEqual(out['age'], [10, 20, 45]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_matched_condition_with_source_on_key(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a2', 'b2', 'c2'], + 'age': pa.array([15, 25, 35], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.id >= 2')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c2']) + self.assertEqual(out['age'], [10, 25, 35]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_not_matched_insert_with_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3, 4], type=pa.int32()), + 'name': ['b', 'c', 'd'], + 'age': pa.array([15, 25, 5], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert='*', condition='s.age >= 10') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b', 'c']) + self.assertEqual(out['age'], [10, 15, 25]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_combined_with_conditions(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3, 4], type=pa.int32()), + 'name': ['a2', 'b2', 'c', 'd'], + 'age': pa.array([50, 5, 30, 8], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + metrics = merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.age > t.age')], + when_not_matched=[ + WhenNotMatched(insert='*', condition='s.age > 10') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a2', 'b', 'c']) + self.assertEqual(out['age'], [50, 20, 30]) + self.assertEqual(metrics['num_matched'], 1) + self.assertEqual(metrics['num_inserted'], 1) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_condition_no_rows_match_is_noop(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a2', 'b2'], + 'age': pa.array([5, 5], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.age > t.age')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [10, 20]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_duplicate_source_filtered_by_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 1], type=pa.int32()), + 'name': ['x', 'y'], + 'age': pa.array([5, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > t.age') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1]) + self.assertEqual(out['name'], ['y']) + self.assertEqual(out['age'], [20]) + + def test_matched_partial_update(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a2', 'b2'], + 'age': pa.array([99, 88], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'age': 's.age'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [99, 88]) + + def test_insert_partial_mapping(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert={'id': 's.id', 'name': 's.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [None, None]) + + def test_update_with_literal(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': 'updated'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['updated']) + self.assertEqual(out['age'], [10]) + + def test_invalid_target_column_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'nonexistent': 's.id'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nonexistent', str(ctx.exception)) + + def test_invalid_target_ref_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': 't.nme'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nme', str(ctx.exception)) + + def test_empty_mapping_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError): + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + def test_insert_target_ref_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert={'name': 't.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('t.', str(ctx.exception)) + + def test_matched_update_with_target_ref(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'age': 's.age', 'name': 't.name'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['old']) + self.assertEqual(out['age'], [99]) + + def test_callable_value_rejected(self): + target = self._create_table() + with self.assertRaises(TypeError): + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': lambda r: r})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + def test_source_missing_referenced_col(self): + target = self._create_table() + source = pa.Table.from_pydict( + {'id': pa.array([1], type=pa.int32())}, + schema=pa.schema([('id', pa.int32())]), + ) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': 's.name'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('name', str(ctx.exception)) + + def test_partial_insert_auto_fills_on_key(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert={'name': 's.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + + def test_partial_insert_renamed_on_key_auto_filled(self): + target = self._create_table() + + source_schema = pa.schema([ + ('uid', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=source_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_not_matched=[ + WhenNotMatched(insert={'name': 's.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + + def test_explicit_source_ref_not_remapped_by_on_key(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source_schema = pa.schema([ + ('uid', pa.int32()), + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([1], type=pa.int32()), + 'id': pa.array([42], type=pa.int32()), + 'name': ['new'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=source_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_matched=[WhenMatched(update={ + 'age': source_col('id'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['age'], [42]) + self.assertEqual(out['name'], ['old']) + + def test_renamed_on_key_missing_source_col_rejected(self): + target = self._create_table() + source_schema = pa.schema([ + ('uid', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=source_schema, + ) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_matched=[WhenMatched(update={ + 'id': source_col('id'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('id', str(ctx.exception)) + + def test_lit_prevents_column_ref_interpretation(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={ + 'name': lit('s.active'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['s.active']) + self.assertEqual(out['age'], [10]) + + def test_source_col_helper(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['new'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={ + 'age': source_col('age'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['old']) + self.assertEqual(out['age'], [99]) + + def test_target_col_helper(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['keep'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={ + 'age': source_col('age'), + 'name': target_col('name'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['keep']) + self.assertEqual(out['age'], [99]) + + +class TargetProjectionTest(unittest.TestCase): + + def _clause(self, spec, condition=None): + from pypaimon.ray import data_evolution_merge_into as m + return m._NormalizedClause(spec=spec, condition=condition) + + def test_unconditional_set_excludes_target_update_col(self): + from pypaimon.ray import data_evolution_merge_into as m + cols = m._resolve_target_projection( + [self._clause({'feature': 's.feature'})], + ['id'], ['feature'], ['id', 'feature', 'image'], + ) + self.assertEqual(['id'], cols) + + def test_condition_adds_referenced_target_cols(self): + from pypaimon.ray import data_evolution_merge_into as m + cols = m._resolve_target_projection( + [self._clause({'feature': 's.feature'}, condition='s.age > t.age')], + ['id'], ['feature'], ['id', 'feature', 'age', 'image'], + ) + self.assertIn('age', cols) + self.assertIn('id', cols) + + +class MergeConditionUnitTest(unittest.TestCase): + + def test_rewrite_condition(self): + from pypaimon.ray.merge_condition import rewrite_condition + self.assertEqual( + rewrite_condition('s.age > t.age + 10'), + '"s.age" > "t.age" + 10', + ) + + def test_rewrite_condition_preserves_string_literals(self): + from pypaimon.ray.merge_condition import rewrite_condition + self.assertEqual( + rewrite_condition("s.status = 't.active' AND s.age > t.age"), + '"s.status" = \'t.active\' AND "s.age" > "t.age"', + ) + + def test_remap_source_on_keys(self): + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = rewrite_condition('s.id > 1 AND s.age > t.age') + remapped = remap_source_on_keys(rewritten, {'id': 'id'}) + self.assertEqual(remapped, '"t.id" > 1 AND "s.age" > "t.age"') + + def test_remap_source_on_keys_renamed(self): + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = rewrite_condition('s.uid > 1') + remapped = remap_source_on_keys(rewritten, {'uid': 'id'}) + self.assertEqual(remapped, '"t.id" > 1') + + def test_remap_preserves_string_literals(self): + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = rewrite_condition("s.note = '\"s.id\"' AND s.id = 1") + remapped = remap_source_on_keys(rewritten, {'id': 'id'}) + self.assertEqual( + remapped, + '"s.note" = \'\"s.id\"\' AND "t.id" = 1', + ) + + def test_extract_target_columns(self): + from pypaimon.ray.merge_condition import extract_target_columns + self.assertEqual( + extract_target_columns('s.name = t.name AND s.age > t.age'), + {'name', 'age'}, + ) + + def test_extract_target_columns_ignores_string_literals(self): + from pypaimon.ray.merge_condition import extract_target_columns + self.assertEqual( + extract_target_columns("s.name = 't.fake' AND s.age > t.age"), + {'age'}, + ) + + def test_extract_columns(self): + from pypaimon.ray.merge_condition import extract_columns + self.assertEqual( + extract_columns('s.id = t.id AND s.age > t.age'), + {'s.id', 't.id', 's.age', 't.age'}, + ) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_filter_batch(self): + from pypaimon.ray.merge_condition import filter_batch + batch = pa.table({ + 's.id': pa.array([1, 2, 3], type=pa.int32()), + 's.age': pa.array([10, 25, 30], type=pa.int32()), + 't.age': pa.array([20, 20, 20], type=pa.int32()), + }) + result = filter_batch(batch, 's.age > t.age') + self.assertEqual(result.column('s.id').to_pylist(), [2, 3]) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/ray_integration_test.py b/paimon-python/pypaimon/tests/ray_integration_test.py index 10464f165339..225dc710d690 100644 --- a/paimon-python/pypaimon/tests/ray_integration_test.py +++ b/paimon-python/pypaimon/tests/ray_integration_test.py @@ -19,6 +19,7 @@ import shutil import tempfile import unittest +from unittest.mock import patch import pyarrow as pa import ray @@ -183,10 +184,13 @@ def test_read_paimon_with_limit(self): self.assertLess(limited_count, 10) def test_read_paimon_empty_table(self): - """read_paimon() on a table with no data returns an empty dataset.""" + """read_paimon() on an empty table preserves the table schema.""" from pypaimon.ray import read_paimon - pa_schema = pa.schema([('id', pa.int32())]) + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ]) identifier = 'default.test_read_empty' catalog = CatalogFactory.create(self.catalog_options) schema = Schema.from_pyarrow_schema(pa_schema) @@ -194,6 +198,40 @@ def test_read_paimon_empty_table(self): ds = read_paimon(identifier, self.catalog_options) self.assertEqual(ds.count(), 0) + self.assertEqual(ds.schema().names, ['id', 'name']) + + def test_read_paimon_empty_table_with_projection(self): + """read_paimon() applies projection to empty table schemas.""" + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('value', pa.int64()), + ]) + identifier = 'default.test_read_empty_projection' + catalog = CatalogFactory.create(self.catalog_options) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table(identifier, schema, False) + + ds = read_paimon( + identifier, self.catalog_options, projection=['id', 'value'] + ) + self.assertEqual(ds.count(), 0) + self.assertEqual(ds.schema().names, ['id', 'value']) + + def test_missing_ray_dependency_has_install_hint(self): + """Ray facade surfaces an actionable install hint when Ray is absent.""" + from pypaimon.ray.ray_paimon import _require_ray_data + + error = ModuleNotFoundError("No module named 'ray'") + error.name = 'ray' + + with patch('importlib.import_module', side_effect=error): + with self.assertRaises(ImportError) as ctx: + _require_ray_data() + + self.assertIn('pip install pypaimon[ray]', str(ctx.exception)) def test_read_paimon_with_snapshot_id(self): """read_paimon(snapshot_id=N) time-travels to that snapshot.""" @@ -306,6 +344,82 @@ def test_write_paimon_overwrite(self): df = result.to_pandas() self.assertEqual(list(df['id']), [3]) + def test_write_paimon_empty_overwrite_unpartitioned(self): + """write_paimon(overwrite=True) with empty data clears an unpartitioned table.""" + from pypaimon.ray import read_paimon, write_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('val', pa.int64()), + ]) + identifier = 'default.test_write_empty_overwrite_unpartitioned' + catalog = CatalogFactory.create(self.catalog_options) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table(identifier, schema, False) + + initial = ray.data.from_arrow( + pa.Table.from_pydict({'id': [1, 2], 'val': [10, 20]}, schema=pa_schema) + ) + write_paimon(initial, identifier, self.catalog_options) + self.assertEqual(read_paimon(identifier, self.catalog_options).count(), 2) + + empty = ray.data.from_arrow( + pa.Table.from_pydict({'id': [], 'val': []}, schema=pa_schema) + ) + write_paimon(empty, identifier, self.catalog_options, overwrite=True) + + result = read_paimon(identifier, self.catalog_options) + self.assertEqual(result.count(), 0) + + def test_table_write_ray_builder_partition_overwrite(self): + """Builder-level partition overwrite is honored by write_ray().""" + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('val', pa.string()), + ('dt', pa.string()), + ]) + identifier = 'default.test_write_ray_partition_overwrite' + catalog = CatalogFactory.create(self.catalog_options) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=['dt'], + options={'dynamic-partition-overwrite': 'false'}, + ) + catalog.create_table(identifier, schema, False) + table = catalog.get_table(identifier) + + initial = pa.Table.from_pydict( + { + 'id': [1, 2, 3], + 'val': ['old-p1-a', 'old-p1-b', 'old-p2'], + 'dt': ['p1', 'p1', 'p2'], + }, + schema=pa_schema, + ) + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(initial) + write_builder.new_commit().commit(writer.prepare_commit()) + writer.close() + + replacement = ray.data.from_arrow( + pa.Table.from_pydict( + {'id': [4], 'val': ['new-p1'], 'dt': ['p1']}, + schema=pa_schema, + ) + ) + writer = table.new_batch_write_builder().overwrite({'dt': 'p1'}).new_write() + writer.write_ray(replacement, concurrency=1) + writer.close() + + result = read_paimon(identifier, self.catalog_options) + df = result.to_pandas().sort_values('id').reset_index(drop=True) + self.assertEqual(list(df['id']), [3, 4]) + self.assertEqual(list(df['val']), ['old-p2', 'new-p1']) + self.assertEqual(list(df['dt']), ['p2', 'p1']) + def test_read_paimon_primary_key(self): """read_paimon() merges PK rows correctly after an upsert.""" from pypaimon.ray import read_paimon diff --git a/paimon-python/pypaimon/tests/ray_repartition_test.py b/paimon-python/pypaimon/tests/ray_repartition_test.py index b66b014b4f10..0d7dd568c0fc 100644 --- a/paimon-python/pypaimon/tests/ray_repartition_test.py +++ b/paimon-python/pypaimon/tests/ray_repartition_test.py @@ -16,21 +16,25 @@ # limitations under the License. ################################################################################ -"""End-to-end tests for HASH_FIXED auto-clustering on ``write_paimon``. +"""End-to-end tests for HASH_FIXED Ray writes. -For HASH_FIXED tables, ``write_paimon`` automatically pre-clusters rows -by ``(partition_keys..., bucket)`` (matching Spark/Flink). These tests -cover: +For append-only HASH_FIXED tables, ``write_paimon`` writes rows to the +correct bucket by default without pre-clustering. HASH_FIXED +primary-key tables fail fast unless the legacy ``map_groups`` mode is +explicitly selected. These tests cover: - * roundtrip correctness on a HASH_FIXED PK table. + * default roundtrip correctness on an append-only HASH_FIXED table. + * default fail-fast behaviour on a HASH_FIXED PK table. * roundtrip correctness on a partitioned HASH_FIXED PK table. - * the transient bucket column is stripped from the sink-visible - schema. - * the output is one file per (partition, bucket) — i.e. the - small-file storm is eliminated. + * explicit ``map_groups`` mode strips the transient bucket column + from the sink-visible schema. + * explicit ``map_groups`` mode can produce one file per + (partition, bucket) on the small test dataset. * regression: a table whose schema already contains a column named ``__paimon_bucket__`` still works (collision-safe column name). - * non-HASH_FIXED tables (BUCKET_UNAWARE etc.) pass through unchanged. + * non-HASH_FIXED append-only tables pass through unchanged. + * dynamic-bucket primary-key tables fail fast, while postpone-bucket + primary-key tables pass through. """ import glob @@ -114,19 +118,18 @@ def _count_data_files(self, table_name): )) return files - # ----- HASH_FIXED auto-clustering ----- + # ----- HASH_FIXED writes ----- - def test_fixed_bucket_roundtrip(self): + def test_append_only_fixed_bucket_roundtrip(self): from pypaimon.ray import write_paimon pa_schema = pa.schema([ - pa.field('id', pa.int32(), nullable=False), + ('id', pa.int32()), ('name', pa.string()), ]) - table_name = 'test_fixed_bucket_roundtrip' + table_name = 'test_append_only_fixed_bucket_roundtrip' identifier = self._make_table( - table_name, pa_schema, - primary_keys=['id'], options={'bucket': '4'}, + table_name, pa_schema, options={'bucket': '4'}, ) rows = pa.Table.from_pydict( @@ -141,6 +144,131 @@ def test_fixed_bucket_roundtrip(self): self.assertEqual(set(result['id']), set(range(40))) self.assertNotIn('__paimon_bucket__', result.columns) + def test_primary_key_fixed_bucket_default_fails_fast(self): + from pypaimon.ray import write_paimon + + pa_schema = pa.schema([ + pa.field('id', pa.int32(), nullable=False), + ('name', pa.string()), + ]) + table_name = 'test_pk_fixed_bucket_default_fails_fast' + identifier = self._make_table( + table_name, pa_schema, + primary_keys=['id'], options={'bucket': '4'}, + ) + + rows = pa.Table.from_pydict( + {'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]}, + schema=pa_schema, + ) + ds = ray.data.from_arrow(rows).repartition(4) + + with self.assertRaisesRegex(ValueError, "HASH_FIXED primary-key"): + write_paimon(ds, identifier, self.catalog_options) + + def test_table_write_ray_primary_key_fixed_bucket_default_fails_fast(self): + pa_schema = pa.schema([ + pa.field('id', pa.int32(), nullable=False), + ('name', pa.string()), + ]) + table_name = 'test_table_write_ray_pk_default_fails_fast' + identifier = self._make_table( + table_name, pa_schema, + primary_keys=['id'], options={'bucket': '4'}, + ) + + rows = pa.Table.from_pydict( + {'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]}, + schema=pa_schema, + ) + ds = ray.data.from_arrow(rows).repartition(4) + + catalog = CatalogFactory.create(self.catalog_options) + table = catalog.get_table(identifier) + writer = table.new_batch_write_builder().new_write() + try: + with self.assertRaisesRegex(ValueError, "HASH_FIXED primary-key"): + writer.write_ray(ds) + finally: + writer.close() + + def test_primary_key_dynamic_bucket_default_fails_fast(self): + from pypaimon.ray import write_paimon + + pa_schema = pa.schema([ + pa.field('id', pa.int32(), nullable=False), + ('name', pa.string()), + ]) + table_name = 'test_pk_dynamic_bucket_default_fails_fast' + identifier = self._make_table( + table_name, pa_schema, primary_keys=['id'], + ) + + rows = pa.Table.from_pydict( + {'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]}, + schema=pa_schema, + ) + ds = ray.data.from_arrow(rows).repartition(4) + + with self.assertRaisesRegex(ValueError, "HASH_DYNAMIC primary-key"): + write_paimon(ds, identifier, self.catalog_options) + + def test_table_write_ray_primary_key_dynamic_bucket_default_fails_fast(self): + pa_schema = pa.schema([ + pa.field('id', pa.int32(), nullable=False), + ('name', pa.string()), + ]) + table_name = 'test_table_write_ray_pk_dynamic_default_fails_fast' + identifier = self._make_table( + table_name, pa_schema, primary_keys=['id'], + ) + + rows = pa.Table.from_pydict( + {'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]}, + schema=pa_schema, + ) + ds = ray.data.from_arrow(rows).repartition(4) + + catalog = CatalogFactory.create(self.catalog_options) + table = catalog.get_table(identifier) + writer = table.new_batch_write_builder().new_write() + try: + with self.assertRaisesRegex(ValueError, "HASH_DYNAMIC primary-key"): + writer.write_ray(ds) + finally: + writer.close() + + def test_primary_key_postpone_bucket_roundtrip_to_postpone_files(self): + from pypaimon.ray import write_paimon + + pa_schema = pa.schema([ + pa.field('id', pa.int32(), nullable=False), + ('dt', pa.string()), + ('value', pa.int64()), + ]) + table_name = 'test_pk_postpone_bucket_ray_write' + identifier = self._make_table( + table_name, pa_schema, + primary_keys=['id', 'dt'], partition_keys=['dt'], + options={'bucket': '-2'}, + ) + + rows = pa.Table.from_pydict({ + 'id': list(range(10)), + 'dt': ['2026-01-01'] * 5 + ['2026-01-02'] * 5, + 'value': list(range(10)), + }, schema=pa_schema) + write_paimon( + ray.data.from_arrow(rows).repartition(2), + identifier, + self.catalog_options, + ) + + files = self._count_data_files(table_name) + self.assertGreater(len(files), 0) + self.assertTrue(all('/bucket-postpone/' in path for path in files)) + self.assertEqual(len(self._read_table(identifier)), 0) + def test_partitioned_fixed_bucket_roundtrip(self): """Partitioned table — confirms the post-groupby schema does not end up with duplicated partition-key or bucket columns.""" @@ -164,16 +292,50 @@ def test_partitioned_fixed_bucket_roundtrip(self): 'value': list(range(20)), }, schema=pa_schema) ds = ray.data.from_arrow(rows).repartition(4) - write_paimon(ds, identifier, self.catalog_options) + write_paimon( + ds, + identifier, + self.catalog_options, + hash_fixed_precluster="map_groups", + ) result = self._read_table(identifier) self.assertEqual(set(result.columns), {'id', 'dt', 'value'}) self.assertEqual(len(result), 20) self.assertEqual(set(result['dt']), {'2026-01-01', '2026-01-02'}) + def test_table_write_ray_primary_key_fixed_bucket_map_groups_roundtrip(self): + pa_schema = pa.schema([ + pa.field('id', pa.int32(), nullable=False), + ('name', pa.string()), + ]) + table_name = 'test_table_write_ray_pk_map_groups' + identifier = self._make_table( + table_name, pa_schema, + primary_keys=['id'], options={'bucket': '4'}, + ) + + rows = pa.Table.from_pydict( + {'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]}, + schema=pa_schema, + ) + ds = ray.data.from_arrow(rows).repartition(4) + + catalog = CatalogFactory.create(self.catalog_options) + table = catalog.get_table(identifier) + writer = table.new_batch_write_builder().new_write() + try: + writer.write_ray(ds, hash_fixed_precluster="map_groups") + finally: + writer.close() + + result = self._read_table(identifier) + self.assertEqual(len(result), 40) + self.assertEqual(set(result['id']), set(range(40))) + def test_fixed_bucket_writes_one_file_per_bucket(self): - """With multiple input blocks, auto-clustering collapses per-task - files into per-bucket files.""" + """With multiple input blocks, explicit map_groups clustering + collapses per-task files into per-bucket files.""" from pypaimon.ray import write_paimon pa_schema = pa.schema([ @@ -190,11 +352,12 @@ def test_fixed_bucket_writes_one_file_per_bucket(self): primary_keys=['id'], options={'bucket': '4'}, ) - # Materialise 4 input blocks. Without auto-clustering each task - # would emit one file per bucket it touched (up to 16 files). + # Materialise 4 input blocks. Without the explicit map_groups + # mode, each task would emit one file per bucket it touched. write_paimon( ray.data.from_arrow(rows).repartition(4), identifier, self.catalog_options, + hash_fixed_precluster="map_groups", ) files = self._count_data_files('test_one_file_per_bucket') @@ -223,7 +386,12 @@ def test_fixed_bucket_with_colliding_column_name(self): schema=pa_schema, ) ds = ray.data.from_arrow(rows).repartition(2) - write_paimon(ds, identifier, self.catalog_options) + write_paimon( + ds, + identifier, + self.catalog_options, + hash_fixed_precluster="map_groups", + ) result = self._read_table(identifier) self.assertEqual(len(result), 10) diff --git a/paimon-python/pypaimon/tests/ray_sink_test.py b/paimon-python/pypaimon/tests/ray_sink_test.py index fab77ecf8713..a6d761df5add 100644 --- a/paimon-python/pypaimon/tests/ray_sink_test.py +++ b/paimon-python/pypaimon/tests/ray_sink_test.py @@ -26,6 +26,7 @@ from pypaimon import CatalogFactory, Schema from pypaimon.write.ray_datasink import PaimonDatasink from pypaimon.write.commit_message import CommitMessage +from pypaimon.write.table_write import TableWrite class RaySinkTest(unittest.TestCase): @@ -69,23 +70,34 @@ def test_init_and_serialization(self): datasink = PaimonDatasink(self.table, overwrite=False) self.assertEqual(datasink.table, self.table) self.assertFalse(datasink.overwrite) + self.assertIsNone(datasink.static_partition) self.assertIsNone(datasink._writer_builder) self.assertEqual(datasink._table_name, "test_db.test_table") datasink_overwrite = PaimonDatasink(self.table, overwrite=True) self.assertTrue(datasink_overwrite.overwrite) + datasink_partition_overwrite = PaimonDatasink( + self.table, static_partition={'dt': '2024-01-01'}) + self.assertFalse(datasink_partition_overwrite.overwrite) + self.assertEqual( + datasink_partition_overwrite.static_partition, + {'dt': '2024-01-01'}, + ) + # Test serialization datasink._writer_builder = Mock() state = datasink.__getstate__() self.assertIn('table', state) self.assertIn('overwrite', state) + self.assertIn('static_partition', state) self.assertIn('_writer_builder', state) new_datasink = PaimonDatasink.__new__(PaimonDatasink) new_datasink.__setstate__(state) self.assertEqual(new_datasink.table, self.table) self.assertFalse(new_datasink.overwrite) + self.assertIsNone(new_datasink.static_partition) def test_table_and_writer_builder_serializable(self): import pickle @@ -120,6 +132,29 @@ def test_table_and_writer_builder_serializable(self): except Exception as e: self.fail(f"Overwrite WriterBuilder is not serializable: {e}") + def test_write_builder_new_write_carries_static_partition(self): + batch_write = ( + self.table + .new_batch_write_builder() + .overwrite({'dt': '2024-01-01'}) + .new_write() + ) + try: + self.assertEqual(batch_write.static_partition, {'dt': '2024-01-01'}) + finally: + batch_write.close() + + stream_write = ( + self.table + .new_stream_write_builder() + .overwrite({'dt': '2024-01-01'}) + .new_write() + ) + try: + self.assertEqual(stream_write.static_partition, {'dt': '2024-01-01'}) + finally: + stream_write.close() + def test_on_write_start(self): """Test on_write_start with normal and overwrite modes.""" datasink = PaimonDatasink(self.table, overwrite=False) @@ -131,6 +166,14 @@ def test_on_write_start(self): datasink_overwrite.on_write_start() self.assertIsNotNone(datasink_overwrite._writer_builder.static_partition) + datasink_partition_overwrite = PaimonDatasink( + self.table, static_partition={'dt': '2024-01-01'}) + datasink_partition_overwrite.on_write_start() + self.assertEqual( + datasink_partition_overwrite._writer_builder.static_partition, + {'dt': '2024-01-01'}, + ) + def test_write(self): """Test write method: empty blocks, multiple blocks, error handling, and resource cleanup.""" datasink = PaimonDatasink(self.table, overwrite=False) @@ -189,6 +232,25 @@ def test_write(self): datasink.write([data_table], ctx) mock_builder.assert_called_once() + partition_datasink = PaimonDatasink( + self.table, static_partition={'dt': '2024-01-01'}) + with patch.object(self.table, 'new_batch_write_builder') as mock_builder: + mock_write_builder = Mock() + mock_write_builder.overwrite.return_value = mock_write_builder + mock_write = Mock() + mock_write.prepare_commit.return_value = [] + mock_write_builder.new_write.return_value = mock_write + mock_builder.return_value = mock_write_builder + + data_table = pa.table({ + 'id': [1], + 'name': ['Alice'], + 'value': [1.1] + }) + partition_datasink.write([data_table], ctx) + mock_write_builder.overwrite.assert_called_once_with( + {'dt': '2024-01-01'}) + invalid_table = pa.table({ 'wrong_column': [1, 2, 3] }) @@ -225,6 +287,36 @@ def test_on_write_complete(self): ) datasink.on_write_complete(write_result) + # Empty overwrite must still reach TableCommit so overwrite semantics + # can delete the target range. + datasink = PaimonDatasink(self.table, overwrite=True) + datasink.on_write_start() + write_result = WriteResult( + num_rows=0, + size_bytes=0, + write_returns=[[], []] + ) + mock_commit = Mock() + datasink._writer_builder.new_commit = Mock(return_value=mock_commit) + datasink.on_write_complete(write_result) + + mock_commit.commit.assert_called_once_with([]) + mock_commit.close.assert_called_once() + + datasink = PaimonDatasink(self.table, static_partition={'dt': '2024-01-01'}) + datasink.on_write_start() + write_result = WriteResult( + num_rows=0, + size_bytes=0, + write_returns=[[], []] + ) + mock_commit = Mock() + datasink._writer_builder.new_commit = Mock(return_value=mock_commit) + datasink.on_write_complete(write_result) + + mock_commit.commit.assert_called_once_with([]) + mock_commit.close.assert_called_once() + # Test with messages and filtering empty messages datasink = PaimonDatasink(self.table, overwrite=False) datasink.on_write_start() @@ -292,6 +384,52 @@ def test_on_write_complete(self): datasink.on_write_complete(write_result) self.assertEqual(len(datasink._pending_commit_messages), 1) + def test_table_write_ray_forwards_static_partition(self): + dataset = Mock() + table_write = TableWrite.__new__(TableWrite) + table_write.table = self.table + table_write.static_partition = {'dt': '2024-01-01'} + + with patch('pypaimon.ray.shuffle.maybe_apply_repartition') as mock_repartition, \ + patch('pypaimon.write.ray_datasink.PaimonDatasink') as mock_datasink_cls: + mock_repartition.return_value = dataset + datasink = mock_datasink_cls.return_value + + table_write.write_ray(dataset, concurrency=2) + + mock_repartition.assert_called_once_with(dataset, self.table, 'auto') + mock_datasink_cls.assert_called_once_with( + self.table, + overwrite=False, + static_partition={'dt': '2024-01-01'}, + ) + dataset.write_datasink.assert_called_once_with( + datasink, + concurrency=2, + ray_remote_args=None, + ) + + def test_table_write_ray_static_partition_argument_overrides_builder(self): + dataset = Mock() + table_write = TableWrite.__new__(TableWrite) + table_write.table = self.table + table_write.static_partition = {'dt': '2024-01-01'} + + with patch('pypaimon.ray.shuffle.maybe_apply_repartition') as mock_repartition, \ + patch('pypaimon.write.ray_datasink.PaimonDatasink') as mock_datasink_cls: + mock_repartition.return_value = dataset + + table_write.write_ray( + dataset, + static_partition={'dt': '2024-01-02'}, + ) + + mock_datasink_cls.assert_called_once_with( + self.table, + overwrite=False, + static_partition={'dt': '2024-01-02'}, + ) + def test_on_write_failed(self): # Test without pending messages (on_write_complete() never called) datasink = PaimonDatasink(self.table, overwrite=False) diff --git a/paimon-python/pypaimon/tests/reader_base_test.py b/paimon-python/pypaimon/tests/reader_base_test.py index 31d8205f8802..12875cabafa5 100644 --- a/paimon-python/pypaimon/tests/reader_base_test.py +++ b/paimon-python/pypaimon/tests/reader_base_test.py @@ -286,7 +286,12 @@ def test_full_data_types(self): # assert equal actual_data = table_read.to_arrow(splits) - self.assertEqual(actual_data, expect_data) + # BINARY(N) maps to variable-length binary on read (see #7518), so the + # fixed-size f9 column normalizes to binary; reflect that in the expected. + f9_index = expect_data.schema.get_field_index('f9') + expected_data = expect_data.set_column( + f9_index, 'f9', expect_data.column('f9').cast(pa.binary())) + self.assertEqual(actual_data, expected_data) # to test GenericRow ability latest_snapshot = table.snapshot_manager().get_latest_snapshot() diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py b/paimon-python/pypaimon/tests/reader_primary_key_test.py index 89a9e8a6a144..7cae0a77c7f4 100644 --- a/paimon-python/pypaimon/tests/reader_primary_key_test.py +++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py @@ -224,12 +224,15 @@ def test_pk_multi_write_once_commit(self): read_builder = table.new_read_builder() actual = self._read_test_table(read_builder).sort_by('user_id') - # TODO support pk merge feature when multiple write + # The in-memory merge buffer in KeyValueDataWriter folds the + # two writes for user_id=2 down to the latest row before flush + # (default merge engine is deduplicate), so the PK appears once + # with the second batch's value. expected = pa.Table.from_pydict({ - 'user_id': [1, 2, 2, 3, 4, 5, 7, 8], - 'item_id': [1001, 1002, 1002, 1003, 1004, 1005, 1007, 1008], - 'behavior': ['a', 'b', 'b-new', 'c', None, 'e', 'g', 'h'], - 'dt': ['p1', 'p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'], + 'user_id': [1, 2, 3, 4, 5, 7, 8], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1007, 1008], + 'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'], }, schema=self.pa_schema) self.assertEqual(actual, expected) diff --git a/paimon-python/pypaimon/tests/schema_evolution_read_test.py b/paimon-python/pypaimon/tests/schema_evolution_read_test.py index a50b8ac44b0e..0bc5c9f142d2 100644 --- a/paimon-python/pypaimon/tests/schema_evolution_read_test.py +++ b/paimon-python/pypaimon/tests/schema_evolution_read_test.py @@ -25,6 +25,8 @@ from pypaimon import CatalogFactory, Schema +from pypaimon.schema.data_types import AtomicType +from pypaimon.schema.schema_change import SchemaChange from pypaimon.schema.schema_manager import SchemaManager from pypaimon.schema.table_schema import TableSchema @@ -202,6 +204,155 @@ def test_schema_evolution_type(self): }, schema=pa_schema) self.assertEqual(expected, actual) + def test_schema_evolution_type_promotion_unpartitioned(self): + # End-to-end via public API only (create -> write -> alter column type + # -> write -> read). A non-partitioned table whose read needs no column + # reordering takes the reader fast path that skips partition padding and + # index remapping. The file written before the type change keeps its old + # physical type, so it must be aligned to the promoted type to + # concatenate with the file written after; otherwise the read crashes + # with an Arrow schema mismatch. This is not specific to INT -> BIGINT: + # it applies to any type change of an existing column -- integer/float + # widening, DECIMAL precision/scale changes, and cross-type changes. + import decimal + + # Each case: (name, old arrow type, new arrow type, new Paimon type, + # value written to the old-schema file, that same value as it should + # read back under the new type, value written to the new-schema file). + # The old write value and its expected read form differ for cross-type + # changes, where the old file is materialized under the new type. + cases = [ + ("smallint_to_int", pa.int16(), pa.int32(), 'INT', + [10, 20], [10, 20], [30, 40]), + ("int_to_bigint", pa.int32(), pa.int64(), 'BIGINT', + [10, 20], [10, 20], [30, 40]), + ("float_to_double", pa.float32(), pa.float64(), 'DOUBLE', + [1.5, 2.5], [1.5, 2.5], [3.5, 4.5]), + ("decimal_precision_up", + pa.decimal128(10, 2), pa.decimal128(20, 2), 'DECIMAL(20, 2)', + [decimal.Decimal('1.23'), decimal.Decimal('4.56')], + [decimal.Decimal('1.23'), decimal.Decimal('4.56')], + [decimal.Decimal('7.89'), decimal.Decimal('0.12')]), + ("decimal_scale_up", + pa.decimal128(10, 2), pa.decimal128(10, 4), 'DECIMAL(10, 4)', + [decimal.Decimal('1.23'), decimal.Decimal('4.56')], + [decimal.Decimal('1.2300'), decimal.Decimal('4.5600')], + [decimal.Decimal('7.8901'), decimal.Decimal('0.1234')]), + ("int_to_string", pa.int32(), pa.string(), 'STRING', + [10, 20], ['10', '20'], ['a', 'b']), + # Lossy cross-type change: DOUBLE -> INT truncates (matches Java + # CastExecutors), so 1.2/2.8 read back as 1/2. + ("double_to_int", pa.float64(), pa.int32(), 'INT', + [1.2, 2.8], [1, 2], [3, 4]), + # Lossy DECIMAL scale-down: (10,4) -> (10,2) truncates the extra + # scale rather than raising. + ("decimal_scale_down", + pa.decimal128(10, 4), pa.decimal128(10, 2), 'DECIMAL(10, 2)', + [decimal.Decimal('1.2345'), decimal.Decimal('4.5678')], + [decimal.Decimal('1.23'), decimal.Decimal('4.56')], + [decimal.Decimal('7.89'), decimal.Decimal('0.12')]), + ] + + for (name, old_type, new_type, new_type_str, + write_vals, old_read_vals, new_vals) in cases: + with self.subTest(case=name): + table_name = f'default.promo_{name}' + old_schema = pa.schema([('k', pa.int64()), ('v', old_type)]) + self.catalog.create_table( + table_name, Schema.from_pyarrow_schema(old_schema), False) + + # Write under the original schema (file stamped schema_id 0). + table = self.catalog.get_table(table_name) + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(pa.Table.from_pydict( + {'k': [1, 2], 'v': write_vals}, schema=old_schema)) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Widen column v through the catalog (new schema_id 1). + self.catalog.alter_table( + table_name, + [SchemaChange.update_column_type( + 'v', AtomicType(new_type_str))], + False) + + # Write under the promoted schema (file stamped schema_id 1). + table = self.catalog.get_table(table_name) + new_schema = pa.schema([('k', pa.int64()), ('v', new_type)]) + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(pa.Table.from_pydict( + {'k': [3, 4], 'v': new_vals}, schema=new_schema)) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Plain full-table read spanning both schema versions. + read_builder = table.new_read_builder() + actual = read_builder.new_read().to_arrow( + self._scan_table(read_builder)) + expected = pa.Table.from_pydict( + {'k': [1, 2, 3, 4], 'v': old_read_vals + new_vals}, + schema=new_schema) + self.assertEqual(expected, actual) + + def test_schema_evolution_type_lossy_old_file_only(self): + # Reading ONLY old-schema files after a lossy type change (no + # newer-schema file in the splits). The output type must equal the + # current read schema regardless of which files the read spans, and the + # conversion must truncate to match Java CastExecutors rather than + # raise. (A previous fix that relied on pyarrow's safe cast crashed + # here on lossy evolutions.) + import decimal + + cases = [ + ("scale_down", + pa.decimal128(10, 4), pa.decimal128(10, 2), 'DECIMAL(10, 2)', + [decimal.Decimal('1.2345'), decimal.Decimal('4.5678')], + [decimal.Decimal('1.23'), decimal.Decimal('4.56')]), + ("double_to_int", pa.float64(), pa.int32(), 'INT', + [1.2, 2.8], [1, 2]), + ] + + for name, old_type, new_type, new_type_str, write_vals, read_vals \ + in cases: + with self.subTest(case=name): + table_name = f'default.lossy_old_only_{name}' + old_schema = pa.schema([('k', pa.int64()), ('v', old_type)]) + self.catalog.create_table( + table_name, Schema.from_pyarrow_schema(old_schema), False) + + # Write under the original schema, then change the type. No + # write happens afterwards, so the read sees only this file. + table = self.catalog.get_table(table_name) + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(pa.Table.from_pydict( + {'k': [1, 2], 'v': write_vals}, schema=old_schema)) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + self.catalog.alter_table( + table_name, + [SchemaChange.update_column_type( + 'v', AtomicType(new_type_str))], + False) + + table = self.catalog.get_table(table_name) + new_schema = pa.schema([('k', pa.int64()), ('v', new_type)]) + read_builder = table.new_read_builder() + actual = read_builder.new_read().to_arrow( + self._scan_table(read_builder)) + expected = pa.Table.from_pydict( + {'k': [1, 2], 'v': read_vals}, schema=new_schema) + self.assertEqual(expected, actual) + def test_schema_evolution_with_scan_filter(self): # schema 0 pa_schema = pa.schema([ diff --git a/paimon-python/pypaimon/tests/system/buckets_table_test.py b/paimon-python/pypaimon/tests/system/buckets_table_test.py new file mode 100644 index 000000000000..57a352392a24 --- /dev/null +++ b/paimon-python/pypaimon/tests/system/buckets_table_test.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""End-to-end tests for the ``$buckets`` system table.""" + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.schema.data_types import DataField +from pypaimon.table.system.buckets_table import BucketsTable + + +def _read(table): + rb = table.new_read_builder() + return rb.new_read().to_arrow(rb.new_scan().plan().splits()) + + +class BucketsTableTest(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp(prefix="buckets_sys_") + warehouse = os.path.join(self.tmp, "warehouse") + self.catalog = CatalogFactory.create({"warehouse": warehouse}) + self.catalog.create_database("db", False) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _create_partitioned_table(self, num_buckets=2): + fields = [ + DataField.from_dict({"id": 0, "name": "id", "type": "INT"}), + DataField.from_dict({"id": 1, "name": "v", "type": "STRING"}), + DataField.from_dict({"id": 2, "name": "dt", "type": "STRING"}), + ] + self.catalog.create_table( + "db.t", + Schema( + fields=fields, + partition_keys=["dt"], + options={"bucket": str(num_buckets)}, + ), + False, + ) + + def _write_data(self): + table = self.catalog.get_table("db.t") + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + commit = write_builder.new_commit() + writer.write_arrow(pa.table({ + "id": pa.array([1, 2, 3, 4], type=pa.int32()), + "v": ["a", "b", "c", "d"], + "dt": ["2024-01-01", "2024-01-01", "2024-01-02", "2024-01-02"], + })) + commit.commit(writer.prepare_commit()) + writer.close() + commit.close() + + def test_buckets_table_loaded_via_catalog(self): + self._create_partitioned_table() + table = self.catalog.get_table("db.t$buckets") + self.assertIsInstance(table, BucketsTable) + + def test_schema_column_layout(self): + self._create_partitioned_table() + table = self.catalog.get_table("db.t$buckets") + row_type = table.row_type() + expected = [ + ("partition", True), ("bucket", False), + ("record_count", False), ("file_size_in_bytes", False), + ("file_count", False), ("last_update_time", True), + ] + self.assertEqual([n for n, _ in expected], + [f.name for f in row_type.fields]) + for field, (_, expected_nullable) in zip(row_type.fields, expected): + self.assertEqual(expected_nullable, field.type.nullable, + "field {} nullability".format(field.name)) + self.assertEqual(["partition", "bucket"], table.primary_keys()) + + def test_empty_when_no_snapshot_exists(self): + self._create_partitioned_table() + arrow_table = _read(self.catalog.get_table("db.t$buckets")) + self.assertEqual(0, arrow_table.num_rows) + + def test_aggregates_by_partition_and_bucket(self): + self._create_partitioned_table(num_buckets=2) + self._write_data() + + arrow_table = _read(self.catalog.get_table("db.t$buckets")) + self.assertGreater(arrow_table.num_rows, 0) + + partitions = arrow_table.column("partition").to_pylist() + self.assertTrue(all(p in ("dt=2024-01-01", "dt=2024-01-02") + for p in partitions)) + + for size in arrow_table.column("file_size_in_bytes").to_pylist(): + self.assertGreater(size, 0) + for count in arrow_table.column("file_count").to_pylist(): + self.assertGreaterEqual(count, 1) + + total_records = sum(arrow_table.column("record_count").to_pylist()) + self.assertEqual(4, total_records) + + def test_rows_sorted_by_partition_then_bucket(self): + self._create_partitioned_table(num_buckets=2) + self._write_data() + + arrow_table = _read(self.catalog.get_table("db.t$buckets")) + partitions = arrow_table.column("partition").to_pylist() + buckets = arrow_table.column("bucket").to_pylist() + + rows = list(zip(partitions, buckets)) + self.assertEqual(rows, sorted(rows)) + + +if __name__ == "__main__": + unittest.main() diff --git a/paimon-python/pypaimon/tests/system/system_table_loader_test.py b/paimon-python/pypaimon/tests/system/system_table_loader_test.py index 1c8690a2bcb8..07b959a2c6c2 100644 --- a/paimon-python/pypaimon/tests/system/system_table_loader_test.py +++ b/paimon-python/pypaimon/tests/system/system_table_loader_test.py @@ -28,6 +28,7 @@ "manifests", "files", "partitions", + "buckets", "tags", "branches", ) @@ -41,7 +42,6 @@ "consumers", "statistics", "aggregation_fields", - "buckets", "file_key_ranges", "table_indexes", "row_tracking", diff --git a/paimon-python/pypaimon/tests/table_scan_mode_test.py b/paimon-python/pypaimon/tests/table_scan_mode_test.py new file mode 100644 index 000000000000..c33fdf9beb40 --- /dev/null +++ b/paimon-python/pypaimon/tests/table_scan_mode_test.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +import warnings +from unittest.mock import Mock + +from pypaimon.common.options.core_options import CoreOptions, StartupMode +from pypaimon.read.table_scan import TableScan + + +def _scan(options): + scan = TableScan.__new__(TableScan) + scan.table = Mock() + scan.table.options = CoreOptions.from_dict(options) + return scan + + +class TableScanModeTest(unittest.TestCase): + + def test_from_timestamp_requires_timestamp_option(self): + scan = _scan({ + CoreOptions.SCAN_MODE.key(): StartupMode.FROM_TIMESTAMP.value, + }) + + with self.assertRaisesRegex( + ValueError, + "neither scan.timestamp-millis nor scan.timestamp is set"): + scan._validate_scan_mode() + + def test_latest_conflicts_with_snapshot_id(self): + scan = _scan({ + CoreOptions.SCAN_MODE.key(): StartupMode.LATEST.value, + CoreOptions.SCAN_SNAPSHOT_ID.key(): "1", + }) + + with self.assertRaisesRegex(ValueError, "scan.snapshot-id"): + scan._validate_scan_mode() + + def test_default_with_timestamp_millis_resolves_to_from_timestamp(self): + options = CoreOptions.from_dict({ + CoreOptions.SCAN_MODE.key(): StartupMode.DEFAULT.value, + CoreOptions.SCAN_TIMESTAMP_MILLIS.key(): "123", + }) + + self.assertEqual(options.startup_mode(), StartupMode.FROM_TIMESTAMP) + _scan(options.options.to_map())._validate_scan_mode() + + def test_default_with_snapshot_id_resolves_to_from_snapshot(self): + options = CoreOptions.from_dict({ + CoreOptions.SCAN_MODE.key(): StartupMode.DEFAULT.value, + CoreOptions.SCAN_SNAPSHOT_ID.key(): "1", + }) + + self.assertEqual(options.startup_mode(), StartupMode.FROM_SNAPSHOT) + _scan(options.options.to_map())._validate_scan_mode() + + def test_unsupported_scan_modes_raise_value_error(self): + scan = _scan({ + CoreOptions.SCAN_MODE.key(): StartupMode.COMPACTED_FULL.value, + }) + + with self.assertRaisesRegex(ValueError, "not yet supported"): + scan._validate_scan_mode() + + def test_full_mode_maps_to_latest_full_with_deprecation_warning(self): + options = CoreOptions.from_dict({ + CoreOptions.SCAN_MODE.key(): StartupMode.FULL.value, + }) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + self.assertEqual(options.startup_mode(), StartupMode.LATEST_FULL) + + self.assertEqual(len(caught), 1) + self.assertTrue(issubclass(caught[0].category, DeprecationWarning)) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/table_update_test.py b/paimon-python/pypaimon/tests/table_update_test.py index 53a84666d17f..57ae605703a4 100644 --- a/paimon-python/pypaimon/tests/table_update_test.py +++ b/paimon-python/pypaimon/tests/table_update_test.py @@ -123,6 +123,45 @@ def test_update_multiple_columns(self): result['city'].to_pylist(), ) + def test_update_columns_fall_back_to_data_when_unset(self): + table = self._create_seeded_table() + + self._do_update(table, pa.Table.from_pydict({ + '_ROW_ID': [0, 1, 2, 3, 4], + 'id': [1, 2, 3, 4, 5], + 'name': ['A', 'B', 'C', 'D', 'E'], + 'age': [1, 2, 3, 4, 5], + 'city': ['c0', 'c1', 'c2', 'c3', 'c4'], + }), ['id', 'name', 'age', 'city']) + result = self._read_all(table) + self.assertEqual(['A', 'B', 'C', 'D', 'E'], result['name'].to_pylist()) + self.assertEqual([1, 2, 3, 4, 5], result['age'].to_pylist()) + self.assertEqual(['c0', 'c1', 'c2', 'c3', 'c4'], result['city'].to_pylist()) + + wb = self._make_write_builder(table) + tu = wb.new_update() + cid = self._next_commit_id() + msgs = self._apply_update(tu, pa.Table.from_pydict({ + '_ROW_ID': [0, 1], + 'age': [99, 98], + }), cid) + tc = wb.new_commit() + self._apply_commit(tc, msgs, cid) + tc.close() + result = self._read_all(table) + self.assertEqual([99, 98, 3, 4, 5], result['age'].to_pylist()) + self.assertEqual(['A', 'B', 'C', 'D', 'E'], result['name'].to_pylist()) + + def test_update_with_only_row_id_raises(self): + table = self._create_seeded_table() + wb = self._make_write_builder(table) + tu = wb.new_update() + cid = self._next_commit_id() + with self.assertRaises(ValueError): + self._apply_update(tu, pa.Table.from_pydict({ + '_ROW_ID': [0, 1], + }), cid) + def test_partitioned_table_update(self): """Updates work on a partitioned table the same as a flat one.""" table = self._create_table(partition_keys=['city']) @@ -388,7 +427,7 @@ def test_missing_row_id_column_raises(self): self.assertIn('_ROW_ID column', str(ctx.exception)) def test_invalid_row_id_raises(self): - """row_id outside [0, total_row_count) (both directions) raises.""" + """row_id outside valid row_id ranges raises.""" table = self._create_seeded_table() cases = [ ('out_of_range_high', [0, 10], [26, 100]), @@ -401,7 +440,7 @@ def test_invalid_row_id_raises(self): bad = pa.Table.from_pydict({'_ROW_ID': row_ids, 'age': ages}) with self.assertRaises(ValueError) as ctx: self._apply_update(tu, bad, self._next_commit_id()) - self.assertIn('out of valid range', str(ctx.exception)) + self.assertIn('does not belong to any valid range', str(ctx.exception)) def test_duplicate_row_id_raises(self): table = self._create_seeded_table() @@ -418,6 +457,46 @@ def test_duplicate_row_id_raises(self): ) self.assertIn('duplicate _ROW_ID', str(ctx.exception)) + def test_update_deleted_row_id_raises(self): + """Updating a row_id that fell into a hole after truncate raises.""" + partitioned_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ('region', pa.string()), + ]) + table = self._create_table( + pa_schema=partitioned_schema, + partition_keys=['region'], + ) + self._write_arrow(table, pa.Table.from_pydict({ + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['A', 'B', 'C'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + 'region': ['US', 'US', 'US'], + }, schema=partitioned_schema)) + + self._write_arrow(table, pa.Table.from_pydict({ + 'id': pa.array([4, 5], type=pa.int32()), + 'name': ['D', 'E'], + 'age': pa.array([40, 50], type=pa.int32()), + 'region': ['EU', 'EU'], + }, schema=partitioned_schema)) + + wb = table.new_batch_write_builder() + tc = wb.new_commit() + tc.truncate_partitions([{'region': 'US'}]) + + wb = self._make_write_builder(table) + tu = wb.new_update().with_update_type(['age']) + with self.assertRaises(ValueError) as ctx: + self._apply_update( + tu, + pa.Table.from_pydict({'_ROW_ID': [0], 'age': [99]}), + self._next_commit_id(), + ) + self.assertIn('does not belong to any valid range', str(ctx.exception)) + # ------------------------------------------------------------------ # Concurrency tests # ------------------------------------------------------------------ diff --git a/paimon-python/pypaimon/tests/table_upsert_by_key_test.py b/paimon-python/pypaimon/tests/table_upsert_by_key_test.py index 1d85ac0b99ec..f276e44ab8ea 100644 --- a/paimon-python/pypaimon/tests/table_upsert_by_key_test.py +++ b/paimon-python/pypaimon/tests/table_upsert_by_key_test.py @@ -436,6 +436,48 @@ def test_partitioned_update_cols_with_new_rows(self): self.assertEqual('Carol', names[idx3]) self.assertEqual('US', regions[idx3]) + def test_upsert_after_truncate_partition(self): + table = self._create_table( + pa_schema=self.partitioned_pa_schema, + partition_keys=['region'], + ) + self._write_arrow(table, pa.Table.from_pydict({ + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['A', 'B', 'C'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + 'region': ['US', 'US', 'US'], + }, schema=self.partitioned_pa_schema)) + + self._write_arrow(table, pa.Table.from_pydict({ + 'id': pa.array([4, 5], type=pa.int32()), + 'name': ['D', 'E'], + 'age': pa.array([40, 50], type=pa.int32()), + 'region': ['EU', 'EU'], + }, schema=self.partitioned_pa_schema)) + + wb = table.new_batch_write_builder() + tc = wb.new_commit() + tc.truncate_partitions([{'region': 'US'}]) + + upsert_data = pa.Table.from_pydict({ + 'id': pa.array([4], type=pa.int32()), + 'name': ['D_v2'], + 'age': pa.array([41], type=pa.int32()), + 'region': ['EU'], + }, schema=self.partitioned_pa_schema) + self._upsert(table, upsert_data, upsert_keys=['id']) + + result = self._read_all(table) + self.assertEqual(2, result.num_rows) + rows = sorted(zip( + result['id'].to_pylist(), + result['name'].to_pylist(), + result['age'].to_pylist(), + result['region'].to_pylist(), + )) + self.assertEqual((4, 'D_v2', 41, 'EU'), rows[0]) + self.assertEqual((5, 'E', 50, 'EU'), rows[1]) + # ================================================================== # update_cols partial update (non-partitioned) # ================================================================== diff --git a/paimon-python/pypaimon/tests/test_aggregation_e2e.py b/paimon-python/pypaimon/tests/test_aggregation_e2e.py new file mode 100644 index 000000000000..0d557799e9c2 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_aggregation_e2e.py @@ -0,0 +1,375 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""End-to-end tests for the ``aggregation`` merge engine. + +Each test creates a PK table with ``merge-engine=aggregation`` plus +per-field aggregator configuration, writes two or more commits against +the same PK, and reads back. The aggregation engine must reduce each +non-PK column independently using the configured aggregator (sum / max +/ last_value / ...). Disjoint PKs must remain unmerged. Default +behaviour when no aggregator is configured is ``last_non_null_value``. + +The second half of the file exercises the merge-engine-support guard: +tables that configure aggregation with options pypaimon does not yet +implement (retract opt-ins, sequence-group, out-of-scope aggregator +identifiers) must raise ``NotImplementedError`` at TableRead +construction rather than silently fall back to a wrong answer. +""" + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema + + +class AggregationMergeEngineE2ETest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('total', pa.int64()), + ('max_score', pa.int64()), + ('label', pa.string()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_pk_table(self, table_name, field_aggs=None, + default_agg=None, extra_options=None): + # bucket=1 forces all rows for a given PK to land in the same + # bucket, which routes reads through SortMergeReader where the + # aggregation merge function lives. Without it, fresh + # single-snapshot tables take the raw_convertible fast path and + # bypass the merge function entirely. + options = { + 'bucket': '1', + 'merge-engine': 'aggregation', + } + if field_aggs: + for field_name, agg_func in field_aggs.items(): + options['fields.{}.aggregate-function'.format(field_name)] = agg_func + if default_agg: + options['fields.default-aggregate-function'] = default_agg + if extra_options: + options.update(extra_options) + schema = Schema.from_pyarrow_schema( + self.pa_schema, + primary_keys=['id'], + options=options, + ) + full = 'default.{}'.format(table_name) + self.catalog.create_table(full, schema, False) + return self.catalog.get_table(full) + + def _write(self, table, rows): + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=self.pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + def _read(self, table): + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + if not splits: + return [] + return sorted( + rb.new_read().to_arrow(splits).to_pylist(), + key=lambda r: r['id'], + ) + + # -- aggregation happy path ----------------------------------------- + + def test_sum_aggregator_across_commits(self): + table = self._create_pk_table( + 'agg_sum', + field_aggs={'total': 'sum'}, + ) + self._write(table, [{'id': 1, 'total': 10, 'max_score': 5, 'label': 'a'}]) + self._write(table, [{'id': 1, 'total': 20, 'max_score': 3, 'label': 'b'}]) + self._write(table, [{'id': 1, 'total': 30, 'max_score': 8, 'label': 'c'}]) + + rows = self._read(table) + self.assertEqual(len(rows), 1) + row = rows[0] + self.assertEqual(row['id'], 1) + # total: 10 + 20 + 30 = 60 + self.assertEqual(row['total'], 60) + # max_score and label have no aggregator configured → default + # last_non_null_value: latest non-null wins. + self.assertEqual(row['max_score'], 8) + self.assertEqual(row['label'], 'c') + + def test_multiple_aggregators_compose(self): + table = self._create_pk_table( + 'agg_multi', + field_aggs={ + 'total': 'sum', + 'max_score': 'max', + 'label': 'last_value', + }, + ) + self._write(table, [{'id': 1, 'total': 10, 'max_score': 5, 'label': 'a'}]) + self._write(table, [{'id': 1, 'total': 7, 'max_score': 12, 'label': 'b'}]) + self._write(table, [{'id': 1, 'total': 3, 'max_score': 1, 'label': 'c'}]) + + row = self._read(table)[0] + self.assertEqual(row['total'], 20) # sum: 10+7+3 + self.assertEqual(row['max_score'], 12) # max: max(5,12,1) + self.assertEqual(row['label'], 'c') # last_value + + def test_null_inputs_follow_aggregator_semantics(self): + table = self._create_pk_table( + 'agg_nulls', + field_aggs={ + 'total': 'sum', + 'max_score': 'last_value', + }, + ) + self._write(table, [{'id': 1, 'total': 5, 'max_score': 7, 'label': 'x'}]) + # null total is absorbed by sum; null max_score replaces under + # last_value (last_value keeps the last input verbatim, + # including None). + self._write(table, [{'id': 1, 'total': None, 'max_score': None, 'label': None}]) + self._write(table, [{'id': 1, 'total': 4, 'max_score': 9, 'label': 'y'}]) + + row = self._read(table)[0] + self.assertEqual(row['total'], 9) # 5 + 4 (None absorbed) + self.assertEqual(row['max_score'], 9) # last_value's last input + # label: default last_non_null_value, intermediate None ignored, + # the final 'y' wins. + self.assertEqual(row['label'], 'y') + + def test_disjoint_keys_remain_unmerged(self): + table = self._create_pk_table( + 'agg_disjoint', + field_aggs={'total': 'sum'}, + ) + self._write(table, [ + {'id': 1, 'total': 10, 'max_score': 1, 'label': 'a'}, + {'id': 2, 'total': 20, 'max_score': 2, 'label': 'b'}, + {'id': 3, 'total': 30, 'max_score': 3, 'label': 'c'}, + ]) + # Second commit only touches id=2. + self._write(table, [{'id': 2, 'total': 5, 'max_score': 7, 'label': 'B'}]) + + rows = self._read(table) + self.assertEqual(rows, [ + {'id': 1, 'total': 10, 'max_score': 1, 'label': 'a'}, + {'id': 2, 'total': 25, 'max_score': 7, 'label': 'B'}, + {'id': 3, 'total': 30, 'max_score': 3, 'label': 'c'}, + ]) + + def test_default_aggregator_applies_to_unconfigured_fields(self): + table = self._create_pk_table( + 'agg_default', + default_agg='max', + ) + self._write(table, [{'id': 1, 'total': 3, 'max_score': 5, 'label': 'm'}]) + self._write(table, [{'id': 1, 'total': 7, 'max_score': 2, 'label': 'a'}]) + self._write(table, [{'id': 1, 'total': 1, 'max_score': 9, 'label': 'z'}]) + + row = self._read(table)[0] + # All non-PK fields fall through to fields.default-aggregate-function=max. + self.assertEqual(row['total'], 7) + self.assertEqual(row['max_score'], 9) + self.assertEqual(row['label'], 'z') # 'z' > 'm' > 'a' lexicographically + + def test_default_behavior_is_last_non_null_value(self): + # No field-level or default aggregator configured → every non-PK + # field uses the system default last_non_null_value. + table = self._create_pk_table('agg_implicit_default') + self._write(table, [{'id': 1, 'total': 5, 'max_score': 9, 'label': 'a'}]) + self._write(table, [{'id': 1, 'total': None, 'max_score': 3, 'label': None}]) + self._write(table, [{'id': 1, 'total': 7, 'max_score': None, 'label': 'b'}]) + + row = self._read(table)[0] + self.assertEqual(row['total'], 7) # latest non-null + self.assertEqual(row['max_score'], 3) # latest non-null + self.assertEqual(row['label'], 'b') # latest non-null + + # -- unsupported-option guards -------------------------------------- + # + # Tables that opt into behaviour AggregateMergeFunction doesn't + # implement must surface a NotImplementedError at TableRead + # construction, not silently produce wrong results. + + def _create_and_expect_unsupported(self, table_name, extra_options, + expected_substring, + error_type=NotImplementedError): + table = self._create_pk_table( + table_name, extra_options=extra_options + ) + # Writing is fine — the guard fires when a reader is built. + self._write(table, [{'id': 1, 'total': 1, 'max_score': 1, 'label': 'a'}]) + rb = table.new_read_builder() + with self.assertRaises(error_type) as cm: + rb.new_read() + msg = str(cm.exception) + if error_type is NotImplementedError: + self.assertIn('aggregation', msg) + self.assertIn(expected_substring, msg) + + def test_remove_record_on_delete_rejected(self): + self._create_and_expect_unsupported( + 'agg_reject_remove_on_delete', + {'aggregation.remove-record-on-delete': 'true'}, + 'aggregation.remove-record-on-delete', + ) + + def test_field_ignore_retract_rejected(self): + self._create_and_expect_unsupported( + 'agg_reject_ignore_retract', + {'fields.total.ignore-retract': 'true'}, + 'fields.total.ignore-retract', + ) + + def test_sequence_field_supported(self): + # Top-level sequence.field is honored by the aggregation engine: + # aggregators fold in sequence-field order, not file order. Here + # ``last_value`` must pick the value from the highest-``total`` row + # even though it was written first. + table = self._create_pk_table( + 'agg_sequence_field', + field_aggs={'max_score': 'last_value', 'label': 'last_value'}, + extra_options={'sequence.field': 'total'}, + ) + self._write(table, [{'id': 1, 'total': 100, 'max_score': 9, 'label': 'hi'}]) + self._write(table, [{'id': 1, 'total': 50, 'max_score': 1, 'label': 'lo'}]) + self.assertEqual( + self._read(table), + [{'id': 1, 'total': 100, 'max_score': 9, 'label': 'hi'}], + ) + + def test_aggregate_function_on_sequence_field_rejected(self): + # An explicit aggregator on the sequence column is invalid: Java + # rejects fields..aggregate-function in + # SchemaValidation.validateSequenceField. Rather than silently + # override 'sum' with last_value, the guard must reject it. + self._create_and_expect_unsupported( + 'agg_reject_agg_on_seq', + {'sequence.field': 'total', + 'fields.total.aggregate-function': 'sum'}, + 'fields.total.aggregate-function', + error_type=ValueError, + ) + + def test_field_sequence_group_rejected(self): + self._create_and_expect_unsupported( + 'agg_reject_sequence_group', + {'fields.max_score.sequence-group': 'label'}, + 'fields.max_score.sequence-group', + ) + + def test_out_of_scope_field_aggregator_rejected(self): + # collect is one of the aggregator identifiers this engine + # doesn't support yet. The guard must reject the config rather + # than let the per-field factory build a (silently wrong) + # fallback. + self._create_and_expect_unsupported( + 'agg_reject_collect', + {'fields.label.aggregate-function': 'collect'}, + 'fields.label.aggregate-function', + ) + + def test_out_of_scope_default_aggregator_rejected(self): + self._create_and_expect_unsupported( + 'agg_reject_default_collect', + {'fields.default-aggregate-function': 'product'}, + 'fields.default-aggregate-function', + ) + + def test_supported_field_aggregator_passes_guard(self): + # Sanity check: setting one of the supported aggregators does + # NOT trip the guard introduced for out-of-scope identifiers. + table = self._create_pk_table( + 'agg_supported_passes', + field_aggs={'total': 'sum'}, + ) + self._write(table, [{'id': 1, 'total': 1, 'max_score': 1, 'label': 'a'}]) + # If the guard wrongly flagged 'sum', new_read() would raise. + # Touch it explicitly so the test fails loudly otherwise. + table.new_read_builder().new_read() + + # -- partition column that is also part of the primary key ---------- + + def test_partition_pk_overlap_not_aggregated_by_default(self): + # When a partition column is also part of the primary key and a + # table-wide ``fields.default-aggregate-function`` is configured, + # the partition-PK column must be treated as PK (identity) and + # not run through the default aggregator. Regression for the + # split_read bug where the trimmed PK list (which drops + # partition columns) was passed to ``build_field_aggregators``. + pa_schema = pa.schema([ + pa.field('p', pa.int64(), nullable=False), + pa.field('id', pa.int64(), nullable=False), + pa.field('v', pa.int64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + primary_keys=['p', 'id'], + partition_keys=['p'], + options={ + 'bucket': '1', + 'merge-engine': 'aggregation', + 'fields.default-aggregate-function': 'sum', + }, + ) + self.catalog.create_table( + 'default.agg_partition_pk_overlap', schema, False) + table = self.catalog.get_table('default.agg_partition_pk_overlap') + + def write(rows): + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + write([{'p': 1, 'id': 1, 'v': 10}]) + write([{'p': 1, 'id': 1, 'v': 20}]) + + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + rows = rb.new_read().to_arrow(splits).to_pylist() + self.assertEqual(rows, [{'p': 1, 'id': 1, 'v': 30}]) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_aggregation_merge_function.py b/paimon-python/pypaimon/tests/test_aggregation_merge_function.py new file mode 100644 index 000000000000..7ff00fbb9b62 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_aggregation_merge_function.py @@ -0,0 +1,300 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Direct unit tests for :class:`AggregateMergeFunction` and its +helper functions. + +Drives the merge function with synthetic :class:`KeyValue` instances +so the contract is pinned down without going through the full read +pipeline. The end-to-end behaviour on real PK tables is exercised +separately in ``test_aggregation_e2e.py``. +""" + +import unittest + +from pypaimon.common.options.core_options import CoreOptions +from pypaimon.common.options.options import Options +from pypaimon.read.reader.aggregate import create_field_aggregator +from pypaimon.read.reader.aggregation_merge_function import ( + AggregateMergeFunction, + build_field_aggregators, + resolve_agg_func_name, +) +from pypaimon.schema.data_types import AtomicType, DataField +from pypaimon.table.row.key_value import KeyValue +from pypaimon.table.row.row_kind import RowKind + + +def _kv(key, seq, row_kind, value): + """Build a fresh KeyValue for a (key, sequence, row_kind, value) + tuple. Same shape as the helper in + ``test_partial_update_merge_function.py`` so both test files read + consistently. + """ + kv = KeyValue(key_arity=len(key), value_arity=len(value)) + kv.replace(tuple(key) + (seq, row_kind.value) + tuple(value)) + return kv + + +def _result_value(kv): + return tuple(kv.value.get_field(i) for i in range(kv.value_arity)) + + +def _result_key(kv): + return tuple(kv.key.get_field(i) for i in range(kv.key_arity)) + + +def _make_agg(identifier, sql_type, field_name="f"): + return create_field_aggregator( + AtomicType(sql_type), field_name, identifier, options=None + ) + + +class AggregateMergeFunctionTest(unittest.TestCase): + + def test_single_insert_returns_aggregated_value(self): + mf = AggregateMergeFunction( + key_arity=1, value_arity=1, + field_aggregators=[_make_agg("sum", "BIGINT")], + ) + mf.reset() + mf.add(_kv((1,), 100, RowKind.INSERT, (10,))) + result = mf.get_result() + + self.assertIsNotNone(result) + self.assertEqual(_result_key(result), (1,)) + self.assertEqual(_result_value(result), (10,)) + self.assertEqual(result.sequence_number, 100) + self.assertEqual(result.value_row_kind_byte, RowKind.INSERT.value) + + def test_multi_row_aggregation_across_fields(self): + # field 0: sum, field 1: max, field 2: last_value + mf = AggregateMergeFunction( + key_arity=1, value_arity=3, + field_aggregators=[ + _make_agg("sum", "BIGINT", "v_sum"), + _make_agg("max", "INT", "v_max"), + _make_agg("last_value", "VARCHAR", "v_last"), + ], + ) + mf.reset() + mf.add(_kv((1,), 100, RowKind.INSERT, (10, 5, "a"))) + mf.add(_kv((1,), 101, RowKind.INSERT, (20, 3, "b"))) + mf.add(_kv((1,), 102, RowKind.INSERT, (30, 9, "c"))) + + result = mf.get_result() + self.assertEqual(_result_value(result), (60, 9, "c")) + # latest sequence is propagated through. + self.assertEqual(result.sequence_number, 102) + + def test_null_inputs_follow_aggregator_semantics(self): + # sum drops nulls; last_non_null_value drops nulls; last_value keeps them. + mf = AggregateMergeFunction( + key_arity=1, value_arity=3, + field_aggregators=[ + _make_agg("sum", "BIGINT", "v_sum"), + _make_agg("last_non_null_value", "VARCHAR", "v_lnn"), + _make_agg("last_value", "VARCHAR", "v_last"), + ], + ) + mf.reset() + mf.add(_kv((1,), 100, RowKind.INSERT, (10, "x", "x"))) + mf.add(_kv((1,), 101, RowKind.INSERT, (None, None, None))) + mf.add(_kv((1,), 102, RowKind.INSERT, (5, None, "z"))) + + result = mf.get_result() + # sum: 10 + 5 = 15 (nulls absorbed) + # last_non_null: 'x' (intermediate nulls preserved earlier value) + # last_value: 'z' (the very last value, including the prior None) + self.assertEqual(_result_value(result), (15, "x", "z")) + + def test_update_after_treated_as_insert(self): + mf = AggregateMergeFunction( + key_arity=1, value_arity=1, + field_aggregators=[_make_agg("sum", "BIGINT")], + ) + mf.reset() + mf.add(_kv((1,), 100, RowKind.UPDATE_AFTER, (7,))) + result = mf.get_result() + self.assertEqual(_result_value(result), (7,)) + + def test_delete_row_raises_not_implemented(self): + mf = AggregateMergeFunction( + key_arity=1, value_arity=1, + field_aggregators=[_make_agg("sum", "BIGINT")], + ) + mf.reset() + with self.assertRaises(NotImplementedError) as ctx: + mf.add(_kv((1,), 100, RowKind.DELETE, (5,))) + self.assertIn("retract", str(ctx.exception)) + self.assertIn("-D", str(ctx.exception)) + + def test_update_before_row_raises_not_implemented(self): + mf = AggregateMergeFunction( + key_arity=1, value_arity=1, + field_aggregators=[_make_agg("sum", "BIGINT")], + ) + mf.reset() + with self.assertRaises(NotImplementedError) as ctx: + mf.add(_kv((1,), 100, RowKind.UPDATE_BEFORE, (5,))) + self.assertIn("-U", str(ctx.exception)) + + def test_reset_between_keys_clears_state(self): + mf = AggregateMergeFunction( + key_arity=1, value_arity=2, + field_aggregators=[ + _make_agg("sum", "BIGINT"), + _make_agg("first_value", "VARCHAR"), + ], + ) + # Key group 1. + mf.reset() + mf.add(_kv((1,), 100, RowKind.INSERT, (5, "a"))) + mf.add(_kv((1,), 101, RowKind.INSERT, (3, "b"))) + r1 = mf.get_result() + self.assertEqual(_result_value(r1), (8, "a")) + # Key group 2 — sum must restart from 0 and first_value must + # re-arm so the new group's first row wins. + mf.reset() + mf.add(_kv((2,), 200, RowKind.INSERT, (10, "x"))) + mf.add(_kv((2,), 201, RowKind.INSERT, (20, "y"))) + r2 = mf.get_result() + self.assertEqual(_result_key(r2), (2,)) + self.assertEqual(_result_value(r2), (30, "x")) + + def test_get_result_before_any_add_returns_none(self): + mf = AggregateMergeFunction( + key_arity=1, value_arity=1, + field_aggregators=[_make_agg("sum", "BIGINT")], + ) + mf.reset() + self.assertIsNone(mf.get_result()) + + def test_result_is_decoupled_from_input_kv(self): + """Critical: upstream KeyValueWrapReader reuses one KeyValue and + rebinds its row_tuple between iterations. The merge function + must snapshot its output so the previously-returned result is + not silently mutated when the source advances. + """ + mf = AggregateMergeFunction( + key_arity=1, value_arity=1, + field_aggregators=[_make_agg("sum", "BIGINT")], + ) + mf.reset() + # Build one reusable kv and rebind it twice, like upstream does. + kv = KeyValue(key_arity=1, value_arity=1) + kv.replace((1, 100, RowKind.INSERT.value, 5)) + mf.add(kv) + kv.replace((1, 101, RowKind.INSERT.value, 7)) + mf.add(kv) + result = mf.get_result() + + # Now mutate the source kv. The previously-captured result must + # NOT change. + kv.replace((999, 999, RowKind.INSERT.value, 99999)) + + self.assertEqual(_result_key(result), (1,)) + self.assertEqual(_result_value(result), (12,)) # 5 + 7 + self.assertEqual(result.sequence_number, 101) + + def test_value_arity_mismatch_at_construction_raises(self): + with self.assertRaises(ValueError): + AggregateMergeFunction( + key_arity=1, value_arity=2, + field_aggregators=[_make_agg("sum", "BIGINT")], + ) + + +class ResolveAggFuncNameTest(unittest.TestCase): + + def test_primary_key_takes_precedence(self): + name = resolve_agg_func_name( + "id", primary_keys={"id"}, options_map={ + "fields.id.aggregate-function": "sum", + } + ) + self.assertEqual(name, "primary_key") + + def test_field_level_override_wins_over_default(self): + name = resolve_agg_func_name( + "v", primary_keys=set(), options_map={ + "fields.v.aggregate-function": "max", + "fields.default-aggregate-function": "sum", + } + ) + self.assertEqual(name, "max") + + def test_table_default_used_when_no_field_override(self): + name = resolve_agg_func_name( + "v", primary_keys=set(), options_map={ + "fields.default-aggregate-function": "sum", + } + ) + self.assertEqual(name, "sum") + + def test_system_default_when_nothing_configured(self): + name = resolve_agg_func_name( + "v", primary_keys=set(), options_map={} + ) + self.assertEqual(name, "last_non_null_value") + + +class BuildFieldAggregatorsTest(unittest.TestCase): + + def _make_options(self, raw): + return CoreOptions(Options(raw)) + + def test_builds_aggregators_aligned_with_value_fields(self): + fields = [ + DataField(0, "id", AtomicType("BIGINT")), + DataField(1, "amount", AtomicType("BIGINT")), + DataField(2, "name", AtomicType("VARCHAR")), + ] + options = self._make_options({ + "fields.amount.aggregate-function": "sum", + }) + aggs = build_field_aggregators( + value_fields=fields, + primary_keys=["id"], + core_options=options, + ) + self.assertEqual(len(aggs), 3) + self.assertEqual(aggs[0].name, "primary_key") + self.assertEqual(aggs[1].name, "sum") + # Falls through to the system default since no override and no + # fields.default-aggregate-function is set. + self.assertEqual(aggs[2].name, "last_non_null_value") + + def test_unknown_aggregator_identifier_raises(self): + fields = [ + DataField(0, "id", AtomicType("BIGINT")), + DataField(1, "v", AtomicType("BIGINT")), + ] + options = self._make_options({ + "fields.v.aggregate-function": "no_such_aggregator", + }) + with self.assertRaises(ValueError): + build_field_aggregators( + value_fields=fields, + primary_keys=["id"], + core_options=options, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_field_aggregator_registry.py b/paimon-python/pypaimon/tests/test_field_aggregator_registry.py new file mode 100644 index 000000000000..2927f35b7bb8 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_field_aggregator_registry.py @@ -0,0 +1,103 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Unit tests for the FieldAggregator registry contract. + +Drives :func:`register_aggregator` / :func:`create_field_aggregator` +without touching the read pipeline so the wiring is pinned down before +any concrete aggregators land in :mod:`aggregators`. +""" + +import unittest + +from pypaimon.read.reader.aggregate import ( + create_field_aggregator, + register_aggregator, +) +from pypaimon.read.reader.aggregate.field_aggregator import FieldAggregator +from pypaimon.schema.data_types import AtomicType + + +class _DummyAgg(FieldAggregator): + """Minimal concrete subclass used only by these tests.""" + + def agg(self, accumulator, input_field): + return input_field + + +class FieldAggregatorRegistryTest(unittest.TestCase): + + def test_register_and_create_returns_instance(self): + register_aggregator( + "_dummy_for_registry_test", + lambda field_type, field_name, options: _DummyAgg( + "_dummy_for_registry_test", field_type + ), + ) + agg = create_field_aggregator( + AtomicType("INT"), + "field0", + "_dummy_for_registry_test", + options=None, + ) + self.assertIsInstance(agg, _DummyAgg) + self.assertEqual(agg.name, "_dummy_for_registry_test") + self.assertEqual(agg.field_type, AtomicType("INT")) + + def test_re_register_replaces_existing_factory(self): + register_aggregator( + "_dummy_replaceable", + lambda ft, fn, opts: _DummyAgg("first", ft), + ) + register_aggregator( + "_dummy_replaceable", + lambda ft, fn, opts: _DummyAgg("second", ft), + ) + agg = create_field_aggregator( + AtomicType("BIGINT"), "f", "_dummy_replaceable", options=None + ) + self.assertEqual(agg.name, "second") + + def test_unknown_identifier_raises_value_error(self): + with self.assertRaises(ValueError) as ctx: + create_field_aggregator( + AtomicType("INT"), + "field0", + "this_aggregator_does_not_exist", + options=None, + ) + msg = str(ctx.exception) + self.assertIn("unsupported aggregation", msg) + self.assertIn("this_aggregator_does_not_exist", msg) + + def test_default_retract_raises_not_implemented(self): + agg = _DummyAgg("dummy", AtomicType("INT")) + with self.assertRaises(NotImplementedError) as ctx: + agg.retract(1, 2) + self.assertIn("does not support retract", str(ctx.exception)) + self.assertIn("dummy", str(ctx.exception)) + + def test_default_reset_is_noop(self): + # Base-class reset() must not raise so subclasses without + # per-group state can skip overriding it. + agg = _DummyAgg("dummy", AtomicType("INT")) + agg.reset() # no exception expected + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_field_aggregators.py b/paimon-python/pypaimon/tests/test_field_aggregators.py new file mode 100644 index 000000000000..f54a67cd95b2 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_field_aggregators.py @@ -0,0 +1,274 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Unit tests for the built-in :class:`FieldAggregator` subclasses. + +Drives each aggregator directly to pin down the value semantics +(reset behaviour, null handling, type validation) without going +through the merge function or the read pipeline. End-to-end coverage +on real PK tables lives in ``test_aggregation_e2e.py``. +""" + +import datetime +import unittest +from decimal import Decimal + +from pypaimon.read.reader.aggregate import create_field_aggregator +from pypaimon.read.reader.aggregate.aggregators import ( + FieldBoolAndAgg, + FieldBoolOrAgg, + FieldFirstNonNullValueAgg, + FieldFirstValueAgg, + FieldLastNonNullValueAgg, + FieldLastValueAgg, + FieldMaxAgg, + FieldMinAgg, + FieldPrimaryKeyAgg, + FieldSumAgg, +) +from pypaimon.schema.data_types import AtomicType + + +def _make(identifier, sql_type): + """Build an aggregator through the public registry path so we also + exercise the registered factory (including its type validation). + """ + return create_field_aggregator( + AtomicType(sql_type), "field0", identifier, options=None + ) + + +class FieldPrimaryKeyAggTest(unittest.TestCase): + + def test_returns_input_field(self): + agg = _make("primary_key", "BIGINT") + self.assertIsInstance(agg, FieldPrimaryKeyAgg) + self.assertEqual(agg.agg(None, 5), 5) + self.assertEqual(agg.agg(99, 5), 5) + self.assertIsNone(agg.agg(5, None)) + + +class FieldLastValueAggTest(unittest.TestCase): + + def test_last_value_wins_including_null(self): + agg = _make("last_value", "VARCHAR") + self.assertIsInstance(agg, FieldLastValueAgg) + self.assertEqual(agg.agg(None, "a"), "a") + self.assertEqual(agg.agg("a", "b"), "b") + # Crucially: a later null replaces the accumulator (unlike + # last_non_null_value). + self.assertIsNone(agg.agg("a", None)) + + +class FieldLastNonNullValueAggTest(unittest.TestCase): + + def test_null_inputs_are_absorbed(self): + agg = _make("last_non_null_value", "INT") + self.assertIsInstance(agg, FieldLastNonNullValueAgg) + self.assertEqual(agg.agg(None, 1), 1) + self.assertEqual(agg.agg(1, 2), 2) + self.assertEqual(agg.agg(2, None), 2) + self.assertIsNone(agg.agg(None, None)) + + +class FieldFirstValueAggTest(unittest.TestCase): + + def test_first_value_locks_after_first_add(self): + agg = _make("first_value", "VARCHAR") + self.assertIsInstance(agg, FieldFirstValueAgg) + # First add returns input, even if input is None. + self.assertIsNone(agg.agg(None, None)) + # Subsequent adds preserve the accumulator (None) regardless of input. + self.assertIsNone(agg.agg(None, "later")) + + def test_reset_re_arms_first_value(self): + agg = _make("first_value", "INT") + self.assertEqual(agg.agg(None, 5), 5) + self.assertEqual(agg.agg(5, 9), 5) # locked + agg.reset() + # After reset the next add is treated as the first again. + self.assertEqual(agg.agg(None, 42), 42) + + +class FieldFirstNonNullValueAggTest(unittest.TestCase): + + def test_first_non_null_skips_nulls(self): + agg = _make("first_non_null_value", "INT") + self.assertIsInstance(agg, FieldFirstNonNullValueAgg) + # Initial null does not lock — accumulator stays None. + self.assertIsNone(agg.agg(None, None)) + # First non-null locks. + self.assertEqual(agg.agg(None, 7), 7) + # Subsequent values do not replace the locked first. + self.assertEqual(agg.agg(7, 99), 7) + self.assertEqual(agg.agg(7, None), 7) + + def test_reset_re_arms_first_non_null(self): + agg = _make("first_non_null_value", "INT") + self.assertEqual(agg.agg(None, 1), 1) + self.assertEqual(agg.agg(1, 2), 1) + agg.reset() + self.assertEqual(agg.agg(None, 9), 9) + + +class FieldSumAggTest(unittest.TestCase): + + def test_int_sum(self): + agg = _make("sum", "BIGINT") + self.assertIsInstance(agg, FieldSumAgg) + self.assertEqual(agg.agg(None, 5), 5) + self.assertEqual(agg.agg(5, 7), 12) + + def test_float_sum(self): + agg = _make("sum", "DOUBLE") + self.assertAlmostEqual(agg.agg(1.5, 2.25), 3.75) + + def test_decimal_sum(self): + agg = _make("sum", "DECIMAL(10,2)") + result = agg.agg(Decimal("1.23"), Decimal("4.56")) + self.assertEqual(result, Decimal("5.79")) + + def test_null_inputs_return_non_null_operand(self): + agg = _make("sum", "INT") + self.assertEqual(agg.agg(None, 5), 5) + self.assertEqual(agg.agg(5, None), 5) + self.assertIsNone(agg.agg(None, None)) + + def test_non_numeric_type_rejected_at_construction(self): + with self.assertRaises(ValueError) as ctx: + _make("sum", "VARCHAR") + self.assertIn("numeric", str(ctx.exception)) + + +class FieldMaxAggTest(unittest.TestCase): + + def test_numeric_max(self): + agg = _make("max", "INT") + self.assertIsInstance(agg, FieldMaxAgg) + self.assertEqual(agg.agg(3, 7), 7) + self.assertEqual(agg.agg(7, 3), 7) + self.assertEqual(agg.agg(5, 5), 5) + + def test_string_max(self): + agg = _make("max", "VARCHAR") + self.assertEqual(agg.agg("apple", "banana"), "banana") + self.assertEqual(agg.agg("banana", "apple"), "banana") + + def test_date_max(self): + agg = _make("max", "DATE") + d1 = datetime.date(2020, 1, 1) + d2 = datetime.date(2025, 6, 15) + self.assertEqual(agg.agg(d1, d2), d2) + self.assertEqual(agg.agg(d2, d1), d2) + + def test_null_inputs_return_non_null_operand(self): + agg = _make("max", "INT") + self.assertEqual(agg.agg(None, 5), 5) + self.assertEqual(agg.agg(5, None), 5) + self.assertIsNone(agg.agg(None, None)) + + +class FieldMinAggTest(unittest.TestCase): + + def test_numeric_min(self): + agg = _make("min", "INT") + self.assertIsInstance(agg, FieldMinAgg) + self.assertEqual(agg.agg(3, 7), 3) + self.assertEqual(agg.agg(7, 3), 3) + self.assertEqual(agg.agg(5, 5), 5) + + def test_string_min(self): + agg = _make("min", "VARCHAR") + self.assertEqual(agg.agg("apple", "banana"), "apple") + + def test_null_inputs_return_non_null_operand(self): + agg = _make("min", "INT") + self.assertEqual(agg.agg(None, 5), 5) + self.assertEqual(agg.agg(5, None), 5) + self.assertIsNone(agg.agg(None, None)) + + +class FieldBoolOrAggTest(unittest.TestCase): + + def test_truth_table(self): + agg = _make("bool_or", "BOOLEAN") + self.assertIsInstance(agg, FieldBoolOrAgg) + self.assertTrue(agg.agg(True, True)) + self.assertTrue(agg.agg(True, False)) + self.assertTrue(agg.agg(False, True)) + self.assertFalse(agg.agg(False, False)) + + def test_null_inputs_return_non_null_operand(self): + agg = _make("bool_or", "BOOLEAN") + self.assertTrue(agg.agg(None, True)) + self.assertFalse(agg.agg(False, None)) + self.assertIsNone(agg.agg(None, None)) + + def test_non_boolean_type_rejected_at_construction(self): + with self.assertRaises(ValueError) as ctx: + _make("bool_or", "INT") + self.assertIn("BOOLEAN", str(ctx.exception)) + + +class FieldBoolAndAggTest(unittest.TestCase): + + def test_truth_table(self): + agg = _make("bool_and", "BOOLEAN") + self.assertIsInstance(agg, FieldBoolAndAgg) + self.assertTrue(agg.agg(True, True)) + self.assertFalse(agg.agg(True, False)) + self.assertFalse(agg.agg(False, True)) + self.assertFalse(agg.agg(False, False)) + + def test_null_inputs_return_non_null_operand(self): + agg = _make("bool_and", "BOOLEAN") + self.assertTrue(agg.agg(None, True)) + self.assertFalse(agg.agg(False, None)) + self.assertIsNone(agg.agg(None, None)) + + def test_non_boolean_type_rejected_at_construction(self): + with self.assertRaises(ValueError) as ctx: + _make("bool_and", "VARCHAR") + self.assertIn("BOOLEAN", str(ctx.exception)) + + +class RegistrationTest(unittest.TestCase): + """Sanity check that all 10 expected aggregators (the primary-key + placeholder plus 9 value aggregators) are registered when the + package is imported. Guards against future refactors silently + dropping a registration. + """ + + EXPECTED = frozenset([ + "primary_key", + "last_value", "last_non_null_value", + "first_value", "first_non_null_value", + "sum", "max", "min", + "bool_or", "bool_and", + ]) + + def test_all_expected_aggregators_registered(self): + from pypaimon.read.reader.aggregate import _FACTORIES + registered = set(_FACTORIES.keys()) + missing = self.EXPECTED - registered + self.assertEqual(missing, set(), + "Missing built-in aggregators: {}".format(missing)) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_first_row_e2e.py b/paimon-python/pypaimon/tests/test_first_row_e2e.py new file mode 100644 index 000000000000..58776848b871 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_first_row_e2e.py @@ -0,0 +1,200 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""End-to-end tests for the ``first-row`` merge engine. + +Each test creates a PK table with ``merge-engine`` set to ``first-row``, +writes one or more batches, and reads back. The first-row engine keeps +only the earliest row per primary key. +""" + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema + + +class FirstRowMergeEngineE2ETest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('a', pa.string()), + ('b', pa.string()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_pk_table(self, table_name, extra_options=None): + options = { + 'bucket': '1', + 'merge-engine': 'first-row', + } + if extra_options: + options.update(extra_options) + schema = Schema.from_pyarrow_schema( + self.pa_schema, + primary_keys=['id'], + options=options, + ) + full = 'default.{}'.format(table_name) + self.catalog.create_table(full, schema, False) + return self.catalog.get_table(full) + + def _write(self, table, rows): + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=self.pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + def _read(self, table): + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + if not splits: + return [] + return sorted( + rb.new_read().to_arrow(splits).to_pylist(), + key=lambda r: r['id'], + ) + + def test_first_row_keeps_earliest(self): + """Two writes with the same PK — first-row keeps the first one.""" + table = self._create_pk_table('first_row_basic') + self._write(table, [{'id': 1, 'a': 'first', 'b': 'B1'}]) + self._write(table, [{'id': 1, 'a': 'second', 'b': 'B2'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'a': 'first', 'b': 'B1'}], + ) + + def test_first_row_multiple_keys(self): + """Multiple PKs across two writes — each key keeps its first row.""" + table = self._create_pk_table('first_row_multi_key') + self._write(table, [ + {'id': 1, 'a': 'A1', 'b': 'B1'}, + {'id': 2, 'a': 'A2', 'b': 'B2'}, + ]) + self._write(table, [ + {'id': 1, 'a': 'A1-new', 'b': 'B1-new'}, + {'id': 3, 'a': 'A3', 'b': 'B3'}, + ]) + + self.assertEqual( + self._read(table), + [ + {'id': 1, 'a': 'A1', 'b': 'B1'}, + {'id': 2, 'a': 'A2', 'b': 'B2'}, + {'id': 3, 'a': 'A3', 'b': 'B3'}, + ], + ) + + def test_first_row_three_writes(self): + """Three writes for the same PK — always the first one wins.""" + table = self._create_pk_table('first_row_three') + self._write(table, [{'id': 1, 'a': 'first', 'b': None}]) + self._write(table, [{'id': 1, 'a': 'second', 'b': 'B'}]) + self._write(table, [{'id': 1, 'a': 'third', 'b': 'C'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'a': 'first', 'b': None}], + ) + + def test_first_row_single_write(self): + """A single write should read back unchanged.""" + table = self._create_pk_table('first_row_single') + self._write(table, [ + {'id': 1, 'a': 'A', 'b': 'B'}, + {'id': 2, 'a': 'C', 'b': 'D'}, + ]) + + self.assertEqual( + self._read(table), + [ + {'id': 1, 'a': 'A', 'b': 'B'}, + {'id': 2, 'a': 'C', 'b': 'D'}, + ], + ) + + def test_first_row_intra_batch_duplicate(self): + """A single write whose batch already contains duplicate PKs. + + The whole batch is folded in one flush, so this exercises the + write-side fold rather than the cross-commit read merge. first-row + must keep the first occurrence of each PK. + """ + table = self._create_pk_table('first_row_intra_batch') + self._write(table, [ + {'id': 1, 'a': 'first', 'b': 'B1'}, + {'id': 1, 'a': 'second', 'b': 'B2'}, + {'id': 1, 'a': 'third', 'b': 'B3'}, + {'id': 2, 'a': 'only', 'b': 'B'}, + ]) + + self.assertEqual( + self._read(table), + [ + {'id': 1, 'a': 'first', 'b': 'B1'}, + {'id': 2, 'a': 'only', 'b': 'B'}, + ], + ) + + def test_first_row_multiple_writes_one_commit(self): + """Several write_arrow calls committed once: the same PK across + those writes folds in a single flush. first-row keeps the first. + """ + table = self._create_pk_table('first_row_multi_write_one_commit') + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist( + [{'id': 1, 'a': 'first', 'b': 'B1'}], schema=self.pa_schema)) + w.write_arrow(pa.Table.from_pylist( + [{'id': 1, 'a': 'second', 'b': 'B2'}], schema=self.pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + self.assertEqual( + self._read(table), + [{'id': 1, 'a': 'first', 'b': 'B1'}], + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_first_row_merge_function.py b/paimon-python/pypaimon/tests/test_first_row_merge_function.py new file mode 100644 index 000000000000..5e565619846f --- /dev/null +++ b/paimon-python/pypaimon/tests/test_first_row_merge_function.py @@ -0,0 +1,146 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Direct unit tests for ``FirstRowMergeFunction``. + +Drives the merge function with synthetic ``KeyValue`` instances so the +contract is pinned down without going through the full read pipeline. +""" + +import unittest + +from pypaimon.read.reader.first_row_merge_function import \ + FirstRowMergeFunction +from pypaimon.table.row.key_value import KeyValue +from pypaimon.table.row.row_kind import RowKind + + +def _kv(key, seq, row_kind, value): + kv = KeyValue(key_arity=len(key), value_arity=len(value)) + kv.replace(tuple(key) + (seq, row_kind.value) + tuple(value)) + return kv + + +def _result_value(kv): + return tuple(kv.value.get_field(i) for i in range(kv.value_arity)) + + +def _result_key(kv): + return tuple(kv.key.get_field(i) for i in range(kv.key_arity)) + + +class FirstRowMergeFunctionTest(unittest.TestCase): + + def test_single_insert_returns_value(self): + mf = FirstRowMergeFunction() + mf.reset() + mf.add(_kv((1,), 1, RowKind.INSERT, (10, "a"))) + result = mf.get_result() + self.assertIsNotNone(result) + self.assertEqual(_result_key(result), (1,)) + self.assertEqual(_result_value(result), (10, "a")) + + def test_keeps_first_row_not_latest(self): + mf = FirstRowMergeFunction() + mf.reset() + mf.add(_kv((1,), 1, RowKind.INSERT, (10, "first"))) + mf.add(_kv((1,), 2, RowKind.INSERT, (20, "second"))) + mf.add(_kv((1,), 3, RowKind.UPDATE_AFTER, (30, "third"))) + result = mf.get_result() + self.assertEqual(_result_value(result), (10, "first")) + + def test_keeps_first_row_when_kv_is_pooled(self): + # The writer's fold (KeyValueDataWriter._merge_pending_by_pk) reuses + # a single KeyValue and replace()s it per row. add() must snapshot + # the first row; otherwise get_result tracks the pooled kv's last + # replace() and returns the LAST row -- silently turning first-row + # into last-row. This is the case the per-row _kv() tests miss. + mf = FirstRowMergeFunction() + mf.reset() + pooled = KeyValue(key_arity=1, value_arity=2) + pooled.replace((1, 1, RowKind.INSERT.value, 10, "first")) + mf.add(pooled) + pooled.replace((1, 2, RowKind.INSERT.value, 20, "second")) + mf.add(pooled) + result = mf.get_result() + self.assertEqual(_result_value(result), (10, "first")) + + def test_reset_clears_state(self): + mf = FirstRowMergeFunction() + mf.reset() + mf.add(_kv((1,), 1, RowKind.INSERT, (10,))) + self.assertIsNotNone(mf.get_result()) + + mf.reset() + self.assertIsNone(mf.get_result()) + + mf.add(_kv((2,), 2, RowKind.INSERT, (20,))) + result = mf.get_result() + self.assertEqual(_result_key(result), (2,)) + self.assertEqual(_result_value(result), (20,)) + + def test_empty_returns_none(self): + mf = FirstRowMergeFunction() + mf.reset() + self.assertIsNone(mf.get_result()) + + def test_delete_raises_by_default(self): + mf = FirstRowMergeFunction(ignore_delete=False) + mf.reset() + with self.assertRaises(ValueError): + mf.add(_kv((1,), 1, RowKind.DELETE, (10,))) + + def test_update_before_raises_by_default(self): + mf = FirstRowMergeFunction(ignore_delete=False) + mf.reset() + with self.assertRaises(ValueError): + mf.add(_kv((1,), 1, RowKind.UPDATE_BEFORE, (10,))) + + def test_ignore_delete_skips_retract(self): + mf = FirstRowMergeFunction(ignore_delete=True) + mf.reset() + mf.add(_kv((1,), 1, RowKind.DELETE, (10,))) + mf.add(_kv((1,), 2, RowKind.INSERT, (20,))) + result = mf.get_result() + self.assertIsNotNone(result) + self.assertEqual(_result_value(result), (20,)) + + def test_ignore_delete_skips_update_before(self): + mf = FirstRowMergeFunction(ignore_delete=True) + mf.reset() + mf.add(_kv((1,), 1, RowKind.UPDATE_BEFORE, (10,))) + self.assertIsNone(mf.get_result()) + + def test_ignore_delete_only_retract_returns_none(self): + mf = FirstRowMergeFunction(ignore_delete=True) + mf.reset() + mf.add(_kv((1,), 1, RowKind.DELETE, (10,))) + mf.add(_kv((1,), 2, RowKind.UPDATE_BEFORE, (20,))) + self.assertIsNone(mf.get_result()) + + def test_update_after_accepted_as_first(self): + mf = FirstRowMergeFunction() + mf.reset() + mf.add(_kv((1,), 1, RowKind.UPDATE_AFTER, (10,))) + result = mf.get_result() + self.assertIsNotNone(result) + self.assertEqual(_result_value(result), (10,)) + + +if __name__ == "__main__": + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_format_mosaic_reader_writer.py b/paimon-python/pypaimon/tests/test_format_mosaic_reader_writer.py new file mode 100644 index 000000000000..154d56cd3c8d --- /dev/null +++ b/paimon-python/pypaimon/tests/test_format_mosaic_reader_writer.py @@ -0,0 +1,304 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tempfile + +import pyarrow as pa +import pytest + +import mosaic +from pypaimon.read.reader.format_mosaic_reader import FormatMosaicReader +from pypaimon.schema.data_types import AtomicType, DataField + + +class SimpleFileIO: + """Minimal FileIO for testing.""" + + def get_file_size(self, path): + return os.path.getsize(path) + + def new_input_stream(self, path): + return open(path, 'rb') + + +def _write_mosaic_file(path, data: pa.Table): + with open(path, 'wb') as f: + mosaic.write_table(data, f) + + +def _read_mosaic_file(path, read_fields, push_down_predicate=None): + file_io = SimpleFileIO() + reader = FormatMosaicReader(file_io, path, read_fields, + push_down_predicate, batch_size=1024) + batches = [] + while True: + batch = reader.read_arrow_batch() + if batch is None: + break + batches.append(batch) + reader.close() + if not batches: + return pa.table({f.name: pa.array([], type=pa.int32()) for f in read_fields}) + return pa.Table.from_batches(batches) + + +class TestFormatMosaicReaderWriter: + + def test_basic_int_string(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + ] + data = pa.table({ + "id": pa.array([1, 2, 3], type=pa.int32()), + "name": pa.array(["alice", "bob", "charlie"], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + result = _read_mosaic_file(path, fields) + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("name").to_pylist() == ["alice", "bob", "charlie"] + finally: + os.unlink(path) + + def test_all_primitive_types(self): + fields = [ + DataField(0, "bool_col", AtomicType("BOOLEAN")), + DataField(1, "tinyint_col", AtomicType("TINYINT")), + DataField(2, "smallint_col", AtomicType("SMALLINT")), + DataField(3, "int_col", AtomicType("INT")), + DataField(4, "bigint_col", AtomicType("BIGINT")), + DataField(5, "float_col", AtomicType("FLOAT")), + DataField(6, "double_col", AtomicType("DOUBLE")), + DataField(7, "string_col", AtomicType("STRING")), + DataField(8, "binary_col", AtomicType("BYTES")), + ] + data = pa.table({ + "bool_col": pa.array([True, False], type=pa.bool_()), + "tinyint_col": pa.array([1, -1], type=pa.int8()), + "smallint_col": pa.array([100, -100], type=pa.int16()), + "int_col": pa.array([1000, -1000], type=pa.int32()), + "bigint_col": pa.array([100000, -100000], type=pa.int64()), + "float_col": pa.array([1.5, -2.5], type=pa.float32()), + "double_col": pa.array([3.14, -2.71], type=pa.float64()), + "string_col": pa.array(["hello", "world"], type=pa.string()), + "binary_col": pa.array([b"\x01\x02", b"\x03\x04"], type=pa.binary()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + result = _read_mosaic_file(path, fields) + assert result.column("bool_col").to_pylist() == [True, False] + assert result.column("tinyint_col").to_pylist() == [1, -1] + assert result.column("smallint_col").to_pylist() == [100, -100] + assert result.column("int_col").to_pylist() == [1000, -1000] + assert result.column("bigint_col").to_pylist() == [100000, -100000] + assert result.column("float_col").to_pylist()[0] == pytest.approx(1.5) + assert result.column("double_col").to_pylist() == [pytest.approx(3.14), pytest.approx(-2.71)] + assert result.column("string_col").to_pylist() == ["hello", "world"] + assert result.column("binary_col").to_pylist() == [b"\x01\x02", b"\x03\x04"] + finally: + os.unlink(path) + + def test_nulls(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + ] + data = pa.table({ + "id": pa.array([1, None, 3], type=pa.int32()), + "name": pa.array([None, "bob", None], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + result = _read_mosaic_file(path, fields) + assert result.column("id").to_pylist() == [1, None, 3] + assert result.column("name").to_pylist() == [None, "bob", None] + finally: + os.unlink(path) + + def test_decimal(self): + from decimal import Decimal + + fields = [ + DataField(0, "d1", AtomicType("DECIMAL(10, 2)")), + ] + data = pa.table({ + "d1": pa.array([Decimal("123.45"), Decimal("-67.89")], type=pa.decimal128(10, 2)), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + result = _read_mosaic_file(path, fields) + assert result.column("d1").to_pylist() == [Decimal("123.45"), Decimal("-67.89")] + finally: + os.unlink(path) + + def test_timestamp(self): + fields = [ + DataField(0, "ts_millis", AtomicType("TIMESTAMP(3)")), + ] + data = pa.table({ + "ts_millis": pa.array([1000, 2000], type=pa.timestamp('ms')), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + result = _read_mosaic_file(path, fields) + assert result.num_rows == 2 + finally: + os.unlink(path) + + def test_column_projection(self): + data = pa.table({ + "id": pa.array([1, 2, 3], type=pa.int32()), + "name": pa.array(["a", "b", "c"], type=pa.string()), + "value": pa.array([1.1, 2.2, 3.3], type=pa.float64()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + projected_fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(2, "value", AtomicType("DOUBLE")), + ] + result = _read_mosaic_file(path, projected_fields) + assert result.num_columns == 2 + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("value").to_pylist() == [ + pytest.approx(1.1), pytest.approx(2.2), pytest.approx(3.3)] + finally: + os.unlink(path) + + def test_schema_evolution_missing_field(self): + """Reading a file that doesn't have a column added later (schema evolution).""" + data = pa.table({ + "id": pa.array([1, 2], type=pa.int32()), + "name": pa.array(["a", "b"], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + fields_read = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + DataField(2, "score", AtomicType("DOUBLE")), + ] + result = _read_mosaic_file(path, fields_read) + assert result.column("id").to_pylist() == [1, 2] + assert result.column("name").to_pylist() == ["a", "b"] + assert result.column("score").to_pylist() == [None, None] + finally: + os.unlink(path) + + def test_predicate_pushdown(self): + import pyarrow.compute as pc + + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + ] + data = pa.table({ + "id": pa.array(list(range(100)), type=pa.int32()), + "name": pa.array([f"user_{i}" for i in range(100)], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + predicate = pc.field("id") > 95 + result = _read_mosaic_file(path, fields, push_down_predicate=predicate) + assert result.num_rows == 4 + assert all(v > 95 for v in result.column("id").to_pylist()) + finally: + os.unlink(path) + + def test_large_dataset(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "data", AtomicType("STRING")), + ] + num_rows = 10000 + data = pa.table({ + "id": pa.array(list(range(num_rows)), type=pa.int32()), + "data": pa.array([f"value_{i}" for i in range(num_rows)], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + _write_mosaic_file(path, data) + result = _read_mosaic_file(path, fields) + assert result.num_rows == num_rows + assert result.column("id").to_pylist() == list(range(num_rows)) + finally: + os.unlink(path) + + def test_write_mosaic_local_file_io(self): + """Test write_mosaic via LocalFileIO.""" + from pypaimon.filesystem.local_file_io import LocalFileIO + + data = pa.table({ + "id": pa.array([1, 2, 3], type=pa.int32()), + "name": pa.array(["a", "b", "c"], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".mosaic", delete=False) as tmp: + path = tmp.name + + try: + file_io = LocalFileIO({}) + file_io.write_mosaic(path, data) + + assert os.path.getsize(path) > 0 + + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + ] + result = _read_mosaic_file(path, fields) + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("name").to_pylist() == ["a", "b", "c"] + finally: + os.unlink(path) diff --git a/paimon-python/pypaimon/tests/test_format_mosaic_table.py b/paimon-python/pypaimon/tests/test_format_mosaic_table.py new file mode 100644 index 000000000000..82a769cdf55e --- /dev/null +++ b/paimon-python/pypaimon/tests/test_format_mosaic_table.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Integration tests for the Mosaic file format across table types.""" + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema + + +class MosaicFormatAppendOnlyTest(unittest.TestCase): + """Test Mosaic format with append-only (non-primary-key) tables.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('user_id', pa.int32()), + ('item_id', pa.int64()), + ('behavior', pa.string()), + ('dt', pa.string()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _write_and_read(self, table, data): + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + return read_builder.new_read().to_arrow(splits) + + def test_append_only_no_partition(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, options={'file.format': 'mosaic'}) + self.catalog.create_table('default.ao_mosaic_no_part', schema, False) + table = self.catalog.get_table('default.ao_mosaic_no_part') + + data = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['buy', 'click', 'view'], + 'dt': ['2024-01-01', '2024-01-01', '2024-01-02'], + }, schema=self.pa_schema) + + result = self._write_and_read(table, data) + self.assertEqual(result.num_rows, 3) + self.assertEqual(sorted(result.column('user_id').to_pylist()), [1, 2, 3]) + + def test_append_only_with_partition(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, partition_keys=['dt'], + options={'file.format': 'mosaic'}) + self.catalog.create_table('default.ao_mosaic_partitioned', schema, False) + table = self.catalog.get_table('default.ao_mosaic_partitioned') + + data = pa.Table.from_pydict({ + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'behavior': ['buy', 'click', 'view', 'buy'], + 'dt': ['p1', 'p1', 'p2', 'p2'], + }, schema=self.pa_schema) + + result = self._write_and_read(table, data) + self.assertEqual(result.num_rows, 4) + + def test_append_only_multiple_commits(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, options={'file.format': 'mosaic'}) + self.catalog.create_table('default.ao_mosaic_multi', schema, False) + table = self.catalog.get_table('default.ao_mosaic_multi') + + write_builder = table.new_batch_write_builder() + + for i in range(3): + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = pa.Table.from_pydict({ + 'user_id': [i * 10 + 1, i * 10 + 2], + 'item_id': [int(i * 100 + 1), int(i * 100 + 2)], + 'behavior': ['buy', 'click'], + 'dt': ['2024-01-01', '2024-01-01'], + }, schema=self.pa_schema) + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + table = self.catalog.get_table('default.ao_mosaic_multi') + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 6) + + +class MosaicFormatPrimaryKeyTest(unittest.TestCase): + """Test Mosaic format with primary-key tables.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('pk', pa.int32()), + ('value', pa.string()), + ('amount', pa.float64()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def test_primary_key_deduplicate(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, primary_keys=['pk'], + options={'file.format': 'mosaic'}) + self.catalog.create_table('default.pk_mosaic_dedup', schema, False) + table = self.catalog.get_table('default.pk_mosaic_dedup') + + write_builder = table.new_batch_write_builder() + + # First write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = pa.Table.from_pydict({ + 'pk': [1, 2, 3], + 'value': ['a', 'b', 'c'], + 'amount': [1.0, 2.0, 3.0], + }, schema=self.pa_schema) + table_write.write_arrow(data1) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Second write (update pk=2) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = pa.Table.from_pydict({ + 'pk': [2], + 'value': ['updated'], + 'amount': [22.0], + }, schema=self.pa_schema) + table_write.write_arrow(data2) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + table = self.catalog.get_table('default.pk_mosaic_dedup') + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + + self.assertEqual(result.num_rows, 3) + result_sorted = result.sort_by('pk') + self.assertEqual(result_sorted.column('value').to_pylist(), ['a', 'updated', 'c']) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_format_row_reader_writer.py b/paimon-python/pypaimon/tests/test_format_row_reader_writer.py new file mode 100644 index 000000000000..cff237b5be0d --- /dev/null +++ b/paimon-python/pypaimon/tests/test_format_row_reader_writer.py @@ -0,0 +1,534 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tempfile +from decimal import Decimal + +import pyarrow as pa +import pytest + +from pypaimon.read.reader.format_row_reader import FormatRowReader +from pypaimon.schema.data_types import ( + ArrayType, AtomicType, DataField, MapType, RowType +) +from pypaimon.write.writer.format_row_writer import FormatRowWriter + + +class SimpleFileIO: + """Minimal FileIO for testing.""" + + def get_file_size(self, path): + return os.path.getsize(path) + + def new_input_stream(self, path): + return open(path, 'rb') + + +def _write_row_file(path, fields, data_table): + with open(path, 'wb') as f: + writer = FormatRowWriter(f, fields) + writer.write_table(data_table) + writer.close() + + +def _read_row_file(path, fields, read_field_names=None, row_indices=None): + file_io = SimpleFileIO() + if read_field_names is None: + read_field_names = [f.name for f in fields] + reader = FormatRowReader(file_io, path, read_field_names, fields, None, + row_indices=row_indices) + batches = [] + while True: + batch = reader.read_arrow_batch() + if batch is None: + break + batches.append(batch) + reader.close() + if not batches: + return pa.table({f.name: [] for f in fields}) + return pa.Table.from_batches(batches) + + +class TestFormatRowReaderWriter: + + def test_basic_int_string(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + ] + data = pa.table({ + "id": pa.array([1, 2, 3], type=pa.int32()), + "name": pa.array(["alice", "bob", "charlie"], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("name").to_pylist() == ["alice", "bob", "charlie"] + finally: + os.unlink(path) + + def test_all_primitive_types(self): + fields = [ + DataField(0, "bool_col", AtomicType("BOOLEAN")), + DataField(1, "tinyint_col", AtomicType("TINYINT")), + DataField(2, "smallint_col", AtomicType("SMALLINT")), + DataField(3, "int_col", AtomicType("INT")), + DataField(4, "bigint_col", AtomicType("BIGINT")), + DataField(5, "float_col", AtomicType("FLOAT")), + DataField(6, "double_col", AtomicType("DOUBLE")), + DataField(7, "string_col", AtomicType("STRING")), + DataField(8, "binary_col", AtomicType("BYTES")), + ] + data = pa.table({ + "bool_col": pa.array([True, False], type=pa.bool_()), + "tinyint_col": pa.array([1, -1], type=pa.int8()), + "smallint_col": pa.array([100, -100], type=pa.int16()), + "int_col": pa.array([1000, -1000], type=pa.int32()), + "bigint_col": pa.array([100000, -100000], type=pa.int64()), + "float_col": pa.array([1.5, -2.5], type=pa.float32()), + "double_col": pa.array([3.14, -2.71], type=pa.float64()), + "string_col": pa.array(["hello", "world"], type=pa.string()), + "binary_col": pa.array([b"\x01\x02", b"\x03\x04"], type=pa.binary()), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.column("bool_col").to_pylist() == [True, False] + assert result.column("tinyint_col").to_pylist() == [1, -1] + assert result.column("smallint_col").to_pylist() == [100, -100] + assert result.column("int_col").to_pylist() == [1000, -1000] + assert result.column("bigint_col").to_pylist() == [100000, -100000] + assert result.column("float_col").to_pylist()[0] == pytest.approx(1.5) + assert result.column("float_col").to_pylist()[1] == pytest.approx(-2.5) + assert result.column("double_col").to_pylist() == [pytest.approx(3.14), pytest.approx(-2.71)] + assert result.column("string_col").to_pylist() == ["hello", "world"] + assert result.column("binary_col").to_pylist() == [b"\x01\x02", b"\x03\x04"] + finally: + os.unlink(path) + + def test_nulls(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + ] + data = pa.table({ + "id": pa.array([1, None, 3], type=pa.int32()), + "name": pa.array([None, "bob", None], type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.column("id").to_pylist() == [1, None, 3] + assert result.column("name").to_pylist() == [None, "bob", None] + finally: + os.unlink(path) + + def test_decimal(self): + fields = [ + DataField(0, "d1", AtomicType("DECIMAL(10, 2)")), + DataField(1, "d2", AtomicType("DECIMAL(20, 5)")), + ] + data = pa.table({ + "d1": pa.array([Decimal("123.45"), Decimal("-67.89")], type=pa.decimal128(10, 2)), + "d2": pa.array([Decimal("12345.67890"), Decimal("-99999.12345")], type=pa.decimal128(20, 5)), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.column("d1").to_pylist() == [Decimal("123.45"), Decimal("-67.89")] + assert result.column("d2").to_pylist() == [Decimal("12345.67890"), Decimal("-99999.12345")] + finally: + os.unlink(path) + + def test_timestamp(self): + fields = [ + DataField(0, "ts_millis", AtomicType("TIMESTAMP(3)")), + DataField(1, "ts_micros", AtomicType("TIMESTAMP(6)")), + ] + data = pa.table({ + "ts_millis": pa.array([1000, 2000], type=pa.timestamp('ms')), + "ts_micros": pa.array([1000000, 2000000], type=pa.timestamp('us')), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.num_rows == 2 + finally: + os.unlink(path) + + def test_array_type(self): + element_type = AtomicType("INT") + fields = [ + DataField(0, "arr", ArrayType(True, element_type)), + ] + data = pa.table({ + "arr": pa.array([[1, 2, 3], [4, 5]], type=pa.list_(pa.int32())), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.column("arr").to_pylist() == [[1, 2, 3], [4, 5]] + finally: + os.unlink(path) + + def test_map_type(self): + fields = [ + DataField(0, "m", MapType(True, AtomicType("STRING"), AtomicType("INT"))), + ] + data = pa.table({ + "m": pa.array( + [[("a", 1), ("b", 2)], [("c", 3)]], + type=pa.map_(pa.string(), pa.int32()) + ), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + result_maps = result.column("m").to_pylist() + assert len(result_maps) == 2 + assert len(result_maps[0]) == 2 + assert len(result_maps[1]) == 1 + finally: + os.unlink(path) + + def test_nested_row(self): + inner_type = RowType(True, [ + DataField(0, "x", AtomicType("INT")), + DataField(1, "y", AtomicType("STRING")), + ]) + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "nested", inner_type), + ] + data = pa.table({ + "id": pa.array([1, 2], type=pa.int32()), + "nested": pa.array( + [{"x": 10, "y": "a"}, {"x": 20, "y": "b"}], + type=pa.struct([ + pa.field("x", pa.int32()), + pa.field("y", pa.string()), + ]) + ), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.column("id").to_pylist() == [1, 2] + nested = result.column("nested").to_pylist() + assert nested[0] == {"x": 10, "y": "a"} + assert nested[1] == {"x": 20, "y": "b"} + finally: + os.unlink(path) + + def test_multi_block(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "data", AtomicType("STRING")), + ] + num_rows = 5000 + ids = list(range(num_rows)) + strings = [f"value_{i}" for i in range(num_rows)] + data = pa.table({ + "id": pa.array(ids, type=pa.int32()), + "data": pa.array(strings, type=pa.string()), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + with open(path, 'wb') as f: + writer = FormatRowWriter(f, fields, block_size=4096) + writer.write_table(data) + writer.close() + + result = _read_row_file(path, fields) + assert result.num_rows == num_rows + assert result.column("id").to_pylist() == ids + assert result.column("data").to_pylist() == strings + finally: + os.unlink(path) + + def test_empty_file(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + ] + data = pa.table({ + "id": pa.array([], type=pa.int32()), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.num_rows == 0 + finally: + os.unlink(path) + + def test_column_projection(self): + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + DataField(2, "value", AtomicType("DOUBLE")), + ] + data = pa.table({ + "id": pa.array([1, 2, 3], type=pa.int32()), + "name": pa.array(["a", "b", "c"], type=pa.string()), + "value": pa.array([1.1, 2.2, 3.3], type=pa.float64()), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields, read_field_names=["id", "value"]) + assert result.num_columns == 2 + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("value").to_pylist() == [pytest.approx(1.1), pytest.approx(2.2), pytest.approx(3.3)] + finally: + os.unlink(path) + + def test_date_and_time(self): + fields = [ + DataField(0, "d", AtomicType("DATE")), + DataField(1, "t", AtomicType("TIME")), + ] + data = pa.table({ + "d": pa.array([18000, 19000], type=pa.date32()), + "t": pa.array([3600000, 7200000], type=pa.time32('ms')), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + assert result.num_rows == 2 + finally: + os.unlink(path) + + def test_variant_type(self): + fields = [ + DataField(0, "v", AtomicType("VARIANT")), + ] + data = pa.table({ + "v": pa.array( + [{"value": b"\x01\x02", "metadata": b"\x03\x04"}, + {"value": b"\x05", "metadata": b"\x06\x07\x08"}], + type=pa.struct([ + pa.field("value", pa.binary(), nullable=False), + pa.field("metadata", pa.binary(), nullable=False), + ]) + ), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + result = _read_row_file(path, fields) + variants = result.column("v").to_pylist() + assert variants[0]["value"] == b"\x01\x02" + assert variants[0]["metadata"] == b"\x03\x04" + assert variants[1]["value"] == b"\x05" + assert variants[1]["metadata"] == b"\x06\x07\x08" + finally: + os.unlink(path) + + def test_row_indices_random_access(self): + """Test reading specific rows by index (O(1) row-number lookup).""" + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("VARCHAR")), + ] + data = pa.table({ + "id": pa.array(list(range(100)), type=pa.int32()), + "name": pa.array([f"row_{i}" for i in range(100)]), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + _write_row_file(path, fields, data) + + # Read specific rows: 0, 5, 50, 99 + result = _read_row_file(path, fields, row_indices=[0, 5, 50, 99]) + assert result.num_rows == 4 + assert result.column("id").to_pylist() == [0, 5, 50, 99] + assert result.column("name").to_pylist() == [ + "row_0", "row_5", "row_50", "row_99" + ] + + # Read single row + result = _read_row_file(path, fields, row_indices=[42]) + assert result.num_rows == 1 + assert result.column("id").to_pylist() == [42] + + # Read empty indices + result = _read_row_file(path, fields, row_indices=[]) + assert result.num_rows == 0 + finally: + os.unlink(path) + + def test_row_indices_multi_block(self): + """Test row_indices across multiple blocks.""" + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "value", AtomicType("VARCHAR")), + ] + # Write enough data to create multiple blocks (small block size) + n_rows = 500 + data = pa.table({ + "id": pa.array(list(range(n_rows)), type=pa.int32()), + "value": pa.array([f"val_{i}" * 10 for i in range(n_rows)]), + }) + + with tempfile.NamedTemporaryFile(suffix=".row", delete=False) as tmp: + path = tmp.name + + try: + with open(path, 'wb') as f: + writer = FormatRowWriter(f, fields, block_size=1024) + writer.write_table(data) + writer.close() + + # Read rows from different blocks + indices = [0, 100, 200, 300, 499] + result = _read_row_file(path, fields, row_indices=indices) + assert result.num_rows == 5 + assert result.column("id").to_pylist() == indices + finally: + os.unlink(path) + + def test_data_evolution_row_id_read(self): + """Test Data Evolution scenario: partial-column write then row-id based read. + + Simulates the Data Evolution pattern where: + 1. First commit writes columns (f0, f1) + 2. Second commit writes column (f2) with first_row_id=0 + 3. Read merges by row ID to reconstruct full rows + """ + import shutil + from pypaimon import CatalogFactory, Schema + + tempdir = tempfile.mkdtemp() + try: + warehouse = os.path.join(tempdir, 'warehouse') + catalog = CatalogFactory.create({'warehouse': warehouse}) + catalog.create_database('default', True) + + pa_schema = pa.schema([ + ('f0', pa.int32()), + ('f1', pa.string()), + ('f2', pa.string()), + ]) + + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'file.format': 'row', + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + catalog.create_table('default.de_row_id_test', schema, False) + table = catalog.get_table('default.de_row_id_test') + + # First commit: write (f0, f1) + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write().with_write_type(['f0', 'f1']) + table_commit = write_builder.new_commit() + + data1 = pa.table({ + 'f0': pa.array([1, 2, 3, 4, 5], type=pa.int32()), + 'f1': pa.array(['a1', 'a2', 'a3', 'a4', 'a5']), + }) + table_write.write_arrow(data1) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Second commit: write (f2) with first_row_id = 0 + table_write = write_builder.new_write().with_write_type(['f2']) + table_commit = write_builder.new_commit() + + data2 = pa.table({ + 'f2': pa.array(['b1', 'b2', 'b3', 'b4', 'b5']), + }) + table_write.write_arrow(data2) + cmts = table_write.prepare_commit() + cmts[0].new_files[0].first_row_id = 0 + table_commit.commit(cmts) + table_write.close() + table_commit.close() + + # Read full table - should merge partial columns by row ID + table = catalog.get_table('default.de_row_id_test') + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + + assert result.num_rows == 5 + result_sorted = result.sort_by('f0') + assert result_sorted.column('f0').to_pylist() == [1, 2, 3, 4, 5] + assert result_sorted.column('f1').to_pylist() == [ + 'a1', 'a2', 'a3', 'a4', 'a5' + ] + assert result_sorted.column('f2').to_pylist() == [ + 'b1', 'b2', 'b3', 'b4', 'b5' + ] + finally: + shutil.rmtree(tempdir, ignore_errors=True) diff --git a/paimon-python/pypaimon/tests/test_format_row_table.py b/paimon-python/pypaimon/tests/test_format_row_table.py new file mode 100644 index 000000000000..c4831445475b --- /dev/null +++ b/paimon-python/pypaimon/tests/test_format_row_table.py @@ -0,0 +1,503 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Integration tests for the ROW file format across all table types.""" + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema + + +class RowFormatAppendOnlyTest(unittest.TestCase): + """Test ROW format with append-only (non-primary-key) tables.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('user_id', pa.int32()), + ('item_id', pa.int64()), + ('behavior', pa.string()), + ('dt', pa.string()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _write_and_read(self, table, data): + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + return read_builder.new_read().to_arrow(splits) + + def test_append_only_no_partition(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, options={'file.format': 'row'}) + self.catalog.create_table('default.ao_row_no_part', schema, False) + table = self.catalog.get_table('default.ao_row_no_part') + + data = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['buy', 'click', 'view'], + 'dt': ['2024-01-01', '2024-01-01', '2024-01-02'], + }, schema=self.pa_schema) + + result = self._write_and_read(table, data) + self.assertEqual(result.num_rows, 3) + self.assertEqual(sorted(result.column('user_id').to_pylist()), [1, 2, 3]) + + def test_append_only_with_partition(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, partition_keys=['dt'], + options={'file.format': 'row'}) + self.catalog.create_table('default.ao_row_partitioned', schema, False) + table = self.catalog.get_table('default.ao_row_partitioned') + + data = pa.Table.from_pydict({ + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'behavior': ['buy', 'click', 'view', 'buy'], + 'dt': ['p1', 'p1', 'p2', 'p2'], + }, schema=self.pa_schema) + + result = self._write_and_read(table, data) + self.assertEqual(result.num_rows, 4) + + def test_append_only_multiple_commits(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, options={'file.format': 'row'}) + self.catalog.create_table('default.ao_row_multi_commit', schema, False) + table = self.catalog.get_table('default.ao_row_multi_commit') + + data1 = pa.Table.from_pydict({ + 'user_id': [1, 2], + 'item_id': [1001, 1002], + 'behavior': ['buy', 'click'], + 'dt': ['2024-01-01', '2024-01-01'], + }, schema=self.pa_schema) + + data2 = pa.Table.from_pydict({ + 'user_id': [3, 4], + 'item_id': [1003, 1004], + 'behavior': ['view', 'buy'], + 'dt': ['2024-01-02', '2024-01-02'], + }, schema=self.pa_schema) + + write_builder = table.new_batch_write_builder() + + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data1) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data2) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 4) + self.assertEqual(sorted(result.column('user_id').to_pylist()), [1, 2, 3, 4]) + + def test_append_only_with_nulls(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, options={'file.format': 'row'}) + self.catalog.create_table('default.ao_row_nulls', schema, False) + table = self.catalog.get_table('default.ao_row_nulls') + + data = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [None, 1002, None], + 'behavior': ['buy', None, 'view'], + 'dt': ['2024-01-01', '2024-01-01', '2024-01-02'], + }, schema=self.pa_schema) + + result = self._write_and_read(table, data) + self.assertEqual(result.num_rows, 3) + item_ids = sorted(result.column('item_id').to_pylist(), key=lambda x: (x is None, x)) + self.assertEqual(item_ids, [1002, None, None]) + + def test_append_only_column_projection(self): + """Test that reading with column projection decodes correctly.""" + schema = Schema.from_pyarrow_schema( + self.pa_schema, options={'file.format': 'row'}) + self.catalog.create_table('default.ao_row_projection', schema, False) + table = self.catalog.get_table('default.ao_row_projection') + + data = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['buy', 'click', 'view'], + 'dt': ['2024-01-01', '2024-01-01', '2024-01-02'], + }, schema=self.pa_schema) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Read only (user_id, behavior) - skipping item_id and dt + read_builder = table.new_read_builder() + read_builder = read_builder.with_projection(['user_id', 'behavior']) + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.schema.names, ['user_id', 'behavior']) + self.assertEqual(sorted(result.column('user_id').to_pylist()), [1, 2, 3]) + self.assertEqual( + sorted(result.column('behavior').to_pylist()), + ['buy', 'click', 'view']) + + # Read only (item_id, dt) - skipping user_id and behavior + read_builder = table.new_read_builder() + read_builder = read_builder.with_projection(['item_id', 'dt']) + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.schema.names, ['item_id', 'dt']) + self.assertEqual( + sorted(result.column('item_id').to_pylist()), [1001, 1002, 1003]) + + +class RowFormatPrimaryKeyTest(unittest.TestCase): + """Test ROW format with primary-key tables.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + pa.field('user_id', pa.int32(), nullable=False), + ('item_id', pa.int64()), + ('behavior', pa.string()), + pa.field('dt', pa.string(), nullable=False), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def test_pk_table_basic(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, + partition_keys=['dt'], + primary_keys=['user_id', 'dt'], + options={'bucket': '1', 'file.format': 'row'}) + self.catalog.create_table('default.pk_row_basic', schema, False) + table = self.catalog.get_table('default.pk_row_basic') + + data = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['buy', 'click', 'view'], + 'dt': ['p1', 'p1', 'p2'], + }, schema=self.pa_schema) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 3) + + def test_pk_table_upsert(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, + partition_keys=['dt'], + primary_keys=['user_id', 'dt'], + options={'bucket': '1', 'file.format': 'row'}) + self.catalog.create_table('default.pk_row_upsert', schema, False) + table = self.catalog.get_table('default.pk_row_upsert') + + write_builder = table.new_batch_write_builder() + + # First commit + data1 = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['buy', 'click', 'view'], + 'dt': ['p1', 'p1', 'p1'], + }, schema=self.pa_schema) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data1) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # Second commit - update user_id=2 + data2 = pa.Table.from_pydict({ + 'user_id': [2, 4], + 'item_id': [1002, 1004], + 'behavior': ['buy-updated', 'new'], + 'dt': ['p1', 'p1'], + }, schema=self.pa_schema) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data2) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 4) + result_dict = result.sort_by('user_id').to_pydict() + self.assertEqual(result_dict['user_id'], [1, 2, 3, 4]) + self.assertEqual(result_dict['behavior'], ['buy', 'buy-updated', 'view', 'new']) + + +class RowFormatDataEvolutionTest(unittest.TestCase): + """Test ROW format with data-evolution enabled tables.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ('city', pa.string()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def test_data_evolution_write_read(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, + options={ + 'file.format': 'row', + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.de_row_basic', schema, False) + table = self.catalog.get_table('default.de_row_basic') + + data = pa.Table.from_pydict({ + 'id': [1, 2, 3], + 'name': ['alice', 'bob', 'charlie'], + 'age': [25, 30, 35], + 'city': ['beijing', 'shanghai', 'guangzhou'], + }, schema=self.pa_schema) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 3) + self.assertEqual(sorted(result.column('id').to_pylist()), [1, 2, 3]) + + def test_data_evolution_multiple_commits(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, + options={ + 'file.format': 'row', + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.de_row_multi', schema, False) + table = self.catalog.get_table('default.de_row_multi') + + write_builder = table.new_batch_write_builder() + + data1 = pa.Table.from_pydict({ + 'id': [1, 2, 3], + 'name': ['alice', 'bob', 'charlie'], + 'age': [25, 30, 35], + 'city': ['beijing', 'shanghai', 'guangzhou'], + }, schema=self.pa_schema) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data1) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + data2 = pa.Table.from_pydict({ + 'id': [4, 5], + 'name': ['dave', 'eve'], + 'age': [40, 45], + 'city': ['shenzhen', 'hangzhou'], + }, schema=self.pa_schema) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data2) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 5) + self.assertEqual(sorted(result.column('id').to_pylist()), [1, 2, 3, 4, 5]) + + +class RowFormatBlobTableTest(unittest.TestCase): + """Test ROW format with blob tables (blob columns stored separately).""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('blob_data', pa.large_binary()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def test_blob_table_with_row_format(self): + schema = Schema.from_pyarrow_schema( + self.pa_schema, + options={ + 'file.format': 'row', + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('default.blob_row_basic', schema, False) + table = self.catalog.get_table('default.blob_row_basic') + + data = pa.Table.from_pydict({ + 'id': [1, 2, 3], + 'name': ['a', 'b', 'c'], + 'blob_data': [b'\x01\x02\x03', b'\x04\x05', b'\x06\x07\x08\x09'], + }, schema=self.pa_schema) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 3) + self.assertEqual(sorted(result.column('id').to_pylist()), [1, 2, 3]) + + +class RowFormatVectorTableTest(unittest.TestCase): + """Test ROW format with vector tables (vector columns inline).""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('id', pa.int64()), + ('embed', pa.list_(pa.float32(), 3)), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def test_vector_inline_with_row_format(self): + """Vector stored inline in the same ROW file.""" + schema = Schema.from_pyarrow_schema( + self.pa_schema, + options={'file.format': 'row'}) + self.catalog.create_table('default.vec_row_inline', schema, False) + table = self.catalog.get_table('default.vec_row_inline') + + data = pa.table({ + 'id': pa.array([1, 2, 3], type=pa.int64()), + 'embed': pa.FixedSizeListArray.from_arrays( + pa.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], type=pa.float32()), + 3 + ), + }) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + table_write.write_arrow(data) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + result = read_builder.new_read().to_arrow(splits) + self.assertEqual(result.num_rows, 3) + self.assertEqual(sorted(result.column('id').to_pylist()), [1, 2, 3]) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_limit_pushdown.py b/paimon-python/pypaimon/tests/test_limit_pushdown.py index 2e717c28c7e1..a4b51ee3c00c 100644 --- a/paimon-python/pypaimon/tests/test_limit_pushdown.py +++ b/paimon-python/pypaimon/tests/test_limit_pushdown.py @@ -110,8 +110,8 @@ def test_append_only_limit_stops_within_first_split(self): exactly 3 rows — even though each partition split has 5 rows.""" table = self._create_ao_table('limit_ao_within_split') self._write_ao_partitions(table, [ - ('p1', list(range(5))), # 5 rows - ('p2', list(range(5, 10))), # 5 rows + ('p1', list(range(5))), # 5 rows + ('p2', list(range(5, 10))), # 5 rows ]) rb = table.new_read_builder().with_limit(3) result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) @@ -207,6 +207,54 @@ def test_to_iterator_limit_short_circuits(self): rows = list(it) self.assertEqual(len(rows), 7) + # ---- SplitRead-level limit pushdown verification --------------------- + + def test_append_only_split_read_creates_limited_batch_reader(self): + """Verify that RawFileSplitRead.create_reader() returns a + LimitedRecordBatchReader (inherits RecordBatchReader) when limit + is set, so the arrow-batch read path is taken.""" + from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader + from pypaimon.read.reader.limited_record_reader import LimitedRecordBatchReader + + table = self._create_ao_table('limit_ao_split_read') + self._write_ao_partitions(table, [('p1', list(range(10)))]) + rb = table.new_read_builder().with_limit(3) + table_read = rb.new_read() + splits = rb.new_scan().plan().splits() + self.assertGreater(len(splits), 0) + for split in splits: + split_read = table_read._create_split_read(split) + self.assertEqual(split_read.limit, 3) + reader = split_read.create_reader() + self.assertIsInstance(reader, LimitedRecordBatchReader, + "RawFileSplitRead.create_reader() should wrap with LimitedRecordBatchReader") + self.assertIsInstance(reader, RecordBatchReader, + "LimitedRecordBatchReader should be a RecordBatchReader") + reader.close() + + def test_append_only_split_read_limit_truncates_within_split(self): + """Directly read from a single split's reader with limit and verify + the reader itself stops at the limit boundary, not relying on + TableRead-level truncation.""" + table = self._create_ao_table('limit_ao_split_truncate') + self._write_ao_partitions(table, [('p1', list(range(20)))]) + rb = table.new_read_builder().with_limit(5) + table_read = rb.new_read() + splits = rb.new_scan().plan().splits() + self.assertEqual(len(splits), 1) + split_read = table_read._create_split_read(splits[0]) + reader = split_read.create_reader() + # Drain the reader directly, bypassing TableRead-level control + total_rows = 0 + while True: + batch = reader.read_arrow_batch() + if batch is None: + break + total_rows += batch.num_rows + reader.close() + self.assertEqual(total_rows, 5, + "SplitRead-level reader should stop at limit=5, got %d" % total_rows) + if __name__ == '__main__': unittest.main() diff --git a/paimon-python/pypaimon/tests/test_merge_engine_dispatch.py b/paimon-python/pypaimon/tests/test_merge_engine_dispatch.py new file mode 100644 index 000000000000..f8e2d717a563 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_merge_engine_dispatch.py @@ -0,0 +1,135 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Unit tests for ``pypaimon.common.merge_engine_dispatch``. + +Pins down the table-option parsing the dispatch uses to decide whether +``partial-update`` should run or be rejected. The key contract: strict +``"true"``-only boolean parsing aligned with the table-option parser +used elsewhere in Paimon, so an option string the rest of the +toolchain treats as ``false`` is not silently elevated to ``true`` +here. +""" + +import unittest + +from pypaimon.common.merge_engine_dispatch import ( + _option_is_truthy, + build_merge_function, + partial_update_unsupported_options, +) +from pypaimon.common.options.core_options import MergeEngine +from pypaimon.read.reader.partial_update_merge_function import \ + PartialUpdateMergeFunction + + +class OptionIsTruthyTest(unittest.TestCase): + """``_option_is_truthy`` accepts only ``"true"`` (case-insensitive) + as truthy; every other string -- including ``"1"``, ``"yes"``, + ``"on"`` -- is falsey. Matches the table-option parser used + elsewhere in Paimon. + """ + + def test_true_string_is_truthy(self): + self.assertTrue(_option_is_truthy("true")) + + def test_true_string_is_case_insensitive(self): + for v in ("TRUE", "True", "tRuE"): + self.assertTrue(_option_is_truthy(v), v) + + def test_true_string_tolerates_surrounding_whitespace(self): + self.assertTrue(_option_is_truthy(" true ")) + + def test_python_bool_true_is_truthy(self): + self.assertTrue(_option_is_truthy(True)) + + def test_python_bool_false_is_falsey(self): + self.assertFalse(_option_is_truthy(False)) + + def test_none_is_falsey(self): + self.assertFalse(_option_is_truthy(None)) + + def test_non_true_strings_are_falsey(self): + # The table-option parser elsewhere in Paimon returns false + # for every one of these. pypaimon must do the same so a + # user-set "yes" is not silently elevated to true here while + # the rest of the toolchain treats it as false. + for v in ("1", "yes", "on", "Yes", "ON", "y", "t", "0", "no", "off", + "false", "FALSE", ""): + self.assertFalse(_option_is_truthy(v), v) + + +class PartialUpdateUnsupportedOptionsTest(unittest.TestCase): + + def test_ignore_delete_yes_is_not_flagged(self): + # ``yes`` is falsey under the upstream table-option parser, + # so partial-update must NOT be blocked here. Pre-fix pypaimon + # rejected this; the fix aligns the dispatch with the parser. + unsupported = partial_update_unsupported_options( + {"partial-update.ignore-delete": "yes"}) + self.assertEqual(unsupported, set()) + + def test_ignore_delete_true_is_flagged(self): + unsupported = partial_update_unsupported_options( + {"partial-update.ignore-delete": "true"}) + self.assertEqual(unsupported, {"partial-update.ignore-delete"}) + + def test_sequence_group_is_flagged(self): + unsupported = partial_update_unsupported_options( + {"fields.a.sequence-group": "b"}) + self.assertEqual(unsupported, {"fields.a.sequence-group"}) + + def test_unrelated_options_are_not_flagged(self): + unsupported = partial_update_unsupported_options( + {"bucket": "1", "merge-engine": "partial-update"}) + self.assertEqual(unsupported, set()) + + +class BuildMergeFunctionTest(unittest.TestCase): + """``build_merge_function`` forwards ``value_field_names`` to + ``PartialUpdateMergeFunction`` so the NOT-NULL error message can + surface the offending column name. This is the only behavioural + contract the dispatch adds on top of routing. + """ + + def test_partial_update_forwards_field_names(self): + mf = build_merge_function( + engine=MergeEngine.PARTIAL_UPDATE, + raw_options={}, + key_arity=1, + value_arity=2, + value_field_nullables=[True, True], + value_field_names=['col_a', 'col_b'], + ) + self.assertIsInstance(mf, PartialUpdateMergeFunction) + self.assertEqual(mf._value_field_names, ['col_a', 'col_b']) + + def test_partial_update_without_field_names_keeps_none(self): + mf = build_merge_function( + engine=MergeEngine.PARTIAL_UPDATE, + raw_options={}, + key_arity=1, + value_arity=2, + value_field_nullables=[True, True], + ) + self.assertIsInstance(mf, PartialUpdateMergeFunction) + self.assertIsNone(mf._value_field_names) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_partial_update_e2e.py b/paimon-python/pypaimon/tests/test_partial_update_e2e.py index af5f61e98f6e..90fa5f711b46 100644 --- a/paimon-python/pypaimon/tests/test_partial_update_e2e.py +++ b/paimon-python/pypaimon/tests/test_partial_update_e2e.py @@ -21,9 +21,9 @@ Each test creates a PK table with ``merge-engine`` set to a particular value, writes one or more batches, and reads back. Partial-update reads must merge non-null fields across batches; ``deduplicate`` must keep -the latest row only; ``aggregation`` and ``first-row`` must raise -``NotImplementedError`` (until they are ported), since silently -treating them as deduplicate would corrupt the user's data. +the latest row only; ``first-row`` must keep the earliest row. +``aggregation`` has its own engine-specific e2e coverage in +:mod:`test_aggregation_e2e`. """ import os @@ -88,6 +88,24 @@ def _write(self, table, rows): w.close() c.close() + def _write_many(self, table, batches): + """Multiple ``write_arrow`` calls inside a single ``prepare_commit``. + + Mirrors the reviewer's question: rows that land in the same + underlying data file must still go through the merge-engine + dispatch; in-writer merging cannot silently degrade to dedupe. + """ + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + for rows in batches: + w.write_arrow(pa.Table.from_pylist(rows, schema=self.pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + def _read(self, table): rb = table.new_read_builder() splits = rb.new_scan().plan().splits() @@ -173,6 +191,46 @@ def test_partial_update_later_null_does_not_clobber_earlier_value(self): [{'id': 1, 'a': 'A', 'b': 'B', 'c': 'C'}], ) + # -- single-commit, multiple write_arrow calls ----------------------- + # + # The in-memory merge buffer added to ``KeyValueDataWriter`` runs + # the merge function on flush, so rows from multiple ``write_arrow`` + # calls that share a primary key are folded into a single row before + # the data file is written. The flushed file therefore satisfies the + # LSM "PK unique within a file" invariant the read-side + # ``raw_convertible`` fast path relies on. + + def test_partial_update_two_write_arrows_single_commit(self): + """Two ``write_arrow`` calls + one ``prepare_commit``: each + carries a disjoint non-null field; result is the per-field merge. + """ + table = self._create_pk_table('two_writes_single_commit') + self._write_many(table, [ + [{'id': 1, 'a': 'A', 'b': None, 'c': None}], + [{'id': 1, 'a': None, 'b': 'B', 'c': None}], + ]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'a': 'A', 'b': 'B', 'c': None}], + ) + + def test_partial_update_three_write_arrows_single_commit(self): + """Three ``write_arrow`` calls in a single commit compose into + the union of non-null fields. + """ + table = self._create_pk_table('three_writes_single_commit') + self._write_many(table, [ + [{'id': 1, 'a': 'A', 'b': None, 'c': None}], + [{'id': 1, 'a': None, 'b': 'B', 'c': None}], + [{'id': 1, 'a': None, 'b': None, 'c': 'C'}], + ]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'a': 'A', 'b': 'B', 'c': 'C'}], + ) + # -- deduplicate (regression) ---------------------------------------- def test_deduplicate_engine_unchanged(self): @@ -188,35 +246,63 @@ def test_deduplicate_engine_unchanged(self): [{'id': 1, 'a': 'new', 'b': None, 'c': None}], ) - # -- engines we don't support yet ------------------------------------ + def test_deduplicate_two_write_arrows_single_commit(self): + """Pre-PR master silently returned both rows because the + flushed file held two records sharing a primary key. With the + in-memory merge buffer in place, ``deduplicate`` collapses + same-PK rows in a single commit too -- LSM "PK unique within a + file" invariant restored. + """ + table = self._create_pk_table( + 'dedupe_two_writes_single_commit', + merge_engine='deduplicate', + ) + self._write_many(table, [ + [{'id': 1, 'a': 'first', 'b': 'old', 'c': None}], + [{'id': 1, 'a': 'second', 'b': 'new', 'c': None}], + ]) - def test_aggregation_engine_raises_not_implemented(self): - """Until ``aggregation`` is ported, reading an aggregation table - must raise rather than silently produce dedupe results.""" - table = self._create_pk_table('agg_unsupported', - merge_engine='aggregation') - self._write(table, [{'id': 1, 'a': 'x', 'b': None, 'c': None}]) - self._write(table, [{'id': 1, 'a': 'y', 'b': None, 'c': None}]) + self.assertEqual( + self._read(table), + [{'id': 1, 'a': 'second', 'b': 'new', 'c': None}], + ) - rb = table.new_read_builder() - splits = rb.new_scan().plan().splits() - with self.assertRaises(NotImplementedError) as cm: - rb.new_read().to_arrow(splits) - self.assertIn('aggregation', str(cm.exception)) + # -- other supported engines (smoke) --------------------------------- + + def test_first_row_engine_keeps_first(self): + """The ``first-row`` engine must keep the earliest row per PK. - def test_first_row_engine_raises_not_implemented(self): - """Until ``first-row`` is ported, reading a first-row table must - raise rather than silently produce dedupe results.""" - table = self._create_pk_table('first_row_unsupported', + Both the writer-side merge buffer and the reader-side merge + function go through ``merge_engine_dispatch``, so first-row is + a real supported engine (no dedupe fallback / no NotImplemented + raise) on both sides. + """ + table = self._create_pk_table('first_row_supported', merge_engine='first-row') - self._write(table, [{'id': 1, 'a': 'x', 'b': None, 'c': None}]) - self._write(table, [{'id': 1, 'a': 'y', 'b': None, 'c': None}]) + self._write(table, [{'id': 1, 'a': 'first', 'b': None, 'c': None}]) + self._write(table, [{'id': 1, 'a': 'second', 'b': 'B', 'c': 'C'}]) - rb = table.new_read_builder() - splits = rb.new_scan().plan().splits() - with self.assertRaises(NotImplementedError) as cm: - rb.new_read().to_arrow(splits) - self.assertIn('first-row', str(cm.exception)) + self.assertEqual( + self._read(table), + [{'id': 1, 'a': 'first', 'b': None, 'c': None}], + ) + + def test_aggregation_engine_write_logs_fallback_warning(self): + """The write-side fallback to deduplicate for unsupported engines + is silent in terms of return value -- a ``logging.warning`` is + the only signal that file contents will not match the table's + declared semantics. Important when the same table is read back + by a reader that honours the declared engine; the pypaimon + read-side raise wouldn't fire there. + """ + table = self._create_pk_table('agg_warning', + merge_engine='aggregation') + with self.assertLogs( + 'pypaimon.write.file_store_write', level='WARNING') as cm: + self._write(table, [{'id': 1, 'a': 'x', 'b': None, 'c': None}]) + combined = '\n'.join(cm.output) + self.assertIn('aggregation', combined) + self.assertIn('deduplicate', combined) # -- partial-update + out-of-scope option combinations --------------- # @@ -229,15 +315,14 @@ def test_first_row_engine_raises_not_implemented(self): def _assert_partial_update_unsupported(self, table_name, extra_options, expected_keys): + # Shared dispatch runs at write time too, so the unsupported- + # option error surfaces inside the first ``write_arrow`` call + # (when ``FileStoreWrite._create_data_writer`` first runs) + # rather than waiting for read. table = self._create_pk_table( table_name, extra_options=extra_options) - self._write(table, [{'id': 1, 'a': 'A', 'b': None, 'c': None}]) - self._write(table, [{'id': 1, 'a': None, 'b': 'B', 'c': None}]) - - rb = table.new_read_builder() - splits = rb.new_scan().plan().splits() with self.assertRaises(NotImplementedError) as cm: - rb.new_read().to_arrow(splits) + self._write(table, [{'id': 1, 'a': 'A', 'b': None, 'c': None}]) msg = str(cm.exception) self.assertIn("partial-update", msg) for key in expected_keys: @@ -287,25 +372,25 @@ def test_partial_update_with_remove_record_on_sequence_group_raises(self): ) def test_partial_update_unsupported_options_guard_covers_raw_convertible(self): - """The unsupported-options guard must fire even when the scan - would dispatch every split through ``RawFileSplitRead`` (i.e. a - single-snapshot table where rows don't overlap). + """The read-side guard at ``TableRead.__init__`` must fire even + when the scan would dispatch every split through + ``RawFileSplitRead`` (single-snapshot, non-overlapping rows). Before the guard moved to ``TableRead.__init__`` this case silently bypassed validation because raw-convertible splits skip - ``MergeFileSplitRead`` entirely — and an option like - ``partial-update.remove-record-on-delete`` would be ignored on - the read path while the user assumed it was honoured. + ``MergeFileSplitRead`` entirely -- the read path's + ``_build_merge_function`` never ran, so an option like + ``partial-update.remove-record-on-delete`` was ignored on read. + + The shared dispatch now also fires on the write path's first + flush (see ``_assert_partial_update_unsupported``), so we skip + ``_write`` here: the read-side guard runs at ``new_read()`` + construction time regardless of whether data exists. """ table = self._create_pk_table( 'pu_rrod_raw_convertible', extra_options={'partial-update.remove-record-on-delete': 'true'}, ) - # Single write -> single snapshot -> splits are raw-convertible. - self._write(table, [ - {'id': 1, 'a': 'A', 'b': None, 'c': None}, - {'id': 2, 'a': 'B', 'b': None, 'c': None}, - ]) rb = table.new_read_builder() with self.assertRaises(NotImplementedError) as cm: rb.new_read() diff --git a/paimon-python/pypaimon/tests/test_partial_update_merge_function.py b/paimon-python/pypaimon/tests/test_partial_update_merge_function.py index 60dfc7198dfb..b412fdd38851 100644 --- a/paimon-python/pypaimon/tests/test_partial_update_merge_function.py +++ b/paimon-python/pypaimon/tests/test_partial_update_merge_function.py @@ -127,8 +127,8 @@ def test_get_result_before_any_add_returns_none(self): self.assertIsNone(mf.get_result()) def test_update_after_is_treated_as_insert(self): - # Java's PartialUpdate accepts UPDATE_AFTER alongside INSERT in - # non-sequence-group mode (both are "add" kinds). Mirror that. + # UPDATE_AFTER is treated as an "add" alongside INSERT in + # non-sequence-group mode, matching the upstream contract. mf = PartialUpdateMergeFunction(key_arity=1, value_arity=2) mf.reset() mf.add(_kv((1,), 100, RowKind.INSERT, ('a', None))) @@ -172,28 +172,54 @@ def test_result_is_decoupled_from_input_kv(self): self.assertEqual(_result_key(result), (1,)) self.assertEqual(_result_value(result), ('a', 'x')) - # -- NOT-NULL input validation (mirrors Java's updateNonNullFields) ---- + # -- NOT-NULL input validation ---- def test_first_insert_with_null_for_not_null_field_raises(self): - """If the very first row writes null to a NOT NULL field, raise — - same input-validation Java does in updateNonNullFields().""" + """If the very first row writes null to a NOT NULL field, raise -- + the schema's NOT NULL declaration is enforced on every add().""" mf = PartialUpdateMergeFunction( key_arity=1, value_arity=2, nullables=[True, False]) mf.reset() with self.assertRaises(ValueError) as cm: mf.add(_kv((1,), 1, RowKind.INSERT, ('a', None))) - self.assertIn("Field 1", str(cm.exception)) + msg = str(cm.exception) + # Without field names we fall back to the index, but the + # actionable hint must still be there. + self.assertIn("at index 1", msg) + self.assertIn("Declare the field nullable", msg) def test_subsequent_insert_with_null_for_not_null_field_raises(self): - """A later null on a NOT NULL field must also raise — Java checks - on every add(), not just the first one.""" + """A later null on a NOT NULL field must also raise -- the + NOT NULL check fires on every add(), not just the first one.""" mf = PartialUpdateMergeFunction( key_arity=1, value_arity=2, nullables=[True, False]) mf.reset() mf.add(_kv((1,), 1, RowKind.INSERT, ('a', 'x'))) with self.assertRaises(ValueError) as cm: mf.add(_kv((1,), 2, RowKind.INSERT, (None, None))) - self.assertIn("Field 1", str(cm.exception)) + self.assertIn("at index 1", str(cm.exception)) + + def test_not_null_error_message_uses_field_name_when_given(self): + """When ``value_field_names`` is supplied, the NOT-NULL error + names the offending field so the message is directly actionable + instead of citing a bare positional index.""" + mf = PartialUpdateMergeFunction( + key_arity=1, value_arity=2, + nullables=[True, False], + value_field_names=['a', 'b']) + mf.reset() + with self.assertRaises(ValueError) as cm: + mf.add(_kv((1,), 1, RowKind.INSERT, ('a', None))) + msg = str(cm.exception) + self.assertIn("'b'", msg) + self.assertIn("Declare the field nullable", msg) + + def test_value_field_names_length_mismatch_raises(self): + with self.assertRaises(ValueError): + PartialUpdateMergeFunction( + key_arity=1, value_arity=2, + nullables=[True, True], + value_field_names=['only_one']) def test_null_for_nullable_field_is_absorbed(self): """A null input on a nullable field is silently absorbed (existing diff --git a/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py b/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py index a849f2788efa..6cfa5ea5f68f 100644 --- a/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py +++ b/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py @@ -25,16 +25,29 @@ in ``pypaimon/tests/ray_repartition_test.py``. """ +import importlib.util +from pathlib import Path import unittest from unittest.mock import MagicMock import pyarrow as pa -from pypaimon.ray.shuffle import (BUCKET_KEY_COL, _coerce_large_string_types, - _make_bucket_udf, _pick_bucket_col_name, - maybe_apply_repartition) from pypaimon.table.bucket_mode import BucketMode +_SHUFFLE_PATH = ( + Path(__file__).resolve().parents[1] / "ray" / "shuffle.py" +) +_SHUFFLE_SPEC = importlib.util.spec_from_file_location( + "pypaimon_ray_shuffle_under_test", _SHUFFLE_PATH) +_SHUFFLE = importlib.util.module_from_spec(_SHUFFLE_SPEC) +_SHUFFLE_SPEC.loader.exec_module(_SHUFFLE) + +BUCKET_KEY_COL = _SHUFFLE.BUCKET_KEY_COL +_coerce_large_string_types = _SHUFFLE._coerce_large_string_types +_make_bucket_udf = _SHUFFLE._make_bucket_udf +_pick_bucket_col_name = _SHUFFLE._pick_bucket_col_name +maybe_apply_repartition = _SHUFFLE.maybe_apply_repartition + class BucketUdfTest(unittest.TestCase): """The bucket-key UDF appends a deterministic int32 column.""" @@ -135,12 +148,13 @@ def test_casts_large_binary_back_to_binary(self): class BucketModeDispatchTest(unittest.TestCase): - """``maybe_apply_repartition`` clusters HASH_FIXED tables and - returns other bucket modes unchanged.""" + """``maybe_apply_repartition`` clusters only supported HASH_FIXED + writes and rejects unsafe primary-key Ray writes.""" - def _make_table(self, bucket_mode): + def _make_table(self, bucket_mode, is_primary_key_table=False): table = MagicMock() table.bucket_mode.return_value = bucket_mode + table.is_primary_key_table = is_primary_key_table return table def test_bucket_unaware_returns_dataset_unchanged(self): @@ -155,13 +169,89 @@ def test_hash_dynamic_returns_dataset_unchanged(self): self.assertIs(maybe_apply_repartition(dataset, table), dataset) + def test_hash_dynamic_primary_key_raises_value_error(self): + dataset = MagicMock(name="dataset") + table = self._make_table( + BucketMode.HASH_DYNAMIC, is_primary_key_table=True) + + with self.assertRaisesRegex(ValueError, "HASH_DYNAMIC primary-key"): + maybe_apply_repartition(dataset, table) + dataset.map_batches.assert_not_called() + + def test_hash_dynamic_primary_key_map_groups_raises_value_error(self): + dataset = MagicMock(name="dataset") + table = self._make_table( + BucketMode.HASH_DYNAMIC, is_primary_key_table=True) + + with self.assertRaisesRegex(ValueError, "HASH_DYNAMIC primary-key"): + maybe_apply_repartition(dataset, table, "map_groups") + dataset.map_batches.assert_not_called() + def test_cross_partition_returns_dataset_unchanged(self): dataset = object() table = self._make_table(BucketMode.CROSS_PARTITION) self.assertIs(maybe_apply_repartition(dataset, table), dataset) - def test_hash_fixed_runs_map_batches_groupby_chain(self): + def test_cross_partition_primary_key_raises_value_error(self): + dataset = MagicMock(name="dataset") + table = self._make_table( + BucketMode.CROSS_PARTITION, is_primary_key_table=True) + + with self.assertRaisesRegex(ValueError, "CROSS_PARTITION primary-key"): + maybe_apply_repartition(dataset, table) + dataset.map_batches.assert_not_called() + + def test_postpone_primary_key_returns_dataset_unchanged(self): + dataset = MagicMock(name="dataset") + table = self._make_table( + BucketMode.POSTPONE_MODE, is_primary_key_table=True) + + self.assertIs(maybe_apply_repartition(dataset, table), dataset) + dataset.map_batches.assert_not_called() + + def test_hash_fixed_default_returns_dataset_unchanged(self): + dataset = MagicMock(name="dataset") + table = MagicMock() + table.bucket_mode.return_value = BucketMode.HASH_FIXED + table.is_primary_key_table = False + + self.assertIs(maybe_apply_repartition(dataset, table), dataset) + dataset.map_batches.assert_not_called() + + def test_hash_fixed_off_returns_dataset_unchanged(self): + dataset = MagicMock(name="dataset") + table = MagicMock() + table.bucket_mode.return_value = BucketMode.HASH_FIXED + table.is_primary_key_table = False + + self.assertIs( + maybe_apply_repartition(dataset, table, "off"), + dataset, + ) + dataset.map_batches.assert_not_called() + + def test_hash_fixed_primary_key_default_raises_value_error(self): + dataset = MagicMock(name="dataset") + table = MagicMock() + table.bucket_mode.return_value = BucketMode.HASH_FIXED + table.is_primary_key_table = True + + with self.assertRaises(ValueError): + maybe_apply_repartition(dataset, table) + dataset.map_batches.assert_not_called() + + def test_hash_fixed_primary_key_off_raises_value_error(self): + dataset = MagicMock(name="dataset") + table = MagicMock() + table.bucket_mode.return_value = BucketMode.HASH_FIXED + table.is_primary_key_table = True + + with self.assertRaises(ValueError): + maybe_apply_repartition(dataset, table, "off") + dataset.map_batches.assert_not_called() + + def test_hash_fixed_map_groups_runs_map_batches_groupby_chain(self): dataset = MagicMock(name="dataset") dataset.map_batches.return_value.groupby.return_value \ .map_groups.return_value.drop_columns.return_value = "clustered" @@ -173,7 +263,7 @@ def test_hash_fixed_runs_map_batches_groupby_chain(self): type("F", (), {"name": "value"})(), ] - out = maybe_apply_repartition(dataset, table) + out = maybe_apply_repartition(dataset, table, "map_groups") self.assertEqual(out, "clustered") # The helper appends a transient bucket column, groups by it, @@ -199,12 +289,19 @@ def test_hash_fixed_groups_include_partition_keys(self): type("F", (), {"name": "dt"})(), ] - maybe_apply_repartition(dataset, table) + maybe_apply_repartition(dataset, table, "map_groups") group_call = dataset.map_batches.return_value.groupby.call_args passed_keys = group_call.args[0] self.assertEqual(passed_keys, ["dt", BUCKET_KEY_COL]) + def test_invalid_precluster_mode_raises_value_error(self): + dataset = object() + table = self._make_table(BucketMode.HASH_FIXED) + + with self.assertRaises(ValueError): + maybe_apply_repartition(dataset, table, "hash_shuffle") + if __name__ == "__main__": unittest.main() diff --git a/paimon-python/pypaimon/tests/test_sequence_field_read.py b/paimon-python/pypaimon/tests/test_sequence_field_read.py new file mode 100644 index 000000000000..ed32768c2ed6 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_sequence_field_read.py @@ -0,0 +1,545 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""End-to-end tests for the ``sequence.field`` option on the read path. + +``sequence.field`` lets the user pick an explicit column (or columns) +whose value -- not the file-level sequence number -- decides which record +is the "latest" for a primary key. The tricky case is when the +write/file order disagrees with the ``sequence.field`` order: a row +written *later* (higher file sequence number) carrying a *lower* +``sequence.field`` value must lose to the earlier-written row. The Java +merge path applies a ``UserDefinedSeqComparator`` on the value row before +falling back to the file sequence number; pypaimon mirrors that via +``builtin_seq_comparator`` wired into ``SortMergeReaderWithMinHeap``. +""" + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema + + +class SequenceFieldReadE2ETest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('ts', pa.int64()), + ('ts2', pa.int64()), + ('val', pa.string()), + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_pk_table(self, table_name, merge_engine='deduplicate', + extra_options=None, partition_keys=None): + # bucket=1 forces all rows for a PK into one bucket so the read + # goes through SortMergeReader (where sequence ordering matters) + # instead of the raw-convertible fast path. + options = { + 'bucket': '1', + 'merge-engine': merge_engine, + } + if extra_options: + options.update(extra_options) + schema = Schema.from_pyarrow_schema( + self.pa_schema, + primary_keys=['id'], + partition_keys=partition_keys or [], + options=options, + ) + full = 'default.{}'.format(table_name) + self.catalog.create_table(full, schema, False) + return self.catalog.get_table(full) + + def _write(self, table, rows): + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=self.pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + def _read(self, table, projection=None, predicate=None): + rb = table.new_read_builder() + if projection is not None: + rb = rb.with_projection(projection) + if predicate is not None: + rb = rb.with_filter(predicate) + splits = rb.new_scan().plan().splits() + if not splits: + return [] + return sorted( + rb.new_read().to_arrow(splits).to_pylist(), + key=lambda r: r['id'], + ) + + # -- basic ordering -------------------------------------------------- + + def test_later_write_with_lower_sequence_field_loses(self): + """The row written second has a higher file sequence number but a + lower ``sequence.field`` value, so the earlier (higher-ts) row + must win. + """ + table = self._create_pk_table( + 'seq_basic', extra_options={'sequence.field': 'ts'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}], + ) + + def test_later_write_with_higher_sequence_field_wins(self): + """Sanity check the non-inverted case still works.""" + table = self._create_pk_table( + 'seq_basic_fwd', extra_options={'sequence.field': 'ts'}) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}]) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}], + ) + + # -- multiple sequence fields ---------------------------------------- + + def test_multi_sequence_field_left_to_right(self): + """When the first sequence field ties, the second breaks it.""" + table = self._create_pk_table( + 'seq_multi', extra_options={'sequence.field': 'ts,ts2'}) + # Same ts; ts2 decides. Write the ts2-winner first so file order + # disagrees with the sequence-field order. + self._write(table, [{'id': 1, 'ts': 10, 'ts2': 99, 'val': 'win'}]) + self._write(table, [{'id': 1, 'ts': 10, 'ts2': 1, 'val': 'lose'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 10, 'ts2': 99, 'val': 'win'}], + ) + + # -- sort order ------------------------------------------------------ + + def test_descending_sort_order_lowest_wins(self): + """With descending sort order, the lowest ``sequence.field`` value + is considered the latest. + """ + table = self._create_pk_table( + 'seq_desc', + extra_options={'sequence.field': 'ts', + 'sequence.field.sort-order': 'descending'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}], + ) + + def test_descending_sort_order_null_sequence_sorts_first(self): + """Null ordering must stay independent of sort order: Java builds + the sequence comparator with ``nullIsLast=false`` and applies + descending only to non-null value comparisons, so a null + ``sequence.field`` value always sorts first (loses) -- even under + descending order. A non-null row must therefore beat a null-seq + row regardless of write order. + """ + table = self._create_pk_table( + 'seq_desc_null', + extra_options={'sequence.field': 'ts', + 'sequence.field.sort-order': 'descending'}) + # null-seq row written second (higher file sequence number). With + # nulls-first ordering it still loses to the earlier non-null row. + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'real'}]) + self._write(table, [{'id': 1, 'ts': None, 'ts2': 0, 'val': 'null'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'real'}], + ) + + def test_ascending_sort_order_null_sequence_sorts_first(self): + """Mirror of the descending case under the default ascending order: + a null ``sequence.field`` value sorts first (loses) to a non-null + row written earlier. + """ + table = self._create_pk_table( + 'seq_asc_null', extra_options={'sequence.field': 'ts'}) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'real'}]) + self._write(table, [{'id': 1, 'ts': None, 'ts2': 0, 'val': 'null'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'real'}], + ) + + # -- projection drops the sequence field ----------------------------- + + def test_projection_dropping_sequence_field(self): + """Projecting columns that exclude the sequence field must still + return the sequence-field-correct row, and the output schema must + contain exactly the requested columns (no leaked ``ts``). + """ + table = self._create_pk_table( + 'seq_proj', extra_options={'sequence.field': 'ts'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}]) + + rows = self._read(table, projection=['id', 'val']) + self.assertEqual(rows, [{'id': 1, 'val': 'high'}]) + # No injected sequence column leaks into the output. + self.assertEqual(set(rows[0].keys()), {'id', 'val'}) + + def test_projection_dropping_sequence_field_with_predicate(self): + """Projection drops the seq field AND a predicate filters on a + kept column -- predicate coordinates must stay correct against the + widened (seq-injected) read type. + """ + table = self._create_pk_table( + 'seq_proj_pred', extra_options={'sequence.field': 'ts'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'keep'}]) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'drop'}]) + self._write(table, [{'id': 2, 'ts': 5, 'ts2': 0, 'val': 'other'}]) + + rb = table.new_read_builder().with_projection(['id', 'val']) + pb = rb.new_predicate_builder() + rows = self._read(table, projection=['id', 'val'], + predicate=pb.equal('val', 'keep')) + self.assertEqual(rows, [{'id': 1, 'val': 'keep'}]) + + # -- per merge engine ------------------------------------------------ + + def test_partial_update_respects_sequence_field(self): + """partial-update folds non-null fields in sequence-field order, so + a later-written but lower-ts row must not overwrite a field set by + the higher-ts row. + """ + table = self._create_pk_table( + 'seq_pu', merge_engine='partial-update', + extra_options={'sequence.field': 'ts'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}], + ) + + def test_first_row_with_sequence_field_rejected(self): + """sequence.field on the first-row merge engine is an invalid + configuration that Java rejects at schema validation + (SchemaValidation.validateSequenceField). pypaimon has no + schema-creation validation, so the read-builder guard must reject + it rather than silently apply a sequence ordering first-row never + honors on write. + """ + table = self._create_pk_table( + 'seq_fr', merge_engine='first-row', + extra_options={'sequence.field': 'ts'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + with self.assertRaises(ValueError) as ctx: + table.new_read_builder().new_read() + self.assertIn('FIRST_ROW', str(ctx.exception)) + + def test_aggregation_last_value_respects_sequence_field(self): + """``last_value`` must pick the value from the highest-sequence-field + row, even when that row was written first. + """ + table = self._create_pk_table( + 'seq_agg', merge_engine='aggregation', + extra_options={ + 'sequence.field': 'ts', + 'fields.val.aggregate-function': 'last_value', + 'fields.ts2.aggregate-function': 'last_value', + }) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}], + ) + + # -- unsupported sub-features still rejected ------------------------- + + def test_missing_sequence_field_rejected(self): + """A sequence.field naming a column absent from the schema is + invalid (Java SchemaValidation). The guard must reject it with a + clear message before any read execution. + """ + table = self._create_pk_table( + 'seq_missing', extra_options={'sequence.field': 'nope'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'x'}]) + with self.assertRaises(ValueError) as ctx: + table.new_read_builder().new_read() + self.assertIn('nope', str(ctx.exception)) + + def test_duplicate_sequence_field_rejected(self): + """A sequence.field listing the same column twice is invalid + (Java SchemaValidation rejects repeated sequence fields). + """ + table = self._create_pk_table( + 'seq_dup', extra_options={'sequence.field': 'ts,ts'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'x'}]) + with self.assertRaises(ValueError) as ctx: + table.new_read_builder().new_read() + self.assertIn('ts', str(ctx.exception)) + + def test_empty_segment_sequence_field_rejected(self): + """A malformed ``sequence.field`` with an empty segment (e.g. + ``'ts,,ts2'``) leaves an empty field name after trimming -- matching + Java ``CoreOptions.sequenceField()``, which trims but does not drop + empty segments -- and must be rejected by validation rather than + silently accepted as ``['ts', 'ts2']``. + """ + table = self._create_pk_table( + 'seq_empty_seg', extra_options={'sequence.field': 'ts,,ts2'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'x'}]) + with self.assertRaises(ValueError) as ctx: + table.new_read_builder().new_read() + # The empty field name is the one that can't be found in the schema. + self.assertIn('can not be found', str(ctx.exception)) + + def test_cross_partition_update_with_sequence_field_rejected(self): + """sequence.field is invalid under cross-partition update (the PK + does not include all partition fields), matching Java + SchemaValidation. + """ + table = self._create_pk_table( + 'seq_xpart', extra_options={'sequence.field': 'ts'}, + partition_keys=['ts2']) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'x'}]) + with self.assertRaises(ValueError) as ctx: + table.new_read_builder().new_read() + self.assertIn('cross partition', str(ctx.exception).lower()) + + def test_aggregate_function_on_sequence_field_rejected(self): + """Defining an aggregator on the sequence column is invalid: Java + rejects fields..aggregate-function outright in + SchemaValidation.validateSequenceField. The read-builder guard + must reject it rather than silently override the user's + aggregator with last_value. + """ + table = self._create_pk_table( + 'seq_agg_on_seq', merge_engine='aggregation', + extra_options={'sequence.field': 'ts', + 'fields.ts.aggregate-function': 'sum'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'x'}]) + with self.assertRaises(ValueError) as ctx: + table.new_read_builder().new_read() + self.assertIn('fields.ts.aggregate-function', str(ctx.exception)) + + def test_sequence_group_still_rejected(self): + """Top-level sequence.field is supported, but per-field + sequence-group is not -- it must still be rejected. The shared + merge-engine dispatch now rejects this combination fail-fast on + the write path, so the write (not the read) is what raises. + """ + table = self._create_pk_table( + 'seq_group', merge_engine='partial-update', + extra_options={'sequence.field': 'ts', + 'fields.ts2.sequence-group': 'val'}) + with self.assertRaises(NotImplementedError): + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'x'}]) + + def test_nested_sequence_field_rejected(self): + """nested-sequence-field is unimplemented and must be rejected + rather than silently ignored by the top-level comparator. + """ + table = self._create_pk_table( + 'seq_nested', merge_engine='deduplicate', + extra_options={'sequence.field': 'ts', + 'fields.val.nested-sequence-field': 'ts2'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'x'}]) + with self.assertRaises(NotImplementedError): + self._read(table) + + def test_trailing_comma_sequence_field_tolerated(self): + """A trailing comma (``'ts,'``) must be tolerated, matching Java + ``String.split(',')`` which drops trailing empty segments. It + behaves exactly like ``'ts'`` -- not rejected as an empty field. + """ + table = self._create_pk_table( + 'seq_trailing', extra_options={'sequence.field': 'ts,'}) + self._write(table, [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}]) + self._write(table, [{'id': 1, 'ts': 50, 'ts2': 0, 'val': 'low'}]) + + self.assertEqual( + self._read(table), + [{'id': 1, 'ts': 100, 'ts2': 0, 'val': 'high'}], + ) + + def test_complex_type_sequence_field_rejected(self): + """A complex (non-atomic) sequence field is valid in Java (handled + via RecordComparator) but unimplemented in pypaimon's atomic-only + comparator. It must be rejected with a clear NotImplementedError + rather than failing later with an obscure attribute error. + """ + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('seq', pa.list_(pa.int64())), + ('val', pa.string()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, primary_keys=['id'], + options={'bucket': '1', 'merge-engine': 'deduplicate', + 'sequence.field': 'seq'}) + self.catalog.create_table('default.seq_complex', schema, False) + table = self.catalog.get_table('default.seq_complex') + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist( + [{'id': 1, 'seq': [1, 2], 'val': 'x'}], schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + with self.assertRaises(NotImplementedError) as ctx: + table.new_read_builder().new_read() + self.assertIn('seq', str(ctx.exception)) + + +class SequenceFieldComparabilityUnitTest(unittest.TestCase): + """Unit-level coverage of ``is_comparable_seq_field`` -- the predicate + behind the read-builder guard. VARIANT in particular is an + ``AtomicType`` but has no ordering, so it must be rejected like the + complex types rather than slipping through an ``isinstance(AtomicType)`` + check. + """ + + def test_variant_sequence_field_not_comparable(self): + from pypaimon.read.reader.sort_merge_reader import ( + is_comparable_seq_field) + from pypaimon.schema.data_types import AtomicType, DataField + + variant = DataField(0, 'seq', AtomicType('VARIANT')) + self.assertFalse(is_comparable_seq_field(variant)) + + def test_atomic_types_are_comparable(self): + from pypaimon.read.reader.sort_merge_reader import ( + is_comparable_seq_field) + from pypaimon.schema.data_types import AtomicType, DataField + + for type_str in ('BIGINT', 'INT', 'TIMESTAMP(6)', 'DECIMAL(10, 2)', + 'STRING', 'BIGINT NOT NULL'): + field = DataField(0, 'seq', AtomicType(type_str)) + self.assertTrue(is_comparable_seq_field(field), + '{} should be comparable'.format(type_str)) + + def test_complex_types_not_comparable(self): + from pypaimon.read.reader.sort_merge_reader import ( + is_comparable_seq_field) + from pypaimon.schema.data_types import ( + ArrayType, AtomicType, DataField) + + array = DataField(0, 'seq', ArrayType(True, AtomicType('INT'))) + self.assertFalse(is_comparable_seq_field(array)) + + +class SequenceFieldParameterizedTypeTest(unittest.TestCase): + """The comparability check must accept parameterized atomic types + (TIMESTAMP(p), DECIMAL(p, s), TIME(p)) as sequence fields -- their + type string carries ``(...)`` which must not be mistaken for a + non-comparable type. + """ + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.catalog = CatalogFactory.create( + {'warehouse': os.path.join(cls.tempdir, 'warehouse')}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _run(self, table_name, pa_schema, rows_first, rows_second, expected): + schema = Schema.from_pyarrow_schema( + pa_schema, primary_keys=['id'], + options={'bucket': '1', 'merge-engine': 'deduplicate', + 'sequence.field': 'seq'}) + full = 'default.{}'.format(table_name) + self.catalog.create_table(full, schema, False) + table = self.catalog.get_table(full) + for batch in (rows_first, rows_second): + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(batch, schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + rows = rb.new_read().to_arrow(splits).to_pylist() + self.assertEqual(rows, expected) + + def test_timestamp_sequence_field(self): + import datetime + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('seq', pa.timestamp('us')), + ('val', pa.string()), + ]) + hi = datetime.datetime(2020, 1, 2) + lo = datetime.datetime(2020, 1, 1) + # Later write has the lower timestamp -> earlier (higher-ts) wins. + self._run('seq_ts', + pa_schema, + [{'id': 1, 'seq': hi, 'val': 'high'}], + [{'id': 1, 'seq': lo, 'val': 'low'}], + [{'id': 1, 'seq': hi, 'val': 'high'}]) + + def test_decimal_sequence_field(self): + from decimal import Decimal + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('seq', pa.decimal128(10, 2)), + ('val', pa.string()), + ]) + self._run('seq_dec', + pa_schema, + [{'id': 1, 'seq': Decimal('100.50'), 'val': 'high'}], + [{'id': 1, 'seq': Decimal('50.25'), 'val': 'low'}], + [{'id': 1, 'seq': Decimal('100.50'), 'val': 'high'}]) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/test_write_merge_buffer.py b/paimon-python/pypaimon/tests/test_write_merge_buffer.py new file mode 100644 index 000000000000..0b86b59138f1 --- /dev/null +++ b/paimon-python/pypaimon/tests/test_write_merge_buffer.py @@ -0,0 +1,366 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Unit tests for ``KeyValueDataWriter`` buffer behaviour. + +Covers the fold algorithm (`_merge_pending_by_pk`), the flush lifecycle +(`_flush_all` empties the buffer + clears pending_data), and the +roll-write helper (`_roll_write` splits oversized buffers across +multiple files). Drives a thin harness that bypasses +``DataWriter.__init__`` so tests can exercise these paths without +spinning up the real catalog/write stack. +""" + +import unittest +from unittest.mock import Mock + +import pyarrow as pa + +from pypaimon.read.reader.deduplicate_merge_function import \ + DeduplicateMergeFunction +from pypaimon.read.reader.partial_update_merge_function import \ + PartialUpdateMergeFunction +from pypaimon.write.writer.key_value_data_writer import KeyValueDataWriter + + +# Layout matches what ``KeyValueDataWriter._add_system_fields`` emits: +# ``[_KEY_id, _SEQUENCE_NUMBER, _VALUE_KIND, id, a, b]``. ``id`` is +# duplicated on the value side because the value layout in Paimon's +# row tuple includes every original column. +_SCHEMA = pa.schema([ + pa.field('_KEY_id', pa.int64(), nullable=False), + pa.field('_SEQUENCE_NUMBER', pa.int64(), nullable=False), + pa.field('_VALUE_KIND', pa.int8(), nullable=False), + pa.field('id', pa.int64(), nullable=False), + pa.field('a', pa.string()), + pa.field('b', pa.string()), +]) + + +def _row(pk, seq, a, b): + return { + '_KEY_id': pk, + '_SEQUENCE_NUMBER': seq, + '_VALUE_KIND': 0, + 'id': pk, + 'a': a, + 'b': b, + } + + +class _Harness(KeyValueDataWriter): + """Bypass ``DataWriter.__init__`` to keep tests focused. + + Provides just the attributes ``_merge_pending_by_pk`` / ``_flush_all`` + / ``_roll_write`` read, plus a recording stub for + ``_write_data_to_file`` so the roll-write path can be exercised + without touching the filesystem. + """ + + def __init__(self, merge_function, target_file_size: int = 10 ** 12): + self.trimmed_primary_keys = ['id'] + self._merge_function = merge_function + # Large enough that ``_check_and_roll_if_needed`` does not + # trigger on its own in tests that don't care about rolling. + self.target_file_size = target_file_size + self.pending_data = None + self.committed_files = [] + self.written_chunks = [] + + def _write_data_to_file(self, data): + # Record each chunk instead of writing to disk; mirrors the + # base writer's contract of appending to ``committed_files``. + self.written_chunks.append(data) + + +class WriteMergeBufferTest(unittest.TestCase): + + # -- deduplicate ------------------------------------------------------ + + def test_dedupe_collapses_same_pk_run_to_latest(self): + writer = _Harness(DeduplicateMergeFunction()) + data = pa.Table.from_pylist( + [_row(1, 1, 'old', None), _row(1, 2, 'new', None)], + schema=_SCHEMA, + ) + out = writer._merge_pending_by_pk(data) + self.assertEqual(out.num_rows, 1) + self.assertEqual( + out.to_pylist(), + [_row(1, 2, 'new', None)], + ) + + def test_dedupe_keeps_disjoint_keys(self): + writer = _Harness(DeduplicateMergeFunction()) + data = pa.Table.from_pylist( + [_row(1, 1, 'A', None), + _row(2, 2, 'B', None), + _row(3, 3, 'C', None)], + schema=_SCHEMA, + ) + out = writer._merge_pending_by_pk(data) + self.assertEqual(out.num_rows, 3) + self.assertEqual( + sorted(out.to_pylist(), key=lambda r: r['id']), + [_row(1, 1, 'A', None), + _row(2, 2, 'B', None), + _row(3, 3, 'C', None)], + ) + + # -- partial-update --------------------------------------------------- + + def _partial_update(self): + # Value-side carries 3 columns (id, a, b). The PK column ``id`` + # is duplicated into the value side so partial-update can apply + # last-non-null semantics uniformly across every original + # user column. + return PartialUpdateMergeFunction( + key_arity=1, value_arity=3, nullables=[True, True, True]) + + def test_partial_update_merges_non_null_per_field(self): + writer = _Harness(self._partial_update()) + data = pa.Table.from_pylist( + [_row(1, 1, 'A', None), _row(1, 2, None, 'B')], + schema=_SCHEMA, + ) + out = writer._merge_pending_by_pk(data) + self.assertEqual(out.num_rows, 1) + self.assertEqual(out.to_pylist(), [_row(1, 2, 'A', 'B')]) + + def test_partial_update_three_writes_compose(self): + writer = _Harness(self._partial_update()) + data = pa.Table.from_pylist( + [_row(1, 1, 'A', None), + _row(1, 2, None, 'B'), + _row(1, 3, None, None)], + schema=_SCHEMA, + ) + out = writer._merge_pending_by_pk(data) + self.assertEqual(out.to_pylist(), [_row(1, 3, 'A', 'B')]) + + def test_partial_update_later_null_does_not_clobber_earlier_value(self): + writer = _Harness(self._partial_update()) + data = pa.Table.from_pylist( + [_row(1, 1, 'KEEP', 'B'), _row(1, 2, None, None)], + schema=_SCHEMA, + ) + out = writer._merge_pending_by_pk(data) + self.assertEqual(out.to_pylist(), [_row(1, 2, 'KEEP', 'B')]) + + # -- edge cases ------------------------------------------------------- + + def test_empty_buffer_returns_empty(self): + writer = _Harness(DeduplicateMergeFunction()) + empty = pa.Table.from_pylist([], schema=_SCHEMA) + out = writer._merge_pending_by_pk(empty) + self.assertEqual(out.num_rows, 0) + + def test_single_row_buffer_skips_merge(self): + # Mock to confirm the merge function isn't invoked: a single + # row cannot have duplicates, so we sidestep the to_pylist + # round-trip. + mock_mf = Mock() + writer = _Harness(mock_mf) + data = pa.Table.from_pylist( + [_row(1, 1, 'X', None)], schema=_SCHEMA) + out = writer._merge_pending_by_pk(data) + self.assertEqual(out.num_rows, 1) + mock_mf.reset.assert_not_called() + mock_mf.add.assert_not_called() + mock_mf.get_result.assert_not_called() + + def test_get_result_none_drops_pk_run(self): + # Future-proof: contract says ``get_result`` returning ``None`` + # means the entire PK group should be dropped. + class DropAll: + def reset(self): + pass + + def add(self, _): + pass + + def get_result(self): + return None + + writer = _Harness(DropAll()) + data = pa.Table.from_pylist( + [_row(1, 1, 'A', None), _row(1, 2, 'B', None)], + schema=_SCHEMA, + ) + out = writer._merge_pending_by_pk(data) + self.assertEqual(out.num_rows, 0) + + # -- KeyValue pooling ------------------------------------------------- + + def test_keyvalue_pool_does_not_alias_results_across_runs(self): + # Pooling reuses one KeyValue across the whole fold. If the + # PartialUpdateMergeFunction's get_result snapshotting were + # broken, run 1's result would mutate when run 2's data is + # written into the pooled instance. This test would catch that + # regression: build a buffer with two distinct PK runs and + # verify both results stand on their own. + writer = _Harness(self._partial_update()) + data = pa.Table.from_pylist( + [_row(1, 1, 'A1', None), + _row(1, 2, None, 'B1'), + _row(2, 3, 'A2', None), + _row(2, 4, None, 'B2')], + schema=_SCHEMA, + ) + out = writer._merge_pending_by_pk(data) + self.assertEqual( + sorted(out.to_pylist(), key=lambda r: r['id']), + [_row(1, 2, 'A1', 'B1'), _row(2, 4, 'A2', 'B2')], + ) + + # -- _process_data (no longer sorts) ---------------------------------- + + def test_process_data_adds_system_fields_without_sorting(self): + # Deferred-sort design: ``_process_data`` must not pre-sort the + # incoming batch. The global sort happens once inside + # ``_flush_all`` over the concatenated buffer. + writer = _Harness(DeduplicateMergeFunction()) + writer.sequence_generator = _StubSeqGen() + # Intentionally pass rows in descending PK order; if _process_data + # were still sorting, the output would come back ascending. + batch = pa.RecordBatch.from_pylist( + [{'id': 3, 'a': 'C', 'b': None}, + {'id': 1, 'a': 'A', 'b': None}, + {'id': 2, 'a': 'B', 'b': None}], + schema=pa.schema([ + pa.field('id', pa.int64(), nullable=False), + pa.field('a', pa.string()), + pa.field('b', pa.string()), + ]), + ) + out = writer._process_data(batch) + self.assertEqual( + [r['id'] for r in out.to_pylist()], + [3, 1, 2], + ) + + # -- _flush_all ------------------------------------------------------- + + def test_flush_all_sorts_folds_and_writes_one_file(self): + # Buffer with duplicate PKs in arbitrary order. _flush_all is + # responsible for sorting before folding, so unsorted input is + # the right stress case. + writer = _Harness(DeduplicateMergeFunction()) + writer.pending_data = pa.Table.from_pylist( + [_row(2, 5, 'B2-new', None), + _row(1, 2, 'A1-mid', None), + _row(1, 1, 'A1-old', None), + _row(2, 4, 'B2-old', None), + _row(1, 3, 'A1-new', None)], + schema=_SCHEMA, + ) + writer._flush_all() + + # Buffer cleared. + self.assertIsNone(writer.pending_data) + # Exactly one file written (size well under target). + self.assertEqual(len(writer.written_chunks), 1) + flushed = writer.written_chunks[0] + result = sorted(flushed.to_pylist(), key=lambda r: r['id']) + # Dedup -> 1 row per PK, with the highest seq value retained. + self.assertEqual(result, [ + _row(1, 3, 'A1-new', None), + _row(2, 5, 'B2-new', None), + ]) + + def test_flush_all_on_empty_buffer_is_noop(self): + writer = _Harness(DeduplicateMergeFunction()) + writer.pending_data = None + writer._flush_all() + self.assertIsNone(writer.pending_data) + self.assertEqual(writer.written_chunks, []) + + def test_flush_all_clears_buffer_even_when_fold_drops_everything(self): + # MergeFunction that returns None for every group; verifies + # ``_flush_all`` still resets ``pending_data`` so a subsequent + # write starts from a clean slate. + class DropAll: + def reset(self): + pass + + def add(self, _): + pass + + def get_result(self): + return None + + writer = _Harness(DropAll()) + writer.pending_data = pa.Table.from_pylist( + [_row(1, 1, 'A', None), _row(1, 2, 'B', None)], + schema=_SCHEMA, + ) + writer._flush_all() + self.assertIsNone(writer.pending_data) + self.assertEqual(writer.written_chunks, []) + + # -- _roll_write ------------------------------------------------------ + + def test_roll_write_single_chunk_when_under_target(self): + writer = _Harness(DeduplicateMergeFunction(), + target_file_size=10 ** 9) + data = pa.Table.from_pylist( + [_row(1, 1, 'A', None), _row(2, 2, 'B', None)], + schema=_SCHEMA, + ) + writer._roll_write(data) + self.assertEqual(len(writer.written_chunks), 1) + self.assertEqual(writer.written_chunks[0].num_rows, 2) + + def test_roll_write_splits_oversized_buffer_into_multiple_files(self): + # Build a buffer whose nbytes comfortably exceeds the chosen + # target. With a small target_file_size the writer should hand + # back at least two files. Use long strings so nbytes scales + # predictably with row count. + rows = [ + _row(i, i, 'x' * 64, 'y' * 64) for i in range(1, 401) + ] + data = pa.Table.from_pylist(rows, schema=_SCHEMA) + # Target small enough that 400 rows will not fit in one file. + target = data.nbytes // 4 + writer = _Harness(DeduplicateMergeFunction(), + target_file_size=target) + writer._roll_write(data) + + self.assertGreaterEqual(len(writer.written_chunks), 2) + total_rows = sum(c.num_rows for c in writer.written_chunks) + self.assertEqual(total_rows, data.num_rows) + # Each chunk except possibly the last should respect the target. + for chunk in writer.written_chunks[:-1]: + self.assertLessEqual(chunk.nbytes, target) + + +class _StubSeqGen: + """Stand-in for ``SequenceGenerator`` so the harness can call + ``_process_data`` without going through the real ``DataWriter.__init__``. + """ + + def __init__(self): + self._n = 0 + + def next(self) -> int: + self._n += 1 + return self._n + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/vector_search_filter_test.py b/paimon-python/pypaimon/tests/vector_search_filter_test.py index 3cb072bfa5f9..2932ea8dac0c 100644 --- a/paimon-python/pypaimon/tests/vector_search_filter_test.py +++ b/paimon-python/pypaimon/tests/vector_search_filter_test.py @@ -21,13 +21,20 @@ redundancy. """ +import io +import json +import struct +import sys +import types import unittest from typing import List from unittest import mock from pypaimon.common.predicate import Predicate from pypaimon.common.predicate_builder import PredicateBuilder +from pypaimon.globalindex.btree.btree_index_meta import BTreeIndexMeta from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta, GlobalIndexMeta +from pypaimon.globalindex.global_index_reader import _completed_future from pypaimon.globalindex.global_index_result import GlobalIndexResult from pypaimon.globalindex.vector_search_result import ScoredGlobalIndexResult from pypaimon.index.index_file_meta import IndexFileMeta @@ -117,6 +124,214 @@ def _scan(snapshot, entry_filter=None): testcase._travel_patch.start() +def _java_tantivy_meta(tokenizer="ngram", min_gram=2, max_gram=2, + prefix_only=False, lower_case=True, + max_token_length=40, ascii_folding=False, + stem=False, language="english", + remove_stop_words=False, stop_words="", + with_position=True): + config = {} + if tokenizer != "default": + config["tokenizer"] = tokenizer + if min_gram != 2: + config["ngram.min-gram"] = min_gram + if max_gram != 2: + config["ngram.max-gram"] = max_gram + if prefix_only: + config["ngram.prefix-only"] = prefix_only + if not lower_case: + config["lower-case"] = lower_case + if max_token_length != 40: + config["max-token-length"] = max_token_length + if ascii_folding: + config["ascii-folding"] = ascii_folding + if stem: + config["stem"] = stem + if language != "english": + config["language"] = language + if remove_stop_words: + config["remove-stop-words"] = remove_stop_words + if stop_words: + config["stop-words"] = stop_words + if not with_position: + config["with-position"] = with_position + return json.dumps(config, separators=(",", ":")).encode("utf-8") + + +class _FakeFileIO: + def new_input_stream(self, path): + buf = io.BytesIO() + buf.write(struct.pack(">i", 1)) + name = b"meta.json" + buf.write(struct.pack(">i", len(name))) + buf.write(name) + data = b"{}" + buf.write(struct.pack(">q", len(data))) + buf.write(data) + buf.seek(0) + return buf + + +class _FakeSchemaBuilder: + def __init__(self): + self.fields = {} + + def add_unsigned_field(self, name, stored=False, indexed=True, fast=False): + self.fields[name] = {"fast": fast, "stored": stored} + + def add_text_field(self, name, stored=False, tokenizer_name=None, **kwargs): + if "index_option" in kwargs and kwargs["index_option"] is None: + raise TypeError("index_option must not be None") + self.fields[name] = { + "stored": stored, + "tokenizer_name": tokenizer_name or "default", + } + if "index_option" in kwargs: + self.fields[name]["index_option"] = kwargs["index_option"] + + def build(self): + return types.SimpleNamespace(fields=self.fields) + + +class _FakeTokenizer: + @staticmethod + def ngram(min_gram=2, max_gram=3, prefix_only=False): + return ("ngram", min_gram, max_gram, prefix_only) + + @staticmethod + def simple(): + return ("simple",) + + @staticmethod + def whitespace(): + return ("whitespace",) + + @staticmethod + def raw(): + return ("raw",) + + +class _FakeFilter: + @staticmethod + def lowercase(): + return "lowercase" + + @staticmethod + def remove_long(length_limit): + return ("remove_long", length_limit) + + @staticmethod + def ascii_fold(): + return "ascii_fold" + + @staticmethod + def stemmer(language): + return ("stemmer", language) + + @staticmethod + def stopword(language): + return ("stopword", language) + + @staticmethod + def custom_stopword(stopwords): + return ("custom_stopword", tuple(stopwords)) + + +class _FakeTextAnalyzerBuilder: + def __init__(self, tokenizer): + self._tokenizer = tokenizer + self._filters = [] + + def filter(self, filter_): + result = _FakeTextAnalyzerBuilder(self._tokenizer) + result._filters = self._filters + [filter_] + return result + + def build(self): + return self._tokenizer + (tuple(self._filters),) + + +class _FakeQuery: + @staticmethod + def empty_query(): + return ("empty",) + + @staticmethod + def term_query(schema, field_name, field_value, index_option="position"): + return ("term", schema, field_name, field_value, index_option) + + @staticmethod + def boolean_query(subqueries, minimum_number_should_match=None): + return ("boolean", tuple(subqueries), minimum_number_should_match) + + +class _FakeOccur: + Should = "should" + Must = "must" + + +class _FakeSearchResults: + hits = [(2.0, "addr")] + + +class _FakeSearcher: + def __init__(self): + self.query = None + + def search(self, query, limit): + self.query = query + return _FakeSearchResults() + + def fast_field_values(self, name, addresses): + return [7] + + +class _FakeIndex: + def __init__(self, schema, directory=None): + self.schema = schema + self.directory = directory + self.registered_tokenizer = None + + def register_tokenizer(self, name, analyzer): + self.registered_tokenizer = (name, analyzer) + + def reload(self): + pass + + def searcher(self): + self.searcher_instance = _FakeSearcher() + return self.searcher_instance + + def parse_query(self, query_text, fields, **kwargs): + return (query_text, tuple(fields), kwargs) + + +class _FakeTantivy(types.SimpleNamespace): + def __init__(self): + super().__init__() + self.Tokenizer = _FakeTokenizer + self.Filter = _FakeFilter + self.TextAnalyzerBuilder = _FakeTextAnalyzerBuilder + self.Query = _FakeQuery + self.Occur = _FakeOccur + self.last_schema = None + self.last_index = None + parent = self + + class SchemaBuilder(_FakeSchemaBuilder): + def build(self_inner): + parent.last_schema = super().build() + return parent.last_schema + + class Index(_FakeIndex): + def __init__(self_inner, schema, directory=None): + super().__init__(schema, directory=directory) + parent.last_index = self_inner + + self.SchemaBuilder = SchemaBuilder + self.Index = Index + + # ----------------------------- tests --------------------------------------- @@ -140,6 +355,414 @@ def test_lumina_reader_accepts_new_and_legacy_identifiers(self): reader.close() +class TantivyFullTextIndexOptionsTest(unittest.TestCase): + """Tantivy full-text tokenizer metadata compatibility.""" + + def test_empty_metadata_uses_default_tokenizer(self): + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TantivyFullTextIndexOptions, + ) + + options = TantivyFullTextIndexOptions.deserialize(b"") + + self.assertEqual("default", options.tokenizer) + self.assertEqual(2, options.ngram_min_gram) + self.assertEqual(2, options.ngram_max_gram) + self.assertFalse(options.ngram_prefix_only) + self.assertTrue(options.lower_case) + self.assertEqual(40, options.max_token_length) + self.assertFalse(options.ascii_folding) + self.assertFalse(options.stem) + self.assertEqual("english", options.language) + self.assertFalse(options.remove_stop_words) + self.assertEqual("", options.stop_words) + self.assertTrue(options.with_position) + self.assertEqual("default", options.tokenizer_name()) + + def test_empty_json_metadata_uses_default_tokenizer(self): + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TantivyFullTextIndexOptions, + ) + + options = TantivyFullTextIndexOptions.deserialize(b"{}") + + self.assertEqual("default", options.tokenizer) + self.assertEqual(2, options.ngram_min_gram) + self.assertEqual(2, options.ngram_max_gram) + self.assertFalse(options.ngram_prefix_only) + self.assertTrue(options.lower_case) + self.assertEqual(40, options.max_token_length) + self.assertFalse(options.ascii_folding) + self.assertFalse(options.stem) + self.assertEqual("english", options.language) + self.assertFalse(options.remove_stop_words) + self.assertEqual("", options.stop_words) + self.assertTrue(options.with_position) + self.assertEqual("default", options.tokenizer_name()) + + def test_deserializes_java_ngram_metadata(self): + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TANTIVY_NGRAM_TOKENIZER, + TantivyFullTextIndexOptions, + ) + + options = TantivyFullTextIndexOptions.deserialize( + _java_tantivy_meta( + tokenizer=" NGRAM ", min_gram=2, max_gram=3, + prefix_only=True, lower_case=False)) + + self.assertEqual("ngram", options.tokenizer) + self.assertEqual(2, options.ngram_min_gram) + self.assertEqual(3, options.ngram_max_gram) + self.assertTrue(options.ngram_prefix_only) + self.assertFalse(options.lower_case) + self.assertEqual(TANTIVY_NGRAM_TOKENIZER, options.tokenizer_name()) + + def test_deserializes_java_json_ngram_metadata(self): + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TANTIVY_NGRAM_TOKENIZER, + TantivyFullTextIndexOptions, + ) + + options = TantivyFullTextIndexOptions.deserialize( + _java_tantivy_meta( + tokenizer=" NGRAM ", min_gram=2, max_gram=3, + prefix_only=True, lower_case=False)) + + self.assertEqual("ngram", options.tokenizer) + self.assertEqual(2, options.ngram_min_gram) + self.assertEqual(3, options.ngram_max_gram) + self.assertTrue(options.ngram_prefix_only) + self.assertFalse(options.lower_case) + self.assertEqual(TANTIVY_NGRAM_TOKENIZER, options.tokenizer_name()) + + def test_deserializes_extended_analyzer_metadata(self): + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TantivyFullTextIndexOptions, + ) + + options = TantivyFullTextIndexOptions.deserialize( + _java_tantivy_meta( + tokenizer=" WHITESPACE ", lower_case=False, max_token_length=12, + ascii_folding=True, stem=True, language="English", + remove_stop_words=True, stop_words="paimon;lake", + with_position=False)) + + self.assertEqual("whitespace", options.tokenizer) + self.assertFalse(options.lower_case) + self.assertEqual(12, options.max_token_length) + self.assertTrue(options.ascii_folding) + self.assertTrue(options.stem) + self.assertEqual("english", options.language) + self.assertTrue(options.remove_stop_words) + self.assertEqual("paimon;lake", options.stop_words) + self.assertEqual(["paimon", "lake"], options.stop_word_list()) + self.assertFalse(options.with_position) + self.assertEqual("paimon_custom", options.tokenizer_name()) + + def test_deserializes_java_json_analyzer_metadata(self): + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TantivyFullTextIndexOptions, + ) + + options = TantivyFullTextIndexOptions.deserialize( + _java_tantivy_meta( + tokenizer=" WHITESPACE ", lower_case=False, max_token_length=12, + ascii_folding=True, stem=True, language="English", + remove_stop_words=True, stop_words=["paimon", "lake"], + with_position=False)) + + self.assertEqual("whitespace", options.tokenizer) + self.assertFalse(options.lower_case) + self.assertEqual(12, options.max_token_length) + self.assertTrue(options.ascii_folding) + self.assertTrue(options.stem) + self.assertEqual("english", options.language) + self.assertTrue(options.remove_stop_words) + self.assertEqual("paimon;lake", options.stop_words) + self.assertEqual(["paimon", "lake"], options.stop_word_list()) + self.assertFalse(options.with_position) + self.assertEqual("paimon_custom", options.tokenizer_name()) + + def test_deserializes_java_jieba_metadata(self): + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TANTIVY_JIEBA_TOKENIZER, + TantivyFullTextIndexOptions, + ) + + options = TantivyFullTextIndexOptions.deserialize( + _java_tantivy_meta(tokenizer=" JIEBA ")) + + self.assertEqual("jieba", options.tokenizer) + self.assertEqual(2, options.ngram_min_gram) + self.assertEqual(2, options.ngram_max_gram) + self.assertFalse(options.ngram_prefix_only) + self.assertTrue(options.lower_case) + self.assertEqual(TANTIVY_JIEBA_TOKENIZER, options.tokenizer_name()) + + def test_ngram_reader_registers_matching_tantivy_analyzer(self): + from pypaimon.globalindex.full_text_search import FullTextSearch + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TANTIVY_NGRAM_TOKENIZER, + TantivyFullTextGlobalIndexReader, + ) + + tantivy = _FakeTantivy() + old_tantivy = sys.modules.get("tantivy") + sys.modules["tantivy"] = tantivy + try: + reader = TantivyFullTextGlobalIndexReader( + _FakeFileIO(), + "/unused", + [GlobalIndexIOMeta( + file_name="ft.index", + file_size=1, + metadata=_java_tantivy_meta( + min_gram=2, max_gram=3, + prefix_only=True, lower_case=True))]) + try: + result = reader.visit_full_text_search( + FullTextSearch("中文", 10, "content")).result() + finally: + reader.close() + finally: + if old_tantivy is None: + sys.modules.pop("tantivy", None) + else: + sys.modules["tantivy"] = old_tantivy + + self.assertEqual({"row_id": {"fast": True, "stored": False}, + "text": {"stored": False, + "tokenizer_name": TANTIVY_NGRAM_TOKENIZER}}, + tantivy.last_schema.fields) + self.assertEqual( + (TANTIVY_NGRAM_TOKENIZER, + ("ngram", 2, 3, True, ("lowercase",))), + tantivy.last_index.registered_tokenizer) + self.assertEqual([7], sorted(list(result.results()))) + + query = tantivy.last_index.searcher_instance.query + self.assertEqual(("中文", ("text",), {}), query) + + def test_schema_fallback_for_pre_7670_indexes(self): + from pypaimon.globalindex.full_text_search import FullTextSearch + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TantivyFullTextGlobalIndexReader, + ) + + call_count = [0] + + class _FakeTantivyWithSchemaFallback(_FakeTantivy): + def __init__(self_outer): + super().__init__() + parent = self_outer + + class SchemaBuilder(_FakeSchemaBuilder): + def build(self_inner): + parent.last_schema = super().build() + return parent.last_schema + + class Index(_FakeIndex): + def __init__(self_inner, schema, directory=None): + call_count[0] += 1 + row_id_opts = schema.fields.get("row_id", {}) + if not row_id_opts.get("stored", False): + raise ValueError( + "Schema error: 'An index exists but " + "the schema does not match.'" + ) + super().__init__(schema, directory=directory) + parent.last_index = self_inner + + self_outer.SchemaBuilder = SchemaBuilder + self_outer.Index = Index + + tantivy = _FakeTantivyWithSchemaFallback() + old_tantivy = sys.modules.get("tantivy") + sys.modules["tantivy"] = tantivy + try: + reader = TantivyFullTextGlobalIndexReader( + _FakeFileIO(), "/unused", + [GlobalIndexIOMeta(file_name="ft.index", file_size=1)]) + try: + reader.visit_full_text_search( + FullTextSearch("hello", 5, "content")).result() + finally: + reader.close() + finally: + if old_tantivy is None: + sys.modules.pop("tantivy", None) + else: + sys.modules["tantivy"] = old_tantivy + + self.assertEqual(2, call_count[0]) + self.assertTrue( + tantivy.last_schema.fields["row_id"].get("stored", False)) + + def test_custom_analyzer_reader_registers_matching_tantivy_analyzer(self): + from pypaimon.globalindex.full_text_search import FullTextSearch + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TANTIVY_CUSTOM_TOKENIZER, + TantivyFullTextGlobalIndexReader, + ) + + tantivy = _FakeTantivy() + old_tantivy = sys.modules.get("tantivy") + sys.modules["tantivy"] = tantivy + try: + reader = TantivyFullTextGlobalIndexReader( + _FakeFileIO(), + "/unused", + [GlobalIndexIOMeta( + file_name="ft.index", + file_size=1, + metadata=_java_tantivy_meta( + tokenizer="simple", max_token_length=16, + ascii_folding=True, stem=True, language="english", + remove_stop_words=True, stop_words="paimon;lake", + with_position=False))]) + try: + reader.visit_full_text_search( + FullTextSearch("running", 10, "content")).result() + finally: + reader.close() + finally: + if old_tantivy is None: + sys.modules.pop("tantivy", None) + else: + sys.modules["tantivy"] = old_tantivy + + self.assertEqual({"row_id": {"fast": True, "stored": False}, + "text": {"stored": False, + "tokenizer_name": TANTIVY_CUSTOM_TOKENIZER, + "index_option": "freq"}}, + tantivy.last_schema.fields) + self.assertEqual( + (TANTIVY_CUSTOM_TOKENIZER, + ("simple", + (("remove_long", 16), "lowercase", "ascii_fold", ("stemmer", "english"), + ("stopword", "english"), ("custom_stopword", ("paimon", "lake"))))), + tantivy.last_index.registered_tokenizer) + + def test_ngram_reader_requires_custom_tokenizer_api(self): + from pypaimon.globalindex.full_text_search import FullTextSearch + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TantivyFullTextGlobalIndexReader, + ) + + old_tantivy = sys.modules.get("tantivy") + sys.modules["tantivy"] = types.SimpleNamespace(SchemaBuilder=object) + try: + reader = TantivyFullTextGlobalIndexReader( + _FakeFileIO(), + "/unused", + [GlobalIndexIOMeta( + file_name="ft.index", + file_size=1, + metadata=_java_tantivy_meta())]) + with self.assertRaisesRegex( + RuntimeError, "ngram tokenizer support"): + reader.visit_full_text_search( + FullTextSearch("中文", 10, "content")).result() + finally: + if old_tantivy is None: + sys.modules.pop("tantivy", None) + else: + sys.modules["tantivy"] = old_tantivy + + def test_jieba_reader_builds_token_query(self): + from pypaimon.globalindex.full_text_search import FullTextSearch + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TANTIVY_JIEBA_TOKENIZER, + TantivyFullTextGlobalIndexReader, + ) + + tantivy = _FakeTantivy() + jieba = types.SimpleNamespace( + tokenize=lambda text, mode, HMM: [ + ("售货", 0, 2), + ("货员", 1, 3), + ("售货员", 0, 3), + ("售货员", 0, 3)]) + old_tantivy = sys.modules.get("tantivy") + old_jieba = sys.modules.get("jieba") + sys.modules["tantivy"] = tantivy + sys.modules["jieba"] = jieba + try: + reader = TantivyFullTextGlobalIndexReader( + _FakeFileIO(), + "/unused", + [GlobalIndexIOMeta( + file_name="ft.index", + file_size=1, + metadata=_java_tantivy_meta(tokenizer="jieba"))]) + try: + result = reader.visit_full_text_search( + FullTextSearch("售货员", 10, "content", "and")).result() + finally: + reader.close() + finally: + if old_tantivy is None: + sys.modules.pop("tantivy", None) + else: + sys.modules["tantivy"] = old_tantivy + if old_jieba is None: + sys.modules.pop("jieba", None) + else: + sys.modules["jieba"] = old_jieba + + self.assertEqual({"row_id": {"fast": True, "stored": False}, + "text": {"stored": False, + "tokenizer_name": TANTIVY_JIEBA_TOKENIZER}}, + tantivy.last_schema.fields) + self.assertIsNone(tantivy.last_index.registered_tokenizer) + self.assertEqual([7], sorted(list(result.results()))) + + query = tantivy.last_index.searcher_instance.query + self.assertEqual("boolean", query[0]) + self.assertEqual( + ("售货", "货员", "售货员"), + tuple(sub_query[1][3] for sub_query in query[1])) + self.assertEqual( + ("must", "must", "must"), + tuple(sub_query[0] for sub_query in query[1])) + + def test_jieba_reader_requires_jieba_package(self): + from pypaimon.globalindex.full_text_search import FullTextSearch + from pypaimon.globalindex.tantivy.tantivy_full_text_global_index_reader import ( + TantivyFullTextGlobalIndexReader, + ) + + tantivy = _FakeTantivy() + old_tantivy = sys.modules.get("tantivy") + old_jieba = sys.modules.get("jieba") + sys.modules["tantivy"] = tantivy + sys.modules["jieba"] = None + try: + reader = TantivyFullTextGlobalIndexReader( + _FakeFileIO(), + "/unused", + [GlobalIndexIOMeta( + file_name="ft.index", + file_size=1, + metadata=_java_tantivy_meta(tokenizer="jieba"))]) + try: + with self.assertRaisesRegex(RuntimeError, "pip install jieba"): + reader.visit_full_text_search( + FullTextSearch("售货员", 10, "content")).result() + finally: + reader.close() + finally: + if old_tantivy is None: + sys.modules.pop("tantivy", None) + else: + sys.modules["tantivy"] = old_tantivy + if old_jieba is None: + sys.modules.pop("jieba", None) + else: + sys.modules["jieba"] = old_jieba + + class VectorSearchFilterTest(unittest.TestCase): """Non-partitioned wiring: scan + read + external_path plumbing.""" @@ -210,7 +833,7 @@ def test_read_threads_prefilter_bitmap_as_include_row_ids(self): for rid in range(5, 10): bitmap.add(rid) scanner = mock.MagicMock() - scanner.scan.return_value = GlobalIndexResult.create(lambda: bitmap) + scanner.scan.return_value = GlobalIndexResult.create(bitmap) captured_searches = [] captured_io_metas = [] @@ -222,7 +845,7 @@ def _capture_create(index_type, file_io, index_path, class _FakeReader: def visit_vector_search(self_inner, vs): captured_searches.append(vs) - return ScoredGlobalIndexResult.create_empty() + return _completed_future(ScoredGlobalIndexResult.create_empty()) def close(self_inner): pass @@ -266,33 +889,34 @@ def test_scanner_threads_external_path_to_btree_reader(self): from pypaimon.globalindex.global_index_scanner import GlobalIndexScanner scalar_file = self.entries[2].index_file - scanner = GlobalIndexScanner( - fields=self.table.fields, - file_io=self.table.file_io, - index_path="/unused/index-path", - index_files=[scalar_file], - ) - captured = [] - class _FakeBTreeReader: + captured_io_metas = [] + + class _FakeLazyReader: def __init__(self_inner, key_serializer, file_io, index_path, - io_meta): - captured.append(io_meta) + io_metas, executor=None): + captured_io_metas.append(list(io_metas)) def close(self_inner): pass - try: - with mock.patch( - "pypaimon.globalindex.btree.BTreeIndexReader", - _FakeBTreeReader): + with mock.patch( + "pypaimon.globalindex.btree.lazy_filtered_btree_reader.LazyFilteredBTreeReader", + _FakeLazyReader): + scanner = GlobalIndexScanner( + fields=self.table.fields, + file_io=self.table.file_io, + index_path="/unused/index-path", + index_files=[scalar_file], + ) + try: list(scanner._evaluator._readers_function(self.id_field)) - finally: - scanner.close() + finally: + scanner.close() - self.assertEqual(1, len(captured)) + self.assertEqual(1, len(captured_io_metas)) self.assertEqual("oss://bucket/id-btree-0.index", - captured[0].external_path) + captured_io_metas[0][0].external_path) class VectorSearchMultiShardScalarTest(unittest.TestCase): @@ -306,7 +930,6 @@ class VectorSearchMultiShardScalarTest(unittest.TestCase): """ def test_hit_only_in_later_shard_returns_global_row_id(self): - from pypaimon.globalindex.global_index_reader import GlobalIndexReader from pypaimon.globalindex.global_index_result import GlobalIndexResult from pypaimon.globalindex.global_index_scanner import ( GlobalIndexScanner, @@ -327,34 +950,44 @@ def test_hit_only_in_later_shard_returns_global_row_id(self): # Stub BTreeIndexReader: shard_a returns empty, shard_b returns {2} # (local row id). After Offset wrapping the scanner should emit {7}. - class _StubBTreeReader(GlobalIndexReader): + class _StubBTreeReader: def __init__(self_inner, key_serializer, file_io, index_path, io_meta): self_inner._file = io_meta.file_name - def visit_equal(self_inner, field_ref, literal): + def visit_equal(self_inner, literal): bm = RoaringBitmap64() if self_inner._file == "id-1.index": bm.add(2) # local offset inside [5,9] - return GlobalIndexResult.create(lambda b=bm: b) + return GlobalIndexResult.create(bm) def close(self_inner): pass - with mock.patch("pypaimon.globalindex.btree.BTreeIndexReader", - _StubBTreeReader): - scanner = GlobalIndexScanner( - fields=table.fields, - file_io=table.file_io, - index_path="/unused", - index_files=[shard_a, shard_b], - ) - try: - result = scanner.scan( - Predicate(method="equal", index=0, field="id", - literals=[7])) - finally: - scanner.close() + import struct + wide_meta = BTreeIndexMeta( + first_key=struct.pack(' None: + self.references_by_field.setdefault(view_struct.field_id, []).append(view_struct) + self.row_ids.append(int(view_struct.row_id)) + + +class TableReadPlan: + """A plan for reading blob descriptors from one upstream table.""" + + def __init__(self, identifier: Identifier, upstream_table, + read_fields: List, row_ranges: List[Range]): + self.identifier: Identifier = identifier + self.upstream_table = upstream_table + self.read_fields: List = read_fields + self.row_ranges: List[Range] = row_ranges + + +class BlobViewLookup: + """Resolve BlobViewStruct references by reading upstream blob descriptors.""" + + def __init__(self, table): + self._table = table + self._descriptor_cache: Dict[BlobViewStruct, BlobDescriptor] = {} + self._null_value_cache: Set[BlobViewStruct] = set() + + def preload(self, view_structs: List[BlobViewStruct]): + if not view_structs: + return + + grouped: Dict[str, TableReferences] = self._group_by_table(view_structs) + plans: List[TableReadPlan] = [] + for table_refs in grouped.values(): + plans.append(self._create_table_read_plan(table_refs)) + + target_rows: int = self._target_rows_per_task(plans) + tasks: List[Tuple[TableReadPlan, List[Range]]] = [] + for plan in plans: + for range_chunk in self._split_row_ranges(plan.row_ranges, target_rows): + tasks.append((plan, range_chunk)) + + if len(tasks) <= 1: + for plan, range_chunk in tasks: + descriptors, null_values = self._load_descriptor_chunk(plan, range_chunk) + self._descriptor_cache.update(descriptors) + self._null_value_cache.update(null_values) + return + + with ThreadPoolExecutor(max_workers=min(_PRELOAD_THREAD_NUM, len(tasks))) as executor: + futures = { + executor.submit(self._load_descriptor_chunk, plan, range_chunk): (plan, range_chunk) + for plan, range_chunk in tasks + } + for future in as_completed(futures): + try: + descriptors, null_values = future.result() + self._descriptor_cache.update(descriptors) + self._null_value_cache.update(null_values) + except Exception as exc: + # Cancel remaining futures that have not started yet so a single + # failure can abort the rest of the preload work as early as possible. + for pending_future in futures: + pending_future.cancel() + raise RuntimeError("Failed to preload blob descriptors.") from exc + + def resolve_descriptor(self, view_struct: BlobViewStruct) -> BlobDescriptor: + descriptor: BlobDescriptor = self._descriptor_cache.get(view_struct) + if descriptor is None: + if view_struct in self._null_value_cache: + raise ValueError( + "BlobViewStruct {} resolves to a null blob value.".format(view_struct) + ) + raise ValueError( + "Cannot resolve BlobViewStruct {} because row id {} was not found " + "in upstream table.".format(view_struct, view_struct.row_id) + ) + return descriptor + + def resolve_to_null(self, view_struct: BlobViewStruct) -> bool: + if view_struct in self._null_value_cache: + return True + if view_struct not in self._descriptor_cache: + raise ValueError( + "Cannot resolve BlobViewStruct {} because row id {} was not found " + "in upstream table.".format(view_struct, view_struct.row_id) + ) + return False + + def _group_by_table( + self, view_structs: List[BlobViewStruct] + ) -> Dict[str, TableReferences]: + grouped: Dict[str, TableReferences] = {} + for view_struct in view_structs: + key = view_struct.identifier.get_full_name() + if key not in grouped: + grouped[key] = TableReferences(view_struct.identifier) + grouped[key].add(view_struct) + return grouped + + def _create_table_read_plan(self, table_refs: TableReferences) -> TableReadPlan: + upstream_table = self._load_table(table_refs.identifier) + + fields: List = [] + for field_id in table_refs.references_by_field: + fields.append(self._field_by_id(upstream_table, field_id)) + + read_fields = SpecialFields.row_type_with_row_id(fields) + return TableReadPlan( + table_refs.identifier, upstream_table, read_fields, + Range.to_ranges(table_refs.row_ids)) + + def _load_descriptor_chunk( + self, plan: TableReadPlan, row_ranges: List[Range] + ) -> Tuple[Dict[BlobViewStruct, BlobDescriptor], set]: + identifier: Identifier = plan.identifier + upstream_table = plan.upstream_table + read_fields = plan.read_fields + + projection_field_names: List[str] = [f.name for f in read_fields] + + descriptor_table = upstream_table.copy({CoreOptions.BLOB_AS_DESCRIPTOR.key(): "true"}) + read_builder = descriptor_table.new_read_builder().with_projection(projection_field_names) + + if SpecialFields.ROW_ID.name not in [ + data_field.name for data_field in read_builder.read_type() + ]: + raise ValueError( + "Cannot resolve blob view for table {} because row tracking is not readable." + .format(identifier.get_full_name()) + ) + + predicate_builder = read_builder.new_predicate_builder() + range_predicates: List = [] + for r in row_ranges: + if r.from_ == r.to: + range_predicates.append( + predicate_builder.equal(SpecialFields.ROW_ID.name, r.from_)) + else: + range_predicates.append( + predicate_builder.between(SpecialFields.ROW_ID.name, r.from_, r.to)) + if len(range_predicates) == 1: + predicate = range_predicates[0] + else: + predicate = predicate_builder.or_predicates(range_predicates) + read_builder.with_filter(predicate) + result = read_builder.new_read().to_arrow(read_builder.new_scan().plan().splits()) + + if SpecialFields.ROW_ID.name not in result.schema.names: + raise ValueError( + "Cannot resolve blob view for table {} because row tracking is not readable." + .format(identifier.get_full_name()) + ) + + row_id_values: List = result.column(SpecialFields.ROW_ID.name).to_pylist() + resolved: Dict[BlobViewStruct, BlobDescriptor] = {} + null_values: set = set() + for field in read_fields: + if field.name == SpecialFields.ROW_ID.name: + continue + if field.name not in result.schema.names: + continue + values = result.column(field.name).to_pylist() + for row_id, value in zip(row_id_values, values): + view_struct = BlobViewStruct( + identifier.get_full_name(), field.id, int(row_id)) + if value is None: + null_values.add(view_struct) + continue + descriptor = BlobDescriptor.deserialize(value) + resolved[view_struct] = descriptor + return resolved, null_values + + @staticmethod + def _split_row_ranges( + row_ranges: List[Range], target_rows_per_task: int + ) -> List[List[Range]]: + """ + Split row ranges into multiple chunks for parallel task processing. + """ + if not row_ranges: + return [] + + chunks: List[List[Range]] = [] + current_chunk: List[Range] = [] + current_chunk_rows: int = 0 + + for r in row_ranges: + next_from = r.from_ + # Process current range until all rows are allocated + while next_from <= r.to: + # If current chunk is full, save it and start a new one + if current_chunk_rows == target_rows_per_task: + chunks.append(current_chunk) + current_chunk = [] + current_chunk_rows = 0 + + # Calculate remaining capacity in current chunk + remaining = target_rows_per_task - current_chunk_rows + # Determine the end position for this allocation (don't exceed range boundary) + next_to = min(r.to, next_from + remaining - 1) + + # Add the allocated range to current chunk + current_chunk.append(Range(next_from, next_to)) + current_chunk_rows += next_to - next_from + 1 + + # Move to next unallocated position + next_from = next_to + 1 + + # Don't forget the last chunk if it has any ranges + if current_chunk: + chunks.append(current_chunk) + + return chunks + + @staticmethod + def _target_rows_per_task(plans: List[TableReadPlan]) -> int: + total_rows: int = 0 + for plan in plans: + for r in plan.row_ranges: + total_rows += r.count() + if total_rows <= 0: + return _MIN_ROWS_PER_TASK + + return max(_MIN_ROWS_PER_TASK, (total_rows + _PRELOAD_THREAD_NUM - 1) // _PRELOAD_THREAD_NUM) + + def _load_table(self, identifier: Identifier): + catalog = self._table.catalog_environment.catalog_loader.load() + return catalog.get_table(identifier) + + @staticmethod + def _field_by_id(table, field_id: int) -> 'DataField': + for field in table.table_schema.fields: + if field.id == field_id: + return field + raise ValueError( + "Cannot find blob fieldId {} in upstream table {}." + .format(field_id, table.identifier.get_full_name()) + ) diff --git a/paimon-python/pypaimon/utils/file_store_path_factory.py b/paimon-python/pypaimon/utils/file_store_path_factory.py index d959ae4ffa55..3816a2ce29a9 100644 --- a/paimon-python/pypaimon/utils/file_store_path_factory.py +++ b/paimon-python/pypaimon/utils/file_store_path_factory.py @@ -55,6 +55,8 @@ def __init__( file_compression: str, data_file_path_directory: Optional[str] = None, external_paths: Optional[List[str]] = None, + external_path_strategy: str = "round-robin", + external_path_weights: Optional[List[int]] = None, index_file_in_data_file_dir: bool = False, ): self._root = root.rstrip('/') @@ -67,6 +69,8 @@ def __init__( self.file_compression = file_compression self.data_file_path_directory = data_file_path_directory self.external_paths = external_paths or [] + self.external_path_strategy = external_path_strategy + self.external_path_weights = external_path_weights self.index_file_in_data_file_dir = index_file_in_data_file_dir self.legacy_partition_name = legacy_partition_name @@ -124,7 +128,12 @@ def create_external_path_provider( return None relative_bucket_path = self.relative_bucket_path(partition, bucket) - return ExternalPathProvider(self.external_paths, relative_bucket_path) + return ExternalPathProvider.create( + self.external_path_strategy, + self.external_paths, + relative_bucket_path, + self.external_path_weights, + ) def global_index_path_factory(self) -> 'IndexPathFactory': return IndexPathFactory(self.index_path()) diff --git a/paimon-python/pypaimon/utils/file_type.py b/paimon-python/pypaimon/utils/file_type.py index 384f927da722..866c70c3e9a8 100644 --- a/paimon-python/pypaimon/utils/file_type.py +++ b/paimon-python/pypaimon/utils/file_type.py @@ -32,7 +32,7 @@ class FileType(Enum): hint files, _SUCCESS, consumer, service files - DATA: data files and any unrecognized files (default) - BUCKET_INDEX: bucket level index files (Hash, DV) - - GLOBAL_INDEX: table level global index files (btree, bitmap, lumina, tantivy) + - GLOBAL_INDEX: table level global index files (btree, lumina, tantivy) - FILE_INDEX: data-file index files (bloom filter, bitmap, etc.) """ META = "META" diff --git a/paimon-python/pypaimon/write/blob_format_writer.py b/paimon-python/pypaimon/write/blob_format_writer.py index 92257f4ca97a..6004425b7181 100644 --- a/paimon-python/pypaimon/write/blob_format_writer.py +++ b/paimon-python/pypaimon/write/blob_format_writer.py @@ -17,9 +17,9 @@ import struct import zlib -from typing import BinaryIO, List +from typing import BinaryIO, List, Optional -from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor +from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor, BlobConsumer from pypaimon.common.delta_varint_compressor import DeltaVarintCompressor @@ -31,8 +31,12 @@ class BlobFormatWriter: BUFFER_SIZE = 4096 METADATA_SIZE = 12 # 8-byte length + 4-byte CRC - def __init__(self, output_stream: BinaryIO): + def __init__(self, output_stream: BinaryIO, + blob_consumer: Optional[BlobConsumer] = None, + file_path: Optional[str] = None): self.output_stream = output_stream + self._blob_consumer = blob_consumer + self._file_path = file_path self.lengths: List[int] = [] self.position = 0 @@ -40,9 +44,12 @@ def add_element(self, row) -> None: if not hasattr(row, 'values') or len(row.values) != 1: raise ValueError("BlobFormatWriter only supports one field") + blob_field_name = row.fields[0].name blob_value = row.values[0] if blob_value is None: self.lengths.append(self.NULL_LENGTH) + if self._blob_consumer is not None: + self._blob_consumer(blob_field_name, None) return if not isinstance(blob_value, Blob): @@ -59,6 +66,8 @@ def add_element(self, row) -> None: magic_bytes = struct.pack(' None: finally: stream.close() + blob_length = self.position - blob_pos + # Calculate total length including magic + data + metadata (length + CRC) bin_length = self.position - previous_pos + self.METADATA_SIZE self.lengths.append(bin_length) @@ -88,6 +99,12 @@ def add_element(self, row) -> None: self.output_stream.write(crc_bytes) self.position += 4 + if self._blob_consumer is not None: + descriptor = BlobDescriptor(self._file_path, blob_pos, blob_length) + flush = self._blob_consumer(blob_field_name, descriptor) + if flush: + self.output_stream.flush() + def _write_with_crc(self, data: bytes, crc32: int) -> int: crc32 = zlib.crc32(data, crc32) self.output_stream.write(data) diff --git a/paimon-python/pypaimon/write/commit_message.py b/paimon-python/pypaimon/write/commit_message.py index 7bce06d8ab13..552df0000fc1 100644 --- a/paimon-python/pypaimon/write/commit_message.py +++ b/paimon-python/pypaimon/write/commit_message.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. -from dataclasses import dataclass -from typing import List, Tuple, Optional +from dataclasses import dataclass, field +from typing import List, Tuple, Optional, TYPE_CHECKING from pypaimon.manifest.schema.data_file_meta import DataFileMeta +if TYPE_CHECKING: + from pypaimon.manifest.index_manifest_entry import IndexManifestEntry + @dataclass class CommitMessage: @@ -27,6 +30,7 @@ class CommitMessage: bucket: int new_files: List[DataFileMeta] check_from_snapshot: Optional[int] = -1 + index_deletes: List['IndexManifestEntry'] = field(default_factory=list) def is_empty(self): - return not self.new_files + return not self.new_files and not self.index_deletes diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index b5c9976aadc6..41440ae862c2 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -143,6 +143,31 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): logger.info("Finished collecting changes, including: %d entries", len(commit_entries)) + index_deletes = [] + for msg in commit_messages: + index_deletes.extend(msg.index_deletes) + + if not index_deletes: + from pypaimon.write.global_index_update_checker import ( + apply_global_index_update_action, + ) + updated_cols = set() + written_partitions = set() + for msg in commit_messages: + if msg.check_from_snapshot == -1: + continue + for f in msg.new_files: + if f.write_cols: + updated_cols.update(f.write_cols) + written_partitions.add(msg.partition) + if updated_cols: + snapshot = self.snapshot_manager.get_latest_snapshot() + index_msgs = apply_global_index_update_action( + self.table, snapshot, list(updated_cols), written_partitions, + ) + for m in index_msgs: + index_deletes.extend(m.index_deletes) + commit_kind = "APPEND" detect_conflicts = False allow_rollback = False @@ -158,7 +183,8 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): commit_identifier=commit_identifier, commit_entries_plan=lambda snapshot: commit_entries, detect_conflicts=detect_conflicts, - allow_rollback=allow_rollback) + allow_rollback=allow_rollback, + index_deletes=index_deletes) def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in overwrite mode.""" @@ -244,7 +270,7 @@ def truncate_table(self, commit_identifier: int) -> None: ) def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, - detect_conflicts=False, allow_rollback=False): + detect_conflicts=False, allow_rollback=False, index_deletes=None): retry_count = 0 retry_result = None @@ -255,7 +281,7 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, # No entries to commit (e.g. drop_partitions with no matching data): skip commit # to avoid creating manifest/snapshot with empty partition_stats (causes read errors). - if not commit_entries: + if not commit_entries and not index_deletes: break result = self._try_commit_once( @@ -266,6 +292,7 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, latest_snapshot=latest_snapshot, detect_conflicts=detect_conflicts, allow_rollback=allow_rollback, + index_deletes=index_deletes, ) if result.is_success(): @@ -317,7 +344,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str commit_entries: List[ManifestEntry], commit_identifier: int, latest_snapshot: Optional[Snapshot], detect_conflicts: bool = False, - allow_rollback: bool = False) -> CommitResult: + allow_rollback: bool = False, + index_deletes=None) -> CommitResult: start_millis = int(time.time() * 1000) if self._is_duplicate_commit(retry_result, latest_snapshot, commit_identifier, commit_kind): return SuccessResult() @@ -328,6 +356,7 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str # process new_manifest new_manifest_file = f"manifest-{str(uuid.uuid4())}-0" + new_index_manifest = None # process snapshot new_snapshot_id = latest_snapshot.id + 1 if latest_snapshot else 1 @@ -378,6 +407,13 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str index_manifest = None if latest_snapshot and commit_kind == "APPEND": index_manifest = latest_snapshot.index_manifest + if index_deletes: + from pypaimon.manifest.index_manifest_file import IndexManifestFile + previous_index_manifest = index_manifest + index_manifest = IndexManifestFile(self.table).combine_deletes( + previous_index_manifest, index_deletes) + if index_manifest != previous_index_manifest: + new_index_manifest = index_manifest snapshot_data = Snapshot( version=3, @@ -397,7 +433,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str # Generate partition statistics for the commit statistics = self._generate_partition_statistics(commit_entries) except Exception as e: - self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list) + self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list, + new_index_manifest) logger.warning(f"Exception occurs when preparing snapshot: {e}", exc_info=True) raise RuntimeError(f"Failed to prepare snapshot: {e}") @@ -417,7 +454,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str commit_kind, commit_time_s, ) - self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list) + self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list, + new_index_manifest) return RetryResult(latest_snapshot, None) except Exception as e: # Commit exception, not sure about the situation and should not clean up the files @@ -598,10 +636,14 @@ def _commit_retry_wait(self, retry_count: int): def _cleanup_preparation_failure(self, delta_manifest_list: Optional[str], - base_manifest_list: Optional[str]): + base_manifest_list: Optional[str], + index_manifest: Optional[str] = None): try: manifest_path = self.manifest_list_manager.manifest_path + if index_manifest: + self.table.file_io.delete_quietly(f"{manifest_path}/{index_manifest}") + if delta_manifest_list: manifest_files = self.manifest_list_manager.read(delta_manifest_list) for manifest_meta in manifest_files: diff --git a/paimon-python/pypaimon/write/file_store_write.py b/paimon-python/pypaimon/write/file_store_write.py index c31a9f8a91f3..c77f88e907be 100644 --- a/paimon-python/pypaimon/write/file_store_write.py +++ b/paimon-python/pypaimon/write/file_store_write.py @@ -15,15 +15,19 @@ # specific language governing permissions and limitations # under the License. +import logging import random from typing import Dict, List, Tuple import pyarrow as pa + +logger = logging.getLogger(__name__) + from pypaimon.common.options.core_options import CoreOptions from pypaimon.write.commit_message import CommitMessage from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter -from pypaimon.write.writer.data_blob_writer import DataBlobWriter +from pypaimon.write.writer.dedicated_format_writer import DedicatedFormatWriter from pypaimon.write.writer.data_vector_writer import DataVectorWriter from pypaimon.write.writer.data_writer import DataWriter from pypaimon.write.writer.key_value_data_writer import KeyValueDataWriter @@ -40,6 +44,7 @@ def __init__(self, table, commit_user): self.data_writers: Dict[Tuple, DataWriter] = {} self.max_seq_numbers: dict = {} self.write_cols = None + self.blob_consumer = None self.commit_identifier = 0 self.options = CoreOptions.copy(table.options) if self.table.bucket_mode() == BucketMode.POSTPONE_MODE: @@ -65,13 +70,14 @@ def max_seq_number(): # Check if table has blob columns if self._has_blob_columns(): - return DataBlobWriter( + return DedicatedFormatWriter( table=self.table, partition=partition, bucket=bucket, max_seq_number=0, options=options, write_cols=self.write_cols, + blob_consumer=self.blob_consumer, ) elif self._has_vector_columns() and options.with_vector_format(): return DataVectorWriter( @@ -88,7 +94,8 @@ def max_seq_number(): partition=partition, bucket=bucket, max_seq_number=max_seq_number(), - options=options) + options=options, + merge_function=self._build_pk_merge_function()) else: seq_number = 0 if self.table.bucket_mode() == BucketMode.BUCKET_UNAWARE else max_seq_number() return AppendOnlyDataWriter( @@ -100,6 +107,111 @@ def max_seq_number(): write_cols=self.write_cols ) + def _build_pk_merge_function(self): + """Build the merge function for the in-memory write buffer. + + Shares ``merge_engine_dispatch.build_merge_function`` with the + read path so the supported engines (deduplicate, first-row, + partial-update with no out-of-scope options) cannot drift + between sides. + + For wholly unsupported engines (``aggregation``) the writer + falls back to ``DeduplicateMergeFunction`` so the flushed file + still maintains the LSM "PK unique within a file" invariant. + The read path's dispatch still raises ``NotImplementedError``, + so the user gets an explicit error before they observe + wrong-engine data; the fallback only narrows the damage to + "file is deduped, not aggregated" rather than the silent + multi-row-per-PK corruption that existed pre-PR. + + Partial-update with out-of-scope options (sequence-group, + per-field aggregator, ignore-delete, remove-record-on-*) does + **not** fall back: ``partial_update_unsupported_options`` sees + the configured keys and re-raises, so the first + ``write_arrow`` call (where ``_create_data_writer`` first runs) + surfaces the error. Silently degrading to dedupe there is the + same live corruption pattern this PR exists to close. + + ``with_write_type`` (column-subset writes) on a PK table is + also rejected here. The buffer layout + ``_add_system_fields`` produces would carry only the subset + on the value side, while a ``MergeFunction`` such as + ``PartialUpdateMergeFunction`` is built against the full table + arity -- the two sides would mismatch on + ``KeyValue.value.get_field`` and raise ``IndexError`` at + flush time. Refusing it explicitly avoids that obscure failure + and keeps the supported surface narrow. + + The value-side schema must match the layout + ``KeyValueDataWriter`` flushes -- ``_add_system_fields`` keeps + every original user column on the value side (the primary keys + are duplicated as ``_KEY_`` columns to the left of the + value side). So ``value_arity`` here is ``len(table.fields)``, + not ``len(table.fields) - len(primary_keys)``. + """ + from pypaimon.common.merge_engine_dispatch import ( + build_merge_function, partial_update_unsupported_options) + from pypaimon.common.options.core_options import MergeEngine + from pypaimon.read.reader.deduplicate_merge_function import \ + DeduplicateMergeFunction + + engine = self.options.merge_engine() + raw_options = self.options.options.to_map() + + if self.write_cols is not None: + raise NotImplementedError( + "with_write_type is not yet supported on primary-key " + "tables: the writer-side merge buffer assumes the " + "input batch carries the full table schema. Drop the " + "with_write_type call or write the missing columns as " + "nulls in the input batch." + ) + + # PARTIAL_UPDATE + out-of-scope option: never silently fall + # back -- forward the read-side error verbatim so writes fail + # before the first flush rather than corrupt the file. + if engine == MergeEngine.PARTIAL_UPDATE \ + and partial_update_unsupported_options(raw_options): + return build_merge_function( + engine=engine, raw_options=raw_options, + key_arity=len(self.table.trimmed_primary_keys), + value_arity=len(self.table.table_schema.fields), + value_field_nullables=[ + f.type.nullable for f in self.table.table_schema.fields], + value_field_names=[ + f.name for f in self.table.table_schema.fields], + ) + + # Catch the dispatch's "wholly unsupported engine" raise only + # for the engines we know are out of scope today; any other + # NotImplementedError is a bug we want to surface, not swallow. + if engine == MergeEngine.AGGREGATE: + # Surface the silent semantic mismatch in logs: the file + # will be PK-unique (better than the pre-PR multi-row + # corruption), but any reader that honours the declared + # engine will see wrong values. Users sharing tables + # across writers especially need to see this. + logger.warning( + "merge-engine '%s' is not implemented on the pypaimon " + "write path; falling back to deduplicate so the flushed " + "file stays PK-unique. The file contents reflect " + "deduplicate semantics (latest writer wins), not %s " + "semantics. Any reader that interprets the file under " + "the declared engine will return incorrect results. " + "Avoid the pypaimon writer for tables on this engine.", + engine.value, engine.value) + return DeduplicateMergeFunction() + + all_value_fields = self.table.table_schema.fields + return build_merge_function( + engine=engine, raw_options=raw_options, + key_arity=len(self.table.trimmed_primary_keys), + value_arity=len(all_value_fields), + value_field_nullables=[ + f.type.nullable for f in all_value_fields], + value_field_names=[f.name for f in all_value_fields], + ) + def _has_blob_columns(self) -> bool: """Check if the table schema contains blob columns.""" for field in self.table.table_schema.fields: diff --git a/paimon-python/pypaimon/write/global_index_update_checker.py b/paimon-python/pypaimon/write/global_index_update_checker.py new file mode 100644 index 000000000000..ac405fc4bcd5 --- /dev/null +++ b/paimon-python/pypaimon/write/global_index_update_checker.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Sequence, Set, Tuple + +from pypaimon.common.options.core_options import GlobalIndexColumnUpdateAction + + +def scan_global_index_entries(table, snapshot): + from pypaimon.index.index_file_handler import IndexFileHandler + + handler = IndexFileHandler(table=table) + return handler.scan( + snapshot, lambda e: e.index_file.global_index_meta is not None + ) + + +def build_index_delete_msgs(entries) -> list: + from pypaimon.manifest.index_manifest_entry import IndexManifestEntry + from pypaimon.write.commit_message import CommitMessage + + by_partition = {} + for e in entries: + key = tuple(e.partition.values) + by_partition.setdefault(key, []).append( + IndexManifestEntry( + kind=1, partition=e.partition, bucket=e.bucket, index_file=e.index_file + ) + ) + return [ + CommitMessage(partition=key, bucket=0, new_files=[], index_deletes=dels) + for key, dels in by_partition.items() + ] + + +def apply_global_index_update_action( + table, + snapshot, + updated_cols: Sequence[str], + written_partitions: Set[Tuple], +) -> list: + if snapshot is None or not updated_cols or not written_partitions: + return [] + entries = scan_global_index_entries(table, snapshot) + if not entries: + return [] + field_by_id = {f.id: f.name for f in table.fields} + update_set = set(updated_cols) + affected = [ + e for e in entries + if field_by_id.get(e.index_file.global_index_meta.index_field_id) in update_set + and tuple(e.partition.values) in written_partitions + ] + if not affected: + return [] + action = table.options.global_index_column_update_action() + if action is None: + action = GlobalIndexColumnUpdateAction.THROW_ERROR + if action == GlobalIndexColumnUpdateAction.DROP_PARTITION_INDEX: + return build_index_delete_msgs(affected) + conflicted = sorted( + {field_by_id.get(e.index_file.global_index_meta.index_field_id) for e in affected} + ) + raise RuntimeError( + f"Update columns contain globally indexed columns, not supported now.\n" + f"Updated columns: {sorted(update_set)}\n" + f"Conflicted columns: {conflicted}" + ) diff --git a/paimon-python/pypaimon/write/ray_datasink.py b/paimon-python/pypaimon/write/ray_datasink.py index a01387b32712..6d48906f9fd8 100644 --- a/paimon-python/pypaimon/write/ray_datasink.py +++ b/paimon-python/pypaimon/write/ray_datasink.py @@ -20,7 +20,7 @@ """ import logging -from typing import TYPE_CHECKING, Any, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional from ray.data.datasource.datasink import Datasink @@ -72,13 +72,18 @@ def __init__( self, table: "Table", overwrite: bool = False, + static_partition: Optional[Dict[str, Any]] = None, ): self.table = table self.overwrite = overwrite + self.static_partition = static_partition self._table_name = table.identifier.get_full_name() self._writer_builder: Optional["WriteBuilder"] = None self._pending_commit_messages: List["CommitMessage"] = [] + def _is_overwrite(self) -> bool: + return self.overwrite or self.static_partition is not None + def __getstate__(self) -> dict: state = self.__dict__.copy() return state @@ -90,13 +95,15 @@ def __setstate__(self, state: dict) -> None: self._writer_builder = None if not hasattr(self, '_table_name'): self._table_name = self.table.identifier.get_full_name() + if not hasattr(self, 'static_partition'): + self.static_partition = None def on_write_start(self, schema=None) -> None: logger.info(f"Starting write job for table {self._table_name}") self._writer_builder = self.table.new_batch_write_builder() - if self.overwrite: - self._writer_builder = self._writer_builder.overwrite() + if self._is_overwrite(): + self._writer_builder = self._writer_builder.overwrite(self.static_partition) def write( self, @@ -108,8 +115,8 @@ def write( try: writer_builder = self.table.new_batch_write_builder() - if self.overwrite: - writer_builder = writer_builder.overwrite() + if self._is_overwrite(): + writer_builder = writer_builder.overwrite(self.static_partition) table_write = writer_builder.new_write() @@ -135,22 +142,26 @@ def write( return commit_messages_list + @staticmethod + def _extract_write_returns(write_result: Any): + """Normalize WriteResult.write_returns (Ray 2.44+) vs list of returns + (older Ray) into a list of per-task commit-message lists.""" + if hasattr(write_result, "write_returns"): + return write_result.write_returns + if isinstance(write_result, list): + return write_result + raise TypeError( + f"Unexpected write_result type {type(write_result).__name__}: " + "expected object with .write_returns or list of commit message " + "lists. Refusing to proceed to avoid silent data loss." + ) + def on_write_complete( self, write_result: Any ): table_commit = None try: - # WriteResult.write_returns (Ray 2.44+); older Ray may pass list of returns - if hasattr(write_result, "write_returns"): - write_returns = write_result.write_returns - elif isinstance(write_result, list): - write_returns = write_result - else: - raise TypeError( - f"Unexpected write_result type {type(write_result).__name__}: " - "expected object with .write_returns or list of commit message lists. " - "Refusing to proceed to avoid silent data loss." - ) + write_returns = self._extract_write_returns(write_result) all_commit_messages = [ commit_message for commit_messages in write_returns @@ -163,7 +174,7 @@ def on_write_complete( self._pending_commit_messages = non_empty_messages - if not non_empty_messages: + if not non_empty_messages and not self._is_overwrite(): logger.info("No data to commit (all commit messages are empty)") self._pending_commit_messages = [] return diff --git a/paimon-python/pypaimon/write/table_update.py b/paimon-python/pypaimon/write/table_update.py index fe2fb9a64b79..4b063dfa7bf5 100644 --- a/paimon-python/pypaimon/write/table_update.py +++ b/paimon-python/pypaimon/write/table_update.py @@ -26,6 +26,7 @@ from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.read.split import DataSplit from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER +from pypaimon.table.special_fields import SpecialFields from pypaimon.write.commit_message import CommitMessage from pypaimon.write.table_update_by_row_id import TableUpdateByRowId from pypaimon.write.table_upsert_by_key import TableUpsertByKey @@ -136,15 +137,12 @@ def new_shard_updator(self, shard_num: int, total_shard_count: int): def _update_by_arrow_with_row_id( self, table: pa.Table, commit_identifier: int ) -> List[CommitMessage]: - """Shared implementation for ``update_by_arrow_with_row_id``. - - The public method lives on the concrete subclasses so each can - expose the signature appropriate to its mode (batch vs stream). - Produced commit messages are tagged with ``commit_identifier``. - """ + cols = self.update_cols if self.update_cols is not None else [ + c for c in table.column_names if c != SpecialFields.ROW_ID.name + ] return TableUpdateByRowId( self.table, self.commit_user, commit_identifier, - ).update_columns(table, self.update_cols) + ).update_columns(table, cols) def _upsert_by_arrow_with_key( self, diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py b/paimon-python/pypaimon/write/table_update_by_row_id.py index ac9c68c3623b..34816d6ffba5 100644 --- a/paimon-python/pypaimon/write/table_update_by_row_id.py +++ b/paimon-python/pypaimon/write/table_update_by_row_id.py @@ -16,6 +16,7 @@ # under the License. import bisect +from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import pyarrow as pa @@ -24,12 +25,29 @@ from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.read.split import DataSplit from pypaimon.read.table_read import TableRead -from pypaimon.utils.range import Range from pypaimon.schema.data_types import DataField +from pypaimon.table.row.blob import Blob from pypaimon.table.row.generic_row import GenericRow from pypaimon.table.special_fields import SpecialFields +from pypaimon.utils.range import Range from pypaimon.write.commit_message import CommitMessage from pypaimon.write.file_store_write import FileStoreWrite +from pypaimon.write.writer.blob_writer import BlobWriter + + +@dataclass(frozen=True) +class _FilesInfo: + """Snapshot view of target data files keyed by first_row_id. + + Built once per merge by the driver and broadcast to workers so each task + avoids re-scanning the manifest. + """ + snapshot_id: int + first_row_ids: List[int] + first_row_id_index: Dict[int, Tuple[DataSplit, List[DataFileMeta]]] = ( + field(default_factory=dict) + ) + valid_row_id_ranges: List[Range] = field(default_factory=list) class TableUpdateByRowId: @@ -42,32 +60,40 @@ class TableUpdateByRowId: FIRST_ROW_ID_COLUMN = '_FIRST_ROW_ID' - def __init__(self, table, commit_user: str, commit_identifier: int): + def __init__( + self, table, commit_user: str, commit_identifier: int, + _precomputed_files_info: Optional[_FilesInfo] = None, + ): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table self.commit_user = commit_user self.commit_identifier = commit_identifier - # Snapshot the current state once: a single ``first_row_id -> (split, files)`` - # map is enough to drive every downstream lookup (partition, row-count, read). - (self.snapshot_id, - self.first_row_ids, - self._first_row_id_index, - self.total_row_count) = self._load_existing_files_info() + info = _precomputed_files_info or self._load_existing_files_info() + self.snapshot_id = info.snapshot_id + self.first_row_ids = info.first_row_ids + self._first_row_id_index = info.first_row_id_index + self.valid_row_id_ranges = info.valid_row_id_ranges self.commit_messages: List[CommitMessage] = [] - def _load_existing_files_info( - self, - ) -> Tuple[int, List[int], Dict[int, Tuple[DataSplit, List[DataFileMeta]]], int]: + def _snapshot_files_info(self) -> _FilesInfo: + """Internal: return the current snapshot's file index for broadcast.""" + return _FilesInfo( + snapshot_id=self.snapshot_id, + first_row_ids=self.first_row_ids, + first_row_id_index=self._first_row_id_index, + valid_row_id_ranges=self.valid_row_id_ranges, + ) + + def _load_existing_files_info(self) -> _FilesInfo: """Scan the latest snapshot once and index files by ``first_row_id``. - Returns: - A 4-tuple of ``(snapshot_id, sorted_unique_first_row_ids, index, total_row_count)`` - where ``index`` maps each ``first_row_id`` to the owning split and - the list of files with that id (a single id may belong to multiple - files when data evolution has split a logical row range). + Returns a :class:`_FilesInfo` whose ``first_row_id_index`` maps each + ``first_row_id`` to the owning split and the list of files with that + id (a single id may belong to multiple files when data evolution has + split a logical row range). """ plan = self.table.new_read_builder().new_scan().plan() splits = plan.splits() @@ -75,27 +101,52 @@ def _load_existing_files_info( index: Dict[int, Tuple[DataSplit, List[DataFileMeta]]] = {} row_id_ranges: List[Range] = [] for split in splits: + files_with_row_id = [ + file for file in split.files if file.first_row_id is not None + ] + data_files = [ + file for file in files_with_row_id + if not DataFileMeta.is_blob_file(file.file_name) + ] for file in split.files: - if file.first_row_id is None or file.file_name.endswith('.blob'): + if file.first_row_id is None or DataFileMeta.is_blob_file(file.file_name): continue row_id_ranges.append(file.row_id_range()) + for file in data_files: + target_files = [ + target_file + for target_file in files_with_row_id + if self._overlaps(file.row_id_range(), target_file.row_id_range()) + ] + entry = index.get(file.first_row_id) if entry is None: - index[file.first_row_id] = (split, [file]) + index[file.first_row_id] = (split, target_files) else: - entry[1].append(file) + existing_files = entry[1] + existing_names = {existing.file_name for existing in existing_files} + existing_files.extend( + target_file + for target_file in target_files + if target_file.file_name not in existing_names + ) - # Multiple physical files may share the same first_row_id (data evolution); - # summing row_count per file would over-count logical rows and widen - # the _ROW_ID validation range incorrectly. if row_id_ranges: merged = Range.sort_and_merge_overlap(row_id_ranges, True, True) - total_row_count = sum(r.count() for r in merged) else: - total_row_count = 0 + merged = [] snapshot_id = plan.snapshot_id if plan.snapshot_id is not None else -1 - return snapshot_id, sorted(index.keys()), index, total_row_count + return _FilesInfo( + snapshot_id=snapshot_id, + first_row_ids=sorted(index.keys()), + first_row_id_index=index, + valid_row_id_ranges=merged, + ) + + @staticmethod + def _overlaps(left: Range, right: Range) -> bool: + return left.from_ <= right.to and right.from_ <= left.to def update_columns(self, data: pa.Table, column_names: List[str]) -> List[CommitMessage]: """ @@ -128,8 +179,8 @@ def update_columns(self, data: pa.Table, column_names: List[str]) -> List[Commit def _calculate_first_row_id(self, data: pa.Table) -> pa.Table: """Append ``_FIRST_ROW_ID`` to *data* by looking up each ``_ROW_ID``. - Validates that every input ``_ROW_ID`` is unique and falls in - ``[0, total_row_count)``. Supports partial / non-consecutive updates. + Validates that every input ``_ROW_ID`` is unique and belongs to + a valid row_id range. Supports partial / non-consecutive updates. """ row_id_arr = data[SpecialFields.ROW_ID.name] row_ids = row_id_arr.to_pylist() @@ -142,15 +193,12 @@ def _calculate_first_row_id(self, data: pa.Table) -> pa.Table: self.FIRST_ROW_ID_COLUMN, pa.array([], type=pa.int64()), ) - # Vectorised range check (avoids a Python-level per-row loop). - min_id = pc.min(row_id_arr).as_py() - max_id = pc.max(row_id_arr).as_py() - if min_id < 0 or max_id >= self.total_row_count: - offending = min_id if min_id < 0 else max_id - raise ValueError( - f"Row ID {offending} is out of valid range " - f"[0, {self.total_row_count})" - ) + for row_id in row_ids: + if not any(r.contains(row_id) for r in self.valid_row_id_ranges): + raise ValueError( + f"Row ID {row_id} does not belong to any valid range " + f"{[f'[{r.from_}, {r.to}]' for r in self.valid_row_id_ranges]}" + ) if not self.first_row_ids: raise ValueError("The input sorted sequence is empty.") @@ -197,7 +245,7 @@ def _read_original_file_data(self, first_row_id: int, column_names: List[str]) - """ wanted = set(column_names) read_fields: List[DataField] = [ - field for field in self.table.fields if field.name in wanted + table_field for table_field in self.table.fields if table_field.name in wanted ] if not read_fields: return None @@ -216,8 +264,12 @@ def _read_original_file_data(self, first_row_id: int, column_names: List[str]) - table_read = TableRead(self.table, predicate=None, read_type=read_fields) return table_read.to_arrow([origin_split]) - def _merge_update_with_original(self, original_data: Optional[pa.Table], update_data: pa.Table, - column_names: List[str], first_row_id: int) -> pa.Table: + def _merge_update_with_original( + self, + original_data: Optional[pa.Table], + update_data: pa.Table, + column_names: List[str], + first_row_id: int) -> Tuple[Optional[pa.Table], Dict[str, List[object]]]: """Merge update data with original data, preserving row order. For rows that have updates, use the update values. @@ -230,7 +282,7 @@ def _merge_update_with_original(self, original_data: Optional[pa.Table], update_ first_row_id: The first_row_id of this file group Returns: - Merged PyArrow Table with all rows + Normal merged PyArrow Table and blob values to write row-by-row. """ # Get the _ROW_ID values from update_data to determine which rows are being updated @@ -245,18 +297,40 @@ def _merge_update_with_original(self, original_data: Optional[pa.Table], update_ # Build the merged table column by column merged_columns = {} + blob_columns = {} + update_by_col = { + col_name: update_data[col_name].combine_chunks() + for col_name in column_names + } + update_positions = { + int(relative_index.as_py()): idx + for idx, relative_index in enumerate(relative_indices) + } for col_name in column_names: - update_col = update_data[col_name].combine_chunks() + update_col = update_by_col[col_name] original_col = original_data[col_name].combine_chunks() - # replace_with_mask fills mask=True positions with update values in order - merged_columns[col_name] = pc.replace_with_mask( - original_col, mask, update_col.cast(original_col.type) - ) - - # Create the merged table - merged_table = pa.table(merged_columns) - - return merged_table + if self._is_blob_column(col_name): + blob_columns[col_name] = [ + update_col[update_positions[i]].as_py() + if i in update_positions + else Blob.PLACE_HOLDER + for i in range(original_data.num_rows) + ] + else: + # replace_with_mask fills mask=True positions with update values in order + merged_columns[col_name] = pc.replace_with_mask( + original_col, mask, update_col.cast(original_col.type) + ) + + merged_table = pa.table(merged_columns) if merged_columns else None + + return merged_table, blob_columns + + def _is_blob_column(self, column_name: str) -> bool: + for table_field in self.table.fields: + if table_field.name == column_name: + return getattr(table_field.type, 'type', None) == 'BLOB' + return False def _write_group(self, partition: GenericRow, first_row_id: int, data: pa.Table, column_names: List[str]): @@ -266,25 +340,54 @@ def _write_group(self, partition: GenericRow, first_row_id: int, writes a single output file (rolling disabled) for the group. """ original_data = self._read_original_file_data(first_row_id, column_names) - merged_data = self._merge_update_with_original( + merged_data, blob_columns = self._merge_update_with_original( original_data, data, column_names, first_row_id, ) - file_store_write = FileStoreWrite(self.table, self.commit_user) + partition_tuple = tuple(partition.values) + new_files = [] + file_store_write = None + blob_writers = [] try: - file_store_write.disable_rolling() - file_store_write.write_cols = column_names - - partition_tuple = tuple(partition.values) - for batch in merged_data.to_batches(): - file_store_write.write(partition_tuple, 0, batch) - - new_messages = file_store_write.prepare_commit(self.commit_identifier) - for msg in new_messages: - msg.check_from_snapshot = self.snapshot_id - for file in msg.new_files: + if merged_data is not None: + file_store_write = FileStoreWrite(self.table, self.commit_user) + file_store_write.disable_rolling() + file_store_write.write_cols = list(merged_data.column_names) + for batch in merged_data.to_batches(): + file_store_write.write(partition_tuple, 0, batch) + new_messages = file_store_write.prepare_commit(self.commit_identifier) + for msg in new_messages: + new_files.extend(msg.new_files) + + for column_name, values in blob_columns.items(): + blob_writer = BlobWriter( + self.table, + partition_tuple, + 0, + 0, + column_name, + self.table.options, + ) + blob_writers.append(blob_writer) + arrow_type = original_data.schema.field(column_name).type + for value in values: + blob_writer.write_blob(value, arrow_type) + new_files.extend(blob_writer.prepare_commit()) + + if new_files: + for file in new_files: file.first_row_id = first_row_id - file.write_cols = column_names - self.commit_messages.extend(new_messages) + file.write_cols = file.write_cols or column_names + self.commit_messages.append( + CommitMessage( + partition=partition_tuple, + bucket=0, + new_files=new_files, + check_from_snapshot=self.snapshot_id, + ) + ) finally: - file_store_write.close() + if file_store_write is not None: + file_store_write.close() + for blob_writer in blob_writers: + blob_writer.close() diff --git a/paimon-python/pypaimon/write/table_write.py b/paimon-python/pypaimon/write/table_write.py index 80ef5a3572bb..411ddd9ceb21 100644 --- a/paimon-python/pypaimon/write/table_write.py +++ b/paimon-python/pypaimon/write/table_write.py @@ -22,6 +22,7 @@ from pypaimon.schema.data_types import PyarrowFieldParser from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER +from pypaimon.table.row.blob import BlobConsumer from pypaimon.write.commit_message import CommitMessage from pypaimon.write.file_store_write import FileStoreWrite @@ -30,7 +31,7 @@ class TableWrite: - def __init__(self, table, commit_user): + def __init__(self, table, commit_user, static_partition: Optional[dict] = None): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table @@ -38,6 +39,7 @@ def __init__(self, table, commit_user): self.file_store_write = FileStoreWrite(self.table, commit_user) self.row_key_extractor = self.table.create_row_key_extractor() self.commit_user = commit_user + self.static_partition = static_partition def write_arrow(self, table: pa.Table): batches_iterator = table.to_batches() @@ -71,12 +73,22 @@ def with_write_type(self, write_cols: List[str]): self.file_store_write.write_cols = write_cols return self + def with_blob_consumer(self, blob_consumer: BlobConsumer): + if self.file_store_write.data_writers: + raise RuntimeError( + "with_blob_consumer must be called before any write operation." + ) + self.file_store_write.blob_consumer = blob_consumer + return self + def write_ray( self, dataset: "Dataset", overwrite: bool = False, concurrency: Optional[int] = None, ray_remote_args: Optional[Dict[str, Any]] = None, + hash_fixed_precluster: str = "auto", + static_partition: Optional[dict] = None, ) -> None: """ Write a Ray Dataset to Paimon table. @@ -85,13 +97,35 @@ def write_ray( dataset: Ray Dataset to write. This is a distributed data collection from Ray Data (ray.data.Dataset). overwrite: Whether to overwrite existing data. Defaults to False. + Builder-level or static_partition overwrite mode takes precedence. concurrency: Optional max number of Ray tasks to run concurrently. By default, dynamically decided based on available resources. ray_remote_args: Optional kwargs passed to :func:`ray.remote` in write tasks. For example, ``{"num_cpus": 2, "max_retries": 3}``. + hash_fixed_precluster: HASH_FIXED pre-clustering mode. ``"auto"`` + and ``"off"`` write append-only HASH_FIXED tables directly + and reject HASH_FIXED primary-key tables. ``"map_groups"`` + preserves the legacy small-file optimization and its single + group memory bound for HASH_FIXED primary-key tables. + static_partition: Optional partition spec to overwrite. When set, + the Ray write runs in overwrite mode for this partition and + overrides any builder-level partition spec. """ + from pypaimon.ray.shuffle import maybe_apply_repartition from pypaimon.write.ray_datasink import PaimonDatasink - datasink = PaimonDatasink(self.table, overwrite=overwrite) + + dataset = maybe_apply_repartition( + dataset, self.table, hash_fixed_precluster) + + overwrite_partition = self.static_partition + if static_partition is not None: + overwrite_partition = static_partition + + datasink = PaimonDatasink( + self.table, + overwrite=overwrite, + static_partition=overwrite_partition, + ) dataset.write_datasink( datasink, concurrency=concurrency, @@ -130,8 +164,8 @@ def _is_binary_family(arrow_type) -> bool: class BatchTableWrite(TableWrite): - def __init__(self, table, commit_user): - super().__init__(table, commit_user) + def __init__(self, table, commit_user, static_partition: Optional[dict] = None): + super().__init__(table, commit_user, static_partition) self.batch_committed = False def prepare_commit(self) -> List[CommitMessage]: diff --git a/paimon-python/pypaimon/write/write_builder.py b/paimon-python/pypaimon/write/write_builder.py index 724e5d7a3ff9..f7a0459305de 100644 --- a/paimon-python/pypaimon/write/write_builder.py +++ b/paimon-python/pypaimon/write/write_builder.py @@ -59,7 +59,7 @@ def _create_commit_user(self): class BatchWriteBuilder(WriteBuilder): def new_write(self) -> BatchTableWrite: - return BatchTableWrite(self.table, self.commit_user) + return BatchTableWrite(self.table, self.commit_user, self.static_partition) def new_update(self) -> BatchTableUpdate: return BatchTableUpdate(self.table, self.commit_user) @@ -72,7 +72,7 @@ def new_commit(self) -> BatchTableCommit: class StreamWriteBuilder(WriteBuilder): def new_write(self) -> StreamTableWrite: - return StreamTableWrite(self.table, self.commit_user) + return StreamTableWrite(self.table, self.commit_user, self.static_partition) def new_update(self) -> StreamTableUpdate: return StreamTableUpdate(self.table, self.commit_user) diff --git a/paimon-python/pypaimon/write/writer/blob_file_writer.py b/paimon-python/pypaimon/write/writer/blob_file_writer.py index 31d945a4ede8..e4aa66a1953c 100644 --- a/paimon-python/pypaimon/write/writer/blob_file_writer.py +++ b/paimon-python/pypaimon/write/writer/blob_file_writer.py @@ -22,7 +22,7 @@ from pypaimon.write.blob_format_writer import BlobFormatWriter from pypaimon.table.row.generic_row import GenericRow, RowKind -from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor +from pypaimon.table.row.blob import Blob, BlobConsumer, BlobData, BlobDescriptor from pypaimon.schema.data_types import DataField, PyarrowFieldParser @@ -32,11 +32,16 @@ class BlobFileWriter: Writes rows one by one and tracks file size. """ - def __init__(self, file_io, file_path: Path): + def __init__(self, file_io, file_path: Path, blob_consumer: Optional[BlobConsumer] = None): self.file_io = file_io self.file_path = file_path + self._blob_consumer = blob_consumer self.output_stream = file_io.new_output_stream(file_path) - self.writer = BlobFormatWriter(self.output_stream) + self.writer = BlobFormatWriter( + self.output_stream, + blob_consumer=blob_consumer, + file_path=str(file_path), + ) self.row_count = 0 self.closed = False @@ -52,8 +57,17 @@ def write_row(self, row_data: pa.Table): blob_data = self._to_blob(col_data) - # Create GenericRow - fields = [DataField(0, field_name, PyarrowFieldParser.to_paimon_type(row_data.schema[0].type, False))] + self.write_blob(field_name, row_data.schema[0].type, blob_data) + + def write_blob(self, field_name: str, arrow_type, blob_data): + blob_data = self._to_blob(blob_data) + fields = [ + DataField( + 0, + field_name, + PyarrowFieldParser.to_paimon_type(arrow_type, False), + ) + ] row = GenericRow([blob_data], fields, RowKind.INSERT) # Write to blob format writer @@ -61,6 +75,8 @@ def write_row(self, row_data: pa.Table): self.row_count += 1 def _to_blob(self, col_data) -> Optional[Blob]: + if col_data is Blob.PLACE_HOLDER: + return Blob.PLACE_HOLDER if hasattr(col_data, 'as_py'): col_data = col_data.as_py() if col_data is None: @@ -107,7 +123,7 @@ def close(self) -> int: return file_size def abort(self): - """Abort the writer and delete the file.""" + """Abort the writer and delete the file (unless a blob consumer holds references).""" if not self.closed: try: if hasattr(self.output_stream, 'close'): @@ -116,5 +132,5 @@ def abort(self): pass self.closed = True - # Delete the file - self.file_io.delete_quietly(self.file_path) + if self._blob_consumer is None: + self.file_io.delete_quietly(self.file_path) diff --git a/paimon-python/pypaimon/write/writer/blob_writer.py b/paimon-python/pypaimon/write/writer/blob_writer.py index 4ebed16785e6..24f64e4dff78 100644 --- a/paimon-python/pypaimon/write/writer/blob_writer.py +++ b/paimon-python/pypaimon/write/writer/blob_writer.py @@ -22,6 +22,7 @@ from pypaimon.common.options.core_options import CoreOptions from pypaimon.data.timestamp import Timestamp +from pypaimon.table.row.blob import BlobConsumer from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter from pypaimon.write.writer.blob_file_writer import BlobFileWriter @@ -31,7 +32,7 @@ class BlobWriter(AppendOnlyDataWriter): def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, blob_column: str, - options: Dict[str, str] = None): + options: Dict[str, str] = None, blob_consumer: Optional[BlobConsumer] = None): super().__init__(table, partition, bucket, max_seq_number, options, write_cols=[blob_column]) @@ -44,6 +45,7 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, bl options = self.table.options self.blob_target_file_size = CoreOptions.blob_target_file_size(options) + self._blob_consumer = blob_consumer self.current_writer: Optional[BlobFileWriter] = None self.current_file_path: Optional[str] = None self.record_count = 0 @@ -81,13 +83,24 @@ def _write_row_to_file(self, row_data: pa.Table): # This ensures each row has a unique sequence number for data versioning and consistency self.sequence_generator.next() + def write_blob(self, value, arrow_type=pa.large_binary()): + if self.current_writer is None: + self.open_current_writer() + + self.current_writer.write_blob(self.blob_column, arrow_type, value) + self.sequence_generator.next() + self.record_count += 1 + + if self.rolling_file(): + self.close_current_writer() + def open_current_writer(self): file_name = (f"{CoreOptions.data_file_prefix(self.options)}" f"{self.file_uuid}-{self.file_count}.{self.file_format}") self.file_count += 1 # Increment counter for next file file_path = self._generate_file_path(file_name) self.current_file_path = file_path - self.current_writer = BlobFileWriter(self.file_io, file_path) + self.current_writer = BlobFileWriter(self.file_io, file_path, blob_consumer=self._blob_consumer) def rolling_file(self) -> bool: if self.current_writer is None: @@ -225,7 +238,11 @@ def abort(self): logger.warning(f"Error aborting blob writer: {e}", exc_info=e) self.current_writer = None self.current_file_path = None - super().abort() + if self._blob_consumer is not None: + self.pending_data = None + self.committed_files.clear() + else: + super().abort() @staticmethod def _get_column_stats(data_or_batch, column_name: str): diff --git a/paimon-python/pypaimon/write/writer/data_vector_writer.py b/paimon-python/pypaimon/write/writer/data_vector_writer.py index d2b0adb568ab..d06d42593665 100644 --- a/paimon-python/pypaimon/write/writer/data_vector_writer.py +++ b/paimon-python/pypaimon/write/writer/data_vector_writer.py @@ -212,6 +212,10 @@ def _write_normal_data_to_file(self, data: pa.Table) -> Optional[DataFileMeta]: self.file_io.write_lance(file_path, data) elif self.file_format == CoreOptions.FILE_FORMAT_VORTEX: self.file_io.write_vortex(file_path, data) + elif self.file_format == CoreOptions.FILE_FORMAT_MOSAIC: + self.file_io.write_mosaic(file_path, data) + elif self.file_format == CoreOptions.FILE_FORMAT_ROW: + self.file_io.write_row(file_path, data, zstd_level=self.zstd_level) else: raise ValueError(f"Unsupported file format: {self.file_format}") diff --git a/paimon-python/pypaimon/write/writer/data_writer.py b/paimon-python/pypaimon/write/writer/data_writer.py index f420a813c0c7..313caa7f6d08 100644 --- a/paimon-python/pypaimon/write/writer/data_writer.py +++ b/paimon-python/pypaimon/write/writer/data_writer.py @@ -198,6 +198,10 @@ def _write_data_to_file(self, data: pa.Table): self.file_io.write_lance(file_path, data) elif self.file_format == CoreOptions.FILE_FORMAT_VORTEX: self.file_io.write_vortex(file_path, data) + elif self.file_format == CoreOptions.FILE_FORMAT_MOSAIC: + self.file_io.write_mosaic(file_path, data) + elif self.file_format == CoreOptions.FILE_FORMAT_ROW: + self.file_io.write_row(file_path, data, zstd_level=self.zstd_level) else: raise ValueError(f"Unsupported file format: {self.file_format}") diff --git a/paimon-python/pypaimon/write/writer/data_blob_writer.py b/paimon-python/pypaimon/write/writer/dedicated_format_writer.py similarity index 64% rename from paimon-python/pypaimon/write/writer/data_blob_writer.py rename to paimon-python/pypaimon/write/writer/dedicated_format_writer.py index 4c3289f5aa44..01216b36cd4d 100644 --- a/paimon-python/pypaimon/write/writer/data_blob_writer.py +++ b/paimon-python/pypaimon/write/writer/dedicated_format_writer.py @@ -25,65 +25,40 @@ from pypaimon.data.timestamp import Timestamp from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.simple_stats import SimpleStats +from pypaimon.schema.data_types import VectorType +from pypaimon.table.row.blob import BlobConsumer from pypaimon.table.row.generic_row import GenericRow from pypaimon.write.writer.data_writer import DataWriter logger = logging.getLogger(__name__) -class DataBlobWriter(DataWriter): - """ - A rolling file writer that handles both normal data and blob data. This writer creates separate - files for normal columns and blob columns, managing their lifecycle independently. - - For example, given a table schema with normal columns (id INT, name STRING) and blob columns - (pic1 BLOB, pic2 BLOB), this writer will create separate files for normal columns and each - blob-file column. - - Key features: - - Blob data can roll independently when normal data doesn't need rolling - - When normal data rolls, blob data MUST also be closed (Java behavior) - - Blob data uses more aggressive rolling (smaller target size) to prevent memory issues - - One normal data file may correspond to multiple blob data files - - Blob data is written immediately to disk to prevent memory corruption - - Blob file metadata is stored as separate DataFileMeta objects after normal file metadata - - When TableWrite.with_write_type narrows columns, incoming batches only carry that subset; - column lists are narrowed accordingly so splitting never selects missing columns. - - Rolling behavior: - - Normal data rolls: Both normal and blob writers are closed together, blob metadata added after normal metadata - - Blob data rolls independently: Only blob writer is closed, blob metadata is cached until normal data rolls - - Metadata organization: - - Normal file metadata is added first to committed_files - - Blob file metadata is added after normal file metadata in committed_files - - When blob rolls independently, metadata is cached until normal data rolls - - Result: [normal_meta, blob_meta1, blob_meta2, blob_meta3, ...] - - Example file organization: - committed_files = [ - normal_file1_meta, # f1.parquet metadata - blob_file1_meta, # b1.blob metadata - blob_file2_meta, # b2.blob metadata - blob_file3_meta, # b3.blob metadata - normal_file2_meta, # f1-2.parquet metadata - blob_file4_meta, # b4.blob metadata - blob_file5_meta, # b5.blob metadata - ] - - This matches the Java RollingBlobFileWriter behavior exactly. +class DedicatedFormatWriter(DataWriter): + """A rolling file writer that writes normal, blob, and vector columns to dedicated files. + + Splits incoming data three ways: + - Normal columns → standard data files (.parquet / .orc / .vortex / …) + - Blob columns (large_binary) → .blob files + - Vector columns (when vector.file.format is configured) → .vector. files + + This mirrors Java's DedicatedFormatRollingFileWriter. + + Metadata order in committed_files: + [normal_meta, blob_meta1, …, vector_meta1, …] """ # Constant for checking rolling condition periodically CHECK_ROLLING_RECORD_CNT = 1000 def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, options: CoreOptions = None, - write_cols: Optional[List[str]] = None): + write_cols: Optional[List[str]] = None, blob_consumer: Optional[BlobConsumer] = None): super().__init__(table, partition, bucket, max_seq_number, options, write_cols=write_cols) # Determine blob columns from table schema self.blob_column_names = self._get_blob_columns_from_schema() self.blob_descriptor_fields = CoreOptions.blob_descriptor_fields(self.options) + self.blob_view_fields = CoreOptions.blob_view_fields(self.options) + self.blob_inline_fields = self.blob_descriptor_fields.union(self.blob_view_fields) unknown_descriptor_fields = self.blob_descriptor_fields.difference( set(self.blob_column_names) @@ -95,27 +70,37 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, op ) # Blob fields that should still be written to `.blob` files. - full_blob_file_column_names = [ - col for col in self.blob_column_names if col not in self.blob_descriptor_fields + self.blob_file_column_names = [ + col for col in self.blob_column_names if col not in self.blob_inline_fields ] - full_blob_file_set = set(full_blob_file_column_names) + full_blob_file_set = set(self.blob_file_column_names) all_column_names = self.table.field_names + # Detect vector columns that should be written to dedicated files. + full_vector_column_names = self._get_vector_columns_from_schema() + full_vector_set = set(full_vector_column_names) + # Only split vector columns when vector.file.format is configured. + has_dedicated_vector = bool(full_vector_column_names) and options.with_vector_format() + dedicated_set = full_blob_file_set | (full_vector_set if has_dedicated_vector else set()) + # Narrow columns when TableWrite.with_write_type(...) supplies a partial column list. # Incoming RecordBatches only contain those columns; selecting full normal/blob lists # would raise KeyError. if write_cols is not None: write_col_set = set(write_cols) self.blob_file_column_names = [ - col for col in full_blob_file_column_names if col in write_col_set + col for col in self.blob_file_column_names if col in write_col_set ] + self.vector_write_columns = [ + col for col in full_vector_column_names if col in write_col_set + ] if has_dedicated_vector else [] self.normal_column_names = [ - col for col in write_cols if col not in full_blob_file_set + col for col in write_cols if col not in dedicated_set ] else: - self.blob_file_column_names = list(full_blob_file_column_names) + self.vector_write_columns = list(full_vector_column_names) if has_dedicated_vector else [] self.normal_column_names = [ - col for col in all_column_names if col not in full_blob_file_set + col for col in all_column_names if col not in dedicated_set ] normal_name_set = set(self.normal_column_names) self.normal_columns = [ @@ -140,7 +125,22 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, op bucket=self.bucket, max_seq_number=max_seq_number, blob_column=blob_column, - options=options + options=options, + blob_consumer=blob_consumer, + ) + + # Initialize vector writer when vector.file.format is configured. + from pypaimon.write.writer.vector_writer import VectorWriter + self.vector_writer: Optional[VectorWriter] = None + if self.vector_write_columns: + self.vector_writer = VectorWriter( + table=self.table, + partition=self.partition, + bucket=self.bucket, + max_seq_number=max_seq_number, + vector_columns=self.vector_write_columns, + vector_file_format=options.vector_file_format(), + options=options, ) # Initialize ExternalStorageBlobWriter if configured @@ -159,12 +159,14 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, op ) logger.info( - "Initialized DataBlobWriter with blob columns: %s, blob file columns: %s, descriptor " - "stored columns: %s, external storage fields: %s", + "Initialized DedicatedFormatWriter with blob columns: %s, blob file columns: %s, " + "vector columns: %s, descriptor stored columns: %s, external storage fields: %s, view stored columns: %s", self.blob_column_names, self.blob_file_column_names, + self.vector_write_columns, sorted(self.blob_descriptor_fields), sorted(external_storage_fields) if external_storage_fields else [], + sorted(self.blob_view_fields) ) def _get_blob_columns_from_schema(self) -> List[str]: @@ -178,8 +180,14 @@ def _get_blob_columns_from_schema(self) -> List[str]: raise ValueError("No blob field found in table schema.") return blob_columns + def _get_vector_columns_from_schema(self) -> List[str]: + return [ + field.name for field in self.table.table_schema.fields + if isinstance(field.type, VectorType) + ] + def _process_data(self, data: pa.RecordBatch) -> pa.RecordBatch: - normal_data, _ = self._split_data(data) + normal_data, _, _ = self._split_data(data) return normal_data def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table: @@ -192,22 +200,27 @@ def write(self, data: pa.RecordBatch): if self._external_storage_writer: data = self._external_storage_writer.transform_batch(data) - # Split data into normal and blob parts - normal_data, blob_data_map = self._split_data(data) - self._validate_descriptor_stored_fields_input(data) + # Split data into normal, blob, and vector parts + normal_data, blob_data_map, vector_data = self._split_data(data) + self._validate_inline_stored_fields_input(data) - # Process and accumulate normal data + # Process and accumulate normal data (may be None for partial writes) processed_normal = self._process_normal_data(normal_data) - if self.pending_normal_data is None: - self.pending_normal_data = processed_normal - else: - self.pending_normal_data = self._merge_normal_data(self.pending_normal_data, processed_normal) + if processed_normal is not None: + if self.pending_normal_data is None: + self.pending_normal_data = processed_normal + else: + self.pending_normal_data = self._merge_normal_data(self.pending_normal_data, processed_normal) # Write blob-file columns to dedicated blob writers. for blob_column, blob_data in blob_data_map.items(): if blob_data is not None and blob_data.num_rows > 0: self.blob_writers[blob_column].write(blob_data) + # Write vector columns to dedicated vector writer. + if self.vector_writer is not None and vector_data is not None and vector_data.num_rows > 0: + self.vector_writer.write(vector_data) + self.record_count += data.num_rows # Check if normal data rolling is needed @@ -231,8 +244,7 @@ def close(self): return try: - if self.pending_normal_data is not None and self.pending_normal_data.num_rows > 0: - self._close_current_writers() + self._close_current_writers() if self._external_storage_writer: self._external_storage_writer.close() except Exception as e: @@ -246,24 +258,33 @@ def abort(self): """Abort all writers and clean up resources.""" for blob_writer in self.blob_writers.values(): blob_writer.abort() + if self.vector_writer is not None: + self.vector_writer.abort() if self._external_storage_writer: self._external_storage_writer.abort() self.pending_normal_data = None self.committed_files.clear() - def _split_data(self, data: pa.RecordBatch) -> Tuple[pa.RecordBatch, Dict[str, pa.RecordBatch]]: - """Split data into normal and blob parts based on column names.""" + def _split_data(self, data: pa.RecordBatch) -> Tuple[ + pa.RecordBatch, Dict[str, pa.RecordBatch], Optional[pa.RecordBatch]]: + """Split data into normal, blob, and vector parts based on column names.""" normal_data = data.select(self.normal_column_names) if self.normal_column_names else None blob_data_map = { blob_column: data.select([blob_column]) for blob_column in self.blob_file_column_names } - return normal_data, blob_data_map + vector_data = ( + pa.RecordBatch.from_arrays( + [data.column(name) for name in self.vector_write_columns], + names=self.vector_write_columns, + ) if self.vector_write_columns else None + ) + return normal_data, blob_data_map, vector_data - def _validate_descriptor_stored_fields_input(self, data: pa.RecordBatch): - if not self.blob_descriptor_fields: + def _validate_inline_stored_fields_input(self, data: pa.RecordBatch): + if not self.blob_inline_fields: return - from pypaimon.table.row.blob import BlobDescriptor + from pypaimon.table.row.blob import BlobDescriptor, BlobViewStruct for field_name in self.blob_descriptor_fields: if field_name not in data.schema.names: @@ -292,11 +313,38 @@ def _validate_descriptor_stored_fields_input(self, data: pa.RecordBatch): "BlobDescriptor." ) from e + for field_name in self.blob_view_fields: + if field_name not in data.schema.names: + continue + values = data.column(data.schema.get_field_index(field_name)).to_pylist() + for value in values: + if value is None: + continue + if hasattr(value, 'as_py'): + value = value.as_py() + if isinstance(value, str): + value = value.encode('utf-8') + if not isinstance(value, (bytes, bytearray)): + raise ValueError( + "blob-view-field requires blob field value to be a serialized " + "BlobViewStruct." + ) + try: + view_bytes = bytes(value) + view_struct = BlobViewStruct.deserialize(view_bytes) + if view_struct.serialize() != view_bytes: + raise ValueError("BlobViewStruct payload contains trailing bytes.") + except Exception as e: + raise ValueError( + "blob-view-field requires blob field value to be a serialized " + "BlobViewStruct." + ) from e + @staticmethod - def _process_normal_data(data: pa.RecordBatch) -> pa.Table: + def _process_normal_data(data: pa.RecordBatch) -> Optional[pa.Table]: """Process normal data (similar to base DataWriter).""" if data is None or data.num_rows == 0: - return pa.Table.from_batches([]) + return None return pa.Table.from_batches([data]) @staticmethod @@ -316,30 +364,36 @@ def _should_roll_normal(self) -> bool: return current_size > self.target_file_size def _close_current_writers(self): - """Close both normal and blob writers and add blob metadata after normal metadata (Java behavior).""" - if self.pending_normal_data is None or self.pending_normal_data.num_rows == 0: - return - - # Close normal writer and get metadata - normal_meta = self._write_normal_data_to_file(self.pending_normal_data) + """Close normal, blob, and vector writers; add metadata in order: normal, blob, vector.""" + normal_meta = None + if self.pending_normal_data is not None and self.pending_normal_data.num_rows > 0: + normal_meta = self._write_normal_data_to_file(self.pending_normal_data) blob_metas = [] for blob_column in self.blob_file_column_names: writer_metas = self.blob_writers[blob_column].prepare_commit() - self._validate_consistency(normal_meta, writer_metas, blob_column) + if normal_meta is not None: + self._validate_consistency(normal_meta, writer_metas, blob_column) blob_metas.extend(writer_metas) - # Add normal file metadata first - self.committed_files.append(normal_meta) + vector_metas = [] + if self.vector_writer is not None: + vector_metas = self.vector_writer.prepare_commit() + self.vector_writer.committed_files.clear() + if vector_metas and normal_meta is not None: + self._validate_consistency(normal_meta, vector_metas, 'vector') - # Add blob file metadata after normal metadata + if normal_meta is not None: + self.committed_files.append(normal_meta) self.committed_files.extend(blob_metas) + self.committed_files.extend(vector_metas) - # Reset pending data self.pending_normal_data = None - logger.info(f"Closed both writers - normal: {normal_meta.file_name}, " - f"added {len(blob_metas)} blob file metadata after normal metadata") + if normal_meta is not None or blob_metas or vector_metas: + normal_name = normal_meta.file_name if normal_meta is not None else '' + logger.info(f"Closed writers - normal: {normal_name}, " + f"{len(blob_metas)} blob metas, {len(vector_metas)} vector metas") def _write_normal_data_to_file(self, data: pa.Table) -> Optional[DataFileMeta]: if data.num_rows == 0: @@ -359,6 +413,10 @@ def _write_normal_data_to_file(self, data: pa.Table) -> Optional[DataFileMeta]: self.file_io.write_lance(file_path, data) elif self.file_format == CoreOptions.FILE_FORMAT_VORTEX: self.file_io.write_vortex(file_path, data) + elif self.file_format == CoreOptions.FILE_FORMAT_MOSAIC: + self.file_io.write_mosaic(file_path, data) + elif self.file_format == CoreOptions.FILE_FORMAT_ROW: + self.file_io.write_row(file_path, data, zstd_level=self.zstd_level) else: raise ValueError(f"Unsupported file format: {self.file_format}") diff --git a/paimon-python/pypaimon/write/writer/format_row_writer.py b/paimon-python/pypaimon/write/writer/format_row_writer.py new file mode 100644 index 000000000000..f31c1eb58c4f --- /dev/null +++ b/paimon-python/pypaimon/write/writer/format_row_writer.py @@ -0,0 +1,408 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import re +import struct +from decimal import Decimal +from typing import Any, List + +import pyarrow as pa + +from pypaimon.common.delta_varint_compressor import DeltaVarintCompressor +from pypaimon.schema.data_types import ( + ArrayType, AtomicType, DataField, MapType, MultisetType, RowType, VectorType +) + +FOOTER_SIZE = 32 +MAGIC = 0x524F5753 # "ROWS" +VERSION = 1 +DEFAULT_BLOCK_SIZE = 65536 + + +class FormatRowWriter: + + def __init__(self, output_stream, fields: List[DataField], + block_size: int = DEFAULT_BLOCK_SIZE, zstd_level: int = 1): + self._out = output_stream + self._fields = fields + self._block_size = block_size + self._zstd_level = zstd_level + + self._block_buf = _BlockBuffer() + self._row_offsets: List[int] = [] + + self._block_compressed_sizes: List[int] = [] + self._block_uncompressed_sizes: List[int] = [] + self._block_row_starts: List[int] = [] + self._total_row_count = 0 + self._bytes_written = 0 + + def write_table(self, data: pa.Table): + columns = {field.name: data.column(field.name).to_pylist() + for field in self._fields if field.name in data.column_names} + num_rows = data.num_rows + + for row_idx in range(num_rows): + row_values = [columns[field.name][row_idx] for field in self._fields] + self._write_row(row_values) + self._total_row_count += 1 + + if self._block_buf.position + len(self._row_offsets) * 4 + 4 >= self._block_size: + self._flush_block() + + def close(self): + self._flush_block() + + index_offset = self._bytes_written + self._write_block_index() + index_length = self._bytes_written - index_offset + + self._write_footer(index_offset, index_length) + self._out.flush() + + def _write_row(self, values: List[Any]): + self._row_offsets.append(self._block_buf.position) + arity = len(self._fields) + header_size = (arity + 7) // 8 + + header_start = self._block_buf.position + self._block_buf.write_zeros(header_size) + + for i, field in enumerate(self._fields): + if values[i] is None: + byte_idx = header_start + i // 8 + self._block_buf.buffer[byte_idx] |= (1 << (i % 8)) + else: + _write_field(self._block_buf, values[i], field.type) + + def _flush_block(self): + if not self._row_offsets: + return + + import zstandard as zstd + + self._block_row_starts.append(self._total_row_count - len(self._row_offsets)) + + for offset in self._row_offsets: + self._block_buf.write_int_le(offset) + self._block_buf.write_int_le(len(self._row_offsets)) + + uncompressed = bytes(self._block_buf.buffer[:self._block_buf.position]) + self._block_uncompressed_sizes.append(len(uncompressed)) + + compressor = zstd.ZstdCompressor(level=self._zstd_level) + compressed = compressor.compress(uncompressed) + self._block_compressed_sizes.append(len(compressed)) + + self._out.write(compressed) + self._bytes_written += len(compressed) + + self._block_buf.reset() + self._row_offsets.clear() + + def _write_block_index(self): + encoded_compressed = DeltaVarintCompressor.compress(self._block_compressed_sizes) + encoded_uncompressed = DeltaVarintCompressor.compress(self._block_uncompressed_sizes) + encoded_row_starts = DeltaVarintCompressor.compress(self._block_row_starts) + + self._write_varint_prefixed(encoded_compressed) + self._write_varint_prefixed(encoded_uncompressed) + self._write_varint_prefixed(encoded_row_starts) + + def _write_varint_prefixed(self, data: bytes): + varint_bytes = _encode_var_int(len(data)) + self._out.write(varint_bytes) + self._out.write(data) + self._bytes_written += len(varint_bytes) + len(data) + + def _write_footer(self, index_offset: int, index_length: int): + buf = bytearray(FOOTER_SIZE) + struct.pack_into(' len(self.buffer): + new_size = max(len(self.buffer) * 2, required) + new_buf = bytearray(new_size) + new_buf[:self.position] = self.buffer[:self.position] + self.buffer = new_buf + + def write_zeros(self, count: int): + self._ensure_capacity(count) + for i in range(count): + self.buffer[self.position + i] = 0 + self.position += count + + def write_boolean(self, value: bool): + self._ensure_capacity(1) + self.buffer[self.position] = 1 if value else 0 + self.position += 1 + + def write_byte(self, value: int): + self._ensure_capacity(1) + self.buffer[self.position] = value & 0xFF + self.position += 1 + + def write_short_le(self, value: int): + self._ensure_capacity(2) + struct.pack_into('>= 7 + self.buffer[self.position] = value & 0x7F + self.position += 1 + + def write_bytes_with_length(self, data: bytes): + self.write_var_int(len(data)) + self._ensure_capacity(len(data)) + self.buffer[self.position:self.position + len(data)] = data + self.position += len(data) + + def write_string(self, value: str): + encoded = value.encode('utf-8') + self.write_bytes_with_length(encoded) + + +def _encode_var_int(value: int) -> bytes: + result = bytearray() + while (value & ~0x7F) != 0: + result.append((value & 0x7F) | 0x80) + value >>= 7 + result.append(value & 0x7F) + return bytes(result) + + +def _write_field(buf: _BlockBuffer, value: Any, data_type) -> None: + if isinstance(data_type, AtomicType): + type_name = data_type.type.upper() + if type_name == 'BOOLEAN': + buf.write_boolean(value) + elif type_name == 'TINYINT': + buf.write_byte(value) + elif type_name == 'SMALLINT': + buf.write_short_le(value) + elif type_name in ('INT', 'INTEGER'): + buf.write_int_le(value) + elif type_name == 'DATE': + if isinstance(value, datetime.date): + epoch = datetime.date(1970, 1, 1) + days = (value - epoch).days + buf.write_int_le(days) + else: + buf.write_int_le(value) + elif type_name == 'TIME' or (type_name.startswith('TIME') and not type_name.startswith('TIMESTAMP')): + if isinstance(value, datetime.time): + millis = (value.hour * 3600 + value.minute * 60 + value.second) * 1000 + value.microsecond // 1000 + buf.write_int_le(millis) + else: + buf.write_int_le(value) + elif type_name == 'BIGINT': + buf.write_long_le(value) + elif type_name == 'FLOAT': + buf.write_float_le(value) + elif type_name == 'DOUBLE': + buf.write_double_le(value) + elif type_name == 'STRING' or type_name.startswith('CHAR') or type_name.startswith('VARCHAR'): + buf.write_string(value) + elif type_name == 'BYTES' or type_name.startswith('BINARY') or type_name.startswith('VARBINARY'): + buf.write_bytes_with_length(value) + elif type_name == 'BLOB': + buf.write_bytes_with_length(value) + elif type_name.startswith('DECIMAL'): + precision, scale = _parse_decimal_params(type_name) + if precision <= 18: + if isinstance(value, Decimal): + unscaled = int(value * (10 ** scale)) + else: + unscaled = int(Decimal(str(value)) * (10 ** scale)) + buf.write_long_le(unscaled) + else: + if isinstance(value, Decimal): + unscaled = int(value * (10 ** scale)) + else: + unscaled = int(Decimal(str(value)) * (10 ** scale)) + raw = unscaled.to_bytes( + (unscaled.bit_length() + 8) // 8, byteorder='big', signed=True) + buf.write_bytes_with_length(raw) + elif type_name.startswith('TIMESTAMP'): + precision = _parse_timestamp_precision(type_name) + if isinstance(value, datetime.datetime): + epoch = datetime.datetime(1970, 1, 1, tzinfo=value.tzinfo) + delta = value - epoch + total_micros = int(delta.total_seconds() * 1_000_000) + if precision <= 3: + buf.write_long_le(total_micros // 1000) + else: + millis = total_micros // 1000 + nano_of_milli = (total_micros % 1000) * 1000 + buf.write_long_le(millis) + buf.write_var_int(nano_of_milli) + elif precision <= 3: + buf.write_long_le(value) + else: + if isinstance(value, int): + millis = value // 1000 + nano_of_milli = (value % 1000) * 1000 + else: + millis = int(value) // 1000 + nano_of_milli = (int(value) % 1000) * 1000 + buf.write_long_le(millis) + buf.write_var_int(nano_of_milli) + elif type_name == 'VARIANT': + if isinstance(value, dict): + buf.write_bytes_with_length(value['value']) + buf.write_bytes_with_length(value['metadata']) + else: + buf.write_bytes_with_length(value.value if hasattr(value, 'value') else b'') + buf.write_bytes_with_length(value.metadata if hasattr(value, 'metadata') else b'') + else: + raise ValueError(f"Unsupported atomic type for writing: {type_name}") + + elif isinstance(data_type, ArrayType): + _write_array_elements(buf, value, data_type.element) + + elif isinstance(data_type, VectorType): + _write_vector(buf, value, data_type.element) + + elif isinstance(data_type, MapType): + if isinstance(value, dict): + keys = list(value.keys()) + values = list(value.values()) + else: + keys = [pair[0] for pair in value] + values = [pair[1] for pair in value] + _write_array_elements(buf, keys, data_type.key) + _write_array_elements(buf, values, data_type.value) + + elif isinstance(data_type, MultisetType): + if isinstance(value, dict): + keys = list(value.keys()) + counts = list(value.values()) + else: + keys = [pair[0] for pair in value] + counts = [pair[1] for pair in value] + _write_array_elements(buf, keys, data_type.element) + _write_array_elements(buf, counts, AtomicType("INT")) + + elif isinstance(data_type, RowType): + _write_nested_row(buf, value, data_type) + + else: + raise ValueError(f"Unsupported data type for writing: {data_type}") + + +def _write_array_elements(buf: _BlockBuffer, elements: list, element_type) -> None: + size = len(elements) + buf.write_var_int(size) + null_bitmap_bytes = (size + 7) // 8 + + null_start = buf.position + buf.write_zeros(null_bitmap_bytes) + + for i, elem in enumerate(elements): + if elem is None: + buf.buffer[null_start + i // 8] |= (1 << (i % 8)) + else: + _write_field(buf, elem, element_type) + + +def _write_vector(buf: _BlockBuffer, elements: list, element_type) -> None: + size = len(elements) + buf.write_var_int(size) + for elem in elements: + _write_field(buf, elem, element_type) + + +def _write_nested_row(buf: _BlockBuffer, value, row_type: RowType) -> None: + fields = row_type.fields + arity = len(fields) + header_size = (arity + 7) // 8 + + header_start = buf.position + buf.write_zeros(header_size) + + for i, field in enumerate(fields): + if isinstance(value, dict): + field_value = value.get(field.name) + else: + field_value = value[i] if i < len(value) else None + + if field_value is None: + buf.buffer[header_start + i // 8] |= (1 << (i % 8)) + else: + _write_field(buf, field_value, field.type) + + +def _parse_decimal_params(type_name: str) -> tuple: + match = re.fullmatch(r'DECIMAL\((\d+),\s*(\d+)\)', type_name) + if match: + return int(match.group(1)), int(match.group(2)) + match = re.fullmatch(r'DECIMAL\((\d+)\)', type_name) + if match: + return int(match.group(1)), 0 + return 10, 0 + + +def _parse_timestamp_precision(type_name: str) -> int: + match = re.search(r'\((\d+)\)', type_name) + if match: + return int(match.group(1)) + return 6 diff --git a/paimon-python/pypaimon/write/writer/key_value_data_writer.py b/paimon-python/pypaimon/write/writer/key_value_data_writer.py index 6c6f292f575c..64f6003a0694 100644 --- a/paimon-python/pypaimon/write/writer/key_value_data_writer.py +++ b/paimon-python/pypaimon/write/writer/key_value_data_writer.py @@ -15,22 +15,227 @@ # specific language governing permissions and limitations # under the License. +from typing import List, Union + import pyarrow as pa import pyarrow.compute as pc +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.read.reader.deduplicate_merge_function import \ + DeduplicateMergeFunction +from pypaimon.table.row.key_value import KeyValue from pypaimon.write.writer.data_writer import DataWriter class KeyValueDataWriter(DataWriter): - """Data writer for primary key tables with system fields and sorting.""" + """Data writer for primary key tables with system fields and sorting. + + Accumulates incoming batches in ``pending_data`` without sorting or + folding on the write path. Sort and ``MergeFunction``-based fold + are deferred to flush time (``_flush_all``), where the result is + roll-written into one or more data files. This enforces the LSM + "PK unique within a file" invariant the read-side + ``raw_convertible`` fast path relies on, while keeping per-write + cost bounded. + """ + + def __init__(self, table, partition, bucket, max_seq_number, + options=None, write_cols=None, merge_function=None): + super().__init__(table, partition, bucket, max_seq_number, + options, write_cols) + # Defaults to deduplicate so direct callers (tests / future code + # paths that don't go through FileStoreWrite) don't accidentally + # skip the merge step entirely. + self._merge_function = merge_function or DeduplicateMergeFunction() def _process_data(self, data: pa.RecordBatch) -> pa.Table: + # No sort here: sorting once at flush is strictly cheaper than + # per-batch sort + a final global sort. ``pending_data`` ends + # up as a concat of unsorted batches; ``_flush_all`` sorts it + # exactly once before folding. enhanced_data = self._add_system_fields(data) - return pa.Table.from_batches([self._sort_by_primary_key(enhanced_data)]) + return pa.Table.from_batches([enhanced_data]) def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table: - combined = pa.concat_tables([existing_data, new_data]) - return self._sort_by_primary_key(combined) + # Plain concat. Sort + fold both run inside ``_flush_all`` so + # N writes incur 1 sort instead of N sorts. + return pa.concat_tables([existing_data, new_data]) + + def prepare_commit(self) -> List[DataFileMeta]: + if self.pending_data is not None and self.pending_data.num_rows > 0: + self._flush_all() + # ``_flush_all`` leaves ``pending_data = None``, so super's + # prepare_commit just returns ``committed_files``. + return super().prepare_commit() + + def _check_and_roll_if_needed(self): + # Buffer overflowed target_file_size: sort + fold + roll-write + # the whole buffer as multiple files in one pass. Unlike the + # base class's slice loop, we never keep a slice remainder in + # ``pending_data`` -- flush empties the buffer outright. + if (self.pending_data is not None + and self.pending_data.num_rows > 0 + and self.pending_data.nbytes > self.target_file_size): + self._flush_all() + + def close(self): + # Override the base ``close`` because its straight + # ``_write_data_to_file(pending_data)`` would land an unsorted, + # un-folded buffer on disk -- violating the file-internal + # PK-unique invariant. Route the final flush through + # ``_flush_all`` so the contract holds even on the + # close-without-prepare_commit path. + try: + if self.pending_data is not None and self.pending_data.num_rows > 0: + self._flush_all() + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.warning( + "Exception occurs when closing writer. Cleaning up.", + exc_info=e) + self.abort() + raise e + finally: + self.pending_data = None + + def _flush_all(self) -> None: + """Sort + fold the entire buffer, then roll-write as files. + + On return, ``pending_data is None`` and every flushed chunk + has been recorded in ``committed_files``. The buffer is + always fully drained per flush: no slice remainder is + carried back into ``pending_data``. + """ + if self.pending_data is None or self.pending_data.num_rows == 0: + self.pending_data = None + return + sorted_data = self._sort_by_primary_key(self.pending_data) + folded = self._merge_pending_by_pk(sorted_data) + self.pending_data = None + if folded.num_rows == 0: + return + self._roll_write(folded) + + def _roll_write(self, data: pa.Table) -> None: + """Write ``data`` as one or more files, each <= target_file_size. + + ``data`` is required to be PK-unique (the fold guarantees + that), so any slice of it is also PK-unique -- splitting for + size does not violate the LSM file-internal invariant. + Reuses ``_find_optimal_split_point`` / ``_write_data_to_file`` + from the base class. + """ + while data.num_rows > 0: + if data.nbytes <= self.target_file_size: + self._write_data_to_file(data) + return + split_row = self._find_optimal_split_point( + data, self.target_file_size) + if split_row <= 0: + # Single row already exceeds target_file_size; nothing + # to gain from further slicing, write it as-is. + self._write_data_to_file(data) + return + self._write_data_to_file(data.slice(0, split_row)) + data = data.slice(split_row) + + def _merge_pending_by_pk(self, data: pa.Table) -> pa.Table: + """Fold same-PK runs in ``data`` using ``self._merge_function``. + + ``data`` is required to already be sorted by + ``(primary_key, _SEQUENCE_NUMBER)``. ``_flush_all`` is the + only caller and runs ``_sort_by_primary_key`` immediately + before this method, so the precondition holds. + + NOTE(follow-up): the merge runs row-by-row over + ``data.to_pydict()`` / ``pa.Table.from_pydict``. Arrow types + with non-trivial Python representations (Decimal128 with + specific precision/scale, timestamps with timezone or + sub-millisecond units, durations, deeply nested structs) can + drift across this round-trip. A columnar merge implementation + would close the gap and is tracked as a follow-up; until + then, partial-update on those types should be avoided in + pypaimon. + """ + n = data.num_rows + if n < 2: + # Single-row buffer cannot have duplicates; sidestep the + # row-by-row pyarrow round-trip in the common streaming case. + return data + + col_names = data.schema.names + # ``to_pydict`` works on pyarrow >= 6 (Python 3.6 CI ships 6.0.1), + # unlike ``to_pylist`` which only landed in pyarrow 7. + col_dict = data.to_pydict() + rows = [{name: col_dict[name][i] for name in col_names} + for i in range(n)] + key_arity = len(self.trimmed_primary_keys) + # System fields sit at indices [key_arity, key_arity + 1] (the + # _SEQUENCE_NUMBER and _VALUE_KIND columns added by + # _add_system_fields). Everything to the right is the value side. + value_arity = len(col_names) - key_arity - 2 + + # Pool one ``KeyValue`` for the whole fold. Safe because: + # - ``DeduplicateMergeFunction.add`` stores the kv reference; the + # reused instance always carries the most recent ``replace``, + # which is exactly the "latest wins" the engine wants. + # - ``PartialUpdateMergeFunction.add`` also stores a reference, + # but ``get_result`` snapshots key + sequence into a fresh + # tuple before returning, so the consumed result is decoupled + # from any later ``replace`` on the pooled kv. + # - ``FirstRowMergeFunction.add`` ``copy()``s the first kv, so it + # keeps the first row rather than tracking later ``replace``s on + # the pooled kv (which would otherwise yield the last row). + # This drops per-row ``KeyValue``/``OffsetRow`` allocations and + # the resulting GC churn on large buffers. + pooled_kv = KeyValue(key_arity, value_arity) + + merged_rows: List[dict] = [] + i = 0 + while i < n: + j = i + first_key = self._key_tuple(rows[i], col_names, key_arity) + while j < n and \ + self._key_tuple(rows[j], col_names, key_arity) == first_key: + j += 1 + run = rows[i:j] + self._merge_function.reset() + for r in run: + pooled_kv.replace(self._row_to_tuple(r, col_names)) + self._merge_function.add(pooled_kv) + result_kv = self._merge_function.get_result() + if result_kv is not None: + merged_rows.append( + self._kv_to_row(result_kv, col_names, + key_arity, value_arity)) + i = j + + if not merged_rows: + return data.slice(0, 0) + result_dict = {name: [r[name] for r in merged_rows] + for name in data.schema.names} + return pa.Table.from_pydict(result_dict, schema=data.schema) + + @staticmethod + def _key_tuple(row: dict, col_names: List[str], key_arity: int) -> tuple: + return tuple(row[col_names[i]] for i in range(key_arity)) + + @staticmethod + def _row_to_tuple(row: dict, col_names: List[str]) -> tuple: + return tuple(row[name] for name in col_names) + + @staticmethod + def _kv_to_row(kv: KeyValue, col_names: List[str], + key_arity: int, value_arity: int) -> dict: + out = {} + for i in range(key_arity): + out[col_names[i]] = kv.key.get_field(i) + out[col_names[key_arity]] = kv.sequence_number + out[col_names[key_arity + 1]] = kv.value_row_kind_byte + for i in range(value_arity): + out[col_names[key_arity + 2 + i]] = kv.value.get_field(i) + return out def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: """Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND.""" @@ -61,11 +266,15 @@ def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: return pa.RecordBatch.from_arrays(new_arrays, schema=pa.schema(new_fields)) - def _sort_by_primary_key(self, data: pa.RecordBatch) -> pa.RecordBatch: + def _sort_by_primary_key( + self, data: Union[pa.RecordBatch, pa.Table] + ) -> Union[pa.RecordBatch, pa.Table]: + # pc.sort_indices + .take work uniformly over RecordBatch and + # Table, so this serves both the per-batch entry path (legacy) + # and the buffer-wide sort path (used by ``_flush_all``). sort_keys = [(key, 'ascending') for key in self.trimmed_primary_keys] if '_SEQUENCE_NUMBER' in data.schema.names: sort_keys.append(('_SEQUENCE_NUMBER', 'ascending')) sorted_indices = pc.sort_indices(data, sort_keys=sort_keys) - sorted_batch = data.take(sorted_indices) - return sorted_batch + return data.take(sorted_indices) diff --git a/paimon-python/setup.py b/paimon-python/setup.py index aaba800fecb5..e4974bdc9b9e 100644 --- a/paimon-python/setup.py +++ b/paimon-python/setup.py @@ -174,6 +174,9 @@ def read_requirements(): 'vortex': [ 'vortex-data==0.70.0; python_version>="3.11"', ], + 'mosaic': [ + 'paimon-mosaic>=0.1.0', + ], 'lumina': [ 'lumina-data>=0.1.0' ], @@ -181,6 +184,9 @@ def read_requirements(): 'pypaimon-rust; python_version>="3.10"', 'datafusion>=52; python_version>="3.10"', ], + 'hdfs': [ + 'hdfs-native>=0.13,<1; python_version >= "3.10" and platform_system != "Windows"', + ], }, description="Apache Paimon Python API", long_description=long_description, diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 770bd8f802ba..ef1f68c09f02 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -25,13 +25,27 @@ import org.apache.spark.sql.connector.read.Scan class PaimonScanBuilder(val table: InnerTable) extends PaimonBaseScanBuilder { override def build(): Scan = { + val (actualTable, vectorSearch, fullTextSearch) = table match { + case vst: org.apache.paimon.table.VectorSearchTable => + val tableVectorSearch = Option(vst.vectorSearch()) + val vs = (tableVectorSearch, pushedVectorSearch) match { + case (Some(_), _) => tableVectorSearch + case (None, Some(_)) => pushedVectorSearch + case (None, None) => None + } + (vst.origin(), vs, None) + case ftst: org.apache.paimon.table.FullTextSearchTable => + (ftst.origin(), None, Option(ftst.fullTextSearch())) + case _ => (table, pushedVectorSearch, pushedFullTextSearch) + } PaimonScan( - table, + actualTable, requiredSchema, pushedPartitionFilters, pushedDataFilters, pushedLimit, pushedTopN, - pushedVectorSearch) + vectorSearch, + fullTextSearch) } } diff --git a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/PaimonScanBuilderTest.scala b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/PaimonScanBuilderTest.scala new file mode 100755 index 000000000000..bbd79c8d78c6 --- /dev/null +++ b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/PaimonScanBuilderTest.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +/** Tests for [[PaimonScanBuilder]] in spark-3.2 module. */ +class PaimonScanBuilderTest extends PaimonSparkTestBase { + + test("PaimonScanBuilder: read vector table normally on spark-3.2") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id BIGINT, embs ARRAY) + |TBLPROPERTIES ( + | 'vector.file.format' = 'lance', + | 'vector-field' = 'embs', + | 'field.embs.vector-dim' = '3', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true' + |) + |""".stripMargin) + + var rows = spark.sql("SELECT id, embs FROM T ORDER BY id") + assert(rows.isEmpty) + rows = + spark.sql("select id, embs from vector_search('T', 'embs', array(1.0f, 2.0f, 3.0f), 5)") + assert(rows.isEmpty) + } + } +} diff --git a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala index d1515031e525..e7eb9a0d9516 100644 --- a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala +++ b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala @@ -18,4 +18,4 @@ package org.apache.paimon.spark.sql -class CopyIntoTest extends CopyIntoTestBase {} +class CopyIntoTest extends CopyIntoTestBase with CopyIntoOnErrorTest {} diff --git a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/DataFrameWriteTest.scala b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/DataFrameWriteTest.scala index cb449edb4ccb..ac5c6c3a178f 100644 --- a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/DataFrameWriteTest.scala +++ b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/DataFrameWriteTest.scala @@ -40,7 +40,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase { .option("bucket", "-1") .option("target-file-size", "256MB") .option("write.merge-schema", "true") - .option("write.merge-schema.explicit-cast", "true") .saveAsTable("test_ctas") val paimonTable = loadTable("test_ctas") @@ -53,7 +52,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase { // non-core options should not be here. Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema")) - Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema.explicit-cast")) } } diff --git a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala index d1515031e525..e7eb9a0d9516 100644 --- a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala +++ b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala @@ -18,4 +18,4 @@ package org.apache.paimon.spark.sql -class CopyIntoTest extends CopyIntoTestBase {} +class CopyIntoTest extends CopyIntoTestBase with CopyIntoOnErrorTest {} diff --git a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala index ab4a9bcd9dbf..d213f257c7f3 100644 --- a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala +++ b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala @@ -260,6 +260,7 @@ class PaimonSinkTest extends PaimonSparkTestBase with StreamTest { .writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .option("write.merge-schema", "true") + .option("write.merge-schema.type-widening", "true") .option("write.merge-schema.explicit-cast", "true") .format("paimon") .start(location) diff --git a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala index d1515031e525..e7eb9a0d9516 100644 --- a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala +++ b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala @@ -18,4 +18,4 @@ package org.apache.paimon.spark.sql -class CopyIntoTest extends CopyIntoTestBase {} +class CopyIntoTest extends CopyIntoTestBase with CopyIntoOnErrorTest {} diff --git a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala index d1515031e525..e7eb9a0d9516 100644 --- a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala +++ b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala @@ -18,4 +18,4 @@ package org.apache.paimon.spark.sql -class CopyIntoTest extends CopyIntoTestBase {} +class CopyIntoTest extends CopyIntoTestBase with CopyIntoOnErrorTest {} diff --git a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala index 9f96840a7788..60bfd244b2de 100644 --- a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala +++ b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala @@ -18,4 +18,21 @@ package org.apache.paimon.spark.sql -class RowTrackingTest extends RowTrackingTestBase {} +import org.apache.paimon.spark.SparkTable + +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations + +class RowTrackingTest extends RowTrackingTestBase { + + test("Row Tracking: Spark 3.5 keeps row-tracking tables on V1 DML path") { + withSparkSQLConf("spark.paimon.write.use-v2-write" -> "true") { + withTable("t", "rt") { + sql("CREATE TABLE t (id INT, data INT)") + assert(SparkTable.of(loadTable("t")).isInstanceOf[SupportsRowLevelOperations]) + + sql("CREATE TABLE rt (id INT, data INT) TBLPROPERTIES ('row-tracking.enabled' = 'true')") + assert(!SparkTable.of(loadTable("rt")).isInstanceOf[SupportsRowLevelOperations]) + } + } + } +} diff --git a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala new file mode 100644 index 000000000000..fe4833a03cb4 --- /dev/null +++ b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.catalyst.parser.extensions + +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.AbstractPaimonSparkSqlExtensionsParser +import org.apache.spark.sql.types.StructType + +class PaimonSpark4SqlExtensionsParser(override val delegate: ParserInterface) + extends AbstractPaimonSparkSqlExtensionsParser(delegate) + with ParserInterface { + + override def parseRoutineParam(sqlText: String): StructType = delegate.parseRoutineParam(sqlText) +} diff --git a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala index 492d64bbf5bf..ad6f5b95011a 100644 --- a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala +++ b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala @@ -27,16 +27,20 @@ import org.apache.paimon.spark.SparkTable import org.apache.paimon.spark.catalyst.analysis.PaimonRelation import org.apache.paimon.spark.catalyst.analysis.PaimonRelation.isPaimonTable import org.apache.paimon.spark.catalyst.analysis.PaimonUpdateTable.toColumn +import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand import org.apache.paimon.spark.util.ScanPlanHelper.createNewScanPlan import org.apache.paimon.table.FileStoreTable import org.apache.paimon.table.sink.{CommitMessage, CommitMessageImpl} import org.apache.paimon.table.source.DataSplit +import org.apache.paimon.table.source.snapshot.SnapshotReader +import org.apache.paimon.types.RowType +import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.PaimonUtils._ import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.resolver -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, EqualTo, Expression, ExprId, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, AttributeReference, EqualTo, Expression, ExprId, Literal, PythonUDF, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftOuter} import org.apache.spark.sql.catalyst.plans.logical._ @@ -60,7 +64,9 @@ case class MergeIntoPaimonDataEvolutionTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction]) extends PaimonLeafRunnableCommand - with WithFileStoreTable { + with WithFileStoreTable + with ExpressionHelper + with Logging { private lazy val writer = PaimonSparkWriter(table) @@ -136,12 +142,16 @@ case class MergeIntoPaimonDataEvolutionTable( lazy val tableSchema: StructType = v2Table.schema override def run(sparkSession: SparkSession): Seq[Row] = { + // Persist the schema that the analyzer evolved in memory (commit deferred to execution). + SchemaEvolutionHelper.commitEvolvedSchemaAtExecution(table, targetRelation, sparkSession) invokeMergeInto(sparkSession) Seq.empty[Row] } private def invokeMergeInto(sparkSession: SparkSession): Unit = { - val plan = table.newSnapshotReader().read() + val snapshotReader = table.newSnapshotReader() + pushDownMergePartitionFilter(snapshotReader) + val plan = snapshotReader.read() val tableSplits: Seq[DataSplit] = plan .splits() .asScala @@ -174,44 +184,114 @@ case class MergeIntoPaimonDataEvolutionTable( map.toMap } - // step 1: find the related data splits, make it target file plan - val dataSplits: Seq[DataSplit] = - targetRelatedSplits(sparkSession, tableSplits, firstRowIds, firstRowIdToBlobFirstRowIds) - val touchedFileTargetRelation = - createNewScanPlan(dataSplits, targetRelation) - - // step 2: invoke update action - val updateCommit = - if (matchedActions.nonEmpty) { - val updateResult = - updateActionInvoke(dataSplits, sparkSession, touchedFileTargetRelation, firstRowIds) - checkUpdateResult(updateResult) - } else Nil - - // step 3: invoke insert action - val insertCommit = - if (notMatchedActions.nonEmpty) - insertActionInvoke(sparkSession, touchedFileTargetRelation) - else Nil - - if (plan.snapshotId() != null) { - writer.rowIdCheckConflict(plan.snapshotId()) + val persistSourceDss: Option[Dataset[Row]] = + if ( + table.coreOptions().dataEvolutionMergeIntoSourcePersist() + && (matchedActions.nonEmpty || notMatchedActions.nonEmpty) + ) { + val dss = createDataset(sparkSession, sourceTable) + dss.persist() + Some(dss) + } else { + None + } + + try { + // step 1: find the related data splits, make it target file plan + val dataSplits: Seq[DataSplit] = targetRelatedSplits( + sparkSession, + tableSplits, + firstRowIds, + firstRowIdToBlobFirstRowIds, + persistSourceDss) + val touchedFileTargetRelation = + createNewScanPlan(dataSplits, targetRelation) + + // step 2: invoke update action + val updateCommit = + if (matchedActions.nonEmpty) { + val updateResult = + updateActionInvoke( + dataSplits, + sparkSession, + touchedFileTargetRelation, + firstRowIds, + persistSourceDss) + checkUpdateResult(updateResult) + } else Nil + + // step 3: invoke insert action + val insertCommit = + if (notMatchedActions.nonEmpty) + insertActionInvoke(sparkSession, touchedFileTargetRelation, persistSourceDss) + else Nil + + if (plan.snapshotId() != null) { + writer.rowIdCheckConflict(plan.snapshotId()) + } + writer.commit(updateCommit ++ insertCommit) + } finally { + if (persistSourceDss.isDefined) { + persistSourceDss.get.unpersist(blocking = false) + } + } + } + + private def pushDownMergePartitionFilter(snapshotReader: SnapshotReader): Unit = { + val partitionRowType = table.schema().logicalPartitionType() + if (partitionRowType.getFieldCount == 0) { + return + } + + // matchedCondition comes from MergeIntoTable.mergeCondition, which is the MERGE ON condition. + val partitionPredicates = getExpressionOnlyRelated(matchedCondition, targetTable) + .map(splitConjunctivePredicates) + .map(extractMergePartitionFilters(_, partitionRowType)) + .getOrElse(Seq.empty) + + if (partitionPredicates.nonEmpty) { + val filter = convertConditionToPaimonPredicate( + partitionPredicates.reduce(And), + targetRelation.output, + rowType, + ignorePartialFailure = true) + filter.foreach(snapshotReader.withFilter) + } + } + + private def extractMergePartitionFilters( + filters: Seq[Expression], + partitionRowType: RowType): Seq[Expression] = { + val partitionColumns = partitionRowType.getFieldNames.asScala.toSet + filters.filter { + f => + f.deterministic && + f.references.forall(attr => partitionColumns.exists(_.equalsIgnoreCase(attr.name))) && + !SubqueryExpression.hasSubquery(f) && + f.collect { case _: PythonUDF => true }.isEmpty } - writer.commit(updateCommit ++ insertCommit) } private def targetRelatedSplits( sparkSession: SparkSession, tableSplits: Seq[DataSplit], firstRowIds: immutable.IndexedSeq[Long], - firstRowIdToBlobFirstRowIds: Map[Long, List[Long]]): Seq[DataSplit] = { + firstRowIdToBlobFirstRowIds: Map[Long, List[Long]], + persistSourceDss: Option[Dataset[Row]]): Seq[DataSplit] = { // Self-Merge shortcut: // In Self-Merge mode, every row in the table may be updated, so we scan all splits. if (isSelfMergeOnRowId) { return tableSplits } - val sourceDss = createDataset(sparkSession, sourceTable) + if (!table.coreOptions().dataEvolutionMergeIntoFilePruning()) { + logInfo( + "Skip file-level pruning for MergeInto partial column update on data-evolution table " + + s"${table.name()}.") + return tableSplits + } + + val sourceDss = persistSourceDss.getOrElse(createDataset(sparkSession, sourceTable)) val firstRowIdsTouched = extractSourceRowIdMapping match { case Some(sourceRowIdAttr) => @@ -248,7 +328,8 @@ case class MergeIntoPaimonDataEvolutionTable( dataSplits: Seq[DataSplit], sparkSession: SparkSession, touchedFileTargetRelation: DataSourceV2Relation, - firstRowIds: immutable.IndexedSeq[Long]): Seq[CommitMessage] = { + firstRowIds: immutable.IndexedSeq[Long], + persistSourceDss: Option[Dataset[Row]]): Seq[CommitMessage] = { val mergeFields = extractFields(matchedCondition) val allFields = mutable.SortedSet.empty[AttributeReference]( (o1, o2) => { @@ -371,7 +452,8 @@ case class MergeIntoPaimonDataEvolutionTable( val sourceTableProjExprs = allReadFieldsOnSource.toSeq :+ Alias(TrueLiteral, ROW_FROM_SOURCE)() - val sourceTableProj = Project(sourceTableProjExprs, sourceTable) + val sourceChild = persistSourceDss.map(_.queryExecution.logical).getOrElse(sourceTable) + val sourceTableProj = Project(sourceTableProjExprs, sourceChild) val joinPlan = Join(targetTableProj, sourceTableProj, LeftOuter, Some(matchedCondition), JoinHint.NONE) @@ -414,16 +496,18 @@ case class MergeIntoPaimonDataEvolutionTable( private def insertActionInvoke( sparkSession: SparkSession, - touchedFileTargetRelation: DataSourceV2Relation): Seq[CommitMessage] = { + touchedFileTargetRelation: DataSourceV2Relation, + persistSourceDss: Option[Dataset[Row]]): Seq[CommitMessage] = { val mergeFields = extractFields(matchedCondition) val allReadFieldsOnTarget = mergeFields.filter(field => targetTable.output.exists(attr => attr.equals(field))) val targetReadPlan = touchedFileTargetRelation.copy(targetRelation.table, allReadFieldsOnTarget.toSeq) + val sourceReadPlan = persistSourceDss.map(_.queryExecution.logical).getOrElse(sourceTable) val joinPlan = - Join(sourceTable, targetReadPlan, LeftAnti, Some(matchedCondition), JoinHint.NONE) + Join(sourceReadPlan, targetReadPlan, LeftAnti, Some(matchedCondition), JoinHint.NONE) // merge rows as there are multiple not matched actions val mergeRows = MergeRows( diff --git a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala index 1822929854e5..89689e108c45 100644 --- a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala @@ -71,6 +71,8 @@ case class MergeIntoPaimonTable( override def run(sparkSession: SparkSession): Seq[Row] = { // Avoid that more than one source rows match the same target row. checkMatchRationality(sparkSession) + // Persist the schema that the analyzer evolved in memory (commit deferred to execution). + SchemaEvolutionHelper.commitEvolvedSchemaAtExecution(table, relation, sparkSession) val commitMessages = if (withPrimaryKeys) { performMergeForPkTable(sparkSession) } else { diff --git a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala deleted file mode 100644 index 3529944f37ef..000000000000 --- a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala +++ /dev/null @@ -1,468 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser.extensions - -import org.apache.paimon.spark.SparkProcedures - -import org.antlr.v4.runtime._ -import org.antlr.v4.runtime.atn.PredictionMode -import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} -import org.antlr.v4.runtime.tree.TerminalNodeImpl -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, PaimonSparkSession, SparkSession} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} -import org.apache.spark.sql.catalyst.parser.extensions.PaimonSqlExtensionsParser.{NonReservedContext, QuotedIdentifierContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.VariableSubstitution -import org.apache.spark.sql.types.{DataType, StructType} - -import java.util.Locale - -import scala.collection.JavaConverters._ - -/* This file is based on source code from the Iceberg Project (http://iceberg.apache.org/), licensed by the Apache - * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. */ - -/** - * The implementation of [[ParserInterface]] that parsers the sql extension. - * - *

    Most of the content of this class is referenced from Iceberg's - * IcebergSparkSqlExtensionsParser. - * - * @param delegate - * The extension parser. - */ -// Keep this class in the Spark 4.0 module so it is compiled against Spark 4.0's ParserInterface. -abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterface) - extends org.apache.spark.sql.catalyst.parser.ParserInterface - with Logging { - - private lazy val substitutor = new VariableSubstitution() - private lazy val astBuilder = new PaimonSqlExtensionsAstBuilder(delegate) - private val nonReservedIdentifierTokenTypes = Set( - PaimonSqlExtensionsParser.ALTER, - PaimonSqlExtensionsParser.AS, - PaimonSqlExtensionsParser.CALL, - PaimonSqlExtensionsParser.CREATE, - PaimonSqlExtensionsParser.DAYS, - PaimonSqlExtensionsParser.DELETE, - PaimonSqlExtensionsParser.EXISTS, - PaimonSqlExtensionsParser.HOURS, - PaimonSqlExtensionsParser.IF, - PaimonSqlExtensionsParser.LIKE, - PaimonSqlExtensionsParser.NOT, - PaimonSqlExtensionsParser.OF, - PaimonSqlExtensionsParser.OR, - PaimonSqlExtensionsParser.TABLE, - PaimonSqlExtensionsParser.REPLACE, - PaimonSqlExtensionsParser.RETAIN, - PaimonSqlExtensionsParser.VERSION, - PaimonSqlExtensionsParser.TAG, - PaimonSqlExtensionsParser.TRUE, - PaimonSqlExtensionsParser.FALSE, - PaimonSqlExtensionsParser.MAP, - PaimonSqlExtensionsParser.COPY, - PaimonSqlExtensionsParser.INTO, - PaimonSqlExtensionsParser.FROM, - PaimonSqlExtensionsParser.FILE_FORMAT, - PaimonSqlExtensionsParser.PATTERN, - PaimonSqlExtensionsParser.FORCE, - PaimonSqlExtensionsParser.ON_ERROR, - PaimonSqlExtensionsParser.ABORT_STATEMENT, - PaimonSqlExtensionsParser.OVERWRITE, - PaimonSqlExtensionsParser.CSV - ) - - /** Parses a string to a LogicalPlan. */ - override def parsePlan(sqlText: String): LogicalPlan = { - val sqlTextAfterSubstitution = substitutor.substitute(sqlText) - if (isPaimonCommand(sqlTextAfterSubstitution)) { - parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement())) - .asInstanceOf[LogicalPlan] - } else { - var plan = - try { - delegate.parsePlan(sqlText) - } catch { - case _: ParseException if maybeCatalogCreateTableLike(sqlTextAfterSubstitution) => - parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement())) - .asInstanceOf[LogicalPlan] - } - val sparkSession = PaimonSparkSession.active - parserRules(sparkSession).foreach( - rule => { - plan = rule.apply(plan) - }) - plan - } - } - - private def parserRules(sparkSession: SparkSession): Seq[Rule[LogicalPlan]] = { - Seq( - RewritePaimonViewCommands(sparkSession), - RewritePaimonFunctionCommands(sparkSession), - RewriteCreateTableLikeCommand(sparkSession), - RewriteSparkDDLCommands(sparkSession) - ) - } - - /** Parses a string to an Expression. */ - override def parseExpression(sqlText: String): Expression = - delegate.parseExpression(sqlText) - - /** Parses a string to a TableIdentifier. */ - override def parseTableIdentifier(sqlText: String): TableIdentifier = - delegate.parseTableIdentifier(sqlText) - - /** Parses a string to a FunctionIdentifier. */ - override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = - delegate.parseFunctionIdentifier(sqlText) - - /** - * Creates StructType for a given SQL string, which is a comma separated list of field definitions - * which will preserve the correct Hive metadata. - */ - override def parseTableSchema(sqlText: String): StructType = - delegate.parseTableSchema(sqlText) - - /** Parses a string to a DataType. */ - override def parseDataType(sqlText: String): DataType = - delegate.parseDataType(sqlText) - - /** Parses a string to a multi-part identifier. */ - override def parseMultipartIdentifier(sqlText: String): Seq[String] = - delegate.parseMultipartIdentifier(sqlText) - - /** Returns whether SQL text is command. */ - private def isPaimonCommand(sqlText: String): Boolean = { - val normalized = sqlText - .toLowerCase(Locale.ROOT) - .trim() - .replaceAll("--.*?\\n", " ") - .replaceAll("\\s+", " ") - .replaceAll("/\\*.*?\\*/", " ") - .replaceAll("`", "") - .trim() - isPaimonProcedure(normalized) || isTagRefDdl(normalized) || isCopyInto(normalized) - } - - // All builtin paimon procedures are under the 'sys' namespace - private def isPaimonProcedure(normalized: String): Boolean = { - normalized.startsWith("call") && - SparkProcedures.names().asScala.map("sys." + _).exists(normalized.contains) - } - - private def isTagRefDdl(normalized: String): Boolean = { - normalized.startsWith("show tags") || - (normalized.startsWith("alter table") && - (normalized.contains("create tag") || - normalized.contains("replace tag") || - normalized.contains("rename tag") || - normalized.contains("delete tag"))) - } - - private def isCopyInto(normalized: String): Boolean = { - normalized.startsWith("copy into") - } - - /** - * Cheap token-level check for `CREATE TABLE [IF NOT EXISTS] x.y[.z] LIKE ...` shape. Used as a - * gate for the Paimon parser fallback when the delegate parser rejects a catalog-qualified CREATE - * TABLE LIKE statement. - */ - private def maybeCatalogCreateTableLike(sqlText: String): Boolean = { - if (org.apache.spark.SPARK_VERSION < "3.4") { - return false - } - if (!startsWithCreateTable(sqlText)) { - return false - } - - tokenStream(sqlText) match { - case Some(tokens) => maybeCreateTableLike(tokens) - case None => false - } - } - - private def tokenStream(sqlText: String): Option[CommonTokenStream] = { - try { - val lexer = new PaimonSqlExtensionsLexer( - new UpperCaseCharStream(CharStreams.fromString(sqlText))) - lexer.removeErrorListeners() - lexer.addErrorListener(PaimonParseErrorListener) - - val tokens = new CommonTokenStream(lexer) - tokens.fill() - Some(tokens) - } catch { - case _: PaimonParseException => None - } - } - - private def maybeCreateTableLike(tokenStream: CommonTokenStream): Boolean = { - val tokens = tokenStream.getTokens.asScala - .filter(token => token.getChannel == Token.DEFAULT_CHANNEL) - .filterNot(token => token.getType == Token.EOF) - - if (tokens.length < 5) return false - if (tokens(0).getType != PaimonSqlExtensionsParser.CREATE) return false - if (tokens(1).getType != PaimonSqlExtensionsParser.TABLE) return false - - var idx = 2 - if ( - idx + 2 < tokens.length && - tokens(idx).getType == PaimonSqlExtensionsParser.IF && - tokens(idx + 1).getType == PaimonSqlExtensionsParser.NOT && - tokens(idx + 2).getType == PaimonSqlExtensionsParser.EXISTS - ) { - idx += 3 - } - - if (idx >= tokens.length || !isIdentifierToken(tokens(idx))) return false - idx += 1 - - while ( - idx + 1 < tokens.length && - tokens(idx).getText == "." && - isIdentifierToken(tokens(idx + 1)) - ) { - idx += 2 - } - - idx < tokens.length && tokens(idx).getType == PaimonSqlExtensionsParser.LIKE - } - - private def isIdentifierToken(token: Token): Boolean = { - token.getType == PaimonSqlExtensionsParser.IDENTIFIER || - token.getType == PaimonSqlExtensionsParser.BACKQUOTED_IDENTIFIER || - nonReservedIdentifierTokenTypes.contains(token.getType) - } - - private def startsWithCreateTable(sqlText: String): Boolean = { - val createIndex = skipWhitespaceAndComments(sqlText, 0) - if (!matchesWord(sqlText, createIndex, "create")) { - return false - } - - val tableIndex = skipWhitespaceAndComments(sqlText, createIndex + "create".length) - matchesWord(sqlText, tableIndex, "table") - } - - private def skipWhitespaceAndComments(sqlText: String, start: Int): Int = { - var index = start - var continue = true - - while (continue) { - while (index < sqlText.length && sqlText.charAt(index).isWhitespace) { - index += 1 - } - - if ( - index + 1 < sqlText.length && - sqlText.charAt(index) == '-' && - sqlText.charAt(index + 1) == '-' - ) { - index += 2 - while ( - index < sqlText.length && - sqlText.charAt(index) != '\n' && - sqlText.charAt(index) != '\r' - ) { - index += 1 - } - } else if ( - index + 1 < sqlText.length && - sqlText.charAt(index) == '/' && - sqlText.charAt(index + 1) == '*' - ) { - val close = sqlText.indexOf("*/", index + 2) - index = if (close >= 0) close + 2 else sqlText.length - } else { - continue = false - } - } - - index - } - - private def matchesWord(sqlText: String, index: Int, word: String): Boolean = { - index + word.length <= sqlText.length && - sqlText.regionMatches(true, index, word, 0, word.length) && - (index + word.length == sqlText.length || - !isIdentifierPart(sqlText.charAt(index + word.length))) - } - - private def isIdentifierPart(char: Char): Boolean = { - char.isLetterOrDigit || char == '_' - } - - protected def parse[T](command: String)(toResult: PaimonSqlExtensionsParser => T): T = { - val lexer = new PaimonSqlExtensionsLexer( - new UpperCaseCharStream(CharStreams.fromString(command))) - lexer.removeErrorListeners() - lexer.addErrorListener(PaimonParseErrorListener) - - val tokenStream = new CommonTokenStream(lexer) - val parser = new PaimonSqlExtensionsParser(tokenStream) - parser.addParseListener(PaimonSqlExtensionsPostProcessor) - parser.removeErrorListeners() - parser.addErrorListener(PaimonParseErrorListener) - - try { - try { - parser.getInterpreter.setPredictionMode(PredictionMode.SLL) - toResult(parser) - } catch { - case _: ParseCancellationException => - tokenStream.seek(0) - parser.reset() - parser.getInterpreter.setPredictionMode(PredictionMode.LL) - toResult(parser) - } - } catch { - case e: PaimonParseException if e.command.isDefined => - throw e - case e: PaimonParseException => - throw e.withCommand(command) - case e: AnalysisException => - val position = Origin(e.line, e.startPosition) - throw new PaimonParseException(Option(command), e.message, position, position) - } - } - - def parseQuery(sqlText: String): LogicalPlan = - parsePlan(sqlText) -} - -/* Copied from Apache Spark's to avoid dependency on Spark Internals */ -class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { - override def consume(): Unit = wrapped.consume() - override def getSourceName: String = wrapped.getSourceName - override def index(): Int = wrapped.index - override def mark(): Int = wrapped.mark - override def release(marker: Int): Unit = wrapped.release(marker) - override def seek(where: Int): Unit = wrapped.seek(where) - override def size(): Int = wrapped.size - - override def getText(interval: Interval): String = wrapped.getText(interval) - - // scalastyle:off - override def LA(i: Int): Int = { - val la = wrapped.LA(i) - if (la == 0 || la == IntStream.EOF) la - else Character.toUpperCase(la) - } - // scalastyle:on -} - -/** The post-processor validates & cleans-up the parse tree during the parse process. */ -case object PaimonSqlExtensionsPostProcessor extends PaimonSqlExtensionsBaseListener { - - /** Removes the back ticks from an Identifier. */ - override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { - replaceTokenByIdentifier(ctx, 1) { - token => - // Remove the double back ticks in the string. - token.setText(token.getText.replace("``", "`")) - token - } - } - - /** Treats non-reserved keywords as Identifiers. */ - override def exitNonReserved(ctx: NonReservedContext): Unit = { - replaceTokenByIdentifier(ctx, 0)(identity) - } - - private def replaceTokenByIdentifier(ctx: ParserRuleContext, stripMargins: Int)( - f: CommonToken => CommonToken = identity): Unit = { - val parent = ctx.getParent - parent.removeLastChild() - val token = ctx.getChild(0).getPayload.asInstanceOf[Token] - val newToken = new CommonToken( - new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), - PaimonSqlExtensionsParser.IDENTIFIER, - token.getChannel, - token.getStartIndex + stripMargins, - token.getStopIndex - stripMargins - ) - parent.addChild(new TerminalNodeImpl(f(newToken))) - } -} - -/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */ -case object PaimonParseErrorListener extends BaseErrorListener { - override def syntaxError( - recognizer: Recognizer[_, _], - offendingSymbol: scala.Any, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException): Unit = { - val (start, stop) = offendingSymbol match { - case token: CommonToken => - val start = Origin(Some(line), Some(token.getCharPositionInLine)) - val length = token.getStopIndex - token.getStartIndex + 1 - val stop = Origin(Some(line), Some(token.getCharPositionInLine + length)) - (start, stop) - case _ => - val start = Origin(Some(line), Some(charPositionInLine)) - (start, start) - } - throw new PaimonParseException(None, msg, start, stop) - } -} - -/** - * Copied from Apache Spark [[ParseException]], it contains fields and an extended error message - * that make reporting and diagnosing errors easier. - */ -class PaimonParseException( - val command: Option[String], - message: String, - start: Origin, - stop: Origin) - extends Exception { - - override def getMessage: String = { - val builder = new StringBuilder - builder ++= "\n" ++= message - start match { - case Origin(Some(l), Some(p), Some(_), Some(_), Some(_), Some(_), Some(_)) => - builder ++= s"(line $l, pos $p)\n" - command.foreach { - cmd => - val (above, below) = cmd.split("\n").splitAt(l) - builder ++= "\n== SQL ==\n" - above.foreach(builder ++= _ += '\n') - builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" - below.foreach(builder ++= _ += '\n') - } - case _ => - command.foreach(cmd => builder ++= "\n== SQL ==\n" ++= cmd) - } - builder.toString - } - - def withCommand(cmd: String): PaimonParseException = - new PaimonParseException(Option(cmd), message, start, stop) -} diff --git a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala index d1515031e525..e7eb9a0d9516 100644 --- a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala +++ b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala @@ -18,4 +18,4 @@ package org.apache.paimon.spark.sql -class CopyIntoTest extends CopyIntoTestBase {} +class CopyIntoTest extends CopyIntoTestBase with CopyIntoOnErrorTest {} diff --git a/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala b/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala index d1515031e525..e7eb9a0d9516 100644 --- a/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala +++ b/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTest.scala @@ -18,4 +18,4 @@ package org.apache.paimon.spark.sql -class CopyIntoTest extends CopyIntoTestBase {} +class CopyIntoTest extends CopyIntoTestBase with CopyIntoOnErrorTest {} diff --git a/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala b/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala index 9f96840a7788..382aa1e77880 100644 --- a/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala +++ b/paimon-spark/paimon-spark-4.1/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTest.scala @@ -18,4 +18,111 @@ package org.apache.paimon.spark.sql -class RowTrackingTest extends RowTrackingTestBase {} +import org.apache.paimon.spark.SparkTable +import org.apache.paimon.spark.schema.PaimonMetadataColumn + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.types.Metadata + +class RowTrackingTest extends RowTrackingTestBase { + + test("Row Tracking: metadata columns expose Spark preserve flags") { + val rowIdMetadata = Metadata.fromJson(PaimonMetadataColumn.ROW_ID.metadataInJSON()) + assert(rowIdMetadata.getBoolean("__preserve_on_delete")) + assert(rowIdMetadata.getBoolean("__preserve_on_update")) + assert(!rowIdMetadata.getBoolean("__preserve_on_reinsert")) + + val sequenceNumberMetadata = + Metadata.fromJson(PaimonMetadataColumn.SEQUENCE_NUMBER.metadataInJSON()) + assert(sequenceNumberMetadata.getBoolean("__preserve_on_delete")) + assert(!sequenceNumberMetadata.getBoolean("__preserve_on_update")) + assert(!sequenceNumberMetadata.getBoolean("__preserve_on_reinsert")) + } + + test("Row Tracking: Spark 4.1 uses V2 copy-on-write for DML") { + withSparkSQLConf("spark.paimon.write.use-v2-write" -> "true") { + withTable("s", "t") { + sql("CREATE TABLE t (id INT, data INT) TBLPROPERTIES ('row-tracking.enabled' = 'true')") + sql("INSERT INTO t VALUES (1, 1), (2, 2)") + sql("INSERT INTO t VALUES (3, 3), (4, 4)") + + assertPlanContains("DELETE FROM t WHERE id = 2", "ReplaceData") + sql("DELETE FROM t WHERE id = 2") + + assertPlanContains("UPDATE t SET data = 30 WHERE id = 3", "ReplaceData") + sql("UPDATE t SET data = 30 WHERE id = 3") + + sql("CREATE TABLE s (id INT, data INT)") + sql("INSERT INTO s VALUES (3, 300), (5, 500)") + assertPlanContains( + """ + |MERGE INTO t + |USING s + |ON t.id = s.id + |WHEN MATCHED THEN UPDATE SET data = s.data + |WHEN NOT MATCHED THEN INSERT * + |""".stripMargin, + "ReplaceData" + ) + sql(""" + |MERGE INTO t + |USING s + |ON t.id = s.id + |WHEN MATCHED THEN UPDATE SET data = s.data + |WHEN NOT MATCHED THEN INSERT * + |""".stripMargin) + + checkAnswer( + sql("SELECT *, _ROW_ID, _SEQUENCE_NUMBER FROM t ORDER BY id"), + Seq(Row(1, 1, 0, 1), Row(3, 300, 2, 5), Row(4, 4, 3, 2), Row(5, 500, 4, 5)) + ) + } + } + } + + test("Row Tracking: nested CHAR columns do not expose V2 row-level capability") { + withSparkSQLConf("spark.paimon.write.use-v2-write" -> "true") { + withTable("t") { + sql(""" + |CREATE TABLE t ( + | id INT, + | info STRUCT + |) TBLPROPERTIES ('row-tracking.enabled' = 'true') + |""".stripMargin) + + assert(!SparkTable.of(loadTable("t")).isInstanceOf[SupportsRowLevelOperations]) + } + } + } + + test("Row Tracking: Spark 4.1 restores metadata-only delete fast path") { + withSparkSQLConf("spark.paimon.write.use-v2-write" -> "true") { + withTable("t") { + sql(""" + |CREATE TABLE t (id INT, data INT, dt STRING) + |PARTITIONED BY (dt) + |TBLPROPERTIES ('row-tracking.enabled' = 'true') + |""".stripMargin) + sql("INSERT INTO t VALUES (1, 1, 'p1'), (2, 2, 'p1'), (3, 3, 'p2')") + + assertPlanContains("DELETE FROM t WHERE dt = 'p1'", "DeleteFromPaimonTableCommand") + sql("DELETE FROM t WHERE dt = 'p1'") + + checkAnswer( + sql("SELECT *, _ROW_ID, _SEQUENCE_NUMBER FROM t ORDER BY id"), + Seq(Row(3, 3, "p2", 0, 1)) + ) + } + } + } + + private def assertPlanContains(sqlText: String, fragment: String): Unit = { + val plan = explain(sqlText) + assert(plan.contains(fragment), plan) + } + + private def explain(sqlText: String): String = { + sql(s"EXPLAIN EXTENDED $sqlText").collect().map(_.getString(0)).mkString("\n") + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 b/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 index 12a5bc8c51b6..620a2bb95abc 100644 --- a/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 +++ b/paimon-spark/paimon-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/PaimonSqlExtensions.g4 @@ -145,7 +145,7 @@ forceClause ; onErrorClause - : ON_ERROR '=' ABORT_STATEMENT + : ON_ERROR '=' (ABORT_STATEMENT | CONTINUE | SKIP_FILE) ; overwriteClause @@ -203,8 +203,10 @@ nonReserved | NOT | OF | OR | TABLE | REPLACE | RETAIN | VERSION | TAG | TRUE | FALSE | MAP - | COPY | INTO | FROM | FILE_FORMAT | PATTERN | FORCE | ON_ERROR | ABORT_STATEMENT | OVERWRITE + | COPY | INTO | FROM | FILE_FORMAT | PATTERN | FORCE | ON_ERROR | ABORT_STATEMENT | CONTINUE | SKIP_FILE | OVERWRITE | CSV + | JSON + | PARQUET ; ALTER: 'ALTER'; @@ -244,8 +246,12 @@ PATTERN: 'PATTERN'; FORCE: 'FORCE'; ON_ERROR: 'ON_ERROR'; ABORT_STATEMENT: 'ABORT_STATEMENT'; +CONTINUE: 'CONTINUE'; +SKIP_FILE: 'SKIP_FILE'; OVERWRITE: 'OVERWRITE'; CSV: 'CSV'; +JSON: 'JSON'; +PARQUET: 'PARQUET'; PLUS: '+'; MINUS: '-'; diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java index 283077430e19..5910c6074450 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java @@ -25,6 +25,7 @@ import org.apache.paimon.types.DataType; import org.apache.paimon.types.DataTypeChecks; import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VectorType; import org.apache.paimon.utils.InternalRowUtils; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -172,7 +173,13 @@ public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numF @Override public ArrayData getArray(int ordinal) { - return fromPaimon(row.getArray(ordinal), (ArrayType) rowType.getTypeAt(ordinal)); + DataType type = rowType.getTypeAt(ordinal); + if (type instanceof ArrayType) { + return fromPaimon(row.getArray(ordinal), (ArrayType) type); + } else if (type instanceof VectorType) { + return DataConverter.fromPaimon(row.getVector(ordinal), (VectorType) type); + } + throw new UnsupportedOperationException("Not an array type: " + type); } @Override diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/DataConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/DataConverter.java index 0b5ea899476e..5c8026f461df 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/DataConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/DataConverter.java @@ -22,6 +22,7 @@ import org.apache.paimon.data.InternalArray; import org.apache.paimon.data.InternalMap; import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.InternalVector; import org.apache.paimon.data.Timestamp; import org.apache.paimon.spark.data.SparkArrayData; import org.apache.paimon.spark.data.SparkInternalRow; @@ -32,6 +33,7 @@ import org.apache.paimon.types.MapType; import org.apache.paimon.types.MultisetType; import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VectorType; import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -58,6 +60,8 @@ public static Object fromPaimon(Object o, DataType type) { return fromPaimon((org.apache.paimon.data.Decimal) o); case ARRAY: return fromPaimon((InternalArray) o, (ArrayType) type); + case VECTOR: + return fromPaimon((InternalVector) o, (VectorType) type); case MAP: case MULTISET: return fromPaimon((InternalMap) o, type); @@ -93,6 +97,16 @@ public static ArrayData fromPaimon(InternalArray array, ArrayType arrayType) { return fromPaimonArrayElementType(array, arrayType.getElementType()); } + public static ArrayData fromPaimon(InternalVector vector, VectorType vectorType) { + if (vector.size() != vectorType.getLength()) { + throw new IllegalArgumentException( + String.format( + "Vector length mismatch. Expected %d but was %d.", + vectorType.getLength(), vector.size())); + } + return fromPaimonArrayElementType(vector, vectorType.getElementType()); + } + private static ArrayData fromPaimonArrayElementType(InternalArray array, DataType elementType) { return SparkArrayData.create(elementType).replace(array); } diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java index 913d4f582af5..165be98980d9 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java @@ -44,7 +44,9 @@ import org.apache.paimon.types.BlobType; import org.apache.paimon.types.DataField; import org.apache.paimon.types.DataType; +import org.apache.paimon.types.DataTypes; import org.apache.paimon.utils.ExceptionUtils; +import org.apache.paimon.utils.Preconditions; import org.apache.spark.sql.PaimonSparkSession$; import org.apache.spark.sql.SparkSession; @@ -73,6 +75,7 @@ import org.apache.spark.sql.execution.datasources.DataSource; import org.apache.spark.sql.execution.datasources.FileFormat; import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2; +import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -99,6 +102,7 @@ import static org.apache.paimon.spark.SparkTypeUtils.toPaimonType; import static org.apache.paimon.spark.util.OptionUtils.checkRequiredConfigurations; import static org.apache.paimon.spark.util.OptionUtils.copyWithSQLConf; +import static org.apache.paimon.spark.util.OptionUtils.withBranchFromOptions; import static org.apache.paimon.spark.utils.CatalogUtils.checkNamespace; import static org.apache.paimon.spark.utils.CatalogUtils.checkNoDefaultValue; import static org.apache.paimon.spark.utils.CatalogUtils.isUpdateColumnDefaultValue; @@ -563,13 +567,20 @@ private Schema toInitialSchema( List blobFields = CoreOptions.blobField(properties); Set blobDescriptorFields = new CoreOptions(properties).blobDescriptorField(); List blobViewFields = CoreOptions.blobViewField(properties); + Set vectorFields = CoreOptions.fromMap(properties).vectorField(); String provider = properties.get(TableCatalog.PROP_PROVIDER); if (!usePaimon(provider)) { if (isFormatTable(provider)) { normalizedProperties.put(TYPE.key(), FORMAT_TABLE.toString()); normalizedProperties.put(FILE_FORMAT.key(), provider.toLowerCase()); } else { - throw new UnsupportedOperationException("Provider is not supported: " + provider); + throw new UnsupportedOperationException( + String.format( + "Provider '%s' is not supported by catalog '%s' (implementation: %s). Supported providers: [paimon, %s]", + provider, + catalogName, + getClass().getSimpleName(), + SparkSource.FORMAT_NAMES().mkString(", "))); } } normalizedProperties.remove(TableCatalog.PROP_PROVIDER); @@ -603,6 +614,22 @@ private Schema toInitialSchema( field.dataType() instanceof org.apache.spark.sql.types.BinaryType, "The type of blob field must be binary"); type = new BlobType(); + } else if (vectorFields.contains(field.name())) { + Preconditions.checkArgument( + field.dataType() instanceof ArrayType, + "The type of blob field must be array"); + ArrayType arrayType = (ArrayType) field.dataType(); + String dimKey = String.format("field.%s.vector-dim", field.name()); + Preconditions.checkArgument( + properties.containsKey(dimKey), + "When setting '" + + CoreOptions.VECTOR_FIELD.key() + + "', you must also set 'field.%s.vector-dim'," + + " where %s is the name of the vector field."); + type = + DataTypes.VECTOR( + Integer.parseInt(properties.get(dimKey)), + toPaimonType(arrayType.elementType())); } else { type = toPaimonType(field.dataType()).copy(field.nullable()); } @@ -759,7 +786,9 @@ public void dropV1Function(FunctionIdentifier funcIdent, boolean ifExists) throw protected org.apache.spark.sql.connector.catalog.Table loadSparkTable( Identifier ident, Map extraOptions) throws NoSuchTableException { try { - org.apache.paimon.catalog.Identifier tblIdent = toIdentifier(ident, catalogName); + org.apache.paimon.catalog.Identifier tblIdent = + withBranchFromOptions( + catalogName, toIdentifier(ident, catalogName), extraOptions); org.apache.paimon.table.Table table = copyWithSQLConf( catalog.getTable(tblIdent), catalogName, tblIdent, extraOptions); diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConnectorOptions.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConnectorOptions.java index 13305637eed3..5217ea05135e 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConnectorOptions.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConnectorOptions.java @@ -37,14 +37,28 @@ public class SparkConnectorOptions { .booleanType() .defaultValue(false) .withDescription( - "If true, merge the data schema and the table schema automatically before write data."); + "If true, evolve the table schema to accept new columns from the incoming data. " + + "Existing column types are preserved and incoming values are cast to them; " + + "to also widen existing types, enable 'write.merge-schema.type-widening'."); + + public static final ConfigOption TYPE_WIDENING = + key("write.merge-schema.type-widening") + .booleanType() + .defaultValue(false) + .withDescription( + "Only effective when 'write.merge-schema' is true. " + + "If true, widen an existing column type when the incoming data has a wider " + + "compatible type (e.g. INT -> BIGINT, DECIMAL precision increase). " + + "Lossy changes are still rejected unless 'write.merge-schema.explicit-cast' is also true."); public static final ConfigOption EXPLICIT_CAST = key("write.merge-schema.explicit-cast") .booleanType() .defaultValue(false) .withDescription( - "If true, allow to merge data types if the two types meet the rules for explicit casting."); + "Only effective when 'write.merge-schema.type-widening' is true. " + + "If true, also allow lossy type changes between compatible types " + + "(e.g. BIGINT -> INT, STRING -> DATE)."); public static final ConfigOption USE_V2_WRITE = key("write.use-v2-write") diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index c0b8cfd66be1..31e4c8aec92a 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -23,6 +23,8 @@ import org.apache.paimon.types.DataType; import org.apache.paimon.types.RowType; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysTrue; import org.apache.spark.sql.sources.And; import org.apache.spark.sql.sources.EqualNullSafe; import org.apache.spark.sql.sources.EqualTo; @@ -54,6 +56,8 @@ public class SparkFilterConverter { public static final List SUPPORT_FILTERS = Arrays.asList( + "AlwaysTrue", + "AlwaysFalse", "EqualTo", "EqualNullSafe", "GreaterThan", @@ -97,10 +101,16 @@ public Predicate convert(Filter filter, boolean ignoreFailure) { } public Predicate convert(Filter filter) { - if (filter instanceof EqualTo) { + if (filter instanceof AlwaysTrue) { + return PredicateBuilder.alwaysTrue(); + } else if (filter instanceof AlwaysFalse) { + return PredicateBuilder.alwaysFalse(); + } else if (filter instanceof EqualTo) { EqualTo eq = (EqualTo) filter; - // TODO deal with isNaN int index = fieldIndex(eq.attribute()); + if (isNaN(eq.value())) { + return builder.isNaN(index); + } Object literal = convertLiteral(index, eq.value()); return builder.equal(index, literal); } else if (filter instanceof EqualNullSafe) { @@ -173,11 +183,20 @@ public Predicate convert(Filter filter) { return builder.contains(index, literal); } - // TODO: AlwaysTrue, AlwaysFalse throw new UnsupportedOperationException( filter + " is unsupported. Support Filters: " + SUPPORT_FILTERS); } + private static boolean isNaN(Object value) { + if (value instanceof Float) { + return Float.isNaN((Float) value); + } + if (value instanceof Double) { + return Double.isNaN((Double) value); + } + return false; + } + public Object convertLiteral(String field, Object value) { return convertLiteral(fieldIndex(field), value); } diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java index 45a7c0af41ee..97b5771594d4 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java @@ -259,7 +259,20 @@ public InternalArray getArray(int pos) { @Override public InternalVector getVector(int pos) { - throw new UnsupportedOperationException("Not support VectorType yet."); + int actualPos = getActualFieldPosition(pos); + if (actualPos == -1 || internalRow.isNullAt(actualPos)) { + return null; + } + DataType dataType = tableSchema.fields()[pos].dataType(); + return toSparkInternalVector(dataType, internalRow.getArray(actualPos)); + } + + private static InternalVector toSparkInternalVector(DataType dataType, ArrayData arrayData) { + if (!(dataType instanceof ArrayType)) { + throw new UnsupportedOperationException("Not a vector type: " + dataType); + } + ArrayType arrayType = (ArrayType) dataType; + return new SparkInternalVector(arrayData, arrayType.elementType()); } @Override @@ -435,7 +448,7 @@ public InternalArray getArray(int pos) { @Override public InternalVector getVector(int pos) { - throw new UnsupportedOperationException("Not support VectorType yet."); + return toSparkInternalVector(elementType, arrayData.getArray(pos)); } @Override @@ -452,6 +465,13 @@ public InternalRow getRow(int pos, int numFields) { } } + /** adapt to spark internal vector. */ + public static class SparkInternalVector extends SparkInternalArray implements InternalVector { + public SparkInternalVector(ArrayData arrayData, DataType elementType) { + super(arrayData, elementType); + } + } + /** adapt to spark internal map. */ public static class SparkInternalMap implements InternalMap { diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java index 84767db9abac..6dea5a5d2981 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java @@ -20,6 +20,7 @@ import org.apache.paimon.catalog.CatalogContext; import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.BinaryVector; import org.apache.paimon.data.Blob; import org.apache.paimon.data.Decimal; import org.apache.paimon.data.InternalArray; @@ -35,6 +36,7 @@ import org.apache.paimon.types.MapType; import org.apache.paimon.types.RowKind; import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VectorType; import org.apache.paimon.utils.DateTimeUtils; import org.apache.paimon.utils.UriReaderFactory; @@ -48,6 +50,7 @@ import java.time.LocalDateTime; import java.time.ZoneId; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -168,7 +171,10 @@ public InternalArray getArray(int i) { @Override public InternalVector getVector(int pos) { - throw new UnsupportedOperationException("Not support VectorType yet."); + if (row.isNullAt(pos)) { + return null; + } + return toPaimonVector((VectorType) type.getTypeAt(pos), row.get(pos)); } @Override @@ -426,4 +432,92 @@ public double[] toDoubleArray() { return res; } } + + private static InternalVector toPaimonVector(VectorType vectorType, Object vector) { + if (vector == null) { + return null; + } + if (vector instanceof boolean[]) { + return BinaryVector.fromPrimitiveArray((boolean[]) vector); + } else if (vector instanceof byte[]) { + return BinaryVector.fromPrimitiveArray((byte[]) vector); + } else if (vector instanceof short[]) { + return BinaryVector.fromPrimitiveArray((short[]) vector); + } else if (vector instanceof int[]) { + return BinaryVector.fromPrimitiveArray((int[]) vector); + } else if (vector instanceof long[]) { + return BinaryVector.fromPrimitiveArray((long[]) vector); + } else if (vector instanceof float[]) { + return BinaryVector.fromPrimitiveArray((float[]) vector); + } else if (vector instanceof double[]) { + return BinaryVector.fromPrimitiveArray((double[]) vector); + } + if (vector instanceof scala.collection.Seq) { + vector = JavaConverters.seqAsJavaList((scala.collection.Seq) vector); + } else if (vector.getClass().isArray()) { + vector = Arrays.asList((Object[]) vector); + } + if (!(vector instanceof List)) { + throw new UnsupportedOperationException( + "Unsupported vector object: " + vector.getClass().getName()); + } + return toPaimonVector(vectorType, (List) vector); + } + + private static InternalVector toPaimonVector(VectorType vectorType, List list) { + int expectedLength = vectorType.getLength(); + if (list.size() != expectedLength) { + throw new IllegalArgumentException( + String.format( + "Vector length mismatch. Expected %d but was %d.", + expectedLength, list.size())); + } + switch (vectorType.getElementType().getTypeRoot()) { + case BOOLEAN: + boolean[] booleanValues = new boolean[expectedLength]; + for (int i = 0; i < expectedLength; i++) { + booleanValues[i] = (Boolean) list.get(i); + } + return BinaryVector.fromPrimitiveArray(booleanValues); + case TINYINT: + byte[] byteValues = new byte[expectedLength]; + for (int i = 0; i < expectedLength; i++) { + byteValues[i] = ((Number) list.get(i)).byteValue(); + } + return BinaryVector.fromPrimitiveArray(byteValues); + case SMALLINT: + short[] shortValues = new short[expectedLength]; + for (int i = 0; i < expectedLength; i++) { + shortValues[i] = ((Number) list.get(i)).shortValue(); + } + return BinaryVector.fromPrimitiveArray(shortValues); + case INTEGER: + int[] intValues = new int[expectedLength]; + for (int i = 0; i < expectedLength; i++) { + intValues[i] = ((Number) list.get(i)).intValue(); + } + return BinaryVector.fromPrimitiveArray(intValues); + case BIGINT: + long[] longValues = new long[expectedLength]; + for (int i = 0; i < expectedLength; i++) { + longValues[i] = ((Number) list.get(i)).longValue(); + } + return BinaryVector.fromPrimitiveArray(longValues); + case FLOAT: + float[] floatValues = new float[expectedLength]; + for (int i = 0; i < expectedLength; i++) { + floatValues[i] = ((Number) list.get(i)).floatValue(); + } + return BinaryVector.fromPrimitiveArray(floatValues); + case DOUBLE: + double[] doubleValues = new double[expectedLength]; + for (int i = 0; i < expectedLength; i++) { + doubleValues[i] = ((Number) list.get(i)).doubleValue(); + } + return BinaryVector.fromPrimitiveArray(doubleValues); + default: + throw new UnsupportedOperationException( + "Unsupported element type for vector " + vectorType.getElementType()); + } + } } diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/schema/PaimonMetadataColumnBase.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/schema/PaimonMetadataColumnBase.java new file mode 100644 index 000000000000..3f4cc090cd48 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/schema/PaimonMetadataColumnBase.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.schema; + +import org.apache.spark.sql.connector.catalog.MetadataColumn; + +abstract class PaimonMetadataColumnBase implements MetadataColumn { + + abstract boolean preserveOnDelete(); + + abstract boolean preserveOnUpdate(); + + abstract boolean preserveOnReinsert(); + + public String metadataInJSON() { + return "{\"__preserve_on_delete\":" + + preserveOnDelete() + + ",\"__preserve_on_update\":" + + preserveOnUpdate() + + ",\"__preserve_on_reinsert\":" + + preserveOnReinsert() + + "}"; + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/write/PaimonV2MetadataAwareDataWriter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/write/PaimonV2MetadataAwareDataWriter.java new file mode 100644 index 000000000000..fd6786a885b3 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/write/PaimonV2MetadataAwareDataWriter.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.write; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.catalog.CatalogContext; +import org.apache.paimon.table.sink.BatchWriteBuilder; +import org.apache.paimon.types.RowType; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +import scala.Option; + +/** + * Spark 4.x calls DataWriter.write(metadata, data) for metadata-aware writes. Keep this method in + * Java so the common sources still compile against Spark 3.5, where that interface method does not + * exist; Spark 4.x compilation generates the erased bridge required by the runtime call. + */ +public class PaimonV2MetadataAwareDataWriter extends PaimonV2DataWriter { + + public PaimonV2MetadataAwareDataWriter( + BatchWriteBuilder writeBuilder, + StructType writeSchema, + StructType rowTrackingWriteSchema, + StructType dataSchema, + StructType metadataSchema, + CoreOptions coreOptions, + CatalogContext catalogContext, + RowType paimonWriteType) { + super( + writeBuilder, + rowTrackingWriteSchema, + dataSchema, + coreOptions, + catalogContext, + Option.empty(), + Option.apply(paimonWriteType), + Option.apply(metadataSchema), + Option.apply(writeSchema)); + } + + public void write(InternalRow metadata, InternalRow data) { + writeWithMetadata(metadata, data); + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala index 444b6d6c642e..0d67001f0ec8 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala @@ -20,11 +20,11 @@ package org.apache.paimon.spark import org.apache.paimon.data.{BinaryString, GenericRow, InternalRow => PaimonInternalRow, JoinedRow} import org.apache.paimon.fs.Path -import org.apache.paimon.reader.{FileRecordIterator, RecordReader} +import org.apache.paimon.reader.{FileRecordIterator, RecordReader, ScoreRecordIterator} import org.apache.paimon.spark.schema.PaimonMetadataColumn -import org.apache.paimon.spark.schema.PaimonMetadataColumn.{PARTITION_AND_BUCKET_META_COLUMNS, PATH_AND_INDEX_META_COLUMNS} +import org.apache.paimon.spark.schema.PaimonMetadataColumn.{PARTITION_AND_BUCKET_META_COLUMNS, PATH_AND_INDEX_META_COLUMNS, VECTOR_SEARCH_SCORE_COLUMN} import org.apache.paimon.table.source.{DataSplit, Split} -import org.apache.paimon.utils.CloseableIterator +import org.apache.paimon.utils.{CloseableIterator, Preconditions} import org.apache.spark.sql.PaimonUtils @@ -48,6 +48,10 @@ case class PaimonRecordReaderIterator( private val needMetadata = metadataColumns.nonEmpty private val needPathAndIndexMetadata = metadataColumns.exists(c => PATH_AND_INDEX_META_COLUMNS.contains(c.name)) + private val needScoreMetadata = { + metadataColumns.exists(_.name == VECTOR_SEARCH_SCORE_COLUMN) + } + Preconditions.checkArgument(!needScoreMetadata || metadataColumns.size == 1) private val metadataRow: GenericRow = GenericRow.of(Array.fill(metadataColumns.size)(null.asInstanceOf[AnyRef]): _*) @@ -122,7 +126,11 @@ case class PaimonRecordReaderIterator( while (!stop) { val dataRow = currentIterator.next() if (dataRow != null) { - if (needMetadata) { + if (needScoreMetadata) { + updateScoreMetadata( + currentIterator.asInstanceOf[ScoreRecordIterator[PaimonInternalRow]]) + currentResult = joinedRow.replace(dataRow, metadataRow) + } else if (needMetadata) { updateMetadataRow(currentIterator.asInstanceOf[FileRecordIterator[PaimonInternalRow]]) currentResult = joinedRow.replace(dataRow, metadataRow) } else { @@ -165,4 +173,15 @@ case class PaimonRecordReaderIterator( } } } + + private def updateScoreMetadata( + fileRecordIterator: ScoreRecordIterator[PaimonInternalRow]): Unit = { + metadataColumns.zipWithIndex.foreach { + case (metadataColumn, index) => + metadataColumn.name match { + case PaimonMetadataColumn.VECTOR_SEARCH_SCORE_COLUMN => + metadataRow.setField(index, fileRecordIterator.returnedScore()) + } + } + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala index 0fc4bd9eb5f7..94b912844430 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala @@ -118,6 +118,9 @@ abstract class PaimonSparkTableBase(val table: Table) _metadataColumns.append(PaimonMetadataColumn.ROW_ID) _metadataColumns.append(PaimonMetadataColumn.SEQUENCE_NUMBER) } + if (table.isInstanceOf[VectorSearchTable]) { + _metadataColumns.append(PaimonMetadataColumn.VECTOR_SEARCH_SCORE) + } _metadataColumns.appendAll( Seq( diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala index 0196ea6404ca..9ea20de1909d 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala @@ -40,14 +40,15 @@ import java.util.{EnumSet => JEnumSet, Set => JSet} * If this base class implemented `SupportsRowLevelOperations`, Spark 4.1 would immediately call * `newRowLevelOperationBuilder` on tables whose V2 write is disabled (e.g. dynamic bucket or * primary-key tables that fall back to V1 write) and fail before Paimon has a chance to rewrite the - * plan to a V1 command. Likewise, deletion-vector, row-tracking, and data-evolution tables need to - * stay on Paimon's V1 postHoc path even when `useV2Write=true`, so they must also not expose - * `SupportsRowLevelOperations`. + * plan to a V1 command. Likewise, deletion-vector, data-evolution, and fixed-length CHAR tables + * need to stay on Paimon's V1 postHoc path even when `useV2Write=true`, so they must also not + * expose `SupportsRowLevelOperations`. * * Tables that DO support V2 row-level operations use the [[SparkTableWithRowLevelOps]] subclass * instead; the [[SparkTable.of]] factory picks the right variant via - * [[SparkTable.supportsV2RowLevelOps]], which is kept in lockstep with - * `RowLevelHelper.shouldFallbackToV1`. + * [[SparkTable.supportsV2RowLevelOps]]. Append-only tables, including row-tracking-only tables, + * expose `SupportsRowLevelOperations` so DELETE, UPDATE, and MERGE INTO can go through the V2 + * copy-on-write path when the table has no PK, deletion vectors, data evolution, or CHAR columns. */ case class SparkTable(override val table: Table) extends PaimonSparkTableBase(table) @@ -93,12 +94,11 @@ object SparkTable { * Whether the given table supports Paimon's V2 row-level operations, i.e. whether it is safe to * expose [[SupportsRowLevelOperations]] to Spark. * - * This must stay in sync with - * `org.apache.paimon.spark.catalyst.analysis.RowLevelHelper#shouldFallbackToV1` — the two - * predicates are logical complements. If they diverge, Spark 4.1's row-level rewrite rules (which - * fire in the main Resolution batch) will intercept DML on tables that Paimon expects to handle - * through its postHoc V1 fallback, leaving primary-key / deletion-vector / row-tracking / - * data-evolution tables with broken MERGE/UPDATE/DELETE dispatch. + * Append-only tables return `true` here so that `SparkTable.of` wraps them as + * `SparkTableWithRowLevelOps`, enabling Spark's V2 copy-on-write DELETE, UPDATE, and MERGE INTO + * paths. Row-tracking append-only tables require Spark 4.0+ because Spark 3.5 does not have the + * metadata-aware `DataWriter.write(metadata, data)` path needed to preserve row-tracking metadata + * for rewritten rows. * * Per-version shims for Spark 3.2/3.3/3.4 each ship their own * `org.apache.paimon.spark.SparkTable` (class + companion) that shadows this one at packaging @@ -113,10 +113,13 @@ object SparkTable { if (!sparkTable.useV2Write) return false sparkTable.getTable match { case fs: FileStoreTable => + val supportsRowTrackingCopyOnWrite = + !sparkTable.coreOptions.rowTrackingEnabled() || org.apache.spark.SPARK_VERSION >= "4.0" fs.primaryKeys().isEmpty && + supportsRowTrackingCopyOnWrite && !sparkTable.coreOptions.deletionVectorsEnabled() && - !sparkTable.coreOptions.rowTrackingEnabled() && - !sparkTable.coreOptions.dataEvolutionEnabled() + !sparkTable.coreOptions.dataEvolutionEnabled() && + !SparkTypeUtils.containsCharType(fs.rowType()) case _ => false } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTypeUtils.java b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTypeUtils.java index dc2f8b30acab..80a27c35ac90 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTypeUtils.java +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTypeUtils.java @@ -44,6 +44,7 @@ import org.apache.paimon.types.VarBinaryType; import org.apache.paimon.types.VarCharType; import org.apache.paimon.types.VariantType; +import org.apache.paimon.types.VectorType; import org.apache.spark.sql.paimon.shims.SparkShimLoader; import org.apache.spark.sql.types.DataType; @@ -98,6 +99,23 @@ public static org.apache.paimon.types.DataType toPaimonType(DataType dataType) { return SparkToPaimonTypeVisitor.visit(dataType); } + public static boolean containsCharType(org.apache.paimon.types.DataType type) { + if (type instanceof CharType) { + return true; + } else if (type instanceof RowType) { + return ((RowType) type).getFields().stream() + .anyMatch(field -> containsCharType(field.type())); + } else if (type instanceof ArrayType) { + return containsCharType(((ArrayType) type).getElementType()); + } else if (type instanceof MapType) { + MapType mapType = (MapType) type; + return containsCharType(mapType.getKeyType()) || containsCharType(mapType.getValueType()); + } else if (type instanceof MultisetType) { + return containsCharType(((MultisetType) type).getElementType()); + } + return false; + } + /** * Prune Paimon `RowType` by required Spark `StructType`, use this method instead of {@link * #toPaimonType(DataType)} when need to retain the field id. @@ -127,8 +145,12 @@ private static org.apache.paimon.types.DataType prunePaimonType( } else if (sparkDataType instanceof org.apache.spark.sql.types.ArrayType) { org.apache.spark.sql.types.ArrayType s = (org.apache.spark.sql.types.ArrayType) sparkDataType; - ArrayType r = (ArrayType) paimonDataType; - return r.newElementType(prunePaimonType(s.elementType(), r.getElementType())); + if (paimonDataType instanceof VectorType) { + return paimonDataType; + } else { + ArrayType r = (ArrayType) paimonDataType; + return r.newElementType(prunePaimonType(s.elementType(), r.getElementType())); + } } else { return paimonDataType; } @@ -242,6 +264,11 @@ public DataType visit(ArrayType arrayType) { return DataTypes.createArrayType(elementType.accept(this), elementType.isNullable()); } + @Override + public DataType visit(VectorType vectorType) { + return DataTypes.createArrayType(vectorType.getElementType().accept(this), false); + } + @Override public DataType visit(MultisetType multisetType) { return DataTypes.createMapType( diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala index 52ce726c3598..8b0a57cc52d8 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala @@ -20,7 +20,7 @@ package org.apache.paimon.spark.aggregate import org.apache.paimon.table.FileStoreTable import org.apache.paimon.table.source.{DataSplit, ReadBuilder, Split} -import org.apache.paimon.table.source.PushDownUtils.minmaxAvailable +import org.apache.paimon.table.source.PushDownUtils.{minmaxAvailable, tightBoundsAvailable} import org.apache.paimon.types._ import org.apache.spark.sql.connector.expressions.Expression @@ -37,7 +37,6 @@ object AggregatePushDownUtils { table: FileStoreTable, aggregation: Aggregation, readBuilder: ReadBuilder): Option[LocalAggregator] = { - val options = table.coreOptions() val rowType = table.rowType val partitionKeys = table.partitionKeys() @@ -53,13 +52,16 @@ object AggregatePushDownUtils { if (columns.isEmpty) { generateSplits(readBuilder.dropStats()) } else { - if (options.deletionVectorsEnabled() || !table.primaryKeys().isEmpty) { + if (!table.primaryKeys().isEmpty) { return None } val splits = generateSplits(readBuilder) if (!splits.forall(minmaxAvailable(_, columns.asJava))) { return None } + if (!splits.forall(tightBoundsAvailable)) { + return None + } splits } case None => return None diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/MergeSchemaEvolutionHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/MergeSchemaEvolutionHelper.scala index e94621c2de1b..9275918398ba 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/MergeSchemaEvolutionHelper.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/MergeSchemaEvolutionHelper.scala @@ -20,8 +20,7 @@ package org.apache.paimon.spark.catalyst.analysis import org.apache.paimon.spark.{SparkTable, SparkTypeUtils} import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper -import org.apache.paimon.spark.commands.SchemaHelper -import org.apache.paimon.spark.schema.SparkSystemColumns +import org.apache.paimon.spark.commands.SchemaEvolutionHelper import org.apache.paimon.spark.util.OptionUtils import org.apache.paimon.table.FileStoreTable @@ -29,17 +28,17 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExprId, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, MergeAction, MergeIntoTable} -import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.paimon.shims.SparkShimLoader import org.apache.spark.sql.types.{StructField, StructType} /** - * Shared MERGE INTO `merge-schema=true` evolution. Triggers on `UPDATE *` / `INSERT *` (via - * [[PaimonMergeActionTags]]) or on any explicit assignment whose key resolved to a source-bound - * attribute (via [[PaimonMergeIntoResolver.resolveAssignments]] fallback — that shape is the - * marker, no extra tag). Evolution is scoped to source columns referenced by matched / not-matched - * actions; NOT MATCHED BY SOURCE can't reference source columns. + * MERGE INTO schema evolution (merge-schema=true). Computes the evolved schema in memory and + * rewrites the merge plan so that action alignment targets the new columns; the actual schema + * commit is deferred to execution (the merge command's `run`). + * + * Triggered by `UPDATE *` / `INSERT *` or explicit source-bound assignment keys. Scoped to source + * columns referenced in matched/not-matched actions. */ trait MergeSchemaEvolutionHelper extends ExpressionHelper { @@ -77,25 +76,24 @@ trait MergeSchemaEvolutionHelper extends ExpressionHelper { } val fileStoreTable = v2Table.getTable.asInstanceOf[FileStoreTable] + // Pass raw source types: the core merge decides whether to keep or widen the target type, and + // the action alignment layer casts incoming values to the result. val sourceSchema = StructType( merge.sourceTable.output .filter(a => scopedNames.exists(n => resolver(n, a.name))) .map(a => StructField(a.name, a.dataType, a.nullable))) - val filteredSourceSchema = SparkSystemColumns.filterSparkSystemColumns(sourceSchema) - val allowExplicitCast = OptionUtils.writeMergeSchemaExplicitCastEnabled() - val updatedFileStoreTable = SchemaHelper - .mergeAndCommitSchema(fileStoreTable, filteredSourceSchema, allowExplicitCast) + // Compute the evolved schema in memory only; the actual schema commit is deferred to execution + // (the merge command's run / the V2 write's toBatch) so analysis stays side-effect-free. The + // evolved relation presents the new columns to the plan; existing rows read them as NULL. + val updatedFileStoreTable = SchemaEvolutionHelper + .evolvedTableInMemory(fileStoreTable, sourceSchema, spark) .getOrElse(return None) - // Invalidate Spark catalog cache so subsequent queries see the new schema. - for (catalog <- relation.catalog; ident <- relation.identifier) { - catalog.asInstanceOf[TableCatalog].invalidateTable(ident) - } - val updatedV2Table = v2Table.copy(table = updatedFileStoreTable) - val mergedSparkSchema = - SparkTypeUtils.fromPaimonRowType(updatedFileStoreTable.schema().logicalRowType()) - val newOutput = buildEvolvedOutput(mergedSparkSchema, relation.output, resolver) + val newOutput = buildEvolvedOutput( + SparkTypeUtils.fromPaimonRowType(updatedFileStoreTable.schema().logicalRowType()), + relation.output, + resolver) val updatedRelation = SparkShimLoader.shim.copyDataSourceV2Relation(relation, updatedV2Table, newOutput) val updatedTargetTable = merge.targetTable.transform { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala index 5400e9865fab..c26eeedd3f62 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala @@ -18,12 +18,12 @@ package org.apache.paimon.spark.catalyst.analysis -import org.apache.paimon.spark.{SparkConnectorOptions, SparkTable} +import org.apache.paimon.options.Options +import org.apache.paimon.spark.SparkTable import org.apache.paimon.spark.catalyst.Compatibility import org.apache.paimon.spark.catalyst.analysis.PaimonRelation.isPaimonTable import org.apache.paimon.spark.catalyst.plans.logical.PaimonDropPartitions -import org.apache.paimon.spark.commands.{PaimonAnalyzeTableColumnCommand, PaimonDynamicPartitionOverwriteCommand, PaimonShowColumnsCommand} -import org.apache.paimon.spark.util.OptionUtils +import org.apache.paimon.spark.commands.{PaimonAnalyzeTableColumnCommand, PaimonDynamicPartitionOverwriteCommand, PaimonShowColumnsCommand, SchemaEvolutionHelper} import org.apache.paimon.table.FileStoreTable import org.apache.spark.sql.{PaimonUtils, SparkSession} @@ -35,6 +35,8 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.TableCapability import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation} +import scala.collection.JavaConverters._ + class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { import DataSourceV2Implicits._ import PaimonAnalysis._ @@ -43,12 +45,17 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { case a @ PaimonV2WriteCommand(table) if !paimonWriteResolved(a.query, table) && a.query.getTagValue(PAIMON_WRITE_RESOLVED).isEmpty => - val mergeSchemaEnabled = - writeOptions(a).get(SparkConnectorOptions.MERGE_SCHEMA.key()).contains("true") || - OptionUtils.writeMergeSchemaEnabled() + val options = Options.fromMap(writeOptions(a).asJava) + val mergeSchemaEnabled = SchemaEvolutionHelper.mergeSchemaEnabled(options) + val expected = SchemaEvolutionHelper.expectedAttrsForCatalogWrite( + table, + a.query.schema, + options, + a.isByName, + session) val newQuery = PaimonOutputResolver.resolveOutputColumns( table.name, - table.output, + expected, a.query, a.isByName, mergeSchemaEnabled) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAssignmentUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAssignmentUtils.scala index 2b2c13bba8bd..a1e5dedc7560 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAssignmentUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAssignmentUtils.scala @@ -22,9 +22,9 @@ import org.apache.paimon.spark.SparkTypeUtils.CURRENT_DEFAULT_COLUMN_METADATA_KE import org.apache.paimon.spark.catalyst.Compatibility import org.apache.paimon.spark.catalyst.analysis.PaimonOutputResolver.MissingFieldBehavior -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{PaimonUtils, SparkSession} import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, CreateNamedStruct, Expression, GetStructField, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, InsertStarAction, MergeAction, UpdateStarAction} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.paimon.shims.SparkShimLoader @@ -192,7 +192,7 @@ object PaimonAssignmentUtils extends SQLConfHelper { col.dataType match { case structType: StructType => - val fieldAttrs = toAttributes(structType) + val fieldAttrs = PaimonUtils.toAttributes(structType) val fieldExprs = structType.fields.zipWithIndex.map { case (field, ordinal) => GetStructField(colExpr, ordinal, Some(field.name)) } @@ -225,10 +225,6 @@ object PaimonAssignmentUtils extends SQLConfHelper { attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) } - private def toAttributes(structType: StructType): Seq[Attribute] = { - structType.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - } - /** Parse the column's `CURRENT_DEFAULT` SQL (cast to column type) or fall back to NULL. */ private def getDefaultValueOrNull(attr: Attribute): Expression = { val nullLit = Literal(null, attr.dataType) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeInto.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeInto.scala index eaa5cb19b828..be8424bf9b55 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeInto.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeInto.scala @@ -66,7 +66,8 @@ case class PaimonMergeInto(spark: SparkSession) primaryKeys) } - // Commit schema changes before alignment so the aligned plan sees new columns. + // Evolve the target schema in memory before alignment so the aligned plan sees the new + // columns; the commit is deferred to execution (the merge command's run). val (resolvedMerge, targetOutput) = evolveTargetIfNeeded(merge, relation, v2Table, spark, resolveNotMatchedBySourceActions) .map { case (m, r, t) => v2Table = t; (m, r.output) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonOutputResolver.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonOutputResolver.scala index 20da69c32123..37ffebe8280f 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonOutputResolver.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonOutputResolver.scala @@ -20,6 +20,7 @@ package org.apache.paimon.spark.catalyst.analysis import org.apache.paimon.spark.catalyst.Compatibility +import org.apache.spark.sql.PaimonUtils import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull @@ -40,7 +41,8 @@ object PaimonOutputResolver extends SQLConfHelper { * How nested struct field misalignment is handled: * - [[FailMissing]]: strict — nested missing target / source-extra throws. * - [[NullForMissing]]: merge-schema for INSERT / explicit UPDATE — missing NULL-fills, - * source-extras kept so [[org.apache.paimon.spark.commands.SchemaHelper]] evolves the table. + * source-extras kept so [[org.apache.paimon.spark.commands.SchemaEvolutionHelper]] evolves + * the table. * - [[PreserveTarget]]: merge-schema for `UPDATE *` struct — missing source field substitutes * `GetStructField(targetExpr, ordinal)` to keep the current target value. */ @@ -277,12 +279,12 @@ object PaimonOutputResolver extends SQLConfHelper { reorderColumnsByName( tableName, fields, - toAttributes(targetType), + PaimonUtils.toAttributes(targetType), behavior, targetExpr, colPath) } else { - resolveColumnsByPosition(tableName, fields, toAttributes(targetType), colPath) + resolveColumnsByPosition(tableName, fields, PaimonUtils.toAttributes(targetType), colPath) } val targetNamedStruct = CreateStruct(resolved) val res = maybeWrapWithNullPreservation( @@ -472,10 +474,6 @@ object PaimonOutputResolver extends SQLConfHelper { case other => other } - private def toAttributes(structType: StructType): Seq[Attribute] = { - structType.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - } - private def restoreActualType(attr: Attribute): Attribute = { attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelHelper.scala index 4bbdb8bbd8c2..da4394867611 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelHelper.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelHelper.scala @@ -76,17 +76,14 @@ trait RowLevelHelper extends SQLConfHelper { } } - /** - * Determines if DataSourceV2 is not supported for the given table. This is the logical complement - * of [[SparkTable.supportsV2RowLevelOps]]; the two predicates must stay in sync so that Spark - * 4.1's row-level rewrite rules (which key on `SupportsRowLevelOperations`) and Paimon's V1 - * postHoc fallback rules (which gate on this predicate) agree about which tables go down which - * path. - */ protected def shouldFallbackToV1(table: SparkTable): Boolean = { !SparkTable.supportsV2RowLevelOps(table) } + // `SparkTable.supportsV2RowLevelOps` controls whether the table exposes Spark row-level + // capability at all. These per-operation checks are the remaining V1 fallbacks for cases Spark's + // V2 rewrite cannot safely handle: metadata-only DELETE, non-rewritable UPDATE/MERGE, or + // assignments that have not been aligned yet. /** Determines if DataSourceV2 delete is not supported for the given table. */ protected def shouldFallbackToV1Delete(table: SparkTable, condition: Expression): Boolean = { shouldFallbackToV1(table) || @@ -106,13 +103,6 @@ trait RowLevelHelper extends SQLConfHelper { protected def shouldFallbackToV1MergeInto(m: MergeIntoTable): Boolean = { val relation = PaimonRelation.getPaimonRelation(m.targetTable) val table = relation.table.asInstanceOf[SparkTable] - // Note for Spark 4.1+: `shouldFallbackToV1(table)` returns `false` for pure append-only - // tables (no PK / RT / DE / DV), so this predicate lets the aligned `MergeIntoTable` node - // return untouched. Spark's own `RewriteMergeIntoTable` in the Resolution batch can't fire - // (`resolveOperators` short-circuits on `analyzed=true`), so the rewrite is performed by - // `Spark41MergeIntoRewrite` (paimon-spark4-common) which aligns + transcribes Spark's - // `ReplaceData` / `AppendData` branches for non-`SupportsDelta` sources. Non-append-only - // tables still fall back to V1 (`MergeIntoPaimonTable` / `MergeIntoPaimonDataEvolutionTable`). shouldFallbackToV1(table) || !m.rewritable || !m.aligned diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyIntoTableCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyIntoTableCommand.scala index eedad9763ed6..b7cc456f3b73 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyIntoTableCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyIntoTableCommand.scala @@ -29,14 +29,17 @@ case class CopyIntoTableCommand( sourcePath: String, fileFormat: CopyFileFormat, pattern: Option[String], - force: Boolean) + force: Boolean, + onError: OnErrorMode = OnErrorMode.AbortStatement) extends PaimonLeafCommand { override def output: Seq[Attribute] = Seq( AttributeReference("file_name", StringType, nullable = false)(), AttributeReference("status", StringType, nullable = false)(), AttributeReference("rows_loaded", LongType, nullable = false)(), - AttributeReference("rows_parsed", LongType, nullable = false)() + AttributeReference("rows_parsed", LongType, nullable = false)(), + AttributeReference("errors_seen", LongType, nullable = false)(), + AttributeReference("first_error", StringType, nullable = true)() ) override def simpleString(maxFields: Int): String = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyOptions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyOptions.scala index 2e2f7e2ec1e0..3e5e35b2b78d 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyOptions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/CopyOptions.scala @@ -18,10 +18,20 @@ package org.apache.paimon.spark.catalyst.plans.logical +sealed trait OnErrorMode + +object OnErrorMode { + case object AbortStatement extends OnErrorMode + case object Continue extends OnErrorMode + case object SkipFile extends OnErrorMode +} + sealed trait FileFormatType object FileFormatType { case object CSV extends FileFormatType + case object JSON extends FileFormatType + case object PARQUET extends FileFormatType case class Unsupported(name: String) extends FileFormatType } @@ -29,33 +39,77 @@ case class CopyFileFormat(formatType: FileFormatType, options: Map[String, Strin def toSparkReaderOptions: Map[String, String] = { val mapped = scala.collection.mutable.Map[String, String]("mode" -> "FAILFAST") - options.foreach { - case (k, v) => - k match { - case "FIELD_DELIMITER" => mapped("sep") = v - case "QUOTE" => mapped("quote") = v - case "ESCAPE" => mapped("escape") = v - case "COMPRESSION" => mapped("compression") = v - case "SKIP_HEADER" => - mapped("header") = if (v == "1" || v.equalsIgnoreCase("TRUE")) "true" else "false" - case _ => + formatType match { + case FileFormatType.CSV => + options.foreach { + case (k, v) => + k match { + case "FIELD_DELIMITER" => mapped("sep") = v + case "QUOTE" => mapped("quote") = v + case "ESCAPE" => mapped("escape") = v + case "COMPRESSION" => mapped("compression") = v + case "SKIP_HEADER" => + mapped("header") = if (v == "1" || v.equalsIgnoreCase("TRUE")) "true" else "false" + case _ => + } + } + case FileFormatType.JSON => + options.foreach { + case (k, v) => + k match { + case "MULTI_LINE" => mapped("multiLine") = v.toLowerCase + case "COMPRESSION" => mapped("compression") = v + case _ => + } + } + case FileFormatType.PARQUET => + mapped.remove("mode") + options.foreach { + case (k, v) => + k match { + case "COMPRESSION" => mapped("compression") = v + case _ => + } } + case _ => } mapped.toMap } def toSparkWriterOptions: Map[String, String] = { val mapped = scala.collection.mutable.Map[String, String]() - options.foreach { - case (k, v) => - k match { - case "FIELD_DELIMITER" => mapped("sep") = v - case "HEADER" => mapped("header") = v.toLowerCase - case "QUOTE" => mapped("quote") = v - case "ESCAPE" => mapped("escape") = v - case "COMPRESSION" => mapped("compression") = v - case _ => + formatType match { + case FileFormatType.CSV => + options.foreach { + case (k, v) => + k match { + case "FIELD_DELIMITER" => mapped("sep") = v + case "HEADER" => mapped("header") = v.toLowerCase + case "QUOTE" => mapped("quote") = v + case "ESCAPE" => mapped("escape") = v + case "COMPRESSION" => mapped("compression") = v + case _ => + } + } + case FileFormatType.JSON => + options.foreach { + case (k, v) => + k match { + case "COMPRESSION" => mapped("compression") = v + case "DATE_FORMAT" => mapped("dateFormat") = v + case "TIMESTAMP_FORMAT" => mapped("timestampFormat") = v + case _ => + } } + case FileFormatType.PARQUET => + options.foreach { + case (k, v) => + k match { + case "COMPRESSION" => mapped("compression") = v + case _ => + } + } + case _ => } mapped.toMap } @@ -78,25 +132,37 @@ case class CopyFileFormat(formatType: FileFormatType, options: Map[String, Strin throw new IllegalArgumentException( "MODE cannot be specified in FILE_FORMAT options; it is reserved for ON_ERROR handling") } - val invalid = options.keys.filterNot(CopyFileFormat.VALID_IMPORT_KEYS.contains) + val validKeys = formatType match { + case FileFormatType.JSON => CopyFileFormat.VALID_JSON_IMPORT_KEYS + case FileFormatType.PARQUET => CopyFileFormat.VALID_PARQUET_IMPORT_KEYS + case _ => CopyFileFormat.VALID_CSV_IMPORT_KEYS + } + val invalid = options.keys.filterNot(validKeys.contains) if (invalid.nonEmpty) { throw new IllegalArgumentException( s"Unsupported FILE_FORMAT options for import: ${invalid.mkString(", ")}") } - options.get("SKIP_HEADER").foreach { - v => - val intVal = - try v.toInt - catch { case _: NumberFormatException => -1 } - if (intVal != 0 && intVal != 1) { - throw new IllegalArgumentException(s"SKIP_HEADER supports only 0 or 1, got: $v") - } + if (formatType == FileFormatType.CSV) { + options.get("SKIP_HEADER").foreach { + v => + val intVal = + try v.toInt + catch { case _: NumberFormatException => -1 } + if (intVal != 0 && intVal != 1) { + throw new IllegalArgumentException(s"SKIP_HEADER supports only 0 or 1, got: $v") + } + } } } def validateForExport(): Unit = { validateFormatType() - val invalid = options.keys.filterNot(CopyFileFormat.VALID_EXPORT_KEYS.contains) + val validKeys = formatType match { + case FileFormatType.JSON => CopyFileFormat.VALID_JSON_EXPORT_KEYS + case FileFormatType.PARQUET => CopyFileFormat.VALID_PARQUET_EXPORT_KEYS + case _ => CopyFileFormat.VALID_CSV_EXPORT_KEYS + } + val invalid = options.keys.filterNot(validKeys.contains) if (invalid.nonEmpty) { throw new IllegalArgumentException( s"Unsupported FILE_FORMAT options for export: ${invalid.mkString(", ")}") @@ -106,16 +172,18 @@ case class CopyFileFormat(formatType: FileFormatType, options: Map[String, Strin private def validateFormatType(): Unit = { formatType match { case FileFormatType.CSV => + case FileFormatType.JSON => + case FileFormatType.PARQUET => case FileFormatType.Unsupported(name) => throw new IllegalArgumentException( - s"Unsupported file format type: $name. Only CSV is currently supported") + s"Unsupported file format type: $name. Supported types: CSV, JSON, PARQUET") } } } object CopyFileFormat { - val VALID_IMPORT_KEYS: Set[String] = Set( + val VALID_CSV_IMPORT_KEYS: Set[String] = Set( "FIELD_DELIMITER", "SKIP_HEADER", "QUOTE", @@ -125,7 +193,14 @@ object CopyFileFormat { "COMPRESSION" ) - val VALID_EXPORT_KEYS: Set[String] = Set( + val VALID_JSON_IMPORT_KEYS: Set[String] = Set( + "MULTI_LINE", + "COMPRESSION", + "NULL_IF", + "EMPTY_FIELD_AS_NULL" + ) + + val VALID_CSV_EXPORT_KEYS: Set[String] = Set( "FIELD_DELIMITER", "HEADER", "QUOTE", @@ -133,12 +208,28 @@ object CopyFileFormat { "COMPRESSION" ) + val VALID_JSON_EXPORT_KEYS: Set[String] = Set( + "COMPRESSION", + "DATE_FORMAT", + "TIMESTAMP_FORMAT" + ) + + val VALID_PARQUET_IMPORT_KEYS: Set[String] = Set( + "COMPRESSION" + ) + + val VALID_PARQUET_EXPORT_KEYS: Set[String] = Set( + "COMPRESSION" + ) + // Unit Separator (U+001F) used to encode multi-value lists in a single string val LIST_SEPARATOR: String = "\u001f" def parseFormatType(typeStr: String): FileFormatType = { typeStr.toUpperCase match { case "CSV" => FileFormatType.CSV + case "JSON" => FileFormatType.JSON + case "PARQUET" => FileFormatType.PARQUET case other => FileFormatType.Unsupported(other) } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DataEvolutionPaimonWriter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DataEvolutionPaimonWriter.scala index 2b7833322393..c49368d1bf38 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DataEvolutionPaimonWriter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DataEvolutionPaimonWriter.scala @@ -19,7 +19,6 @@ package org.apache.paimon.spark.commands import org.apache.paimon.CoreOptions -import org.apache.paimon.data.BinaryRow import org.apache.paimon.format.blob.BlobFileFormat.isBlobFile import org.apache.paimon.spark.write.{DataEvolutionTableDataWrite, WriteHelper, WriteTaskResult} import org.apache.paimon.table.FileStoreTable @@ -28,6 +27,7 @@ import org.apache.paimon.table.source.DataSplit import org.apache.paimon.types.DataType import org.apache.paimon.types.DataTypeRoot.BLOB import org.apache.paimon.types.VectorType.isVectorStoreFile +import org.apache.paimon.utils.SerializationUtils import org.apache.spark.sql._ @@ -39,21 +39,6 @@ import scala.collection.mutable case class DataEvolutionPaimonWriter(paimonTable: FileStoreTable, dataSplits: Seq[DataSplit]) extends WriteHelper { - private lazy val firstRowIdToPartitionMap: mutable.HashMap[Long, (BinaryRow, Long)] = { - val firstRowIdToPartitionMap = new mutable.HashMap[Long, (BinaryRow, Long)] - dataSplits.foreach( - split => - split - .dataFiles() - .asScala - .filter(file => !isBlobFile(file.fileName()) && !isVectorStoreFile(file.fileName())) - .foreach( - file => - firstRowIdToPartitionMap - .put(file.firstRowId(), (split.partition(), file.rowCount())))) - firstRowIdToPartitionMap - } - // File rolling will never be performed override val table: FileStoreTable = paimonTable.copy(Collections.singletonMap(CoreOptions.TARGET_FILE_SIZE.key(), "99999 G")) @@ -77,14 +62,35 @@ case class DataEvolutionPaimonWriter(paimonTable: FileStoreTable, dataSplits: Se CoreOptions.BLOB_EXTERNAL_STORAGE_FIELD.key() + "') can be updated.") } + val firstRowIdToPartitionMap = new mutable.HashMap[Long, (Array[Byte], Long)] + dataSplits.foreach( + split => + split + .dataFiles() + .asScala + .filter(file => !isBlobFile(file.fileName()) && !isVectorStoreFile(file.fileName())) + .foreach( + file => + firstRowIdToPartitionMap + .put( + file.firstRowId(), + // BinaryRow stores data in transient memory segments and relies on Java + // serialization hooks to restore them. Store bytes in Spark closures and + // broadcasts so Kryo does not serialize BinaryRow internals directly. + (SerializationUtils.serializeBinaryRow(split.partition()), file.rowCount()) + ))) + val firstRowIdToPartitionMapBroadcast = + sparkSession.sparkContext.broadcast(firstRowIdToPartitionMap) + val writeBuilder = table.newBatchWriteBuilder() + val written = data.mapPartitions { iter => { val write = DataEvolutionTableDataWrite( - table.newBatchWriteBuilder(), + writeBuilder, writeType, - firstRowIdToPartitionMap, + firstRowIdToPartitionMapBroadcast.value, catalogContextForBlobDescriptor) try { iter.foreach(row => write.write(row)) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala index 96f8c0c5cc9f..cd1b000a361f 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala @@ -27,17 +27,21 @@ import org.apache.paimon.spark.SparkTable import org.apache.paimon.spark.catalyst.analysis.PaimonRelation import org.apache.paimon.spark.catalyst.analysis.PaimonRelation.isPaimonTable import org.apache.paimon.spark.catalyst.analysis.PaimonUpdateTable.toColumn +import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand import org.apache.paimon.spark.util.ScanPlanHelper.createNewScanPlan import org.apache.paimon.table.FileStoreTable import org.apache.paimon.table.sink.{CommitMessage, CommitMessageImpl} import org.apache.paimon.table.source.DataSplit +import org.apache.paimon.table.source.snapshot.SnapshotReader +import org.apache.paimon.types.RowType import org.apache.paimon.types.VectorType.isVectorStoreFile +import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.PaimonUtils._ import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.resolver -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, EqualTo, Expression, ExprId, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, AttributeReference, EqualTo, Expression, ExprId, Literal, Or, PythonUDF, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftOuter} import org.apache.spark.sql.catalyst.plans.logical._ @@ -61,7 +65,9 @@ case class MergeIntoPaimonDataEvolutionTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction]) extends PaimonLeafRunnableCommand - with WithFileStoreTable { + with WithFileStoreTable + with ExpressionHelper + with Logging { private lazy val writer = PaimonSparkWriter(table) @@ -136,12 +142,16 @@ case class MergeIntoPaimonDataEvolutionTable( lazy val tableSchema: StructType = v2Table.schema override def run(sparkSession: SparkSession): Seq[Row] = { + // Persist the schema that the analyzer evolved in memory (commit deferred to execution). + SchemaEvolutionHelper.commitEvolvedSchemaAtExecution(table, targetRelation, sparkSession) invokeMergeInto(sparkSession) Seq.empty[Row] } private def invokeMergeInto(sparkSession: SparkSession): Unit = { - val plan = table.newSnapshotReader().read() + val snapshotReader = table.newSnapshotReader() + pushDownMergePartitionFilter(snapshotReader) + val plan = snapshotReader.read() val tableSplits: Seq[DataSplit] = plan .splits() .asScala @@ -179,44 +189,113 @@ case class MergeIntoPaimonDataEvolutionTable( map.toMap } - // step 1: find the related data splits, make it target file plan - val dataSplits: Seq[DataSplit] = - targetRelatedSplits(sparkSession, tableSplits, firstRowIds, firstRowIdToBlobFirstRowIds) - val touchedFileTargetRelation = - createNewScanPlan(dataSplits, targetRelation) - - // step 2: invoke update action - val updateCommit = - if (matchedActions.nonEmpty) { - val updateResult = - updateActionInvoke(dataSplits, sparkSession, touchedFileTargetRelation, firstRowIds) - checkUpdateResult(updateResult) - } else Nil - - // step 3: invoke insert action - val insertCommit = - if (notMatchedActions.nonEmpty) - insertActionInvoke(sparkSession, touchedFileTargetRelation) - else Nil - - if (plan.snapshotId() != null) { - writer.rowIdCheckConflict(plan.snapshotId()) + val persistSourceDss: Option[Dataset[Row]] = + if ( + table.coreOptions().dataEvolutionMergeIntoSourcePersist() + && (matchedActions.nonEmpty || notMatchedActions.nonEmpty) + ) { + val dss = createDataset(sparkSession, sourceTable) + dss.persist() + Some(dss) + } else { + None + } + + try { + // step 1: find the related data splits, make it target file plan + val dataSplits: Seq[DataSplit] = targetRelatedSplits( + sparkSession, + tableSplits, + firstRowIds, + firstRowIdToBlobFirstRowIds, + persistSourceDss) + val touchedFileTargetRelation = + createNewScanPlan(dataSplits, targetRelation) + + // step 2: invoke update action + val updateCommit = + if (matchedActions.nonEmpty) { + val updateResult = updateActionInvoke( + dataSplits, + sparkSession, + touchedFileTargetRelation, + firstRowIds, + persistSourceDss) + checkUpdateResult(updateResult) + } else Nil + + // step 3: invoke insert action + val insertCommit = + if (notMatchedActions.nonEmpty) + insertActionInvoke(sparkSession, touchedFileTargetRelation, persistSourceDss) + else Nil + + if (plan.snapshotId() != null) { + writer.rowIdCheckConflict(plan.snapshotId()) + } + writer.commit(updateCommit ++ insertCommit) + } finally { + if (persistSourceDss.isDefined) { + persistSourceDss.get.unpersist(blocking = false) + } + } + } + + private def pushDownMergePartitionFilter(snapshotReader: SnapshotReader): Unit = { + val partitionRowType = table.schema().logicalPartitionType() + if (partitionRowType.getFieldCount == 0) { + return + } + + // matchedCondition comes from MergeIntoTable.mergeCondition, which is the MERGE ON condition. + val partitionPredicates = getExpressionOnlyRelated(matchedCondition, targetTable) + .map(splitConjunctivePredicates) + .map(extractMergePartitionFilters(_, partitionRowType)) + .getOrElse(Seq.empty) + + if (partitionPredicates.nonEmpty) { + val filter = convertConditionToPaimonPredicate( + partitionPredicates.reduce(And), + targetRelation.output, + rowType, + ignorePartialFailure = true) + filter.foreach(snapshotReader.withFilter) + } + } + + private def extractMergePartitionFilters( + filters: Seq[Expression], + partitionRowType: RowType): Seq[Expression] = { + val partitionColumns = partitionRowType.getFieldNames.asScala.toSet + filters.filter { + f => + f.deterministic && + f.references.forall(attr => partitionColumns.exists(_.equalsIgnoreCase(attr.name))) && + !SubqueryExpression.hasSubquery(f) && + f.collect { case _: PythonUDF => true }.isEmpty } - writer.commit(updateCommit ++ insertCommit) } private def targetRelatedSplits( sparkSession: SparkSession, tableSplits: Seq[DataSplit], firstRowIds: immutable.IndexedSeq[Long], - firstRowIdToBlobFirstRowIds: Map[Long, List[Long]]): Seq[DataSplit] = { + firstRowIdToBlobFirstRowIds: Map[Long, List[Long]], + persistSourceDss: Option[Dataset[Row]]): Seq[DataSplit] = { // Self-Merge shortcut: // In Self-Merge mode, every row in the table may be updated, so we scan all splits. if (isSelfMergeOnRowId) { return tableSplits } - val sourceDss = createDataset(sparkSession, sourceTable) + if (!table.coreOptions().dataEvolutionMergeIntoFilePruning()) { + logInfo( + "Skip file-level pruning for MergeInto partial column update on data-evolution table " + + s"${table.name()}.") + return tableSplits + } + + val sourceDss = persistSourceDss.getOrElse(createDataset(sparkSession, sourceTable)) val firstRowIdsTouched = extractSourceRowIdMapping match { case Some(sourceRowIdAttr) => @@ -253,7 +332,8 @@ case class MergeIntoPaimonDataEvolutionTable( dataSplits: Seq[DataSplit], sparkSession: SparkSession, touchedFileTargetRelation: DataSourceV2Relation, - firstRowIds: immutable.IndexedSeq[Long]): Seq[CommitMessage] = { + firstRowIds: immutable.IndexedSeq[Long], + persistSourceDss: Option[Dataset[Row]]): Seq[CommitMessage] = { val mergeFields = extractFields(matchedCondition) val allFields = mutable.SortedSet.empty[AttributeReference]( (o1, o2) => { @@ -374,7 +454,8 @@ case class MergeIntoPaimonDataEvolutionTable( val sourceTableProjExprs = allReadFieldsOnSource.toSeq :+ Alias(TrueLiteral, ROW_FROM_SOURCE)() - val sourceTableProj = Project(sourceTableProjExprs, sourceTable) + val sourceChild = persistSourceDss.map(_.queryExecution.logical).getOrElse(sourceTable) + val sourceTableProj = Project(sourceTableProjExprs, sourceChild) val joinPlan = Join(targetTableProj, sourceTableProj, LeftOuter, Some(matchedCondition), JoinHint.NONE) @@ -417,16 +498,18 @@ case class MergeIntoPaimonDataEvolutionTable( private def insertActionInvoke( sparkSession: SparkSession, - touchedFileTargetRelation: DataSourceV2Relation): Seq[CommitMessage] = { + touchedFileTargetRelation: DataSourceV2Relation, + persistSourceDss: Option[Dataset[Row]]): Seq[CommitMessage] = { val mergeFields = extractFields(matchedCondition) val allReadFieldsOnTarget = mergeFields.filter(field => targetTable.output.exists(attr => attr.equals(field))) val targetReadPlan = touchedFileTargetRelation.copy(targetRelation.table, allReadFieldsOnTarget.toSeq) + val sourceReadPlan = persistSourceDss.map(_.queryExecution.logical).getOrElse(sourceTable) val joinPlan = - Join(sourceTable, targetReadPlan, LeftAnti, Some(matchedCondition), JoinHint.NONE) + Join(sourceReadPlan, targetReadPlan, LeftAnti, Some(matchedCondition), JoinHint.NONE) // merge rows as there are multiple not matched actions val mergeRows = MergeRows( diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala index 1822929854e5..89689e108c45 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala @@ -71,6 +71,8 @@ case class MergeIntoPaimonTable( override def run(sparkSession: SparkSession): Seq[Row] = { // Avoid that more than one source rows match the same target row. checkMatchRationality(sparkSession) + // Persist the schema that the analyzer evolved in memory (commit deferred to execution). + SchemaEvolutionHelper.commitEvolvedSchemaAtExecution(table, relation, sparkSession) val commitMessages = if (withPrimaryKeys) { performMergeForPkTable(sparkSession) } else { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommand.scala index 1edcc99b8ec5..0f722d4c8066 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommand.scala @@ -23,7 +23,7 @@ import org.apache.paimon.spark.DynamicOverWrite import org.apache.paimon.table.FileStoreTable import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.PaimonUtils.createDataset +import org.apache.spark.sql.PaimonUtils.{createDataset, createNewDataFrame} import org.apache.spark.sql.catalyst.analysis.NamedRelation import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, V2WriteCommand} import org.apache.spark.sql.execution.command.RunnableCommand @@ -64,7 +64,7 @@ case class PaimonDynamicPartitionOverwriteCommand( WriteIntoPaimonTable( fileStoreTable, DynamicOverWrite, - createDataset(sparkSession, query), + createNewDataFrame(createDataset(sparkSession, query)), Options.fromMap(fileStoreTable.options() ++ writeOptions) ).run(sparkSession) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaEvolutionHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaEvolutionHelper.scala new file mode 100644 index 000000000000..7bfb7ee90501 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaEvolutionHelper.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.commands + +import org.apache.paimon.options.Options +import org.apache.paimon.schema.{SchemaMergingUtils, TableSchema} +import org.apache.paimon.spark.{SparkConnectorOptions, SparkTable, SparkTypeUtils} +import org.apache.paimon.spark.catalyst.analysis.PaimonOutputResolver +import org.apache.paimon.spark.schema.SparkSystemColumns +import org.apache.paimon.spark.util.OptionUtils +import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.types.RowType + +import org.apache.spark.sql.{DataFrame, PaimonUtils, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.StructType + +import scala.collection.JavaConverters._ + +/** Schema evolution flags resolved from write options and session conf. */ +private[spark] case class SchemaEvolutionFlags( + typeWidening: Boolean, + allowExplicitCast: Boolean, + caseSensitive: Boolean) + +/** + * Schema evolution entry points for catalog writes. The two `mergeSchema` overloads commit the + * evolved schema at execution (`WriteIntoPaimonTable.run` for V1, `PaimonV2Write.toBatch` for V2). + */ +private[spark] trait SchemaEvolutionHelper extends WithFileStoreTable { + + val originTable: FileStoreTable + + protected var newTable: Option[FileStoreTable] = None + + override def table: FileStoreTable = newTable.getOrElse(originTable) + + /** + * V1 catalog write entry (`WriteIntoPaimonTable.run`). The data is already cast to the evolved + * schema by [[PaimonOutputResolver]] during analysis, so this only commits the schema. + */ + def mergeSchema(sparkSession: SparkSession, input: DataFrame, options: Options): Unit = + if (SchemaEvolutionHelper.mergeSchemaEnabled(options)) + commitEvolution(sparkSession, input.schema, options) + + /** V2 catalog write entry (`PaimonV2Write.toBatch`). Commits and returns the write schema. */ + def mergeSchema(dataSchema: StructType, options: Options): StructType = { + if (!SchemaEvolutionHelper.mergeSchemaEnabled(options)) return dataSchema + commitEvolution(SparkSession.active, dataSchema, options) match { + case Some(evolved) => SparkTypeUtils.fromPaimonRowType(evolved.schema().logicalRowType()) + case None => dataSchema + } + } + + /** Commit the evolved schema for the incoming data; updates `newTable` and returns it. */ + private def commitEvolution( + sparkSession: SparkSession, + dataSchema: StructType, + options: Options): Option[FileStoreTable] = + if (SchemaEvolutionHelper.commitSchemaEvolution(table, dataSchema, sparkSession, options)) { + val evolved = table.copyWithLatestSchema() + newTable = Some(evolved) + Some(evolved) + } else { + None + } + + def updateTableWithOptions(options: Map[String, String]): Unit = { + newTable = Some(table.copy(options.asJava)) + } +} + +private[spark] object SchemaEvolutionHelper { + + /** + * The single side-effect-free merge primitive (drops system columns; `None` when unchanged), + * shared by [[expectedAttrsForCatalogWrite]] and [[evolvedTableInMemory]]; + * [[commitSchemaEvolution]] is the persisting counterpart. + */ + private def computeMergedSchema( + table: FileStoreTable, + dataSchema: StructType, + flags: SchemaEvolutionFlags): Option[TableSchema] = { + val filtered = SparkSystemColumns.filterSparkSystemColumns(dataSchema) + val dataRowType = SparkTypeUtils.toPaimonType(filtered).asInstanceOf[RowType] + val current = table.schema() + val merged = + SchemaMergingUtils.mergeSchemas( + current, + dataRowType, + flags.typeWidening, + flags.allowExplicitCast, + flags.caseSensitive) + if (merged.equals(current)) None else Some(merged) + } + + /** + * Filter system columns, resolve flags, and commit the merge of `dataSchema` into the stored + * schema. Returns whether a schema change was committed. Shared by catalog writes and MERGE INTO; + * callers that need the evolved table reload it via `copyWithLatestSchema`. + */ + def commitSchemaEvolution( + table: FileStoreTable, + dataSchema: StructType, + sparkSession: SparkSession, + options: Options = new Options()): Boolean = { + val filtered = SparkSystemColumns.filterSparkSystemColumns(dataSchema) + val flags = readFlags(sparkSession, options) + val dataRowType = SparkTypeUtils.toPaimonType(filtered).asInstanceOf[RowType] + table + .store() + .mergeSchema(dataRowType, flags.typeWidening, flags.allowExplicitCast, flags.caseSensitive) + } + + /** + * Persist the schema that MERGE INTO evolved in memory (see [[evolvedTableInMemory]]) and refresh + * the catalog cache. Called from the merge command's `run`, so the commit happens at execution, + * not during analysis. + */ + def commitEvolvedSchemaAtExecution( + table: FileStoreTable, + relation: DataSourceV2Relation, + sparkSession: SparkSession): Unit = { + if (!OptionUtils.writeMergeSchemaEnabled()) return + val evolved = SparkTypeUtils.fromPaimonRowType(table.schema().logicalRowType()) + if (commitSchemaEvolution(table, evolved, sparkSession)) { + // Refresh the catalog cache so later queries see the new schema. + for (catalog <- relation.catalog; ident <- relation.identifier) { + catalog.asInstanceOf[TableCatalog].invalidateTable(ident) + } + } + } + + /** + * The in-memory post-evolution table (not persisted), letting MERGE INTO show the new columns in + * the plan. `mergeSchemas` assigns the next schema id deterministically, so the deferred + * [[commitSchemaEvolution]] reproduces an identical schema. `None` when unchanged. + */ + def evolvedTableInMemory( + table: FileStoreTable, + dataSchema: StructType, + sparkSession: SparkSession, + options: Options = new Options()): Option[FileStoreTable] = + computeMergedSchema(table, dataSchema, readFlags(sparkSession, options)).map(table.copy) + + /** Whether schema evolution is enabled, from the per-write options or the session conf. */ + def mergeSchemaEnabled(options: Options): Boolean = + options.get(SparkConnectorOptions.MERGE_SCHEMA) || OptionUtils.writeMergeSchemaEnabled() + + /** + * Compute the resolver's expected attributes for a catalog write. When type widening is enabled + * with `byName` resolution, returns the post-evolution attrs so the resolver can cast incoming + * data to the widened target types; otherwise returns `table.output` unchanged. + */ + def expectedAttrsForCatalogWrite( + table: DataSourceV2Relation, + querySchema: StructType, + options: Options, + isByName: Boolean, + sparkSession: SparkSession): Seq[Attribute] = { + val flags = readFlags(sparkSession, options) + if (!isByName || !mergeSchemaEnabled(options) || !flags.typeWidening) return table.output + + table.table.asInstanceOf[SparkTable].getTable match { + case fst: FileStoreTable => + computeMergedSchema(fst, querySchema, flags) + .map(s => PaimonUtils.toAttributes(SparkTypeUtils.fromPaimonRowType(s.logicalRowType()))) + .getOrElse(table.output) + case _ => table.output + } + } + + /** Resolve schema evolution flags from write options and session conf. */ + private def readFlags(sparkSession: SparkSession, options: Options): SchemaEvolutionFlags = { + val typeWidening = options.get(SparkConnectorOptions.TYPE_WIDENING) || OptionUtils + .writeMergeSchemaTypeWideningEnabled() + val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST) || OptionUtils + .writeMergeSchemaExplicitCastEnabled() + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + SchemaEvolutionFlags(typeWidening, allowExplicitCast, caseSensitive) + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala deleted file mode 100644 index 1416602a5ff2..000000000000 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.spark.commands - -import org.apache.paimon.options.Options -import org.apache.paimon.spark.{SparkConnectorOptions, SparkTypeUtils} -import org.apache.paimon.spark.schema.SparkSystemColumns -import org.apache.paimon.spark.util.OptionUtils -import org.apache.paimon.table.FileStoreTable -import org.apache.paimon.types.RowType - -import org.apache.spark.sql.{Column, DataFrame, PaimonUtils, SparkSession} -import org.apache.spark.sql.functions.{col, lit, struct, transform, transform_values} -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} - -import scala.collection.JavaConverters._ - -private[spark] trait SchemaHelper extends WithFileStoreTable { - - val originTable: FileStoreTable - - protected var newTable: Option[FileStoreTable] = None - - override def table: FileStoreTable = newTable.getOrElse(originTable) - - def mergeSchema(sparkSession: SparkSession, input: DataFrame, options: Options): DataFrame = { - val dataSchema = SparkSystemColumns.filterSparkSystemColumns(input.schema) - val writeSchema = mergeSchema(dataSchema, options) - if (!PaimonUtils.sameType(writeSchema, dataSchema)) { - val resolve = sparkSession.sessionState.conf.resolver - val cols = SchemaHelper.alignColumns(writeSchema, dataSchema, resolve) - input.select(cols: _*) - } else { - input - } - } - - def mergeSchema(dataSchema: StructType, options: Options): StructType = { - val mergeSchemaEnabled = - options.get(SparkConnectorOptions.MERGE_SCHEMA) || OptionUtils.writeMergeSchemaEnabled() - if (!mergeSchemaEnabled) { - return dataSchema - } - - val filteredDataSchema = SparkSystemColumns.filterSparkSystemColumns(dataSchema) - val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST) || OptionUtils - .writeMergeSchemaExplicitCastEnabled() - SchemaHelper.mergeAndCommitSchema(table, filteredDataSchema, allowExplicitCast).foreach { - updatedTable => newTable = Some(updatedTable) - } - - val writeSchema = SparkTypeUtils.fromPaimonRowType(table.schema().logicalRowType()) - if (!PaimonUtils.sameType(writeSchema, filteredDataSchema)) { - writeSchema - } else { - filteredDataSchema - } - } - - def updateTableWithOptions(options: Map[String, String]): Unit = { - newTable = Some(table.copy(options.asJava)) - } -} - -private[spark] object SchemaHelper { - - /** - * Merge the given dataSchema into the table's schema. If the schema changed, commit the change - * and return the updated table; otherwise return None. - */ - def mergeAndCommitSchema( - table: FileStoreTable, - dataSchema: StructType, - allowExplicitCast: Boolean): Option[FileStoreTable] = { - val dataRowType = SparkTypeUtils.toPaimonType(dataSchema).asInstanceOf[RowType] - if (table.store().mergeSchema(dataRowType, allowExplicitCast)) { - Some(table.copyWithLatestSchema()) - } else { - None - } - } - - /** - * Recursively align columns from dataSchema to targetSchema by name. For nested struct fields, - * reorder and fill nulls for missing sub-fields. - */ - def alignColumns( - targetSchema: StructType, - dataSchema: StructType, - resolve: (String, String) => Boolean): Seq[Column] = { - targetSchema.map { - targetField => - dataSchema.find(f => resolve(f.name, targetField.name)) match { - case Some(dataField) => - alignColumn(col(dataField.name), dataField.dataType, targetField, resolve) - case _ => - lit(null).cast(targetField.dataType).as(targetField.name) - } - } - } - - private def alignColumn( - sourceCol: Column, - sourceType: DataType, - targetField: StructField, - resolve: (String, String) => Boolean): Column = { - (sourceType, targetField.dataType) match { - case (s: StructType, t: StructType) if !PaimonUtils.sameType(s, t) => - alignStruct(sourceCol, s, t, resolve).as(targetField.name) - case (ArrayType(s: StructType, _), ArrayType(t: StructType, _)) - if !PaimonUtils.sameType(s, t) => - transform(sourceCol, elem => alignStruct(elem, s, t, resolve)) - .as(targetField.name) - case (MapType(sKey, sVal: StructType, _), MapType(tKey, tVal: StructType, _)) - if !PaimonUtils.sameType(sVal, tVal) => - transform_values(sourceCol, (_, v) => alignStruct(v, sVal, tVal, resolve)) - .as(targetField.name) - case _ => - sourceCol.as(targetField.name) - } - } - - private def alignStruct( - sourceCol: Column, - sourceType: StructType, - targetType: StructType, - resolve: (String, String) => Boolean): Column = { - val subCols = targetType.map { - subTargetField => - sourceType.find(f => resolve(f.name, subTargetField.name)) match { - case Some(subDataField) => - alignColumn( - sourceCol.getField(subDataField.name), - subDataField.dataType, - subTargetField, - resolve) - case _ => - lit(null).cast(subTargetField.dataType).as(subTargetField.name) - } - } - struct(subCols: _*) - } -} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala index 623c88d72ff1..fcf061ec6733 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala @@ -41,7 +41,7 @@ case class WriteIntoPaimonTable( batchId: Option[Long] = None) extends RunnableCommand with ExpressionHelper - with SchemaHelper + with SchemaEvolutionHelper with Logging { override def run(sparkSession: SparkSession): Seq[Row] = { @@ -49,7 +49,7 @@ case class WriteIntoPaimonTable( PaimonUtils.createDataset( sparkSession, ReplacePaimonFunctions(sparkSession)(_data.queryExecution.analyzed)) - val data = mergeSchema(sparkSession, replacedData, options) + mergeSchema(sparkSession, replacedData, options) val (dynamicPartitionOverwriteMode, overwritePartition) = parseSaveMode() // use the extra options to rebuild the table object @@ -60,7 +60,7 @@ case class WriteIntoPaimonTable( if (overwritePartition != null) { writer.writeBuilder.withOverwrite(overwritePartition.asJava) } - val commitMessages = writer.write(data) + val commitMessages = writer.write(replacedData) writer.commit(commitMessages) Seq.empty diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/copyinto/CopyIntoResultBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/copyinto/CopyIntoResultBuilder.scala new file mode 100644 index 000000000000..a165a4523389 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/copyinto/CopyIntoResultBuilder.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.copyinto + +import org.apache.paimon.table.FileStoreTable + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String + +/** Builds per-file result rows for COPY INTO commands. */ +object CopyIntoResultBuilder { + + /** Build a single result row with the standard 6-column schema. */ + def buildResultRow( + fileName: String, + status: String, + rowsLoaded: Long, + rowsParsed: Long, + errorsSeen: Long, + firstError: String): InternalRow = { + InternalRow( + UTF8String.fromString(fileName), + UTF8String.fromString(status), + rowsLoaded, + rowsParsed, + errorsSeen, + if (firstError != null) UTF8String.fromString(firstError) else null + ) + } + + /** + * Build per-file result rows for ON_ERROR=CONTINUE mode. Merges parse errors and cast errors per + * file, determines status (LOADED / PARTIALLY_LOADED / LOAD_FAILED), and records load history for + * files that loaded at least one row or had no errors (including empty files). + */ + def buildContinueResults( + paimonTable: FileStoreTable, + filesToLoad: Array[FileStatus], + skippedFiles: Array[FileStatus], + totalRowsPerFile: Map[String, Long], + parseErrors: Map[String, Long], + castErrors: Map[String, Long], + firstParseErrorPerFile: Map[String, String], + firstCastErrorPerFile: Map[String, String]): Seq[InternalRow] = { + val paimonPath = new org.apache.paimon.fs.Path(paimonTable.location().toString) + val historyManager = new CopyLoadHistoryManager(paimonTable.fileIO(), paimonPath) + val snapshotId = paimonTable.snapshotManager().latestSnapshotId() + val loadedAt = System.currentTimeMillis() + + val loadedResults = filesToLoad.map { + fileStatus => + val baseName = fileStatus.getPath.getName + val fullPath = fileStatus.getPath.toString + val parsedCount = totalRowsPerFile.getOrElse(baseName, 0L) + val fileParseErrors = parseErrors.getOrElse(baseName, 0L) + val fileCastErrors = castErrors.getOrElse(baseName, 0L) + val totalFileErrors = fileParseErrors + fileCastErrors + val rowsLoaded = Math.max(0, parsedCount - totalFileErrors) + + if (rowsLoaded > 0 || totalFileErrors == 0) { + historyManager.recordLoaded( + CopyLoadRecord( + filePath = fullPath, + fileSize = fileStatus.getLen, + lastModified = fileStatus.getModificationTime, + loadedAt = loadedAt, + snapshotId = snapshotId, + rowsLoaded = rowsLoaded + )) + } + + val status = + if (rowsLoaded == 0 && totalFileErrors > 0) "LOAD_FAILED" + else if (totalFileErrors > 0) "PARTIALLY_LOADED" + else "LOADED" + val fileFirstError = + firstParseErrorPerFile.get(baseName).orElse(firstCastErrorPerFile.get(baseName)) + buildResultRow( + baseName, + status, + rowsLoaded, + parsedCount, + totalFileErrors, + fileFirstError.orNull) + }.toSeq + + loadedResults ++ buildSkippedResults(skippedFiles) + } + + /** + * Record load history and build results using pre-computed per-file row counts (ABORT mode). + * Avoids re-reading source files just to count rows. + */ + def recordHistoryAndBuildResultsDirect( + paimonTable: FileStoreTable, + filesToLoad: Array[FileStatus], + skippedFiles: Array[FileStatus], + rowCountsPerFile: Map[String, Long]): Seq[InternalRow] = { + val paimonPath = new org.apache.paimon.fs.Path(paimonTable.location().toString) + val historyManager = new CopyLoadHistoryManager(paimonTable.fileIO(), paimonPath) + val snapshotId = paimonTable.snapshotManager().latestSnapshotId() + val loadedAt = System.currentTimeMillis() + + val loadedResults = filesToLoad.map { + fileStatus => + val baseName = fileStatus.getPath.getName + val rowCount = rowCountsPerFile.getOrElse(baseName, 0L) + + historyManager.recordLoaded( + CopyLoadRecord( + filePath = fileStatus.getPath.toString, + fileSize = fileStatus.getLen, + lastModified = fileStatus.getModificationTime, + loadedAt = loadedAt, + snapshotId = snapshotId, + rowsLoaded = rowCount + )) + + buildResultRow(baseName, "LOADED", rowCount, rowCount, 0L, null) + }.toSeq + + loadedResults ++ buildSkippedResults(skippedFiles) + } + + def buildSkippedResults(files: Array[FileStatus]): Seq[InternalRow] = { + files.map(f => buildResultRow(f.getPath.getName, "SKIPPED", 0L, 0L, 0L, null)).toSeq + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoCastValidator.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoCastValidator.scala new file mode 100644 index 000000000000..971b9803d4d1 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoCastValidator.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution + +import org.apache.paimon.spark.catalyst.Compatibility +import org.apache.paimon.types.DataField + +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.paimon.shims.SparkShimLoader +import org.apache.spark.sql.types.{DataType, StringType} + +/** + * Handles cast validation for COPY INTO operations. Validates that source data can be safely cast + * to target table types without data loss. + */ +private[execution] class CopyIntoCastValidator(spark: org.apache.spark.sql.SparkSession) { + + /** + * Build cast validation columns for Parquet import. For each writable column that exists in both + * targetColumns and source, adds a casted temp column for comparison. Returns the augmented + * DataFrame, the (source, cast) column pairs, and the bad-cast filter expression. + */ + def buildParquetCastValidation( + rawDf: DataFrame, + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField], + excludeCols: Set[String] = Set.empty): CastValidationSetup = { + val resolver = spark.sessionState.conf.resolver + val sourceColumns = rawDf.columns.toSeq.filterNot(excludeCols.contains) + + val castColMapping = scala.collection.mutable.LinkedHashMap[String, String]() + var validationDf = rawDf + val existingCols = rawDf.columns.toSet ++ writableColumns.toSet + var usedCols = existingCols + + writableColumns.zip(fields).foreach { + case (colName, field) => + if (targetColumns.exists(tc => resolver(tc, colName))) { + sourceColumns.find(s => resolver(s, colName)).foreach { + srcColName => + val sparkType = + org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) + val castColName = safeTempCol("__pq_cv_" + colName, usedCols) + usedCols += castColName + validationDf = + validationDf.withColumn(castColName, nonAnsiCast(col(srcColName), sparkType)) + castColMapping(srcColName) = castColName + } + } + } + + val badCastFilter = if (castColMapping.nonEmpty) { + Some( + castColMapping + .map { case (src, dst) => col(src).isNotNull && col(dst).isNull } + .reduce(_ || _)) + } else None + + CastValidationSetup(validationDf, castColMapping.toMap, badCastFilter) + } + + /** + * Build cast validation columns for text-based import. For each non-string writable column, adds + * a casted temp column for comparison. Returns the augmented DataFrame, the column names + * requiring validation, a mapping from original to temp column names, and the bad-cast filter. + */ + def buildTextCastValidation( + df: DataFrame, + writableColumns: Seq[String], + fields: Seq[DataField]): CastValidationSetup = { + val existingCols = df.columns.toSet ++ writableColumns.toSet + val castColMapping = scala.collection.mutable.LinkedHashMap[String, String]() + var usedCols = existingCols + var validationDf = df + + writableColumns.zip(fields).foreach { + case (colName, field) => + val sparkType = org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) + if (sparkType != StringType) { + val tempName = safeTempCol("__cv_" + colName, usedCols) + usedCols += tempName + castColMapping(colName) = tempName + validationDf = validationDf.withColumn(tempName, nonAnsiCast(col(colName), sparkType)) + } + } + + if (castColMapping.isEmpty) { + return CastValidationSetup(df, Map.empty, None) + } + + val badCastFilter = castColMapping + .map { case (src, dst) => col(src).isNotNull && col(dst).isNull } + .reduce(_ || _) + + CastValidationSetup(validationDf, castColMapping.toMap, Some(badCastFilter)) + } + + /** + * Validate Parquet cast and abort on first failure. Detection strategy: if a source value is + * non-null but becomes null after casting, the cast failed (e.g., a string "abc" cast to + * IntegerType → null). Aborts immediately on first failure. + */ + def validateParquetCast( + rawDf: DataFrame, + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField]): Unit = { + val setup = buildParquetCastValidation(rawDf, targetColumns, writableColumns, fields) + abortOnCastFailure(setup) + } + + /** Abort on first cast failure found in the validation setup. */ + def abortOnCastFailure(setup: CastValidationSetup): Unit = { + setup.badCastFilter.foreach { + filter => + val badRows = setup.validationDf.filter(filter).limit(1).collect() + if (badRows.nonEmpty) { + val example = setup.castColMapping.find { + case (src, dst) => + val row = badRows(0) + val srcIdx = setup.validationDf.schema.fieldIndex(src) + val dstIdx = setup.validationDf.schema.fieldIndex(dst) + !row.isNullAt(srcIdx) && row.isNullAt(dstIdx) + } + throw new IllegalArgumentException( + s"ON_ERROR = ABORT_STATEMENT: Cast failure in column '${example.map(_._1).getOrElse("unknown")}'. Source data contains values that cannot be converted to the target type.") + } + } + } + + /** + * Cast all writable columns to their target Paimon types and validate. Used for text-based ABORT + * mode. + */ + def castAndValidate( + finalDf: DataFrame, + writableColumns: Seq[String], + fields: Seq[DataField]): DataFrame = { + val castedDf = castColumns(finalDf, writableColumns, fields) + + val setup = buildTextCastValidation(finalDf, writableColumns, fields) + abortOnCastFailure(setup) + + castedDf + } + + /** Cast all writable columns to their target Paimon types. */ + def castColumns( + df: DataFrame, + writableColumns: Seq[String], + fields: Seq[DataField]): DataFrame = { + writableColumns.zip(fields).foldLeft(df) { + case (d, (colName, field)) => + val sparkType = org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) + d.withColumn(colName, col(colName).cast(sparkType)) + } + } + + private def safeTempCol(baseName: String, existingColumns: Set[String]): String = + CopyIntoHelper.safeTempCol(spark, baseName, existingColumns) + + /** + * Cast a column with ANSI disabled so a failed cast yields NULL instead of throwing. This is the + * basis for bad-row detection: a source value that is non-null but becomes null after casting + * could not be converted to the target type. Under Spark's default ANSI mode a plain `.cast` + * would raise `CAST_INVALID_INPUT` before any filtering could run. + */ + private def nonAnsiCast(column: Column, dataType: DataType): Column = { + val expr = SparkShimLoader.shim.classicApi.expression(spark, column) + SparkShimLoader.shim.classicApi.column(Compatibility.cast(expr, dataType, ansiEnabled = false)) + } +} + +/** Unified cast validation setup for both Parquet and text (CSV/JSON) paths. */ +private[execution] case class CastValidationSetup( + validationDf: DataFrame, + castColMapping: Map[String, String], + badCastFilter: Option[org.apache.spark.sql.Column]) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilder.scala new file mode 100644 index 000000000000..e44ba8c63aec --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilder.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution + +import org.apache.paimon.spark.catalyst.plans.logical.{CopyFileFormat, FileFormatType} +import org.apache.paimon.types.DataField + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.functions.{col, lit, when} +import org.apache.spark.sql.paimon.shims.SparkShimLoader +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +/** + * Handles DataFrame construction and transformation for COPY INTO operations. Responsible for + * building DataFrames from source files, applying transformations, and preparing data for writing + * to Paimon tables. + */ +private[execution] class CopyIntoDataFrameBuilder( + spark: SparkSession, + fileFormat: CopyFileFormat, + columns: Option[Seq[String]]) + extends Logging { + + /** + * Build the projection DataFrame for Parquet import. Maps source columns to target table columns + * by name (case-insensitive). For each writable column: + * - If it's in targetColumns AND exists in source: cast source column to target type + * - If it's in targetColumns but missing from source: fill with NULL + * - If it's NOT in targetColumns (unmapped): fill with default value or NULL + */ + def buildParquetDataFrame( + rawDf: DataFrame, + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField]): DataFrame = { + val resolver = spark.sessionState.conf.resolver + val sourceColumns = rawDf.columns.toSeq + + val selectExprs: Seq[Column] = writableColumns.map { + colName => + if (targetColumns.exists(tc => resolver(tc, colName))) { + val srcCol = sourceColumns.find(s => resolver(s, colName)) + srcCol match { + case Some(s) => + val field = fields.find(_.name() == colName).get + val sparkType = + org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) + col(s).cast(sparkType).as(colName) + case None => + val field = fields.find(_.name() == colName).get + val sparkType = + org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) + lit(null).cast(sparkType).as(colName) + } + } else { + resolveDefaultColumn(fields.find(_.name() == colName).get, colName) + } + } + rawDf.select(selectExprs: _*) + } + + /** + * Build the final DataFrame for text-based import. Handles column renaming (CSV positional → + * named), default value filling for unmapped columns, and preserves extra columns like file name. + */ + def buildFinalDataFrame( + sourceDf: DataFrame, + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField], + extraCols: Seq[String] = Seq.empty): DataFrame = { + val renamedDf = fileFormat.formatType match { + case FileFormatType.JSON => + sourceDf + case _ => + targetColumns.zipWithIndex.foldLeft(sourceDf) { + case (df, (targetCol, idx)) => df.withColumnRenamed(s"_c$idx", targetCol) + } + } + + if (columns.isDefined) { + val selectExprs: Seq[Column] = writableColumns.map { + colName => + if (targetColumns.contains(colName)) { + col(colName) + } else { + resolveDefaultColumn(fields.find(_.name() == colName).get, colName) + } + } + renamedDf.select((selectExprs ++ extraCols.map(col)): _*) + } else { + renamedDf + } + } + + /** Build string schema for text-based formats (CSV/JSON). */ + def buildStringSchema(targetColumns: Seq[String]): StructType = { + fileFormat.formatType match { + case FileFormatType.JSON => + StructType(targetColumns.map(name => StructField(name, StringType, nullable = true))) + case _ => + StructType( + (0 until targetColumns.size).map(i => StructField(s"_c$i", StringType, nullable = true))) + } + } + + /** Read source data for text-based formats with NULL transforms applied. */ + def readSourceData( + filePaths: Array[String], + stringSchema: StructType, + readerOptions: Map[String, String]): DataFrame = { + val df = fileFormat.formatType match { + case FileFormatType.JSON => + spark.read.options(readerOptions).schema(stringSchema).json(filePaths: _*) + case _ => + spark.read.options(readerOptions).schema(stringSchema).csv(filePaths: _*) + } + + applyNullTransforms(df, df.columns) + } + + /** Apply NULL_IF and EMPTY_FIELD_AS_NULL transforms to the specified columns. */ + def applyNullTransforms(df: DataFrame, columns: Seq[String]): DataFrame = { + var result = df + + val nullIfVals = fileFormat.nullIfValues + if (nullIfVals.nonEmpty) { + columns.foreach { + colName => + result = result.withColumn( + colName, + when(col(colName).isin(nullIfVals: _*), lit(null).cast(StringType)) + .otherwise(col(colName))) + } + } + + if (fileFormat.emptyFieldAsNull) { + columns.foreach { + colName => + result = result.withColumn( + colName, + when(col(colName) === lit(""), lit(null).cast(StringType)) + .otherwise(col(colName))) + } + } + + result + } + + /** Resolve the default value expression for a column not populated from source data. */ + private def resolveDefaultColumn(field: DataField, colName: String): Column = { + val defaultVal = field.defaultValue() + if (defaultVal != null) { + val sparkType = + org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) + try { + val parsed = spark.sessionState.sqlParser.parseExpression(defaultVal) + SparkShimLoader.shim.classicApi.column(parsed).cast(sparkType).as(colName) + } catch { + case e: Exception => + logWarning( + s"Failed to parse default value '$defaultVal' for column '$colName'; using NULL instead.", + e) + lit(null).cast(sparkType).as(colName) + } + } else { + lit(null).as(colName) + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoErrorHandler.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoErrorHandler.scala new file mode 100644 index 000000000000..466134900f52 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoErrorHandler.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution + +import org.apache.paimon.spark.catalyst.plans.logical.FileFormatType +import org.apache.paimon.spark.copyinto.{CopyIntoResultBuilder, CopyLoadHistoryManager, CopyLoadRecord} +import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.types.DataField + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +/** + * Handles error detection and result building for COPY INTO operations. Supports both row-level + * (CONTINUE) and file-level (SKIP_FILE) error handling. + */ +private[execution] class CopyIntoErrorHandler( + spark: org.apache.spark.sql.SparkSession, + castValidator: CopyIntoCastValidator, + dataFrameBuilder: CopyIntoDataFrameBuilder) { + + /** Detect cast errors in Parquet files. Returns error statistics and good rows DataFrame. */ + def detectParquetErrors( + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField])(rawDfWithFile: DataFrame, fileCol: String): ErrorDetectionResult = { + + val totalRowsPerFile = CopyIntoUtils.countPerFile(rawDfWithFile, fileCol) + + val castResult = + filterParquetCastErrors(rawDfWithFile, targetColumns, writableColumns, fields, fileCol) + + ErrorDetectionResult( + totalRowsPerFile = totalRowsPerFile, + parseErrors = Map.empty, // Parquet has no parse errors + castErrors = castResult.errorsPerFile, + firstParseError = Map.empty, + firstCastError = castResult.firstErrorPerFile, + goodRowsDf = castResult.df + ) + } + + /** + * Detect parse and cast errors in text files (CSV/JSON). Returns error statistics and good rows + * DataFrame (already processed and cast). + */ + def detectTextErrors( + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField], + stringSchema: StructType, + corruptCol: String)(rawDfWithFile: DataFrame, fileCol: String): ErrorDetectionResult = { + + val totalRowsPerFile = CopyIntoUtils.countPerFile(rawDfWithFile, fileCol) + + // Separate corrupt rows (parse errors) from valid rows + val corruptDf = rawDfWithFile.filter(col(corruptCol).isNotNull) + val validDf = rawDfWithFile.filter(col(corruptCol).isNull).drop(corruptCol) + + val parseErrors = CopyIntoUtils.countPerFile(corruptDf, fileCol) + val firstParseError = if (parseErrors.nonEmpty) { + val samplesPerFile = corruptDf + .select(col(fileCol), col(corruptCol)) + .dropDuplicates(fileCol) + .collect() + samplesPerFile.map { + row => + CopyIntoUtils.extractBaseName( + row.getString(0)) -> s"Malformed record: ${row.getString(1)}" + }.toMap + } else Map.empty[String, String] + + // Process valid rows: apply null transforms, build final DataFrame, detect cast errors + val processedDf = dataFrameBuilder.applyNullTransforms(validDf, stringSchema.fieldNames) + val finalDfWithFile = dataFrameBuilder.buildFinalDataFrame( + processedDf, + targetColumns, + writableColumns, + fields, + extraCols = Seq(fileCol)) + + val castResult = castAndFilterErrors(finalDfWithFile, writableColumns, fields, fileCol) + + ErrorDetectionResult( + totalRowsPerFile = totalRowsPerFile, + parseErrors = parseErrors, + castErrors = castResult.errorsPerFile, + firstParseError = firstParseError, + firstCastError = castResult.firstErrorPerFile, + goodRowsDf = castResult.df + ) + } + + /** + * Build results for error-tolerant modes (CONTINUE/SKIP_FILE). Handles both row-level and + * file-level error reporting. + */ + def buildErrorTolerantResults( + paimonTable: FileStoreTable, + filesToLoad: Array[FileStatus], + skippedFiles: Array[FileStatus], + errorResult: ErrorDetectionResult, + filesWithErrors: Set[String], + errorGranularity: ErrorGranularity): Seq[InternalRow] = { + + errorGranularity match { + case ErrorGranularity.RowLevel => + // CONTINUE mode: use existing buildContinueResults + CopyIntoResultBuilder.buildContinueResults( + paimonTable, + filesToLoad, + skippedFiles, + errorResult.totalRowsPerFile, + errorResult.parseErrors, + errorResult.castErrors, + errorResult.firstParseError, + errorResult.firstCastError + ) + + case ErrorGranularity.FileLevel => + // SKIP_FILE mode: files are either fully loaded or fully failed + val paimonPath = new org.apache.paimon.fs.Path(paimonTable.location().toString) + val historyManager = new CopyLoadHistoryManager(paimonTable.fileIO(), paimonPath) + val snapshotId = paimonTable.snapshotManager().latestSnapshotId() + val loadedAt = System.currentTimeMillis() + + val results = filesToLoad.map { + fileStatus => + val baseName = fileStatus.getPath.getName + val fullPath = fileStatus.getPath.toString + + if (filesWithErrors.contains(baseName)) { + // File had errors, mark as LOAD_FAILED + val parseErrorCount = errorResult.parseErrors.getOrElse(baseName, 0L) + val castErrorCount = errorResult.castErrors.getOrElse(baseName, 0L) + val totalErrors = parseErrorCount + castErrorCount + val firstError = errorResult.firstParseError + .get(baseName) + .orElse(errorResult.firstCastError.get(baseName)) + + CopyIntoResultBuilder.buildResultRow( + baseName, + "LOAD_FAILED", + 0L, + errorResult.totalRowsPerFile.getOrElse(baseName, 0L), + totalErrors, + firstError.orNull) + } else { + // File had no errors, mark as LOADED + val rowCount = errorResult.totalRowsPerFile.getOrElse(baseName, 0L) + historyManager.recordLoaded( + CopyLoadRecord( + filePath = fullPath, + fileSize = fileStatus.getLen, + lastModified = fileStatus.getModificationTime, + loadedAt = loadedAt, + snapshotId = snapshotId, + rowsLoaded = rowCount + )) + + CopyIntoResultBuilder.buildResultRow(baseName, "LOADED", rowCount, rowCount, 0L, null) + } + }.toSeq + + results ++ CopyIntoResultBuilder.buildSkippedResults(skippedFiles) + } + } + + /** + * Filter out rows with cast errors and collect per-file error stats. Shared by both Parquet and + * text CONTINUE paths. + */ + private def filterCastErrors(setup: CastValidationSetup, fileCol: String): CastFilterResult = { + setup.badCastFilter match { + case Some(filter) => + val badRowsDf = setup.validationDf.filter(filter) + + // Count errors and sample first error per file + val sampleRows = badRowsDf.dropDuplicates(fileCol).collect() + val errorsPerFile = CopyIntoUtils.countPerFile(badRowsDf, fileCol) + + val firstErrorPerFile = if (sampleRows.nonEmpty) { + sampleRows.map { + sampleRow => + val fileName = + CopyIntoUtils.extractBaseName(sampleRow.getString(sampleRow.fieldIndex(fileCol))) + val example = setup.castColMapping.find { + case (src, dst) => + val srcIdx = setup.validationDf.schema.fieldIndex(src) + val dstIdx = setup.validationDf.schema.fieldIndex(dst) + !sampleRow.isNullAt(srcIdx) && sampleRow.isNullAt(dstIdx) + } + fileName -> s"Cast failure in column '${example.map(_._1).getOrElse("unknown")}'. Source data contains values that cannot be converted to the target type." + }.toMap + } else Map.empty[String, String] + + // Keep good rows and drop validation temp columns + var goodDf = setup.validationDf.filter(!filter) + setup.castColMapping.values.foreach(c => goodDf = goodDf.drop(c)) + + CastFilterResult(goodDf, errorsPerFile, firstErrorPerFile) + case None => + CastFilterResult(setup.validationDf, Map.empty, Map.empty) + } + } + + private def filterParquetCastErrors( + rawDfWithFile: DataFrame, + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField], + fileCol: String): CastFilterResult = { + val setup = + castValidator.buildParquetCastValidation( + rawDfWithFile, + targetColumns, + writableColumns, + fields, + excludeCols = Set(fileCol)) + filterCastErrors(setup, fileCol) + } + + private def castAndFilterErrors( + dfWithFile: DataFrame, + writableColumns: Seq[String], + fields: Seq[DataField], + fileCol: String): CastFilterResult = { + val setup = castValidator.buildTextCastValidation(dfWithFile, writableColumns, fields) + val result = filterCastErrors(setup, fileCol) + // Apply final cast to target types for good rows + CastFilterResult( + castValidator.castColumns(result.df, writableColumns, fields), + result.errorsPerFile, + result.firstErrorPerFile) + } +} + +private[execution] case class CastFilterResult( + df: DataFrame, + errorsPerFile: Map[String, Long], + firstErrorPerFile: Map[String, String]) + +/** Error granularity for error-tolerant modes (CONTINUE/SKIP_FILE). */ +sealed private[execution] trait ErrorGranularity +private[execution] object ErrorGranularity { + case object RowLevel extends ErrorGranularity + case object FileLevel extends ErrorGranularity +} + +/** Unified error detection result for both Parquet and text formats. */ +private[execution] case class ErrorDetectionResult( + totalRowsPerFile: Map[String, Long], + parseErrors: Map[String, Long], + castErrors: Map[String, Long], + firstParseError: Map[String, String], + firstCastError: Map[String, String], + goodRowsDf: DataFrame) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoHelper.scala new file mode 100644 index 000000000000..c00ac34f7bd7 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoHelper.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution + +import org.apache.paimon.spark.copyinto.CopyLoadHistoryManager +import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.types.DataField + +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.sql.SparkSession + +/** + * Utility methods for COPY INTO operations. Provides helper functions for column resolution, file + * filtering, and validation. + */ +private[execution] object CopyIntoHelper { + + /** + * Resolve target columns from user-specified column list or use all writable columns. Validates + * that specified columns exist in the table and are not duplicated. + */ + def resolveTargetColumns( + spark: SparkSession, + columns: Option[Seq[String]], + writableColumns: Seq[String]): Seq[String] = { + columns match { + case Some(cols) => + val resolver = spark.sessionState.conf.resolver + cols.indices.foreach { + i => + cols.indices.filter(_ > i).foreach { + j => + if (resolver(cols(i), cols(j))) { + throw new IllegalArgumentException( + s"Duplicate columns in column list: ${cols(i)}") + } + } + } + cols.map { + c => + writableColumns.find(w => resolver(w, c)).getOrElse { + throw new IllegalArgumentException( + s"Column '$c' does not exist in target table. Available columns: ${writableColumns.mkString(", ")}") + } + } + case None => writableColumns + } + } + + /** + * Validate that non-nullable columns without default values are included in the target column + * list. Throws exception if validation fails. + */ + def validateNonNullableDefaults( + columns: Option[Seq[String]], + writableColumns: Seq[String], + targetColumns: Seq[String], + fields: Seq[DataField]): Unit = { + if (columns.isEmpty) return + val unmapped = writableColumns.filterNot(targetColumns.contains) + unmapped.foreach { + colName => + val field = fields.find(_.name() == colName).get + if (!field.`type`().isNullable && field.defaultValue() == null) { + throw new IllegalArgumentException( + s"Non-nullable column '$colName' is not in the column list and has no default value") + } + } + } + + /** + * List files from source path and filter based on pattern and load history. Returns (files to + * load, files to skip). + */ + def listAndFilterFiles( + spark: SparkSession, + paimonTable: FileStoreTable, + sourcePath: String, + pattern: Option[String], + force: Boolean): (Array[FileStatus], Array[FileStatus]) = { + val hadoopConf = spark.sessionState.newHadoopConf() + val fsPath = new Path(sourcePath) + val fs = fsPath.getFileSystem(hadoopConf) + val allFiles = fs.listStatus(fsPath).filter(_.isFile) + + val patternFiltered = pattern match { + case Some(p) => + val regex = p.r + allFiles.filter(f => regex.findFirstIn(f.getPath.getName).isDefined) + case None => allFiles + } + + if (patternFiltered.isEmpty) { + return (Array.empty, Array.empty) + } + + val paimonPath = new org.apache.paimon.fs.Path(paimonTable.location().toString) + val historyManager = new CopyLoadHistoryManager(paimonTable.fileIO(), paimonPath) + + if (!force) { + val (skip, load) = patternFiltered.partition { + f => historyManager.isLoaded(f.getPath.toString, f.getLen, f.getModificationTime) + } + (load, skip) + } else { + (patternFiltered, Array.empty[FileStatus]) + } + } + + /** + * Generate a safe temporary column name that doesn't conflict with existing columns. + * Case-insensitive conflict detection. + */ + def safeTempCol(spark: SparkSession, baseName: String, existingColumns: Set[String]): String = { + val resolver = spark.sessionState.conf.resolver + var candidate = baseName + while (existingColumns.exists(c => resolver(c, candidate))) { + candidate = "_" + candidate + } + candidate + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoLocationExec.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoLocationExec.scala index f4f0720b289a..45471bab8458 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoLocationExec.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoLocationExec.scala @@ -18,7 +18,7 @@ package org.apache.paimon.spark.execution -import org.apache.paimon.spark.catalyst.plans.logical.CopyFileFormat +import org.apache.paimon.spark.catalyst.plans.logical.{CopyFileFormat, FileFormatType} import org.apache.paimon.spark.leafnode.PaimonLeafV2CommandExec import org.apache.hadoop.fs.Path @@ -51,7 +51,14 @@ case class CopyIntoLocationExec( val writerOptions = fileFormat.toSparkWriterOptions val saveMode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists - df.write.options(writerOptions).mode(saveMode).csv(targetPath) + fileFormat.formatType match { + case FileFormatType.JSON => + df.write.options(writerOptions).mode(saveMode).json(targetPath) + case FileFormatType.PARQUET => + df.write.options(writerOptions).mode(saveMode).parquet(targetPath) + case _ => + df.write.options(writerOptions).mode(saveMode).csv(targetPath) + } val hadoopConf = spark.sessionState.newHadoopConf() val fsPath = new Path(targetPath) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoTableExec.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoTableExec.scala index 260add434969..4701262f13d1 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoTableExec.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoTableExec.scala @@ -19,24 +19,22 @@ package org.apache.paimon.spark.execution import org.apache.paimon.spark.SparkTable -import org.apache.paimon.spark.catalyst.plans.logical.CopyFileFormat -import org.apache.paimon.spark.copyinto.{CopyLoadHistoryManager, CopyLoadRecord} +import org.apache.paimon.spark.catalyst.plans.logical.{CopyFileFormat, FileFormatType, OnErrorMode} +import org.apache.paimon.spark.copyinto.CopyIntoResultBuilder import org.apache.paimon.spark.leafnode.PaimonLeafV2CommandExec import org.apache.paimon.table.FileStoreTable import org.apache.paimon.types.DataField -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} -import org.apache.spark.sql.functions.{col, input_file_name, lit, when} -import org.apache.spark.sql.paimon.shims.SparkShimLoader -import org.apache.spark.sql.types.{StringType, StructField, StructType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.functions.{col, input_file_name, substring_index} +import org.apache.spark.sql.types.{StringType, StructField} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer case class CopyIntoTableExec( spark: SparkSession, @@ -47,8 +45,15 @@ case class CopyIntoTableExec( fileFormat: CopyFileFormat, pattern: Option[String], force: Boolean, + onError: OnErrorMode, out: Seq[Attribute]) - extends PaimonLeafV2CommandExec { + extends PaimonLeafV2CommandExec + with Logging { + + // Initialize helper classes + private val castValidator = new CopyIntoCastValidator(spark) + private val dataFrameBuilder = new CopyIntoDataFrameBuilder(spark, fileFormat, columns) + private val errorHandler = new CopyIntoErrorHandler(spark, castValidator, dataFrameBuilder) override def output: Seq[Attribute] = out @@ -61,284 +66,310 @@ case class CopyIntoTableExec( val tableSchema = paimonTable.schema() val writableColumns = tableSchema.fieldNames().asScala.toSeq val fields = tableSchema.fields().asScala.toSeq - val targetColumns = resolveTargetColumns(writableColumns) + val targetColumns = CopyIntoHelper.resolveTargetColumns(spark, columns, writableColumns) - validateNonNullableDefaults(writableColumns, targetColumns, fields) + CopyIntoHelper.validateNonNullableDefaults(columns, writableColumns, targetColumns, fields) - val (filesToLoad, skippedFiles) = listAndFilterFiles(paimonTable) + val (filesToLoad, skippedFiles) = + CopyIntoHelper.listAndFilterFiles(spark, paimonTable, sourcePath, pattern, force) if (filesToLoad.isEmpty) { - return buildSkippedResults(skippedFiles) + return CopyIntoResultBuilder.buildSkippedResults(skippedFiles) } val filePaths = filesToLoad.map(_.getPath.toString) - val stringSchema = StructType( - (0 until targetColumns.size).map(i => StructField(s"_c$i", StringType, nullable = true))) val readerOptions = fileFormat.toSparkReaderOptions - val csvDf = readAndProcessCsv(filePaths, stringSchema, readerOptions) - val finalDf = - buildFinalDataFrame(csvDf, targetColumns, writableColumns, fields) - val castedDf = castAndValidate(finalDf, writableColumns, fields) - - val tableName = CopyIntoUtils.quoteIdentifier(catalog.name(), ident) - castedDf.write.format("paimon").mode("append").insertInto(tableName) - - recordHistoryAndBuildResults( - paimonTable, - filesToLoad, - skippedFiles, - filePaths, - stringSchema, - readerOptions) - } - - private def resolveTargetColumns(writableColumns: Seq[String]): Seq[String] = { - columns match { - case Some(cols) => - val resolver = spark.sessionState.conf.resolver - cols.indices.foreach { - i => - cols.indices.filter(_ > i).foreach { - j => - if (resolver(cols(i), cols(j))) { - throw new IllegalArgumentException( - s"Duplicate columns in column list: ${cols(i)}") - } - } - } - cols.map { - c => - writableColumns.find(w => resolver(w, c)).getOrElse { - throw new IllegalArgumentException( - s"Column '$c' does not exist in target table. Available columns: ${writableColumns.mkString(", ")}") - } - } - case None => writableColumns + fileFormat.formatType match { + case FileFormatType.PARQUET => + runParquetImport( + paimonTable, + filePaths, + targetColumns, + writableColumns, + fields, + filesToLoad, + skippedFiles, + readerOptions) + case _ => + runTextImport( + paimonTable, + filePaths, + targetColumns, + writableColumns, + fields, + filesToLoad, + skippedFiles, + readerOptions) } } - private def validateNonNullableDefaults( - writableColumns: Seq[String], + /** + * Unified error-tolerant import for both CONTINUE and SKIP_FILE modes. Both start from + * `goodRowsDf`, the row-clean DataFrame produced by error detection (bad rows already removed + * and, for text formats, already cast to target types). + * - CONTINUE: writes every good row, so a file with errors is partially loaded. + * - SKIP_FILE: additionally drops every row belonging to a file that had any error, so such a + * file is loaded all-or-nothing. + * + * Both modes use a single batch write (one commit) regardless of file count. + */ + private def runErrorTolerantMode( + paimonTable: FileStoreTable, + rawDf: DataFrame, targetColumns: Seq[String], - fields: Seq[DataField]): Unit = { - if (columns.isEmpty) return - val unmapped = writableColumns.filterNot(targetColumns.contains) - unmapped.foreach { - colName => - val field = fields.find(_.name() == colName).get - if (!field.`type`().isNullable && field.defaultValue() == null) { - throw new IllegalArgumentException( - s"Non-nullable column '$colName' is not in the column list and has no default value") - } - } - } + writableColumns: Seq[String], + fields: Seq[DataField], + filesToLoad: Array[FileStatus], + skippedFiles: Array[FileStatus], + errorGranularity: ErrorGranularity, + detectErrors: (DataFrame, String) => ErrorDetectionResult): Seq[InternalRow] = { + + val allTargetCols = writableColumns.toSet ++ targetColumns.toSet + val inputFileCol = CopyIntoHelper.safeTempCol(spark, "__input_file", allTargetCols) + val rawDfWithFile = rawDf.withColumn(inputFileCol, input_file_name()).cache() + + try { + val errorResult = detectErrors(rawDfWithFile, inputFileCol) + + // `goodRowsDf` carries `inputFileCol` (full path); error keys are base names, so compare on + // the base name extracted from the path. + val filesWithErrors = errorResult.parseErrors.keySet ++ errorResult.castErrors.keySet + val dfToWrite = errorGranularity match { + case ErrorGranularity.RowLevel => + // CONTINUE: keep all good rows, including good rows from files that also had errors. + errorResult.goodRowsDf + + case ErrorGranularity.FileLevel => + // SKIP_FILE: drop every row whose file had any error. + if (filesWithErrors.nonEmpty) { + val baseName = substring_index(col(inputFileCol), "/", -1) + errorResult.goodRowsDf.filter(!baseName.isin(filesWithErrors.toSeq: _*)) + } else { + errorResult.goodRowsDf + } + } - private def listAndFilterFiles( - paimonTable: FileStoreTable): (Array[FileStatus], Array[FileStatus]) = { - val hadoopConf = spark.sessionState.newHadoopConf() - val fsPath = new Path(sourcePath) - val fs = fsPath.getFileSystem(hadoopConf) - val allFiles = fs.listStatus(fsPath).filter(_.isFile) - - val patternFiltered = pattern match { - case Some(p) => - val regex = p.r - allFiles.filter(f => regex.findFirstIn(f.getPath.getName).isDefined) - case None => allFiles - } + val tableName = CopyIntoUtils.quoteIdentifier(catalog.name(), ident) + val finalDf = fileFormat.formatType match { + case FileFormatType.PARQUET => + dataFrameBuilder.buildParquetDataFrame(dfToWrite, targetColumns, writableColumns, fields) + case _ => + // For text formats, goodRowsDf is already processed and cast. + dfToWrite + } - if (patternFiltered.isEmpty) { - return (Array.empty, Array.empty) - } + finalDf.drop(inputFileCol).write.format("paimon").mode("append").insertInto(tableName) - val paimonPath = new org.apache.paimon.fs.Path(paimonTable.location().toString) - val historyManager = new CopyLoadHistoryManager(paimonTable.fileIO(), paimonPath) + errorHandler.buildErrorTolerantResults( + paimonTable, + filesToLoad, + skippedFiles, + errorResult, + filesWithErrors, + errorGranularity) - if (!force) { - val (skip, load) = patternFiltered.partition { - f => historyManager.isLoaded(f.getPath.toString, f.getLen, f.getModificationTime) - } - (load, skip) - } else { - (patternFiltered, Array.empty[FileStatus]) + } finally { + rawDfWithFile.unpersist() } } - private def readAndProcessCsv( + /** + * Import files in ABORT mode: read, validate, write, and return per-file row counts. Validation + * aborts the whole statement on the first parse or cast error, so either all files are written or + * none are. + */ + private def importAbort( filePaths: Array[String], - stringSchema: StructType, - readerOptions: Map[String, String]): DataFrame = { - var df = spark.read - .options(readerOptions) - .schema(stringSchema) - .csv(filePaths: _*) - - val nullIfVals = fileFormat.nullIfValues - if (nullIfVals.nonEmpty) { - df.columns.foreach { - colName => - df = df.withColumn( - colName, - when(col(colName).isin(nullIfVals: _*), lit(null).cast(StringType)) - .otherwise(col(colName))) - } - } - - if (fileFormat.emptyFieldAsNull) { - df.columns.foreach { - colName => - df = df.withColumn( - colName, - when(col(colName) === lit(""), lit(null).cast(StringType)) - .otherwise(col(colName))) - } + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField], + readerOptions: Map[String, String], + tableName: String): Map[String, Long] = { + val allTargetCols = writableColumns.toSet ++ targetColumns.toSet + val fileCol = CopyIntoHelper.safeTempCol(spark, "__file__", allTargetCols) + fileFormat.formatType match { + case FileFormatType.PARQUET => + val rawDf = spark.read.options(readerOptions).parquet(filePaths: _*) + castValidator.validateParquetCast(rawDf, targetColumns, writableColumns, fields) + val selectedDf = + dataFrameBuilder + .buildParquetDataFrame(rawDf, targetColumns, writableColumns, fields) + .withColumn(fileCol, input_file_name()) + .cache() + try { + val counts = CopyIntoUtils.countPerFile(selectedDf, fileCol) + selectedDf.drop(fileCol).write.format("paimon").mode("append").insertInto(tableName) + counts + } finally { + selectedDf.unpersist() + } + case _ => + val stringSchema = dataFrameBuilder.buildStringSchema(targetColumns) + val sourceDf = dataFrameBuilder.readSourceData(filePaths, stringSchema, readerOptions) + val finalDf = + dataFrameBuilder.buildFinalDataFrame(sourceDf, targetColumns, writableColumns, fields) + val castedDf = castValidator + .castAndValidate(finalDf, writableColumns, fields) + .withColumn(fileCol, input_file_name()) + .cache() + try { + val counts = CopyIntoUtils.countPerFile(castedDf, fileCol) + castedDf.drop(fileCol).write.format("paimon").mode("append").insertInto(tableName) + counts + } finally { + castedDf.unpersist() + } } - - df } - private def buildFinalDataFrame( - csvDf: DataFrame, + /** + * Parquet import pipeline. Unlike CSV/JSON which read as strings then cast, Parquet files already + * have typed columns, so the flow is: + * 1. Read source Parquet with native types + * 2. Project and cast columns to match target table schema (by column name, not position) + * 3. Validate that no non-null values become null after casting (detect type incompatibility) + * 4. Write to Paimon table + * 5. Record load history for idempotent re-runs (FORCE=FALSE dedup) + */ + private def runParquetImport( + paimonTable: FileStoreTable, + filePaths: Array[String], targetColumns: Seq[String], writableColumns: Seq[String], - fields: Seq[DataField]): DataFrame = { - val renamedDf = targetColumns.zipWithIndex.foldLeft(csvDf) { - case (df, (targetCol, idx)) => df.withColumnRenamed(s"_c$idx", targetCol) - } + fields: Seq[DataField], + filesToLoad: Array[FileStatus], + skippedFiles: Array[FileStatus], + readerOptions: Map[String, String]): Seq[InternalRow] = { + val rawDf = spark.read.options(readerOptions).parquet(filePaths: _*) + + onError match { + case OnErrorMode.Continue | OnErrorMode.SkipFile => + val errorGranularity = if (onError == OnErrorMode.Continue) { + ErrorGranularity.RowLevel + } else { + ErrorGranularity.FileLevel + } - if (columns.isDefined) { - val selectExprs: Seq[Column] = writableColumns.map { - colName => - if (targetColumns.contains(colName)) { - col(colName) - } else { - val field = fields.find(_.name() == colName).get - val defaultVal = field.defaultValue() - if (defaultVal != null) { - val sparkType = - org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) - try { - val parsed = spark.sessionState.sqlParser.parseExpression(defaultVal) - SparkShimLoader.shim.classicApi.column(parsed).cast(sparkType).as(colName) - } catch { - case _: Exception => lit(null).cast(sparkType).as(colName) - } - } else { - lit(null).as(colName) - } - } - } - renamedDf.select(selectExprs: _*) - } else { - renamedDf + runErrorTolerantMode( + paimonTable, + rawDf, + targetColumns, + writableColumns, + fields, + filesToLoad, + skippedFiles, + errorGranularity, + errorHandler.detectParquetErrors(targetColumns, writableColumns, fields) + ) + + case _ => + val tableName = CopyIntoUtils.quoteIdentifier(catalog.name(), ident) + val countsPerFile = + importAbort(filePaths, targetColumns, writableColumns, fields, readerOptions, tableName) + CopyIntoResultBuilder.recordHistoryAndBuildResultsDirect( + paimonTable, + filesToLoad, + skippedFiles, + countsPerFile) } } - private def castAndValidate( - finalDf: DataFrame, + /** + * Text-based (CSV/JSON) import pipeline. Reads all columns as strings first, then: + * 1. Rename positional columns (CSV) or keep named columns (JSON) + * 2. Fill unmapped columns with default values + * 3. Cast all string columns to target types with validation + * 4. Write to Paimon table + * 5. Record load history + */ + private def runTextImport( + paimonTable: FileStoreTable, + filePaths: Array[String], + targetColumns: Seq[String], writableColumns: Seq[String], - fields: Seq[DataField]): DataFrame = { - val nonStringCastCols = ArrayBuffer[String]() - var castedDf = finalDf - writableColumns.zip(fields).foreach { - case (colName, field) => - val sparkType = org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) - castedDf = castedDf.withColumn(colName, col(colName).cast(sparkType)) - if (sparkType != StringType) { - nonStringCastCols += colName + fields: Seq[DataField], + filesToLoad: Array[FileStatus], + skippedFiles: Array[FileStatus], + readerOptions: Map[String, String]): Seq[InternalRow] = { + val stringSchema = dataFrameBuilder.buildStringSchema(targetColumns) + + onError match { + case OnErrorMode.Continue | OnErrorMode.SkipFile => + val errorGranularity = if (onError == OnErrorMode.Continue) { + ErrorGranularity.RowLevel + } else { + ErrorGranularity.FileLevel } - } - if (nonStringCastCols.nonEmpty) { - val castSuffix = "__cv" - val validationDf = nonStringCastCols.zipWithIndex.foldLeft(finalDf) { - case (df, (colName, idx)) => - val field = fields.find(_.name() == colName).get - val sparkType = org.apache.paimon.spark.SparkTypeUtils.fromPaimonType(field.`type`()) - df.withColumn(castSuffix + idx, col(colName).cast(sparkType)) - } - val badCastFilter = nonStringCastCols.zipWithIndex - .map { case (cn, idx) => col(cn).isNotNull && col(castSuffix + idx).isNull } - .reduce(_ || _) - val badRows = validationDf.filter(badCastFilter).limit(1).collect() - if (badRows.nonEmpty) { - val example = nonStringCastCols.zipWithIndex.find { - case (cn, idx) => - val row = badRows(0) - val srcIdx = validationDf.schema.fieldIndex(cn) - val dstIdx = validationDf.schema.fieldIndex(castSuffix + idx) - !row.isNullAt(srcIdx) && row.isNullAt(dstIdx) + // For error-tolerant modes, we need to read with corrupt record tracking + val allTargetCols = writableColumns.toSet ++ targetColumns.toSet + val corruptCol = CopyIntoHelper.safeTempCol(spark, "_corrupt_record", allTargetCols) + val schemaWithCorrupt = + stringSchema.add(StructField(corruptCol, StringType, nullable = true)) + val corruptRecordOption = + Map("columnNameOfCorruptRecord" -> corruptCol, "mode" -> "PERMISSIVE") + + val rawDf = fileFormat.formatType match { + case FileFormatType.JSON => + spark.read + .options(readerOptions ++ corruptRecordOption) + .schema(schemaWithCorrupt) + .json(filePaths: _*) + case _ => + spark.read + .options(readerOptions ++ corruptRecordOption) + .schema(schemaWithCorrupt) + .csv(filePaths: _*) } - throw new IllegalArgumentException( - s"ON_ERROR = ABORT_STATEMENT: Cast failure in column '${example.map(_._1).getOrElse("unknown")}'. Source data contains values that cannot be converted to the target type.") - } - } - castedDf + runErrorTolerantMode( + paimonTable, + rawDf, + targetColumns, + writableColumns, + fields, + filesToLoad, + skippedFiles, + errorGranularity, + errorHandler.detectTextErrors( + targetColumns, + writableColumns, + fields, + stringSchema, + corruptCol) + ) + + case _ => + runTextImportAbort( + paimonTable, + filePaths, + targetColumns, + writableColumns, + fields, + filesToLoad, + skippedFiles, + readerOptions) + } } - private def recordHistoryAndBuildResults( + private def runTextImportAbort( paimonTable: FileStoreTable, + filePaths: Array[String], + targetColumns: Seq[String], + writableColumns: Seq[String], + fields: Seq[DataField], filesToLoad: Array[FileStatus], skippedFiles: Array[FileStatus], - filePaths: Array[String], - stringSchema: StructType, readerOptions: Map[String, String]): Seq[InternalRow] = { - val paimonPath = new org.apache.paimon.fs.Path(paimonTable.location().toString) - val historyManager = new CopyLoadHistoryManager(paimonTable.fileIO(), paimonPath) - val snapshotId = paimonTable.snapshotManager().latestSnapshotId() - val loadedAt = System.currentTimeMillis() - - val rowCounts = spark.read - .options(readerOptions) - .schema(stringSchema) - .csv(filePaths: _*) - .groupBy(input_file_name().as("file")) - .count() - .collect() - - val fileCountMap = rowCounts.map { - row => - val fullPath = row.getString(0) - val baseName = fullPath.substring(fullPath.lastIndexOf('/') + 1) - baseName -> row.getLong(1) - }.toMap - - val loadedResults = filesToLoad.map { - fileStatus => - val baseName = fileStatus.getPath.getName - val rowCount = fileCountMap.getOrElse(baseName, 0L) - - historyManager.recordLoaded( - CopyLoadRecord( - filePath = fileStatus.getPath.toString, - fileSize = fileStatus.getLen, - lastModified = fileStatus.getModificationTime, - loadedAt = loadedAt, - snapshotId = snapshotId, - rowsLoaded = rowCount - )) - - InternalRow( - UTF8String.fromString(baseName), - UTF8String.fromString("LOADED"), - rowCount, - rowCount) - }.toSeq - - val skippedResults = buildSkippedResults(skippedFiles) - loadedResults ++ skippedResults - } + val tableName = CopyIntoUtils.quoteIdentifier(catalog.name(), ident) + val countsPerFile = + importAbort(filePaths, targetColumns, writableColumns, fields, readerOptions, tableName) - private def buildSkippedResults(files: Array[FileStatus]): Seq[InternalRow] = { - files.map { - f => - InternalRow( - UTF8String.fromString(f.getPath.getName), - UTF8String.fromString("SKIPPED"), - 0L, - 0L) - }.toSeq + CopyIntoResultBuilder.recordHistoryAndBuildResultsDirect( + paimonTable, + filesToLoad, + skippedFiles, + countsPerFile) } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoUtils.scala index dd85af058d56..fdf76b8dc6c5 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/CopyIntoUtils.scala @@ -18,7 +18,9 @@ package org.apache.paimon.spark.execution +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.functions.col object CopyIntoUtils { @@ -28,4 +30,17 @@ object CopyIntoUtils { Seq(ident.name()) parts.filter(_.nonEmpty).map(p => s"`${p.replace("`", "``")}`").mkString(".") } + + def extractBaseName(fullPath: String): String = { + fullPath.substring(fullPath.lastIndexOf('/') + 1) + } + + /** Count rows per file, keyed by base file name. */ + def countPerFile(df: DataFrame, fileCol: String): Map[String, Long] = { + df.groupBy(col(fileCol)) + .count() + .collect() + .map(row => extractBaseName(row.getString(0)) -> row.getLong(1)) + .toMap + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala index 76bd6beb1559..6870ce8832a4 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala @@ -149,7 +149,7 @@ case class PaimonStrategy(spark: SparkSession) partitionPredicate: Option[PartitionPredicate]) => TruncatePaimonTableWithFilterExec(table, partitionPredicate) :: Nil - case c @ CopyIntoTableCommand(PaimonCatalogAndIdentifier(catalog, ident), _, _, _, _, _) => + case c @ CopyIntoTableCommand(PaimonCatalogAndIdentifier(catalog, ident), _, _, _, _, _, _) => CopyIntoTableExec( spark, catalog, @@ -159,6 +159,7 @@ case class PaimonStrategy(spark: SparkSession) c.fileFormat, c.pattern, c.force, + c.onError, c.output) :: Nil case c @ CopyIntoLocationCommand(_, PaimonCatalogAndIdentifier(catalog, ident), _, _) => diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/read/BinPackingSplits.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/read/BinPackingSplits.scala index 24de25bce7f1..f2d06b45eaad 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/read/BinPackingSplits.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/read/BinPackingSplits.scala @@ -78,14 +78,18 @@ case class BinPackingSplits(coreOptions: CoreOptions, readRowSizeRatio: Double = def pack(splits: Array[Split]): Seq[PaimonInputPartition] = { val (toReshuffle, reserved) = splits.partition { case _: FallbackSplit => false - case split: DataSplit => split.rawConvertible() + case split: DataSplit => split.rawConvertible() || coreOptions.dataEvolutionEnabled() // Currently, format table reader only supports reading one file. case _: FormatDataSplit => false case _ => false } if (toReshuffle.nonEmpty) { val startTS = System.currentTimeMillis() - val reshuffled = packDataSplit(toReshuffle.collect { case ds: DataSplit => ds }) + val reshuffled = if (coreOptions.dataEvolutionEnabled()) { + packDataEvolutionSplit(toReshuffle.collect { case ds: DataSplit => ds }) + } else { + packDataSplit(toReshuffle.collect { case ds: DataSplit => ds }) + } val all = reserved.map(PaimonInputPartition.apply) ++ reshuffled val duration = System.currentTimeMillis() - startTS logInfo( @@ -156,6 +160,41 @@ case class BinPackingSplits(coreOptions: CoreOptions, readRowSizeRatio: Double = partitions.toArray } + private def packDataEvolutionSplit(splits: Array[DataSplit]): Array[PaimonInputPartition] = { + val maxSplitBytes = computeMaxSplitBytes(splits) + + var currentSize = 0L + val currentSplits = new ArrayBuffer[DataSplit] + val partitions = new ArrayBuffer[PaimonInputPartition] + + def closeInputPartition(): Unit = { + if (currentSplits.nonEmpty) { + partitions += PaimonInputPartition(currentSplits.toArray) + currentSplits.clear() + currentSize = 0L + } + } + + splits.foreach { + split => + val ddFiles = dataFileAndDeletionFiles(split) + val size = ddFiles.map { + case (dataFile, deletionFile) => + (dataFile.fileSize() * readRowSizeRatio).toLong + openCostInBytes + Option(deletionFile) + .map(_.length()) + .getOrElse(0L) + }.sum + if (currentSplits.nonEmpty && currentSize + size > maxSplitBytes) { + closeInputPartition() + } + currentSplits += split + currentSize += size + } + + closeInputPartition() + partitions.toArray + } + private def copyDataSplit( split: DataSplit, dataFiles: Seq[DataFileMeta], diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/rowops/PaimonSparkCopyOnWriteOperation.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/rowops/PaimonSparkCopyOnWriteOperation.scala index e1e2ddd4d9f1..24c3c19761c2 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/rowops/PaimonSparkCopyOnWriteOperation.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/rowops/PaimonSparkCopyOnWriteOperation.scala @@ -20,7 +20,7 @@ package org.apache.paimon.spark.rowops import org.apache.paimon.options.Options import org.apache.paimon.spark.PaimonBaseScanBuilder -import org.apache.paimon.spark.schema.PaimonMetadataColumn.FILE_PATH_COLUMN +import org.apache.paimon.spark.schema.PaimonMetadataColumn.{FILE_PATH_COLUMN, ROW_ID_COLUMN, SEQUENCE_NUMBER_COLUMN} import org.apache.paimon.spark.write.PaimonV2WriteBuilder import org.apache.paimon.table.FileStoreTable @@ -57,6 +57,11 @@ class PaimonSparkCopyOnWriteOperation(table: FileStoreTable, info: RowLevelOpera } override def requiredMetadataAttributes(): Array[NamedReference] = { - Array(Expressions.column(FILE_PATH_COLUMN)) + val base = Array(Expressions.column(FILE_PATH_COLUMN)) + if (table.coreOptions().rowTrackingEnabled()) { + base ++ Array(Expressions.column(ROW_ID_COLUMN), Expressions.column(SEQUENCE_NUMBER_COLUMN)) + } else { + base + } } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala index 4b8ede097c3d..8438d5b17581 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala @@ -24,11 +24,16 @@ import org.apache.paimon.types.DataField import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.connector.catalog.MetadataColumn -import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, FloatType, IntegerType, LongType, StringType, StructField, StructType} -case class PaimonMetadataColumn(id: Int, override val name: String, override val dataType: DataType) - extends MetadataColumn { +case class PaimonMetadataColumn( + id: Int, + override val name: String, + override val dataType: DataType, + preserveOnDelete: Boolean = true, + preserveOnUpdate: Boolean = true, + preserveOnReinsert: Boolean = false) + extends PaimonMetadataColumnBase { def toPaimonDataField: DataField = { new DataField(id, name, SparkTypeUtils.toPaimonType(dataType)); @@ -51,6 +56,7 @@ object PaimonMetadataColumn { val BUCKET_COLUMN = "__paimon_bucket" val ROW_ID_COLUMN: String = SpecialFields.ROW_ID.name() val SEQUENCE_NUMBER_COLUMN: String = SpecialFields.SEQUENCE_NUMBER.name() + val VECTOR_SEARCH_SCORE_COLUMN: String = "__paimon_vector_search_score" val PATH_AND_INDEX_META_COLUMNS: Seq[String] = Seq(FILE_PATH_COLUMN, ROW_INDEX_COLUMN) val PARTITION_AND_BUCKET_META_COLUMNS: Seq[String] = Seq(PARTITION_COLUMN, BUCKET_COLUMN) @@ -62,7 +68,8 @@ object PaimonMetadataColumn { PARTITION_COLUMN, BUCKET_COLUMN, ROW_ID_COLUMN, - SEQUENCE_NUMBER_COLUMN + SEQUENCE_NUMBER_COLUMN, + VECTOR_SEARCH_SCORE_COLUMN ) val ROW_INDEX: PaimonMetadataColumn = @@ -77,7 +84,13 @@ object PaimonMetadataColumn { val ROW_ID: PaimonMetadataColumn = PaimonMetadataColumn(Int.MaxValue - 104, ROW_ID_COLUMN, LongType) val SEQUENCE_NUMBER: PaimonMetadataColumn = - PaimonMetadataColumn(Int.MaxValue - 105, SEQUENCE_NUMBER_COLUMN, LongType) + PaimonMetadataColumn( + Int.MaxValue - 105, + SEQUENCE_NUMBER_COLUMN, + LongType, + preserveOnUpdate = false) + val VECTOR_SEARCH_SCORE: PaimonMetadataColumn = + PaimonMetadataColumn(Integer.MAX_VALUE - 106, VECTOR_SEARCH_SCORE_COLUMN, FloatType) def dvMetaCols: Seq[PaimonMetadataColumn] = Seq(FILE_PATH, ROW_INDEX) @@ -89,6 +102,7 @@ object PaimonMetadataColumn { case BUCKET_COLUMN => BUCKET case ROW_ID_COLUMN => ROW_ID case SEQUENCE_NUMBER_COLUMN => SEQUENCE_NUMBER + case VECTOR_SEARCH_SCORE_COLUMN => VECTOR_SEARCH_SCORE case _ => throw new IllegalArgumentException(s"$metadataColumn metadata column is not supported.") } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonSink.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonSink.scala index 62712950e70c..9d0a1795b589 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonSink.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonSink.scala @@ -20,7 +20,7 @@ package org.apache.paimon.spark.sources import org.apache.paimon.options.Options import org.apache.paimon.spark.{InsertInto, Overwrite} -import org.apache.paimon.spark.commands.{SchemaHelper, WriteIntoPaimonTable} +import org.apache.paimon.spark.commands.{SchemaEvolutionHelper, WriteIntoPaimonTable} import org.apache.paimon.table.FileStoreTable import org.apache.spark.sql.{DataFrame, PaimonUtils, SQLContext} @@ -35,7 +35,7 @@ class PaimonSink( outputMode: OutputMode, options: Options) extends Sink - with SchemaHelper { + with SchemaEvolutionHelper { override def addBatch(batchId: Long, data: DataFrame): Unit = { val saveMode = if (outputMode == OutputMode.Complete()) { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/OptionUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/OptionUtils.scala index 7a6fa547c2a9..b21b83de26e7 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/OptionUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/OptionUtils.scala @@ -18,6 +18,7 @@ package org.apache.paimon.spark.util +import org.apache.paimon.CoreOptions import org.apache.paimon.catalog.Identifier import org.apache.paimon.options.ConfigOption import org.apache.paimon.spark.{SparkCatalogOptions, SparkConnectorOptions} @@ -113,6 +114,10 @@ object OptionUtils extends SQLConfHelper with Logging { getOptionString(SparkConnectorOptions.EXPLICIT_CAST).toBoolean } + def writeMergeSchemaTypeWideningEnabled(): Boolean = { + getOptionString(SparkConnectorOptions.TYPE_WIDENING).toBoolean + } + def v1FunctionEnabled(): Boolean = { getOptionString(SparkCatalogOptions.V1FUNCTION_ENABLED).toBoolean } @@ -165,15 +170,41 @@ object OptionUtils extends SQLConfHelper with Logging { catalogName: String = null, ident: Identifier = null, extraOptions: JMap[String, String] = new JHashMap[String, String]()): T = { - val mergedOptions = if (catalogName != null && ident != null) { - mergeSQLConfWithIdentifier(extraOptions, catalogName, ident) - } else { - mergeSQLConf(extraOptions) - } + val mergedOptions = getMergedOptions(catalogName, ident, extraOptions) if (mergedOptions.isEmpty) { table } else { table.copy(mergedOptions).asInstanceOf[T] } } + + private def getMergedOptions( + catalogName: String = null, + ident: Identifier = null, + extraOptions: JMap[String, String] = new JHashMap[String, String]()): JMap[String, String] = { + if (catalogName != null && ident != null) { + mergeSQLConfWithIdentifier(extraOptions, catalogName, ident) + } else { + mergeSQLConf(extraOptions) + } + } + + def withBranchFromOptions( + catalogName: String = null, + identifier: Identifier = null, + extraOptions: JMap[String, String] = new JHashMap[String, String]() + ): Identifier = { + if (identifier != null && !identifier.isSystemTable) { + val branch = + getMergedOptions(catalogName, identifier, extraOptions).get(CoreOptions.BRANCH.key) + if (branch != null && identifier.getBranchName == null) { + logWarning( + s"Using deprecated 'spark.paimon.branch=$branch' to access table '${identifier.getTableName}'. " + + s"Please migrate to '${identifier.getTableName}$$branch_$branch' syntax, as 'spark.paimon.branch' " + + s"will be removed in a future version.") + return new Identifier(identifier.getDatabaseName, identifier.getTableName, branch) + } + } + identifier + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/DataEvolutionTableDataWrite.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/DataEvolutionTableDataWrite.scala index c0e5190d5dfa..ba2e84ef8a44 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/DataEvolutionTableDataWrite.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/DataEvolutionTableDataWrite.scala @@ -28,6 +28,7 @@ import org.apache.paimon.spark.util.SparkRowUtils import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage, CommitMessageImpl, TableWriteImpl} import org.apache.paimon.types.RowType import org.apache.paimon.utils.RecordWriter +import org.apache.paimon.utils.SerializationUtils import org.apache.spark.sql.Row @@ -39,7 +40,7 @@ import scala.collection.mutable.ListBuffer case class DataEvolutionTableDataWrite( writeBuilder: BatchWriteBuilder, writeType: RowType, - firstRowIdToPartitionMap: mutable.HashMap[Long, (BinaryRow, Long)], + firstRowIdToPartitionMap: mutable.HashMap[Long, (Array[Byte], Long)], catalogContext: CatalogContext) extends InnerTableV1DataWrite { @@ -73,7 +74,8 @@ case class DataEvolutionTableDataWrite( s"Available first row IDs: ${firstRowIdToPartitionMap.keys.mkString(", ")}") } - val (partition, numRecords) = pair + val (partitionBytes, numRecords) = pair + val partition = SerializationUtils.deserializeBinaryRow(partitionBytes) val writer = writeBuilder .newWrite() diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWriteBase.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWriteBase.scala index d19f1a709646..42d2ebcd8590 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWriteBase.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWriteBase.scala @@ -19,11 +19,13 @@ package org.apache.paimon.spark.write import org.apache.paimon.io.{CompactIncrement, DataFileMeta, DataIncrement} +import org.apache.paimon.spark.SparkTypeUtils import org.apache.paimon.spark.catalyst.Compatibility import org.apache.paimon.spark.commands.SparkDataFileMeta import org.apache.paimon.spark.metric.SparkMetricRegistry import org.apache.paimon.spark.rowops.PaimonCopyOnWriteScan -import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.spark.schema.PaimonMetadataColumn.{FILE_PATH, ROW_ID, SEQUENCE_NUMBER} +import org.apache.paimon.table.{FileStoreTable, SpecialFields} import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage, CommitMessageImpl} import org.apache.spark.sql.PaimonSparkSession @@ -67,18 +69,46 @@ abstract class PaimonBatchWriteBase( builder } + private val writeRowTracking: Boolean = + coreOptions.rowTrackingEnabled() && copyOnWriteScan.isDefined + + private lazy val rtPaimonWriteType = + SpecialFields.rowTypeWithRowTracking(table.rowType(), false, true) + + private lazy val rtWriteSchema = + SparkTypeUtils.fromPaimonRowType(rtPaimonWriteType) + + private lazy val rtMetadataSchema = + StructType(Seq(FILE_PATH, ROW_ID, SEQUENCE_NUMBER).map(_.toStructField)) + protected def createPaimonDataWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { (_: Int, _: Long) => { - PaimonV2DataWriter( - batchWriteBuilder, - writeSchema, - dataSchema, - coreOptions, - catalogContextForBlobDescriptor) + if (writeRowTracking) { + createPaimonMetadataAwareDataWriter() + } else { + PaimonV2DataWriter( + batchWriteBuilder, + writeSchema, + dataSchema, + coreOptions, + catalogContextForBlobDescriptor) + } } } + private def createPaimonMetadataAwareDataWriter(): PaimonV2DataWriter = { + new PaimonV2MetadataAwareDataWriter( + batchWriteBuilder, + writeSchema, + rtWriteSchema, + dataSchema, + rtMetadataSchema, + coreOptions, + catalogContextForBlobDescriptor, + rtPaimonWriteType) + } + protected def commitMessages(messages: Array[WriterCommitMessage]): Unit = { commitStarted = true logInfo(s"Committing to table ${table.name()}") diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2DataWriter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2DataWriter.scala index aa2dfcdf8f56..eadd056cf2a5 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2DataWriter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2DataWriter.scala @@ -23,8 +23,11 @@ import org.apache.paimon.catalog.CatalogContext import org.apache.paimon.spark.{SparkInternalRowWrapper, SparkUtils} import org.apache.paimon.spark.metric.SparkMetricRegistry import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage, TableWriteImpl} +import org.apache.paimon.types.RowType +import org.apache.paimon.utils.IOUtils import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.types.StructType @@ -36,7 +39,10 @@ case class PaimonV2DataWriter( dataSchema: StructType, coreOptions: CoreOptions, catalogContext: CatalogContext, - batchId: Option[Long] = None) + batchId: Option[Long] = None, + paimonWriteType: Option[RowType] = None, + metadataSchema: Option[StructType] = None, + plainWriteSchema: Option[StructType] = None) extends abstractInnerTableDataWrite[InternalRow] with InnerTableV2DataWrite { @@ -46,35 +52,79 @@ case class PaimonV2DataWriter( val fullCompactionDeltaCommits: Option[Int] = Option.apply(coreOptions.fullCompactionDeltaCommits()) - val write: TableWriteImpl[InternalRow] = { - writeBuilder + private def createTableWrite(writeType: Option[RowType]): TableWriteImpl[InternalRow] = { + val w = writeBuilder .newWrite() .withIOManager(ioManager) .withMetricRegistry(metricRegistry) .asInstanceOf[TableWriteImpl[InternalRow]] + writeType.foreach(w.withWriteType) + w } - private val rowConverter: InternalRow => SparkInternalRowWrapper = { + val write: TableWriteImpl[InternalRow] = createTableWrite(paimonWriteType) + + private var plainWrite: Option[TableWriteImpl[InternalRow]] = None + + private def getPlainWrite: TableWriteImpl[InternalRow] = { + plainWrite.getOrElse { + val w = createTableWrite(None) + plainWrite = Some(w) + w + } + } + + private def createRowConverter( + writeSchema: StructType, + schema: StructType): InternalRow => SparkInternalRowWrapper = { val numFields = writeSchema.fields.length val reusableWrapper = - new SparkInternalRowWrapper(writeSchema, numFields, dataSchema, catalogContext) + new SparkInternalRowWrapper(writeSchema, numFields, schema, catalogContext) record => reusableWrapper.replace(record) } + private val rowConverter: InternalRow => SparkInternalRowWrapper = + createRowConverter(writeSchema, dataSchema) + + private val plainRowConverter: Option[InternalRow => SparkInternalRowWrapper] = + plainWriteSchema.map(schema => createRowConverter(schema, dataSchema)) + + private val metadataAwareRowConverter: Option[InternalRow => SparkInternalRowWrapper] = + metadataSchema.map( + schema => createRowConverter(writeSchema, StructType(dataSchema.fields ++ schema.fields))) + + private val joinedRow = new JoinedRow() + override def write(record: InternalRow): Unit = { - postWrite(write.writeAndReturn(rowConverter.apply(record))) + plainRowConverter match { + case Some(converter) => + postWrite(getPlainWrite.writeAndReturn(converter.apply(record))) + case _ => + postWrite(write.writeAndReturn(rowConverter.apply(record))) + } + } + + def writeWithMetadata(metadata: InternalRow, record: InternalRow): Unit = { + metadataAwareRowConverter match { + case Some(converter) => + postWrite(write.writeAndReturn(converter.apply(joinedRow(record, metadata)))) + case None => + write(record) + } } override def commitImpl(): Seq[CommitMessage] = { - write.prepareCommit().asScala.toSeq + val metadataMessages = write.prepareCommit().asScala.toSeq + val plainMessages = plainWrite.map(_.prepareCommit().asScala.toSeq).getOrElse(Seq.empty) + metadataMessages ++ plainMessages } override def abort(): Unit = close() override def close(): Unit = { try { - write.close() - ioManager.close() + val closeables = Seq[AutoCloseable](write) ++ plainWrite.toSeq ++ Seq(ioManager) + IOUtils.closeAll(closeables.asJava) } catch { case e: Exception => throw new RuntimeException(e) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala index c7a9a9ff3ab8..2ae1dd53a367 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala @@ -21,7 +21,7 @@ package org.apache.paimon.spark.write import org.apache.paimon.CoreOptions.ChangelogProducer import org.apache.paimon.options.Options import org.apache.paimon.spark._ -import org.apache.paimon.spark.commands.SchemaHelper +import org.apache.paimon.spark.commands.SchemaEvolutionHelper import org.apache.paimon.spark.rowops.PaimonCopyOnWriteScan import org.apache.paimon.table.BucketMode.BUCKET_UNAWARE import org.apache.paimon.table.FileStoreTable @@ -44,10 +44,9 @@ class PaimonV2Write( options: Options ) extends Write with RequiresDistributionAndOrdering - with SchemaHelper + with SchemaEvolutionHelper with Logging { - private val writeSchema = mergeSchema(dataSchema, options) private val writeRequirement = PaimonWriteRequirement(table) override def requiredDistribution(): Distribution = { @@ -63,6 +62,8 @@ class PaimonV2Write( } override def toBatch: BatchWrite = { + // Commit the evolved schema at execution (not at planning), then write to the evolved table. + val writeSchema = mergeSchema(dataSchema, options) SparkShimLoader.shim.createPaimonBatchWrite( table, writeSchema, diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala index d8b70f68f24c..16d34ecf4c4a 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.OutputMetrics import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -136,6 +136,10 @@ object PaimonUtils { DataType.equalsIgnoreCompatibleNullability(from, to) } + /** `StructType` to fresh `AttributeReference`s (the `StructType.toAttributes` removed in 3.4+). */ + def toAttributes(schema: StructType): Seq[AttributeReference] = + schema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + def classIsLoadable(clazz: String): Boolean = { SparkUtils.classIsLoadable(clazz) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala index 1e0c13a573fc..4fd533f40122 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/AbstractPaimonSparkSqlExtensionsParser.scala @@ -43,18 +43,8 @@ import scala.collection.JavaConverters._ * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * The implementation of [[ParserInterface]] that parsers the sql extension. - * - *

    Most of the content of this class is referenced from Iceberg's - * IcebergSparkSqlExtensionsParser. - * - * @param delegate - * The extension parser. - */ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterface) - extends org.apache.spark.sql.catalyst.parser.ParserInterface - with Logging { + extends Logging { private lazy val substitutor = new VariableSubstitution() private lazy val astBuilder = new PaimonSqlExtensionsAstBuilder(delegate) @@ -88,12 +78,14 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf PaimonSqlExtensionsParser.FORCE, PaimonSqlExtensionsParser.ON_ERROR, PaimonSqlExtensionsParser.ABORT_STATEMENT, + PaimonSqlExtensionsParser.CONTINUE, + PaimonSqlExtensionsParser.SKIP_FILE, PaimonSqlExtensionsParser.OVERWRITE, PaimonSqlExtensionsParser.CSV ) /** Parses a string to a LogicalPlan. */ - override def parsePlan(sqlText: String): LogicalPlan = { + def parsePlan(sqlText: String): LogicalPlan = { val sqlTextAfterSubstitution = substitutor.substitute(sqlText) if (isPaimonCommand(sqlTextAfterSubstitution)) { parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement())) @@ -126,30 +118,30 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf } /** Parses a string to an Expression. */ - override def parseExpression(sqlText: String): Expression = + def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText) /** Parses a string to a TableIdentifier. */ - override def parseTableIdentifier(sqlText: String): TableIdentifier = + def parseTableIdentifier(sqlText: String): TableIdentifier = delegate.parseTableIdentifier(sqlText) /** Parses a string to a FunctionIdentifier. */ - override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = delegate.parseFunctionIdentifier(sqlText) /** * Creates StructType for a given SQL string, which is a comma separated list of field definitions * which will preserve the correct Hive metadata. */ - override def parseTableSchema(sqlText: String): StructType = + def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText) /** Parses a string to a DataType. */ - override def parseDataType(sqlText: String): DataType = + def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText) /** Parses a string to a multi-part identifier. */ - override def parseMultipartIdentifier(sqlText: String): Seq[String] = + def parseMultipartIdentifier(sqlText: String): Seq[String] = delegate.parseMultipartIdentifier(sqlText) /** Returns whether SQL text is command. */ diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala index da716ced11c5..0d54c698e6a7 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/PaimonSqlExtensionsAstBuilder.scala @@ -172,7 +172,15 @@ class PaimonSqlExtensionsAstBuilder(delegate: ParserInterface) val fileFormat = buildFileFormat(ctx.fileFormatClause()) val pattern = Option(ctx.patternClause()).map(p => unquoteString(p.STRING().getText)) val force = Option(ctx.forceClause()).exists(_.booleanValue().TRUE() != null) - logical.CopyIntoTableCommand(table, columns, sourcePath, fileFormat, pattern, force) + val onError = Option(ctx.onErrorClause()) + .map { + clause => + if (clause.CONTINUE() != null) OnErrorMode.Continue + else if (clause.SKIP_FILE() != null) OnErrorMode.SkipFile + else OnErrorMode.AbortStatement + } + .getOrElse(OnErrorMode.AbortStatement) + logical.CopyIntoTableCommand(table, columns, sourcePath, fileFormat, pattern, force, onError) } /** Create a COPY INTO LOCATION (export) logical command. */ diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkChainTableITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkChainTableITCase.java index 1907c7fcf3cc..c2833c223ea0 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkChainTableITCase.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkChainTableITCase.java @@ -2298,4 +2298,68 @@ public void testChainTableWithMultiGroupPartition(@TempDir java.nio.file.Path te spark.close(); } + + @Test + public void testChainTableWithBranchOption(@TempDir java.nio.file.Path tempDir) + throws IOException { + Path warehousePath = new Path("file:" + tempDir.toString()); + SparkSession.Builder builder = createSparkSessionBuilder(warehousePath); + SparkSession spark = builder.getOrCreate(); + spark.sql("CREATE DATABASE IF NOT EXISTS my_db1"); + spark.sql("USE spark_catalog.my_db1"); + spark.sql("DROP TABLE IF EXISTS `my_db1`.`chain_test`;"); + spark.sql( + "CREATE TABLE IF NOT EXISTS `chain_test` (\n" + + " `t1` BIGINT,\n" + + " `t2` BIGINT,\n" + + " `t3` STRING\n" + + ") PARTITIONED BY (`dt` STRING)\n" + + "TBLPROPERTIES (\n" + + " 'bucket-key' = 't1',\n" + + " 'primary-key' = 'dt,t1',\n" + + " 'partition.timestamp-pattern' = '$dt',\n" + + " 'partition.timestamp-formatter' = 'yyyyMMdd',\n" + + " 'chain-table.enabled' = 'true',\n" + + " 'bucket' = '1',\n" + + " 'merge-engine' = 'deduplicate',\n" + + " 'sequence.field' = 't2'\n" + + ")"); + setupChainTableBranches(spark, "chain_test"); + // Write main branch + spark.sql( + "INSERT OVERWRITE TABLE `my_db1`.`chain_test` PARTITION (dt = '20250810') VALUES (1, 3, '0')"); + // Write delta branch + spark.sql("SET spark.paimon.branch = delta"); + spark.sql( + "INSERT OVERWRITE TABLE `my_db1`.`chain_test` PARTITION (dt = '20250810') VALUES (1, 2, '1')"); + spark.sql( + "INSERT OVERWRITE TABLE `my_db1`.`chain_test$branch_delta` PARTITION (dt = '20250811') VALUES (2, 2, '1')"); + assertThat(spark.sql("SELECT * FROM `my_db1`.`chain_test$snapshots`").count()).isEqualTo(2); + spark.sql("RESET spark.paimon.branch"); + assertThat( + spark.sql("SELECT * FROM `my_db1`.`chain_test` where dt = '20250811'") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder("[1,2,1,20250811]", "[2,2,1,20250811]"); + assertThat( + spark + .sql( + "SELECT * FROM `my_db1`.`chain_test$branch_snapshot` WHERE dt = '20250811'") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .isEmpty(); + + spark.sql("SET spark.paimon.branch = snapshot"); + assertThat( + spark.sql("SELECT * FROM `my_db1`.`chain_test` where dt = '20250811'") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .isEmpty(); + assertThat(spark.sql("SELECT * FROM `my_db1`.`chain_test$snapshots`").count()).isEqualTo(2); + spark.sql("DROP TABLE IF EXISTS `my_db1`.`chain_test`;"); + spark.close(); + } } diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkDataEvolutionITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkDataEvolutionITCase.java new file mode 100644 index 000000000000..856b0cc57efb --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkDataEvolutionITCase.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark; + +import org.apache.paimon.fs.Path; +import org.apache.paimon.hive.TestHiveMetastore; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for data evolution on Spark. */ +public class SparkDataEvolutionITCase { + + private static TestHiveMetastore testHiveMetastore; + private static final int PORT = 9092; + + @BeforeAll + public static void startMetastore() { + testHiveMetastore = new TestHiveMetastore(); + testHiveMetastore.start(PORT); + } + + @AfterAll + public static void closeMetastore() throws Exception { + testHiveMetastore.stop(); + } + + private SparkSession.Builder createSparkSessionBuilder(Path warehousePath) { + return SparkSession.builder() + .config("spark.sql.warehouse.dir", warehousePath.toString()) + // with hive metastore + .config("spark.sql.catalogImplementation", "hive") + .config("hive.metastore.uris", "thrift://localhost:" + PORT) + .config("spark.sql.catalog.spark_catalog", SparkCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.metastore", "hive") + .config( + "spark.sql.catalog.spark_catalog.hive.metastore.uris", + "thrift://localhost:" + PORT) + .config("spark.sql.catalog.spark_catalog.format-table.enabled", "true") + .config("spark.sql.catalog.spark_catalog.warehouse", warehousePath.toString()) + .config( + "spark.sql.extensions", + "org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions") + .master("local[2]"); + } + + @Test + public void testDataEvolution(@TempDir java.nio.file.Path tempDir) throws IOException { + Path warehousePath = new Path("file:" + tempDir.toString()); + SparkSession.Builder builder = + SparkSession.builder() + .config("spark.sql.warehouse.dir", warehousePath.toString()) + // with hive metastore + .config("spark.sql.catalogImplementation", "hive") + .config("hive.metastore.uris", "thrift://localhost:" + PORT) + .config("spark.sql.catalog.spark_catalog", SparkCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.metastore", "hive") + .config( + "spark.sql.catalog.spark_catalog.hive.metastore.uris", + "thrift://localhost:" + PORT) + .config("spark.sql.catalog.spark_catalog.format-table.enabled", "true") + .config( + "spark.sql.catalog.spark_catalog.warehouse", + warehousePath.toString()) + .config( + "spark.sql.extensions", + "org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions") + .master("local[2]"); + SparkSession spark = builder.getOrCreate(); + spark.sql("CREATE DATABASE IF NOT EXISTS my_db1"); + spark.sql("USE spark_catalog.my_db1"); + + /** Create table */ + spark.sql( + "CREATE TABLE IF NOT EXISTS \n" + + " `my_db1`.`data_evolution_test` (\n" + + " `id` BIGINT COMMENT 'id',\n" + + " `g_1_1` BIGINT COMMENT 'g_1_1',\n" + + " `g_1_2` BIGINT COMMENT 'g_1_2',\n" + + " `g_2_1` BIGINT COMMENT 'g_2_1',\n" + + " `g_2_2` BIGINT COMMENT 'g_2_2'\n" + + " ) PARTITIONED BY (`dt` STRING COMMENT 'dt') ROW FORMAT SERDE 'org.apache.paimon.hive.PaimonSerDe'\n" + + "WITH\n" + + " SERDEPROPERTIES ('serialization.format' = '1') STORED AS INPUTFORMAT 'org.apache.paimon.hive.mapred.PaimonInputFormat' OUTPUTFORMAT 'org.apache.paimon.hive.mapred.PaimonOutputFormat' TBLPROPERTIES (\n" + + " 'file.compression' = 'snappy',\n" + + " 'manifest.compression' = 'snappy',\n" + + " 'row-tracking.enabled' = 'true',\n" + + " 'data-evolution.enabled' = 'true',\n" + + " 'data-evolution.merge-into.source-persist' = 'true',\n" + + " 'partition.timestamp-pattern' = '$dt',\n" + + " 'partition.timestamp-formatter' = 'yyyyMMdd',\n" + + " 'metastore.partitioned-table' = 'true'" + + " )"); + + spark.sql( + "CREATE TABLE IF NOT EXISTS \n" + + " `my_db1`.`data_evolution_source` (\n" + + " `id` BIGINT COMMENT 'id',\n" + + " `g_2_1` BIGINT COMMENT 'g_2_1',\n" + + " `g_2_2` BIGINT COMMENT 'g_2_2'\n" + + " ) PARTITIONED BY (`dt` STRING COMMENT 'dt') ROW FORMAT SERDE 'org.apache.paimon.hive.PaimonSerDe'\n" + + "WITH\n" + + " SERDEPROPERTIES ('serialization.format' = '1') STORED AS INPUTFORMAT 'org.apache.paimon.hive.mapred.PaimonInputFormat' OUTPUTFORMAT 'org.apache.paimon.hive.mapred.PaimonOutputFormat' TBLPROPERTIES (\n" + + " 'file.compression' = 'snappy',\n" + + " 'manifest.compression' = 'snappy',\n" + + " 'partition.timestamp-pattern' = '$dt',\n" + + " 'partition.timestamp-formatter' = 'yyyyMMdd',\n" + + " 'metastore.partitioned-table' = 'true'" + + " )"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "insert overwrite table\n" + + " `my_db1`.`data_evolution_test` partition (dt='20260305')\n" + + "values (1, 1, 1, null, null),\n" + + " (1, 2, 1, null, null),\n" + + " (2, 1, 1, null, null);"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "insert overwrite table\n" + + " `my_db1`.`data_evolution_test` partition (dt='20260304')\n" + + "values (10, 10, 10, null, null),\n" + + " (10, 20, 10, null, null),\n" + + " (20, 10, 10, null, null)"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "insert overwrite table\n" + + " `my_db1`.`data_evolution_source` partition (dt='20260304')\n" + + "values (10, 20, 20),\n" + + " (40, 10, 10);"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "insert overwrite table\n" + + " `my_db1`.`data_evolution_source` partition (dt='20260305')\n" + + "values (1, 2, 2),\n" + + " (4, 1, 1);"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "MERGE INTO `my_db1`.`data_evolution_test` AS t\n" + + "USING `my_db1`.`data_evolution_source` AS s\n" + + "ON t.id = s.id\n" + + "AND t.dt = s.dt\n" + + "AND s.dt = '20260304'\n" + + "AND t.dt = '20260304'\n" + + " WHEN matched THEN\n" + + "UPDATE\n" + + "SET t.g_2_1 = s.g_2_1,\n" + + " t.g_2_2 = s.g_2_2;"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "MERGE INTO `my_db1`.`data_evolution_test` AS t\n" + + "USING `my_db1`.`data_evolution_source` AS s\n" + + "ON t.id = s.id\n" + + "AND t.dt = s.dt\n" + + "AND s.dt = '20260305'\n" + + "AND t.dt = '20260305'\n" + + " WHEN matched THEN\n" + + "UPDATE\n" + + "SET t.g_2_1 = s.g_2_1,\n" + + " t.g_2_2 = s.g_2_2;"); + spark.close(); + + spark = builder.getOrCreate(); + assertThat( + spark + .sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260304'") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder("[10,10,20]", "[10,20,20]", "[20,10,null]"); + + long recordCount = + spark.sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260304' and g_1_1 = 10 and g_2_1 = 10") + .count(); + assertThat(recordCount).isEqualTo(0); + + assertThat( + spark + .sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260304' and g_1_1 = 10 and g_2_1 =20") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder("[10,10,20]"); + + assertThat( + spark + .sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260304' and (g_1_1 = 10 or g_2_1 =10)") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder("[10,10,20]", "[20,10,null]"); + + assertThat( + spark + .sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260304' and (g_1_1 = 10 or g_2_1 =20)") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder("[10,10,20]", "[10,20,20]", "[20,10,null]"); + spark.close(); + + spark = builder.getOrCreate(); + assertThat( + spark + .sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260305'") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder("[1,1,2]", "[1,2,2]", "[2,1,null]"); + + recordCount = + spark.sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260305' and g_1_1 = 1 and g_2_1 =1") + .count(); + assertThat(recordCount).isEqualTo(0); + + assertThat( + spark + .sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260305' and g_1_1 = 1 and g_2_1 =2") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder("[1,1,2]"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "insert overwrite table\n" + + " `my_db1`.`data_evolution_test` partition (dt='20260604')\n" + + "values (10, 10, 10, null, null),\n" + + " (20, 10, 10, null, null)"); + spark.sql( + "insert overwrite table\n" + + " `my_db1`.`data_evolution_source` partition (dt='20260604')\n" + + "values (30, 10, 10),\n" + + " (40, 10, 10);"); + spark.sql( + "MERGE INTO `my_db1`.`data_evolution_test` AS t\n" + + "USING `my_db1`.`data_evolution_source` AS s\n" + + "ON t.id = s.id\n" + + "AND t.dt = s.dt\n" + + "AND s.dt = '20260604'\n" + + "AND t.dt = '20260604'\n" + + " WHEN NOT MATCHED THEN\n" + + "INSERT (id,g_2_1, g_2_2, dt) VALUES (s.id, s.g_2_1, s.g_2_2, s.dt);"); + assertThat( + spark + .sql( + "select id,g_1_1,g_2_1 from `my_db1`.`data_evolution_test` where dt='20260604'") + .collectAsList().stream() + .map(Row::toString) + .collect(Collectors.toList())) + .containsExactlyInAnyOrder( + "[10,10,null]", "[20,10,null]", "[30,null,10]", "[40,null,10]"); + spark.close(); + + spark = builder.getOrCreate(); + /** Drop table */ + spark.sql("DROP TABLE IF EXISTS `my_db1`.`data_evolution_test`;"); + spark.sql("DROP TABLE IF EXISTS `my_db1`.`data_evolution_source`;"); + spark.close(); + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java index 4248d07d769f..8b5457c9dff6 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java @@ -26,11 +26,15 @@ import org.apache.paimon.types.CharType; import org.apache.paimon.types.DataField; import org.apache.paimon.types.DateType; +import org.apache.paimon.types.DoubleType; +import org.apache.paimon.types.FloatType; import org.apache.paimon.types.IntType; import org.apache.paimon.types.RowType; import org.apache.paimon.types.TimestampType; import org.apache.paimon.types.VarCharType; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysTrue; import org.apache.spark.sql.sources.EqualNullSafe; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.GreaterThan; @@ -229,6 +233,36 @@ public void testDate() { assertThat(localDateExpression).isEqualTo(rawExpression); } + @Test + public void testAlwaysTrueFalse() { + RowType rowType = + new RowType(Collections.singletonList(new DataField(0, "id", new IntType()))); + SparkFilterConverter converter = new SparkFilterConverter(rowType); + + assertThat(converter.convert(new AlwaysTrue())).isEqualTo(PredicateBuilder.alwaysTrue()); + assertThat(converter.convert(new AlwaysFalse())).isEqualTo(PredicateBuilder.alwaysFalse()); + } + + @Test + public void testEqualToNaN() { + RowType rowType = + new RowType( + Arrays.asList( + new DataField(0, "f", new FloatType()), + new DataField(1, "d", new DoubleType()))); + SparkFilterConverter converter = new SparkFilterConverter(rowType); + PredicateBuilder builder = new PredicateBuilder(rowType); + + EqualTo eqNaNFloat = EqualTo.apply("f", Float.NaN); + assertThat(converter.convert(eqNaNFloat)).isEqualTo(builder.isNaN(0)); + + EqualTo eqNaNDouble = EqualTo.apply("d", Double.NaN); + assertThat(converter.convert(eqNaNDouble)).isEqualTo(builder.isNaN(1)); + + EqualTo eqFloat = EqualTo.apply("f", 1.0f); + assertThat(converter.convert(eqFloat)).isEqualTo(builder.equal(0, 1.0f)); + } + @Test public void testIgnoreFailure() { List dataFields = new ArrayList<>(); diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java new file mode 100644 index 000000000000..4100f54f61dd --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark; + +import org.apache.paimon.fs.Path; +import org.apache.paimon.hive.TestHiveMetastore; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for Paimon Multimodality type support on Spark. */ +public class SparkMultimodalITCase { + + private static TestHiveMetastore testHiveMetastore; + private static final int PORT = 9092; + + @BeforeAll + public static void startMetastore() { + testHiveMetastore = new TestHiveMetastore(); + testHiveMetastore.start(PORT); + } + + @AfterAll + public static void closeMetastore() throws Exception { + testHiveMetastore.stop(); + } + + private SparkSession.Builder createSparkSessionBuilder(Path warehousePath) { + return SparkSession.builder() + .config("spark.sql.warehouse.dir", warehousePath.toString()) + // with hive metastore + .config("spark.sql.catalogImplementation", "hive") + .config("hive.metastore.uris", "thrift://localhost:" + PORT) + .config("spark.sql.catalog.spark_catalog", SparkCatalog.class.getName()) + .config("spark.sql.catalog.spark_catalog.metastore", "hive") + .config( + "spark.sql.catalog.spark_catalog.hive.metastore.uris", + "thrift://localhost:" + PORT) + .config("spark.sql.catalog.spark_catalog.format-table.enabled", "true") + .config("spark.sql.catalog.spark_catalog.warehouse", warehousePath.toString()) + .config( + "spark.sql.extensions", + "org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions") + .master("local[2]"); + } + + @Test + public void testVector(@TempDir java.nio.file.Path tempDir) throws IOException { + Path warehousePath = new Path("file:" + tempDir.toString()); + SparkSession.Builder builder = createSparkSessionBuilder(warehousePath); + SparkSession spark = builder.getOrCreate(); + spark.sql("CREATE DATABASE IF NOT EXISTS my_db1"); + spark.sql("USE spark_catalog.my_db1"); + + spark.sql( + "\n" + + "CREATE TABLE my_db1.vector_test (gid BIGINT, sid STRING, embs ARRAY)" + + " PARTITIONED BY (`date` STRING COMMENT 'date') ROW FORMAT SERDE 'org.apache.paimon.hive.PaimonSerDe'\n" + + "WITH\n" + + " SERDEPROPERTIES ('serialization.format' = '1') STORED AS INPUTFORMAT 'org.apache.paimon.hive.mapred.PaimonInputFormat' OUTPUTFORMAT 'org.apache.paimon.hive.mapred.PaimonOutputFormat' TBLPROPERTIES (\n" + + " 'vector.file.format'='lance',\n" + + " 'vector-field'='embs',\n" + + " 'field.embs.vector-dim'='4',\n" + + " 'row-tracking.enabled'='true',\n" + + " 'data-evolution.enabled'='true',\n" + + " 'global-index.enabled' = 'true'\n" + + ");"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "insert overwrite table my_db1.vector_test\n" + + "VALUES (1, '1', array(cast(1.0 as float), cast(2.0 as float), cast(3.0 as float), cast(4.0 as float)), '20260420'),\n" + + "(2, '2', array(cast(2.0 as float), cast(3.0 as float), cast(4.0 as float), cast(5.0 as float)), '20260420'),\n" + + "(3, '3', array(cast(3.0 as float), cast(4.0 as float), cast(5.0 as float), cast(6.0 as float)), '20260420'),\n" + + "(4, '4', array(cast(4.0 as float), cast(5.0 as float), cast(6.0 as float), cast(7.0 as float)), '20260420'),\n" + + "(5, '5', array(cast(5.0 as float), cast(6.0 as float), cast(7.0 as float), cast(8.0 as float)), '20260420'),\n" + + "(6, '6', array(cast(6.0 as float), cast(7.0 as float), cast(8.0 as float), cast(9.0 as float)), '20260420'),\n" + + "(7, '7', array(cast(7.0 as float), cast(8.0 as float), cast(9.0 as float), cast(10.0 as float)), '20260420'),\n" + + "(8, '8', array(cast(8.0 as float), cast(9.0 as float), cast(10.0 as float), cast(11.0 as float)), '20260420');"); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql( + "\n" + + "CALL sys.create_global_index(\n" + + " `table` => 'my_db1.vector_test',\n" + + " `partitions` => \"date='20260420'\",\n" + + " `index_column` => 'embs',\n" + + " `index_type` => 'lumina-vector-ann',\n" + + " `options` => 'lumina.index.dimension=4'\n" + + ");"); + spark.close(); + + spark = builder.getOrCreate(); + List rows = + spark.sql("select gid, sid, embs from my_db1.vector_test where date = '20260420';") + .collectAsList(); + assertThat(rows).hasSize(8); + rows = + spark.sql( + "select gid, sid, embs from vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f), 5) where date = '20260420'") + .collectAsList(); + assertThat(rows).hasSize(5); + Dataset df = + spark.sql( + "select gid, sid, embs, __paimon_vector_search_score from vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f), 5) where date = '20260420'"); + assertThat(df.columns()).hasSize(4); + rows = df.collectAsList(); + assertThat(rows).hasSize(5); + spark.close(); + + spark = builder.getOrCreate(); + spark.sql("DROP TABLE IF EXISTS `my_db1`.`vector_test`;"); + spark.close(); + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkSchemaEvolutionITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkSchemaEvolutionITCase.java index fc8b4adb6ebd..7afe3c76cf78 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkSchemaEvolutionITCase.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkSchemaEvolutionITCase.java @@ -1080,4 +1080,38 @@ public void testUpdateNestedColumnTypeInMap(String formatType) { .containsExactlyInAnyOrder( "[1,APPLE,1000000000000]", "[2,cat,200]", "[3,FLOWER,3000000000000]"); } + + private static final String BLOB_TABLE_PROPS = + "'row-tracking.enabled'='true', 'data-evolution.enabled'='true', 'bucket'='-1'"; + + @Test + public void testAddBlobColumnViaCommentDirective() { + String table = "paimon.default.blob_add_col"; + spark.sql( + "CREATE TABLE " + + table + + " (id INT, data STRING) TBLPROPERTIES (" + + BLOB_TABLE_PROPS + + ")"); + + // bare directive — no user comment + spark.sql( + "ALTER TABLE " + + table + + " ADD COLUMN desc_col BINARY COMMENT '__BLOB_DESCRIPTOR_FIELD'"); + // directive + user comment + spark.sql( + "ALTER TABLE " + + table + + " ADD COLUMN picture BINARY COMMENT '__BLOB_FIELD; profile picture'"); + + String createSql = + spark.sql("SHOW CREATE TABLE " + table).collectAsList().get(0).toString(); + assertThat(createSql).doesNotContain("__BLOB"); + assertThat(createSql).contains("desc_col"); + assertThat(createSql).contains("picture"); + assertThat(createSql).contains("profile picture"); + assertThat(createSql).contains("'blob-field' = 'picture'"); + assertThat(createSql).contains("'blob-descriptor-field' = 'desc_col'"); + } } diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkTypeTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkTypeTest.java index fdc7558fd5f4..424981916fe3 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkTypeTest.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkTypeTest.java @@ -21,6 +21,7 @@ import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; +import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.Test; @@ -107,4 +108,22 @@ public void testAllTypes() { assertThat(toPaimonType(sparkType)).isEqualTo(ALL_TYPES); } + + @Test + public void testVectorType() { + RowType rowType = + RowType.builder() + .field("nullable_vec", DataTypes.VECTOR(3, DataTypes.FLOAT())) + .field("notnull_vec", DataTypes.VECTOR(3, DataTypes.FLOAT()).notNull()) + .build(); + StructType sparkType = fromPaimonRowType(rowType); + + assertThat(sparkType.apply("nullable_vec").nullable()).isTrue(); + ArrayType nullableArray = (ArrayType) sparkType.apply("nullable_vec").dataType(); + assertThat(nullableArray.containsNull()).isFalse(); + + assertThat(sparkType.apply("notnull_vec").nullable()).isFalse(); + ArrayType notNullArray = (ArrayType) sparkType.apply("notnull_vec").dataType(); + assertThat(notNullArray.containsNull()).isFalse(); + } } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/BinPackingSplitsTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/BinPackingSplitsTest.scala index 2e2d311c85da..9173681953b4 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/BinPackingSplitsTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/BinPackingSplitsTest.scala @@ -23,7 +23,7 @@ import org.apache.paimon.data.BinaryRow import org.apache.paimon.io.DataFileMeta import org.apache.paimon.manifest.FileSource import org.apache.paimon.spark.read.BinPackingSplits -import org.apache.paimon.table.source.{DataSplit, Split} +import org.apache.paimon.table.source.{DataSplit, DeletionFile, Split} import org.junit.jupiter.api.Assertions @@ -35,43 +35,19 @@ import scala.collection.mutable class BinPackingSplitsTest extends PaimonSparkTestBase { test("Paimon: reshuffle splits") { - withSparkSQLConf(("spark.sql.leafNodeDefaultParallelism", "20")) { + withSparkSQLConf("spark.sql.files.minPartitionNum" -> "20") { val splitNum = 5 val fileNum = 100 val files = scala.collection.mutable.ListBuffer.empty[DataFileMeta] - 0.until(fileNum).foreach { - i => - val path = s"f$i.parquet" - files += DataFileMeta.forAppend( - path, - 750000, - 30000, - null, - 0, - 29999, - 1, - new java.util.ArrayList[String](), - null, - FileSource.APPEND, - null, - null, - null, - null) - } + 0.until(fileNum).foreach(i => files += newDataFile(s"f$i.parquet", 750000, 30000, 29999)) val dataSplits = mutable.ArrayBuffer.empty[Split] 0.until(splitNum).foreach { i => - dataSplits += DataSplit - .builder() - .withSnapshot(1) - .withBucket(0) - .withPartition(BinaryRow.EMPTY_ROW) - .withDataFiles(files.zipWithIndex.filter(_._2 % splitNum == i).map(_._1).toList.asJava) - .rawConvertible(true) - .withBucketPath("no use") - .build() + dataSplits += newDataSplitFromFiles( + files.zipWithIndex.filter(_._2 % splitNum == i).map(_._1).toSeq, + rawConvertible = true) } val binPacking = BinPackingSplits(CoreOptions.fromMap(new JHashMap())) @@ -81,41 +57,69 @@ class BinPackingSplitsTest extends PaimonSparkTestBase { } test("Paimon: reshuffle one split") { - val files = List( - DataFileMeta.forAppend( - "f1.parquet", - 750000, - 30000, - null, - 0, - 29999, - 1, - new java.util.ArrayList[String](), - null, - FileSource.APPEND, - null, - null, - null, - null) - ).asJava - - val dataSplits: Array[Split] = Array( - DataSplit - .builder() - .withSnapshot(1) - .withBucket(0) - .withPartition(BinaryRow.EMPTY_ROW) - .withDataFiles(files) - .rawConvertible(true) - .withBucketPath("no use") - .build() - ) + val split = newDataSplitFromFiles( + Seq(newDataFile("f1.parquet", 750000, 30000, 29999)), + rawConvertible = true) + + val dataSplits: Array[Split] = Array(split) val binPacking = BinPackingSplits(CoreOptions.fromMap(new JHashMap())) val reshuffled = binPacking.pack(dataSplits) Assertions.assertEquals(1, reshuffled.length) } + test("Paimon: pack data evolution splits by split granularity") { + withSparkSQLConf("spark.sql.files.minPartitionNum" -> "1") { + val split1 = newDataSplit("split1", Seq(40L, 40L), deletionFileLength = Some(5L)) + val split2 = newDataSplit("split2", Seq(40L, 40L), deletionFileLength = Some(5L)) + val split3 = newDataSplit("split3", Seq(40L, 40L), deletionFileLength = Some(5L)) + + val binPacking = BinPackingSplits( + CoreOptions.fromMap( + Map( + "data-evolution.enabled" -> "true", + "deletion-vectors.enabled" -> "true", + "source.split.open-file-cost" -> "5 B", + "source.split.target-size" -> "150 B").asJava), + readRowSizeRatio = 0.5 + ) + val reshuffled = binPacking.pack(Array[Split](split1, split2, split3)) + + // Each split size is 2 * (40 * 0.5 + 5 open cost + 5 deletion file length) = 60. + // Therefore two whole splits fit into the 150 B target while three do not. + Assertions.assertEquals(2, reshuffled.length) + Assertions.assertEquals(2, reshuffled.head.splits.length) + Assertions.assertSame(split1, reshuffled.head.splits.head) + Assertions.assertSame(split2, reshuffled.head.splits(1)) + Assertions.assertEquals(1, reshuffled(1).splits.length) + Assertions.assertSame(split3, reshuffled(1).splits.head) + reshuffled.flatMap(_.splits).foreach { + split => Assertions.assertEquals(2, split.asInstanceOf[DataSplit].dataFiles().size()) + } + } + } + + test("Paimon: data evolution split packing keeps oversized split whole") { + withSparkSQLConf("spark.sql.files.minPartitionNum" -> "1") { + val split = newDataSplit("oversized", Seq(40L, 40L)) + + val binPacking = BinPackingSplits( + CoreOptions.fromMap( + Map( + "data-evolution.enabled" -> "true", + "source.split.open-file-cost" -> "0 B", + "source.split.target-size" -> "50 B").asJava)) + val reshuffled = binPacking.pack(Array[Split](split)) + + Assertions.assertEquals(1, reshuffled.length) + Assertions.assertEquals(1, reshuffled.head.splits.length) + Assertions.assertSame(split, reshuffled.head.splits.head) + Assertions.assertEquals( + 2, + reshuffled.head.splits.head.asInstanceOf[DataSplit].dataFiles().size()) + } + } + test("Paimon: set open-file-cost to 0") { withTable("t") { sql("CREATE TABLE t (a INT, b STRING)") @@ -126,21 +130,21 @@ class BinPackingSplitsTest extends PaimonSparkTestBase { def paimonScan() = getPaimonScan("SELECT * FROM t") // default openCostInBytes is 4m, so we will get 400 / 128 = 4 partitions - withSparkSQLConf("spark.sql.leafNodeDefaultParallelism" -> "1") { + withSparkSQLConf("spark.sql.files.minPartitionNum" -> "1") { assert(paimonScan().inputPartitions.length == 4) } withSparkSQLConf( - "spark.sql.files.openCostInBytes" -> "0", - "spark.sql.leafNodeDefaultParallelism" -> "1") { + "spark.sql.files.minPartitionNum" -> "1", + "spark.sql.files.openCostInBytes" -> "0") { assert(paimonScan().inputPartitions.length == 1) } // Paimon's conf takes precedence over Spark's withSparkSQLConf( + "spark.sql.files.minPartitionNum" -> "1", "spark.sql.files.openCostInBytes" -> "4194304", - "spark.paimon.source.split.open-file-cost" -> "0", - "spark.sql.leafNodeDefaultParallelism" -> "1") { + "spark.paimon.source.split.open-file-cost" -> "0") { assert(paimonScan().inputPartitions.length == 1) } } @@ -176,4 +180,60 @@ class BinPackingSplitsTest extends PaimonSparkTestBase { } } } + + private def newDataSplit( + prefix: String, + fileSizes: Seq[Long], + rawConvertible: Boolean = false, + deletionFileLength: Option[Long] = None): DataSplit = { + val files = fileSizes.zipWithIndex.map { + case (fileSize, index) => newDataFile(s"$prefix-$index.parquet", fileSize) + } + newDataSplitFromFiles(files, rawConvertible, deletionFileLength, prefix) + } + + private def newDataSplitFromFiles( + files: Seq[DataFileMeta], + rawConvertible: Boolean, + deletionFileLength: Option[Long] = None, + deletionFilePrefix: String = "delete"): DataSplit = { + val builder = DataSplit + .builder() + .withSnapshot(1) + .withBucket(0) + .withPartition(BinaryRow.EMPTY_ROW) + .withDataFiles(files.asJava) + .rawConvertible(rawConvertible) + .withBucketPath("no use") + deletionFileLength.foreach { + length => + builder.withDataDeletionFiles( + files.indices + .map(index => new DeletionFile(s"$deletionFilePrefix-$index.dv", 0, length, null)) + .asJava) + } + builder.build() + } + + private def newDataFile( + fileName: String, + fileSize: Long, + rowCount: Long = 1, + maxSequenceNumber: Long = 0): DataFileMeta = { + DataFileMeta.forAppend( + fileName, + fileSize, + rowCount, + null, + 0, + maxSequenceNumber, + 1, + new java.util.ArrayList[String](), + null, + FileSource.APPEND, + null, + null, + null, + null) + } } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala index dcb99224cd1d..f22dab4b9cd0 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSinkTest.scala @@ -263,6 +263,7 @@ class PaimonSinkTest extends PaimonSparkTestBase with StreamTest { .writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .option("write.merge-schema", "true") + .option("write.merge-schema.type-widening", "true") .option("write.merge-schema.explicit-cast", "true") .format("paimon") .start(location) diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommandTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommandTest.scala new file mode 100644 index 000000000000..ef4046ccc912 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/commands/PaimonDynamicPartitionOverwriteCommandTest.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.commands + +import org.apache.paimon.spark.PaimonSparkTestBase + +import org.apache.spark.sql.PaimonUtils.{createDataset, createNewDataFrame} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +import java.io.File + +class PaimonDynamicPartitionOverwriteCommandTest extends PaimonSparkTestBase { + + import testImplicits._ + + test("dynamic overwrite consumes optimizer-safe child plan") { + withTempDir { + tempDir => + withTable("paimon_target") { + try { + sql("CREATE TABLE paimon_target (id INT, pt STRING) PARTITIONED BY (pt)") + sql("INSERT INTO paimon_target VALUES (3, 'p3')") + + val srcPath = new File(tempDir, "parquet_src").getCanonicalPath + val dimPath = new File(tempDir, "dim").getCanonicalPath + Seq((1, "p1"), (2, "p2"), (3, "p3")) + .toDF("id", "pt") + .write + .partitionBy("pt") + .parquet(srcPath) + spark.read.parquet(srcPath).createOrReplaceTempView("parquet_src") + + Seq(("p1", "use"), ("p2", "use"), ("p3", "skip")) + .toDF("pt", "tag") + .write + .parquet(dimPath) + spark.read.parquet(dimPath).createOrReplaceTempView("dim") + + withSparkSQLConf( + "spark.sql.sources.partitionOverwriteMode" -> "dynamic", + "spark.paimon.write.use-v2-write" -> "false", + "spark.sql.optimizer.dynamicPartitionPruning.enabled" -> "true", + "spark.sql.optimizer.dynamicPartitionPruning.useStats" -> "false", + "spark.sql.optimizer.dynamicPartitionPruning.fallbackFilterRatio" -> "1.0", + "spark.sql.autoBroadcastJoinThreshold" -> "10485760", + "spark.sql.adaptive.enabled" -> "false" + ) { + val insertSql = + """ + |INSERT OVERWRITE paimon_target + |SELECT /*+ BROADCAST(dim) */ s.id + 100 AS id, s.pt + |FROM parquet_src s JOIN dim ON s.pt = dim.pt + |WHERE dim.tag = 'use' + |""".stripMargin + + val parsed = spark.sessionState.sqlParser.parsePlan(insertSql) + val analyzed = + spark.sessionState.analyzer.executeAndCheck(parsed, new QueryPlanningTracker) + val cmd = analyzed.asInstanceOf[PaimonDynamicPartitionOverwriteCommand] + + val optimizedChild = spark.sessionState.optimizer.execute(cmd.query) + assert( + hasDynamicPruningSubquery(optimizedChild), + s"Expected dynamic pruning in optimized child, got:\n$optimizedChild") + + val cmdWithOptimizedQuery = cmd.copy(query = optimizedChild) + val writeDataFrame = + createNewDataFrame(createDataset(spark, cmdWithOptimizedQuery.query)) + assert( + !hasDynamicPruningSubquery(writeDataFrame.queryExecution.logical), + s"Expected writer DataFrame to be free of dynamic pruning, got:\n" + + writeDataFrame.queryExecution.logical + ) + + cmdWithOptimizedQuery.run(spark) + checkAnswer( + sql("SELECT * FROM paimon_target ORDER BY id"), + Seq(Row(3, "p3"), Row(101, "p1"), Row(102, "p2"))) + } + } finally { + spark.catalog.dropTempView("parquet_src") + spark.catalog.dropTempView("dim") + } + } + } + } + + private def hasDynamicPruningSubquery(plan: LogicalPlan): Boolean = { + plan.exists(_.expressions.exists(_.exists { + case _: DynamicPruningSubquery => true + case _ => false + })) + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoCastValidatorTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoCastValidatorTest.scala new file mode 100644 index 000000000000..ab006e2570ad --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoCastValidatorTest.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution + +import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.types.{DataField, DataTypes} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +class CopyIntoCastValidatorTest extends PaimonSparkTestBase { + + private lazy val validator = new CopyIntoCastValidator(spark) + + test("buildParquetCastValidation: no validation needed when all columns are compatible") { + val schema = StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(1, "Alice"), Row(2, "Bob"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val setup = validator.buildParquetCastValidation(df, targetColumns, writableColumns, fields) + assert(setup.castColMapping.nonEmpty) + assert(setup.badCastFilter.isDefined) + } + + test("buildParquetCastValidation: exclude specified columns") { + val schema = StructType( + Seq( + StructField("id", IntegerType), + StructField("name", StringType), + StructField("__file__", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(1, "Alice", "file1.csv"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val setup = validator.buildParquetCastValidation( + df, + targetColumns, + writableColumns, + fields, + excludeCols = Set("__file__")) + assert(!setup.castColMapping.contains("__file__")) + } + + test("buildTextCastValidation: no validation for string columns") { + val schema = StructType(Seq(StructField("id", StringType), StructField("name", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "Alice"), Row("2", "Bob"))), + schema) + + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.STRING()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val setup = validator.buildTextCastValidation(df, writableColumns, fields) + assert(setup.castColMapping.isEmpty) + assert(setup.badCastFilter.isEmpty) + } + + test("buildTextCastValidation: add validation columns for non-string types") { + val schema = StructType(Seq(StructField("id", StringType), StructField("age", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "30"), Row("2", "25"))), + schema) + + val writableColumns = Seq("id", "age") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "age", DataTypes.INT()) + ) + + val setup = validator.buildTextCastValidation(df, writableColumns, fields) + assert(setup.castColMapping.size == 2) + assert(setup.badCastFilter.isDefined) + assert(setup.castColMapping.contains("id")) + assert(setup.castColMapping.contains("age")) + } + + test("validateParquetCast: pass when all casts are valid") { + val schema = StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(1, "Alice"), Row(2, "Bob"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + // Should not throw + validator.validateParquetCast(df, targetColumns, writableColumns, fields) + } + + test("validateParquetCast: throw exception when cast fails") { + val schema = StructType(Seq(StructField("id", StringType), StructField("name", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("abc", "Alice"), Row("2", "Bob"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val exception = intercept[IllegalArgumentException] { + validator.validateParquetCast(df, targetColumns, writableColumns, fields) + } + assert(exception.getMessage.contains("ON_ERROR = ABORT_STATEMENT")) + assert(exception.getMessage.contains("Cast failure")) + } + + test("castColumns: cast all columns to target types") { + val schema = StructType(Seq(StructField("id", StringType), StructField("age", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "30"), Row("2", "25"))), + schema) + + val writableColumns = Seq("id", "age") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "age", DataTypes.INT()) + ) + + val castedDf = validator.castColumns(df, writableColumns, fields) + assert(castedDf.schema("id").dataType == IntegerType) + assert(castedDf.schema("age").dataType == IntegerType) + + val rows = castedDf.collect() + assert(rows(0).getInt(0) == 1) + assert(rows(0).getInt(1) == 30) + } + + test("castAndValidate: pass when all casts are valid") { + val schema = StructType(Seq(StructField("id", StringType), StructField("age", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "30"), Row("2", "25"))), + schema) + + val writableColumns = Seq("id", "age") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "age", DataTypes.INT()) + ) + + val castedDf = validator.castAndValidate(df, writableColumns, fields) + assert(castedDf.schema("id").dataType == IntegerType) + assert(castedDf.schema("age").dataType == IntegerType) + } + + test("castAndValidate: throw exception when cast fails") { + val schema = StructType(Seq(StructField("id", StringType), StructField("age", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "invalid"), Row("2", "25"))), + schema) + + val writableColumns = Seq("id", "age") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "age", DataTypes.INT()) + ) + + val exception = intercept[IllegalArgumentException] { + validator.castAndValidate(df, writableColumns, fields) + } + assert(exception.getMessage.contains("ON_ERROR = ABORT_STATEMENT")) + assert(exception.getMessage.contains("Cast failure")) + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilderTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilderTest.scala new file mode 100644 index 000000000000..542932b53c6c --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoDataFrameBuilderTest.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution + +import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.spark.catalyst.plans.logical.{CopyFileFormat, FileFormatType} +import org.apache.paimon.types.{DataField, DataTypes} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +class CopyIntoDataFrameBuilderTest extends PaimonSparkTestBase { + + private def createBuilder( + formatType: FileFormatType, + columns: Option[Seq[String]] = None): CopyIntoDataFrameBuilder = { + val fileFormat = new CopyFileFormat(formatType = formatType, options = Map.empty) + new CopyIntoDataFrameBuilder(spark, fileFormat, columns) + } + + test("buildStringSchema: CSV format with positional columns") { + val builder = createBuilder(FileFormatType.CSV) + val targetColumns = Seq("id", "name", "age") + val schema = builder.buildStringSchema(targetColumns) + + assert(schema.fields.length == 3) + assert(schema.fields(0).name == "_c0") + assert(schema.fields(1).name == "_c1") + assert(schema.fields(2).name == "_c2") + assert(schema.fields.forall(_.dataType == StringType)) + } + + test("buildStringSchema: JSON format with named columns") { + val builder = createBuilder(FileFormatType.JSON) + val targetColumns = Seq("id", "name", "age") + val schema = builder.buildStringSchema(targetColumns) + + assert(schema.fields.length == 3) + assert(schema.fields(0).name == "id") + assert(schema.fields(1).name == "name") + assert(schema.fields(2).name == "age") + assert(schema.fields.forall(_.dataType == StringType)) + } + + test("buildParquetDataFrame: map source columns to target by name") { + val builder = createBuilder(FileFormatType.PARQUET) + val schema = StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(1, "Alice"), Row(2, "Bob"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val result = builder.buildParquetDataFrame(df, targetColumns, writableColumns, fields) + assert(result.columns.toSeq == Seq("id", "name")) + assert(result.count() == 2) + } + + test("buildParquetDataFrame: fill NULL for missing source columns") { + val builder = createBuilder(FileFormatType.PARQUET) + val schema = StructType(Seq(StructField("id", IntegerType))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row(1), Row(2))), schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val result = builder.buildParquetDataFrame(df, targetColumns, writableColumns, fields) + assert(result.columns.toSeq == Seq("id", "name")) + val rows = result.collect() + assert(rows(0).isNullAt(1)) // name should be NULL + } + + test("buildParquetDataFrame: fill default value for unmapped columns") { + val builder = createBuilder(FileFormatType.PARQUET, Some(Seq("id"))) + val schema = StructType(Seq(StructField("id", IntegerType))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row(1), Row(2))), schema) + + val targetColumns = Seq("id") + val writableColumns = Seq("id", "status") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "status", DataTypes.STRING(), "'active'") + ) + + val result = builder.buildParquetDataFrame(df, targetColumns, writableColumns, fields) + assert(result.columns.toSeq == Seq("id", "status")) + } + + test("buildFinalDataFrame: rename CSV positional columns to named columns") { + val builder = createBuilder(FileFormatType.CSV, Some(Seq("id", "name"))) + val schema = StructType(Seq(StructField("_c0", StringType), StructField("_c1", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "Alice"), Row("2", "Bob"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val result = builder.buildFinalDataFrame(df, targetColumns, writableColumns, fields) + assert(result.columns.toSeq == Seq("id", "name")) + } + + test("buildFinalDataFrame: keep JSON named columns as-is") { + val builder = createBuilder(FileFormatType.JSON) + val schema = StructType(Seq(StructField("id", StringType), StructField("name", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "Alice"), Row("2", "Bob"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val result = builder.buildFinalDataFrame(df, targetColumns, writableColumns, fields) + assert(result.columns.toSeq == Seq("id", "name")) + } + + test("buildFinalDataFrame: preserve extra columns") { + val builder = createBuilder(FileFormatType.CSV, Some(Seq("id", "name"))) + val schema = StructType( + Seq( + StructField("_c0", StringType), + StructField("_c1", StringType), + StructField("__file__", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("1", "Alice", "file1.csv"))), + schema) + + val targetColumns = Seq("id", "name") + val writableColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT()), + new DataField(1, "name", DataTypes.STRING()) + ) + + val result = + builder.buildFinalDataFrame(df, targetColumns, writableColumns, fields, Seq("__file__")) + assert(result.columns.toSeq == Seq("id", "name", "__file__")) + } + + test("applyNullTransforms: replace NULL_IF values with null") { + val fileFormat = new CopyFileFormat( + formatType = FileFormatType.CSV, + options = Map("NULL_IF" -> Seq("N/A", "NULL").mkString(CopyFileFormat.LIST_SEPARATOR))) + val builder = new CopyIntoDataFrameBuilder(spark, fileFormat, None) + + val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("value", "N/A"), Row("NULL", "data"))), + schema) + + val result = builder.applyNullTransforms(df, Seq("col1", "col2")) + val rows = result.collect() + assert(rows(0).getString(0) == "value") + assert(rows(0).isNullAt(1)) // "N/A" -> null + assert(rows(1).isNullAt(0)) // "NULL" -> null + assert(rows(1).getString(1) == "data") + } + + test("applyNullTransforms: replace empty strings with null when EMPTY_FIELD_AS_NULL is true") { + val fileFormat = new CopyFileFormat( + formatType = FileFormatType.CSV, + options = Map("EMPTY_FIELD_AS_NULL" -> "TRUE")) + val builder = new CopyIntoDataFrameBuilder(spark, fileFormat, None) + + val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("value", ""), Row("", "data"))), + schema) + + val result = builder.applyNullTransforms(df, Seq("col1", "col2")) + val rows = result.collect() + assert(rows(0).getString(0) == "value") + assert(rows(0).isNullAt(1)) // "" -> null + assert(rows(1).isNullAt(0)) // "" -> null + assert(rows(1).getString(1) == "data") + } + + test("applyNullTransforms: apply both NULL_IF and EMPTY_FIELD_AS_NULL") { + val fileFormat = new CopyFileFormat( + formatType = FileFormatType.CSV, + options = Map( + "NULL_IF" -> Seq("N/A").mkString(CopyFileFormat.LIST_SEPARATOR), + "EMPTY_FIELD_AS_NULL" -> "TRUE")) + val builder = new CopyIntoDataFrameBuilder(spark, fileFormat, None) + + val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("N/A", ""), Row("value", "data"))), + schema) + + val result = builder.applyNullTransforms(df, Seq("col1", "col2")) + val rows = result.collect() + assert(rows(0).isNullAt(0)) // "N/A" -> null + assert(rows(0).isNullAt(1)) // "" -> null + assert(rows(1).getString(0) == "value") + assert(rows(1).getString(1) == "data") + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoHelperTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoHelperTest.scala new file mode 100644 index 000000000000..d09b374efb93 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/execution/CopyIntoHelperTest.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution + +import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.types.{DataField, DataTypes} + +class CopyIntoHelperTest extends PaimonSparkTestBase { + + test("resolveTargetColumns: use all writable columns when columns is None") { + val writableColumns = Seq("id", "name", "age") + val result = CopyIntoHelper.resolveTargetColumns(spark, None, writableColumns) + assert(result == writableColumns) + } + + test("resolveTargetColumns: resolve specified columns case-insensitively") { + val writableColumns = Seq("id", "name", "age") + val columns = Some(Seq("ID", "Name")) + val result = CopyIntoHelper.resolveTargetColumns(spark, columns, writableColumns) + assert(result == Seq("id", "name")) + } + + test("resolveTargetColumns: throw exception for non-existent column") { + val writableColumns = Seq("id", "name", "age") + val columns = Some(Seq("id", "invalid_col")) + val exception = intercept[IllegalArgumentException] { + CopyIntoHelper.resolveTargetColumns(spark, columns, writableColumns) + } + assert(exception.getMessage.contains("invalid_col")) + assert(exception.getMessage.contains("does not exist")) + } + + test("resolveTargetColumns: throw exception for duplicate columns") { + val writableColumns = Seq("id", "name", "age") + val columns = Some(Seq("id", "ID")) + val exception = intercept[IllegalArgumentException] { + CopyIntoHelper.resolveTargetColumns(spark, columns, writableColumns) + } + assert(exception.getMessage.contains("Duplicate columns")) + } + + test("validateNonNullableDefaults: pass when all non-nullable columns are mapped") { + val writableColumns = Seq("id", "name", "age") + val targetColumns = Seq("id", "name", "age") + val fields = Seq( + new DataField(0, "id", DataTypes.INT().notNull()), + new DataField(1, "name", DataTypes.STRING()), + new DataField(2, "age", DataTypes.INT()) + ) + // Should not throw + CopyIntoHelper.validateNonNullableDefaults( + Some(targetColumns), + writableColumns, + targetColumns, + fields) + } + + test("validateNonNullableDefaults: pass when unmapped non-nullable column has default") { + val writableColumns = Seq("id", "name", "status") + val targetColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT().notNull()), + new DataField(1, "name", DataTypes.STRING()), + new DataField(2, "status", DataTypes.STRING().notNull(), null, "default value") + ) + // Should not throw + CopyIntoHelper.validateNonNullableDefaults( + Some(targetColumns), + writableColumns, + targetColumns, + fields) + } + + test("validateNonNullableDefaults: throw when unmapped non-nullable column has no default") { + val writableColumns = Seq("id", "name", "status") + val targetColumns = Seq("id", "name") + val fields = Seq( + new DataField(0, "id", DataTypes.INT().notNull()), + new DataField(1, "name", DataTypes.STRING()), + new DataField(2, "status", DataTypes.STRING().notNull()) + ) + val exception = intercept[IllegalArgumentException] { + CopyIntoHelper.validateNonNullableDefaults( + Some(targetColumns), + writableColumns, + targetColumns, + fields) + } + assert(exception.getMessage.contains("status")) + assert(exception.getMessage.contains("not in the column list")) + assert(exception.getMessage.contains("no default value")) + } + + test("safeTempCol: generate unique column name") { + val existingColumns = Set("col1", "col2", "__temp") + val result = CopyIntoHelper.safeTempCol(spark, "__new", existingColumns) + assert(result == "__new") + } + + test("safeTempCol: add prefix when name conflicts") { + val existingColumns = Set("col1", "col2", "__temp") + val result = CopyIntoHelper.safeTempCol(spark, "__temp", existingColumns) + assert(result == "___temp") + } + + test("safeTempCol: handle case-insensitive conflicts") { + val existingColumns = Set("Col1", "COL2") + val result = CopyIntoHelper.safeTempCol(spark, "col1", existingColumns) + assert(result == "_col1") + } + + test("safeTempCol: add multiple prefixes for multiple conflicts") { + val existingColumns = Set("__temp", "___temp") + val result = CopyIntoHelper.safeTempCol(spark, "__temp", existingColumns) + assert(result == "____temp") + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CreateGlobalIndexProcedureTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CreateGlobalIndexProcedureTest.scala index b45475ff0c82..cba3c3cd63f3 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CreateGlobalIndexProcedureTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CreateGlobalIndexProcedureTest.scala @@ -34,115 +34,6 @@ import scala.collection.immutable class CreateGlobalIndexProcedureTest extends PaimonSparkTestBase with StreamTest { - test("create bitmap global index") { - withTable("T") { - spark.sql(""" - |CREATE TABLE T (id INT, name STRING) - |TBLPROPERTIES ( - | 'bucket' = '-1', - | 'global-index.row-count-per-shard' = '10000', - | 'row-tracking.enabled' = 'true', - | 'data-evolution.enabled' = 'true') - |""".stripMargin) - - val values = - (0 until 100000).map(i => s"($i, 'name_$i')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - val output = - spark - .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'name', index_type => 'bitmap')") - .collect() - .head - - assert(output.getBoolean(0)) - - val table = loadTable("T") - val bitmapEntries = table - .store() - .newIndexFileHandler() - .scanEntries() - .asScala - .filter(_.indexFile().indexType() == "bitmap") - assert(bitmapEntries.nonEmpty) - val totalRowCount = bitmapEntries.map(_.indexFile().rowCount()).sum - assert(totalRowCount == 100000L) - } - } - - test("create bitmap global index with partition") { - withTable("T") { - spark.sql(""" - |CREATE TABLE T (id INT, name STRING, pt STRING) - |TBLPROPERTIES ( - | 'bucket' = '-1', - | 'global-index.row-count-per-shard' = '10000', - | 'row-tracking.enabled' = 'true', - | 'data-evolution.enabled' = 'true') - | PARTITIONED BY (pt) - |""".stripMargin) - - var values = - (0 until 65000).map(i => s"($i, 'name_$i', 'p0')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - values = (0 until 35000).map(i => s"($i, 'name_$i', 'p1')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - values = (0 until 22222).map(i => s"($i, 'name_$i', 'p0')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - values = (0 until 100).map(i => s"($i, 'name_$i', 'p1')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - values = (0 until 100).map(i => s"($i, 'name_$i', 'p2')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - values = (0 until 33333).map(i => s"($i, 'name_$i', 'p2')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - values = (0 until 33333).map(i => s"($i, 'name_$i', 'p1')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - val output = - spark - .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'name', index_type => 'bitmap')") - .collect() - .head - - assert(output.getBoolean(0)) - - val table = loadTable("T") - val bitmapEntries = table - .store() - .newIndexFileHandler() - .scanEntries() - .asScala - .filter(_.indexFile().indexType() == "bitmap") - assert(bitmapEntries.nonEmpty) - - val ranges = bitmapEntries - .map( - s => - new Range( - s.indexFile().globalIndexMeta().rowRangeStart(), - s.indexFile().globalIndexMeta().rowRangeEnd())) - .toList - .asJava - val mergedRange = Range.sortAndMergeOverlap(ranges, true) - assert(mergedRange.size() == 1) - assert(mergedRange.get(0).equals(new Range(0, 189087))) - val totalRowCount = bitmapEntries - .map( - x => - x.indexFile() - .globalIndexMeta() - .rowRangeEnd() - x.indexFile().globalIndexMeta().rowRangeStart() + 1) - .sum - assert(totalRowCount == 189088L) - } - } - test("create btree global index") { withTable("T") { spark.sql(""" @@ -300,52 +191,6 @@ class CreateGlobalIndexProcedureTest extends PaimonSparkTestBase with StreamTest } } - test("create bitmap global index with external path") { - withTable("T") { - val tempIndexDir: File = Utils.createTempDir - val indexPath = "file:" + tempIndexDir.toString - spark.sql(s""" - |CREATE TABLE T (id INT, name STRING) - |TBLPROPERTIES ( - | 'bucket' = '-1', - | 'global-index.row-count-per-shard' = '10000', - | 'global-index.external-path' = '$indexPath', - | 'row-tracking.enabled' = 'true', - | 'data-evolution.enabled' = 'true') - |""".stripMargin) - - val values = - (0 until 100000).map(i => s"($i, 'name_$i')").mkString(",") - spark.sql(s"INSERT INTO T VALUES $values") - - val output = - spark - .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'name', index_type => 'bitmap')") - .collect() - .head - - assert(output.getBoolean(0)) - - val table = loadTable("T") - val bitmapEntries = table - .store() - .newIndexFileHandler() - .scanEntries() - .asScala - .filter(_.indexFile().indexType() == "bitmap") - assert(bitmapEntries.nonEmpty) - val totalRowCount = bitmapEntries.map(_.indexFile().rowCount()).sum - assert(totalRowCount == 100000L) - for (entry <- bitmapEntries) { - assert( - entry - .indexFile() - .externalPath() - .startsWith(indexPath + "/" + entry.indexFile().fileName())) - } - } - } - private def assertMultiplePartitionsResult( tableName: String, rowCount: Long, diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/DropGlobalIndexProcedureTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/DropGlobalIndexProcedureTest.scala index 88caadcfdaf1..fd76da2cb8c6 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/DropGlobalIndexProcedureTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/DropGlobalIndexProcedureTest.scala @@ -29,7 +29,7 @@ import scala.collection.JavaConverters._ class DropGlobalIndexProcedureTest extends PaimonSparkTestBase with StreamTest { - test("drop bitmap global index") { + test("drop btree global index") { withTable("T") { spark.sql(""" |CREATE TABLE T (id INT, name STRING) @@ -46,42 +46,42 @@ class DropGlobalIndexProcedureTest extends PaimonSparkTestBase with StreamTest { var output = spark - .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'name', index_type => 'bitmap')") + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'name', index_type => 'btree')") .collect() .head assert(output.getBoolean(0)) var table = loadTable("T") - var bitmapEntries = table + var btreeEntries = table .store() .newIndexFileHandler() .scanEntries() .asScala - .filter(_.indexFile().indexType() == "bitmap") - assert(bitmapEntries.nonEmpty) - val totalRowCount = bitmapEntries.map(_.indexFile().rowCount()).sum + .filter(_.indexFile().indexType() == "btree") + assert(btreeEntries.nonEmpty) + val totalRowCount = btreeEntries.map(_.indexFile().rowCount()).sum assert(totalRowCount == 100000L) output = spark - .sql("CALL sys.drop_global_index(table => 'test.T', index_column => 'name', index_type => 'bitmap')") + .sql("CALL sys.drop_global_index(table => 'test.T', index_column => 'name', index_type => 'btree')") .collect() .head assert(output.getBoolean(0)) table = loadTable("T") - bitmapEntries = table + btreeEntries = table .store() .newIndexFileHandler() .scanEntries() .asScala - .filter(_.indexFile().indexType() == "bitmap") - assert(bitmapEntries.isEmpty) + .filter(_.indexFile().indexType() == "btree") + assert(btreeEntries.isEmpty) } } - test("create bitmap global index with partition") { + test("create btree global index with partition") { withTable("T") { spark.sql(""" |CREATE TABLE T (id INT, name STRING, pt STRING) @@ -117,22 +117,22 @@ class DropGlobalIndexProcedureTest extends PaimonSparkTestBase with StreamTest { var output = spark - .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'name', index_type => 'bitmap')") + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'name', index_type => 'btree')") .collect() .head assert(output.getBoolean(0)) val table = loadTable("T") - var bitmapEntries = table + var btreeEntries = table .store() .newIndexFileHandler() .scanEntries() .asScala - .filter(_.indexFile().indexType() == "bitmap") - assert(bitmapEntries.nonEmpty) + .filter(_.indexFile().indexType() == "btree") + assert(btreeEntries.nonEmpty) - var ranges = bitmapEntries + var ranges = btreeEntries .map( s => new Range( @@ -143,7 +143,7 @@ class DropGlobalIndexProcedureTest extends PaimonSparkTestBase with StreamTest { var mergedRange = Range.sortAndMergeOverlap(ranges, true) assert(mergedRange.size() == 1) assert(mergedRange.get(0).equals(new Range(0, 189087))) - val totalRowCount = bitmapEntries + val totalRowCount = btreeEntries .map( x => x.indexFile() @@ -153,19 +153,19 @@ class DropGlobalIndexProcedureTest extends PaimonSparkTestBase with StreamTest { assert(totalRowCount == 189088L) output = spark - .sql("CALL sys.drop_global_index(table => 'test.T', index_column => 'name', index_type => 'bitmap', partitions => 'pt=\"p1\"')") + .sql("CALL sys.drop_global_index(table => 'test.T', index_column => 'name', index_type => 'btree', partitions => 'pt=\"p1\"')") .collect() .head - bitmapEntries = table + btreeEntries = table .store() .newIndexFileHandler() .scanEntries() .asScala - .filter(_.indexFile().indexType() == "bitmap") - assert(bitmapEntries.nonEmpty) + .filter(_.indexFile().indexType() == "btree") + assert(btreeEntries.nonEmpty) - ranges = bitmapEntries + ranges = btreeEntries .map( s => new Range( diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/MigrateTableProcedureTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/MigrateTableProcedureTest.scala index 8befd3082c5e..dcd832ba0d43 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/MigrateTableProcedureTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/MigrateTableProcedureTest.scala @@ -153,6 +153,21 @@ class MigrateTableProcedureTest extends PaimonHiveTestBase { } }) + test("Paimon migrate table procedure: migrate empty partitioned table") { + withTable(s"hive_tbl$random") { + spark.sql(s""" + |CREATE TABLE hive_tbl$random (id STRING, name STRING, pt STRING) + |USING parquet + |PARTITIONED BY (pt) + |""".stripMargin) + + spark.sql( + s"CALL sys.migrate_table(source_type => 'hive', table => '$hiveDbName.hive_tbl$random', options => 'file.format=parquet')") + + checkAnswer(spark.sql(s"SELECT * FROM hive_tbl$random"), Nil) + } + } + test(s"Paimon migrate table procedure: migrate partitioned table with null partition") { withTable(s"hive_tbl$random") { // create hive table diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala index 03e109919b53..c4eb2cd6443a 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CatalogQualifiedCreateTableLikeTest.scala @@ -129,6 +129,16 @@ class CatalogQualifiedCreateTableLikeTest extends PaimonSparkTestBase { Assertions.assertEquals("tag", nonReservedIdentifierCommand.targetIdent.name()) Assertions.assertEquals(Seq("test"), nonReservedIdentifierCommand.targetIdent.namespace().toSeq) + val continueCommand = + parseCreateTableLikeCommand("CREATE TABLE paimon.test.continue LIKE paimon.test.source_tbl") + Assertions.assertEquals("continue", continueCommand.targetIdent.name()) + Assertions.assertEquals(Seq("test"), continueCommand.targetIdent.namespace().toSeq) + + val skipFileCommand = + parseCreateTableLikeCommand("CREATE TABLE paimon.test.skip_file LIKE paimon.test.source_tbl") + Assertions.assertEquals("skip_file", skipFileCommand.targetIdent.name()) + Assertions.assertEquals(Seq("test"), skipFileCommand.targetIdent.namespace().toSeq) + val nestedIdentifierCommand = parseCreateTableLikeCommand( "CREATE TABLE paimon.test.extra.target_tbl LIKE paimon.test.extra.source_tbl") diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoOnErrorTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoOnErrorTest.scala new file mode 100644 index 000000000000..11942c7dfda4 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoOnErrorTest.scala @@ -0,0 +1,527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.spark.sql.Row + +trait CopyIntoOnErrorTest { self: CopyIntoTestBase => + + test("COPY INTO: ON_ERROR = CONTINUE skips bad CSV rows") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_csv") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_csv (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "data.csv", "1,Alice\nabc,Bob\n3,Charlie\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_csv + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + + val rows = result.collect() + assert(rows.length == 1) + assert(rows(0).getLong(4) > 0) + assert(rows(0).getString(1) == "PARTIALLY_LOADED") + + val data = spark.sql(s"SELECT * FROM $dbName0.copy_continue_csv ORDER BY id").collect() + assert( + data.length == 2, + s"Expected 2 rows but got ${data.length} — bad row should NOT be in table") + assert(data(0).getInt(0) == 1 && data(0).getString(1) == "Alice") + assert(data(1).getInt(0) == 3 && data(1).getString(1) == "Charlie") + assert( + !data.exists(r => r.getString(1) == "Bob"), + "Bad row with 'abc' as id should not be in the table") + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_csv") + } + + test("COPY INTO: ON_ERROR = CONTINUE with no errors behaves like ABORT") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_ok") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_ok (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "data.csv", "1,Alice\n2,Bob\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_ok + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + + val rows = result.collect() + assert(rows.length == 1) + assert(rows(0).getString(1) == "LOADED") + assert(rows(0).getLong(4) == 0L) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_continue_ok ORDER BY id"), + Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_ok") + } + + test("COPY INTO: ON_ERROR = SKIP_FILE skips file with bad data") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile") + spark.sql(s"CREATE TABLE $dbName0.copy_skipfile (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "good.csv", "1,Alice\n2,Bob\n") + createCsvFile(dir, "bad.csv", "abc,Charlie\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_skipfile + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = SKIP_FILE + |""".stripMargin) + + val rows = result.collect() + val loaded = rows.filter(_.getString(1) == "LOADED") + val failed = rows.filter(_.getString(1) == "LOAD_FAILED") + assert(loaded.length == 1) + assert(failed.length == 1) + assert(failed(0).getLong(4) >= 1L) + assert(failed(0).getString(5) != null) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_skipfile ORDER BY id"), + Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile") + } + + test("COPY INTO: ON_ERROR = SKIP_FILE discards good rows in a file that also has bad rows") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_mixed") + spark.sql(s"CREATE TABLE $dbName0.copy_skipfile_mixed (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "good.csv", "1,Alice\n2,Bob\n") + // 2 good rows (Carol, Eve) and 1 bad row (non-numeric id) + createCsvFile(dir, "mixed.csv", "3,Carol\nabc,Dave\n4,Eve\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_skipfile_mixed + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = SKIP_FILE + |""".stripMargin) + + val rows = result.collect() + assert(rows.count(_.getString(1) == "LOADED") == 1) + assert(rows.count(_.getString(1) == "LOAD_FAILED") == 1) + + // The whole mixed.csv is rejected: its good rows (Carol, Eve) must not be loaded. + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_skipfile_mixed ORDER BY id"), + Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_mixed") + } + + test("COPY INTO: ON_ERROR = SKIP_FILE with all good files") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_ok") + spark.sql(s"CREATE TABLE $dbName0.copy_skipfile_ok (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "a.csv", "1,Alice\n") + createCsvFile(dir, "b.csv", "2,Bob\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_skipfile_ok + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = SKIP_FILE + |""".stripMargin) + + val rows = result.collect() + assert(rows.forall(_.getString(1) == "LOADED")) + assert(rows.length == 2) + + assert(spark.sql(s"SELECT * FROM $dbName0.copy_skipfile_ok").count() == 2) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_ok") + } + + test("COPY INTO: ON_ERROR = CONTINUE with JSON malformed records") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_json") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_json (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id": "1", "name": "Alice"} + |{bad json here} + |{"id": "3", "name": "Charlie"} + |""".stripMargin + ) + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_json + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |ON_ERROR = CONTINUE + |""".stripMargin) + + val rows = result.collect() + assert(rows(0).getLong(4) > 0) + + val data = spark.sql(s"SELECT * FROM $dbName0.copy_continue_json ORDER BY id").collect() + assert( + data.length == 2, + s"Expected 2 rows but got ${data.length} — malformed JSON should NOT be in table") + assert(data(0).getInt(0) == 1) + assert(data(1).getInt(0) == 3) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_json") + } + + test("COPY INTO: ON_ERROR = SKIP_FILE with JSON") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_json") + spark.sql(s"CREATE TABLE $dbName0.copy_skipfile_json (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile(dir, "good.json", """{"id": "1", "name": "Alice"}""" + "\n") + createJsonFile(dir, "bad.json", "{bad json}\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_skipfile_json + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |ON_ERROR = SKIP_FILE + |""".stripMargin) + + val rows = result.collect() + val loaded = rows.filter(_.getString(1) == "LOADED") + val failed = rows.filter(_.getString(1) == "LOAD_FAILED") + assert(loaded.length == 1) + assert(failed.length == 1) + + checkAnswer(spark.sql(s"SELECT * FROM $dbName0.copy_skipfile_json"), Seq(Row(1, "Alice"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_json") + } + + test("COPY INTO: ABORT_STATEMENT still fails on bad data") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_abort_explicit") + spark.sql(s"CREATE TABLE $dbName0.copy_abort_explicit (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "data.csv", "abc,Alice\n") + + val e = intercept[Exception] { + spark.sql(s"""COPY INTO $dbName0.copy_abort_explicit + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = ABORT_STATEMENT + |""".stripMargin) + } + assert( + e.getMessage.contains("Cast failure") || + e.getMessage.contains("ABORT_STATEMENT") || + e.getMessage.contains("CAST_INVALID_INPUT") || + e.getCause != null) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_abort_explicit") + } + + test("COPY INTO: ON_ERROR = CONTINUE with Parquet cast errors") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_pq") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_pq (id INT, name STRING)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = StructType(Seq(StructField("id", StringType), StructField("name", StringType))) + createParquetSingleFile( + dir, + "data.parquet", + Seq(Row("1", "Alice"), Row("abc", "Bad"), Row("3", "Charlie")), + schema) + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_pq + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = PARQUET) + |ON_ERROR = CONTINUE + |""".stripMargin) + + val rows = result.collect() + assert(rows.length == 1) + assert(rows(0).getLong(4) > 0) + + val data = spark.sql(s"SELECT * FROM $dbName0.copy_continue_pq ORDER BY id").collect() + assert( + data.length == 2, + s"Expected 2 rows but got ${data.length} — bad row should NOT be in table") + assert(data(0).getInt(0) == 1) + assert(data(1).getInt(0) == 3) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_pq") + } + + test("COPY INTO: ON_ERROR = SKIP_FILE with Parquet") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_pq") + spark.sql(s"CREATE TABLE $dbName0.copy_skipfile_pq (id INT, name STRING)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + // Both files share the same physical schema (id STRING) so the directory read succeeds; + // good rows cast cleanly to the target INT column while bad.parquet's "abc" fails the cast, + // which is what SKIP_FILE must catch. Using a mixed INT/STRING footer here would instead + // crash Spark's vectorized Parquet reader before any cast validation runs. + val schema = + StructType(Seq(StructField("id", StringType), StructField("name", StringType))) + createParquetSingleFile( + dir, + "good.parquet", + Seq(Row("1", "Alice"), Row("2", "Bob")), + schema) + + createParquetSingleFile(dir, "bad.parquet", Seq(Row("abc", "Bad")), schema) + + val result = spark.sql(s"""COPY INTO $dbName0.copy_skipfile_pq + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = PARQUET) + |ON_ERROR = SKIP_FILE + |""".stripMargin) + + val rows = result.collect() + val loaded = rows.filter(_.getString(1) == "LOADED") + val failed = rows.filter(_.getString(1) == "LOAD_FAILED") + assert(loaded.length == 1) + assert(failed.length == 1) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_skipfile_pq ORDER BY id"), + Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_skipfile_pq") + } + + test("COPY INTO: ON_ERROR = CONTINUE clean file has no first_error") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_multi") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_multi (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "good.csv", "1,Alice\n2,Bob\n") + createCsvFile(dir, "bad.csv", "abc,Charlie\n3,Dave\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_multi + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + + val rows = result.collect() + val goodFile = rows.find(_.getString(0) == "good.csv").get + val badFile = rows.find(_.getString(0) == "bad.csv").get + + assert(goodFile.getString(1) == "LOADED") + assert(goodFile.getLong(4) == 0L) + assert(goodFile.isNullAt(5), "Clean file should have null first_error") + + assert(badFile.getString(1) == "PARTIALLY_LOADED") + assert(badFile.getLong(4) > 0) + assert(!badFile.isNullAt(5), "Bad file should have non-null first_error") + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_multi") + } + + test("COPY INTO: ON_ERROR = CONTINUE with JSON cast errors reports correctly") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_json_cast") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_json_cast (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id": "1", "name": "Alice"} + |{"id": "abc", "name": "Bad"} + |{"id": "3", "name": "Charlie"} + |""".stripMargin + ) + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_json_cast + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |ON_ERROR = CONTINUE + |""".stripMargin) + + val rows = result.collect() + assert(rows.length == 1) + assert(rows(0).getLong(4) > 0, "Should report cast errors") + assert( + rows(0).getString(1) == "PARTIALLY_LOADED", + s"Expected PARTIALLY_LOADED but got ${rows(0).getString(1)}") + + val data = + spark.sql(s"SELECT * FROM $dbName0.copy_continue_json_cast ORDER BY id").collect() + assert(data.length == 2, s"Expected 2 rows but got ${data.length}") + assert(data(0).getInt(0) == 1) + assert(data(1).getInt(0) == 3) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_json_cast") + } + + test("COPY INTO: ON_ERROR = CONTINUE all rows fail reports LOAD_FAILED") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_allfail") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_allfail (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "allfail.csv", "abc,Alice\nxyz,Bob\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_allfail + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + + val rows = result.collect() + assert(rows.length == 1) + assert( + rows(0).getString(1) == "LOAD_FAILED", + s"Expected LOAD_FAILED when all rows fail, got ${rows(0).getString(1)}") + assert(rows(0).getLong(2) == 0L, "rows_loaded should be 0") + assert(rows(0).getLong(3) == 2L, "rows_parsed should be 2") + assert(rows(0).getLong(4) == 2L, "errors_seen should be 2") + + assert( + spark.sql(s"SELECT * FROM $dbName0.copy_continue_allfail").count() == 0, + "No rows should be in table") + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_allfail") + } + + test("COPY INTO: ON_ERROR = CONTINUE all rows fail does not block re-run") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_retry") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_retry (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "data.csv", "abc,Alice\n") + + // First run: all rows fail + val result1 = spark.sql(s"""COPY INTO $dbName0.copy_continue_retry + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + val rows1 = result1.collect() + assert(rows1(0).getString(1) == "LOAD_FAILED") + assert(spark.sql(s"SELECT * FROM $dbName0.copy_continue_retry").count() == 0) + + // Re-run without FORCE: file should NOT be skipped since 0 rows were loaded + val result2 = spark.sql(s"""COPY INTO $dbName0.copy_continue_retry + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + val rows2 = result2.collect() + assert( + rows2(0).getString(1) != "SKIPPED", + "File with 0 rows loaded should not be skipped on re-run") + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_retry") + } + + test("COPY INTO: ON_ERROR = CONTINUE skips CSV rows with fewer columns") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_fewer_cols") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_fewer_cols (id INT, name STRING, age INT)") + + withCsvDir { + dir => + createCsvFile(dir, "data.csv", "1,Alice\n2,Bob,20\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_fewer_cols + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + + // A row with fewer columns than the target schema is a malformed record + // (Spark CSV PERMISSIVE mode), so CONTINUE skips it and keeps the well-formed row. + val rows = result.collect() + assert(rows.length == 1) + assert(rows(0).getString(1) == "PARTIALLY_LOADED") + assert(rows(0).getLong(2) == 1L, "rows_loaded should be 1") + assert(rows(0).getLong(4) == 1L, "errors_seen should be 1") + + val data = + spark.sql(s"SELECT * FROM $dbName0.copy_continue_fewer_cols ORDER BY id").collect() + assert(data.length == 1) + assert(data(0).getInt(0) == 2 && data(0).getString(1) == "Bob" && data(0).getInt(2) == 20) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_fewer_cols") + } + + test("COPY INTO: ON_ERROR = CONTINUE skips CSV rows with extra columns") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_extra_cols") + spark.sql(s"CREATE TABLE $dbName0.copy_continue_extra_cols (id INT, name STRING)") + + withCsvDir { + dir => + createCsvFile(dir, "data.csv", "1,Alice,extra\n2,Bob\n") + + val result = spark.sql(s"""COPY INTO $dbName0.copy_continue_extra_cols + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = CSV) + |ON_ERROR = CONTINUE + |""".stripMargin) + + // A row with more columns than the target schema is a malformed record + // (Spark CSV PERMISSIVE mode), so CONTINUE skips it and keeps the well-formed row. + val rows = result.collect() + assert(rows.length == 1) + assert(rows(0).getString(1) == "PARTIALLY_LOADED") + assert(rows(0).getLong(2) == 1L, "rows_loaded should be 1") + assert(rows(0).getLong(4) == 1L, "errors_seen should be 1") + + val data = + spark.sql(s"SELECT * FROM $dbName0.copy_continue_extra_cols ORDER BY id").collect() + assert(data.length == 1) + assert(data(0).getInt(0) == 2 && data(0).getString(1) == "Bob") + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_continue_extra_cols") + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTestBase.scala index f3fb75a8c6d2..23365e35e68d 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/CopyIntoTestBase.scala @@ -27,7 +27,7 @@ import java.nio.file.Files class CopyIntoTestBase extends PaimonSparkTestBase { - private def createCsvFile(dir: File, name: String, content: String): File = { + protected def createCsvFile(dir: File, name: String, content: String): File = { val file = new File(dir, name) val writer = new PrintWriter(file) try writer.write(content) @@ -35,13 +35,13 @@ class CopyIntoTestBase extends PaimonSparkTestBase { file } - private def withCsvDir(testBody: File => Unit): Unit = { + protected def withCsvDir(testBody: File => Unit): Unit = { val dir = Files.createTempDirectory("copy_into_test").toFile try testBody(dir) finally deleteRecursively(dir) } - private def deleteRecursively(file: File): Unit = { + protected def deleteRecursively(file: File): Unit = { if (file.isDirectory) { file.listFiles().foreach(deleteRecursively) } @@ -234,10 +234,12 @@ class CopyIntoTestBase extends PaimonSparkTestBase { val e = intercept[IllegalArgumentException] { spark.sql(s"""COPY INTO $dbName0.copy_unsup |FROM '${dir.getAbsolutePath}' - |FILE_FORMAT = (TYPE = PARQUET) + |FILE_FORMAT = (TYPE = ORC) |""".stripMargin) } - assert(e.getMessage.contains("Unsupported file format type")) + assert( + e.getMessage.contains("Unsupported file format type") || + e.getMessage.contains("Supported types")) } spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_unsup") @@ -646,6 +648,423 @@ class CopyIntoTestBase extends PaimonSparkTestBase { spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_dup_col") } + protected def createJsonFile(dir: File, name: String, content: String): File = { + val file = new File(dir, name) + val writer = new PrintWriter(file) + try writer.write(content) + finally writer.close() + file + } + + protected def withJsonDir(testBody: File => Unit): Unit = { + val dir = Files.createTempDirectory("copy_into_json_test").toFile + try testBody(dir) + finally deleteRecursively(dir) + } + + test("COPY INTO: basic JSON import") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_basic") + spark.sql(s"CREATE TABLE $dbName0.copy_json_basic (id INT, name STRING, age INT)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","name":"Alice","age":"30"} + |{"id":"2","name":"Bob","age":"25"} + |""".stripMargin) + + val result = spark.sql(s"""COPY INTO $dbName0.copy_json_basic + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + assert(result.collect().length > 0) + assert(result.collect()(0).getString(1) == "LOADED") + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_basic ORDER BY id"), + Seq(Row(1, "Alice", 30), Row(2, "Bob", 25))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_basic") + } + + test("COPY INTO: JSON column name matching ignores order") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_order") + spark.sql(s"CREATE TABLE $dbName0.copy_json_order (id INT, name STRING, age INT)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"age":"30","name":"Alice","id":"1"} + |{"name":"Bob","id":"2","age":"25"} + |""".stripMargin) + + spark.sql(s"""COPY INTO $dbName0.copy_json_order + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_order ORDER BY id"), + Seq(Row(1, "Alice", 30), Row(2, "Bob", 25))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_order") + } + + test("COPY INTO: JSON with MULTI_LINE") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_ml") + spark.sql(s"CREATE TABLE $dbName0.copy_json_ml (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """[ + | {"id":"1","name":"Alice"}, + | {"id":"2","name":"Bob"} + |]""".stripMargin) + + spark.sql(s"""COPY INTO $dbName0.copy_json_ml + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON, MULTI_LINE = TRUE) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_ml ORDER BY id"), + Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_ml") + } + + test("COPY INTO: JSON with explicit column list") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_cols") + spark.sql(s"CREATE TABLE $dbName0.copy_json_cols (id INT, name STRING, age INT)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","name":"Alice"} + |{"id":"2","name":"Bob"} + |""".stripMargin) + + spark.sql(s"""COPY INTO $dbName0.copy_json_cols (id, name) + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_cols ORDER BY id"), + Seq(Row(1, "Alice", null), Row(2, "Bob", null))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_cols") + } + + test("COPY INTO: JSON NULL_IF") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_null") + spark.sql(s"CREATE TABLE $dbName0.copy_json_null (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","name":"NULL"} + |{"id":"2","name":"\\N"} + |{"id":"3","name":"Alice"} + |""".stripMargin + ) + + spark.sql(s"""COPY INTO $dbName0.copy_json_null + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON, NULL_IF = ('NULL', '\\N')) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_null ORDER BY id"), + Seq(Row(1, null), Row(2, null), Row(3, "Alice"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_null") + } + + test("COPY INTO: JSON export") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_export") + spark.sql(s"CREATE TABLE $dbName0.copy_json_export (id INT, name STRING)") + spark.sql(s"INSERT INTO $dbName0.copy_json_export VALUES (1, 'Alice'), (2, 'Bob')") + + withJsonDir { + dir => + val outputPath = new File(dir, "output").getAbsolutePath + + val result = spark.sql(s"""COPY INTO '$outputPath' + |FROM $dbName0.copy_json_export + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + val row = result.collect()(0) + assert(row.getString(0) == outputPath) + assert(row.getLong(2) == 2L) + + val exported = spark.read.json(outputPath) + assert(exported.count() == 2) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_export") + } + + test("COPY INTO: JSON rejects CSV-only options") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_bad_opt") + spark.sql(s"CREATE TABLE $dbName0.copy_json_bad_opt (id INT)") + + withJsonDir { + dir => + createJsonFile(dir, "data.json", """{"id":"1"}""") + + val e = intercept[Exception] { + spark.sql(s"""COPY INTO $dbName0.copy_json_bad_opt + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON, FIELD_DELIMITER = ',') + |""".stripMargin) + } + assert(e.getMessage.contains("Unsupported FILE_FORMAT options")) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_bad_opt") + } + + test("COPY INTO: JSON with date and timestamp columns") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_date") + spark.sql(s"CREATE TABLE $dbName0.copy_json_date (id INT, dt DATE, ts TIMESTAMP)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","dt":"2024-01-15","ts":"2024-01-15T10:30:00"} + |{"id":"2","dt":"2024-06-20","ts":"2024-06-20T14:45:30"} + |""".stripMargin + ) + + spark.sql(s"""COPY INTO $dbName0.copy_json_date + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + val rows = spark.sql(s"SELECT * FROM $dbName0.copy_json_date ORDER BY id").collect() + assert(rows.length == 2) + assert(rows(0).getInt(0) == 1) + assert(rows(0).getDate(1).toString == "2024-01-15") + assert(rows(1).getInt(0) == 2) + assert(rows(1).getDate(1).toString == "2024-06-20") + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_date") + } + + test("COPY INTO: JSON rows_loaded count is accurate") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_count") + spark.sql(s"CREATE TABLE $dbName0.copy_json_count (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","name":"Alice"} + |{"id":"2","name":"Bob"} + |{"id":"3","name":"Charlie"} + |""".stripMargin + ) + + val result = spark.sql(s"""COPY INTO $dbName0.copy_json_count + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + val rows = result.collect() + assert(rows.length == 1) + assert(rows(0).getString(1) == "LOADED") + assert(rows(0).getLong(2) == 3L) + assert(rows(0).getLong(3) == 3L) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_count") + } + + test("COPY INTO: JSON export then import round-trip") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_rt_src") + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_rt_dst") + spark.sql(s"CREATE TABLE $dbName0.copy_json_rt_src (id INT, name STRING, score DOUBLE)") + spark.sql(s"INSERT INTO $dbName0.copy_json_rt_src VALUES (1, 'Alice', 95.5), (2, 'Bob', 87.3)") + spark.sql(s"CREATE TABLE $dbName0.copy_json_rt_dst (id INT, name STRING, score DOUBLE)") + + withJsonDir { + dir => + val exportPath = new File(dir, "exported").getAbsolutePath + + spark.sql(s"""COPY INTO '$exportPath' + |FROM $dbName0.copy_json_rt_src + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + spark.sql(s"""COPY INTO $dbName0.copy_json_rt_dst + |FROM '$exportPath' + |FILE_FORMAT = (TYPE = JSON) + |PATTERN = '.*\\.json' + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_rt_dst ORDER BY id"), + Seq(Row(1, "Alice", 95.5), Row(2, "Bob", 87.3))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_rt_src") + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_rt_dst") + } + + test("COPY INTO: JSON extra fields are ignored") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_extra") + spark.sql(s"CREATE TABLE $dbName0.copy_json_extra (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","name":"Alice","extra_field":"ignored","another":"also_ignored"} + |{"id":"2","name":"Bob","extra_field":"ignored2"} + |""".stripMargin + ) + + spark.sql(s"""COPY INTO $dbName0.copy_json_extra + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_extra ORDER BY id"), + Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_extra") + } + + test("COPY INTO: JSON missing fields become null") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_missing") + spark.sql(s"CREATE TABLE $dbName0.copy_json_missing (id INT, name STRING, age INT)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","name":"Alice"} + |{"id":"2"} + |""".stripMargin + ) + + spark.sql(s"""COPY INTO $dbName0.copy_json_missing + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_json_missing ORDER BY id"), + Seq(Row(1, "Alice", null), Row(2, null, null))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_missing") + } + + test("COPY INTO: JSON malformed data fails with ABORT_STATEMENT") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_malformed") + spark.sql(s"CREATE TABLE $dbName0.copy_json_malformed (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile( + dir, + "data.json", + """{"id":"1","name":"Alice"} + |{this is not valid json} + |""".stripMargin + ) + + intercept[Exception] { + spark.sql(s"""COPY INTO $dbName0.copy_json_malformed + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + } + assert(spark.sql(s"SELECT * FROM $dbName0.copy_json_malformed").count() == 0) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_malformed") + } + + test("COPY INTO: JSON bad cast fails with ABORT_STATEMENT") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_badcast") + spark.sql(s"CREATE TABLE $dbName0.copy_json_badcast (id INT, name STRING)") + + withJsonDir { + dir => + createJsonFile(dir, "data.json", """{"id":"not_a_number","name":"Alice"}""") + + val e = intercept[Exception] { + spark.sql(s"""COPY INTO $dbName0.copy_json_badcast + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = JSON) + |""".stripMargin) + } + val msg = e.getMessage + assert( + msg.contains("Cast failure") || + msg.contains("ABORT_STATEMENT") || + msg.contains("CAST_INVALID_INPUT") || + msg.contains("cannot be cast to") || + e.getCause != null) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_badcast") + } + + test("COPY INTO: JSON export with COMPRESSION") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_compress") + spark.sql(s"CREATE TABLE $dbName0.copy_json_compress (id INT, name STRING)") + spark.sql(s"INSERT INTO $dbName0.copy_json_compress VALUES (1, 'Alice'), (2, 'Bob')") + + withJsonDir { + dir => + val outputPath = new File(dir, "compressed").getAbsolutePath + + val result = spark.sql(s"""COPY INTO '$outputPath' + |FROM $dbName0.copy_json_compress + |FILE_FORMAT = (TYPE = JSON, COMPRESSION = GZIP) + |OVERWRITE = TRUE + |""".stripMargin) + + assert(result.collect()(0).getLong(2) == 2L) + + val outputDir = new File(outputPath) + val gzFiles = outputDir.listFiles().filter(_.getName.endsWith(".gz")) + assert(gzFiles.nonEmpty) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_json_compress") + } + test("COPY INTO: case-insensitive column matching") { spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_case") spark.sql(s"CREATE TABLE $dbName0.copy_case (id INT, name STRING, age INT)") @@ -664,4 +1083,343 @@ class CopyIntoTestBase extends PaimonSparkTestBase { spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_case") } + + // ========== Parquet Tests ========== + + protected def withParquetDir(testBody: File => Unit): Unit = { + val dir = Files.createTempDirectory("copy_into_parquet_test").toFile + try testBody(dir) + finally deleteRecursively(dir) + } + + private def createParquetFile( + dir: File, + name: String, + data: Seq[Row], + schema: org.apache.spark.sql.types.StructType): Unit = { + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.coalesce(1).write.parquet(new File(dir, name).getAbsolutePath) + } + + protected def createParquetSingleFile( + dir: File, + fileName: String, + data: Seq[Row], + schema: org.apache.spark.sql.types.StructType): Unit = { + val tmpDir = new File(dir, s"_tmp_$fileName") + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.coalesce(1).write.parquet(tmpDir.getAbsolutePath) + val partFile = tmpDir.listFiles().find(_.getName.endsWith(".parquet")).get + partFile.renameTo(new File(dir, fileName)) + deleteRecursively(tmpDir) + } + + test("COPY INTO: basic Parquet import") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_basic") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_basic (id INT, name STRING, age INT)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = StructType( + Seq( + StructField("id", IntegerType), + StructField("name", StringType), + StructField("age", IntegerType))) + createParquetFile(dir, "data", Seq(Row(1, "Alice", 30), Row(2, "Bob", 25)), schema) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_basic + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_basic ORDER BY id"), + Seq(Row(1, "Alice", 30), Row(2, "Bob", 25))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_basic") + } + + test("COPY INTO: Parquet column name matching ignores order") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_order") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_order (id INT, name STRING, age INT)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = StructType( + Seq( + StructField("age", IntegerType), + StructField("name", StringType), + StructField("id", IntegerType))) + createParquetFile(dir, "data", Seq(Row(30, "Alice", 1), Row(25, "Bob", 2)), schema) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_order + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_order ORDER BY id"), + Seq(Row(1, "Alice", 30), Row(2, "Bob", 25))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_order") + } + + test("COPY INTO: Parquet with explicit column list") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_cols") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_cols (id INT, name STRING, age INT)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + createParquetFile(dir, "data", Seq(Row(1, "Alice"), Row(2, "Bob")), schema) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_cols (id, name) + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_cols ORDER BY id"), + Seq(Row(1, "Alice", null), Row(2, "Bob", null))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_cols") + } + + test("COPY INTO: Parquet export") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_export") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_export (id INT, name STRING)") + spark.sql(s"INSERT INTO $dbName0.copy_parquet_export VALUES (1, 'Alice'), (2, 'Bob')") + + withParquetDir { + dir => + val outputPath = new File(dir, "output").getAbsolutePath + spark.sql(s"""COPY INTO '$outputPath' + |FROM $dbName0.copy_parquet_export + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + val readBack = spark.read.parquet(outputPath) + checkAnswer(readBack.orderBy("id"), Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_export") + } + + test("COPY INTO: Parquet export with COMPRESSION") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_compress") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_compress (id INT, name STRING)") + spark.sql(s"INSERT INTO $dbName0.copy_parquet_compress VALUES (1, 'Alice'), (2, 'Bob')") + + withParquetDir { + dir => + val outputPath = new File(dir, "output").getAbsolutePath + spark.sql(s"""COPY INTO '$outputPath' + |FROM $dbName0.copy_parquet_compress + |FILE_FORMAT = (TYPE = PARQUET, COMPRESSION = GZIP) + |""".stripMargin) + + val readBack = spark.read.parquet(outputPath) + checkAnswer(readBack.orderBy("id"), Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_compress") + } + + test("COPY INTO: Parquet export then import round-trip") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_rt_src") + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_rt_dst") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_rt_src (id INT, name STRING, score DOUBLE)") + spark.sql( + s"INSERT INTO $dbName0.copy_parquet_rt_src VALUES (1, 'Alice', 95.5), (2, 'Bob', 87.3)") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_rt_dst (id INT, name STRING, score DOUBLE)") + + withParquetDir { + dir => + val outputPath = new File(dir, "export").getAbsolutePath + spark.sql(s"""COPY INTO '$outputPath' + |FROM $dbName0.copy_parquet_rt_src + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_rt_dst + |FROM '$outputPath' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_rt_dst ORDER BY id"), + Seq(Row(1, "Alice", 95.5), Row(2, "Bob", 87.3))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_rt_src") + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_rt_dst") + } + + test("COPY INTO: Parquet extra fields are ignored") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_extra") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_extra (id INT, name STRING)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = StructType( + Seq( + StructField("id", IntegerType), + StructField("name", StringType), + StructField("extra", StringType))) + createParquetFile( + dir, + "data", + Seq(Row(1, "Alice", "ignore"), Row(2, "Bob", "ignore")), + schema) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_extra + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_extra ORDER BY id"), + Seq(Row(1, "Alice"), Row(2, "Bob"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_extra") + } + + test("COPY INTO: Parquet missing fields become null") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_missing") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_missing (id INT, name STRING, age INT)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + createParquetFile(dir, "data", Seq(Row(1, "Alice"), Row(2, "Bob")), schema) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_missing + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_missing ORDER BY id"), + Seq(Row(1, "Alice", null), Row(2, "Bob", null))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_missing") + } + + test("COPY INTO: Parquet FORCE=FALSE skips already-loaded files") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_force") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_force (id INT, name STRING)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + createParquetFile(dir, "data", Seq(Row(1, "Alice")), schema) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_force + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_force + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |FORCE = FALSE + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_force ORDER BY id"), + Seq(Row(1, "Alice"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_force") + } + + test("COPY INTO: Parquet PATTERN filters files") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_pattern") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_pattern (id INT, name STRING)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + createParquetSingleFile(dir, "include_data.parquet", Seq(Row(1, "Alice")), schema) + createParquetSingleFile(dir, "exclude_data.parquet", Seq(Row(2, "Bob")), schema) + + spark.sql(s"""COPY INTO $dbName0.copy_parquet_pattern + |FROM '${dir.getAbsolutePath}' + |FILE_FORMAT = (TYPE = PARQUET) + |PATTERN = 'include.*' + |""".stripMargin) + + checkAnswer( + spark.sql(s"SELECT * FROM $dbName0.copy_parquet_pattern ORDER BY id"), + Seq(Row(1, "Alice"))) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_pattern") + } + + test("COPY INTO: Parquet unsupported option errors") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_opt_err") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_opt_err (id INT, name STRING)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + createParquetFile(dir, "data", Seq(Row(1, "Alice")), schema) + + intercept[IllegalArgumentException] { + spark.sql(s"""COPY INTO $dbName0.copy_parquet_opt_err + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET, FIELD_DELIMITER = ',') + |""".stripMargin) + } + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_opt_err") + } + + test("COPY INTO: Parquet rows_loaded count is accurate") { + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_count") + spark.sql(s"CREATE TABLE $dbName0.copy_parquet_count (id INT, name STRING)") + + withParquetDir { + dir => + import org.apache.spark.sql.types._ + val schema = + StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))) + createParquetFile( + dir, + "data", + Seq(Row(1, "Alice"), Row(2, "Bob"), Row(3, "Charlie")), + schema) + + val result = spark.sql(s"""COPY INTO $dbName0.copy_parquet_count + |FROM '${dir.getAbsolutePath}/data' + |FILE_FORMAT = (TYPE = PARQUET) + |""".stripMargin) + + val rows = result.collect() + val totalLoaded = rows.map(_.getLong(2)).sum + assert(totalLoaded == 3) + } + + spark.sql(s"DROP TABLE IF EXISTS $dbName0.copy_parquet_count") + } } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DDLTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DDLTestBase.scala index b44561a09cff..98cac73f4aed 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DDLTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DDLTestBase.scala @@ -860,7 +860,7 @@ abstract class DDLTestBase extends PaimonSparkTestBase { test("Paimon DDL: create unsupported table") { assert(intercept[Exception] { sql("CREATE TABLE t (id INT) USING paimon1") - }.getMessage.contains("Provider is not supported: paimon1")) + }.getMessage.contains("Provider 'paimon1' is not supported")) } test("Paimon DDL: Drop Partition by partial spec") { diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTestBase.scala index 87d80cb8972b..de31d8ccc45d 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTestBase.scala @@ -166,7 +166,6 @@ abstract class DataFrameWriteTestBase extends PaimonSparkTestBase { .option("bucket", "-1") .option("target-file-size", "256MB") .option("write.merge-schema", "true") - .option("write.merge-schema.explicit-cast", "true") .saveAsTable("test_ctas") val paimonTable = loadTable("test_ctas") @@ -179,7 +178,6 @@ abstract class DataFrameWriteTestBase extends PaimonSparkTestBase { // non-core options should not be here. Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema")) - Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema.explicit-cast")) } } @@ -597,6 +595,7 @@ abstract class DataFrameWriteTestBase extends PaimonSparkTestBase { .format("paimon") .mode("append") .option("write.merge-schema", "true") + .option("write.merge-schema.type-widening", "true") .save(location) val expected3 = if (hasPk) { Row(1L, "a2", BigDecimal.decimal(123), Map("k" -> 11.1)) :: Row( @@ -641,6 +640,7 @@ abstract class DataFrameWriteTestBase extends PaimonSparkTestBase { .format("paimon") .mode("append") .option("write.merge-schema", "true") + .option("write.merge-schema.type-widening", "true") .save(location) val expected4 = expected3 ++ Seq(Row(99L, "df4", BigDecimal.decimal(4.0), Map("4" -> 4.1))) @@ -712,19 +712,21 @@ abstract class DataFrameWriteTestBase extends PaimonSparkTestBase { "c", "d") - // throw UnsupportedOperationException if write.merge-schema.explicit-cast = false + // throw UnsupportedOperationException when type-widening is on but explicit-cast = false assertThrows[UnsupportedOperationException] { df3.write .format("paimon") .mode("append") .option("write.merge-schema", "true") + .option("write.merge-schema.type-widening", "true") .save(location) } - // merge schema and write data when write.merge-schema.explicit-cast = true + // merge schema and write data when type-widening + explicit-cast = true df3.write .format("paimon") .mode("append") .option("write.merge-schema", "true") + .option("write.merge-schema.type-widening", "true") .option("write.merge-schema.explicit-cast", "true") .save(location) val expected3 = if (hasPk) { @@ -864,56 +866,50 @@ abstract class DataFrameWriteTestBase extends PaimonSparkTestBase { } test("Paimon Schema Evolution: some columns is absent in the coming data") { + withTable("T") { + spark.sql("CREATE TABLE T (a INT, b STRING)") - spark.sql(s""" - |CREATE TABLE T (a INT, b STRING) - |""".stripMargin) + val df1 = Seq((1, "2023-08-01"), (2, "2023-08-02")).toDF("a", "b") + df1.write.format("paimon").mode("append").saveAsTable("T") + checkAnswer( + spark.sql("SELECT * FROM T ORDER BY a, b"), + Row(1, "2023-08-01") :: Row(2, "2023-08-02") :: Nil) + + // Case 1: two additional fields: DoubleType and TimestampType + val ts = java.sql.Timestamp.valueOf("2023-08-01 10:00:00.0") + val df2 = Seq((1, "2023-08-01", 12.3d, ts), (3, "2023-08-03", 34.5d, ts)) + .toDF("a", "b", "c", "d") + df2.write + .format("paimon") + .mode("append") + .option("write.merge-schema", "true") + .saveAsTable("T") - val paimonTable = loadTable("T") - val location = paimonTable.location().toString - - val df1 = Seq((1, "2023-08-01"), (2, "2023-08-02")).toDF("a", "b") - df1.write.format("paimon").mode("append").save(location) - checkAnswer( - spark.sql("SELECT * FROM T ORDER BY a, b"), - Row(1, "2023-08-01") :: Row(2, "2023-08-02") :: Nil) - - // Case 1: two additional fields: DoubleType and TimestampType - val ts = java.sql.Timestamp.valueOf("2023-08-01 10:00:00.0") - val df2 = Seq((1, "2023-08-01", 12.3d, ts), (3, "2023-08-03", 34.5d, ts)) - .toDF("a", "b", "c", "d") - df2.write - .format("paimon") - .mode("append") - .option("write.merge-schema", "true") - .save(location) - - // Case 2: colum b and d are absent in the coming data - val df3 = Seq((4, 45.6d), (5, 56.7d)) - .toDF("a", "c") - df3.write - .format("paimon") - .mode("append") - .option("write.merge-schema", "true") - .save(location) - val expected3 = - Row(1, "2023-08-01", null, null) :: Row(1, "2023-08-01", 12.3d, ts) :: Row( - 2, - "2023-08-02", - null, - null) :: Row(3, "2023-08-03", 34.5d, ts) :: Row(4, null, 45.6d, null) :: Row( - 5, - null, - 56.7d, - null) :: Nil - checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected3) + // Case 2: column b and d are absent in the coming data + val df3 = Seq((4, 45.6d), (5, 56.7d)) + .toDF("a", "c") + df3.write + .format("paimon") + .mode("append") + .option("write.merge-schema", "true") + .saveAsTable("T") + val expected3 = + Row(1, "2023-08-01", null, null) :: Row(1, "2023-08-01", 12.3d, ts) :: Row( + 2, + "2023-08-02", + null, + null) :: Row(3, "2023-08-03", 34.5d, ts) :: Row(4, null, 45.6d, null) :: Row( + 5, + null, + 56.7d, + null) :: Nil + checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected3) + } } test("Paimon write merge-schema conflict: deep nested array element bigint -> string") { for (useV2Write <- Seq("true", "false")) { - withSparkSQLConf( - "spark.paimon.write.use-v2-write" -> useV2Write, - "spark.paimon.write.merge-schema.explicit-cast" -> "true") { + withSparkSQLConf("spark.paimon.write.use-v2-write" -> useV2Write) { withTable("target") { sql(""" |CREATE TABLE target ( @@ -953,9 +949,7 @@ abstract class DataFrameWriteTestBase extends PaimonSparkTestBase { test("Paimon write merge-schema conflict: top-level same-name column string vs bigint") { for (useV2Write <- Seq("true", "false")) { - withSparkSQLConf( - "spark.paimon.write.use-v2-write" -> useV2Write, - "spark.paimon.write.merge-schema.explicit-cast" -> "true") { + withSparkSQLConf("spark.paimon.write.use-v2-write" -> useV2Write) { withTable("target") { sql("CREATE TABLE target (id STRING, value BIGINT) USING paimon") sql("INSERT INTO target VALUES ('r0', 1000L)") diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala index 766fac386ffb..feb20e8da298 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala @@ -268,6 +268,25 @@ class PushDownAggregatesTest extends PaimonSparkTestBase with AdaptiveSparkPlanH }) } + test("Push down aggregate - non-primary-key DV table with tight bounds") { + withTable("T") { + sql(""" + |CREATE TABLE T (id INT) + |TBLPROPERTIES ( + | 'deletion-vectors.enabled' = 'true', + | 'bucket-key' = 'id', + | 'bucket' = '1' + |) + |""".stripMargin) + sql("INSERT INTO T SELECT id FROM range (0, 5000)") + // No deleted rows, so split stats can answer MIN/MAX. + runAndCheckAggregate("SELECT COUNT(*), MIN(id), MAX(id) FROM T", Row(5000, 0, 4999) :: Nil, 0) + sql("DELETE FROM T WHERE id > 100 AND id <= 400") + // Deleted rows make file stats wide, so Spark keeps MIN/MAX aggregation. + runAndCheckAggregate("SELECT MIN(id), MAX(id) FROM T", Row(0, 4999) :: Nil, 2) + } + } + test("Push down aggregate: group by partial partition of a multi partition table") { sql(s""" |CREATE TABLE T ( diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala index 644eb49847b6..1da318e2825d 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala @@ -19,12 +19,18 @@ package org.apache.paimon.spark.sql import org.apache.paimon.Snapshot.CommitKind +import org.apache.paimon.spark.PaimonMetrics.RESULTED_TABLE_FILES import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.spark.read.PaimonSplitScan import org.apache.paimon.table.source.DataSplit import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Join, LogicalPlan, MergeRows, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Join, LogicalPlan, MergeRows, RepartitionByExpression, Sort, SubqueryAlias} +import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.paimon.Utils import org.apache.spark.sql.util.QueryExecutionListener import java.util.concurrent.{CountDownLatch, TimeUnit} @@ -34,7 +40,7 @@ import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.DurationInt -abstract class RowTrackingTestBase extends PaimonSparkTestBase { +abstract class RowTrackingTestBase extends PaimonSparkTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -328,6 +334,21 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase { } } + test("Row Tracking: delete preserves row tracking metadata for update") { + withTable("t") { + sql("CREATE TABLE t (id INT, data INT) TBLPROPERTIES ('row-tracking.enabled' = 'true')") + sql("INSERT INTO t SELECT /*+ REPARTITION(1) */ id, id AS data FROM range(1, 4)") + + sql("DELETE FROM t WHERE id = 2") + sql("UPDATE t SET data = 33 WHERE _ROW_ID = 2") + + checkAnswer( + sql("SELECT *, _ROW_ID, _SEQUENCE_NUMBER FROM t ORDER BY id"), + Seq(Row(1, 1, 0, 1), Row(3, 33, 2, 3)) + ) + } + } + test("Row Tracking: update table") { withTable("t") { // only enable row tracking @@ -628,6 +649,174 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase { } } + Seq(false, true).foreach { + filePruning => + test(s"Data Evolution: merge into file pruning: $filePruning") { + withSparkSQLConf( + "spark.paimon.data-evolution.merge-into.file-pruning" -> + filePruning.toString) { + withTable("source", "target") { + sql("CREATE TABLE source (id INT, b INT, dt STRING)") + sql("INSERT INTO source VALUES (1, 100, '2026-05-28'), (3, 300, '2026-05-28')") + + sql(""" + |CREATE TABLE target (id INT, b INT, c STRING, dt STRING) + |TBLPROPERTIES ( + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |PARTITIONED BY (dt) + |""".stripMargin) + sql("INSERT INTO target VALUES (1, 10, 'old-1', '2026-05-28'), (2, 20, 'old-2', '2026-05-28'), (4, 40, 'old-4', '2026-05-29')") + + executeMergeIntoAndAssertFilePruning( + """ + |MERGE INTO target + |USING source + |ON target.id = source.id AND target.dt = source.dt + |WHEN MATCHED THEN UPDATE SET target.b = source.b + |WHEN NOT MATCHED THEN INSERT (id, b, c, dt) VALUES (id, b, 'new', dt) + |""".stripMargin, + filePruning + ) + + checkAnswer( + sql("SELECT id, b, c, dt FROM target ORDER BY id"), + Seq( + Row(1, 100, "old-1", "2026-05-28"), + Row(2, 20, "old-2", "2026-05-28"), + Row(3, 300, "new", "2026-05-28"), + Row(4, 40, "old-4", "2026-05-29")) + ) + } + } + } + } + + private def executeMergeIntoAndAssertFilePruning(mergeSql: String, filePruning: Boolean): Unit = { + @volatile var hasTargetFilePruningJoin = false + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + checkPlan(qe.analyzed) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + checkPlan(qe.analyzed) + } + + private def checkPlan(plan: LogicalPlan): Unit = { + if (isTargetFilePruningJoinPlan(plan)) { + hasTargetFilePruningJoin = true + assert( + filePruning, + s"File pruning join should be skipped when file pruning is disabled: $plan") + } + } + } + + spark.listenerManager.register(listener) + try { + sql(mergeSql) + Utils.waitUntilEventEmpty(spark) + } finally { + spark.listenerManager.unregister(listener) + } + + if (filePruning) { + assert(hasTargetFilePruningJoin, "Expected target file pruning join plan.") + } + } + + private def isTargetFilePruningJoinPlan(plan: LogicalPlan): Boolean = { + plan.collectFirst { case _: Deduplicate => true }.nonEmpty && + plan.collectFirst { case _: Join => true }.nonEmpty && + plan.collectFirst { + case SubqueryAlias(identifier, _) if identifier.name == "_left" => true + }.nonEmpty && + plan.collectFirst { case _: MergeRows => true }.isEmpty + } + + test("Data Evolution: merge into skip file pruning push down partition filter in on condition") { + withSparkSQLConf("spark.paimon.data-evolution.merge-into.file-pruning" -> "false") { + withTempView("source") { + withTable("target") { + Seq((1, 100), (2, 200), (3, 300)).toDF("id", "b").createOrReplaceTempView("source") + + sql(""" + |CREATE TABLE target (id INT, b INT, dt STRING) + |TBLPROPERTIES ( + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |PARTITIONED BY (dt) + |""".stripMargin) + sql(""" + |INSERT INTO target VALUES + | (1, 10, '2026-05-28'), + | (2, 20, '2026-05-29'), + | (3, 30, '2026-05-30') + |""".stripMargin) + + val mergeSql = + """ + |MERGE INTO target + |USING source + |ON target.id = source.id AND target.dt = '2026-05-28' + |WHEN MATCHED THEN UPDATE SET target.b = source.b + |""".stripMargin + + executeMergeIntoAndAssertPartitionPruned(mergeSql) + checkAnswer( + sql("SELECT id, b, dt FROM target ORDER BY id"), + Seq(Row(1, 100, "2026-05-28"), Row(2, 20, "2026-05-29"), Row(3, 30, "2026-05-30")) + ) + } + } + } + } + + private def executeMergeIntoAndAssertPartitionPruned(mergeSql: String): Unit = { + val resultedTableFiles = new java.util.concurrent.CopyOnWriteArrayList[Long]() + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + checkPlan(qe) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + checkPlan(qe) + } + + private def checkPlan(qe: QueryExecution): Unit = { + collect(qe.executedPlan) { + case scanExec: BatchScanExec + if scanExec.scan.isInstanceOf[PaimonSplitScan] && + scanExec.scan.description().startsWith("PaimonSplitScan: [target]") => + val scan = scanExec.scan.asInstanceOf[PaimonSplitScan] + metric(scan.reportDriverMetrics(), RESULTED_TABLE_FILES) + }.foreach(resultedTableFile => resultedTableFiles.add(resultedTableFile)) + } + } + + spark.listenerManager.register(listener) + try { + sql(mergeSql) + Utils.waitUntilEventEmpty(spark) + } finally { + spark.listenerManager.unregister(listener) + } + + val metrics = resultedTableFiles.asScala + assert(metrics.nonEmpty, "Expected target PaimonSplitScan in merge into executed plans.") + assert( + metrics.contains(1), + s"Expected target scan to read only one partition file, but got resulted table files: " + + metrics.mkString(", ") + ) + } + + private def metric(metrics: Array[CustomTaskMetric], name: String): Long = { + metrics.find(_.name() == name).get.value() + } + test("Data Evolution: merge into table with data-evolution on _ROW_ID") { withTable("source", "target") { sql( diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala index 2a17d6297ea5..b1bf57193b4e 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala @@ -30,6 +30,7 @@ class V2WriteMergeSchemaTest extends PaimonSparkTestBase { .set("spark.sql.catalog.paimon.cache-enabled", "false") .set("spark.paimon.write.use-v2-write", "true") .set("spark.paimon.write.merge-schema", "true") + .set("spark.paimon.write.merge-schema.type-widening", "true") .set("spark.paimon.write.merge-schema.explicit-cast", "true") } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WriteMergeSchemaTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WriteMergeSchemaTest.scala index 83b423438aa8..094bd60e48fd 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WriteMergeSchemaTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WriteMergeSchemaTest.scala @@ -482,7 +482,7 @@ class WriteMergeSchemaTest extends PaimonSparkTestBase { withTable("t") { withSparkSQLConf( "spark.paimon.write.merge-schema" -> "true", - "spark.paimon.write.merge-schema.explicit-cast" -> "true") { + "spark.paimon.write.merge-schema.type-widening" -> "true") { sql("CREATE TABLE t (id INT, info STRUCT)") sql("INSERT INTO t VALUES (1, struct(10, 'a'))") @@ -499,6 +499,370 @@ class WriteMergeSchemaTest extends PaimonSparkTestBase { } } + test("Write merge schema: case-insensitive column matching with new column") { + withTable("t") { + sql("CREATE TABLE t (id INT, name STRING)") + sql("INSERT INTO t VALUES (1, 'a'), (2, 'b')") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("INSERT INTO t BY NAME SELECT 3 AS ID, 'c' AS NAME, 100 AS extra") + } + + val columnNames = spark.table("t").schema.fieldNames + assert( + columnNames.length == 3, + s"Expected 3 columns (id, name, extra) but got ${columnNames.length}: ${columnNames.mkString(", ")}") + + checkAnswer( + sql("SELECT * FROM t ORDER BY id"), + Seq(Row(1, "a", null), Row(2, "b", null), Row(3, "c", 100))) + } + } + + test("Write merge schema: case-insensitive dataframe write") { + withTable("t") { + sql("CREATE TABLE t (id INT, name STRING)") + Seq((1, "a"), (2, "b")) + .toDF("id", "name") + .write + .format("paimon") + .mode("append") + .saveAsTable("t") + + Seq((3, "c", 100)) + .toDF("ID", "NAME", "extra") + .write + .format("paimon") + .mode("append") + .option("write.merge-schema", "true") + .saveAsTable("t") + + val columnNames = spark.table("t").schema.fieldNames + assert( + columnNames.length == 3, + s"Expected 3 columns but got ${columnNames.length}: ${columnNames.mkString(", ")}") + + checkAnswer( + sql("SELECT * FROM t ORDER BY id"), + Seq(Row(1, "a", null), Row(2, "b", null), Row(3, "c", 100))) + } + } + + test("Write merge schema: only case differs, no schema change") { + withTable("t") { + sql("CREATE TABLE t (id INT, name STRING)") + sql("INSERT INTO t VALUES (1, 'a')") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("INSERT INTO t BY NAME SELECT 2 AS ID, 'b' AS NAME") + } + + val columnNames = spark.table("t").schema.fieldNames + assert( + columnNames.toSeq == Seq("id", "name"), + s"Schema changed unexpectedly: ${columnNames.mkString(", ")}") + + checkAnswer(sql("SELECT * FROM t ORDER BY id"), Seq(Row(1, "a"), Row(2, "b"))) + } + } + + test("Write merge schema: repeated writes with alternating case") { + withTable("t") { + sql("CREATE TABLE t (id INT, name STRING)") + sql("INSERT INTO t VALUES (1, 'a')") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("INSERT INTO t BY NAME SELECT 2 AS ID, 'b' AS NAME") + sql("INSERT INTO t BY NAME SELECT 3 AS Id, 'c' AS NaMe") + sql("INSERT INTO t BY NAME SELECT 4 AS iD, 'd' AS nAmE") + } + + val columnNames = spark.table("t").schema.fieldNames + assert(columnNames.length == 2, s"Expected 2 columns but got: ${columnNames.mkString(", ")}") + + checkAnswer( + sql("SELECT * FROM t ORDER BY id"), + Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"), Row(4, "d"))) + } + } + + test("Write merge schema: nested struct case mismatch with new sub-field") { + withTable("t") { + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("CREATE TABLE t (id INT, info STRUCT)") + sql("INSERT INTO t VALUES (1, named_struct('key1', 'a', 'key2', 'b'))") + + sql( + "INSERT INTO t BY NAME SELECT 2 AS id, " + + "named_struct('KEY1', 'A', 'KEY2', 'B', 'key3', 'C') AS info") + + val infoFields = spark + .table("t") + .schema("info") + .dataType + .asInstanceOf[org.apache.spark.sql.types.StructType] + .fieldNames + assert( + infoFields.length == 3, + s"Expected 3 sub-fields but got ${infoFields.length}: ${infoFields.mkString(", ")}") + + checkAnswer( + sql("SELECT id, info.key1, info.key2, info.key3 FROM t ORDER BY id"), + Seq(Row(1, "a", "b", null), Row(2, "A", "B", "C"))) + } + } + } + + test("Merge into with merge-schema: source uppercase should not create duplicate columns") { + withTable("t") { + sql("""CREATE TABLE t (id INT, name STRING, value INT) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, 'a', 10), (2, 'b', 20)") + + spark + .sql("SELECT 1 AS ID, 'A' AS NAME, 100 AS VALUE UNION ALL SELECT 3 AS ID, 'c' AS NAME, 30 AS VALUE") + .createOrReplaceTempView("s") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.ID + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + + val columnNames = spark.table("t").schema.fieldNames.toSeq + assert( + columnNames.size == 3, + s"Expected 3 columns but got ${columnNames.size}: ${columnNames.mkString(", ")}") + + checkAnswer( + sql("SELECT id, name, value FROM t ORDER BY id"), + Seq(Row(1, "A", 100), Row(2, "b", 20), Row(3, "c", 30))) + } + } + + test("Merge into with merge-schema: append-only target evolves schema") { + withTable("t") { + sql("CREATE TABLE t (id INT, name STRING) USING paimon") + sql("INSERT INTO t VALUES (1, 'a'), (2, 'b')") + + spark + .sql("SELECT 1 AS id, 'a2' AS name, 100 AS value UNION ALL SELECT 3 AS id, 'c' AS name, 30 AS value") + .createOrReplaceTempView("s") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.id + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + + assert(spark.table("t").schema.fieldNames.toSeq == Seq("id", "name", "value")) + checkAnswer( + sql("SELECT id, name, value FROM t ORDER BY id"), + Seq(Row(1, "a2", 100), Row(2, "b", null), Row(3, "c", 30))) + } + } + + test("Merge into with merge-schema: extra column with case-mismatched existing columns") { + withTable("t") { + sql("""CREATE TABLE t (id INT, name STRING) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, 'a'), (2, 'b')") + + spark + .sql("SELECT 1 AS ID, 'A' AS NAME, 100 AS extra_col UNION ALL SELECT 3 AS ID, 'c' AS NAME, 30 AS extra_col") + .createOrReplaceTempView("s") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.ID + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + + val columnNames = spark.table("t").schema.fieldNames.toSeq + assert( + columnNames.size == 3, + s"Expected 3 columns (id, name, extra_col) but got ${columnNames.size}: ${columnNames.mkString(", ")}") + + checkAnswer( + sql("SELECT id, name, extra_col FROM t ORDER BY id"), + Seq(Row(1, "A", 100), Row(2, "b", null), Row(3, "c", 30))) + } + } + + test("Merge into with merge-schema: only case differs, schema should not change") { + withTable("t") { + sql("""CREATE TABLE t (id INT, name STRING) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, 'a')") + + spark.sql("SELECT 1 AS ID, 'A' AS NAME").createOrReplaceTempView("s") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.ID + | WHEN MATCHED THEN UPDATE SET *""".stripMargin) + } + + val columnNames = spark.table("t").schema.fieldNames + assert( + columnNames.toSeq == Seq("id", "name"), + s"Schema changed unexpectedly: ${columnNames.mkString(", ")}") + + checkAnswer(sql("SELECT * FROM t"), Seq(Row(1, "A"))) + } + } + + test("Merge into with merge-schema: nested struct fields case mismatch") { + withTable("t") { + sql("""CREATE TABLE t (id INT, info STRUCT) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, named_struct('key1', 'a', 'key2', 'b'))") + + spark + .sql("SELECT 1 AS id, named_struct('KEY1', 'A', 'KEY2', 'B', 'key3', 'C') AS info") + .createOrReplaceTempView("s") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.id + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + + val infoFields = spark + .table("t") + .schema("info") + .dataType + .asInstanceOf[org.apache.spark.sql.types.StructType] + .fieldNames + assert( + infoFields.length == 3, + s"Expected 3 sub-fields but got ${infoFields.length}: ${infoFields.mkString(", ")}") + + checkAnswer( + sql("SELECT id, info.key1, info.key2, info.key3 FROM t"), + Seq(Row(1, "A", "B", "C"))) + } + } + + test("Merge into with merge-schema: repeated writes with alternating case") { + withTable("t") { + sql("""CREATE TABLE t (id INT, name STRING) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, 'a')") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + spark.sql("SELECT 2 AS ID, 'b' AS NAME").createOrReplaceTempView("s1") + sql("""MERGE INTO t USING s1 ON t.id = s1.ID + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + + spark.sql("SELECT 3 AS Id, 'c' AS NaMe").createOrReplaceTempView("s2") + sql("""MERGE INTO t USING s2 ON t.id = s2.Id + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + + val columnNames = spark.table("t").schema.fieldNames + assert(columnNames.length == 2, s"Expected 2 columns but got: ${columnNames.mkString(", ")}") + + checkAnswer(sql("SELECT * FROM t ORDER BY id"), Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + + test("Merge into without merge-schema: case-insensitive matching works normally") { + withTable("t") { + sql("""CREATE TABLE t (id INT, name STRING) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, 'a'), (2, 'b')") + + spark + .sql("SELECT 1 AS ID, 'A' AS NAME UNION ALL SELECT 3 AS ID, 'c' AS NAME") + .createOrReplaceTempView("s") + + sql("""MERGE INTO t USING s ON t.id = s.ID + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + + val schema = spark.table("t").schema + assert( + schema.fieldNames.length == 2, + s"Expected 2 columns but got: ${schema.fieldNames.mkString(", ")}") + + checkAnswer( + sql("SELECT id, name FROM t ORDER BY id"), + Seq(Row(1, "A"), Row(2, "b"), Row(3, "c"))) + } + } + + test("Merge into without merge-schema: ARRAY target with ARRAY source") { + withTable("t") { + sql("""CREATE TABLE t (id INT, ports ARRAY) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, array(80, 443))") + + spark + .sql("SELECT 1 AS id, array(cast(8080 as bigint), cast(9090 as bigint)) AS ports") + .createOrReplaceTempView("s") + + sql("""MERGE INTO t USING s ON t.id = s.id + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + + val portsType = spark.table("t").schema("ports").dataType + assert( + portsType.simpleString == "array", + s"Expected array but got ${portsType.simpleString}") + + checkAnswer(sql("SELECT * FROM t"), Seq(Row(1, Seq(8080, 9090)))) + } + } + + test("Write merge schema: INT to BIGINT keeps target type") { + withTable("t") { + sql("CREATE TABLE t (id INT, value INT)") + sql("INSERT INTO t VALUES (1, 100)") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("INSERT INTO t BY NAME SELECT 2 AS id, cast(200 as bigint) AS value") + } + + val valueType = spark.table("t").schema("value").dataType + assert(valueType.simpleString == "int", s"Expected int but got ${valueType.simpleString}") + + checkAnswer(sql("SELECT * FROM t ORDER BY id"), Seq(Row(1, 100), Row(2, 200))) + } + } + + test("Merge into with merge-schema: case-sensitive mode treats different case as new columns") { + withTable("t") { + sql("""CREATE TABLE t (id INT, name STRING) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, 'a')") + + spark + .sql("SELECT 1 AS ID, 'A' AS NAME, 100 AS extra") + .createOrReplaceTempView("s") + + withSparkSQLConf( + "spark.paimon.write.merge-schema" -> "true", + "spark.sql.caseSensitive" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.ID + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + + val columnNames = spark.table("t").schema.fieldNames + assert( + columnNames.length == 5, + s"Expected 5 columns (id, name, ID, NAME, extra) but got ${columnNames.length}: ${columnNames.mkString(", ")}") + } + } + test("Write merge schema: array of struct missing nested field by dataframe") { withTable("t") { sql(""" @@ -546,4 +910,95 @@ class WriteMergeSchemaTest extends PaimonSparkTestBase { Seq(Row(Seq(1L, 2L), 1L, 2L, Seq(Row(10, "v2", "v3", "v4", "v5", null))))) } } + + test("Merge into with merge-schema: ARRAY to ARRAY keeps target type") { + withTable("t") { + sql("""CREATE TABLE t (id INT, ports ARRAY) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, array(80, 443))") + + spark + .sql("SELECT 1 AS id, array(cast(8080 as bigint), cast(9090 as bigint)) AS ports") + .createOrReplaceTempView("s") + + withSparkSQLConf("spark.paimon.write.merge-schema" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.id + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + + val portsType = spark.table("t").schema("ports").dataType + assert( + portsType.simpleString == "array", + s"Expected array but got ${portsType.simpleString}") + + checkAnswer(sql("SELECT * FROM t"), Seq(Row(1, Seq(8080, 9090)))) + } + } + + test("Write merge schema: case-sensitive mode treats different case as new columns") { + withTable("t") { + sql("CREATE TABLE t (id INT, name STRING)") + sql("INSERT INTO t VALUES (1, 'a')") + + withSparkSQLConf( + "spark.paimon.write.merge-schema" -> "true", + "spark.sql.caseSensitive" -> "true") { + sql("INSERT INTO t BY NAME SELECT 2 AS ID, 'b' AS NAME, 100 AS extra") + } + + val columnNames = spark.table("t").schema.fieldNames + assert( + columnNames.length == 5, + s"Expected 5 columns (id, name, ID, NAME, extra) but got ${columnNames.length}: ${columnNames.mkString(", ")}") + } + } + + test("Merge into with type-widening=true: INT to BIGINT widens existing column") { + withTable("t") { + sql("""CREATE TABLE t (id INT, v INT) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, 10)") + + spark.sql("SELECT 1 AS id, cast(200 AS bigint) AS v").createOrReplaceTempView("s") + + withSparkSQLConf( + "spark.paimon.write.merge-schema" -> "true", + "spark.paimon.write.merge-schema.type-widening" -> "true") { + sql("""MERGE INTO t USING s ON t.id = s.id + | WHEN MATCHED THEN UPDATE SET *""".stripMargin) + } + + assert(spark.table("t").schema("v").dataType.simpleString == "bigint") + checkAnswer(sql("SELECT * FROM t"), Seq(Row(1, 200L))) + } + } + + test("Merge into with type-widening=true: ARRAY element widening throws (known limitation)") { + withTable("t") { + sql("""CREATE TABLE t (id INT, ports ARRAY) + | USING paimon + | TBLPROPERTIES ('primary-key' = 'id', 'bucket' = '1')""".stripMargin) + sql("INSERT INTO t VALUES (1, array(80))") + + spark + .sql("SELECT 2 AS id, array(cast(9090 AS bigint)) AS ports") + .createOrReplaceTempView("s") + + // KNOWN LIMITATION: widening the element type of ARRAY/MAP is not yet supported by + // SchemaManager.generateTableSchema (CastExecutors has no ARRAY -> ARRAY + // rule). Tracked as a follow-up; until then type-widening on complex element types throws. + withSparkSQLConf( + "spark.paimon.write.merge-schema" -> "true", + "spark.paimon.write.merge-schema.type-widening" -> "true") { + intercept[Exception] { + sql("""MERGE INTO t USING s ON t.id = s.id + | WHEN MATCHED THEN UPDATE SET * + | WHEN NOT MATCHED THEN INSERT *""".stripMargin) + } + } + } + } } diff --git a/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark3SqlExtensionsParser.scala b/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark3SqlExtensionsParser.scala index 07481b6f639f..04f824145499 100644 --- a/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark3SqlExtensionsParser.scala +++ b/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark3SqlExtensionsParser.scala @@ -22,4 +22,5 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.parser.extensions.AbstractPaimonSparkSqlExtensionsParser class PaimonSpark3SqlExtensionsParser(override val delegate: ParserInterface) - extends AbstractPaimonSparkSqlExtensionsParser(delegate) {} + extends AbstractPaimonSparkSqlExtensionsParser(delegate) + with ParserInterface diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala index 9bd395f3336d..fe4833a03cb4 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalyst/parser/extensions/PaimonSpark4SqlExtensionsParser.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.parser.extensions.AbstractPaimonSparkSqlExt import org.apache.spark.sql.types.StructType class PaimonSpark4SqlExtensionsParser(override val delegate: ParserInterface) - extends AbstractPaimonSparkSqlExtensionsParser(delegate) { + extends AbstractPaimonSparkSqlExtensionsParser(delegate) + with ParserInterface { override def parseRoutineParam(sqlText: String): StructType = delegate.parseRoutineParam(sqlText) } diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PureAppendOnlyScope.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PureAppendOnlyScope.scala index c005bd40c717..4fdf2bafc02f 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PureAppendOnlyScope.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PureAppendOnlyScope.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.paimon.spark.SparkTable +import org.apache.paimon.spark.{SparkTable, SparkTypeUtils} import org.apache.paimon.table.FileStoreTable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -26,14 +26,13 @@ import org.apache.spark.sql.execution.datasources.v2.ExtractV2Table /** * Shared scope predicates for the Spark 4.1 Resolution-batch row-level rewrite rules - * ([[Spark41AppendOnlyRowLevelRewrite]] for UPDATE + metadata-only DELETE reverse-optimization, + * ([[Spark41UpdateTableRewrite]] for UPDATE + metadata-only DELETE reverse-optimization, * [[Spark41MergeIntoRewrite]] for MERGE). * - * Both rules only intercept operations against **pure append-only** Paimon tables: no primary key, - * row tracking, data evolution, deletion vectors, or fixed-length `CHAR(n)` columns. Tables that - * violate any of these constraints either have a working V2 rewrite path on 4.1 (PK / DV / RT / DE - * go through Paimon's own postHoc V1 commands) or race with Spark's `CharVarcharCodegenUtils` - * padding Project (CHAR columns — see [[hasCharColumn]]). + * These rules only intercept operations against Paimon tables that are valid for Spark's V2 + * copy-on-write rewrite: no primary key, data evolution, deletion vectors, or fixed-length + * `CHAR(n)` columns. Row-tracking-only tables are included; tables that violate any of these + * constraints go through Paimon's postHoc V1 commands or Spark's built-in analysis path. * * Kept as a mix-in trait so the two rewrite objects stay single-responsibility (one rule per Spark * row-level command, mirroring Spark's own `RewriteUpdateTable` / `RewriteMergeIntoTable` layout) @@ -41,36 +40,25 @@ import org.apache.spark.sql.execution.datasources.v2.ExtractV2Table */ trait PureAppendOnlyScope { - /** - * Whether the target of a row-level operation is a pure append-only Paimon table that Spark 4.1's - * built-in rewrite rules can't handle (see the two rule class docs for why). - */ - protected def targetsPureAppendOnly(aliasedTable: LogicalPlan): Boolean = { + protected def targetsV2CopyOnWriteTable(aliasedTable: LogicalPlan): Boolean = { + targetsPaimonFileStoreTable(aliasedTable) { + case (sparkTable, fs) => + fs.primaryKeys().isEmpty && + !sparkTable.coreOptions.dataEvolutionEnabled() && + !sparkTable.coreOptions.deletionVectorsEnabled() && + !SparkTypeUtils.containsCharType(fs.rowType()) + } + } + + private def targetsPaimonFileStoreTable(aliasedTable: LogicalPlan)( + predicate: (SparkTable, FileStoreTable) => Boolean): Boolean = { EliminateSubqueryAliases(aliasedTable) match { case ExtractV2Table(sparkTable: SparkTable) => sparkTable.getTable match { - case fs: FileStoreTable => - fs.primaryKeys().isEmpty && - !sparkTable.coreOptions.rowTrackingEnabled() && - !sparkTable.coreOptions.dataEvolutionEnabled() && - !sparkTable.coreOptions.deletionVectorsEnabled() && - !hasCharColumn(fs) + case fs: FileStoreTable => predicate(sparkTable, fs) case _ => false } case _ => false } } - - /** - * Tables with fixed-length `CHAR(n)` columns go through Spark's - * `CharVarcharCodegenUtils.readSidePadding` Project that gets inserted between the - * `DataSourceV2Relation` and its consumers. If we intercept before that padding project settles, - * CheckAnalysis trips on mismatched attribute ids (see PR 7648 history). Let those plans fall - * through to Paimon's postHoc V1 fallback rules which run after the padding project stabilizes. - */ - protected def hasCharColumn(fs: FileStoreTable): Boolean = { - import org.apache.paimon.types.CharType - import scala.collection.JavaConverters._ - fs.rowType().getFields.asScala.exists(_.`type`().isInstanceOf[CharType]) - } } diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41DeleteMetadataRestore.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41DeleteMetadataRestore.scala index 0efbe6a8bb96..d21bc8098e49 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41DeleteMetadataRestore.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41DeleteMetadataRestore.scala @@ -42,8 +42,8 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation * fast path that a `DeleteFromPaimonTableCommand` would enable. * * This rule pattern-matches the `ReplaceData` Spark produced (tagged with - * `RowLevelOperation.Command.DELETE`) and, if the target is a pure append-only Paimon table (see - * [[PureAppendOnlyScope]]) and the predicate is metadata-only, rewrites back to + * `RowLevelOperation.Command.DELETE`) and, if the target is a Paimon table eligible for V2 + * copy-on-write (see [[PureAppendOnlyScope]]) and the predicate is metadata-only, rewrites back to * `DeleteFromPaimonTableCommand`. Non-metadata-only DELETE is left alone (Spark's `ReplaceData` is * correct for data deletes). This is **not** a rewrite of `DeleteFromTable` — it's a restoration * layered on top of Spark's existing rewrite output, hence the `…Restore` naming rather than @@ -64,8 +64,8 @@ object Spark41DeleteMetadataRestore extends RewriteRowLevelCommand with PureAppe } /** - * Whether a `ReplaceData` node (Spark 4.1's post-rewrite DELETE form) targets a pure append-only - * Paimon table with a metadata-only predicate, such that converting back to + * Whether a `ReplaceData` node (Spark 4.1's post-rewrite DELETE form) targets a Paimon table + * eligible for V2 copy-on-write with a metadata-only predicate, such that converting back to * `DeleteFromPaimonTableCommand` would let the optimizer fold to `TruncatePaimonTableWithFilter`. */ private def isMetadataOnlyDeleteOnAppendOnlyPaimon(rd: ReplaceData): Boolean = { @@ -78,7 +78,7 @@ object Spark41DeleteMetadataRestore extends RewriteRowLevelCommand with PureAppe case _ => false } writeIsDelete && (rd.originalTable match { - case r: DataSourceV2Relation if targetsPureAppendOnly(r) => + case r: DataSourceV2Relation if targetsV2CopyOnWriteTable(r) => r.table match { case spk: SparkTable => spk.getTable match { diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41MergeIntoRewrite.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41MergeIntoRewrite.scala index 0b2e6c607e17..f3dd21f15f45 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41MergeIntoRewrite.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41MergeIntoRewrite.scala @@ -38,9 +38,9 @@ import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * Spark 4.1-only Resolution-batch rule that rewrites MERGE INTO on pure append-only Paimon tables - * (no PK / RT / DE / DV) into V2 `ReplaceData` / `AppendData` plans, mirroring Spark's built-in - * `RewriteMergeIntoTable` for non-`SupportsDelta` row-level tables. + * Spark 4.1-only Resolution-batch rule that rewrites MERGE INTO on Paimon tables eligible for V2 + * copy-on-write (no PK / DE / DV / CHAR) into V2 `ReplaceData` / `AppendData` plans, mirroring + * Spark's built-in `RewriteMergeIntoTable` for non-`SupportsDelta` row-level tables. * * In Spark 4.1, `RewriteMergeIntoTable` runs in the Resolution batch via `resolveOperators`, which * short-circuits on `analyzed=true` plans — by the time it would fire, the `MergeIntoTable` is @@ -52,9 +52,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * We fire before `ResolveAssignments`, so `m.aligned` is `false`. The rule pre-aligns each action * list via `PaimonAssignmentUtils.alignActions` (shared with the postHoc `PaimonMergeInto` rule). * - * CHAR columns are excluded — `readSidePadding` races with the rewrite and trips CheckAnalysis; - * those plans fall back to the postHoc `PaimonMergeInto` V1 path, which also owns PK / RT / DE / DV - * tables via `RowLevelHelper.shouldFallbackToV1MergeInto`. + * Row-tracking-only tables use the same V2 copy-on-write rewrite. CHAR columns are excluded — + * `readSidePadding` races with the rewrite and trips CheckAnalysis; those plans fall back to the + * postHoc `PaimonMergeInto` V1 path, which also owns PK / DE / DV tables via + * `RowLevelHelper.shouldFallbackToV1MergeInto`. */ object Spark41MergeIntoRewrite extends RewriteRowLevelCommand @@ -71,7 +72,7 @@ object Spark41MergeIntoRewrite plan.transformDown { case m: MergeIntoTable if m.resolved && m.rewritable && !m.needSchemaEvolution && - targetsPureAppendOnly(m.targetTable) => + targetsV2CopyOnWriteTable(m.targetTable) => // Pure append-only tables skip postHoc `PaimonMergeInto`, so evolve schema here. val evolved = evolveSchemaIfPaimon(m) rewrite(alignAllMergeActions(evolved, evolved.targetTable.output)) diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41UpdateTableRewrite.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41UpdateTableRewrite.scala index d8082b592e08..97edbdc780c3 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41UpdateTableRewrite.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/Spark41UpdateTableRewrite.scala @@ -47,8 +47,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * We fire before `ResolveAssignments`, so `u.aligned` is `false`; the rule pre-aligns via * `PaimonAssignmentUtils.alignUpdateAssignments` before building the plan. * - * PK tables go through the postHoc rule; RT / DE / DV tables go through Spark's V2 path. DELETE is - * handled by [[Spark41DeleteMetadataRestore]]; MERGE by [[Spark41MergeIntoRewrite]]. + * Row-tracking-only tables use the same V2 copy-on-write rewrite. PK / DE / DV tables go through + * the postHoc V1 rule because they do not expose `SupportsRowLevelOperations`. DELETE is handled by + * [[Spark41DeleteMetadataRestore]]; MERGE by [[Spark41MergeIntoRewrite]]. */ object Spark41UpdateTableRewrite extends RewriteRowLevelCommand with PureAppendOnlyScope { @@ -57,7 +58,7 @@ object Spark41UpdateTableRewrite extends RewriteRowLevelCommand with PureAppendO AnalysisHelper.allowInvokingTransformsInAnalyzer { plan.transformDown { case u @ UpdateTable(aliasedTable, assignments, cond) - if u.resolved && u.rewritable && targetsPureAppendOnly(aliasedTable) => + if u.resolved && u.rewritable && targetsV2CopyOnWriteTable(aliasedTable) => EliminateSubqueryAliases(aliasedTable) match { case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) => val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty()) diff --git a/paimon-tantivy/paimon-tantivy-index/README.md b/paimon-tantivy/paimon-tantivy-index/README.md index 18ea8bfc5846..0b93d15875a4 100644 --- a/paimon-tantivy/paimon-tantivy-index/README.md +++ b/paimon-tantivy/paimon-tantivy-index/README.md @@ -113,6 +113,55 @@ CALL sys.create_global_index( ); ``` +### Tokenizers + +By default, Tantivy uses its built-in tokenizer. For Chinese or other languages where users often +search by short character fragments, build the index with the `ngram` tokenizer: + +```sql +CALL sys.create_global_index( + table => 'db.my_table', + index_column => 'content', + index_type => 'tantivy-fulltext', + options => 'tantivy.tokenizer=ngram,tantivy.ngram.min-gram=2,tantivy.ngram.max-gram=2' +); +``` + +For Chinese word segmentation, build the index with the `jieba` tokenizer: + +```sql +CALL sys.create_global_index( + table => 'db.my_table', + index_column => 'content', + index_type => 'tantivy-fulltext', + options => 'tantivy.tokenizer=jieba' +); +``` + +Available tokenizer options: + +| Option | Default | Description | +|--------|---------|-------------| +| `tantivy.tokenizer` | `default` | Tokenizer used by the full-text index. Supported values: `default`, `simple`, `whitespace`, `raw`, `ngram`, `jieba`. | +| `tantivy.ngram.min-gram` | `2` | Minimum gram length for the `ngram` tokenizer. | +| `tantivy.ngram.max-gram` | `2` | Maximum gram length for the `ngram` tokenizer. | +| `tantivy.ngram.prefix-only` | `false` | Whether the `ngram` tokenizer only emits prefix ngrams. | +| `tantivy.lower-case` | `true` | Whether configurable tokenizers lowercase emitted tokens. | +| `tantivy.max-token-length` | `40` | Maximum token length kept by configurable tokenizers. | +| `tantivy.ascii-folding` | `false` | Whether to normalize non-ASCII Latin characters to ASCII. | +| `tantivy.stem` | `false` | Whether to apply stemming to emitted tokens. | +| `tantivy.language` | `english` | Language used by stemming and built-in stop word filters. | +| `tantivy.remove-stop-words` | `false` | Whether to remove built-in stop words for the configured language. | +| `tantivy.stop-words` | ` ` | Semicolon-separated custom stop words to remove. | +| `tantivy.with-position` | `true` | Whether to store term positions for phrase queries. | + +Tokenizer settings are persisted in each global index file's metadata. Readers use that metadata +when reopening an index, so changing table/procedure options later does not make existing index +files use a different analyzer. +Custom analysis is provided by composing the supported tokenizer and filter options above; Paimon +does not load arbitrary Rust tokenizer plugins from configuration. +PyPaimon can query `jieba` indexes when the Python `jieba` package is installed. + ### Search ```sql diff --git a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexReader.java b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexReader.java index 64d06af2c836..67a0762a8ef1 100644 --- a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexReader.java +++ b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexReader.java @@ -34,10 +34,13 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; import static org.apache.paimon.utils.Preconditions.checkArgument; @@ -57,6 +60,8 @@ public class TantivyFullTextGlobalIndexReader implements GlobalIndexReader { private final Map layoutCache; private final TantivySearcherPool searcherPool; private final String poolKey; + private final ExecutorService executor; + private final TantivyFullTextIndexOptions indexOptions; private volatile TantivySearcherPool.PooledEntry borrowed; @@ -64,25 +69,41 @@ public TantivyFullTextGlobalIndexReader( GlobalIndexFileReader fileReader, List ioMetas, Map layoutCache, - TantivySearcherPool searcherPool) { + TantivySearcherPool searcherPool, + ExecutorService executor) { checkArgument(ioMetas.size() == 1, "Expected exactly one index file per shard"); + this.executor = executor; this.fileReader = fileReader; this.ioMeta = ioMetas.get(0); this.layoutCache = layoutCache; this.searcherPool = searcherPool; - this.poolKey = this.ioMeta.filePath().toString() + "@" + this.ioMeta.fileSize(); + this.poolKey = + this.ioMeta.filePath().toString() + + "@" + + this.ioMeta.fileSize() + + "#" + + Arrays.hashCode(this.ioMeta.metadata()); + this.indexOptions = deserializeIndexOptions(this.ioMeta.metadata()); } @Override - public Optional visitFullTextSearch(FullTextSearch fullTextSearch) { - try { - ensureLoaded(); - SearchResult result = - borrowed.searcher.search(fullTextSearch.queryText(), fullTextSearch.limit()); - return Optional.of(toScoredResult(result)); - } catch (IOException e) { - throw new RuntimeException("Failed to search Tantivy full-text index", e); - } + public CompletableFuture> visitFullTextSearch( + FullTextSearch fullTextSearch) { + return CompletableFuture.supplyAsync( + () -> { + try { + ensureLoaded(); + SearchResult result = + borrowed.searcher.search( + fullTextSearch.queryText(), + fullTextSearch.limit(), + fullTextSearch.queryOperator()); + return Optional.of(toScoredResult(result)); + } catch (IOException e) { + throw new RuntimeException("Failed to search Tantivy full-text index", e); + } + }, + executor); } private ScoredGlobalIndexResult toScoredResult(SearchResult result) { @@ -121,7 +142,11 @@ private TantivySearcherPool.PooledEntry createEntry() throws IOException { StreamFileInput streamInput = new SynchronizedStreamFileInput(in); TantivySearcher searcher = new TantivySearcher( - layout.fileNames, layout.fileOffsets, layout.fileLengths, streamInput); + layout.fileNames, + layout.fileOffsets, + layout.fileLengths, + streamInput, + indexOptions.toNativeConfigJson()); return new TantivySearcherPool.PooledEntry(searcher, in); } catch (Exception e) { in.close(); @@ -129,6 +154,15 @@ private TantivySearcherPool.PooledEntry createEntry() throws IOException { } } + static TantivyFullTextIndexOptions deserializeIndexOptions(byte[] metadata) { + try { + return TantivyFullTextIndexOptions.deserialize(metadata); + } catch (IOException e) { + throw new IllegalArgumentException( + "Failed to deserialize Tantivy full-text index meta", e); + } + } + /** * Parse the archive header to extract file names, offsets, and lengths. The archive format is: * [fileCount(4)] then for each file: [nameLen(4)] [name(utf8)] [dataLen(8)] [data]. @@ -200,73 +234,85 @@ public void close() throws IOException { // =================== unsupported ===================== @Override - public Optional visitIsNotNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNotNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIsNull(FieldRef fieldRef) { - return Optional.empty(); + public CompletableFuture> visitIsNull(FieldRef fieldRef) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitStartsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitStartsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEndsWith(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEndsWith( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitContains(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitContains( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLike(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLike( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitNotEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitLessOrEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitEqual(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitEqual( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { - return Optional.empty(); + public CompletableFuture> visitGreaterThan( + FieldRef fieldRef, Object literal) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } @Override - public Optional visitNotIn(FieldRef fieldRef, List literals) { - return Optional.empty(); + public CompletableFuture> visitNotIn( + FieldRef fieldRef, List literals) { + return CompletableFuture.completedFuture(Optional.empty()); } /** diff --git a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexWriter.java b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexWriter.java index 6def9cb69e4f..66b6c19f6dc5 100644 --- a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexWriter.java +++ b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexWriter.java @@ -55,20 +55,29 @@ public class TantivyFullTextGlobalIndexWriter implements GlobalIndexSingletonWri LoggerFactory.getLogger(TantivyFullTextGlobalIndexWriter.class); private final GlobalIndexFileWriter fileWriter; + private final TantivyFullTextIndexOptions indexOptions; private File tempIndexDir; private TantivyIndexWriter writer; private long rowId; private boolean closed; public TantivyFullTextGlobalIndexWriter(GlobalIndexFileWriter fileWriter) { + this(fileWriter, TantivyFullTextIndexOptions.defaults()); + } + + public TantivyFullTextGlobalIndexWriter( + GlobalIndexFileWriter fileWriter, TantivyFullTextIndexOptions indexOptions) { this.fileWriter = fileWriter; + this.indexOptions = indexOptions; this.rowId = 0; this.closed = false; try { this.tempIndexDir = Files.createTempDirectory("tantivy-index-").toFile(); this.tempIndexDir.deleteOnExit(); - this.writer = new TantivyIndexWriter(tempIndexDir.getAbsolutePath()); + this.writer = + new TantivyIndexWriter( + tempIndexDir.getAbsolutePath(), indexOptions.toNativeConfigJson()); } catch (IOException e) { throw new RuntimeException("Failed to create temp index directory", e); } @@ -162,7 +171,7 @@ private ResultEntry packIndex() throws IOException { } LOG.info("Tantivy index packed: {} documents", rowId); - return new ResultEntry(fileName, rowId, null); + return new ResultEntry(fileName, rowId, indexOptions.serialize()); } private static void writeInt(PositionOutputStream out, int value) throws IOException { diff --git a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexer.java b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexer.java index add95cd8c577..07dcf8289c9c 100644 --- a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexer.java +++ b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexer.java @@ -28,25 +28,36 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; /** Tantivy full-text global indexer. */ public class TantivyFullTextGlobalIndexer implements GlobalIndexer { private final Map layoutCache = new ConcurrentHashMap<>(); private final TantivySearcherPool searcherPool; + private final TantivyFullTextIndexOptions indexOptions; - public TantivyFullTextGlobalIndexer(TantivySearcherPool searcherPool) { + public TantivyFullTextGlobalIndexer( + TantivySearcherPool searcherPool, TantivyFullTextIndexOptions indexOptions) { this.searcherPool = searcherPool; + this.indexOptions = indexOptions; + } + + public TantivyFullTextGlobalIndexer(TantivySearcherPool searcherPool) { + this(searcherPool, TantivyFullTextIndexOptions.defaults()); } @Override public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) { - return new TantivyFullTextGlobalIndexWriter(fileWriter); + return new TantivyFullTextGlobalIndexWriter(fileWriter, indexOptions); } @Override public GlobalIndexReader createReader( - GlobalIndexFileReader fileReader, List files) { - return new TantivyFullTextGlobalIndexReader(fileReader, files, layoutCache, searcherPool); + GlobalIndexFileReader fileReader, + List files, + ExecutorService executor) { + return new TantivyFullTextGlobalIndexReader( + fileReader, files, layoutCache, searcherPool, executor); } } diff --git a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexerFactory.java b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexerFactory.java index e54294340c1d..4fc0e9cf6b15 100644 --- a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexerFactory.java +++ b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexerFactory.java @@ -23,6 +23,9 @@ import org.apache.paimon.options.Options; import org.apache.paimon.types.DataField; +import java.util.LinkedHashMap; +import java.util.Map; + /** Factory for creating Tantivy full-text index. */ public class TantivyFullTextGlobalIndexerFactory implements GlobalIndexerFactory { @@ -45,6 +48,19 @@ public GlobalIndexer create(DataField field, Options options) { } } } - return new TantivyFullTextGlobalIndexer(searcherPool); + return new TantivyFullTextGlobalIndexer( + searcherPool, new TantivyFullTextIndexOptions(removeTantivyPrefix(options))); + } + + static Map removeTantivyPrefix(Options options) { + Map result = new LinkedHashMap<>(); + for (String key : options.keySet()) { + if (key.startsWith(TantivyFullTextIndexOptions.TANTIVY_PREFIX)) { + result.put( + key.substring(TantivyFullTextIndexOptions.TANTIVY_PREFIX.length()), + options.get(key)); + } + } + return result; } } diff --git a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptions.java b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptions.java index dbc4cd396079..f6f7133dada9 100644 --- a/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptions.java +++ b/paimon-tantivy/paimon-tantivy-index/src/main/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptions.java @@ -20,10 +20,102 @@ import org.apache.paimon.options.ConfigOption; import org.apache.paimon.options.ConfigOptions; +import org.apache.paimon.utils.JsonSerdeUtil; +import org.apache.paimon.utils.Preconditions; + +import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.core.type.TypeReference; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; /** Options for the Tantivy full-text index. */ public class TantivyFullTextIndexOptions { + static final String TANTIVY_PREFIX = "tantivy."; + + public static final ConfigOption TOKENIZER = + ConfigOptions.key("tantivy.tokenizer") + .stringType() + .defaultValue("default") + .withDescription( + "Tokenizer for Tantivy full-text index. Supported values are 'default', 'simple', 'whitespace', 'raw', 'ngram', and 'jieba'."); + + public static final ConfigOption NGRAM_MIN_GRAM = + ConfigOptions.key("tantivy.ngram.min-gram") + .intType() + .defaultValue(2) + .withDescription("Minimum ngram length when 'tantivy.tokenizer' is 'ngram'."); + + public static final ConfigOption NGRAM_MAX_GRAM = + ConfigOptions.key("tantivy.ngram.max-gram") + .intType() + .defaultValue(2) + .withDescription("Maximum ngram length when 'tantivy.tokenizer' is 'ngram'."); + + public static final ConfigOption NGRAM_PREFIX_ONLY = + ConfigOptions.key("tantivy.ngram.prefix-only") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether ngram tokenizer should only emit prefix ngrams when 'tantivy.tokenizer' is 'ngram'."); + + public static final ConfigOption LOWER_CASE = + ConfigOptions.key("tantivy.lower-case") + .booleanType() + .defaultValue(true) + .withDescription("Whether to lowercase tokens for configurable tokenizers."); + + public static final ConfigOption MAX_TOKEN_LENGTH = + ConfigOptions.key("tantivy.max-token-length") + .intType() + .defaultValue(40) + .withDescription("Maximum token length kept by configurable tokenizers."); + + public static final ConfigOption ASCII_FOLDING = + ConfigOptions.key("tantivy.ascii-folding") + .booleanType() + .defaultValue(false) + .withDescription("Whether to normalize non-ASCII latin characters to ASCII."); + + public static final ConfigOption STEM = + ConfigOptions.key("tantivy.stem") + .booleanType() + .defaultValue(false) + .withDescription("Whether to apply stemming to emitted tokens."); + + public static final ConfigOption LANGUAGE = + ConfigOptions.key("tantivy.language") + .stringType() + .defaultValue("english") + .withDescription("Language used by stemming and built-in stop word filters."); + + public static final ConfigOption REMOVE_STOP_WORDS = + ConfigOptions.key("tantivy.remove-stop-words") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether to remove built-in stop words for the configured language."); + + public static final ConfigOption STOP_WORDS = + ConfigOptions.key("tantivy.stop-words") + .stringType() + .defaultValue("") + .withDescription("Semicolon-separated custom stop words to remove."); + + public static final ConfigOption WITH_POSITION = + ConfigOptions.key("tantivy.with-position") + .booleanType() + .defaultValue(true) + .withDescription("Whether to store term positions for phrase queries."); + public static final ConfigOption SEARCHER_POOL_MAX_SIZE = ConfigOptions.key("tantivy.searcher-pool.max-size") .intType() @@ -34,4 +126,332 @@ public class TantivyFullTextIndexOptions { + "Rust memory (including the FST term dictionary), so memory " + "usage scales with this value times the index size per shard. " + "Set to 0 to disable pooling."); + + private final Map config; + + public TantivyFullTextIndexOptions(Map options) { + this.config = + Collections.unmodifiableMap(toNativeConfigMap(Preconditions.checkNotNull(options))); + } + + public static TantivyFullTextIndexOptions defaults() { + return new TantivyFullTextIndexOptions(Collections.emptyMap()); + } + + public String tokenizer() { + return getString(optionKey(TOKENIZER), TOKENIZER.defaultValue()); + } + + public int ngramMinGram() { + return getInt(optionKey(NGRAM_MIN_GRAM), NGRAM_MIN_GRAM.defaultValue()); + } + + public int ngramMaxGram() { + return getInt(optionKey(NGRAM_MAX_GRAM), NGRAM_MAX_GRAM.defaultValue()); + } + + public boolean ngramPrefixOnly() { + return getBoolean(optionKey(NGRAM_PREFIX_ONLY), NGRAM_PREFIX_ONLY.defaultValue()); + } + + public boolean lowerCase() { + return getBoolean(optionKey(LOWER_CASE), LOWER_CASE.defaultValue()); + } + + public int maxTokenLength() { + return getInt(optionKey(MAX_TOKEN_LENGTH), MAX_TOKEN_LENGTH.defaultValue()); + } + + public boolean asciiFolding() { + return getBoolean(optionKey(ASCII_FOLDING), ASCII_FOLDING.defaultValue()); + } + + public boolean stem() { + return getBoolean(optionKey(STEM), STEM.defaultValue()); + } + + public String language() { + return getString(optionKey(LANGUAGE), LANGUAGE.defaultValue()); + } + + public boolean removeStopWords() { + return getBoolean(optionKey(REMOVE_STOP_WORDS), REMOVE_STOP_WORDS.defaultValue()); + } + + public String stopWords() { + return joinStopWords(stopWordList()); + } + + public boolean withPosition() { + return getBoolean(optionKey(WITH_POSITION), WITH_POSITION.defaultValue()); + } + + public List stopWordList() { + return getStringList(optionKey(STOP_WORDS)); + } + + public String toNativeConfigJson() { + return JsonSerdeUtil.toFlatJson(config); + } + + public byte[] serialize() throws IOException { + return toNativeConfigJson().getBytes(StandardCharsets.UTF_8); + } + + public static TantivyFullTextIndexOptions deserialize(byte[] data) throws IOException { + if (data == null || data.length == 0) { + return defaults(); + } + + return fromNativeConfigJson(new String(data, StandardCharsets.UTF_8)); + } + + private static Map toNativeConfigMap(Map options) { + Map config = new LinkedHashMap<>(); + String tokenizer = + normalizeTokenizer( + getString(options, optionKey(TOKENIZER), TOKENIZER.defaultValue())); + validateTokenizer(tokenizer); + if (!tokenizer.equals(TOKENIZER.defaultValue())) { + config.put(optionKey(TOKENIZER), tokenizer); + } + int ngramMinGram = + getInt(options, optionKey(NGRAM_MIN_GRAM), NGRAM_MIN_GRAM.defaultValue()); + int ngramMaxGram = + getInt(options, optionKey(NGRAM_MAX_GRAM), NGRAM_MAX_GRAM.defaultValue()); + validateNgramGrams(ngramMinGram, ngramMaxGram); + if (ngramMinGram != NGRAM_MIN_GRAM.defaultValue()) { + config.put(optionKey(NGRAM_MIN_GRAM), ngramMinGram); + } + if (ngramMaxGram != NGRAM_MAX_GRAM.defaultValue()) { + config.put(optionKey(NGRAM_MAX_GRAM), ngramMaxGram); + } + boolean ngramPrefixOnly = + getBoolean(options, optionKey(NGRAM_PREFIX_ONLY), NGRAM_PREFIX_ONLY.defaultValue()); + if (ngramPrefixOnly != NGRAM_PREFIX_ONLY.defaultValue()) { + config.put(optionKey(NGRAM_PREFIX_ONLY), ngramPrefixOnly); + } + boolean lowerCase = getBoolean(options, optionKey(LOWER_CASE), LOWER_CASE.defaultValue()); + if (lowerCase != LOWER_CASE.defaultValue()) { + config.put(optionKey(LOWER_CASE), lowerCase); + } + int maxTokenLength = + getInt(options, optionKey(MAX_TOKEN_LENGTH), MAX_TOKEN_LENGTH.defaultValue()); + validateMaxTokenLength(maxTokenLength); + if (maxTokenLength != MAX_TOKEN_LENGTH.defaultValue()) { + config.put(optionKey(MAX_TOKEN_LENGTH), maxTokenLength); + } + boolean asciiFolding = + getBoolean(options, optionKey(ASCII_FOLDING), ASCII_FOLDING.defaultValue()); + if (asciiFolding != ASCII_FOLDING.defaultValue()) { + config.put(optionKey(ASCII_FOLDING), asciiFolding); + } + boolean stem = getBoolean(options, optionKey(STEM), STEM.defaultValue()); + if (stem != STEM.defaultValue()) { + config.put(optionKey(STEM), stem); + } + String language = + normalizeLanguage(getString(options, optionKey(LANGUAGE), LANGUAGE.defaultValue())); + validateLanguage(language); + if (!language.equals(LANGUAGE.defaultValue())) { + config.put(optionKey(LANGUAGE), language); + } + boolean removeStopWords = + getBoolean(options, optionKey(REMOVE_STOP_WORDS), REMOVE_STOP_WORDS.defaultValue()); + if (removeStopWords != REMOVE_STOP_WORDS.defaultValue()) { + config.put(optionKey(REMOVE_STOP_WORDS), removeStopWords); + } + List stopWordList = getStopWordList(options, optionKey(STOP_WORDS)); + if (!stopWordList.isEmpty()) { + config.put(optionKey(STOP_WORDS), Collections.unmodifiableList(stopWordList)); + } + boolean withPosition = + getBoolean(options, optionKey(WITH_POSITION), WITH_POSITION.defaultValue()); + if (withPosition != WITH_POSITION.defaultValue()) { + config.put(optionKey(WITH_POSITION), withPosition); + } + return config; + } + + private static TantivyFullTextIndexOptions fromNativeConfigJson(String json) + throws IOException { + try { + Map config = + JsonSerdeUtil.fromJson(json, new TypeReference>() {}); + return new TantivyFullTextIndexOptions(config); + } catch (UncheckedIOException e) { + throw e.getCause(); + } + } + + private String getString(String key, String defaultValue) { + Object value = config.get(key); + return value == null ? defaultValue : (String) value; + } + + private int getInt(String key, int defaultValue) { + Object value = config.get(key); + return value == null ? defaultValue : (int) value; + } + + private boolean getBoolean(String key, boolean defaultValue) { + Object value = config.get(key); + return value == null ? defaultValue : (boolean) value; + } + + @SuppressWarnings("unchecked") + private List getStringList(String key) { + Object value = config.get(key); + if (value == null) { + return new ArrayList<>(); + } + return new ArrayList<>((List) value); + } + + private static String joinStopWords(List stopWords) { + if (stopWords == null || stopWords.isEmpty()) { + return ""; + } + + List words = new ArrayList<>(); + for (String stopWord : stopWords) { + if (stopWord != null) { + words.add(stopWord); + } + } + return String.join(";", words); + } + + private static List toStopWordList(Object value) { + if (value == null) { + return new ArrayList<>(); + } + if (value instanceof List) { + List words = new ArrayList<>(); + for (Object word : (List) value) { + if (word != null) { + words.add(word.toString()); + } + } + return words; + } + return normalizeStopWordList(value.toString()); + } + + private static String getString(Map options, String key, String defaultValue) { + Object value = options.get(key); + return value == null ? defaultValue : value.toString(); + } + + private static int getInt(Map options, String key, int defaultValue) { + Object value = options.get(key); + if (value == null) { + return defaultValue; + } + if (value instanceof Number) { + return ((Number) value).intValue(); + } + return Integer.parseInt(value.toString()); + } + + private static boolean getBoolean( + Map options, String key, boolean defaultValue) { + Object value = options.get(key); + if (value == null) { + return defaultValue; + } + if (value instanceof Boolean) { + return (boolean) value; + } + return Boolean.parseBoolean(value.toString()); + } + + private static List getStopWordList(Map options, String key) { + Object value = options.get(key); + if (value == null) { + return new ArrayList<>(); + } + if (value instanceof List) { + return toStopWordList(value); + } + return normalizeStopWordList(value.toString()); + } + + private static String optionKey(ConfigOption option) { + return option.key().substring(TANTIVY_PREFIX.length()); + } + + private static String normalizeTokenizer(String tokenizer) { + return tokenizer == null ? "" : tokenizer.trim().toLowerCase(); + } + + private static String normalizeLanguage(String language) { + return language == null ? "" : language.trim().toLowerCase(Locale.ROOT); + } + + private static List normalizeStopWordList(String stopWords) { + List words = new ArrayList<>(); + if (stopWords == null) { + return words; + } + for (String word : stopWords.split(";")) { + String trimmed = word.trim(); + if (!trimmed.isEmpty()) { + words.add(trimmed); + } + } + return words; + } + + private static void validateTokenizer(String tokenizer) { + Preconditions.checkArgument( + "default".equals(tokenizer) + || "simple".equals(tokenizer) + || "whitespace".equals(tokenizer) + || "raw".equals(tokenizer) + || "ngram".equals(tokenizer) + || "jieba".equals(tokenizer), + "Unsupported Tantivy tokenizer: %s", + tokenizer); + } + + private static void validateNgramGrams(int ngramMinGram, int ngramMaxGram) { + Preconditions.checkArgument(ngramMinGram > 0, "ngram min gram must be positive."); + Preconditions.checkArgument(ngramMaxGram > 0, "ngram max gram must be positive."); + Preconditions.checkArgument( + ngramMinGram <= ngramMaxGram, "ngram min gram must not be greater than max gram."); + } + + private static void validateMaxTokenLength(int maxTokenLength) { + Preconditions.checkArgument(maxTokenLength > 0, "max token length must be positive."); + } + + private static void validateLanguage(String language) { + Preconditions.checkArgument( + supportedLanguages().contains(language), + "Unsupported Tantivy language: %s", + language); + } + + private static List supportedLanguages() { + return Arrays.asList( + "arabic", + "danish", + "dutch", + "english", + "finnish", + "french", + "german", + "greek", + "hungarian", + "italian", + "norwegian", + "portuguese", + "romanian", + "russian", + "spanish", + "swedish", + "tamil", + "turkish"); + } } diff --git a/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/JavaPyTantivyE2ETest.java b/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/JavaPyTantivyE2ETest.java index fa46e898c377..fcb2ef7acb3b 100644 --- a/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/JavaPyTantivyE2ETest.java +++ b/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/JavaPyTantivyE2ETest.java @@ -56,6 +56,7 @@ import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -103,7 +104,47 @@ public void before() throws Exception { @Test @EnabledIfSystemProperty(named = "run.e2e.tests", matches = "true") public void testTantivyFullTextIndexWrite() throws Exception { - String tableName = "test_tantivy_fulltext"; + writeTableWithTantivyIndex( + "test_tantivy_fulltext", + Arrays.asList( + "Apache Paimon is a streaming data lake platform", + "Tantivy is a full-text search engine written in Rust", + "Paimon supports real-time data ingestion and analytics", + "Full-text search enables efficient text retrieval", + "Data lake platforms like Paimon handle large-scale data"), + "default"); + + writeTableWithTantivyIndex( + "test_tantivy_fulltext_ngram", + Arrays.asList( + "Apache Paimon 支持中文全文检索", + "Tantivy ngram tokenizer helps Chinese search", + "湖仓表支持实时数据分析", + "默认分词适合英文内容", + "中文索引支持片段查询"), + "ngram"); + + writeTableWithTantivyIndex( + "test_tantivy_fulltext_simple", + Arrays.asList( + "Running runners search Apache Paimon", + "Run search with Paimon lake", + "The connector runs analytics"), + "simple"); + + writeTableWithTantivyIndex( + "test_tantivy_fulltext_jieba", + Arrays.asList( + "张华在百货公司当售货员", + "Apache Paimon supports full text search", + "李萍进入中等技术学校学习", + "中文分词支持更自然的全文检索", + "默认英文分词不适合中文语义"), + "jieba"); + } + + private void writeTableWithTantivyIndex( + String tableName, List contents, String tokenizer) throws Exception { Path tablePath = new Path(warehouse.toString() + "/default.db/" + tableName); RowType rowType = @@ -138,37 +179,25 @@ public void testTantivyFullTextIndexWrite() throws Exception { BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); try (BatchTableWrite write = writeBuilder.newWrite(); BatchTableCommit commit = writeBuilder.newCommit()) { - write.write( - GenericRow.of( - 0, - BinaryString.fromString( - "Apache Paimon is a streaming data lake platform"))); - write.write( - GenericRow.of( - 1, - BinaryString.fromString( - "Tantivy is a full-text search engine written in Rust"))); - write.write( - GenericRow.of( - 2, - BinaryString.fromString( - "Paimon supports real-time data ingestion and analytics"))); - write.write( - GenericRow.of( - 3, - BinaryString.fromString( - "Full-text search enables efficient text retrieval"))); - write.write( - GenericRow.of( - 4, - BinaryString.fromString( - "Data lake platforms like Paimon handle large-scale data"))); + for (int i = 0; i < contents.size(); i++) { + write.write(GenericRow.of(i, BinaryString.fromString(contents.get(i)))); + } commit.commit(write.prepareCommit()); } // Build tantivy full-text index on the "content" column DataField contentField = table.rowType().getField("content"); Options indexOptions = table.coreOptions().toConfiguration(); + if (!"default".equals(tokenizer)) { + indexOptions.set(TantivyFullTextIndexOptions.TOKENIZER, tokenizer); + } + if ("ngram".equals(tokenizer)) { + indexOptions.set(TantivyFullTextIndexOptions.NGRAM_MIN_GRAM, 2); + indexOptions.set(TantivyFullTextIndexOptions.NGRAM_MAX_GRAM, 2); + } else if ("simple".equals(tokenizer)) { + indexOptions.set(TantivyFullTextIndexOptions.STEM, true); + indexOptions.set(TantivyFullTextIndexOptions.REMOVE_STOP_WORDS, true); + } GlobalIndexSingletonWriter writer = (GlobalIndexSingletonWriter) @@ -178,21 +207,25 @@ public void testTantivyFullTextIndexWrite() throws Exception { contentField, indexOptions); - // Write the same text data to the index - writer.write(BinaryString.fromString("Apache Paimon is a streaming data lake platform")); - writer.write( - BinaryString.fromString("Tantivy is a full-text search engine written in Rust")); - writer.write( - BinaryString.fromString("Paimon supports real-time data ingestion and analytics")); - writer.write(BinaryString.fromString("Full-text search enables efficient text retrieval")); - writer.write( - BinaryString.fromString("Data lake platforms like Paimon handle large-scale data")); + // Write the same text data to the index. + for (String content : contents) { + writer.write(BinaryString.fromString(content)); + } List entries = writer.finish(); assertThat(entries).hasSize(1); - assertThat(entries.get(0).rowCount()).isEqualTo(5); + assertThat(entries.get(0).rowCount()).isEqualTo(contents.size()); + TantivyFullTextIndexOptions persistedOptions = + TantivyFullTextIndexOptions.deserialize(entries.get(0).meta()); + assertThat(persistedOptions.tokenizer()).isEqualTo(tokenizer); + assertThat(persistedOptions.ngramMinGram()).isEqualTo(2); + assertThat(persistedOptions.ngramMaxGram()).isEqualTo(2); + if ("simple".equals(tokenizer)) { + assertThat(persistedOptions.stem()).isTrue(); + assertThat(persistedOptions.removeStopWords()).isTrue(); + } - Range rowRange = new Range(0, 4); + Range rowRange = new Range(0, contents.size() - 1); List indexFiles = GlobalIndexBuilderUtils.toIndexFileMetas( table.fileIO(), diff --git a/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexTest.java b/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexTest.java index 8f3f13fadeb4..de13f5164528 100644 --- a/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexTest.java +++ b/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextGlobalIndexTest.java @@ -28,12 +28,12 @@ import org.apache.paimon.globalindex.ScoredGlobalIndexResult; import org.apache.paimon.globalindex.io.GlobalIndexFileReader; import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.options.Options; import org.apache.paimon.predicate.FullTextSearch; import org.apache.paimon.tantivy.NativeLoader; import org.apache.paimon.utils.RoaringNavigableMap64; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -46,6 +46,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import static org.apache.paimon.shade.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -54,11 +55,6 @@ */ public class TantivyFullTextGlobalIndexTest { - @BeforeAll - static void checkNativeLibrary() { - assumeTrue(isNativeAvailable(), "Tantivy native library not available, skipping tests"); - } - private static boolean isNativeAvailable() { try { NativeLoader.loadJni(); @@ -77,6 +73,7 @@ private static boolean isNativeAvailable() { @BeforeEach public void setup() { + assumeTrue(isNativeAvailable(), "Tantivy native library not available, skipping tests"); fileIO = new LocalFileIO(); indexPath = new Path(tempDir.toString()); layoutCache = new ConcurrentHashMap<>(); @@ -119,7 +116,8 @@ private List toIOMetas(List results, Path path) private TantivyFullTextGlobalIndexReader createReader( GlobalIndexFileReader fileReader, List metas) { - return new TantivyFullTextGlobalIndexReader(fileReader, metas, layoutCache, pool); + return new TantivyFullTextGlobalIndexReader( + fileReader, metas, layoutCache, pool, newDirectExecutorService()); } @Test @@ -140,7 +138,8 @@ public void testEndToEnd() throws IOException { try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { FullTextSearch search = new FullTextSearch("paimon", 10, "text"); - Optional searchResult = reader.visitFullTextSearch(search); + Optional searchResult = + reader.visitFullTextSearch(search).join(); assertThat(searchResult).isPresent(); ScoredGlobalIndexResult scored = searchResult.get(); @@ -158,6 +157,66 @@ public void testEndToEnd() throws IOException { } } + @Test + public void testWriterPersistsTokenizerMeta() throws IOException { + Options options = new Options(); + options.set(TantivyFullTextIndexOptions.TOKENIZER, "ngram"); + options.set(TantivyFullTextIndexOptions.NGRAM_MIN_GRAM, 2); + options.set(TantivyFullTextIndexOptions.NGRAM_MAX_GRAM, 2); + + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + TantivyFullTextGlobalIndexWriter writer = + new TantivyFullTextGlobalIndexWriter( + fileWriter, + new TantivyFullTextIndexOptions( + TantivyFullTextGlobalIndexerFactory.removeTantivyPrefix(options))); + + writer.write(BinaryString.fromString("Apache Paimon supports Chinese text")); + List results = writer.finish(); + + assertThat(results).hasSize(1); + TantivyFullTextIndexOptions indexOptions = + TantivyFullTextIndexOptions.deserialize(results.get(0).meta()); + assertThat(indexOptions.tokenizer()).isEqualTo("ngram"); + assertThat(indexOptions.ngramMinGram()).isEqualTo(2); + assertThat(indexOptions.ngramMaxGram()).isEqualTo(2); + } + + @Test + public void testJiebaTokenizerFindsChineseWord() throws IOException { + Options options = new Options(); + options.set(TantivyFullTextIndexOptions.TOKENIZER, "jieba"); + + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + TantivyFullTextGlobalIndexWriter writer = + new TantivyFullTextGlobalIndexWriter( + fileWriter, + new TantivyFullTextIndexOptions( + TantivyFullTextGlobalIndexerFactory.removeTantivyPrefix(options))); + + writer.write(BinaryString.fromString("张华在百货公司当售货员")); + writer.write(BinaryString.fromString("Apache Paimon supports full text search")); + + List results = writer.finish(); + TantivyFullTextIndexOptions indexOptions = + TantivyFullTextIndexOptions.deserialize(results.get(0).meta()); + assertThat(indexOptions.tokenizer()).isEqualTo("jieba"); + + List metas = toIOMetas(results, indexPath); + GlobalIndexFileReader fileReader = createFileReader(); + + try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { + FullTextSearch search = new FullTextSearch("售货员", 10, "text"); + Optional searchResult = + reader.visitFullTextSearch(search).join(); + assertThat(searchResult).isPresent(); + + RoaringNavigableMap64 rowIds = searchResult.get().results(); + assertThat(rowIds.getLongCardinality()).isEqualTo(1); + assertThat(rowIds.contains(0L)).isTrue(); + } + } + @Test public void testSearchNoResults() throws IOException { GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); @@ -172,7 +231,8 @@ public void testSearchNoResults() throws IOException { try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { FullTextSearch search = new FullTextSearch("nonexistent", 10, "text"); - Optional searchResult = reader.visitFullTextSearch(search); + Optional searchResult = + reader.visitFullTextSearch(search).join(); assertThat(searchResult).isPresent(); RoaringNavigableMap64 rowIds = searchResult.get().results(); @@ -197,7 +257,8 @@ public void testNullFieldSkipped() throws IOException { try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { FullTextSearch search = new FullTextSearch("paimon", 10, "text"); - Optional searchResult = reader.visitFullTextSearch(search); + Optional searchResult = + reader.visitFullTextSearch(search).join(); assertThat(searchResult).isPresent(); ScoredGlobalIndexResult scored = searchResult.get(); @@ -242,7 +303,8 @@ public void testLargeDataset() throws IOException { try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { // Search for the special keyword — should match every 10th doc FullTextSearch search = new FullTextSearch("special_keyword", 1000, "text"); - Optional searchResult = reader.visitFullTextSearch(search); + Optional searchResult = + reader.visitFullTextSearch(search).join(); assertThat(searchResult).isPresent(); ScoredGlobalIndexResult scored = searchResult.get(); @@ -271,7 +333,8 @@ public void testLimitRespected() throws IOException { try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { // Limit to 5 results FullTextSearch search = new FullTextSearch("paimon", 5, "text"); - Optional searchResult = reader.visitFullTextSearch(search); + Optional searchResult = + reader.visitFullTextSearch(search).join(); assertThat(searchResult).isPresent(); RoaringNavigableMap64 rowIds = searchResult.get().results(); @@ -293,14 +356,14 @@ public void testPoolReuse() throws IOException { // First query: pool miss, searcher is loaded and returned to pool on close. try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { - Optional result = reader.visitFullTextSearch(search); + Optional result = reader.visitFullTextSearch(search).join(); assertThat(result).isPresent(); assertThat(result.get().results().contains(0L)).isTrue(); } // Second query: pool hit, reuses the same searcher. Results must be identical. try (TantivyFullTextGlobalIndexReader reader = createReader(fileReader, metas)) { - Optional result = reader.visitFullTextSearch(search); + Optional result = reader.visitFullTextSearch(search).join(); assertThat(result).isPresent(); assertThat(result.get().results().getLongCardinality()).isEqualTo(1); assertThat(result.get().results().contains(0L)).isTrue(); @@ -324,9 +387,11 @@ public void testViaIndexer() throws IOException { GlobalIndexFileReader fileReader = createFileReader(); try (TantivyFullTextGlobalIndexReader reader = - (TantivyFullTextGlobalIndexReader) indexer.createReader(fileReader, metas)) { + (TantivyFullTextGlobalIndexReader) + indexer.createReader(fileReader, metas, newDirectExecutorService())) { FullTextSearch search = new FullTextSearch("indexer", 10, "text"); - Optional searchResult = reader.visitFullTextSearch(search); + Optional searchResult = + reader.visitFullTextSearch(search).join(); assertThat(searchResult).isPresent(); assertThat(searchResult.get().results().getLongCardinality()).isEqualTo(1); assertThat(searchResult.get().results().contains(0L)).isTrue(); diff --git a/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptionsTest.java b/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptionsTest.java new file mode 100644 index 000000000000..79e12dfd3d13 --- /dev/null +++ b/paimon-tantivy/paimon-tantivy-index/src/test/java/org/apache/paimon/tantivy/index/TantivyFullTextIndexOptionsTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.tantivy.index; + +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link TantivyFullTextIndexOptions}. */ +public class TantivyFullTextIndexOptionsTest { + + @Test + public void testDeserializeEmptyMetaUsesDefaults() throws Exception { + TantivyFullTextIndexOptions indexOptions = + TantivyFullTextIndexOptions.deserialize(new byte[0]); + + assertThat(indexOptions.tokenizer()).isEqualTo("default"); + assertThat(indexOptions.ngramMinGram()).isEqualTo(2); + assertThat(indexOptions.ngramMaxGram()).isEqualTo(2); + assertThat(indexOptions.ngramPrefixOnly()).isFalse(); + assertThat(indexOptions.lowerCase()).isTrue(); + assertThat(indexOptions.maxTokenLength()).isEqualTo(40); + assertThat(indexOptions.asciiFolding()).isFalse(); + assertThat(indexOptions.stem()).isFalse(); + assertThat(indexOptions.language()).isEqualTo("english"); + assertThat(indexOptions.removeStopWords()).isFalse(); + assertThat(indexOptions.stopWords()).isEmpty(); + assertThat(indexOptions.withPosition()).isTrue(); + } + + @Test + public void testDefaultOptionsSerializeSparseJson() throws Exception { + TantivyFullTextIndexOptions indexOptions = TantivyFullTextIndexOptions.defaults(); + + assertThat(indexOptions.toNativeConfigJson()).isEqualTo("{}"); + assertThat(indexOptions.serialize()).isEqualTo("{}".getBytes(StandardCharsets.UTF_8)); + assertThat(TantivyFullTextIndexOptions.deserialize(indexOptions.serialize()).tokenizer()) + .isEqualTo("default"); + } + + @Test + public void testSerializeDeserializeNgramOptions() throws Exception { + Map options = new HashMap<>(); + options.put("tokenizer", " NGRAM "); + options.put("ngram.min-gram", 2); + options.put("ngram.max-gram", 3); + options.put("ngram.prefix-only", true); + options.put("lower-case", false); + + TantivyFullTextIndexOptions indexOptions = + TantivyFullTextIndexOptions.deserialize( + new TantivyFullTextIndexOptions(options).serialize()); + + assertThat(indexOptions.tokenizer()).isEqualTo("ngram"); + assertThat(indexOptions.ngramMinGram()).isEqualTo(2); + assertThat(indexOptions.ngramMaxGram()).isEqualTo(3); + assertThat(indexOptions.ngramPrefixOnly()).isTrue(); + assertThat(indexOptions.lowerCase()).isFalse(); + assertThat(indexOptions.toNativeConfigJson()) + .isEqualTo( + "{\"tokenizer\":\"ngram\",\"ngram.max-gram\":3," + + "\"ngram.prefix-only\":true,\"lower-case\":false}"); + } + + @Test + public void testSerializeDeserializeAnalyzerOptions() throws Exception { + Map options = new HashMap<>(); + options.put("tokenizer", " WHITESPACE "); + options.put("lower-case", false); + options.put("max-token-length", 12); + options.put("ascii-folding", true); + options.put("stem", true); + options.put("language", "English"); + options.put("remove-stop-words", true); + options.put("stop-words", "paimon;lake"); + options.put("with-position", false); + + TantivyFullTextIndexOptions indexOptions = + TantivyFullTextIndexOptions.deserialize( + new TantivyFullTextIndexOptions(options).serialize()); + + assertThat(indexOptions.tokenizer()).isEqualTo("whitespace"); + assertThat(indexOptions.lowerCase()).isFalse(); + assertThat(indexOptions.maxTokenLength()).isEqualTo(12); + assertThat(indexOptions.asciiFolding()).isTrue(); + assertThat(indexOptions.stem()).isTrue(); + assertThat(indexOptions.language()).isEqualTo("english"); + assertThat(indexOptions.removeStopWords()).isTrue(); + assertThat(indexOptions.stopWords()).isEqualTo("paimon;lake"); + assertThat(indexOptions.stopWordList()).containsExactly("paimon", "lake"); + assertThat(indexOptions.withPosition()).isFalse(); + assertThat(indexOptions.toNativeConfigJson()).contains("\"tokenizer\":\"whitespace\""); + assertThat(indexOptions.toNativeConfigJson()).doesNotContain("\"ngram.min-gram\""); + assertThat(indexOptions.toNativeConfigJson()).doesNotContain("\"ngram.max-gram\""); + assertThat(indexOptions.serialize()) + .isEqualTo(indexOptions.toNativeConfigJson().getBytes(StandardCharsets.UTF_8)); + } + + @Test + public void testSerializeDeserializeJiebaOptions() throws Exception { + Map options = new HashMap<>(); + options.put("tokenizer", " JIEBA "); + options.put("lower-case", true); + + TantivyFullTextIndexOptions indexOptions = + TantivyFullTextIndexOptions.deserialize( + new TantivyFullTextIndexOptions(options).serialize()); + + assertThat(indexOptions.tokenizer()).isEqualTo("jieba"); + assertThat(indexOptions.ngramMinGram()).isEqualTo(2); + assertThat(indexOptions.ngramMaxGram()).isEqualTo(2); + assertThat(indexOptions.ngramPrefixOnly()).isFalse(); + assertThat(indexOptions.lowerCase()).isTrue(); + } + + @Test + public void testValidateTokenizerOptions() { + Map unsupportedTokenizerOptions = new HashMap<>(); + unsupportedTokenizerOptions.put("tokenizer", "ik"); + assertThatThrownBy(() -> new TantivyFullTextIndexOptions(unsupportedTokenizerOptions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported Tantivy tokenizer"); + + Map invalidNgramOptions = new HashMap<>(); + invalidNgramOptions.put("tokenizer", "ngram"); + invalidNgramOptions.put("ngram.min-gram", 3); + invalidNgramOptions.put("ngram.max-gram", 2); + assertThatThrownBy(() -> new TantivyFullTextIndexOptions(invalidNgramOptions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ngram min gram must not be greater than max gram"); + + Map unsupportedLanguageOptions = new HashMap<>(); + unsupportedLanguageOptions.put("stem", true); + unsupportedLanguageOptions.put("language", "klingon"); + assertThatThrownBy(() -> new TantivyFullTextIndexOptions(unsupportedLanguageOptions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported Tantivy language"); + } +} diff --git a/paimon-tantivy/paimon-tantivy-jni/README.md b/paimon-tantivy/paimon-tantivy-jni/README.md index 0627885a85e5..12c77aa6b384 100644 --- a/paimon-tantivy/paimon-tantivy-jni/README.md +++ b/paimon-tantivy/paimon-tantivy-jni/README.md @@ -69,3 +69,54 @@ try (TantivySearcher searcher = new TantivySearcher("/tmp/my_index")) { } } ``` + +### Tokenizers + +Use the tokenizer-aware constructors when indexing Chinese text or other content that should match +short character fragments: + +```java +try (TantivyIndexWriter writer = + new TantivyIndexWriter("/tmp/my_index", "ngram", 2, 2, false, true)) { + writer.addDocument(1L, "Apache Paimon 支持中文全文检索"); + writer.commit(); +} + +try (TantivySearcher searcher = + new TantivySearcher("/tmp/my_index", "ngram", 2, 2, false, true)) { + SearchResult result = searcher.search("中文", 10); + System.out.println(result.size()); +} +``` + +Use `jieba` for Chinese word segmentation: + +```java +try (TantivyIndexWriter writer = + new TantivyIndexWriter("/tmp/my_index", "jieba", 2, 2, false, true)) { + writer.addDocument(1L, "张华在百货公司当售货员"); + writer.commit(); +} + +try (TantivySearcher searcher = + new TantivySearcher("/tmp/my_index", "jieba", 2, 2, false, true)) { + SearchResult result = searcher.search("售货员", 10); + System.out.println(result.size()); +} +``` + +For analyzer customization, pass a JSON tokenizer config. This supports the same base tokenizers +and token filters as the global index options: + +```java +String configJson = + "{\"tokenizer\":\"simple\",\"stem\":true,\"remove-stop-words\":true," + + "\"language\":\"english\"}"; +try (TantivyIndexWriter writer = new TantivyIndexWriter("/tmp/my_index", configJson); + TantivySearcher searcher = new TantivySearcher("/tmp/my_index", configJson)) { + writer.addDocument(1L, "running with Apache Paimon"); + writer.commit(); + SearchResult result = searcher.search("run", 10, "and"); + System.out.println(result.size()); +} +``` diff --git a/paimon-tantivy/paimon-tantivy-jni/rust/Cargo.toml b/paimon-tantivy/paimon-tantivy-jni/rust/Cargo.toml index a064ad8dd498..45f15d6da9b1 100644 --- a/paimon-tantivy/paimon-tantivy-jni/rust/Cargo.toml +++ b/paimon-tantivy/paimon-tantivy-jni/rust/Cargo.toml @@ -10,4 +10,6 @@ crate-type = ["cdylib"] [dependencies] jni = "0.21" tantivy = "0.22" +tantivy-jieba = "0.11.0" +serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/paimon-tantivy/paimon-tantivy-jni/rust/src/lib.rs b/paimon-tantivy/paimon-tantivy-jni/rust/src/lib.rs index aec47eaa8564..ef8983837d4d 100644 --- a/paimon-tantivy/paimon-tantivy-jni/rust/src/lib.rs +++ b/paimon-tantivy/paimon-tantivy-jni/rust/src/lib.rs @@ -18,13 +18,22 @@ mod jni_directory; use jni::objects::{JClass, JObject, JString, JValue}; -use jni::sys::{jfloat, jint, jlong, jobject}; +use jni::sys::{jboolean, jfloat, jint, jlong, jobject}; use jni::JNIEnv; +use serde::Deserialize; use std::ptr; use tantivy::collector::TopDocs; use tantivy::query::QueryParser; -use tantivy::schema::{Field, IndexRecordOption, NumericOptions, Schema, TextFieldIndexing, TextOptions}; +use tantivy::schema::{ + Field, IndexRecordOption, NumericOptions, Schema, TextFieldIndexing, TextOptions, +}; +use tantivy::tokenizer::{ + AsciiFoldingFilter, Language, LowerCaser, NgramTokenizer, RawTokenizer, RemoveLongFilter, + SimpleTokenizer, Stemmer, StopWordFilter, TextAnalyzer, TextAnalyzerBuilder, + WhitespaceTokenizer, +}; use tantivy::{Index, IndexReader, IndexWriter, ReloadPolicy}; +use tantivy_jieba::JiebaTokenizer; use crate::jni_directory::JniDirectory; @@ -52,36 +61,249 @@ struct TantivySearcherHandle { text_field: Field, } -fn build_schema() -> (Schema, Field, Field) { +#[derive(Clone, Deserialize)] +#[serde(default)] +struct TokenizerConfig { + tokenizer: String, + #[serde(rename = "ngram.min-gram")] + ngram_min_gram: usize, + #[serde(rename = "ngram.max-gram")] + ngram_max_gram: usize, + #[serde(rename = "ngram.prefix-only")] + ngram_prefix_only: bool, + #[serde(rename = "lower-case")] + lower_case: bool, + #[serde(rename = "max-token-length")] + max_token_length: usize, + #[serde(rename = "ascii-folding")] + ascii_folding: bool, + stem: bool, + language: String, + #[serde(rename = "remove-stop-words")] + remove_stop_words: bool, + #[serde(rename = "stop-words")] + stop_words: Vec, + #[serde(rename = "with-position")] + with_position: bool, +} + +impl Default for TokenizerConfig { + fn default() -> Self { + Self { + tokenizer: "default".to_string(), + ngram_min_gram: 2, + ngram_max_gram: 2, + ngram_prefix_only: false, + lower_case: true, + max_token_length: 40, + ascii_folding: false, + stem: false, + language: "english".to_string(), + remove_stop_words: false, + stop_words: Vec::new(), + with_position: true, + } + } +} + +impl TokenizerConfig { + fn tokenizer_name(&self) -> &str { + match self.tokenizer.as_str() { + "ngram" => "paimon_ngram", + "jieba" => "paimon_jieba", + "simple" | "whitespace" | "raw" => "paimon_custom", + "default" if self.needs_custom_default_tokenizer() => "paimon_custom", + _ => &self.tokenizer, + } + } + + fn needs_custom_default_tokenizer(&self) -> bool { + self.max_token_length != 40 + || !self.lower_case + || self.ascii_folding + || self.stem + || self.remove_stop_words + || !self.stop_words.is_empty() + } + + fn normalize(mut self) -> Result { + self.tokenizer = self.tokenizer.trim().to_lowercase(); + self.language = self.language.trim().to_lowercase(); + self.stop_words = self + .stop_words + .into_iter() + .map(|word| word.trim().to_string()) + .filter(|word| !word.is_empty()) + .collect(); + self.validate()?; + Ok(self) + } + + fn validate(&self) -> Result<(), String> { + match self.tokenizer.as_str() { + "default" | "simple" | "whitespace" | "raw" | "ngram" | "jieba" => {} + _ => return Err(format!("Unsupported tokenizer: {}", self.tokenizer)), + } + if self.ngram_min_gram == 0 { + return Err("ngram.min-gram must be positive, got 0".to_string()); + } + if self.ngram_max_gram == 0 { + return Err("ngram.max-gram must be positive, got 0".to_string()); + } + if self.ngram_min_gram > self.ngram_max_gram { + return Err(format!( + "ngram.min-gram must not be greater than ngram.max-gram, got {} > {}", + self.ngram_min_gram, self.ngram_max_gram + )); + } + if self.max_token_length == 0 { + return Err("max-token-length must be positive, got 0".to_string()); + } + self.language()?; + Ok(()) + } + + fn language(&self) -> Result { + match self.language.as_str() { + "arabic" => Ok(Language::Arabic), + "danish" => Ok(Language::Danish), + "dutch" => Ok(Language::Dutch), + "english" => Ok(Language::English), + "finnish" => Ok(Language::Finnish), + "french" => Ok(Language::French), + "german" => Ok(Language::German), + "greek" => Ok(Language::Greek), + "hungarian" => Ok(Language::Hungarian), + "italian" => Ok(Language::Italian), + "norwegian" => Ok(Language::Norwegian), + "portuguese" => Ok(Language::Portuguese), + "romanian" => Ok(Language::Romanian), + "russian" => Ok(Language::Russian), + "spanish" => Ok(Language::Spanish), + "swedish" => Ok(Language::Swedish), + "tamil" => Ok(Language::Tamil), + "turkish" => Ok(Language::Turkish), + _ => Err(format!("Unsupported language: {}", self.language)), + } + } +} + +fn build_base_analyzer(config: &TokenizerConfig) -> Result { + match config.tokenizer.as_str() { + "default" | "simple" => Ok(TextAnalyzer::builder(SimpleTokenizer::default()).dynamic()), + "whitespace" => Ok(TextAnalyzer::builder(WhitespaceTokenizer::default()).dynamic()), + "raw" => Ok(TextAnalyzer::builder(RawTokenizer::default()).dynamic()), + "ngram" => Ok(TextAnalyzer::builder( + NgramTokenizer::new( + config.ngram_min_gram, + config.ngram_max_gram, + config.ngram_prefix_only, + ) + .map_err(|e| e.to_string())?, + ) + .dynamic()), + "jieba" => Ok(TextAnalyzer::builder(JiebaTokenizer {}).dynamic()), + _ => Err(format!("Unsupported tokenizer: {}", config.tokenizer)), + } +} + +fn build_analyzer(config: &TokenizerConfig) -> Result { + let mut analyzer_builder = build_base_analyzer(config)?; + analyzer_builder = + analyzer_builder.filter_dynamic(RemoveLongFilter::limit(config.max_token_length)); + if config.lower_case { + analyzer_builder = analyzer_builder.filter_dynamic(LowerCaser); + } + if config.ascii_folding { + analyzer_builder = analyzer_builder.filter_dynamic(AsciiFoldingFilter); + } + if config.stem { + analyzer_builder = analyzer_builder.filter_dynamic(Stemmer::new(config.language()?)); + } + if config.remove_stop_words { + let stop_word_filter = StopWordFilter::new(config.language()?).ok_or_else(|| { + format!( + "Removing stop words for language '{}' is not supported", + config.language + ) + })?; + analyzer_builder = analyzer_builder.filter_dynamic(stop_word_filter); + } + if !config.stop_words.is_empty() { + analyzer_builder = + analyzer_builder.filter_dynamic(StopWordFilter::remove(config.stop_words.clone())); + } + Ok(analyzer_builder.build()) +} + +fn register_tokenizer(index: &Index, config: &TokenizerConfig) -> Result<(), String> { + index + .tokenizers() + .register(config.tokenizer_name(), build_analyzer(config)?); + Ok(()) +} + +fn tokenizer_config_from_java( + env: &mut JNIEnv, + tokenizer_name: JString, + min_gram: jint, + max_gram: jint, + prefix_only: jboolean, + lower_case: jboolean, +) -> Result { + let name: String = env + .get_string(&tokenizer_name) + .map_err(|e| format!("Failed to get tokenizer name: {}", e))? + .into(); + let name = name.trim().to_lowercase(); + TokenizerConfig { + tokenizer: name, + ngram_min_gram: min_gram as usize, + ngram_max_gram: max_gram as usize, + ngram_prefix_only: prefix_only != 0, + lower_case: lower_case != 0, + ..TokenizerConfig::default() + } + .normalize() +} + +fn tokenizer_config_from_json( + env: &mut JNIEnv, + config_json: JString, +) -> Result { + let json: String = env + .get_string(&config_json) + .map_err(|e| format!("Failed to get tokenizer config: {}", e))? + .into(); + serde_json::from_str::(&json) + .map_err(|e| format!("Failed to parse tokenizer config: {}", e))? + .normalize() +} + +fn build_schema(config: &TokenizerConfig) -> (Schema, Field, Field) { let mut builder = Schema::builder(); - let row_id_field = builder.add_u64_field( - "row_id", - NumericOptions::default().set_indexed().set_fast(), - ); + let row_id_field = + builder.add_u64_field("row_id", NumericOptions::default().set_indexed().set_fast()); + let index_option = if config.with_position { + IndexRecordOption::WithFreqsAndPositions + } else { + IndexRecordOption::WithFreqs + }; let text_options = TextOptions::default().set_indexing_options( TextFieldIndexing::default() - .set_tokenizer("default") - .set_index_option(IndexRecordOption::WithFreqsAndPositions), + .set_tokenizer(config.tokenizer_name()) + .set_index_option(index_option), ); let text_field = builder.add_text_field("text", text_options); (builder.build(), row_id_field, text_field) } -// --------------------------------------------------------------------------- -// TantivyIndexWriter native methods -// --------------------------------------------------------------------------- - -#[no_mangle] -pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_createIndex( - mut env: JNIEnv, - _class: JClass, - index_path: JString, -) -> jlong { +fn create_index_internal(mut env: JNIEnv, index_path: JString, config: TokenizerConfig) -> jlong { let path: String = match env.get_string(&index_path) { Ok(s) => s.into(), Err(e) => return throw_and_return(&mut env, &format!("Failed to get index path: {}", e)), }; - let (schema, row_id_field, text_field) = build_schema(); + let (schema, row_id_field, text_field) = build_schema(&config); let dir = std::path::Path::new(&path); if let Err(e) = std::fs::create_dir_all(dir) { @@ -91,6 +313,9 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_createI Ok(i) => i, Err(e) => return throw_and_return(&mut env, &format!("Failed to create index: {}", e)), }; + if let Err(e) = register_tokenizer(&index, &config) { + return throw_and_return(&mut env, &format!("Failed to register tokenizer: {}", e)); + } let writer = match index.writer(50_000_000) { Ok(w) => w, Err(e) => return throw_and_return(&mut env, &format!("Failed to create writer: {}", e)), @@ -104,64 +329,7 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_createI Box::into_raw(handle) as jlong } -#[no_mangle] -pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_writeDocument( - mut env: JNIEnv, - _class: JClass, - index_ptr: jlong, - row_id: jlong, - text: JString, -) { - let handle = unsafe { &mut *(index_ptr as *mut TantivyIndex) }; - let text_str: String = match env.get_string(&text) { - Ok(s) => s.into(), - Err(e) => { - throw_and_return::<()>(&mut env, &format!("Failed to get text string: {}", e)); - return; - } - }; - - let mut doc = tantivy::TantivyDocument::new(); - doc.add_u64(handle.row_id_field, row_id as u64); - doc.add_text(handle.text_field, &text_str); - if let Err(e) = handle.writer.add_document(doc) { - throw_and_return::<()>(&mut env, &format!("Failed to add document: {}", e)); - } -} - -#[no_mangle] -pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_commitIndex( - mut env: JNIEnv, - _class: JClass, - index_ptr: jlong, -) { - let handle = unsafe { &mut *(index_ptr as *mut TantivyIndex) }; - if let Err(e) = handle.writer.commit() { - throw_and_return::<()>(&mut env, &format!("Failed to commit index: {}", e)); - } -} - -#[no_mangle] -pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_freeIndex( - _env: JNIEnv, - _class: JClass, - index_ptr: jlong, -) { - unsafe { - let _ = Box::from_raw(index_ptr as *mut TantivyIndex); - } -} - -// --------------------------------------------------------------------------- -// TantivySearcher native methods -// --------------------------------------------------------------------------- - -#[no_mangle] -pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openIndex( - mut env: JNIEnv, - _class: JClass, - index_path: JString, -) -> jlong { +fn open_index_internal(mut env: JNIEnv, index_path: JString, config: TokenizerConfig) -> jlong { let path: String = match env.get_string(&index_path) { Ok(s) => s.into(), Err(e) => return throw_and_return(&mut env, &format!("Failed to get index path: {}", e)), @@ -170,6 +338,9 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openIndex( Ok(i) => i, Err(e) => return throw_and_return(&mut env, &format!("Failed to open index: {}", e)), }; + if let Err(e) = register_tokenizer(&index, &config) { + return throw_and_return(&mut env, &format!("Failed to register tokenizer: {}", e)); + } let schema = index.schema(); let text_field = schema.get_field("text").unwrap(); @@ -183,27 +354,17 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openIndex( Err(e) => return throw_and_return(&mut env, &format!("Failed to create reader: {}", e)), }; - let handle = Box::new(TantivySearcherHandle { - reader, - text_field, - }); + let handle = Box::new(TantivySearcherHandle { reader, text_field }); Box::into_raw(handle) as jlong } -/// Open an index from a Java StreamFileInput callback object. -/// -/// fileNames: String[] — names of files in the archive -/// fileOffsets: long[] — byte offset of each file in the stream -/// fileLengths: long[] — byte length of each file -/// streamInput: StreamFileInput — Java object with seek(long) and read(byte[], int, int) methods -#[no_mangle] -pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openFromStream( +fn open_from_stream_internal( mut env: JNIEnv, - _class: JClass, file_names: jni::objects::JObjectArray, file_offsets: jni::objects::JLongArray, file_lengths: jni::objects::JLongArray, stream_input: JObject, + config: TokenizerConfig, ) -> jlong { // Parse file metadata from Java arrays let count = match env.get_array_length(&file_names) { @@ -223,12 +384,19 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openFromSt for i in 0..count { let obj = match env.get_object_array_element(&file_names, i as i32) { Ok(o) => o, - Err(e) => return throw_and_return(&mut env, &format!("Failed to get file name at {}: {}", i, e)), + Err(e) => { + return throw_and_return( + &mut env, + &format!("Failed to get file name at {}: {}", i, e), + ) + } }; let jstr = JString::from(obj); let name: String = match env.get_string(&jstr) { Ok(s) => s.into(), - Err(e) => return throw_and_return(&mut env, &format!("Failed to convert file name: {}", e)), + Err(e) => { + return throw_and_return(&mut env, &format!("Failed to convert file name: {}", e)) + } }; files.push((name, offsets_buf[i] as u64, lengths_buf[i] as u64)); } @@ -240,14 +408,24 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openFromSt }; let stream_ref = match env.new_global_ref(stream_input) { Ok(r) => r, - Err(e) => return throw_and_return(&mut env, &format!("Failed to create global ref: {}", e)), + Err(e) => { + return throw_and_return(&mut env, &format!("Failed to create global ref: {}", e)) + } }; let directory = JniDirectory::new(jvm, stream_ref, files); let index = match Index::open(directory) { Ok(i) => i, - Err(e) => return throw_and_return(&mut env, &format!("Failed to open index from stream: {}", e)), + Err(e) => { + return throw_and_return( + &mut env, + &format!("Failed to open index from stream: {}", e), + ) + } }; + if let Err(e) = register_tokenizer(&index, &config) { + return throw_and_return(&mut env, &format!("Failed to register tokenizer: {}", e)); + } let schema = index.schema(); let text_field = schema.get_field("text").unwrap(); @@ -261,13 +439,246 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openFromSt Err(e) => return throw_and_return(&mut env, &format!("Failed to create reader: {}", e)), }; - let handle = Box::new(TantivySearcherHandle { - reader, - text_field, - }); + let handle = Box::new(TantivySearcherHandle { reader, text_field }); Box::into_raw(handle) as jlong } +// --------------------------------------------------------------------------- +// TantivyIndexWriter native methods +// --------------------------------------------------------------------------- + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_createIndex( + env: JNIEnv, + _class: JClass, + index_path: JString, +) -> jlong { + create_index_internal(env, index_path, TokenizerConfig::default()) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_createIndexWithTokenizer( + mut env: JNIEnv, + _class: JClass, + index_path: JString, + tokenizer_name: JString, + min_gram: jint, + max_gram: jint, + prefix_only: jboolean, + lower_case: jboolean, +) -> jlong { + let config = match tokenizer_config_from_java( + &mut env, + tokenizer_name, + min_gram, + max_gram, + prefix_only, + lower_case, + ) { + Ok(config) => config, + Err(e) => return throw_and_return(&mut env, &e), + }; + create_index_internal(env, index_path, config) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_createIndexWithTokenizerConfig( + mut env: JNIEnv, + _class: JClass, + index_path: JString, + config_json: JString, +) -> jlong { + let config = match tokenizer_config_from_json(&mut env, config_json) { + Ok(config) => config, + Err(e) => return throw_and_return(&mut env, &e), + }; + create_index_internal(env, index_path, config) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_writeDocument( + mut env: JNIEnv, + _class: JClass, + index_ptr: jlong, + row_id: jlong, + text: JString, +) { + let handle = unsafe { &mut *(index_ptr as *mut TantivyIndex) }; + let text_str: String = match env.get_string(&text) { + Ok(s) => s.into(), + Err(e) => { + throw_and_return::<()>(&mut env, &format!("Failed to get text string: {}", e)); + return; + } + }; + + let mut doc = tantivy::TantivyDocument::new(); + doc.add_u64(handle.row_id_field, row_id as u64); + doc.add_text(handle.text_field, &text_str); + if let Err(e) = handle.writer.add_document(doc) { + throw_and_return::<()>(&mut env, &format!("Failed to add document: {}", e)); + } +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_commitIndex( + mut env: JNIEnv, + _class: JClass, + index_ptr: jlong, +) { + let handle = unsafe { &mut *(index_ptr as *mut TantivyIndex) }; + if let Err(e) = handle.writer.commit() { + throw_and_return::<()>(&mut env, &format!("Failed to commit index: {}", e)); + } +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivyIndexWriter_freeIndex( + _env: JNIEnv, + _class: JClass, + index_ptr: jlong, +) { + unsafe { + let _ = Box::from_raw(index_ptr as *mut TantivyIndex); + } +} + +// --------------------------------------------------------------------------- +// TantivySearcher native methods +// --------------------------------------------------------------------------- + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openIndex( + env: JNIEnv, + _class: JClass, + index_path: JString, +) -> jlong { + open_index_internal(env, index_path, TokenizerConfig::default()) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openIndexWithTokenizer( + mut env: JNIEnv, + _class: JClass, + index_path: JString, + tokenizer_name: JString, + min_gram: jint, + max_gram: jint, + prefix_only: jboolean, + lower_case: jboolean, +) -> jlong { + let config = match tokenizer_config_from_java( + &mut env, + tokenizer_name, + min_gram, + max_gram, + prefix_only, + lower_case, + ) { + Ok(config) => config, + Err(e) => return throw_and_return(&mut env, &e), + }; + open_index_internal(env, index_path, config) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openIndexWithTokenizerConfig( + mut env: JNIEnv, + _class: JClass, + index_path: JString, + config_json: JString, +) -> jlong { + let config = match tokenizer_config_from_json(&mut env, config_json) { + Ok(config) => config, + Err(e) => return throw_and_return(&mut env, &e), + }; + open_index_internal(env, index_path, config) +} + +/// Open an index from a Java StreamFileInput callback object. +/// +/// fileNames: String[] — names of files in the archive +/// fileOffsets: long[] — byte offset of each file in the stream +/// fileLengths: long[] — byte length of each file +/// streamInput: StreamFileInput — Java object with seek(long) and read(byte[], int, int) methods +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openFromStream( + env: JNIEnv, + _class: JClass, + file_names: jni::objects::JObjectArray, + file_offsets: jni::objects::JLongArray, + file_lengths: jni::objects::JLongArray, + stream_input: JObject, +) -> jlong { + open_from_stream_internal( + env, + file_names, + file_offsets, + file_lengths, + stream_input, + TokenizerConfig::default(), + ) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openFromStreamWithTokenizer( + mut env: JNIEnv, + _class: JClass, + file_names: jni::objects::JObjectArray, + file_offsets: jni::objects::JLongArray, + file_lengths: jni::objects::JLongArray, + stream_input: JObject, + tokenizer_name: JString, + min_gram: jint, + max_gram: jint, + prefix_only: jboolean, + lower_case: jboolean, +) -> jlong { + let config = match tokenizer_config_from_java( + &mut env, + tokenizer_name, + min_gram, + max_gram, + prefix_only, + lower_case, + ) { + Ok(config) => config, + Err(e) => return throw_and_return(&mut env, &e), + }; + open_from_stream_internal( + env, + file_names, + file_offsets, + file_lengths, + stream_input, + config, + ) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_openFromStreamWithTokenizerConfig( + mut env: JNIEnv, + _class: JClass, + file_names: jni::objects::JObjectArray, + file_offsets: jni::objects::JLongArray, + file_lengths: jni::objects::JLongArray, + stream_input: JObject, + config_json: JString, +) -> jlong { + let config = match tokenizer_config_from_json(&mut env, config_json) { + Ok(config) => config, + Err(e) => return throw_and_return(&mut env, &e), + }; + open_from_stream_internal( + env, + file_names, + file_offsets, + file_lengths, + stream_input, + config, + ) +} + /// Search and return a SearchResult(long[] rowIds, float[] scores). #[no_mangle] pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_searchIndex( @@ -276,18 +687,45 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_searchInde searcher_ptr: jlong, query_string: JString, limit: jint, + query_operator: JString, ) -> jobject { let handle = unsafe { &*(searcher_ptr as *const TantivySearcherHandle) }; let query_str: String = match env.get_string(&query_string) { Ok(s) => s.into(), - Err(e) => return throw_and_return_null(&mut env, &format!("Failed to get query string: {}", e)), + Err(e) => { + return throw_and_return_null(&mut env, &format!("Failed to get query string: {}", e)) + } }; + let query_operator_str: String = match env.get_string(&query_operator) { + Ok(s) => s.into(), + Err(e) => { + return throw_and_return_null(&mut env, &format!("Failed to get query operator: {}", e)) + } + }; + let query_operator_str = query_operator_str.trim().to_lowercase(); + if query_operator_str != "or" && query_operator_str != "and" { + return throw_and_return_null( + &mut env, + &format!( + "Query operator must be 'or' or 'and', got: {}", + query_operator_str + ), + ); + } let searcher = handle.reader.searcher(); - let query_parser = QueryParser::for_index(&searcher.index(), vec![handle.text_field]); + let mut query_parser = QueryParser::for_index(&searcher.index(), vec![handle.text_field]); + if query_operator_str == "and" { + query_parser.set_conjunction_by_default(); + } let query = match query_parser.parse_query(&query_str) { Ok(q) => q, - Err(e) => return throw_and_return_null(&mut env, &format!("Failed to parse query '{}': {}", query_str, e)), + Err(e) => { + return throw_and_return_null( + &mut env, + &format!("Failed to parse query '{}': {}", query_str, e), + ) + } }; let top_docs = match searcher.search(&query, &TopDocs::with_limit(limit as usize)) { Ok(d) => d, @@ -299,11 +737,15 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_searchInde // Build Java long[] and float[] let row_id_array = match env.new_long_array(count as i32) { Ok(a) => a, - Err(e) => return throw_and_return_null(&mut env, &format!("Failed to create long array: {}", e)), + Err(e) => { + return throw_and_return_null(&mut env, &format!("Failed to create long array: {}", e)) + } }; let score_array = match env.new_float_array(count as i32) { Ok(a) => a, - Err(e) => return throw_and_return_null(&mut env, &format!("Failed to create float array: {}", e)), + Err(e) => { + return throw_and_return_null(&mut env, &format!("Failed to create float array: {}", e)) + } }; let mut row_ids: Vec = Vec::with_capacity(count); @@ -314,7 +756,9 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_searchInde let segment_reader = searcher.segment_reader(doc_address.segment_ord); let fast_fields = match segment_reader.fast_fields().u64("row_id") { Ok(f) => f, - Err(e) => return throw_and_return_null(&mut env, &format!("Failed to get fast field: {}", e)), + Err(e) => { + return throw_and_return_null(&mut env, &format!("Failed to get fast field: {}", e)) + } }; let row_id = fast_fields.first(doc_address.doc_id).unwrap_or(0) as jlong; row_ids.push(row_id); @@ -331,7 +775,12 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_searchInde // Construct SearchResult object let class = match env.find_class("org/apache/paimon/tantivy/SearchResult") { Ok(c) => c, - Err(e) => return throw_and_return_null(&mut env, &format!("Failed to find SearchResult class: {}", e)), + Err(e) => { + return throw_and_return_null( + &mut env, + &format!("Failed to find SearchResult class: {}", e), + ) + } }; let obj = match env.new_object( class, @@ -342,7 +791,12 @@ pub extern "system" fn Java_org_apache_paimon_tantivy_TantivySearcher_searchInde ], ) { Ok(o) => o, - Err(e) => return throw_and_return_null(&mut env, &format!("Failed to create SearchResult: {}", e)), + Err(e) => { + return throw_and_return_null( + &mut env, + &format!("Failed to create SearchResult: {}", e), + ) + } }; obj.into_raw() diff --git a/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivyIndexWriter.java b/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivyIndexWriter.java index 51ccabf3a06b..ba9e834094a7 100644 --- a/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivyIndexWriter.java +++ b/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivyIndexWriter.java @@ -33,6 +33,25 @@ public TantivyIndexWriter(String indexPath) { this.closed = false; } + public TantivyIndexWriter( + String indexPath, + String tokenizerName, + int minGram, + int maxGram, + boolean prefixOnly, + boolean lowerCase) { + this.indexPtr = + createIndexWithTokenizer( + indexPath, tokenizerName, minGram, maxGram, prefixOnly, lowerCase); + this.closed = false; + } + + public TantivyIndexWriter(String indexPath, String configJson) { + this.indexPtr = + createIndexWithTokenizerConfig(indexPath, configJson); + this.closed = false; + } + public void addDocument(long rowId, String text) { checkNotClosed(); writeDocument(indexPtr, rowId, text); @@ -62,6 +81,16 @@ private void checkNotClosed() { static native long createIndex(String indexPath); + static native long createIndexWithTokenizer( + String indexPath, + String tokenizerName, + int minGram, + int maxGram, + boolean prefixOnly, + boolean lowerCase); + + static native long createIndexWithTokenizerConfig(String indexPath, String configJson); + static native void writeDocument(long indexPtr, long rowId, String text); static native void commitIndex(long indexPtr); diff --git a/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivySearcher.java b/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivySearcher.java index 798af2be6260..bc04632bb084 100644 --- a/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivySearcher.java +++ b/paimon-tantivy/paimon-tantivy-jni/src/main/java/org/apache/paimon/tantivy/TantivySearcher.java @@ -34,6 +34,24 @@ public TantivySearcher(String indexPath) { this.closed = false; } + public TantivySearcher( + String indexPath, + String tokenizerName, + int minGram, + int maxGram, + boolean prefixOnly, + boolean lowerCase) { + this.searcherPtr = + openIndexWithTokenizer( + indexPath, tokenizerName, minGram, maxGram, prefixOnly, lowerCase); + this.closed = false; + } + + public TantivySearcher(String indexPath, String configJson) { + this.searcherPtr = openIndexWithTokenizerConfig(indexPath, configJson); + this.closed = false; + } + /** * Open a searcher from a stream-backed archive. * @@ -51,6 +69,42 @@ public TantivySearcher( this.closed = false; } + public TantivySearcher( + String[] fileNames, + long[] fileOffsets, + long[] fileLengths, + StreamFileInput streamInput, + String tokenizerName, + int minGram, + int maxGram, + boolean prefixOnly, + boolean lowerCase) { + this.searcherPtr = + openFromStreamWithTokenizer( + fileNames, + fileOffsets, + fileLengths, + streamInput, + tokenizerName, + minGram, + maxGram, + prefixOnly, + lowerCase); + this.closed = false; + } + + public TantivySearcher( + String[] fileNames, + long[] fileOffsets, + long[] fileLengths, + StreamFileInput streamInput, + String configJson) { + this.searcherPtr = + openFromStreamWithTokenizerConfig( + fileNames, fileOffsets, fileLengths, streamInput, configJson); + this.closed = false; + } + /** * Search the index with a query string, returning top N results ranked by score. * @@ -59,8 +113,12 @@ public TantivySearcher( * @return search results containing rowIds and scores */ public SearchResult search(String queryString, int limit) { + return search(queryString, limit, "or"); + } + + public SearchResult search(String queryString, int limit, String queryOperator) { checkNotClosed(); - return searchIndex(searcherPtr, queryString, limit); + return searchIndex(searcherPtr, queryString, limit, queryOperator); } @Override @@ -86,10 +144,40 @@ private void checkNotClosed() { static native long openIndex(String indexPath); + static native long openIndexWithTokenizer( + String indexPath, + String tokenizerName, + int minGram, + int maxGram, + boolean prefixOnly, + boolean lowerCase); + + static native long openIndexWithTokenizerConfig(String indexPath, String configJson); + static native long openFromStream( String[] fileNames, long[] fileOffsets, long[] fileLengths, StreamFileInput streamInput); - static native SearchResult searchIndex(long searcherPtr, String queryString, int limit); + static native long openFromStreamWithTokenizer( + String[] fileNames, + long[] fileOffsets, + long[] fileLengths, + StreamFileInput streamInput, + String tokenizerName, + int minGram, + int maxGram, + boolean prefixOnly, + boolean lowerCase); + + static native long openFromStreamWithTokenizerConfig( + String[] fileNames, + long[] fileOffsets, + long[] fileLengths, + StreamFileInput streamInput, + String configJson); + + static native SearchResult searchIndex( + long searcherPtr, String queryString, int limit, String queryOperator); static native void freeSearcher(long searcherPtr); + } diff --git a/paimon-tantivy/paimon-tantivy-jni/src/test/java/org/apache/paimon/tantivy/TantivyJniTest.java b/paimon-tantivy/paimon-tantivy-jni/src/test/java/org/apache/paimon/tantivy/TantivyJniTest.java index 4313f5768d16..80300c24a8a9 100644 --- a/paimon-tantivy/paimon-tantivy-jni/src/test/java/org/apache/paimon/tantivy/TantivyJniTest.java +++ b/paimon-tantivy/paimon-tantivy-jni/src/test/java/org/apache/paimon/tantivy/TantivyJniTest.java @@ -18,7 +18,6 @@ package org.apache.paimon.tantivy; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -28,10 +27,9 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue; /** Smoke test for Tantivy JNI. */ -class TantivyJniTest { +public class TantivyJniTest { - @BeforeAll - static void checkNativeLibrary() { + private static void assumeNativeAvailable() { assumeTrue(isNativeAvailable(), "Tantivy native library not available, skipping tests"); } @@ -45,7 +43,8 @@ private static boolean isNativeAvailable() { } @Test - void testWriteAndSearch(@TempDir Path tempDir) { + public void testWriteAndSearch(@TempDir Path tempDir) { + assumeNativeAvailable(); String indexPath = tempDir.resolve("test_index").toString(); try (TantivyIndexWriter writer = new TantivyIndexWriter(indexPath)) { @@ -74,4 +73,69 @@ void testWriteAndSearch(@TempDir Path tempDir) { } } } + + @Test + public void testNgramTokenizerFindsChineseFragment(@TempDir Path tempDir) { + assumeNativeAvailable(); + String indexPath = tempDir.resolve("ngram_index").toString(); + + try (TantivyIndexWriter writer = + new TantivyIndexWriter(indexPath, "ngram", 2, 2, false, true)) { + writer.addDocument(1L, "Apache Paimon 支持中文全文检索"); + writer.addDocument(2L, "Tantivy full text search engine"); + writer.commit(); + } + + try (TantivySearcher searcher = + new TantivySearcher(indexPath, "ngram", 2, 2, false, true)) { + SearchResult result = searcher.search("中文", 10); + assertEquals(1, result.size()); + assertEquals(1L, result.getRowIds()[0]); + } + } + + @Test + public void testJiebaTokenizerFindsChineseWord(@TempDir Path tempDir) { + assumeNativeAvailable(); + String indexPath = tempDir.resolve("jieba_index").toString(); + + try (TantivyIndexWriter writer = + new TantivyIndexWriter(indexPath, "jieba", 2, 2, false, true)) { + writer.addDocument(1L, "张华在百货公司当售货员"); + writer.addDocument(2L, "Apache Paimon supports full text search"); + writer.commit(); + } + + try (TantivySearcher searcher = + new TantivySearcher(indexPath, "jieba", 2, 2, false, true)) { + SearchResult result = searcher.search("售货员", 10); + assertEquals(1, result.size()); + assertEquals(1L, result.getRowIds()[0]); + } + } + + @Test + public void testTokenizerConfigAndQueryOperator(@TempDir Path tempDir) { + assumeNativeAvailable(); + String indexPath = tempDir.resolve("config_index").toString(); + String configJson = + "{\"tokenizer\":\"simple\",\"stem\":true,\"remove-stop-words\":true," + + "\"language\":\"english\"}"; + + try (TantivyIndexWriter writer = new TantivyIndexWriter(indexPath, configJson)) { + writer.addDocument(1L, "Apache Paimon runs streaming jobs"); + writer.addDocument(2L, "Apache Spark runs batch jobs"); + writer.addDocument(3L, "Paimon lake stores data"); + writer.commit(); + } + + try (TantivySearcher searcher = new TantivySearcher(indexPath, configJson)) { + SearchResult orResult = searcher.search("paimon spark", 10); + assertEquals(3, orResult.size()); + + SearchResult andResult = searcher.search("paimon run", 10, "and"); + assertEquals(1, andResult.size()); + assertEquals(1L, andResult.getRowIds()[0]); + } + } } diff --git a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexFileFormat.java b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexFileFormat.java index 18878250065c..46e13d6cb454 100644 --- a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexFileFormat.java +++ b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexFileFormat.java @@ -18,7 +18,7 @@ package org.apache.paimon.format.vortex; -import org.apache.paimon.arrow.vector.ArrowFormatWriter; +import org.apache.paimon.arrow.vector.ArrowFormatCWriter; import org.apache.paimon.format.FileFormat; import org.apache.paimon.format.FileFormatFactory; import org.apache.paimon.format.FormatReaderFactory; @@ -79,13 +79,7 @@ public FormatReaderFactory createReaderFactory( @Override public FormatWriterFactory createWriterFactory(RowType type) { return new VortexWriterFactory( - type, - () -> - new ArrowFormatWriter( - type, - formatContext.writeBatchSize(), - true, - formatContext.writeBatchMemory().getBytes())); + () -> new ArrowFormatCWriter(type, formatContext.writeBatchSize(), true)); } @Override diff --git a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexPredicateConverter.java b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexPredicateConverter.java index ff1e1607a23d..e66f6cbecee1 100644 --- a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexPredicateConverter.java +++ b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexPredicateConverter.java @@ -42,11 +42,6 @@ import org.apache.paimon.types.TimestampType; import dev.vortex.api.Expression; -import dev.vortex.api.expressions.Binary; -import dev.vortex.api.expressions.GetItem; -import dev.vortex.api.expressions.Literal; -import dev.vortex.api.expressions.Not; -import dev.vortex.api.expressions.Root; import javax.annotation.Nullable; @@ -74,13 +69,13 @@ public Expression visit(LeafPredicate predicate) { return null; } FieldRef fieldRef = fieldRefOpt.get(); - Expression field = GetItem.of(Root.INSTANCE, fieldRef.name()); + Expression field = Expression.column(fieldRef.name()); if (predicate.function() instanceof IsNull) { - return Not.of(Binary.notEq(field, Literal.nullLit())); + return Expression.isNull(field); } if (predicate.function() instanceof IsNotNull) { - return Binary.notEq(field, Literal.nullLit()); + return Expression.isNotNull(field); } List literals = predicate.literals(); @@ -88,26 +83,25 @@ public Expression visit(LeafPredicate predicate) { return null; } - Literal vortexLiteral = toLiteral(fieldRef.type(), literals.get(0)); + Expression vortexLiteral = toLiteral(fieldRef.type(), literals.get(0)); if (vortexLiteral == null) { return null; } if (predicate.function() instanceof Equal) { - return Binary.eq(field, vortexLiteral); + return Expression.binary(Expression.BinaryOp.EQ, field, vortexLiteral); } else if (predicate.function() instanceof NotEqual) { - return Binary.notEq(field, vortexLiteral); + return Expression.binary(Expression.BinaryOp.NOT_EQ, field, vortexLiteral); } else if (predicate.function() instanceof GreaterThan) { - return Binary.gt(field, vortexLiteral); + return Expression.binary(Expression.BinaryOp.GT, field, vortexLiteral); } else if (predicate.function() instanceof GreaterOrEqual) { - return Binary.gtEq(field, vortexLiteral); + return Expression.binary(Expression.BinaryOp.GTE, field, vortexLiteral); } else if (predicate.function() instanceof LessThan) { - return Binary.lt(field, vortexLiteral); + return Expression.binary(Expression.BinaryOp.LT, field, vortexLiteral); } else if (predicate.function() instanceof LessOrEqual) { - return Binary.ltEq(field, vortexLiteral); + return Expression.binary(Expression.BinaryOp.LTE, field, vortexLiteral); } - // unsupported function (e.g. In, Between) return null; } @@ -124,11 +118,7 @@ public Expression visit(CompoundPredicate predicate) { if (children.isEmpty()) { return null; } - Expression result = children.get(0); - for (int i = 1; i < children.size(); i++) { - result = Binary.and(result, children.get(i)); - } - return result; + return Expression.and(children.toArray(new Expression[0])); } else if (predicate.function() instanceof Or) { List children = new ArrayList<>(); for (Predicate child : predicate.children()) { @@ -138,74 +128,70 @@ public Expression visit(CompoundPredicate predicate) { } children.add(expr); } - Expression result = children.get(0); - for (int i = 1; i < children.size(); i++) { - result = Binary.or(result, children.get(i)); - } - return result; + return Expression.or(children.toArray(new Expression[0])); } return null; } @Nullable - private static Literal toLiteral(DataType type, Object value) { + private static Expression toLiteral(DataType type, Object value) { if (value == null) { - return Literal.nullLit(); + return Expression.nullLiteral(Expression.DType.I32); } switch (type.getTypeRoot()) { case BOOLEAN: - return Literal.bool((Boolean) value); + return Expression.literal((Boolean) value); case TINYINT: - return Literal.int8((Byte) value); + return Expression.literal((Byte) value); case SMALLINT: - return Literal.int16((Short) value); + return Expression.literal((Short) value); case INTEGER: case DATE: - return Literal.int32((Integer) value); + return Expression.literal((Integer) value); case BIGINT: - return Literal.int64((Long) value); + return Expression.literal((Long) value); case FLOAT: - return Literal.float32((Float) value); + return Expression.literal((Float) value); case DOUBLE: - return Literal.float64((Double) value); + return Expression.literal((Double) value); case CHAR: case VARCHAR: - return Literal.string( + return Expression.literal( value instanceof BinaryString ? value.toString() : (String) value); case DECIMAL: Decimal decimal = (Decimal) value; - return Literal.decimal( - decimal.toBigDecimal(), decimal.precision(), decimal.scale()); + return Expression.literalDecimal( + decimal.toBigDecimal().unscaledValue(), + decimal.precision(), + decimal.scale()); case TIMESTAMP_WITHOUT_TIME_ZONE: Timestamp ts = (Timestamp) value; - return toTimestampLiteral( - ts, ((TimestampType) type).getPrecision(), Optional.empty()); + return toTimestampLiteral(ts, ((TimestampType) type).getPrecision(), null); case TIMESTAMP_WITH_LOCAL_TIME_ZONE: Timestamp lzTs = (Timestamp) value; return toTimestampLiteral( - lzTs, ((LocalZonedTimestampType) type).getPrecision(), Optional.of("UTC")); + lzTs, ((LocalZonedTimestampType) type).getPrecision(), "UTC"); default: return null; } } - private static Literal toTimestampLiteral( - Timestamp ts, int precision, Optional timeZone) { - if (precision <= 0) { - // seconds - return Literal.timestampMillis(ts.getMillisecond(), timeZone); - } else if (precision <= 3) { - // millis - return Literal.timestampMillis(ts.getMillisecond(), timeZone); + private static Expression toTimestampLiteral( + Timestamp ts, int precision, @Nullable String timeZone) { + if (precision <= 3) { + return Expression.literalTimestamp( + ts.getMillisecond(), Expression.TimeUnit.MILLISECONDS, timeZone); } else if (precision <= 6) { - // micros - return Literal.timestampMicros( - ts.getMillisecond() * 1000 + ts.getNanoOfMillisecond() / 1000, timeZone); + return Expression.literalTimestamp( + ts.getMillisecond() * 1000 + ts.getNanoOfMillisecond() / 1000, + Expression.TimeUnit.MICROSECONDS, + timeZone); } else { - // nanos - return Literal.timestampNanos( - ts.getMillisecond() * 1_000_000 + ts.getNanoOfMillisecond(), timeZone); + return Expression.literalTimestamp( + ts.getMillisecond() * 1_000_000 + ts.getNanoOfMillisecond(), + Expression.TimeUnit.NANOSECONDS, + timeZone); } } } diff --git a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsReader.java b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsReader.java index 69b7e2f7ba8f..571c4edf044e 100644 --- a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsReader.java +++ b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsReader.java @@ -30,15 +30,17 @@ import org.apache.paimon.utils.LongIterator; import org.apache.paimon.utils.ProjectedRow; -import dev.vortex.api.Array; -import dev.vortex.api.ArrayIterator; +import dev.vortex.api.DataSource; import dev.vortex.api.Expression; -import dev.vortex.api.File; -import dev.vortex.api.Files; import dev.vortex.api.ImmutableScanOptions; +import dev.vortex.api.Partition; +import dev.vortex.api.Scan; +import dev.vortex.api.ScanOptions; +import dev.vortex.api.Session; import dev.vortex.arrow.ArrowAllocation; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; import javax.annotation.Nullable; @@ -56,12 +58,13 @@ public class VortexRecordsReader implements FileRecordReader { private final ArrowBatchReader arrowBatchReader; private final Path filePath; private final BufferAllocator allocator; - private final ArrayIterator arrayIterator; - private final File vortexFile; + private final Session session; + private final DataSource dataSource; + private final Scan scan; private final LongIterator positionIterator; @Nullable private final int[] physicalFieldMapping; - private VectorSchemaRoot reuse; - private Array currentArray; + private ArrowReader currentArrowReader; + private Partition currentPartition; private long returnedPosition = -1; public VortexRecordsReader( @@ -77,45 +80,72 @@ public VortexRecordsReader( this.allocator = ArrowAllocation.rootAllocator() .newChildAllocator("vortex-reader", 0, Long.MAX_VALUE); + try { - this.vortexFile = Files.open(path.toUri().toString(), storageOptions); + this.session = Session.create(); try { - ImmutableScanOptions.Builder scanBuilder = ImmutableScanOptions.builder(); - scanBuilder.addAllColumns(physicalReadRowType.getFieldNames()); - if (rowIndices != null) { - scanBuilder.rowIndices(rowIndices); - } - if (predicate != null) { - scanBuilder.predicate(predicate); + this.dataSource = DataSource.open(session, path.toUri().toString(), storageOptions); + try { + ImmutableScanOptions.Builder scanBuilder = ImmutableScanOptions.builder(); + + java.util.List columns = physicalReadRowType.getFieldNames(); + scanBuilder.projection( + Expression.select(columns.toArray(new String[0]), Expression.root())); + + if (rowIndices != null) { + scanBuilder.selectionIndices(rowIndices); + scanBuilder.selectionMode(ScanOptions.SelectionMode.INCLUDE); + } + if (predicate != null) { + scanBuilder.filter(predicate); + } + this.scan = dataSource.scan(scanBuilder.build()); + } catch (Exception e) { + dataSource.close(); + throw e; } - this.arrayIterator = vortexFile.newScan(scanBuilder.build()); } catch (Exception e) { - vortexFile.close(); + session.close(); throw e; } } catch (Exception e) { allocator.close(); - throw e; + throw new RuntimeException(e); } + this.arrowBatchReader = new ArrowBatchReader(physicalReadRowType, true); + long totalRowCount = dataSource.rowCount(); this.positionIterator = rowIndices != null ? LongIterator.fromArray(rowIndices) - : LongIterator.fromRange(0, vortexFile.rowCount()); + : LongIterator.fromRange(0, totalRowCount > 0 ? totalRowCount : 0); } @Nullable @Override public FileRecordIterator readBatch() throws IOException { - if (!arrayIterator.hasNext()) { - return null; + // Try to load the next batch from the current ArrowReader + if (currentArrowReader != null && currentArrowReader.loadNextBatch()) { + return toBatchIterator(currentArrowReader.getVectorSchemaRoot()); + } + + // Close current reader and move to the next partition + closeCurrentArrowReader(); + + while (scan.hasNext()) { + closeCurrentPartition(); + currentPartition = scan.next(); + currentArrowReader = currentPartition.scanArrow(allocator); + if (currentArrowReader.loadNextBatch()) { + return toBatchIterator(currentArrowReader.getVectorSchemaRoot()); + } + closeCurrentArrowReader(); } - releaseCurrentArray(); - Array array = arrayIterator.next(); - this.currentArray = array; - VectorSchemaRoot vsr = array.exportToArrow(allocator, reuse); - this.reuse = vsr; + return null; + } + + private FileRecordIterator toBatchIterator(VectorSchemaRoot vsr) { Iterator rows = arrowBatchReader.readBatch(vsr).iterator(); ProjectedRow projectedRow = physicalFieldMapping == null ? null : ProjectedRow.from(physicalFieldMapping); @@ -143,27 +173,35 @@ public InternalRow next() { } @Override - public void releaseBatch() { - releaseCurrentArray(); - } + public void releaseBatch() {} }; } - private void releaseCurrentArray() { - if (currentArray != null) { - currentArray.close(); - currentArray = null; + private void closeCurrentArrowReader() { + if (currentArrowReader != null) { + try { + currentArrowReader.close(); + } catch (IOException e) { + // ignore + } + currentArrowReader = null; + } + } + + private void closeCurrentPartition() { + if (currentPartition != null) { + currentPartition.close(); + currentPartition = null; } } @Override public void close() { - releaseCurrentArray(); - if (reuse != null) { - reuse.close(); - } - arrayIterator.close(); - vortexFile.close(); + closeCurrentArrowReader(); + closeCurrentPartition(); + scan.close(); + dataSource.close(); + session.close(); allocator.close(); } diff --git a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsWriter.java b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsWriter.java index 6d8de8fcccae..7798fa09d419 100644 --- a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsWriter.java +++ b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexRecordsWriter.java @@ -19,54 +19,81 @@ package org.apache.paimon.format.vortex; import org.apache.paimon.arrow.ArrowBundleRecords; -import org.apache.paimon.arrow.ArrowUtils; -import org.apache.paimon.arrow.vector.ArrowFormatWriter; +import org.apache.paimon.arrow.vector.ArrowCStruct; +import org.apache.paimon.arrow.vector.ArrowFormatCWriter; import org.apache.paimon.data.InternalRow; import org.apache.paimon.format.BundleFormatWriter; import org.apache.paimon.fs.Path; import org.apache.paimon.io.BundleRecords; -import org.apache.paimon.types.RowType; -import dev.vortex.api.DType; +import dev.vortex.api.Session; import dev.vortex.api.VortexWriter; +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; -/** Vortex records writer. */ +/** Vortex records writer using the Arrow C Data Interface. */ public class VortexRecordsWriter implements BundleFormatWriter { private static final Logger LOG = LoggerFactory.getLogger(VortexRecordsWriter.class); private static final double COMPRESSION_RATIO = 0.25; - private final ArrowFormatWriter arrowFormatWriter; + private final Supplier cWriterSupplier; + private final Session session; private final VortexWriter nativeWriter; private final String path; + + // Vortex's writeBatch is semi-async: Rust takes zero-copy ownership of buffers + // via Arc and releases them after the background write completes. + // Each flush creates a new ArrowFormatCWriter (with its own RootAllocator) so + // buffers are never reused across batches. The Rust release callback frees most + // memory; only a small residual (~148 bytes per batch from Arrow 15's incomplete + // release) remains in each retained resource until nativeWriter.close(). + private final List retainedResources; + private ArrowFormatCWriter currentWriter; + private long jniCost = 0; - private long ipcBytes = 0; + private long ffiBytes = 0; public VortexRecordsWriter( - RowType rowType, - ArrowFormatWriter arrowFormatWriter, + Supplier cWriterSupplier, Path path, Map storageOptions) throws IOException { - this.arrowFormatWriter = arrowFormatWriter; + this.cWriterSupplier = cWriterSupplier; this.path = path.toUri().toString(); + this.retainedResources = new ArrayList<>(); + this.currentWriter = cWriterSupplier.get(); - DType dtype = VortexTypeUtils.toDType(rowType); - this.nativeWriter = VortexWriter.create(this.path, dtype, storageOptions); + this.session = Session.create(); + try { + Schema arrowSchema = currentWriter.getVectorSchemaRoot().getSchema(); + this.nativeWriter = + VortexWriter.create(session, this.path, arrowSchema, storageOptions); + } catch (Exception e) { + session.close(); + throw e; + } } @Override public void addElement(InternalRow internalRow) throws IOException { - if (!arrowFormatWriter.write(internalRow)) { + if (!currentWriter.write(internalRow)) { flush(); - if (!arrowFormatWriter.write(internalRow)) { + if (!currentWriter.write(internalRow)) { throw new RuntimeException("Exception happens while write to vortex file"); } } @@ -76,7 +103,7 @@ public void addElement(InternalRow internalRow) throws IOException { public void writeBundle(BundleRecords bundleRecords) throws IOException { if (bundleRecords instanceof ArrowBundleRecords) { flush(); - writeVsr(((ArrowBundleRecords) bundleRecords).getVectorSchemaRoot()); + writeBundleVsr(((ArrowBundleRecords) bundleRecords).getVectorSchemaRoot()); } else { for (InternalRow row : bundleRecords) { addElement(row); @@ -86,7 +113,7 @@ public void writeBundle(BundleRecords bundleRecords) throws IOException { @Override public boolean reachTargetSize(boolean suggestedCheck, long targetSize) { - return suggestedCheck && (long) (ipcBytes * COMPRESSION_RATIO) >= targetSize; + return suggestedCheck && (long) (ffiBytes * COMPRESSION_RATIO) >= targetSize; } @Override @@ -100,14 +127,22 @@ public void close() throws IOException { throwable = t; } + // nativeWriter.close() blocks until all async background writes complete. try { nativeWriter.close(); } catch (Throwable t) { throwable = addSuppressed(throwable, t); } + // Release all retained resources now that async writes are done. + for (AutoCloseable res : retainedResources) { + closeQuietly(res); + } + retainedResources.clear(); + closeQuietly(currentWriter); + try { - arrowFormatWriter.close(); + session.close(); } catch (Throwable t) { throwable = addSuppressed(throwable, t); } @@ -121,22 +156,66 @@ public void close() throws IOException { } private void flush() throws IOException { + currentWriter.flush(); + if (!currentWriter.empty()) { + ffiBytes += bufferBytes(currentWriter.getVectorSchemaRoot()); + ArrowCStruct cStruct = currentWriter.toCStruct(); + long t1 = System.currentTimeMillis(); + nativeWriter.writeBatch(cStruct.arrayAddress(), cStruct.schemaAddress()); + jniCost += (System.currentTimeMillis() - t1); + // Each ArrowFormatCWriter has its own RootAllocator and buffers. + // Retain it so buffer memory stays alive for async Rust reads. + retainedResources.add(currentWriter); + currentWriter = cWriterSupplier.get(); + } + } + + /** Write an external VSR (from writeBundle) via IPC copy into an independent allocator. */ + private void writeBundleVsr(VectorSchemaRoot vsr) throws IOException { + ffiBytes += bufferBytes(vsr); + byte[] ipc = org.apache.paimon.arrow.ArrowUtils.serializeToIpc(vsr); + RootAllocator bundleAllocator = new RootAllocator(Long.MAX_VALUE); try { - arrowFormatWriter.flush(); - if (!arrowFormatWriter.empty()) { - writeVsr(arrowFormatWriter.getVectorSchemaRoot()); - } - } finally { - arrowFormatWriter.reset(); + ArrowStreamReader ipcReader = + new ArrowStreamReader(new java.io.ByteArrayInputStream(ipc), bundleAllocator); + ipcReader.loadNextBatch(); + VectorSchemaRoot copy = ipcReader.getVectorSchemaRoot(); + + ArrowArray arrowArray = ArrowArray.allocateNew(bundleAllocator); + ArrowSchema arrowSchema = ArrowSchema.allocateNew(bundleAllocator); + Data.exportVectorSchemaRoot(bundleAllocator, copy, null, arrowArray, arrowSchema); + long t1 = System.currentTimeMillis(); + nativeWriter.writeBatch(arrowArray.memoryAddress(), arrowSchema.memoryAddress()); + jniCost += (System.currentTimeMillis() - t1); + // Retain all resources that own the exported C Data buffers and release + // callbacks. Rust holds async zero-copy references via Arc. + // Order matters: close ipcReader (owns VectorSchemaRoot) before allocator. + retainedResources.add(ipcReader); + retainedResources.add(bundleAllocator); + } catch (Exception e) { + closeQuietly(bundleAllocator); + throw e instanceof IOException ? (IOException) e : new IOException(e); } } - private void writeVsr(VectorSchemaRoot vsr) throws IOException { - byte[] bytes = ArrowUtils.serializeToIpc(vsr); - ipcBytes += bytes.length; - long t1 = System.currentTimeMillis(); - nativeWriter.writeBatch(bytes); - jniCost += (System.currentTimeMillis() - t1); + private static long bufferBytes(VectorSchemaRoot vsr) { + long bytes = 0; + for (int i = 0; i < vsr.getFieldVectors().size(); i++) { + bytes += vsr.getFieldVectors().get(i).getBufferSize(); + } + return bytes; + } + + private static void closeQuietly(AutoCloseable closeable) { + try { + closeable.close(); + } catch (IllegalStateException e) { + if (e.getMessage() == null || !e.getMessage().contains("Memory was leaked")) { + throw e; + } + } catch (Exception e) { + throw new RuntimeException(e); + } } private static Throwable addSuppressed(Throwable throwable, Throwable suppressed) { diff --git a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexTypeUtils.java b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexTypeUtils.java deleted file mode 100644 index 271c14d19647..000000000000 --- a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexTypeUtils.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.format.vortex; - -import org.apache.paimon.types.ArrayType; -import org.apache.paimon.types.BigIntType; -import org.apache.paimon.types.BinaryType; -import org.apache.paimon.types.BooleanType; -import org.apache.paimon.types.CharType; -import org.apache.paimon.types.DataField; -import org.apache.paimon.types.DataType; -import org.apache.paimon.types.DataTypeDefaultVisitor; -import org.apache.paimon.types.DateType; -import org.apache.paimon.types.DecimalType; -import org.apache.paimon.types.DoubleType; -import org.apache.paimon.types.FloatType; -import org.apache.paimon.types.IntType; -import org.apache.paimon.types.LocalZonedTimestampType; -import org.apache.paimon.types.RowType; -import org.apache.paimon.types.SmallIntType; -import org.apache.paimon.types.TimeType; -import org.apache.paimon.types.TimestampType; -import org.apache.paimon.types.TinyIntType; -import org.apache.paimon.types.VarBinaryType; -import org.apache.paimon.types.VarCharType; -import org.apache.paimon.types.VectorType; - -import dev.vortex.api.DType; - -import java.util.List; -import java.util.Optional; - -/** Utilities for converting Paimon types to Vortex DType. */ -public class VortexTypeUtils { - - public static DType toDType(RowType rowType) { - // Vortex does not support nullable top-level structs - return toStructDType(rowType, false); - } - - private static DType toStructDType(RowType rowType, boolean isNullable) { - List fields = rowType.getFields(); - String[] fieldNames = new String[fields.size()]; - DType[] fieldTypes = new DType[fields.size()]; - for (int i = 0; i < fields.size(); i++) { - DataField field = fields.get(i); - fieldNames[i] = field.name(); - fieldTypes[i] = field.type().accept(new VortexTypeVisitor()); - } - return DType.newStruct(fieldNames, fieldTypes, isNullable); - } - - private static class VortexTypeVisitor extends DataTypeDefaultVisitor { - - @Override - public DType visit(TinyIntType tinyIntType) { - return DType.newByte(tinyIntType.isNullable()); - } - - @Override - public DType visit(SmallIntType smallIntType) { - return DType.newShort(smallIntType.isNullable()); - } - - @Override - public DType visit(IntType intType) { - return DType.newInt(intType.isNullable()); - } - - @Override - public DType visit(BigIntType bigIntType) { - return DType.newLong(bigIntType.isNullable()); - } - - @Override - public DType visit(FloatType floatType) { - return DType.newFloat(floatType.isNullable()); - } - - @Override - public DType visit(DoubleType doubleType) { - return DType.newDouble(doubleType.isNullable()); - } - - @Override - public DType visit(CharType charType) { - return DType.newUtf8(charType.isNullable()); - } - - @Override - public DType visit(VarCharType varCharType) { - return DType.newUtf8(varCharType.isNullable()); - } - - @Override - public DType visit(BooleanType booleanType) { - return DType.newBool(booleanType.isNullable()); - } - - @Override - public DType visit(BinaryType binaryType) { - return DType.newBinary(binaryType.isNullable()); - } - - @Override - public DType visit(VarBinaryType varBinaryType) { - return DType.newBinary(varBinaryType.isNullable()); - } - - @Override - public DType visit(DecimalType decimalType) { - return DType.newDecimal( - decimalType.getPrecision(), decimalType.getScale(), decimalType.isNullable()); - } - - @Override - public DType visit(DateType dateType) { - return DType.newDate(DType.TimeUnit.DAYS, dateType.isNullable()); - } - - @Override - public DType visit(TimeType timeType) { - return DType.newTime(DType.TimeUnit.MILLISECONDS, timeType.isNullable()); - } - - @Override - public DType visit(TimestampType timestampType) { - DType.TimeUnit unit; - int precision = timestampType.getPrecision(); - if (precision <= 0) { - unit = DType.TimeUnit.SECONDS; - } else if (precision <= 3) { - unit = DType.TimeUnit.MILLISECONDS; - } else if (precision <= 6) { - unit = DType.TimeUnit.MICROSECONDS; - } else { - unit = DType.TimeUnit.NANOSECONDS; - } - return DType.newTimestamp(unit, Optional.empty(), timestampType.isNullable()); - } - - @Override - public DType visit(LocalZonedTimestampType lzTimestampType) { - DType.TimeUnit unit; - int precision = lzTimestampType.getPrecision(); - if (precision <= 0) { - unit = DType.TimeUnit.SECONDS; - } else if (precision <= 3) { - unit = DType.TimeUnit.MILLISECONDS; - } else if (precision <= 6) { - unit = DType.TimeUnit.MICROSECONDS; - } else { - unit = DType.TimeUnit.NANOSECONDS; - } - return DType.newTimestamp(unit, Optional.of("UTC"), lzTimestampType.isNullable()); - } - - @Override - public DType visit(ArrayType arrayType) { - DType elementType = arrayType.getElementType().accept(this); - return DType.newList(elementType, arrayType.isNullable()); - } - - @Override - public DType visit(VectorType vectorType) { - DType elementType = vectorType.getElementType().accept(this); - return DType.newFixedSizeList( - elementType, vectorType.getLength(), vectorType.isNullable()); - } - - @Override - public DType visit(RowType rowType) { - return VortexTypeUtils.toStructDType(rowType, rowType.isNullable()); - } - - @Override - protected DType defaultMethod(DataType dataType) { - throw new UnsupportedOperationException( - "Vortex does not support type: " + dataType.asSQLString()); - } - } -} diff --git a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexWriterFactory.java b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexWriterFactory.java index f3de8fca7435..e3840acd366c 100644 --- a/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexWriterFactory.java +++ b/paimon-vortex/paimon-vortex-format/src/main/java/org/apache/paimon/format/vortex/VortexWriterFactory.java @@ -18,14 +18,13 @@ package org.apache.paimon.format.vortex; -import org.apache.paimon.arrow.vector.ArrowFormatWriter; +import org.apache.paimon.arrow.vector.ArrowFormatCWriter; import org.apache.paimon.format.FormatWriter; import org.apache.paimon.format.FormatWriterFactory; import org.apache.paimon.format.SupportsDirectWrite; import org.apache.paimon.fs.FileIO; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.PositionOutputStream; -import org.apache.paimon.types.RowType; import org.apache.paimon.utils.Pair; import java.io.IOException; @@ -37,13 +36,10 @@ /** A factory to create Vortex {@link FormatWriter}. */ public class VortexWriterFactory implements FormatWriterFactory, SupportsDirectWrite { - private final RowType rowType; - private final Supplier arrowFormatWriterSupplier; + private final Supplier cWriterSupplier; - public VortexWriterFactory( - RowType rowType, Supplier arrowFormatWriterSupplier) { - this.rowType = rowType; - this.arrowFormatWriterSupplier = arrowFormatWriterSupplier; + public VortexWriterFactory(Supplier cWriterSupplier) { + this.cWriterSupplier = cWriterSupplier; } @Override @@ -57,9 +53,6 @@ public FormatWriter create(PositionOutputStream positionOutputStream, String com public FormatWriter create(FileIO fileIO, Path path, String compression) throws IOException { Pair> vortexSpecified = toVortexSpecifiedForWriter(fileIO, path); return new VortexRecordsWriter( - rowType, - arrowFormatWriterSupplier.get(), - vortexSpecified.getLeft(), - vortexSpecified.getRight()); + cWriterSupplier, vortexSpecified.getLeft(), vortexSpecified.getRight()); } } diff --git a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexFileFormatReadWriteTest.java b/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexFileFormatReadWriteTest.java index a270d2e92f0d..a65916bb6cb2 100644 --- a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexFileFormatReadWriteTest.java +++ b/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexFileFormatReadWriteTest.java @@ -30,14 +30,10 @@ import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; -import dev.vortex.api.DType; -import org.junit.jupiter.api.BeforeAll; - import java.math.BigDecimal; import java.util.concurrent.ThreadLocalRandom; import static org.apache.paimon.data.BinaryString.fromString; -import static org.junit.jupiter.api.Assumptions.assumeTrue; /** Test for Vortex file format read/write using the base test framework. */ public class VortexFileFormatReadWriteTest extends FormatReadWriteTest { @@ -46,29 +42,6 @@ protected VortexFileFormatReadWriteTest() { super("vortex"); } - @BeforeAll - static void checkNativeLibrary() { - assumeTrue(isNativeAvailable(), "Vortex native library not available, skipping tests"); - } - - private static boolean isNativeAvailable() { - try { - dev.vortex.jni.NativeLoader.loadJni(); - return true; - } catch (Throwable t) { - return false; - } - } - - private static boolean isFixedSizeListSupported() { - try { - DType.newFixedSizeList(DType.newInt(false), 2, false); - return true; - } catch (Throwable t) { - return false; - } - } - @Override public boolean supportNestedReadPruning() { return false; diff --git a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexPredicateConverterTest.java b/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexPredicateConverterTest.java index 1b8965911047..b2f51ef20f8d 100644 --- a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexPredicateConverterTest.java +++ b/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexPredicateConverterTest.java @@ -18,28 +18,50 @@ package org.apache.paimon.format.vortex; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; import org.apache.paimon.data.Timestamp; +import org.apache.paimon.data.serializer.InternalRowSerializer; +import org.apache.paimon.format.FileFormatFactory; +import org.apache.paimon.format.FormatReaderContext; +import org.apache.paimon.format.FormatReaderFactory; +import org.apache.paimon.format.FormatWriter; +import org.apache.paimon.format.SupportsDirectWrite; +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.options.Options; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.reader.RecordReaderIterator; import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; import dev.vortex.api.Expression; -import dev.vortex.api.expressions.Binary; -import dev.vortex.api.expressions.GetItem; -import dev.vortex.api.expressions.Literal; -import dev.vortex.api.expressions.Not; -import dev.vortex.api.expressions.Root; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.Optional; +import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; -/** Tests for {@link VortexPredicateConverter}. */ +/** + * Tests for {@link VortexPredicateConverter}. + * + *

    Each test verifies two things: (1) the converted expression has a valid native pointer, and + * (2) the predicate produces correct results when used in a write-read round-trip through the + * Vortex format. + */ public class VortexPredicateConverterTest { private static final RowType ROW_TYPE = @@ -51,112 +73,99 @@ public class VortexPredicateConverterTest { private static final PredicateBuilder BUILDER = new PredicateBuilder(ROW_TYPE); - private static Expression field(String name) { - return GetItem.of(Root.INSTANCE, name); + private static void assertValidExpression(Expression expr) { + assertNotNull(expr, "Expression should not be null"); + assertTrue(expr.nativePointer() != 0, "Expression native pointer should be non-zero"); } + // -- Predicate conversion + native pointer validity tests -- + @Test public void testEqual() { - Predicate predicate = BUILDER.equal(0, 42); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.eq(field("f_int"), Literal.int32(42)), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.equal(0, 42))); + assertValidExpression(result); } @Test public void testNotEqual() { - Predicate predicate = BUILDER.notEqual(1, 100L); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.notEq(field("f_bigint"), Literal.int64(100L)), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.notEqual(1, 100L))); + assertValidExpression(result); } @Test public void testGreaterThan() { - Predicate predicate = BUILDER.greaterThan(0, 10); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.gt(field("f_int"), Literal.int32(10)), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.greaterThan(0, 10))); + assertValidExpression(result); } @Test public void testGreaterOrEqual() { - Predicate predicate = BUILDER.greaterOrEqual(0, 10); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.gtEq(field("f_int"), Literal.int32(10)), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.greaterOrEqual(0, 10))); + assertValidExpression(result); } @Test public void testLessThan() { - Predicate predicate = BUILDER.lessThan(1, 50L); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.lt(field("f_bigint"), Literal.int64(50L)), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.lessThan(1, 50L))); + assertValidExpression(result); } @Test public void testLessOrEqual() { - Predicate predicate = BUILDER.lessOrEqual(1, 50L); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.ltEq(field("f_bigint"), Literal.int64(50L)), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.lessOrEqual(1, 50L))); + assertValidExpression(result); } @Test public void testIsNull() { - Predicate predicate = BUILDER.isNull(0); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Not.of(Binary.notEq(field("f_int"), Literal.nullLit())), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.isNull(0))); + assertValidExpression(result); } @Test public void testIsNotNull() { - Predicate predicate = BUILDER.isNotNull(0); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.notEq(field("f_int"), Literal.nullLit()), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList(BUILDER.isNotNull(0))); + assertValidExpression(result); } @Test public void testAnd() { - Predicate p1 = BUILDER.greaterThan(0, 10); - Predicate p2 = BUILDER.lessThan(0, 100); - Predicate and = PredicateBuilder.and(p1, p2); + Predicate and = PredicateBuilder.and(BUILDER.greaterThan(0, 10), BUILDER.lessThan(0, 100)); Expression result = VortexPredicateConverter.toVortexExpression(Collections.singletonList(and)); - assertEquals( - Binary.and( - Binary.gt(field("f_int"), Literal.int32(10)), - Binary.lt(field("f_int"), Literal.int32(100))), - result); + assertValidExpression(result); } @Test public void testOr() { - Predicate p1 = BUILDER.equal(0, 1); - Predicate p2 = BUILDER.equal(0, 2); - Predicate or = PredicateBuilder.or(p1, p2); + Predicate or = PredicateBuilder.or(BUILDER.equal(0, 1), BUILDER.equal(0, 2)); Expression result = VortexPredicateConverter.toVortexExpression(Collections.singletonList(or)); - assertEquals( - Binary.or( - Binary.eq(field("f_int"), Literal.int32(1)), - Binary.eq(field("f_int"), Literal.int32(2))), - result); + assertValidExpression(result); } @Test public void testMultiplePredicatesAsAnd() { - Predicate p1 = BUILDER.greaterThan(0, 5); - Predicate p2 = BUILDER.lessThan(1, 200L); - Expression result = VortexPredicateConverter.toVortexExpression(Arrays.asList(p1, p2)); - assertEquals( - Binary.and( - Binary.gt(field("f_int"), Literal.int32(5)), - Binary.lt(field("f_bigint"), Literal.int64(200L))), - result); + Expression result = + VortexPredicateConverter.toVortexExpression( + Arrays.asList(BUILDER.greaterThan(0, 5), BUILDER.lessThan(1, 200L))); + assertValidExpression(result); } @Test @@ -171,74 +180,221 @@ public void testEmptyPredicates() { @Test public void testStringLiteral() { - Predicate predicate = - BUILDER.equal(2, org.apache.paimon.data.BinaryString.fromString("hello")); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals(Binary.eq(field("f_string"), Literal.string("hello")), result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList( + BUILDER.equal(2, BinaryString.fromString("hello")))); + assertValidExpression(result); + } + + @Test + public void testDecimalLiteral() { + RowType decRowType = RowType.builder().field("f_dec", DataTypes.DECIMAL(10, 2)).build(); + PredicateBuilder decBuilder = new PredicateBuilder(decRowType); + Expression result = + VortexPredicateConverter.toVortexExpression( + Collections.singletonList( + decBuilder.equal( + 0, + Decimal.fromBigDecimal(new BigDecimal("123.45"), 10, 2)))); + assertValidExpression(result); } @Test public void testTimestampMillisPrecision() { - // TIMESTAMP(3) should produce timestampMillis RowType tsRowType = RowType.builder().field("f_ts", DataTypes.TIMESTAMP(3)).build(); PredicateBuilder tsBuilder = new PredicateBuilder(tsRowType); - Timestamp ts = Timestamp.fromEpochMillis(123456789L); - Predicate predicate = tsBuilder.equal(0, ts); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals( - Binary.eq(field("f_ts"), Literal.timestampMillis(123456789L, Optional.empty())), - result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList( + tsBuilder.equal(0, Timestamp.fromEpochMillis(123456789L)))); + assertValidExpression(result); } @Test public void testTimestampMicrosPrecision() { - // TIMESTAMP(6) should produce timestampMicros RowType tsRowType = RowType.builder().field("f_ts", DataTypes.TIMESTAMP(6)).build(); PredicateBuilder tsBuilder = new PredicateBuilder(tsRowType); - // 123456789 millis + 123000 nanos = 123456789_123 micros - Timestamp ts = Timestamp.fromMicros(123456789123L); - Predicate predicate = tsBuilder.equal(0, ts); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals( - Binary.eq(field("f_ts"), Literal.timestampMicros(123456789123L, Optional.empty())), - result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList( + tsBuilder.equal(0, Timestamp.fromMicros(123456789123L)))); + assertValidExpression(result); } @Test public void testTimestampNanosPrecision() { - // TIMESTAMP(9) should produce timestampNanos RowType tsRowType = RowType.builder().field("f_ts", DataTypes.TIMESTAMP(9)).build(); PredicateBuilder tsBuilder = new PredicateBuilder(tsRowType); - Timestamp ts = Timestamp.fromEpochMillis(123456L, 789012); - Predicate predicate = tsBuilder.equal(0, ts); Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - // 123456 ms * 1_000_000 + 789012 nanos = 123456000789012 nanos - assertEquals( - Binary.eq( - field("f_ts"), - Literal.timestampNanos(123456L * 1_000_000 + 789012, Optional.empty())), - result); + VortexPredicateConverter.toVortexExpression( + Collections.singletonList( + tsBuilder.equal(0, Timestamp.fromEpochMillis(123456L, 789012)))); + assertValidExpression(result); } @Test public void testTimestampWithLocalTimeZone() { - // TIMESTAMP_LTZ(3) should produce timestampMillis with UTC RowType tsRowType = RowType.builder() .field("f_ts_ltz", DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)) .build(); PredicateBuilder tsBuilder = new PredicateBuilder(tsRowType); - Timestamp ts = Timestamp.fromEpochMillis(123456789L); - Predicate predicate = tsBuilder.equal(0, ts); - Expression result = - VortexPredicateConverter.toVortexExpression(Collections.singletonList(predicate)); - assertEquals( - Binary.eq( - field("f_ts_ltz"), Literal.timestampMillis(123456789L, Optional.of("UTC"))), - result); + Expression result = + VortexPredicateConverter.toVortexExpression( + Collections.singletonList( + tsBuilder.equal(0, Timestamp.fromEpochMillis(123456789L)))); + assertValidExpression(result); + } + + // -- Semantic round-trip tests: write data, read with predicate, verify filtering -- + + @Test + public void testEqualSemantic(@TempDir java.nio.file.Path tempDir) throws Exception { + // f_int == 2 should return only the row with f_int=2 + List rows = + roundTrip( + tempDir, + ROW_TYPE, + new GenericRow[] { + GenericRow.of(1, 10L, BinaryString.fromString("a")), + GenericRow.of(2, 20L, BinaryString.fromString("b")), + GenericRow.of(3, 30L, BinaryString.fromString("c")) + }, + Collections.singletonList(BUILDER.equal(0, 2))); + assertEquals(1, rows.size()); + assertEquals(2, rows.get(0).getInt(0)); + } + + @Test + public void testGreaterThanSemantic(@TempDir java.nio.file.Path tempDir) throws Exception { + // f_int > 2 should return rows with f_int=3,4 + List rows = + roundTrip( + tempDir, + ROW_TYPE, + new GenericRow[] { + GenericRow.of(1, 10L, BinaryString.fromString("a")), + GenericRow.of(2, 20L, BinaryString.fromString("b")), + GenericRow.of(3, 30L, BinaryString.fromString("c")), + GenericRow.of(4, 40L, BinaryString.fromString("d")) + }, + Collections.singletonList(BUILDER.greaterThan(0, 2))); + assertEquals(2, rows.size()); + assertEquals(3, rows.get(0).getInt(0)); + assertEquals(4, rows.get(1).getInt(0)); + } + + @Test + public void testGreaterOrEqualSemantic(@TempDir java.nio.file.Path tempDir) throws Exception { + // f_int >= 3 should return rows with f_int=3,4 + List rows = + roundTrip( + tempDir, + ROW_TYPE, + new GenericRow[] { + GenericRow.of(1, 10L, BinaryString.fromString("a")), + GenericRow.of(2, 20L, BinaryString.fromString("b")), + GenericRow.of(3, 30L, BinaryString.fromString("c")), + GenericRow.of(4, 40L, BinaryString.fromString("d")) + }, + Collections.singletonList(BUILDER.greaterOrEqual(0, 3))); + assertEquals(2, rows.size()); + assertEquals(3, rows.get(0).getInt(0)); + assertEquals(4, rows.get(1).getInt(0)); + } + + @Test + public void testAndSemantic(@TempDir java.nio.file.Path tempDir) throws Exception { + // f_int > 1 AND f_int < 4 should return rows 2,3 + Predicate and = PredicateBuilder.and(BUILDER.greaterThan(0, 1), BUILDER.lessThan(0, 4)); + List rows = + roundTrip( + tempDir, + ROW_TYPE, + new GenericRow[] { + GenericRow.of(1, 10L, BinaryString.fromString("a")), + GenericRow.of(2, 20L, BinaryString.fromString("b")), + GenericRow.of(3, 30L, BinaryString.fromString("c")), + GenericRow.of(4, 40L, BinaryString.fromString("d")) + }, + Collections.singletonList(and)); + assertEquals(2, rows.size()); + assertEquals(2, rows.get(0).getInt(0)); + assertEquals(3, rows.get(1).getInt(0)); + } + + @Test + public void testOrSemantic(@TempDir java.nio.file.Path tempDir) throws Exception { + // f_int == 1 OR f_int == 4 should return rows 1,4 + Predicate or = PredicateBuilder.or(BUILDER.equal(0, 1), BUILDER.equal(0, 4)); + List rows = + roundTrip( + tempDir, + ROW_TYPE, + new GenericRow[] { + GenericRow.of(1, 10L, BinaryString.fromString("a")), + GenericRow.of(2, 20L, BinaryString.fromString("b")), + GenericRow.of(3, 30L, BinaryString.fromString("c")), + GenericRow.of(4, 40L, BinaryString.fromString("d")) + }, + Collections.singletonList(or)); + assertEquals(2, rows.size()); + assertEquals(1, rows.get(0).getInt(0)); + assertEquals(4, rows.get(1).getInt(0)); + } + + @Test + public void testStringEqualSemantic(@TempDir java.nio.file.Path tempDir) throws Exception { + // f_string == "hello" + List rows = + roundTrip( + tempDir, + ROW_TYPE, + new GenericRow[] { + GenericRow.of(1, 10L, BinaryString.fromString("hello")), + GenericRow.of(2, 20L, BinaryString.fromString("world")), + }, + Collections.singletonList( + BUILDER.equal(2, BinaryString.fromString("hello")))); + assertEquals(1, rows.size()); + assertEquals(BinaryString.fromString("hello"), rows.get(0).getString(2)); + } + + private List roundTrip( + java.nio.file.Path tempDir, + RowType rowType, + GenericRow[] data, + List predicates) + throws Exception { + Options options = new Options(); + VortexFileFormat format = + new VortexFileFormatFactory() + .create(new FileFormatFactory.FormatContext(options, 1024, 1024)); + FileIO fileIO = new LocalFileIO(); + Path testFile = new Path(new Path(tempDir.toUri()), "predicate_test_" + System.nanoTime()); + + try (FormatWriter writer = + ((SupportsDirectWrite) format.createWriterFactory(rowType)) + .create(fileIO, testFile, "")) { + for (GenericRow row : data) { + writer.addElement(row); + } + } + + InternalRowSerializer serializer = new InternalRowSerializer(rowType); + FormatReaderFactory readerFactory = + format.createReaderFactory(rowType, rowType, predicates); + try (RecordReader reader = + readerFactory.createReader( + new FormatReaderContext( + fileIO, testFile, fileIO.getFileSize(testFile), null)); + RecordReaderIterator iterator = new RecordReaderIterator<>(reader)) { + List result = new ArrayList<>(); + while (iterator.hasNext()) { + result.add(serializer.copy(iterator.next())); + } + return result; + } } } diff --git a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexReaderWriterTest.java b/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexReaderWriterTest.java index 85230da5ff09..7ccca17c09e6 100644 --- a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexReaderWriterTest.java +++ b/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexReaderWriterTest.java @@ -44,7 +44,6 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.utils.RoaringBitmap32; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -58,25 +57,10 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assumptions.assumeTrue; /** Test read write for Vortex file format. */ public class VortexReaderWriterTest { - @BeforeAll - static void checkNativeLibrary() { - assumeTrue(isNativeAvailable(), "Vortex native library not available, skipping tests"); - } - - private static boolean isNativeAvailable() { - try { - dev.vortex.jni.NativeLoader.loadJni(); - return true; - } catch (Throwable t) { - return false; - } - } - @Test public void testWriteAndRead(@TempDir java.nio.file.Path tempDir) throws Exception { RowType rowType = RowType.of(DataTypes.INT(), DataTypes.STRING()); @@ -122,6 +106,56 @@ public void testWriteAndRead(@TempDir java.nio.file.Path tempDir) throws Excepti } } + @Test + public void testReadWithColumnProjection(@TempDir java.nio.file.Path tempDir) throws Exception { + RowType fullRowType = + RowType.builder() + .field("f_int", DataTypes.INT()) + .field("f_string", DataTypes.STRING()) + .field("f_double", DataTypes.DOUBLE()) + .build(); + + Options options = new Options(); + VortexFileFormat format = + new VortexFileFormatFactory() + .create(new FileFormatFactory.FormatContext(options, 1024, 1024)); + + FileIO fileIO = new LocalFileIO(); + Path testFile = new Path(new Path(tempDir.toUri()), "test_projection_" + UUID.randomUUID()); + + // Write 3 columns + try (FormatWriter writer = + ((SupportsDirectWrite) format.createWriterFactory(fullRowType)) + .create(fileIO, testFile, "")) { + writer.addElement(GenericRow.of(1, BinaryString.fromString("hello"), 1.5D)); + writer.addElement(GenericRow.of(2, BinaryString.fromString("world"), 2.5D)); + } + + InternalRowSerializer serializer; + + // Read only f_string column + RowType projectedRowType = RowType.builder().field("f_string", DataTypes.STRING()).build(); + serializer = new InternalRowSerializer(projectedRowType); + FormatReaderFactory readerFactory = + format.createReaderFactory(fullRowType, projectedRowType, null); + try (RecordReader reader = + readerFactory.createReader( + new FormatReaderContext( + fileIO, testFile, fileIO.getFileSize(testFile), null)); + RecordReaderIterator iterator = new RecordReaderIterator<>(reader)) { + + List actualRows = new ArrayList<>(); + while (iterator.hasNext()) { + actualRows.add(serializer.copy(iterator.next())); + } + + assertEquals(2, actualRows.size()); + assertEquals(1, actualRows.get(0).getFieldCount()); + assertEquals(BinaryString.fromString("hello"), actualRows.get(0).getString(0)); + assertEquals(BinaryString.fromString("world"), actualRows.get(1).getString(0)); + } + } + @Test public void testWriteAndReadMultipleTypes(@TempDir java.nio.file.Path tempDir) throws Exception { diff --git a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexTypeUtilsTest.java b/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexTypeUtilsTest.java deleted file mode 100644 index 8a06c273e545..000000000000 --- a/paimon-vortex/paimon-vortex-format/src/test/java/org/apache/paimon/format/vortex/VortexTypeUtilsTest.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.format.vortex; - -import org.apache.paimon.types.DataTypes; -import org.apache.paimon.types.RowType; - -import dev.vortex.api.DType; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assumptions.assumeTrue; - -/** Test for {@link VortexTypeUtils}. */ -public class VortexTypeUtilsTest { - - @BeforeAll - static void checkNativeLibrary() { - assumeTrue(isNativeAvailable(), "Vortex native library not available, skipping tests"); - } - - private static boolean isNativeAvailable() { - try { - dev.vortex.jni.NativeLoader.loadJni(); - return true; - } catch (Throwable t) { - return false; - } - } - - @Test - public void testSimpleTypes() { - RowType rowType = - RowType.builder() - .field("f_int", DataTypes.INT()) - .field("f_bigint", DataTypes.BIGINT()) - .field("f_string", DataTypes.STRING()) - .field("f_boolean", DataTypes.BOOLEAN()) - .field("f_double", DataTypes.DOUBLE()) - .field("f_float", DataTypes.FLOAT()) - .build(); - - DType dtype = VortexTypeUtils.toDType(rowType); - assertEquals(DType.Variant.STRUCT, dtype.getVariant()); - assertEquals(6, dtype.getFieldNames().size()); - assertEquals("f_int", dtype.getFieldNames().get(0)); - assertEquals("f_bigint", dtype.getFieldNames().get(1)); - assertEquals("f_string", dtype.getFieldNames().get(2)); - assertEquals("f_boolean", dtype.getFieldNames().get(3)); - assertEquals("f_double", dtype.getFieldNames().get(4)); - assertEquals("f_float", dtype.getFieldNames().get(5)); - } - - @Test - public void testDecimalType() { - RowType rowType = RowType.builder().field("f_decimal", DataTypes.DECIMAL(10, 2)).build(); - - DType dtype = VortexTypeUtils.toDType(rowType); - DType decimalField = dtype.getFieldTypes().get(0); - assertEquals(DType.Variant.DECIMAL, decimalField.getVariant()); - assertEquals(10, decimalField.getPrecision()); - assertEquals(2, decimalField.getScale()); - } - - @Test - public void testArrayType() { - RowType rowType = - RowType.builder().field("f_array", DataTypes.ARRAY(DataTypes.INT())).build(); - - DType dtype = VortexTypeUtils.toDType(rowType); - DType arrayField = dtype.getFieldTypes().get(0); - assertEquals(DType.Variant.LIST, arrayField.getVariant()); - } - - @Test - public void testNestedRowType() { - RowType innerType = - RowType.builder() - .field("inner_int", DataTypes.INT()) - .field("inner_str", DataTypes.STRING()) - .build(); - RowType rowType = RowType.builder().field("f_row", innerType).build(); - - DType dtype = VortexTypeUtils.toDType(rowType); - DType rowField = dtype.getFieldTypes().get(0); - assertEquals(DType.Variant.STRUCT, rowField.getVariant()); - assertEquals(2, rowField.getFieldNames().size()); - assertEquals("inner_int", rowField.getFieldNames().get(0)); - assertEquals("inner_str", rowField.getFieldNames().get(1)); - } - - @Test - public void testUnsupportedMapType() { - RowType rowType = - RowType.builder() - .field("f_map", DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())) - .build(); - - assertThrows(UnsupportedOperationException.class, () -> VortexTypeUtils.toDType(rowType)); - } - - @Test - public void testNullability() { - RowType rowType = - RowType.builder() - .field("f_nullable", DataTypes.INT().nullable()) - .field("f_not_null", DataTypes.INT().notNull()) - .build(); - - DType dtype = VortexTypeUtils.toDType(rowType); - assertEquals(true, dtype.getFieldTypes().get(0).isNullable()); - assertEquals(false, dtype.getFieldTypes().get(1).isNullable()); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/README.md b/paimon-vortex/paimon-vortex-jni/README.md index a8ba77a471f4..f9c3e5017407 100644 --- a/paimon-vortex/paimon-vortex-jni/README.md +++ b/paimon-vortex/paimon-vortex-jni/README.md @@ -4,76 +4,40 @@ This module provides Java JNI bindings for the [Vortex](https://github.com/spira ## Source Code Origin -The Java source code under `src/main/java/dev/vortex/` is copied from the upstream +The Java source code under `src/main/java/dev/vortex/` is adapted from the upstream [vortex-jni](https://github.com/spiraldb/vortex/tree/develop/java/vortex-jni) module (Apache License 2.0). The `dev.vortex.jni` package name is preserved to match the JNI native method signatures in the pre-compiled Rust library. Key adaptations from upstream: - Build system changed from Gradle to Maven. -- Protobuf `.proto` files copied to `src/main/proto/` and compiled via `protobuf-maven-plugin`. -- Target Java version set to 1.8 for Paimon compatibility. +- Target Java version set to 1.8 for Paimon compatibility (upstream requires JDK 17). +- `java.lang.ref.Cleaner` (JDK 9+) replaced with explicit `close()` via `AutoCloseable`. +- Protobuf-based expression serialization replaced with native pointer-based `Expression` API. +- Arrow memory allocation uses `Unsafe` allocator (via `arrow-memory-unsafe`) to guarantee + buffer alignment required by Vortex's Rust FFI. -## Building the Native Library +## Native Library -The module requires a platform-specific native library (`libvortex_jni`) built from the -Vortex Rust project. This library is **not** published to Maven Central and must be built -from source. +The native library (`libvortex_jni`) is automatically extracted from the published +[`dev.vortex:vortex-jni`](https://mvnrepository.com/artifact/dev.vortex/vortex-jni) +Maven artifact during the `generate-resources` phase via `maven-dependency-plugin:unpack`. -### Prerequisites - -- **Rust toolchain**: Install via [rustup](https://rustup.rs/) - ```bash - curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh - ``` -- **C compiler**: `gcc` (Linux) or Xcode Command Line Tools (macOS). - -### Build Steps - -```bash -# 1. Clone the Vortex repository -git clone https://github.com/spiraldb/vortex.git -cd vortex - -# 2. Build the native library -cargo build --package vortex-jni # debug build -cargo build --package vortex-jni --release # release build (optimized) - -# 3. Copy to resources (example for macOS Apple Silicon) -mkdir -p /path/to/paimon-vortex-jni/src/main/resources/native/darwin-aarch64 -cp target/debug/libvortex_jni.dylib \ - /path/to/paimon-vortex-jni/src/main/resources/native/darwin-aarch64/ -``` - -Platform-specific paths: +The published jar includes pre-built binaries for: | Platform | Directory | Library file | |-----------------------|----------------------|-------------------------| | macOS Apple Silicon | `darwin-aarch64` | `libvortex_jni.dylib` | -| macOS Intel | `darwin-x86_64` | `libvortex_jni.dylib` | | Linux x86_64 | `linux-amd64` | `libvortex_jni.so` | | Linux aarch64 | `linux-aarch64` | `libvortex_jni.so` | -You only need to provide the library for your current platform. - -### Expected Resources Layout - -``` -src/main/resources/native/ -├── darwin-aarch64/ -│ └── libvortex_jni.dylib -├── darwin-x86_64/ -│ └── libvortex_jni.dylib -├── linux-amd64/ -│ └── libvortex_jni.so -└── linux-aarch64/ - └── libvortex_jni.so -``` +No Rust toolchain or manual build is required. ## Verification ```bash -mvn test -pl paimon-vortex/paimon-vortex-jni -Dcheckstyle.skip=true -Dspotless.check.skip=true +mvn test -pl paimon-vortex/paimon-vortex-jni,paimon-vortex/paimon-vortex-format \ + -Dcheckstyle.skip=true -Dspotless.check.skip=true ``` Tests that require the native library use `assumeTrue` to check availability at runtime. @@ -83,7 +47,6 @@ If the library is not found, those tests will be skipped rather than fail. | Problem | Solution | |---------|----------| -| `UnsatisfiedLinkError` | Ensure the `.dylib`/`.so` is in the correct `native/{os}-{arch}/` directory. | -| `FileNotFoundException: Library not found` | Check `os.name` and `os.arch` system properties match the directory name. | -| Rust compilation fails | Ensure Rust toolchain is up-to-date: `rustup update`. | -| Cross-compilation needed | Use `cargo build --target `, e.g. `x86_64-unknown-linux-gnu`. | +| `UnsatisfiedLinkError` | Ensure `maven-dependency-plugin:unpack` ran (check `target/classes/native/`). | +| `FileNotFoundException: Library not found` | Check `os.name` and `os.arch` system properties match a supported platform. | +| `Memory pointer ... not aligned` | Ensure `arrow.allocation.manager.type=Unsafe` is set (done automatically by `NativeLoader`). | diff --git a/paimon-vortex/paimon-vortex-jni/pom.xml b/paimon-vortex/paimon-vortex-jni/pom.xml index 913741aa30e8..c6ca7a8d6e8e 100644 --- a/paimon-vortex/paimon-vortex-jni/pom.xml +++ b/paimon-vortex/paimon-vortex-jni/pom.xml @@ -33,8 +33,7 @@ under the License. 1.8 - 0.69.0 - 3.25.5 + 0.73.0 2.10.1 true @@ -64,15 +63,14 @@ under the License. org.apache.arrow - arrow-vector + arrow-memory-unsafe ${arrow.version} - - com.google.protobuf - protobuf-java - ${protobuf.version} + org.apache.arrow + arrow-vector + ${arrow.version} @@ -101,35 +99,30 @@ under the License. - - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.6.1 - - com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} - ${project.basedir}/src/main/proto - - - - - compile - - - - - - + - kr.motd.maven - os-maven-plugin - 1.7.1 + org.apache.maven.plugins + maven-dependency-plugin - initialize + unpack-vortex-native + generate-resources - detect + unpack + + + + dev.vortex + vortex-jni + ${vortex.jni.version} + jar + true + + + native/** + ${project.build.directory}/classes + diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Array.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Array.java deleted file mode 100644 index 369339215cd0..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Array.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api; - -import java.math.BigDecimal; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -/** Vortex array interface for accessing columnar data. */ -public interface Array extends AutoCloseable { - long getLen(); - - long nbytes(); - - VectorSchemaRoot exportToArrow(BufferAllocator allocator, VectorSchemaRoot reuse); - - DType getDataType(); - - Array getField(int index); - - Array slice(int start, int stop); - - boolean getNull(int index); - - int getNullCount(); - - byte getByte(int index); - - short getShort(int index); - - int getInt(int index); - - long getLong(int index); - - boolean getBool(int index); - - float getFloat(int index); - - double getDouble(int index); - - BigDecimal getBigDecimal(int index); - - String getUTF8(int index); - - void getUTF8_ptr_len(int index, long[] ptr, int[] len); - - byte[] getBinary(int index); - - @Override - void close(); -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DType.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DType.java deleted file mode 100644 index 1530595b440f..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DType.java +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api; - -import dev.vortex.jni.JNIDType; -import dev.vortex.jni.NativeDTypeMethods; -import java.util.List; -import java.util.Optional; - -/** Vortex logical type interface representing the schema and metadata for array data. */ -public interface DType extends AutoCloseable { - - Variant getVariant(); - - boolean isNullable(); - - List getFieldNames(); - - List getFieldTypes(); - - DType getElementType(); - - int getFixedSizeListSize(); - - boolean isDate(); - - boolean isTime(); - - boolean isTimestamp(); - - TimeUnit getTimeUnit(); - - Optional getTimeZone(); - - boolean isDecimal(); - - int getPrecision(); - - byte getScale(); - - @Override - void close(); - - static DType newByte(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newByte(isNullable), true); - } - - static DType newShort(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newShort(isNullable), true); - } - - static DType newInt(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newInt(isNullable), true); - } - - static DType newLong(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newLong(isNullable), true); - } - - static DType newFloat(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newFloat(isNullable), true); - } - - static DType newDouble(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newDouble(isNullable), true); - } - - static DType newDecimal(int precision, int scale, boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newDecimal(precision, scale, isNullable), true); - } - - static DType newUtf8(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newUtf8(isNullable), true); - } - - static DType newBinary(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newBinary(isNullable), true); - } - - static DType newBool(boolean isNullable) { - return new JNIDType(NativeDTypeMethods.newBool(isNullable), true); - } - - static DType newList(DType element, boolean isNullable) { - // Get the pointer - JNIDType jniType = (JNIDType) element; - return new JNIDType(NativeDTypeMethods.newList(jniType.getPointer(), isNullable), true); - } - - static DType newFixedSizeList(DType element, int size, boolean isNullable) { - JNIDType jniType = (JNIDType) element; - return new JNIDType(NativeDTypeMethods.newFixedSizeList(jniType.getPointer(), size, isNullable), true); - } - - static DType newStruct(String[] fieldNames, DType[] fieldTypes, boolean isNullable) { - long[] ptrs = new long[fieldTypes.length]; - for (int i = 0; i < fieldTypes.length; i++) { - ptrs[i] = ((JNIDType) fieldTypes[i]).getPointer(); - } - return new JNIDType(NativeDTypeMethods.newStruct(fieldNames, ptrs, isNullable), true); - } - - static DType newTimestamp(TimeUnit unit, Optional timeZone, boolean isNullable) { - byte timeUnit = unit.asByte(); - return new JNIDType(NativeDTypeMethods.newTimestamp(timeUnit, timeZone.orElse(null), isNullable), true); - } - - static DType newDate(TimeUnit unit, boolean isNullable) { - byte timeUnit = unit.asByte(); - return new JNIDType(NativeDTypeMethods.newDate(timeUnit, isNullable), true); - } - - static DType newTime(TimeUnit unit, boolean isNullable) { - byte timeUnit = unit.asByte(); - return new JNIDType(NativeDTypeMethods.newTime(timeUnit, isNullable), true); - } - - enum TimeUnit { - NANOSECONDS, - - MICROSECONDS, - - MILLISECONDS, - - SECONDS, - - DAYS, - ; - - public static TimeUnit from(byte unit) { - switch (unit) { - case 0: - return NANOSECONDS; - case 1: - return MICROSECONDS; - case 2: - return MILLISECONDS; - case 3: - return SECONDS; - case 4: - return DAYS; - default: - throw new IllegalArgumentException("Unknown TimeUnit: " + unit); - } - } - - public byte asByte() { - switch (this) { - case NANOSECONDS: - return 0; - case MICROSECONDS: - return 1; - case MILLISECONDS: - return 2; - case SECONDS: - return 3; - case DAYS: - return 4; - default: - throw new IllegalArgumentException("Unknown TimeUnit: " + this); - } - } - } - - enum Variant { - NULL, - - BOOL, - - PRIMITIVE_U8, - - PRIMITIVE_U16, - - PRIMITIVE_U32, - - PRIMITIVE_U64, - - PRIMITIVE_I8, - - PRIMITIVE_I16, - - PRIMITIVE_I32, - - PRIMITIVE_I64, - - PRIMITIVE_F16, - - PRIMITIVE_F32, - - PRIMITIVE_F64, - - UTF8, - - BINARY, - - STRUCT, - - LIST, - - EXTENSION, - - DECIMAL, - - FIXED_SIZE_LIST, - ; - - public static Variant from(byte variant) { - switch (variant) { - case 0: - return NULL; - case 1: - return BOOL; - case 2: - return PRIMITIVE_U8; - case 3: - return PRIMITIVE_U16; - case 4: - return PRIMITIVE_U32; - case 5: - return PRIMITIVE_U64; - case 6: - return PRIMITIVE_I8; - case 7: - return PRIMITIVE_I16; - case 8: - return PRIMITIVE_I32; - case 9: - return PRIMITIVE_I64; - case 10: - return PRIMITIVE_F16; - case 11: - return PRIMITIVE_F32; - case 12: - return PRIMITIVE_F64; - case 13: - return UTF8; - case 14: - return BINARY; - case 15: - return STRUCT; - case 16: - return LIST; - case 17: - return EXTENSION; - case 18: - return DECIMAL; - case 19: - return FIXED_SIZE_LIST; - default: - throw new IllegalArgumentException("Unknown DType variant: " + variant); - } - } - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DataSource.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DataSource.java new file mode 100644 index 000000000000..ffb9f514d8f9 --- /dev/null +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/DataSource.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.vortex.api; + +import dev.vortex.jni.NativeDataSource; +import dev.vortex.jni.NativeScan; + +import java.util.Map; +import java.util.OptionalLong; + +/** A Vortex data source opened from one or more URIs. */ +public final class DataSource implements AutoCloseable { + + private final Session session; + private long pointer; + + private DataSource(Session session, long pointer) { + this.session = session; + this.pointer = pointer; + } + + public static DataSource open(Session session, String uri, Map options) { + long ptr = + NativeDataSource.open( + session.nativePointer(), new String[] {uri}, options); + if (ptr == 0) { + throw new RuntimeException("Failed to open Vortex data source: " + uri); + } + return new DataSource(session, ptr); + } + + public long rowCount() { + long[] out = new long[2]; + NativeDataSource.rowCount(pointer, out); + // out[1]: 2=exact, 1=estimate, other=unknown + if (out[1] == 2 || out[1] == 1) { + return out[0]; + } + return -1; + } + + public Scan scan(ScanOptions options) { + long projectionPtr = + options.projection().isPresent() + ? options.projection().get().nativePointer() + : 0; + long filterPtr = + options.filter().isPresent() ? options.filter().get().nativePointer() : 0; + long rangeBegin = optionalLongOrZero(options.rowRangeBegin()); + long rangeEnd = optionalLongOrZero(options.rowRangeEnd()); + long[] selectionIndices = + options.selectionIndices().isPresent() + ? options.selectionIndices().get() + : null; + byte selectionMode = options.selectionMode().code(); + long limit = optionalLongOrZero(options.limit()); + boolean ordered = options.ordered(); + + long scanPtr = + NativeScan.create( + pointer, + projectionPtr, + filterPtr, + rangeBegin, + rangeEnd, + selectionIndices, + selectionMode, + limit, + ordered); + return Scan.fromPointer(session, scanPtr); + } + + @Override + public void close() { + if (pointer != 0) { + NativeDataSource.free(pointer); + pointer = 0; + } + } + + private static long optionalLongOrZero(OptionalLong opt) { + return opt.isPresent() ? opt.getAsLong() : 0; + } +} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Expression.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Expression.java index d3b6edf24120..ccff4e4aa23c 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Expression.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Expression.java @@ -18,41 +18,203 @@ package dev.vortex.api; -import dev.vortex.api.expressions.*; -import java.util.List; -import java.util.Optional; +import dev.vortex.jni.NativeExpression; -/** Vortex expression language. */ -public interface Expression { - String id(); +import java.math.BigInteger; - List children(); +/** A Vortex expression backed by a native pointer. */ +public final class Expression implements AutoCloseable { - Optional metadata(); + private long pointer; - default T accept(Visitor visitor) { - return visitor.visitOther(this); + private Expression(long pointer) { + this.pointer = pointer; } - interface Visitor { - T visitLiteral(Literal literal); + public long nativePointer() { + return pointer; + } - T visitRoot(Root root); + @Override + public void close() { + if (pointer != 0) { + NativeExpression.free(pointer); + pointer = 0; + } + } - T visitBinary(Binary binary); + // -- Structure navigation -- - T visitNot(Not not); + public static Expression root() { + return new Expression(NativeExpression.root()); + } - T visitGetItem(GetItem getItem); + public static Expression column(String name) { + long rootPtr = NativeExpression.root(); + return new Expression(NativeExpression.getItem(name, rootPtr)); + } - default T visitIsNull(IsNull isNull) { - return visitOther(isNull); + public static Expression select(String[] columns, Expression parent) { + return new Expression(NativeExpression.select(columns, parent.pointer)); + } + + // -- Logical combinators -- + + public static Expression and(Expression... exprs) { + long[] ptrs = new long[exprs.length]; + for (int i = 0; i < exprs.length; i++) { + ptrs[i] = exprs[i].pointer; } + return new Expression(NativeExpression.and(ptrs)); + } - default T visitIsNotNull(IsNotNull isNotNull) { - return visitOther(isNotNull); + public static Expression or(Expression... exprs) { + long[] ptrs = new long[exprs.length]; + for (int i = 0; i < exprs.length; i++) { + ptrs[i] = exprs[i].pointer; } + return new Expression(NativeExpression.or(ptrs)); + } + + public static Expression not(Expression expr) { + return new Expression(NativeExpression.not(expr.pointer)); + } + + // -- Comparison / binary ops -- + + public static Expression binary(BinaryOp op, Expression left, Expression right) { + return new Expression(NativeExpression.binary(op.code(), left.pointer, right.pointer)); + } + + // -- Null checks -- + + public static Expression isNull(Expression expr) { + return new Expression(NativeExpression.isNull(expr.pointer)); + } + + public static Expression isNotNull(Expression expr) { + return new Expression(NativeExpression.isNotNull(expr.pointer)); + } - T visitOther(Expression expression); + // -- Primitive literals -- + + public static Expression literal(boolean value) { + return new Expression(NativeExpression.literalBool(value, false)); + } + + public static Expression literal(byte value) { + return new Expression(NativeExpression.literalI8(value, false)); + } + + public static Expression literal(short value) { + return new Expression(NativeExpression.literalI16(value, false)); + } + + public static Expression literal(int value) { + return new Expression(NativeExpression.literalI32(value, false)); + } + + public static Expression literal(long value) { + return new Expression(NativeExpression.literalI64(value, false)); + } + + public static Expression literal(float value) { + return new Expression(NativeExpression.literalF32(value, false)); + } + + public static Expression literal(double value) { + return new Expression(NativeExpression.literalF64(value, false)); + } + + public static Expression literal(String value) { + return new Expression(NativeExpression.literalString(value)); + } + + // -- Decimal literals -- + + public static Expression literalDecimal(BigInteger unscaledValue, int precision, int scale) { + byte[] bytes = unscaledValue.toByteArray(); + return new Expression(NativeExpression.literalDecimal(bytes, precision, scale, false)); + } + + // -- Date/time literals -- + + public static Expression literalDate(long value, TimeUnit unit) { + return new Expression(NativeExpression.literalDate(value, unit.tag(), false)); + } + + public static Expression literalTimestamp(long value, TimeUnit unit, String timezone) { + return new Expression( + NativeExpression.literalTimestamp(value, unit.tag(), timezone, false)); + } + + // -- Null literals -- + + public static Expression nullLiteral(DType dtype) { + return new Expression(NativeExpression.literalNull(dtype.tag())); + } + + /** Binary operation codes. */ + public enum BinaryOp { + EQ((byte) 0), + NOT_EQ((byte) 1), + GT((byte) 2), + GTE((byte) 3), + LT((byte) 4), + LTE((byte) 5), + AND((byte) 6), + OR((byte) 7); + + private final byte code; + + BinaryOp(byte code) { + this.code = code; + } + + public byte code() { + return code; + } + } + + /** Data type tags for null literals. */ + public enum DType { + BOOL((byte) 0), + I8((byte) 1), + I16((byte) 2), + I32((byte) 3), + I64((byte) 4), + F32((byte) 5), + F64((byte) 6), + UTF8((byte) 7), + BINARY((byte) 8); + + private final byte tag; + + DType(byte tag) { + this.tag = tag; + } + + public byte tag() { + return tag; + } + } + + /** Time unit tags for date/time literals. */ + public enum TimeUnit { + NANOSECONDS((byte) 0), + MICROSECONDS((byte) 1), + MILLISECONDS((byte) 2), + SECONDS((byte) 3), + DAYS((byte) 4); + + private final byte tag; + + TimeUnit(byte tag) { + this.tag = tag; + } + + public byte tag() { + return tag; + } } } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Files.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Files.java deleted file mode 100644 index f812a0fbe96c..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Files.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api; - -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; -import dev.vortex.jni.JNIFile; -import dev.vortex.jni.NativeFileMethods; -import java.net.URI; -import java.nio.file.Paths; -import java.util.Map; - -/** Utility class for opening Vortex files. */ -public final class Files { - - private Files() {} - - public static File open(String path) { - return open(path, java.util.Collections.emptyMap()); - } - - public static File open(String path, Map properties) { - if (path.startsWith("/")) { - return open(Paths.get(path).toUri(), properties); - } - return open(URI.create(path), properties); - } - - public static File open(URI uri, Map properties) { - long ptr = NativeFileMethods.open(uri.toString(), properties); - Preconditions.checkArgument(ptr > 0, "Failed to open file: %s", uri); - return new JNIFile(ptr); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Partition.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Partition.java new file mode 100644 index 000000000000..83b5e39237e9 --- /dev/null +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Partition.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.vortex.api; + +import dev.vortex.jni.NativePartition; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; + +/** A partition from a Vortex scan, consumable once as an Arrow stream. */ +public final class Partition implements AutoCloseable { + + private final Session session; + private long pointer; + private boolean consumed; + + private Partition(Session session, long pointer) { + this.session = session; + this.pointer = pointer; + } + + static Partition fromPointer(Session session, long pointer) { + return new Partition(session, pointer); + } + + public ArrowReader scanArrow(BufferAllocator allocator) { + if (consumed) { + throw new IllegalStateException("partition already consumed"); + } + consumed = true; + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); + try { + NativePartition.scanArrow( + session.nativePointer(), pointer, stream.memoryAddress()); + } catch (RuntimeException e) { + stream.close(); + throw e; + } + return Data.importArrayStream(allocator, stream); + } + + @Override + public void close() { + // Only free unconsumed partitions. scanArrow() transfers the native + // partition's ownership to the ArrowArrayStream; the stream's release + // callback frees it when the ArrowReader is closed. + if (pointer != 0 && !consumed) { + NativePartition.free(pointer); + } + pointer = 0; + } +} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Scan.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Scan.java new file mode 100644 index 000000000000..260da3488836 --- /dev/null +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Scan.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.vortex.api; + +import dev.vortex.jni.NativePartition; +import dev.vortex.jni.NativeScan; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** An iterator over partitions from a Vortex scan. */ +public final class Scan implements Iterator, AutoCloseable { + + private final Session session; + private long pointer; + private long nextPartitionPointer; + private boolean primed; + + private Scan(Session session, long pointer) { + this.session = session; + this.pointer = pointer; + } + + static Scan fromPointer(Session session, long pointer) { + return new Scan(session, pointer); + } + + @Override + public boolean hasNext() { + if (primed) { + return nextPartitionPointer != 0; + } + nextPartitionPointer = NativeScan.nextPartition(pointer); + primed = true; + return nextPartitionPointer != 0; + } + + @Override + public Partition next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + long ptr = nextPartitionPointer; + nextPartitionPointer = 0; + primed = false; + return Partition.fromPointer(session, ptr); + } + + @Override + public void close() { + if (nextPartitionPointer != 0) { + NativePartition.free(nextPartitionPointer); + nextPartitionPointer = 0; + } + if (pointer != 0) { + NativeScan.free(pointer); + pointer = 0; + } + } +} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/ScanOptions.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/ScanOptions.java index 5aa24113400c..cb2b1992d2d0 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/ScanOptions.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/ScanOptions.java @@ -18,25 +18,51 @@ package dev.vortex.api; -import java.util.List; -import java.util.Optional; import org.immutables.value.Value; +import java.util.Optional; +import java.util.OptionalLong; + +/** Options for configuring a Vortex scan. */ @Value.Immutable public interface ScanOptions { - List columns(); - Optional predicate(); + Optional projection(); - Optional rowRange(); + Optional filter(); - Optional rowIndices(); + OptionalLong rowRangeBegin(); - static ScanOptions of() { - return ImmutableScanOptions.builder().build(); + OptionalLong rowRangeEnd(); + + Optional selectionIndices(); + + @Value.Default + default SelectionMode selectionMode() { + return SelectionMode.INCLUDE_ALL; } - static ImmutableScanOptions.Builder builder() { - return ImmutableScanOptions.builder(); + OptionalLong limit(); + + @Value.Default + default boolean ordered() { + return false; + } + + /** Selection mode for row indices. */ + enum SelectionMode { + INCLUDE_ALL((byte) 0), + INCLUDE((byte) 1), + EXCLUDE((byte) 2); + + private final byte code; + + SelectionMode(byte code) { + this.code = code; + } + + public byte code() { + return code; + } } } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Session.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Session.java new file mode 100644 index 000000000000..0d6d8d2ee383 --- /dev/null +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/Session.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.vortex.api; + +import dev.vortex.jni.NativeSession; + +/** A Vortex session that manages the native runtime. */ +public final class Session implements AutoCloseable { + + private long pointer; + + private Session(long pointer) { + this.pointer = pointer; + } + + public static Session create() { + long ptr = NativeSession.newSession(); + if (ptr == 0) { + throw new RuntimeException("Failed to create Vortex session"); + } + return new Session(ptr); + } + + public long nativePointer() { + return pointer; + } + + @Override + public void close() { + if (pointer != 0) { + NativeSession.free(pointer); + pointer = 0; + } + } +} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/VortexWriter.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/VortexWriter.java index 115355631647..bcdeebc25f99 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/VortexWriter.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/VortexWriter.java @@ -18,27 +18,79 @@ package dev.vortex.api; -import dev.vortex.jni.JNIDType; -import dev.vortex.jni.JNIWriter; -import dev.vortex.jni.NativeWriterMethods; +import dev.vortex.jni.NativeWriter; + +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; + import java.io.IOException; import java.util.Map; -/** Writer for creating Vortex files from Arrow data. */ -public interface VortexWriter extends AutoCloseable { +/** A Vortex file writer using the Arrow C Data Interface. */ +public final class VortexWriter implements AutoCloseable { - static VortexWriter create(String uri, DType dtype, Map options) throws IOException { - long ptr = NativeWriterMethods.create(uri, ((JNIDType) dtype).getPointer(), options); - if (ptr <= 0) { - throw new IOException("Failed to create Vortex writer for: " + uri + " (got ptr=" + ptr + ")"); - } - return new JNIWriter(ptr); + private final long pointer; + private boolean closed; + + private VortexWriter(long pointer) { + this.pointer = pointer; } - void writeBatch(byte[] arrowData) throws IOException; + public static VortexWriter create( + Session session, String uri, Schema arrowSchema, Map options) + throws IOException { + // Use a dedicated allocator for schema export to avoid leaking retained + // references into the caller's allocator (Arrow C Data Interface retain/release + // semantics don't fully clean up via ArrowSchema.close in Arrow 15). + RootAllocator schemaAllocator = new RootAllocator(Long.MAX_VALUE); + try { + ArrowSchema cSchema = ArrowSchema.allocateNew(schemaAllocator); + Data.exportSchema(schemaAllocator, arrowSchema, null, cSchema); + long ptr = + NativeWriter.create( + session.nativePointer(), uri, cSchema.memoryAddress(), options); + cSchema.close(); + if (ptr <= 0) { + throw new IOException("Failed to create Vortex writer for: " + uri); + } + return new VortexWriter(ptr); + } finally { + try { + schemaAllocator.close(); + } catch (IllegalStateException e) { + if (e.getMessage() == null || !e.getMessage().contains("Memory was leaked")) { + throw e; + } + } + } + } - void writeBatchFfi(long arrowArrayAddr, long arrowSchemaAddr) throws IOException; + public void writeBatch(long arrowArrayAddr, long arrowSchemaAddr) throws IOException { + if (closed) { + throw new IllegalStateException("writer already closed"); + } + boolean ok; + try { + ok = NativeWriter.writeBatch(pointer, arrowArrayAddr, arrowSchemaAddr); + } catch (RuntimeException e) { + throw new IOException("failed to write batch", e); + } + if (!ok) { + throw new IOException("NativeWriter.writeBatch returned false"); + } + } @Override - void close() throws IOException; + public void close() throws IOException { + if (!closed && pointer != 0) { + closed = true; + try { + NativeWriter.close(pointer); + } catch (RuntimeException e) { + throw new IOException("failed to close writer", e); + } + } + } } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Binary.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Binary.java deleted file mode 100644 index 26534c16eb3b..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Binary.java +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import com.google.protobuf.InvalidProtocolBufferException; -import dev.vortex.api.Expression; -import dev.vortex.proto.ExprProtos; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Stream; - -/** Binary expression operating on two child expressions with a binary operator. */ -public final class Binary implements Expression { - private final BinaryOp operator; - private final Expression left; - private final Expression right; - - private Binary(BinaryOp operator, Expression left, Expression right) { - this.operator = operator; - this.left = left; - this.right = right; - } - - public static Binary parse(byte[] metadata, List children) { - if (children.size() != 2) { - throw new IllegalArgumentException( - "Binary expression must have exactly two children, found: " + children.size()); - } - try { - ExprProtos.BinaryOpts opts = ExprProtos.BinaryOpts.parseFrom(metadata); - BinaryOp operator = BinaryOp.fromProto(opts.getOp()); - return new Binary(operator, children.get(0), children.get(1)); - } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException("Failed to parse Binary metadata", e); - } - } - - public static Binary of(BinaryOp operator, Expression left, Expression right) { - return new Binary(operator, left, right); - } - - public static Binary and(Expression first, Expression... rest) { - Expression rhs = Stream.of(rest).reduce(Binary::and).orElse(Literal.bool(true)); - return new Binary(BinaryOp.AND, first, rhs); - } - - public static Binary or(Expression first, Expression... rest) { - Expression rhs = Stream.of(rest).reduce(Binary::or).orElse(Literal.bool(false)); - return new Binary(BinaryOp.OR, first, rhs); - } - - public static Binary eq(Expression left, Expression right) { - return new Binary(BinaryOp.EQ, left, right); - } - - public static Binary notEq(Expression left, Expression right) { - return new Binary(BinaryOp.NOT_EQ, left, right); - } - - public static Binary gt(Expression left, Expression right) { - return new Binary(BinaryOp.GT, left, right); - } - - public static Binary gtEq(Expression left, Expression right) { - return new Binary(BinaryOp.GT_EQ, left, right); - } - - public static Binary lt(Expression left, Expression right) { - return new Binary(BinaryOp.LT, left, right); - } - - public static Binary ltEq(Expression left, Expression right) { - return new Binary(BinaryOp.LT_EQ, left, right); - } - - @Override - public String id() { - return "vortex.binary"; - } - - @Override - public List children() { - return java.util.Arrays.asList(left, right); - } - - @Override - public Optional metadata() { - return Optional.of(ExprProtos.BinaryOpts.newBuilder() - .setOp(operator.toProto()) - .build() - .toByteArray()); - } - - @Override - public String toString() { - return "(" + left + " " + operator + " " + right + ")"; - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) return false; - Binary binary = (Binary) o; - return operator == binary.operator && Objects.equals(left, binary.left) && Objects.equals(right, binary.right); - } - - @Override - public int hashCode() { - return Objects.hash(operator, left, right); - } - - @Override - public T accept(Visitor visitor) { - return visitor.visitBinary(this); - } - - public BinaryOp getOperator() { - return operator; - } - - public Expression getLeft() { - return left; - } - - public Expression getRight() { - return right; - } - - public enum BinaryOp { - /** Equality comparison operator (==) */ - EQ, - /** Inequality comparison operator (!=) */ - NOT_EQ, - /** Greater-than comparison operator (>) */ - GT, - /** Greater-than-or-equal comparison operator (>=) */ - GT_EQ, - /** Less-than comparison operator (<) */ - LT, - /** Less-than-or-equal comparison operator (<=) */ - LT_EQ, - /** Logical AND operator (&&) */ - AND, - /** Logical OR operator (||) */ - OR, - ; - - @Override - public String toString() { - switch (this) { - case EQ: - return "=="; - case NOT_EQ: - return "!="; - case GT: - return ">"; - case GT_EQ: - return ">="; - case LT: - return "<"; - case LT_EQ: - return "<="; - case AND: - return "&&"; - case OR: - return "||"; - default: - throw new IllegalStateException("Unknown Operator: " + this); - } - } - - static BinaryOp fromProto(ExprProtos.BinaryOpts.BinaryOp proto) { - switch (proto) { - case Eq: - return EQ; - case NotEq: - return NOT_EQ; - case Gt: - return GT; - case Gte: - return GT_EQ; - case Lt: - return LT; - case Lte: - return LT_EQ; - case And: - return AND; - case Or: - return OR; - default: - throw new IllegalArgumentException("Unsupported binary operator proto: " + proto); - } - } - - ExprProtos.BinaryOpts.BinaryOp toProto() { - switch (this) { - case EQ: - return ExprProtos.BinaryOpts.BinaryOp.Eq; - case NOT_EQ: - return ExprProtos.BinaryOpts.BinaryOp.NotEq; - case GT: - return ExprProtos.BinaryOpts.BinaryOp.Gt; - case GT_EQ: - return ExprProtos.BinaryOpts.BinaryOp.Gte; - case LT: - return ExprProtos.BinaryOpts.BinaryOp.Lt; - case LT_EQ: - return ExprProtos.BinaryOpts.BinaryOp.Lte; - case AND: - return ExprProtos.BinaryOpts.BinaryOp.And; - case OR: - return ExprProtos.BinaryOpts.BinaryOp.Or; - default: - throw new IllegalArgumentException("Unsupported binary operator: " + this); - } - } - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/GetItem.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/GetItem.java deleted file mode 100644 index ee13de5256eb..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/GetItem.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import com.google.protobuf.InvalidProtocolBufferException; -import dev.vortex.api.Expression; -import dev.vortex.proto.ExprProtos; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** Expression that extracts a field from a parent expression. */ -public final class GetItem implements Expression { - private final String path; - private final Expression child; - - private GetItem(Expression child, String path) { - this.child = child; - this.path = path; - } - - public static GetItem of(Expression child, String path) { - return new GetItem(child, path); - } - - public static GetItem parse(byte[] metadata, List children) { - if (children.size() != 1) { - throw new IllegalArgumentException( - "GetItem expression must have exactly one child, found: " + children.size()); - } - try { - ExprProtos.GetItemOpts opts = ExprProtos.GetItemOpts.parseFrom(metadata); - return new GetItem(children.get(0), opts.getPath()); - } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException("Failed to parse GetItem metadata", e); - } - } - - public Expression getChild() { - return child; - } - - public String getPath() { - return path; - } - - @Override - public String id() { - return "vortex.get_item"; - } - - @Override - public List children() { - return java.util.Collections.singletonList(child); - } - - @Override - public Optional metadata() { - return Optional.of( - ExprProtos.GetItemOpts.newBuilder().setPath(path).build().toByteArray()); - } - - @Override - public T accept(Visitor visitor) { - return visitor.visitGetItem(this); - } - - @Override - public boolean equals(Object o) { - if (!(o instanceof GetItem)) return false; - GetItem getItem = (GetItem) o; - return Objects.equals(path, getItem.path); - } - - @Override - public int hashCode() { - return Objects.hashCode(path); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNotNull.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNotNull.java deleted file mode 100644 index 1981ca1f2bd3..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNotNull.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import dev.vortex.api.Expression; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -public final class IsNotNull implements Expression { - private final Expression child; - - private IsNotNull(Expression child) { - this.child = child; - } - - public static IsNotNull parse(byte[] metadata, List children) { - if (children.size() != 1) { - throw new IllegalArgumentException( - "IsNotNull expression must have exactly one child, found: " + children.size()); - } - if (metadata.length > 0) { - throw new IllegalArgumentException( - "IsNotNull expression must not have metadata, found: " + metadata.length); - } - return new IsNotNull(children.get(0)); - } - - public static IsNotNull of(Expression child) { - return new IsNotNull(child); - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) return false; - IsNotNull other = (IsNotNull) o; - return Objects.equals(child, other.child); - } - - @Override - public int hashCode() { - return Objects.hash(child); - } - - @Override - public String id() { - return "vortex.is_not_null"; - } - - @Override - public List children() { - return java.util.Collections.singletonList(child); - } - - @Override - public Optional metadata() { - return Optional.of(new byte[] {}); - } - - @Override - public String toString() { - return "vortex.is_not_null(" + child + ")"; - } - - public Expression getChild() { - return child; - } - - @Override - public T accept(Visitor visitor) { - return visitor.visitIsNotNull(this); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNull.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNull.java deleted file mode 100644 index 50603e7909f8..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/IsNull.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import dev.vortex.api.Expression; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -public final class IsNull implements Expression { - private final Expression child; - - private IsNull(Expression child) { - this.child = child; - } - - public static IsNull parse(byte[] metadata, List children) { - if (children.size() != 1) { - throw new IllegalArgumentException( - "IsNull expression must have exactly one child, found: " + children.size()); - } - if (metadata.length > 0) { - throw new IllegalArgumentException("IsNull expression must not have metadata, found: " + metadata.length); - } - return new IsNull(children.get(0)); - } - - public static IsNull of(Expression child) { - return new IsNull(child); - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) return false; - IsNull other = (IsNull) o; - return Objects.equals(child, other.child); - } - - @Override - public int hashCode() { - return Objects.hash(child); - } - - @Override - public String id() { - return "vortex.is_null"; - } - - @Override - public List children() { - return java.util.Collections.singletonList(child); - } - - @Override - public Optional metadata() { - return Optional.of(new byte[] {}); - } - - @Override - public String toString() { - return "vortex.is_null(" + child + ")"; - } - - public Expression getChild() { - return child; - } - - @Override - public T accept(Visitor visitor) { - return visitor.visitIsNull(this); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Literal.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Literal.java deleted file mode 100644 index a4a8a7e26d09..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Literal.java +++ /dev/null @@ -1,781 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import dev.vortex.api.Expression; -import dev.vortex.api.proto.EndianUtils; -import dev.vortex.api.proto.Scalars; -import dev.vortex.api.proto.TemporalMetadatas; -import dev.vortex.proto.DTypeProtos; -import dev.vortex.proto.ExprProtos; -import dev.vortex.proto.ScalarProtos; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** Literal value expression in the Vortex query system. */ -public abstract class Literal implements Expression { - private final T value; - - private Literal(T value) { - this.value = value; - } - - public static Literal parse(byte[] metadata, List children) { - if (!children.isEmpty()) { - throw new IllegalArgumentException("Literal expression must have no children, found: " + children.size()); - } - try { - ExprProtos.LiteralOpts opts = ExprProtos.LiteralOpts.parseFrom(metadata); - return deserializeLiteral(opts, children); - } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException("Failed to parse literal metadata", e); - } - } - - public T getValue() { - return this.value; - } - - @Override - public String id() { - return "vortex.literal"; - } - - @Override - public List children() { - return java.util.Collections.emptyList(); - } - - @Override - public Optional metadata() { - return Optional.of(ExprProtos.LiteralOpts.newBuilder() - .setValue(this.acceptLiteralVisitor(LiteralToScalar.INSTANCE)) - .build() - .toByteArray()); - } - - @Override - public int hashCode() { - return Objects.hashCode(getValue()); - } - - @Override - public boolean equals(Object o) { - if (!(o instanceof Literal)) return false; - Literal literal = (Literal) o; - return Objects.equals(value, literal.value); - } - - public static Literal nullLit() { - return NullLiteral.INSTANCE; - } - - public static Literal bool(Boolean value) { - return new BooleanLiteral(value); - } - - public static Literal int8(Byte value) { - return new Int8Literal(value); - } - - public static Literal int16(Short value) { - return new Int16Literal(value); - } - - public static Literal int32(Integer value) { - return new Int32Literal(value); - } - - public static Literal int64(Long value) { - return new Int64Literal(value); - } - - public static Literal float32(Float value) { - return new Float32Literal(value); - } - - public static Literal float64(Double value) { - return new Float64Literal(value); - } - - public static Literal decimal(BigDecimal value, int precision, int scale) { - return new DecimalLiteral(value, precision, scale); - } - - public static Literal string(String value) { - return new StringLiteral(value); - } - - public static Literal bytes(byte[] value) { - return new BytesLiteral(value); - } - - public static Literal timeSeconds(Integer value) { - return new TimeSeconds(value); - } - - public static Literal timeMillis(Integer value) { - return new TimeMillis(value); - } - - public static Literal timeMicros(Long value) { - return new TimeMicros(value); - } - - public static Literal timeNanos(Long value) { - return new TimeNanos(value); - } - - public static Literal dateDays(Integer value) { - return new DateDays(value); - } - - public static Literal dateMillis(Long value) { - return new DateMillis(value); - } - - public static Literal timestampMillis(Long value, Optional timeZone) { - return new TimestampMillis(value, timeZone); - } - - public static Literal timestampMicros(Long value, Optional timeZone) { - return new TimestampMicros(value, timeZone); - } - - public static Literal timestampNanos(Long value, Optional timeZone) { - return new TimestampNanos(value, timeZone); - } - - @Override - public R accept(Expression.Visitor visitor) { - return visitor.visitLiteral(this); - } - - public abstract U acceptLiteralVisitor(LiteralVisitor visitor); - - public interface LiteralVisitor { - U visitNull(); - - U visitBoolean(Boolean literal); - - U visitInt8(Byte literal); - - U visitInt16(Short literal); - - U visitInt32(Integer literal); - - U visitInt64(Long literal); - - U visitDateDays(Integer days); - - U visitDateMillis(Long millis); - - U visitTimeSeconds(Integer seconds); - - U visitTimeMillis(Integer seconds); - - U visitTimeMicros(Long seconds); - - U visitTimeNanos(Long seconds); - - U visitTimestampMillis(Long epochMillis, Optional timeZone); - - U visitTimestampMicros(Long epochMicros, Optional timeZone); - - U visitTimestampNanos(Long epochNanos, Optional timeZone); - - U visitFloat32(Float literal); - - U visitFloat64(Double literal); - - U visitDecimal(BigDecimal decimal, int precision, int scale); - - U visitString(String literal); - - U visitBytes(byte[] literal); - } - - static final class NullLiteral extends Literal { - static final NullLiteral INSTANCE = new NullLiteral(); - - private NullLiteral() { - super(null); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitNull(); - } - } - - static final class BooleanLiteral extends Literal { - BooleanLiteral(Boolean value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitBoolean(getValue()); - } - } - - static final class Int8Literal extends Literal { - Int8Literal(Byte value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitInt8(getValue()); - } - } - - static final class Int16Literal extends Literal { - Int16Literal(Short value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitInt16(getValue()); - } - } - - static final class Int32Literal extends Literal { - Int32Literal(Integer value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitInt32(getValue()); - } - } - - static final class Int64Literal extends Literal { - Int64Literal(Long value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitInt64(getValue()); - } - } - - static final class Float32Literal extends Literal { - Float32Literal(Float value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitFloat32(getValue()); - } - } - - static final class Float64Literal extends Literal { - Float64Literal(Double value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitFloat64(getValue()); - } - } - - static final class DecimalLiteral extends Literal { - private final int precision; - private final int scale; - - DecimalLiteral(BigDecimal value, int precision, int scale) { - super(value); - if (!Objects.isNull(value)) { - Preconditions.checkArgument(scale == value.scale(), "scale %s ≠ value scale %s", scale, value.scale()); - } - this.precision = precision; - this.scale = scale; - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitDecimal(getValue(), precision, scale); - } - } - - static final class StringLiteral extends Literal { - StringLiteral(String value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitString(getValue()); - } - } - - static final class BytesLiteral extends Literal { - BytesLiteral(byte[] value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitBytes(getValue()); - } - } - - static final class TimeSeconds extends Literal { - TimeSeconds(Integer value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitTimeSeconds(getValue()); - } - } - - static final class TimeMillis extends Literal { - TimeMillis(Integer value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitTimeMillis(getValue()); - } - } - - static final class TimeMicros extends Literal { - TimeMicros(Long value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitTimeMicros(getValue()); - } - } - - static final class TimeNanos extends Literal { - TimeNanos(Long value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitTimeNanos(getValue()); - } - } - - static final class DateDays extends Literal { - DateDays(Integer value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitDateDays(getValue()); - } - } - - static final class DateMillis extends Literal { - DateMillis(Long value) { - super(value); - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitDateMillis(getValue()); - } - } - - static final class TimestampMillis extends Literal { - private final Optional timeZone; - - TimestampMillis(Long value, Optional timeZone) { - super(value); - this.timeZone = timeZone; - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitTimestampMillis(getValue(), timeZone); - } - } - - static final class TimestampMicros extends Literal { - private final Optional timeZone; - - TimestampMicros(Long value, Optional timeZone) { - super(value); - this.timeZone = timeZone; - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitTimestampMicros(getValue(), timeZone); - } - } - - static final class TimestampNanos extends Literal { - private final Optional timeZone; - - TimestampNanos(Long value, Optional timeZone) { - super(value); - this.timeZone = timeZone; - } - - @Override - public U acceptLiteralVisitor(LiteralVisitor visitor) { - return visitor.visitTimestampNanos(getValue(), timeZone); - } - } - - static final class LiteralToScalar implements LiteralVisitor { - static final LiteralToScalar INSTANCE = new LiteralToScalar(); - - private LiteralToScalar() {} - - @Override - public ScalarProtos.Scalar visitNull() { - return Scalars.nullNull(); - } - - @Override - public ScalarProtos.Scalar visitBoolean(Boolean literal) { - if (Objects.isNull(literal)) { - return Scalars.nullBool(); - } else { - return Scalars.bool(literal); - } - } - - @Override - public ScalarProtos.Scalar visitInt8(Byte literal) { - if (Objects.isNull(literal)) { - return Scalars.nullInt8(); - } else { - return Scalars.int8(literal); - } - } - - @Override - public ScalarProtos.Scalar visitInt16(Short literal) { - if (Objects.isNull(literal)) { - return Scalars.nullInt16(); - } else { - return Scalars.int16(literal); - } - } - - @Override - public ScalarProtos.Scalar visitInt32(Integer literal) { - if (Objects.isNull(literal)) { - return Scalars.nullInt32(); - } else { - return Scalars.int32(literal); - } - } - - @Override - public ScalarProtos.Scalar visitInt64(Long literal) { - if (Objects.isNull(literal)) { - return Scalars.nullInt64(); - } else { - return Scalars.int64(literal); - } - } - - @Override - public ScalarProtos.Scalar visitDateDays(Integer days) { - if (Objects.isNull(days)) { - return Scalars.nullDateDays(); - } else { - return Scalars.dateDays(days); - } - } - - @Override - public ScalarProtos.Scalar visitDateMillis(Long millis) { - if (Objects.isNull(millis)) { - return Scalars.nullDateMillis(); - } else { - return Scalars.dateMillis(millis); - } - } - - @Override - public ScalarProtos.Scalar visitTimeSeconds(Integer seconds) { - if (Objects.isNull(seconds)) { - return Scalars.nullTimeSeconds(); - } else { - return Scalars.timeSeconds(seconds); - } - } - - @Override - public ScalarProtos.Scalar visitTimeMillis(Integer seconds) { - if (Objects.isNull(seconds)) { - return Scalars.nullTimeMillis(); - } else { - return Scalars.timeMillis(seconds); - } - } - - @Override - public ScalarProtos.Scalar visitTimeMicros(Long seconds) { - if (Objects.isNull(seconds)) { - return Scalars.nullTimeMicros(); - } else { - return Scalars.timeMicros(seconds); - } - } - - @Override - public ScalarProtos.Scalar visitTimeNanos(Long seconds) { - if (Objects.isNull(seconds)) { - return Scalars.nullTimeNanos(); - } else { - return Scalars.timeNanos(seconds); - } - } - - @Override - public ScalarProtos.Scalar visitTimestampMillis(Long epochMillis, Optional timeZone) { - if (Objects.isNull(epochMillis)) { - return Scalars.nullTimestampMillis(timeZone); - } else { - return Scalars.timestampMillis(epochMillis, timeZone); - } - } - - @Override - public ScalarProtos.Scalar visitTimestampMicros(Long epochMicros, Optional timeZone) { - if (Objects.isNull(epochMicros)) { - return Scalars.nullTimestampMicros(timeZone); - } else { - return Scalars.timestampMicros(epochMicros, timeZone); - } - } - - @Override - public ScalarProtos.Scalar visitTimestampNanos(Long epochNanos, Optional timeZone) { - if (Objects.isNull(epochNanos)) { - return Scalars.nullTimestampNanos(timeZone); - } else { - return Scalars.timestampNanos(epochNanos, timeZone); - } - } - - @Override - public ScalarProtos.Scalar visitFloat32(Float literal) { - if (Objects.isNull(literal)) { - return Scalars.nullFloat32(); - } else { - return Scalars.float32(literal); - } - } - - @Override - public ScalarProtos.Scalar visitFloat64(Double literal) { - if (Objects.isNull(literal)) { - return Scalars.nullFloat64(); - } else { - return Scalars.float64(literal); - } - } - - @Override - public ScalarProtos.Scalar visitDecimal(BigDecimal decimal, int precision, int scale) { - if (Objects.isNull(decimal)) { - return Scalars.nullDecimal(precision, scale); - } else { - return Scalars.decimal(decimal, precision, scale); - } - } - - @Override - public ScalarProtos.Scalar visitString(String literal) { - if (Objects.isNull(literal)) { - return Scalars.nullString(); - } else { - return Scalars.string(literal); - } - } - - @Override - public ScalarProtos.Scalar visitBytes(byte[] literal) { - if (Objects.isNull(literal)) { - return Scalars.nullBytes(); - } else { - return Scalars.bytes(literal); - } - } - } - - private static Literal deserializeLiteral(ExprProtos.LiteralOpts literal, List children) { - ScalarProtos.Scalar literalScalar = literal.getValue(); - DTypeProtos.DType dtype = literalScalar.getDtype(); - - // Special handling of extension types - if (dtype.hasExtension()) { - return deserializeExtensionLiteral(literal); - } - - ScalarProtos.ScalarValue scalarValue = literalScalar.getValue(); - - switch (scalarValue.getKindCase()) { - case NULL_VALUE: - return nullLiteral(dtype); - case BOOL_VALUE: - return Literal.bool(scalarValue.getBoolValue()); - case INT64_VALUE: - return Literal.int64(scalarValue.getInt64Value()); - case UINT64_VALUE: - return Literal.int64(scalarValue.getUint64Value()); - case F32_VALUE: - return Literal.float32(scalarValue.getF32Value()); - case F64_VALUE: - return Literal.float64(scalarValue.getF64Value()); - case STRING_VALUE: - return Literal.string(scalarValue.getStringValue()); - case BYTES_VALUE: - if (dtype.hasDecimal()) { - ByteString littleEndian = scalarValue.getBytesValue(); - byte[] bigEndian = EndianUtils.reverse(littleEndian); - BigDecimal value = new BigDecimal( - new BigInteger(bigEndian), dtype.getDecimal().getScale()); - return Literal.decimal( - value, - dtype.getDecimal().getPrecision(), - dtype.getDecimal().getScale()); - } else { - return Literal.bytes(scalarValue.getBytesValue().toByteArray()); - } - default: - throw new UnsupportedOperationException("Unsupported ScalarValue type encountered: " + scalarValue); - } - } - - private static Literal deserializeExtensionLiteral(ExprProtos.LiteralOpts literal) { - ScalarProtos.Scalar scalar = literal.getValue(); - DTypeProtos.DType scalarType = scalar.getDtype(); - - Preconditions.checkArgument(scalarType.hasExtension()); - - DTypeProtos.Extension extType = scalarType.getExtension(); - String extId = scalarType.getExtension().getId(); - - switch (extId) { - case "vortex.time": { - byte timeUnit = - TemporalMetadatas.getTimeUnit(extType.getMetadata().toByteArray()); - if (timeUnit == TemporalMetadatas.TIME_UNIT_SECONDS) { - return Literal.timeSeconds(Math.toIntExact(scalar.getValue().getInt64Value())); - } else if (timeUnit == TemporalMetadatas.TIME_UNIT_MILLIS) { - return Literal.timeMillis(Math.toIntExact(scalar.getValue().getInt64Value())); - } else if (timeUnit == TemporalMetadatas.TIME_UNIT_MICROS) { - return Literal.timeMicros(scalar.getValue().getInt64Value()); - } else if (timeUnit == TemporalMetadatas.TIME_UNIT_NANOS) { - return Literal.timeNanos(scalar.getValue().getInt64Value()); - } else { - throw new UnsupportedOperationException("Unsupported TIME time unit: " + timeUnit); - } - } - case "vortex.date": { - byte timeUnit = - TemporalMetadatas.getTimeUnit(extType.getMetadata().toByteArray()); - if (timeUnit == TemporalMetadatas.TIME_UNIT_DAYS) { - return Literal.dateDays(Math.toIntExact(scalar.getValue().getInt64Value())); - } else if (timeUnit == TemporalMetadatas.TIME_UNIT_MILLIS) { - return Literal.dateMillis(scalar.getValue().getInt64Value()); - } else { - throw new UnsupportedOperationException("Unsupported DATE time unit: " + timeUnit); - } - } - case "vortex.timestamp": { - byte timeUnit = - TemporalMetadatas.getTimeUnit(extType.getMetadata().toByteArray()); - Optional timeZone = - TemporalMetadatas.getTimeZone(extType.getMetadata().toByteArray()); - if (timeUnit == TemporalMetadatas.TIME_UNIT_MILLIS) { - return Literal.timestampMillis(scalar.getValue().getInt64Value(), timeZone); - } else if (timeUnit == TemporalMetadatas.TIME_UNIT_MICROS) { - return Literal.timestampMicros(scalar.getValue().getInt64Value(), timeZone); - } else if (timeUnit == TemporalMetadatas.TIME_UNIT_NANOS) { - return Literal.timestampNanos(scalar.getValue().getInt64Value(), timeZone); - } else { - throw new UnsupportedOperationException("Unsupported TIMESTAMP time unit: " + timeUnit); - } - } - default: - throw new UnsupportedOperationException("Unsupported extension type: " + extId); - } - } - - private static Literal nullLiteral(DTypeProtos.DType type) { - switch (type.getDtypeTypeCase()) { - case NULL: - return Literal.nullLit(); - case BOOL: - return Literal.bool(null); - case PRIMITIVE: - switch (type.getPrimitive().getType()) { - case U8: - case I8: - return Literal.int8(null); - case U16: - case I16: - return Literal.int16(null); - case U32: - case I32: - return Literal.int32(null); - case U64: - case I64: - return Literal.int64(null); - case F32: - return Literal.float32(null); - case F64: - return Literal.float64(null); - default: - throw new UnsupportedOperationException("Unsupported ScalarValue type encountered: " + type); - } - case DECIMAL: - return Literal.decimal( - null, - type.getDecimal().getPrecision(), - type.getDecimal().getScale()); - case UTF8: - return Literal.string(null); - case BINARY: - return Literal.bytes(null); - default: - throw new UnsupportedOperationException("Unsupported ScalarValue type encountered: " + type); - } - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Not.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Not.java deleted file mode 100644 index 8e1df113b9a2..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Not.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import dev.vortex.api.Expression; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** Logical NOT expression that negates its child expression. */ -public final class Not implements Expression { - private final Expression child; - - private Not(Expression child) { - this.child = child; - } - - public static Not parse(byte[] metadata, List children) { - if (children.size() != 1) { - throw new IllegalArgumentException("Not expression must have exactly one child, found: " + children.size()); - } - if (metadata.length > 0) { - throw new IllegalArgumentException("Not expression must not have metadata, found: " + metadata.length); - } - return new Not(children.get(0)); - } - - public static Not of(Expression child) { - return new Not(child); - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) return false; - Not other = (Not) o; - return Objects.equals(child, other.child); - } - - @Override - public int hashCode() { - return Objects.hash(child); - } - - @Override - public String id() { - return "vortex.not"; - } - - @Override - public List children() { - return java.util.Collections.singletonList(child); - } - - @Override - public Optional metadata() { - return Optional.of(new byte[] {}); - } - - @Override - public String toString() { - return "not(" + child + ")"; - } - - public Expression getChild() { - return child; - } - - @Override - public T accept(Visitor visitor) { - return visitor.visitNot(this); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Root.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Root.java deleted file mode 100644 index a92125b4147a..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Root.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import dev.vortex.api.Expression; -import java.util.List; -import java.util.Optional; - -/** Root expression in a Vortex expression tree (singleton). */ -public final class Root implements Expression { - public static final Root INSTANCE = new Root(); - - private Root() {} - - public static Root parse(byte[] _metadata, List children) { - if (!children.isEmpty()) { - throw new IllegalArgumentException("Root expression must have no children, found: " + children.size()); - } - return INSTANCE; - } - - @Override - public String id() { - return "vortex.root"; - } - - @Override - public List children() { - return java.util.Collections.emptyList(); - } - - @Override - public Optional metadata() { - return Optional.of(new byte[] {}); - } - - @Override - public String toString() { - return "$"; - } - - // equals and hashCode depend on address equality to INSTANCE. - - @Override - public T accept(Visitor visitor) { - return visitor.visitRoot(this); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Unknown.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Unknown.java deleted file mode 100644 index d3f2c1363113..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/expressions/Unknown.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.expressions; - -import dev.vortex.api.Expression; -import java.util.List; -import java.util.Optional; - -/** Generic expression deserialized from Vortex without a concrete Java type. */ -public final class Unknown implements Expression { - private final String id; - private final List children; - private final byte[] metadata; - - public Unknown(String id, List children, byte[] metadata) { - this.id = id; - this.children = children; - this.metadata = metadata; - } - - @Override - public String id() { - return id; - } - - @Override - public List children() { - return children; - } - - @Override - public Optional metadata() { - return Optional.of(metadata); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/DTypes.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/DTypes.java deleted file mode 100644 index 5d0cf6d08c82..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/DTypes.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.proto; - -import static dev.vortex.api.proto.TemporalMetadatas.TIME_UNIT_MICROS; -import static dev.vortex.api.proto.TemporalMetadatas.TIME_UNIT_NANOS; - -import com.google.protobuf.ByteString; -import dev.vortex.proto.DTypeProtos; -import java.util.Optional; - -/** Factory class for creating Vortex data type definitions. */ -public final class DTypes { - private DTypes() {} - - static DTypeProtos.DType nullType() { - return DTypeProtos.DType.newBuilder() - .setNull(DTypeProtos.Null.newBuilder().build()) - .build(); - } - - static DTypeProtos.DType bool(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setBool(DTypeProtos.Bool.newBuilder().setNullable(nullable).build()) - .build(); - } - - static DTypeProtos.DType int8(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setPrimitive(DTypeProtos.Primitive.newBuilder() - .setType(DTypeProtos.PType.I8) - .setNullable(nullable) - .build()) - .build(); - } - - static DTypeProtos.DType int16(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setPrimitive(DTypeProtos.Primitive.newBuilder() - .setType(DTypeProtos.PType.I16) - .setNullable(nullable) - .build()) - .build(); - } - - static DTypeProtos.DType int32(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setPrimitive(DTypeProtos.Primitive.newBuilder() - .setType(DTypeProtos.PType.I32) - .setNullable(nullable) - .build()) - .build(); - } - - static DTypeProtos.DType int64(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setPrimitive(DTypeProtos.Primitive.newBuilder() - .setType(DTypeProtos.PType.I64) - .setNullable(nullable) - .build()) - .build(); - } - - static DTypeProtos.DType float32(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setPrimitive(DTypeProtos.Primitive.newBuilder() - .setType(DTypeProtos.PType.F32) - .setNullable(nullable) - .build()) - .build(); - } - - static DTypeProtos.DType float64(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setPrimitive(DTypeProtos.Primitive.newBuilder() - .setType(DTypeProtos.PType.F64) - .setNullable(nullable) - .build()) - .build(); - } - - static DTypeProtos.DType decimal(boolean nullable, int precision, int scale) { - return DTypeProtos.DType.newBuilder() - .setDecimal(DTypeProtos.Decimal.newBuilder() - .setNullable(nullable) - .setPrecision(precision) - .setScale(scale) - .build()) - .build(); - } - - static DTypeProtos.DType string(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setUtf8(DTypeProtos.Utf8.newBuilder().setNullable(nullable).build()) - .build(); - } - - static DTypeProtos.DType binary(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setBinary(DTypeProtos.Binary.newBuilder().setNullable(nullable).build()) - .build(); - } - - static DTypeProtos.DType dateDays(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setExtension(DTypeProtos.Extension.newBuilder() - .setId("vortex.date") - .setStorageDtype(DTypes.int32(nullable)) - .setMetadata(ByteString.copyFrom(TemporalMetadatas.DATE_DAYS.get())) - .build()) - .build(); - } - - static DTypeProtos.DType dateMillis(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setExtension(DTypeProtos.Extension.newBuilder() - .setId("vortex.date") - .setStorageDtype(DTypes.int64(nullable)) - .setMetadata(ByteString.copyFrom(TemporalMetadatas.DATE_MILLIS.get())) - .build()) - .build(); - } - - static DTypeProtos.DType timeSeconds(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setExtension(DTypeProtos.Extension.newBuilder() - .setId("vortex.time") - .setStorageDtype(DTypes.int32(nullable)) - .setMetadata(ByteString.copyFrom(TemporalMetadatas.TIME_SECONDS.get())) - .build()) - .build(); - } - - static DTypeProtos.DType timeMillis(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setExtension(DTypeProtos.Extension.newBuilder() - .setId("vortex.time") - .setStorageDtype(DTypes.int32(nullable)) - .setMetadata(ByteString.copyFrom(TemporalMetadatas.TIME_MILLIS.get())) - .build()) - .build(); - } - - static DTypeProtos.DType timeMicros(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setExtension(DTypeProtos.Extension.newBuilder() - .setId("vortex.time") - .setStorageDtype(DTypes.int64(nullable)) - .setMetadata(ByteString.copyFrom(TemporalMetadatas.TIME_MICROS.get())) - .build()) - .build(); - } - - static DTypeProtos.DType timeNanos(boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setExtension(DTypeProtos.Extension.newBuilder() - .setId("vortex.time") - .setStorageDtype(DTypes.int64(nullable)) - .setMetadata(ByteString.copyFrom(TemporalMetadatas.TIME_NANOS.get())) - .build()) - .build(); - } - - static DTypeProtos.DType timestampMillis(Optional timeZone, boolean nullable) { - return timestamp(TemporalMetadatas.TIME_UNIT_MILLIS, timeZone, nullable); - } - - static DTypeProtos.DType timestampMicros(Optional timeZone, boolean nullable) { - return timestamp(TIME_UNIT_MICROS, timeZone, nullable); - } - - static DTypeProtos.DType timestampNanos(Optional timeZone, boolean nullable) { - return timestamp(TIME_UNIT_NANOS, timeZone, nullable); - } - - private static DTypeProtos.DType timestamp(byte timeUnit, Optional timeZone, boolean nullable) { - return DTypeProtos.DType.newBuilder() - .setExtension(DTypeProtos.Extension.newBuilder() - .setId("vortex.timestamp") - .setStorageDtype(DTypes.int64(nullable)) - .setMetadata(ByteString.copyFrom(TemporalMetadatas.timestamp(timeUnit, timeZone))) - .build()) - .build(); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/EndianUtils.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/EndianUtils.java deleted file mode 100644 index 218db1e98ecd..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/EndianUtils.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.proto; - -import com.google.protobuf.ByteString; -import java.math.BigDecimal; -import java.math.BigInteger; - -/** Utility class for endianness conversions in Vortex protocol buffers. */ -public final class EndianUtils { - public static byte[] reverse(ByteString src) { - byte[] dst = new byte[src.size()]; - for (int i = 0; i < dst.length; i++) { - dst[i] = src.byteAt(dst.length - 1 - i); - } - return dst; - } - - public static byte[] littleEndianDecimal(BigDecimal decimal) { - BigInteger unscaled = decimal.unscaledValue(); - byte[] bigEndianBytes = unscaled.toByteArray(); - - // Determine target size (1, 2, 4, 8, 16, or 32 bytes) - int targetSize; - if (bigEndianBytes.length <= 1) { - targetSize = 1; - } else if (bigEndianBytes.length <= 2) { - targetSize = 2; - } else if (bigEndianBytes.length <= 4) { - targetSize = 4; - } else if (bigEndianBytes.length <= 8) { - targetSize = 8; - } else if (bigEndianBytes.length <= 16) { - targetSize = 16; - } else if (bigEndianBytes.length <= 32) { - targetSize = 32; - } else { - throw new IllegalArgumentException( - "BigDecimal with " + bigEndianBytes.length + " bytes overflows maximum Vortex decimal size"); - } - - byte[] result = new byte[targetSize]; - - // Copy bytes in reverse order (big endian to little endian) - for (int i = 0; i < bigEndianBytes.length; i++) { - result[i] = bigEndianBytes[bigEndianBytes.length - 1 - i]; - } - - // Sign extend if negative - if (unscaled.signum() < 0) { - for (int i = bigEndianBytes.length; i < targetSize; i++) { - result[i] = (byte) 0xFF; - } - } - - return result; - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Expressions.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Expressions.java deleted file mode 100644 index 108ea35976fa..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Expressions.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.proto; - -import com.google.protobuf.ByteString; -import dev.vortex.api.Expression; -import dev.vortex.api.expressions.*; -import dev.vortex.proto.ExprProtos; -import java.util.List; -import java.util.stream.Collectors; - -/** Serialize/deserialize Vortex expressions to/from protocol buffers. */ -public final class Expressions { - public static ExprProtos.Expr serialize(Expression expression) { - ByteString metadata = ByteString.copyFrom(expression - .metadata() - .orElseThrow(() -> new IllegalArgumentException("Expression is not serializable: " + expression.id()))); - - return ExprProtos.Expr.newBuilder() - .setId(expression.id()) - .addAllChildren(expression.children().stream() - .map(Expressions::serialize) - .collect(Collectors.toList())) - .setMetadata(metadata) - .build(); - } - - public static Expression deserialize(ExprProtos.Expr expr) { - byte[] metadata = expr.getMetadata().toByteArray(); - List children = - expr.getChildrenList().stream().map(Expressions::deserialize).collect(Collectors.toList()); - - switch (expr.getId()) { - case "vortex.binary": - return Binary.parse(metadata, children); - case "vortex.get_item": - return GetItem.parse(metadata, children); - case "vortex.root": - return Root.parse(metadata, children); - case "vortex.literal": - return Literal.parse(metadata, children); - case "vortex.not": - return Not.parse(metadata, children); - case "vortex.is_null": - return IsNull.parse(metadata, children); - case "vortex.is_not_null": - return IsNotNull.parse(metadata, children); - default: - return new Unknown(expr.getId(), children, expr.getMetadata().toByteArray()); - } - } - - private Expressions() {} -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Scalars.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Scalars.java deleted file mode 100644 index bc8df78b731d..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/Scalars.java +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.proto; - -import com.google.protobuf.ByteString; -import com.google.protobuf.NullValue; -import dev.vortex.proto.ScalarProtos; -import java.math.BigDecimal; -import java.util.Optional; - -/** Factory class for creating Vortex scalar values with their associated data types. */ -public final class Scalars { - private Scalars() {} - - public static ScalarProtos.Scalar nullNull() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.nullType()) - .build(); - } - - public static ScalarProtos.Scalar bool(boolean value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setBoolValue(value) - .build()) - .setDtype(DTypes.bool(false)) - .build(); - } - - public static ScalarProtos.Scalar nullBool() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.bool(true)) - .build(); - } - - public static ScalarProtos.Scalar int8(byte value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.int8(false)) - .build(); - } - - public static ScalarProtos.Scalar nullInt8() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.int8(true)) - .build(); - } - - public static ScalarProtos.Scalar int16(short value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.int16(false)) - .build(); - } - - public static ScalarProtos.Scalar nullInt16() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.int16(true)) - .build(); - } - - public static ScalarProtos.Scalar int32(int value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.int32(false)) - .build(); - } - - public static ScalarProtos.Scalar nullInt32() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.int32(true)) - .build(); - } - - public static ScalarProtos.Scalar int64(long value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.int64(false)) - .build(); - } - - public static ScalarProtos.Scalar nullInt64() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.int64(true)) - .build(); - } - - public static ScalarProtos.Scalar float32(float value) { - return ScalarProtos.Scalar.newBuilder() - .setValue( - ScalarProtos.ScalarValue.newBuilder().setF32Value(value).build()) - .setDtype(DTypes.float32(false)) - .build(); - } - - public static ScalarProtos.Scalar nullFloat32() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.float32(true)) - .build(); - } - - public static ScalarProtos.Scalar float64(double value) { - return ScalarProtos.Scalar.newBuilder() - .setValue( - ScalarProtos.ScalarValue.newBuilder().setF64Value(value).build()) - .setDtype(DTypes.float64(false)) - .build(); - } - - public static ScalarProtos.Scalar nullFloat64() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.float64(true)) - .build(); - } - - public static ScalarProtos.Scalar string(String value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setStringValue(value) - .build()) - .setDtype(DTypes.string(false)) - .build(); - } - - public static ScalarProtos.Scalar nullString() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.string(true)) - .build(); - } - - public static ScalarProtos.Scalar decimal(BigDecimal decimal, int precision, int scale) { - byte[] littleEndian = EndianUtils.littleEndianDecimal(decimal); - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setBytesValue(ByteString.copyFrom(littleEndian)) - .build()) - .setDtype(DTypes.decimal(false, precision, scale)) - .build(); - } - - public static ScalarProtos.Scalar nullDecimal(int precision, int scale) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder().setNullValue(NullValue.NULL_VALUE)) - .setDtype(DTypes.decimal(true, precision, scale)) - .build(); - } - - public static ScalarProtos.Scalar bytes(byte[] value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setBytesValue(ByteString.copyFrom(value)) - .build()) - .setDtype(DTypes.binary(false)) - .build(); - } - - public static ScalarProtos.Scalar nullBytes() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.binary(true)) - .build(); - } - - public static ScalarProtos.Scalar dateDays(int value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.dateDays(false)) - .build(); - } - - public static ScalarProtos.Scalar nullDateDays() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.dateDays(true)) - .build(); - } - - public static ScalarProtos.Scalar dateMillis(long value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.dateMillis(false)) - .build(); - } - - public static ScalarProtos.Scalar nullDateMillis() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.dateMillis(true)) - .build(); - } - - public static ScalarProtos.Scalar timeSeconds(int value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.timeSeconds(false)) - .build(); - } - - public static ScalarProtos.Scalar nullTimeSeconds() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.timeSeconds(true)) - .build(); - } - - public static ScalarProtos.Scalar timeMillis(int value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.timeMillis(false)) - .build(); - } - - public static ScalarProtos.Scalar nullTimeMillis() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.timeMillis(true)) - .build(); - } - - public static ScalarProtos.Scalar timeMicros(long value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.timeMicros(false)) - .build(); - } - - public static ScalarProtos.Scalar nullTimeMicros() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.timeMicros(true)) - .build(); - } - - public static ScalarProtos.Scalar timeNanos(long value) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.timeNanos(false)) - .build(); - } - - public static ScalarProtos.Scalar nullTimeNanos() { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.timeNanos(true)) - .build(); - } - - public static ScalarProtos.Scalar timestampMillis(long value, Optional timeZone) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.timestampMillis(timeZone, false)) - .build(); - } - - public static ScalarProtos.Scalar nullTimestampMillis(Optional timeZone) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.timestampMillis(timeZone, true)) - .build(); - } - - public static ScalarProtos.Scalar timestampMicros(long value, Optional timeZone) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.timestampMicros(timeZone, false)) - .build(); - } - - public static ScalarProtos.Scalar nullTimestampMicros(Optional timeZone) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.timestampMicros(timeZone, true)) - .build(); - } - - public static ScalarProtos.Scalar timestampNanos(long value, Optional timeZone) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setInt64Value(value) - .build()) - .setDtype(DTypes.timestampNanos(timeZone, false)) - .build(); - } - - public static ScalarProtos.Scalar nullTimestampNanos(Optional timeZone) { - return ScalarProtos.Scalar.newBuilder() - .setValue(ScalarProtos.ScalarValue.newBuilder() - .setNullValue(NullValue.NULL_VALUE) - .build()) - .setDtype(DTypes.timestampNanos(timeZone, true)) - .build(); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/TemporalMetadatas.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/TemporalMetadatas.java deleted file mode 100644 index fb76419229eb..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/proto/TemporalMetadatas.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.api.proto; - -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; -import org.apache.paimon.shade.guava30.com.google.common.base.Supplier; -import org.apache.paimon.shade.guava30.com.google.common.base.Suppliers; -import java.io.ByteArrayOutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Optional; - -/** Utility class for creating and parsing temporal metadata in Vortex protocol buffers. */ -public final class TemporalMetadatas { - private TemporalMetadatas() {} - - /** Time unit constant representing nanoseconds precision. */ - public static byte TIME_UNIT_NANOS = 0; - /** Time unit constant representing microseconds precision. */ - public static byte TIME_UNIT_MICROS = 1; - /** Time unit constant representing milliseconds precision. */ - public static byte TIME_UNIT_MILLIS = 2; - /** Time unit constant representing seconds precision. */ - public static byte TIME_UNIT_SECONDS = 3; - /** Time unit constant representing days precision. */ - public static byte TIME_UNIT_DAYS = 4; - - /** Supplier for date metadata with days precision. */ - public static final Supplier DATE_DAYS = Suppliers.memoize(() -> new byte[] {TIME_UNIT_DAYS}); - /** Supplier for date metadata with milliseconds precision. */ - public static final Supplier DATE_MILLIS = Suppliers.memoize(() -> new byte[] {TIME_UNIT_MILLIS}); - /** Supplier for time metadata with seconds precision. */ - public static final Supplier TIME_SECONDS = Suppliers.memoize(() -> new byte[] {TIME_UNIT_SECONDS}); - /** Supplier for time metadata with milliseconds precision. */ - public static final Supplier TIME_MILLIS = Suppliers.memoize(() -> new byte[] {TIME_UNIT_MILLIS}); - /** Supplier for time metadata with microseconds precision. */ - public static final Supplier TIME_MICROS = Suppliers.memoize(() -> new byte[] {TIME_UNIT_MICROS}); - /** Supplier for time metadata with nanoseconds precision. */ - public static final Supplier TIME_NANOS = Suppliers.memoize(() -> new byte[] {TIME_UNIT_NANOS}); - - public static byte[] timestamp(byte timeUnit, Optional timeZone) { - Preconditions.checkArgument( - timeUnit >= TIME_UNIT_NANOS && timeUnit < TIME_UNIT_DAYS, "invalid timeUnit for Timestamp:" + timeUnit); - - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - baos.write(timeUnit); - if (timeZone.isPresent()) { - byte[] timeZoneBytes = timeZone.get().getBytes(StandardCharsets.UTF_8); - // Write length as little-endian uint16. - int lenLow = timeZoneBytes.length & 0xFF; - int lenHigh = (timeZoneBytes.length >> 8) & 0xFF; - baos.write(lenLow); - baos.write(lenHigh); - baos.write(timeZoneBytes, 0, timeZoneBytes.length); - } else { - // write uint16 zero value - baos.write(0); - baos.write(0); - } - return baos.toByteArray(); - } - - public static byte getTimeUnit(byte[] serializedMetadata) { - byte timeUnit = serializedMetadata[0]; - Preconditions.checkArgument( - timeUnit >= TIME_UNIT_NANOS && timeUnit <= TIME_UNIT_DAYS, "invalid timeUnit byte: " + timeUnit); - - return timeUnit; - } - - public static Optional getTimeZone(byte[] serializedMetadata) { - byte lenLow = serializedMetadata[1]; - byte lenHigh = serializedMetadata[2]; - int len = ((lenHigh & 0xFF) << 8) | (lenLow & 0xFF); - if (len == 0) { - return Optional.empty(); - } else { - byte[] timeZoneBytes = new byte[len]; - System.arraycopy(serializedMetadata, 3, timeZoneBytes, 0, len); - return Optional.of(new String(timeZoneBytes, StandardCharsets.UTF_8)); - } - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArray.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArray.java deleted file mode 100644 index e28f2fefa756..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArray.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.jni; - -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; -import dev.vortex.api.Array; -import dev.vortex.api.DType; -import java.math.BigDecimal; -import java.util.OptionalLong; -import org.apache.arrow.c.ArrowArray; -import org.apache.arrow.c.ArrowSchema; -import org.apache.arrow.c.CDataDictionaryProvider; -import org.apache.arrow.c.Data; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -/** JNI implementation of the Array interface. */ -public final class JNIArray implements Array { - static { - NativeLoader.loadJni(); - } - - private final ThreadLocal schemaPtr = ThreadLocal.withInitial(() -> new long[1]); - private final ThreadLocal arrayPtr = ThreadLocal.withInitial(() -> new long[1]); - - private OptionalLong pointer; - - public JNIArray(long pointer) { - Preconditions.checkArgument(pointer > 0, "Invalid pointer address: " + pointer); - this.pointer = OptionalLong.of(pointer); - } - - @Override - public long getLen() { - return NativeArrayMethods.getLen(pointer.getAsLong()); - } - - @Override - public long nbytes() { - return NativeArrayMethods.nbytes(pointer.getAsLong()); - } - - @Override - public VectorSchemaRoot exportToArrow(BufferAllocator allocator, VectorSchemaRoot reuse) { - // Export the dataset to Arrow over C Data Interface. - NativeArrayMethods.exportToArrow(pointer.getAsLong(), schemaPtr.get(), arrayPtr.get()); - try (ArrowSchema arrowSchema = ArrowSchema.wrap(schemaPtr.get()[0]); - ArrowArray arrowArray = ArrowArray.wrap(arrayPtr.get()[0]); - CDataDictionaryProvider provider = new CDataDictionaryProvider()) { - if (reuse != null) { - Data.importIntoVectorSchemaRoot(allocator, arrowArray, reuse, provider); - return reuse; - } else { - return Data.importVectorSchemaRoot(allocator, arrowArray, arrowSchema, new CDataDictionaryProvider()); - } - } finally { - NativeArrayMethods.dropArrowSchema(schemaPtr.get()[0]); - NativeArrayMethods.dropArrowArray(arrayPtr.get()[0]); - } - } - - @Override - public DType getDataType() { - return new JNIDType(NativeArrayMethods.getDataType(pointer.getAsLong())); - } - - @Override - public Array getField(int index) { - return new JNIArray(NativeArrayMethods.getField(pointer.getAsLong(), index)); - } - - @Override - public Array slice(int start, int stop) { - return new JNIArray(NativeArrayMethods.slice(pointer.getAsLong(), start, stop)); - } - - @Override - public boolean getNull(int index) { - return NativeArrayMethods.getNull(pointer.getAsLong(), index); - } - - @Override - public int getNullCount() { - return NativeArrayMethods.getNullCount(pointer.getAsLong()); - } - - @Override - public byte getByte(int index) { - return NativeArrayMethods.getByte(pointer.getAsLong(), index); - } - - @Override - public short getShort(int index) { - return NativeArrayMethods.getShort(pointer.getAsLong(), index); - } - - @Override - public int getInt(int index) { - return NativeArrayMethods.getInt(pointer.getAsLong(), index); - } - - @Override - public long getLong(int index) { - return NativeArrayMethods.getLong(pointer.getAsLong(), index); - } - - @Override - public boolean getBool(int index) { - return NativeArrayMethods.getBool(pointer.getAsLong(), index); - } - - @Override - public float getFloat(int index) { - return NativeArrayMethods.getFloat(pointer.getAsLong(), index); - } - - @Override - public double getDouble(int index) { - return NativeArrayMethods.getDouble(pointer.getAsLong(), index); - } - - @Override - public BigDecimal getBigDecimal(int index) { - return NativeArrayMethods.getBigDecimal(pointer.getAsLong(), index); - } - - @Override - public String getUTF8(int index) { - return NativeArrayMethods.getUTF8(pointer.getAsLong(), index); - } - - @Override - public void getUTF8_ptr_len(int index, long[] ptr, int[] len) { - NativeArrayMethods.getUTF8_ptr_len(pointer.getAsLong(), index, ptr, len); - } - - @Override - public byte[] getBinary(int index) { - return NativeArrayMethods.getBinary(pointer.getAsLong(), index); - } - - @Override - public void close() { - if (!pointer.isPresent()) { - return; - } - - NativeArrayMethods.free(pointer.getAsLong()); - pointer = OptionalLong.empty(); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArrayIterator.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArrayIterator.java deleted file mode 100644 index c27c2d648162..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIArrayIterator.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.jni; - -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; -import dev.vortex.api.Array; -import dev.vortex.api.ArrayIterator; -import dev.vortex.api.DType; -import java.util.Optional; -import java.util.OptionalLong; - -/** JNI implementation of the ArrayIterator interface. */ -public final class JNIArrayIterator implements ArrayIterator { - private OptionalLong pointer; - private Optional next; - - public JNIArrayIterator(long pointer) { - Preconditions.checkArgument(pointer > 0, "Invalid pointer address: " + pointer); - this.pointer = OptionalLong.of(pointer); - advance(); - } - - @Override - public boolean hasNext() { - return next.isPresent(); - } - - @Override - public Array next() { - Array array = this.next.get(); - advance(); - return array; - } - - @Override - public DType getDataType() { - return new JNIDType(NativeArrayIteratorMethods.getDType(pointer.getAsLong())); - } - - @Override - public void close() { - if (!pointer.isPresent()) { - return; - } - - NativeArrayIteratorMethods.free(pointer.getAsLong()); - pointer = OptionalLong.empty(); - next = Optional.empty(); - } - - private void advance() { - long next = NativeArrayIteratorMethods.take(pointer.getAsLong()); - if (next <= 0) { - this.next = Optional.empty(); - } else { - this.next = Optional.of(new JNIArray(next)); - } - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIDType.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIDType.java deleted file mode 100644 index 06e2a906641c..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIDType.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.jni; - -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; -import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; -import dev.vortex.api.DType; -import java.util.List; -import java.util.Optional; -import java.util.OptionalLong; - -/** JNI implementation of the DType interface. */ -public final class JNIDType implements DType { - OptionalLong pointer; - final boolean isOwned; // True if this object owns the native memory - - public JNIDType(long pointer) { - this(pointer, false); - } - - public long getPointer() { - return pointer.getAsLong(); - } - - public JNIDType(long pointer, boolean isOwned) { - Preconditions.checkArgument(pointer > 0, "Invalid pointer address: " + pointer); - this.pointer = OptionalLong.of(pointer); - this.isOwned = isOwned; - } - - @Override - public Variant getVariant() { - return Variant.from(NativeDTypeMethods.getVariant(pointer.getAsLong())); - } - - @Override - public boolean isNullable() { - return NativeDTypeMethods.isNullable(pointer.getAsLong()); - } - - @Override - public List getFieldNames() { - return NativeDTypeMethods.getFieldNames(pointer.getAsLong()); - } - - @Override - public List getFieldTypes() { - return Lists.transform(NativeDTypeMethods.getFieldTypes(pointer.getAsLong()), JNIDType::new); - } - - @Override - public DType getElementType() { - // Returns a borrowed reference - the parent DType owns this memory - return new JNIDType(NativeDTypeMethods.getElementType(pointer.getAsLong())); - } - - @Override - public int getFixedSizeListSize() { - return NativeDTypeMethods.getFixedSizeListSize(pointer.getAsLong()); - } - - @Override - public boolean isDate() { - return NativeDTypeMethods.isDate(pointer.getAsLong()); - } - - @Override - public boolean isTime() { - return NativeDTypeMethods.isTime(pointer.getAsLong()); - } - - @Override - public boolean isTimestamp() { - return NativeDTypeMethods.isTimestamp(pointer.getAsLong()); - } - - @Override - public TimeUnit getTimeUnit() { - return TimeUnit.from(NativeDTypeMethods.getTimeUnit(pointer.getAsLong())); - } - - @Override - public Optional getTimeZone() { - return Optional.ofNullable(NativeDTypeMethods.getTimeZone(pointer.getAsLong())); - } - - @Override - public boolean isDecimal() { - return NativeDTypeMethods.isDecimal(pointer.getAsLong()); - } - - @Override - public int getPrecision() { - return NativeDTypeMethods.getDecimalPrecision(pointer.getAsLong()); - } - - @Override - public byte getScale() { - return NativeDTypeMethods.getDecimalScale(pointer.getAsLong()); - } - - @Override - public void close() { - if (isOwned && pointer.isPresent()) { - NativeDTypeMethods.free(pointer.getAsLong()); - pointer = OptionalLong.empty(); - } - } - - public static JNIDType ownedDType(long pointer) { - return new JNIDType(pointer, true); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIFile.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIFile.java deleted file mode 100644 index 8823e4e46aaa..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIFile.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.jni; - -import org.apache.paimon.shade.guava30.com.google.common.base.Preconditions; -import dev.vortex.api.ArrayIterator; -import dev.vortex.api.DType; -import dev.vortex.api.File; -import dev.vortex.api.ScanOptions; -import dev.vortex.api.proto.Expressions; -import java.util.OptionalLong; - -/** JNI implementation of the File interface. */ -public final class JNIFile implements File { - private OptionalLong pointer; - - public JNIFile(long pointer) { - Preconditions.checkArgument(pointer > 0, "Invalid pointer address: " + pointer); - this.pointer = OptionalLong.of(pointer); - } - - @Override - public DType getDType() { - return new JNIDType(NativeFileMethods.dtype(pointer.getAsLong())); - } - - @Override - public long rowCount() { - return NativeFileMethods.rowCount(pointer.getAsLong()); - } - - @Override - public ArrayIterator newScan(ScanOptions options) { - byte[] predicateProto = null; - - if (options.predicate().isPresent()) { - predicateProto = Expressions.serialize(options.predicate().get()).toByteArray(); - } - - long[] rowIndices = options.rowIndices().orElse(null); - long[] rowRange = options.rowRange().orElse(null); - - return new JNIArrayIterator( - NativeFileMethods.scan(pointer.getAsLong(), options.columns(), predicateProto, rowRange, rowIndices)); - } - - @Override - public void close() { - if (!pointer.isPresent()) { - return; - } - NativeFileMethods.close(pointer.getAsLong()); - pointer = OptionalLong.empty(); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIWriter.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIWriter.java deleted file mode 100644 index c3345a9e3ee0..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/JNIWriter.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.jni; - -import dev.vortex.api.VortexWriter; -import java.io.IOException; -import java.util.OptionalLong; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** JNI implementation of VortexWriter. */ -public final class JNIWriter implements VortexWriter, AutoCloseable { - private static final Logger logger = LoggerFactory.getLogger(JNIWriter.class); - - private OptionalLong ptr; - - public JNIWriter(long ptr) { - this.ptr = OptionalLong.of(ptr); - logger.debug("Created JNIWriter with ptr={}", ptr); - } - - @Override - public void writeBatch(byte[] arrowData) throws IOException { - logger.trace("Writing batch with {} bytes", arrowData.length); - - // Write the Arrow data to Vortex through JNI - boolean success = NativeWriterMethods.writeBatch(ptr.getAsLong(), arrowData); - if (!success) { - logger.error("Failed to write batch to Vortex file"); - throw new IOException("Failed to write batch to Vortex file"); - } - } - - @Override - public void writeBatchFfi(long arrowArrayAddr, long arrowSchemaAddr) throws IOException { - logger.trace("Writing batch via FFI (arrayAddr={}, schemaAddr={})", arrowArrayAddr, arrowSchemaAddr); - - boolean success = NativeWriterMethods.writeBatchFfi(ptr.getAsLong(), arrowArrayAddr, arrowSchemaAddr); - if (!success) { - logger.error("Failed to write FFI batch to Vortex file"); - throw new IOException("Failed to write FFI batch to Vortex file"); - } - } - - @Override - public void close() { - if (!this.ptr.isPresent()) { - logger.debug("Attempted to close already closed JNIWriter, skipping"); - return; - } - - long ptr = this.ptr.getAsLong(); - - logger.debug("Closing JNIWriter with ptr={}", ptr); - NativeWriterMethods.close(ptr); - this.ptr = OptionalLong.empty(); - } -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeArrayMethods.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeArrayMethods.java deleted file mode 100644 index f9d693cc890c..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeArrayMethods.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.jni; - -import java.math.BigDecimal; - -/** Native JNI methods for array operations. */ -public final class NativeArrayMethods { - static { - NativeLoader.loadJni(); - } - - private NativeArrayMethods() {} - - public static native long nbytes(long pointer); - - public static native void exportToArrow(long pointer, long[] schemaPointer, long[] arrayPointer); - - public static native void dropArrowSchema(long arrowSchemaPtr); - - public static native void dropArrowArray(long arrowArrayPtr); - - public static native void free(long pointer); - - public static native long getLen(long pointer); - - public static native long getDataType(long pointer); - - public static native long getField(long pointer, int index); - - public static native long slice(long pointer, int start, int stop); - - public static native boolean getNull(long pointer, int index); - - public static native int getNullCount(long pointer); - - public static native byte getByte(long pointer, int index); - - public static native short getShort(long pointer, int index); - - public static native int getInt(long pointer, int index); - - public static native long getLong(long pointer, int index); - - public static native boolean getBool(long pointer, int index); - - public static native float getFloat(long pointer, int index); - - public static native double getDouble(long pointer, int index); - - public static native BigDecimal getBigDecimal(long pointer, int index); - - public static native String getUTF8(long pointer, int index); - - public static native void getUTF8_ptr_len(long pointer, int index, long[] outPtr, int[] outLen); - - public static native byte[] getBinary(long pointer, int index); -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeDTypeMethods.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeDTypeMethods.java deleted file mode 100644 index 1126310498f1..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeDTypeMethods.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dev.vortex.jni; - -import java.util.List; - -/** Native JNI methods for DType operations. */ -public final class NativeDTypeMethods { - static { - NativeLoader.loadJni(); - } - - private NativeDTypeMethods() {} - - public static native long newByte(boolean isNullable); - - public static native long newShort(boolean isNullable); - - public static native long newInt(boolean isNullable); - - public static native long newLong(boolean isNullable); - - public static native long newFloat(boolean isNullable); - - public static native long newDouble(boolean isNullable); - - public static native long newDecimal(int precision, int scale, boolean isNullable); - - public static native long newUtf8(boolean isNullable); - - public static native long newBinary(boolean isNullable); - - public static native long newBool(boolean isNullable); - - public static native long newList(long elementTypePtr, boolean isNullable); - - public static native long newFixedSizeList(long elementTypePtr, int size, boolean isNullable); - - public static native long newStruct(String[] fieldNames, long[] fieldTypes, boolean isNullable); - - public static native long newTimestamp(byte timeUnit, String zone, boolean isNullable); - - public static native long newDate(byte timeUnit, boolean isNullable); - - public static native long newTime(byte timeUnit, boolean isNullable); - - public static native void free(long pointer); - - public static native byte getVariant(long pointer); - - public static native boolean isNullable(long pointer); - - public static native List getFieldNames(long pointer); - - // Returns a list of DType pointers. - public static native List getFieldTypes(long pointer); - - public static native long getElementType(long pointer); - - public static native int getFixedSizeListSize(long pointer); - - public static native boolean isDate(long pointer); - - public static native boolean isTime(long pointer); - - public static native boolean isTimestamp(long pointer); - - public static native byte getTimeUnit(long pointer); - - public static native String getTimeZone(long pointer); - - public static native boolean isDecimal(long pointer); - - public static native int getDecimalPrecision(long pointer); - - public static native byte getDecimalScale(long pointer); -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeFileMethods.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeDataSource.java similarity index 56% rename from paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeFileMethods.java rename to paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeDataSource.java index 07b4c6224146..e0c93d8fc966 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeFileMethods.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeDataSource.java @@ -18,29 +18,23 @@ package dev.vortex.jni; -import java.util.List; import java.util.Map; -/** Native JNI methods for file operations. */ -public final class NativeFileMethods { +/** Native methods for Vortex data source operations. */ +public final class NativeDataSource { + static { NativeLoader.loadJni(); } - private NativeFileMethods() {} - - public static native List listVortexFiles(String uri, Map options); - - public static native void delete(String[] uris, Map options); - - public static native long open(String uri, Map options); + private NativeDataSource() {} - public static native long rowCount(long pointer); + public static native long open(long sessionPtr, String[] uris, Map options); - public static native long dtype(long pointer); + public static native void free(long dataSourcePtr); - public static native void close(long pointer); + public static native void arrowSchema(long dataSourcePtr, long arrowSchemaOutAddr); - public static native long scan( - long pointer, List columns, byte[] predicateProto, long[] rowRange, long[] rowIndices); + /** Result: out[0] = row count value, out[1] = type (1=estimate, 2=exact, other=unknown). */ + public static native void rowCount(long dataSourcePtr, long[] resultOut); } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java new file mode 100644 index 000000000000..a1377b7890f1 --- /dev/null +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.vortex.jni; + +/** Native methods for Vortex expression construction. */ +public final class NativeExpression { + + static { + NativeLoader.loadJni(); + } + + private NativeExpression() {} + + public static native long root(); + + public static native long getItem(String path, long parentPtr); + + public static native long select(String[] columns, long parentPtr); + + public static native long and(long[] exprPtrs); + + public static native long or(long[] exprPtrs); + + public static native long binary(byte opCode, long leftPtr, long rightPtr); + + public static native long not(long exprPtr); + + public static native long isNull(long exprPtr); + + public static native long isNotNull(long exprPtr); + + public static native long like( + long exprPtr, long patternPtr, boolean caseSensitive, boolean negated); + + public static native long between( + long exprPtr, long lowPtr, long highPtr, boolean lowInclusive, boolean highInclusive); + + public static native long literalBool(boolean value, boolean isNull); + + public static native long literalI8(byte value, boolean isNull); + + public static native long literalI16(short value, boolean isNull); + + public static native long literalI32(int value, boolean isNull); + + public static native long literalI64(long value, boolean isNull); + + public static native long literalF32(float value, boolean isNull); + + public static native long literalF64(double value, boolean isNull); + + public static native long literalString(String value); + + public static native long literalBinary(byte[] value); + + public static native long literalDecimal( + byte[] bigIntBytes, int precision, int scale, boolean isNull); + + public static native long literalDate(long value, byte timeUnitTag, boolean isNull); + + public static native long literalTimestamp( + long value, byte timeUnitTag, String timezone, boolean isNull); + + public static native long literalNull(byte dTypeTag); + + public static native void free(long exprPtr); +} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/ArrayIterator.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeFiles.java similarity index 62% rename from paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/ArrayIterator.java rename to paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeFiles.java index 41df75e53a5a..fb2611c9d837 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/ArrayIterator.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeFiles.java @@ -16,14 +16,23 @@ * limitations under the License. */ -package dev.vortex.api; +package dev.vortex.jni; -import java.util.Iterator; +import java.util.List; +import java.util.Map; -/** An iterator over Vortex arrays with type information and resource management. */ -public interface ArrayIterator extends AutoCloseable, Iterator { - DType getDataType(); +/** Native methods for Vortex file listing and deletion. */ +public final class NativeFiles { - @Override - void close(); + static { + NativeLoader.loadJni(); + } + + private NativeFiles() {} + + private static native List listFiles( + long sessionPtr, String uri, Map options); + + private static native void delete( + long sessionPtr, String[] uris, Map options); } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLoader.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLoader.java index 32a9311f2455..71e5dbb9ab6e 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLoader.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLoader.java @@ -33,6 +33,15 @@ public static synchronized void loadJni() { return; } + // Use Unsafe allocator for Arrow memory to guarantee buffer alignment. + // Vortex's Rust FFI requires buffers aligned to the scalar type's natural + // alignment (e.g. 16 bytes for Decimal128). Netty's pooled allocator may + // return sub-aligned addresses. Unsafe.allocateMemory provides at least + // 16-byte alignment on 64-bit systems. + if (System.getProperty("arrow.allocation.manager.type") == null) { + System.setProperty("arrow.allocation.manager.type", "Unsafe"); + } + // Load the native library String osName = System.getProperty("os.name").toLowerCase(Locale.ROOT); String osArch = System.getProperty("os.arch").toLowerCase(Locale.ROOT); diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLogging.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLogging.java index c69989295da4..cd5ce3a47491 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLogging.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeLogging.java @@ -27,19 +27,15 @@ public final class NativeLogging { private NativeLogging() {} /** Logging level constant for error messages only */ - public static final int ERROR = 0; + public static final int ERROR = 1; - /** Logging level constant for warning and error messages */ - public static final int WARN = 1; + public static final int WARN = 2; - /** Logging level constant for informational, warning, and error messages */ - public static final int INFO = 2; + public static final int INFO = 3; - /** Logging level constant for debug, informational, warning, and error messages */ - public static final int DEBUG = 3; + public static final int DEBUG = 4; - /** Logging level constant for all messages including trace-level debugging */ - public static final int TRACE = 4; + public static final int TRACE = 5; public static native void initLogging(int level); } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativePartition.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativePartition.java new file mode 100644 index 000000000000..0b54de553971 --- /dev/null +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativePartition.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.vortex.jni; + +/** Native methods for Vortex partition operations. */ +public final class NativePartition { + + static { + NativeLoader.loadJni(); + } + + private NativePartition() {} + + public static native void free(long partitionPtr); + + /** Result: out[0] = row count value, out[1] = has value (0=unknown, nonzero=known). */ + public static native void rowCount(long partitionPtr, long[] resultOut); + + public static native void scanArrow( + long sessionPtr, long partitionPtr, long arrowArrayStreamAddr); +} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeArrayIteratorMethods.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeRuntime.java similarity index 74% rename from paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeArrayIteratorMethods.java rename to paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeRuntime.java index 08c6fc7dd248..e1a4e46b75a8 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeArrayIteratorMethods.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeRuntime.java @@ -18,17 +18,18 @@ package dev.vortex.jni; -/** Native JNI methods for array iterator operations. */ -public final class NativeArrayIteratorMethods { +/** Native methods for Vortex runtime configuration. */ +public final class NativeRuntime { + static { NativeLoader.loadJni(); } - private NativeArrayIteratorMethods() {} + private NativeRuntime() {} - public static native void free(long pointer); + public static native void setWorkerThreads(int count); - public static native long take(long pointer); + public static native void setWorkerThreadsToAvailableParallelism(); - public static native long getDType(long pointer); + public static native int workerCount(); } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeScan.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeScan.java new file mode 100644 index 000000000000..1ee835398632 --- /dev/null +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeScan.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.vortex.jni; + +/** Native methods for Vortex scan operations. */ +public final class NativeScan { + + static { + NativeLoader.loadJni(); + } + + private NativeScan() {} + + public static native long create( + long dataSourcePtr, + long projectionExprPtr, + long filterExprPtr, + long rowRangeBegin, + long rowRangeEnd, + long[] selectionIndices, + byte selectionModeCode, + long limit, + boolean ordered); + + public static native void free(long scanPtr); + + public static native void arrowSchema(long scanPtr, long arrowSchemaOutAddr); + + public static native void partitionCount(long scanPtr, long[] resultOut); + + /** Returns the next partition pointer, or 0 when exhausted. */ + public static native long nextPartition(long scanPtr); +} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/File.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeSession.java similarity index 73% rename from paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/File.java rename to paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeSession.java index beeb8f754822..bca55dac45ca 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/api/File.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeSession.java @@ -16,16 +16,18 @@ * limitations under the License. */ -package dev.vortex.api; +package dev.vortex.jni; -/** Interface for reading Vortex format files. */ -public interface File extends AutoCloseable { - DType getDType(); +/** Native methods for Vortex session lifecycle. */ +public final class NativeSession { - long rowCount(); + static { + NativeLoader.loadJni(); + } - ArrayIterator newScan(ScanOptions options); + private NativeSession() {} - @Override - void close(); + public static native long newSession(); + + public static native void free(long sessionPtr); } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeWriterMethods.java b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeWriter.java similarity index 70% rename from paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeWriterMethods.java rename to paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeWriter.java index 37a2bffc24ca..dca107b4075a 100644 --- a/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeWriterMethods.java +++ b/paimon-vortex/paimon-vortex-jni/src/main/java/dev/vortex/jni/NativeWriter.java @@ -20,20 +20,20 @@ import java.util.Map; -/** Native JNI methods for writing Vortex files. */ -public final class NativeWriterMethods { +/** Native methods for Vortex writer operations. */ +public final class NativeWriter { static { NativeLoader.loadJni(); } - private NativeWriterMethods() {} + private NativeWriter() {} - public static native long create(String uri, long dtype, Map options); + public static native long create( + long sessionPtr, String uri, long arrowSchemaAddr, Map options); - public static native boolean writeBatch(long writerPtr, byte[] arrowData); - - public static native boolean writeBatchFfi(long writerPtr, long arrowArrayAddr, long arrowSchemaAddr); + public static native boolean writeBatch( + long writerPtr, long arrowArrayAddr, long arrowSchemaAddr); public static native void close(long writerPtr); } diff --git a/paimon-vortex/paimon-vortex-jni/src/main/proto/dtype.proto b/paimon-vortex/paimon-vortex-jni/src/main/proto/dtype.proto deleted file mode 100644 index 3a735636cf40..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/proto/dtype.proto +++ /dev/null @@ -1,101 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -syntax = "proto3"; - -package vortex.dtype; - -option java_package = "dev.vortex.proto"; -option java_outer_classname = "DTypeProtos"; - -enum PType { - U8 = 0; - U16 = 1; - U32 = 2; - U64 = 3; - I8 = 4; - I16 = 5; - I32 = 6; - I64 = 7; - F16 = 8; - F32 = 9; - F64 = 10; -} - -message Null {} - -message Bool { - bool nullable = 1; -} - -message Primitive { - PType type = 1; - bool nullable = 2; -} - -message Decimal { - uint32 precision = 1; - int32 scale = 2; - bool nullable = 3; -} - -message Utf8 { - bool nullable = 1; -} - -message Binary { - bool nullable = 1; -} - -message Struct { - repeated string names = 1; - repeated DType dtypes = 2; - bool nullable = 3; -} - -message List { - DType element_type = 1; - bool nullable = 2; -} - -message FixedSizeList { - DType element_type = 1; - uint32 size = 2; - bool nullable = 3; -} - -message Extension { - string id = 1; - DType storage_dtype = 2; - optional bytes metadata = 3; -} - -message Variant { - bool nullable = 1; -} - -message DType { - oneof dtype_type { - Null null = 1; - Bool bool = 2; - Primitive primitive = 3; - Decimal decimal = 4; - Utf8 utf8 = 5; - Binary binary = 6; - Struct struct = 7; - List list = 8; - Extension extension = 9; - FixedSizeList fixed_size_list = 10; // This is after `Extension` for backwards compatibility. - Variant variant = 11; - } -} - -message Field { - oneof field_type { - string name = 1; - } -} - -message FieldPath { - repeated Field path = 1; -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/proto/expr.proto b/paimon-vortex/paimon-vortex-jni/src/main/proto/expr.proto deleted file mode 100644 index 73ba7209a159..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/proto/expr.proto +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -syntax = "proto3"; - -package vortex.expr; - -import "scalar.proto"; -import "dtype.proto"; - -option java_package = "dev.vortex.proto"; -option java_outer_classname = "ExprProtos"; - -// Captures a generic representation of expressions in Vortex. -// Expression deserializers can be registered with a Vortex session to handle parsing this into -// an in-memory expression for execution. -message Expr { - string id = 1; - repeated Expr children = 2; - optional bytes metadata = 3; -} - -// Captures a serialized aggregate function with its ID and options metadata. -message AggregateFn { - string id = 1; - optional bytes metadata = 2; -} - -// Options for `vortex.literal` -message LiteralOpts { - vortex.scalar.Scalar value = 1; -} - -// Options for `vortex.pack` -message PackOpts { - repeated string paths = 1; - bool nullable = 2; -} - -// Options for `vortex.getitem` -message GetItemOpts { - string path = 1; -} - -// Options for `vortex.binary` -message BinaryOpts { - BinaryOp op = 1; - - enum BinaryOp { - Eq = 0; - NotEq = 1; - Gt = 2; - Gte = 3; - Lt = 4; - Lte = 5; - And = 6; - Or = 7; - Add = 8; - Sub = 9; - Mul = 10; - Div = 11; - } -} - -message BetweenOpts { - bool lower_strict = 1; - bool upper_strict = 2; -} - -message LikeOpts { - bool negated = 1; - bool case_insensitive = 2; -} - -message CastOpts { - vortex.dtype.DType target = 1; -} - -message FieldNames { - repeated string names = 1; -} - -message SelectOpts { - oneof opts { - FieldNames include = 1; - FieldNames exclude = 2; - } -} - -// Options for `vortex.case_when` -// Encodes num_when_then_pairs and has_else into a single u32 (num_children). -// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0) -// has_else = num_children % 2 == 1 -// num_when_then_pairs = num_children / 2 -message CaseWhenOpts { - uint32 num_children = 1; -} diff --git a/paimon-vortex/paimon-vortex-jni/src/main/proto/scalar.proto b/paimon-vortex/paimon-vortex-jni/src/main/proto/scalar.proto deleted file mode 100644 index 251863dc3a3e..000000000000 --- a/paimon-vortex/paimon-vortex-jni/src/main/proto/scalar.proto +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -syntax = "proto3"; - -package vortex.scalar; - -option java_package = "dev.vortex.proto"; -option java_outer_classname = "ScalarProtos"; - -import "dtype.proto"; -import "google/protobuf/struct.proto"; - -message Scalar { - vortex.dtype.DType dtype = 1; - ScalarValue value = 2; -} - -message ScalarValue { - oneof kind { - google.protobuf.NullValue null_value = 1; - bool bool_value = 2; - sint64 int64_value = 3; - uint64 uint64_value = 4; - float f32_value = 5; - double f64_value = 6; - string string_value = 7; - bytes bytes_value = 8; - ListValue list_value = 9; - uint64 f16_value = 10; - // Variant scalars carry a row-specific nested scalar. - // See RFC 0015: https://github.com/vortex-data/rfcs/blob/develop/accepted/0015-variant-type.md - Scalar variant_value = 11; - } -} - -message ListValue { - repeated ScalarValue values = 1; -} diff --git a/pom.xml b/pom.xml index 5a336fb76c2d..d2c02ae4d60c 100644 --- a/pom.xml +++ b/pom.xml @@ -74,6 +74,7 @@ under the License. paimon-api paimon-lumina paimon-vortex + paimon-mosaic paimon-tantivy From df4d29ec735c01bd42cc52b585829acfcecc2fe1 Mon Sep 17 00:00:00 2001 From: jkukreja Date: Sat, 6 Jun 2026 07:29:15 -0400 Subject: [PATCH 4/4] Re-Trigger CI