diff --git a/amber/src/main/python/core/architecture/handlers/control/assign_port_handler.py b/amber/src/main/python/core/architecture/handlers/control/assign_port_handler.py index 73ebad26b31..71f1e7f96ba 100644 --- a/amber/src/main/python/core/architecture/handlers/control/assign_port_handler.py +++ b/amber/src/main/python/core/architecture/handlers/control/assign_port_handler.py @@ -44,10 +44,10 @@ async def assign_port(self, req: AssignPortRequest) -> EmptyReturn: channel_id=channel_id, port_id=req.port_id ) else: - storage_uri = None + storage_uri_base = None if len(req.storage_uris) > 0 and req.storage_uris[0]: - storage_uri = req.storage_uris[0] + storage_uri_base = req.storage_uris[0] self.context.output_manager.add_output_port( - req.port_id, Schema(raw_schema=req.schema), storage_uri + req.port_id, Schema(raw_schema=req.schema), storage_uri_base ) return EmptyReturn() diff --git a/amber/src/main/python/core/architecture/packaging/output_manager.py b/amber/src/main/python/core/architecture/packaging/output_manager.py index bf4afbf396f..b85e3e39bf1 100644 --- a/amber/src/main/python/core/architecture/packaging/output_manager.py +++ b/amber/src/main/python/core/architecture/packaging/output_manager.py @@ -45,6 +45,7 @@ from core.models.payload import DataPayload, DataFrame from core.models.state import State from core.storage.document_factory import DocumentFactory +from core.storage.vfs_uri_factory import VFSURIFactory from core.storage.runnables.port_storage_writer import ( PortStorageWriter, PortStorageWriterElement, @@ -87,6 +88,10 @@ def __init__(self, worker_id: str): PortIdentity, typing.Tuple[Queue, PortStorageWriter, Thread] ] = dict() + self._port_state_writers: typing.Dict[ + PortIdentity, typing.Tuple[Queue, PortStorageWriter, Thread] + ] = dict() + def is_missing_output_ports(self): """ This method is only used for ensuring correct region execution. @@ -107,26 +112,30 @@ def add_output_port( self, port_id: PortIdentity, schema: Schema, - storage_uri: typing.Optional[str] = None, + storage_uri_base: typing.Optional[str] = None, ) -> None: if port_id.id is None: port_id.id = 0 if port_id.internal is None: port_id.internal = False - if storage_uri is not None: - self.set_up_port_storage_writer(port_id, storage_uri) + if storage_uri_base is not None: + self.set_up_port_storage_writer(port_id, storage_uri_base) # each port can only be added and initialized once. if port_id not in self._ports: self._ports[port_id] = WorkerPort(schema) - def set_up_port_storage_writer(self, port_id: PortIdentity, storage_uri: str): + def set_up_port_storage_writer(self, port_id: PortIdentity, storage_uri_base: str): """ Create a separate thread for saving output tuples of a port - to storage in batch. + to storage in batch, and open a long-lived buffered writer for + state materialization on the same port. `storage_uri_base` is the + port's base URI; the result and state URIs are derived from it. """ - document, _ = DocumentFactory.open_document(storage_uri) + document, _ = DocumentFactory.open_document( + VFSURIFactory.result_uri(storage_uri_base) + ) buffered_item_writer = document.writer(str(get_worker_index(self.worker_id))) writer_queue = Queue() port_storage_writer = PortStorageWriter( @@ -144,6 +153,29 @@ def set_up_port_storage_writer(self, port_id: PortIdentity, storage_uri: str): writer_thread, ) + state_document, _ = DocumentFactory.open_document( + VFSURIFactory.state_uri(storage_uri_base) + ) + state_buffered_item_writer = state_document.writer( + str(get_worker_index(self.worker_id)) + ) + state_writer_queue = Queue() + state_port_writer = PortStorageWriter( + buffered_item_writer=state_buffered_item_writer, + queue=state_writer_queue, + ) + state_writer_thread = threading.Thread( + target=state_port_writer.run, + daemon=True, + name=f"port_state_writer_thread_{port_id}", + ) + state_writer_thread.start() + self._port_state_writers[port_id] = ( + state_writer_queue, + state_port_writer, + state_writer_thread, + ) + def get_port(self, port_id=None) -> WorkerPort: return list(self._ports.values())[0] @@ -171,6 +203,20 @@ def save_tuple_to_storage_if_needed(self, tuple_: Tuple, port_id=None) -> None: PortStorageWriterElement(data_tuple=tuple_) ) + def save_state_to_storage_if_needed(self, state: State, port_id=None) -> None: + # When port_id is omitted the same state row is fanned out to + # every output port's state table. This mirrors the + # broadcast-to-all-workers behavior on the emit side: state is + # shared context, not per-key data, so every downstream operator + # (and every worker reading the materialization) needs the full + # set. + element = PortStorageWriterElement(data_tuple=state.to_tuple()) + if port_id is None: + for writer_queue, _, _ in self._port_state_writers.values(): + writer_queue.put(element) + elif port_id in self._port_state_writers: + self._port_state_writers[port_id][0].put(element) + def close_port_storage_writers(self) -> None: """ Flush the buffers of port storage writers and wait for all the @@ -184,6 +230,11 @@ def close_port_storage_writers(self) -> None: for _, _, writer_thread in self._port_storage_writers.values(): # This blocking call will wait for all the writer to finish commit writer_thread.join() + for _, state_writer, _ in self._port_state_writers.values(): + state_writer.stop() + for _, _, state_writer_thread in self._port_state_writers.values(): + state_writer_thread.join() + self._port_state_writers.clear() def add_partitioning(self, tag: PhysicalLink, partitioning: Partitioning) -> None: """ diff --git a/amber/src/main/python/core/runnables/main_loop.py b/amber/src/main/python/core/runnables/main_loop.py index ab35cda81b9..1334af12bfe 100644 --- a/amber/src/main/python/core/runnables/main_loop.py +++ b/amber/src/main/python/core/runnables/main_loop.py @@ -202,6 +202,7 @@ def process_input_state(self) -> None: payload=batch, ) ) + self.context.output_manager.save_state_to_storage_if_needed(output_state) def process_tuple_with_udf(self) -> Iterator[Optional[Tuple]]: """ diff --git a/amber/src/main/python/core/storage/document_factory.py b/amber/src/main/python/core/storage/document_factory.py index 9b686ab66b6..bd690ceb592 100644 --- a/amber/src/main/python/core/storage/document_factory.py +++ b/amber/src/main/python/core/storage/document_factory.py @@ -61,30 +61,35 @@ def create_document(uri: str, schema: Schema) -> VirtualDocument: if parsed_uri.scheme == VFSURIFactory.VFS_FILE_URI_SCHEME: _, _, _, resource_type = VFSURIFactory.decode_uri(uri) - if resource_type in {VFSResourceType.RESULT}: - storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) - - # Convert Amber Schema to Iceberg Schema with LARGE_BINARY - # field name encoding - iceberg_schema = amber_schema_to_iceberg_schema(schema) - - create_table( - IcebergCatalogInstance.get_instance(), - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - iceberg_schema, - override_if_exists=True, - ) - - return IcebergDocument[Tuple]( - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - iceberg_schema, - amber_tuples_to_arrow_table, - arrow_table_to_amber_tuples, - ) - else: - raise ValueError(f"Resource type {resource_type} is not supported") + match resource_type: + case VFSResourceType.RESULT: + namespace = StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE + case VFSResourceType.STATE: + namespace = StorageConfig.ICEBERG_TABLE_STATE_NAMESPACE + case _: + raise ValueError(f"Resource type {resource_type} is not supported") + + storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) + # Convert Amber Schema to Iceberg Schema with LARGE_BINARY + # field name encoding + iceberg_schema = amber_schema_to_iceberg_schema(schema) + + create_table( + IcebergCatalogInstance.get_instance(), + namespace, + storage_key, + iceberg_schema, + override_if_exists=True, + ) + + return IcebergDocument[Tuple]( + namespace, + storage_key, + iceberg_schema, + amber_tuples_to_arrow_table, + arrow_table_to_amber_tuples, + ) + else: raise NotImplementedError( f"Unsupported URI scheme: {parsed_uri.scheme} for creating the document" @@ -96,30 +101,36 @@ def open_document(uri: str) -> typing.Tuple[VirtualDocument, Optional[Schema]]: if parsed_uri.scheme == "vfs": _, _, _, resource_type = VFSURIFactory.decode_uri(uri) - if resource_type in {VFSResourceType.RESULT}: - storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) - - table = load_table_metadata( - IcebergCatalogInstance.get_instance(), - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - ) - - if table is None: - raise ValueError("No storage is found for the given URI") - - amber_schema = Schema(table.schema().as_arrow()) - - document = IcebergDocument( - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - table.schema(), - amber_tuples_to_arrow_table, - arrow_table_to_amber_tuples, - ) - return document, amber_schema - else: - raise ValueError(f"Resource type {resource_type} is not supported") + match resource_type: + case VFSResourceType.RESULT: + namespace = StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE + case VFSResourceType.STATE: + namespace = StorageConfig.ICEBERG_TABLE_STATE_NAMESPACE + case _: + raise ValueError(f"Resource type {resource_type} is not supported") + + storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) + + table = load_table_metadata( + IcebergCatalogInstance.get_instance(), + namespace, + storage_key, + ) + + if table is None: + raise ValueError("No storage is found for the given URI") + + amber_schema = Schema(table.schema().as_arrow()) + + document = IcebergDocument( + namespace, + storage_key, + table.schema(), + amber_tuples_to_arrow_table, + arrow_table_to_amber_tuples, + ) + return document, amber_schema + else: raise NotImplementedError( f"Unsupported URI scheme: {parsed_uri.scheme} for opening the document" diff --git a/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py b/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py index 6122bbb8b98..3e0e2d48ab5 100644 --- a/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py +++ b/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py @@ -34,9 +34,10 @@ from core.architecture.sendsemantics.round_robin_partitioner import ( RoundRobinPartitioner, ) -from core.models import Tuple, InternalQueue, DataFrame, DataPayload +from core.models import Tuple, InternalQueue, DataFrame, DataPayload, State, StateFrame from core.models.internal_queue import DataElement, ECMElement from core.storage.document_factory import DocumentFactory +from core.storage.vfs_uri_factory import VFSURIFactory from core.util import Stoppable, get_one_of from core.util.runnable import Runnable from core.util.virtual_identity import get_from_actor_id_for_input_port_storage @@ -132,14 +133,28 @@ def run(self) -> None: emits an EndChannel ECM. Use the same partitioner implementation as that in output manager, where a tuple is batched by the partitioner and only selected as the input of this worker according to the partitioner. + + States and tuples are persisted to separate tables, so the original + interleaving is lost and replay has to pick an order: we replay states + first because downstream operators typically need their state set up + before they process the incoming tuples. Every state is broadcast to + every downstream worker -- no partitioner filtering, unlike the tuple + loop. State is shared context (e.g. config / counters), not per-key + data, so each worker needs the full set. """ try: self.materialization, self.tuple_schema = DocumentFactory.open_document( - self.uri + VFSURIFactory.result_uri(self.uri) ) self.emit_ecm("StartChannel", EmbeddedControlMessageType.NO_ALIGNMENT) - storage_iterator = self.materialization.get() + state_document, _ = DocumentFactory.open_document( + VFSURIFactory.state_uri(self.uri) + ) + for state_row in state_document.get(): + self.emit_payload(StateFrame(State.from_tuple(state_row))) + + storage_iterator = self.materialization.get() # Iterate and process tuples. for tup in storage_iterator: if self._stopped: diff --git a/amber/src/main/python/core/storage/storage_config.py b/amber/src/main/python/core/storage/storage_config.py index 0e47bdb71ae..82335909874 100644 --- a/amber/src/main/python/core/storage/storage_config.py +++ b/amber/src/main/python/core/storage/storage_config.py @@ -32,6 +32,7 @@ class StorageConfig: ICEBERG_REST_CATALOG_URI = None ICEBERG_REST_CATALOG_WAREHOUSE_NAME = None ICEBERG_TABLE_RESULT_NAMESPACE = None + ICEBERG_TABLE_STATE_NAMESPACE = None ICEBERG_FILE_STORAGE_DIRECTORY_PATH = None ICEBERG_TABLE_COMMIT_BATCH_SIZE = None @@ -51,6 +52,7 @@ def initialize( rest_catalog_uri, rest_catalog_warehouse_name, table_result_namespace, + table_state_namespace, directory_path, commit_batch_size, s3_endpoint, @@ -71,6 +73,7 @@ def initialize( cls.ICEBERG_REST_CATALOG_WAREHOUSE_NAME = rest_catalog_warehouse_name cls.ICEBERG_TABLE_RESULT_NAMESPACE = table_result_namespace + cls.ICEBERG_TABLE_STATE_NAMESPACE = table_state_namespace cls.ICEBERG_FILE_STORAGE_DIRECTORY_PATH = directory_path cls.ICEBERG_TABLE_COMMIT_BATCH_SIZE = int(commit_batch_size) diff --git a/amber/src/main/python/core/storage/vfs_uri_factory.py b/amber/src/main/python/core/storage/vfs_uri_factory.py index de0c5db56ec..883450abf2b 100644 --- a/amber/src/main/python/core/storage/vfs_uri_factory.py +++ b/amber/src/main/python/core/storage/vfs_uri_factory.py @@ -34,6 +34,7 @@ class VFSResourceType(str, Enum): RESULT = "result" RUNTIME_STATISTICS = "runtimeStatistics" CONSOLE_MESSAGES = "consoleMessages" + STATE = "state" class VFSURIFactory: @@ -88,12 +89,22 @@ def extract_value(key: str) -> str: ) @staticmethod - def create_result_uri(workflow_id, execution_id, global_port_id) -> str: - """Creates a URI pointing to a result storage.""" - base_uri = ( + def create_port_base_uri(workflow_id, execution_id, global_port_id) -> str: + """Base URI for a port. Result and state URIs derive from it via + `result_uri` / `state_uri`. + """ + return ( f"{VFSURIFactory.VFS_FILE_URI_SCHEME}:///wid/{workflow_id.id}" f"/eid/{execution_id.id}/globalportid/" f"{serialize_global_port_identity(global_port_id)}" ) + @staticmethod + def result_uri(base_uri: str) -> str: + """The result-resource URI under a port base URI.""" return f"{base_uri}/{VFSResourceType.RESULT.value}" + + @staticmethod + def state_uri(base_uri: str) -> str: + """The state-resource URI under a port base URI.""" + return f"{base_uri}/{VFSResourceType.STATE.value}" diff --git a/amber/src/main/python/texera_run_python_worker.py b/amber/src/main/python/texera_run_python_worker.py index 8687298f819..9b21fa53343 100644 --- a/amber/src/main/python/texera_run_python_worker.py +++ b/amber/src/main/python/texera_run_python_worker.py @@ -52,6 +52,7 @@ def init_loguru_logger(stream_log_level) -> None: iceberg_rest_catalog_uri, iceberg_rest_catalog_warehouse_name, iceberg_table_namespace, + iceberg_table_state_namespace, iceberg_file_storage_directory_path, iceberg_table_commit_batch_size, s3_endpoint, @@ -68,6 +69,7 @@ def init_loguru_logger(stream_log_level) -> None: iceberg_rest_catalog_uri, iceberg_rest_catalog_warehouse_name, iceberg_table_namespace, + iceberg_table_state_namespace, iceberg_file_storage_directory_path, iceberg_table_commit_batch_size, s3_endpoint, diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala index affbd786f9b..2862714ffb7 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala @@ -20,7 +20,7 @@ package org.apache.texera.amber.engine.architecture.messaginglayer import org.apache.texera.amber.core.state.State -import org.apache.texera.amber.core.storage.DocumentFactory +import org.apache.texera.amber.core.storage.{DocumentFactory, VFSURIFactory} import org.apache.texera.amber.core.storage.model.BufferedItemWriter import org.apache.texera.amber.core.tuple._ import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity} @@ -33,7 +33,7 @@ import org.apache.texera.amber.engine.architecture.messaginglayer.OutputManager. import org.apache.texera.amber.engine.architecture.sendsemantics.partitioners._ import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings._ import org.apache.texera.amber.engine.architecture.worker.managers.{ - OutputPortResultWriterThread, + OutputPortStorageWriterThread, PortStorageWriterTerminateSignal } import org.apache.texera.amber.engine.common.AmberLogging @@ -121,7 +121,10 @@ class OutputManager( mutable.HashMap[(PhysicalLink, ActorVirtualIdentity), NetworkOutputBuffer]() private val outputPortResultWriterThreads - : mutable.HashMap[PortIdentity, OutputPortResultWriterThread] = + : mutable.HashMap[PortIdentity, OutputPortStorageWriterThread] = + mutable.HashMap() + + private val stateWriterThreads: mutable.HashMap[PortIdentity, OutputPortStorageWriterThread] = mutable.HashMap() /** @@ -191,19 +194,20 @@ class OutputManager( def emitState(state: State): Unit = { networkOutputBuffers.foreach(kv => kv._2.sendState(state)) + saveStateToStorageIfNeeded(state) } - def addPort(portId: PortIdentity, schema: Schema, storageURIOption: Option[URI]): Unit = { + def addPort(portId: PortIdentity, schema: Schema, storageURIBaseOption: Option[URI]): Unit = { // each port can only be added and initialized once. if (this.ports.contains(portId)) { return } this.ports(portId) = WorkerPort(schema) - // if a storage URI is provided, set up a storage writer thread - storageURIOption match { - case Some(storageUri) => setupOutputStorageWriterThread(portId, storageUri) - case None => // No need to add a writer + // if a storage URI base is provided, set up storage writer threads + storageURIBaseOption match { + case Some(portBaseURI) => setupOutputStorageWriterThread(portId, portBaseURI) + case None => // No need to add a writer } } @@ -232,6 +236,15 @@ class OutputManager( }) } + private def saveStateToStorageIfNeeded(state: State): Unit = { + // The same state row is fanned out to every output port's state + // table. This mirrors the broadcast-to-all-workers behavior on the + // emit side: state is shared context, not per-key data, so every + // downstream operator (and every worker reading the materialization) + // needs the full set. + stateWriterThreads.values.foreach(_.queue.put(Left(state.toTuple))) + } + /** * Singal the port storage writer to flush the remaining buffer and wait for commits to finish so that * the output port is properly completed. If the output port does not need storage, no action will be done. @@ -251,7 +264,11 @@ class OutputManager( writerThread.getFailure.foreach(throw _) case None => } - + this.stateWriterThreads.remove(outputPortId).foreach { writerThread => + writerThread.queue.put(Right(PortStorageWriterTerminateSignal)) + writerThread.join() + writerThread.getFailure.foreach(throw _) + } } def getPort(portId: PortIdentity): WorkerPort = ports(portId) @@ -285,15 +302,26 @@ class OutputManager( ports.head._1 } - private def setupOutputStorageWriterThread(portId: PortIdentity, storageUri: URI): Unit = { + private def setupOutputStorageWriterThread(portId: PortIdentity, portBaseURI: URI): Unit = { val bufferedItemWriter = DocumentFactory - .openDocument(storageUri) + .openDocument(VFSURIFactory.resultURI(portBaseURI)) ._1 .writer(VirtualIdentityUtils.getWorkerIndex(actorId).toString) .asInstanceOf[BufferedItemWriter[Tuple]] - val writerThread = new OutputPortResultWriterThread(bufferedItemWriter) + val writerThread = new OutputPortStorageWriterThread(bufferedItemWriter) this.outputPortResultWriterThreads(portId) = writerThread writerThread.start() + + // The state document is provisioned alongside the result document + // by RegionExecutionCoordinator, so it is always present. + val stateWriter = DocumentFactory + .openDocument(VFSURIFactory.stateURI(portBaseURI)) + ._1 + .writer(VirtualIdentityUtils.getWorkerIndex(actorId).toString) + .asInstanceOf[BufferedItemWriter[Tuple]] + val stateWriterThread = new OutputPortStorageWriterThread(stateWriter) + this.stateWriterThreads(portId) = stateWriterThread + stateWriterThread.start() } } diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonWorkflowWorker.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonWorkflowWorker.scala index 4ff5ff15ae3..3358e31e65f 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonWorkflowWorker.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonWorkflowWorker.scala @@ -187,6 +187,7 @@ class PythonWorkflowWorker( if (isRest) StorageConfig.icebergRESTCatalogUri else "", if (isRest) StorageConfig.icebergRESTCatalogWarehouseName else "", StorageConfig.icebergTableResultNamespace, + StorageConfig.icebergTableStateNamespace, StorageConfig.fileStorageDirectoryPath.toString, StorageConfig.icebergTableCommitBatchSize.toString, StorageConfig.s3Endpoint, diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/CostBasedScheduleGenerator.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/CostBasedScheduleGenerator.scala index 401ccddc0a4..43e8d281ce3 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/CostBasedScheduleGenerator.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/CostBasedScheduleGenerator.scala @@ -20,7 +20,7 @@ package org.apache.texera.amber.engine.architecture.scheduling import org.apache.texera.amber.config.ApplicationConfig -import org.apache.texera.amber.core.storage.VFSURIFactory.createResultURI +import org.apache.texera.amber.core.storage.VFSURIFactory.createPortBaseURI import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, PhysicalOpIdentity} import org.apache.texera.amber.core.workflow._ import org.apache.texera.amber.engine.architecture.scheduling.SchedulingUtils.replaceVertex @@ -174,12 +174,12 @@ class CostBasedScheduleGenerator( // Allocate an URI for each of these output ports val outputPortConfigs: Map[GlobalPortIdentity, OutputPortConfig] = outputPortIdsNeedingStorage.map { gpid => - val outputWriterURI = createResultURI( + val portBaseURI = createPortBaseURI( workflowId = workflowContext.workflowId, executionId = workflowContext.executionId, globalPortId = gpid ) - gpid -> OutputPortConfig(outputWriterURI) + gpid -> OutputPortConfig(portBaseURI) }.toMap val resourceConfig = ResourceConfig(portConfigs = outputPortConfigs) @@ -237,7 +237,7 @@ class CostBasedScheduleGenerator( s"the outout port $globalOutputPortId has not been assigned a URI yet." ) ) - .storageURI + .storageURIBase // Group all available URIs of this input port together acc.updated( diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/ExpansionGreedyScheduleGenerator.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/ExpansionGreedyScheduleGenerator.scala index 4bb89338967..304e1496f8a 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/ExpansionGreedyScheduleGenerator.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/ExpansionGreedyScheduleGenerator.scala @@ -21,7 +21,7 @@ package org.apache.texera.amber.engine.architecture.scheduling import com.typesafe.scalalogging.LazyLogging import org.apache.texera.amber.core.WorkflowRuntimeException -import org.apache.texera.amber.core.storage.VFSURIFactory.createResultURI +import org.apache.texera.amber.core.storage.VFSURIFactory.createPortBaseURI import org.apache.texera.amber.core.virtualidentity.PhysicalOpIdentity import org.apache.texera.amber.core.workflow.{ GlobalPortIdentity, @@ -331,7 +331,7 @@ class ExpansionGreedyScheduleGenerator( private def getStorageURIFromGlobalOutputPortId(outputPortId: GlobalPortIdentity) = { assert(!outputPortId.input) - createResultURI( + createPortBaseURI( workflowId = workflowContext.workflowId, executionId = workflowContext.executionId, globalPortId = outputPortId diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala index 254c16bf34b..f72487f268b 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala @@ -21,7 +21,8 @@ package org.apache.texera.amber.engine.architecture.scheduling import org.apache.pekko.pattern.gracefulStop import com.twitter.util.{Duration => TwitterDuration, Future, JavaTimer, Return, Throw, Timer} -import org.apache.texera.amber.core.storage.DocumentFactory +import org.apache.texera.amber.core.state.State +import org.apache.texera.amber.core.storage.{DocumentFactory, VFSURIFactory} import org.apache.texera.amber.core.storage.VFSURIFactory.decodeURI import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity import org.apache.texera.amber.core.workflow.{GlobalPortIdentity, PhysicalLink, PhysicalOp} @@ -465,7 +466,7 @@ class RegionExecutionCoordinator( opId = physicalOp.id, portId = outputPortId ) => - cfg.storageURI.toString + cfg.storageURIBase.toString } .getOrElse("") Some( @@ -568,18 +569,21 @@ class RegionExecutionCoordinator( ): Unit = { portConfigs.foreach { case (outputPortId, portConfig) => - val storageUriToAdd = portConfig.storageURI - val (_, eid, _, _) = decodeURI(storageUriToAdd) + val portBaseURI = portConfig.storageURIBase + val resultURI = VFSURIFactory.resultURI(portBaseURI) + val stateURI = VFSURIFactory.stateURI(portBaseURI) val schemaOptional = region.getOperator(outputPortId.opId).outputPorts(outputPortId.portId)._3 val schema = schemaOptional.getOrElse(throw new IllegalStateException("Schema is missing")) - DocumentFactory.createDocument(storageUriToAdd, schema) + DocumentFactory.createDocument(resultURI, schema) + DocumentFactory.createDocument(stateURI, State.schema) if (!isRestart) { + val (_, eid, _, _) = decodeURI(resultURI) WorkflowExecutionsResource.insertOperatorPortResultUri( eid = eid, globalPortId = outputPortId, - uri = storageUriToAdd + uri = resultURI ) } } diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/config/PortConfig.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/config/PortConfig.scala index b4a1e058b44..56743ae0956 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/config/PortConfig.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/config/PortConfig.scala @@ -31,9 +31,13 @@ sealed trait PortConfig { def storageURIs: List[URI] } -/** An output port requires exactly one materialization URI. */ -final case class OutputPortConfig(storageURI: URI) extends PortConfig { - override val storageURIs: List[URI] = List(storageURI) +/** + * An output port requires exactly one materialization base URI. Result and + * state URIs hang off it via `VFSURIFactory.resultURI` / `stateURI`; this + * field is *not* a URI you can pass straight to `DocumentFactory`. + */ +final case class OutputPortConfig(storageURIBase: URI) extends PortConfig { + override val storageURIs: List[URI] = List(storageURIBase) } /** diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala index 10fbbc44a2c..428d9fb48cb 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala @@ -21,7 +21,8 @@ package org.apache.texera.amber.engine.architecture.worker.managers import io.grpc.MethodDescriptor import org.apache.texera.amber.config.ApplicationConfig -import org.apache.texera.amber.core.storage.DocumentFactory +import org.apache.texera.amber.core.state.State +import org.apache.texera.amber.core.storage.{DocumentFactory, VFSURIFactory} import org.apache.texera.amber.core.storage.model.VirtualDocument import org.apache.texera.amber.core.tuple.Tuple import org.apache.texera.amber.core.virtualidentity.{ @@ -45,7 +46,11 @@ import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.{ DPInputQueueElement, FIFOMessageElement } -import org.apache.texera.amber.engine.common.ambermessage.{DataFrame, WorkflowFIFOMessage} +import org.apache.texera.amber.engine.common.ambermessage.{ + DataFrame, + StateFrame, + WorkflowFIFOMessage +} import org.apache.texera.amber.util.VirtualIdentityUtils.getFromActorIdForInputPortStorage import java.net.URI @@ -78,14 +83,35 @@ class InputPortMaterializationReaderThread( def finished: Boolean = isFinished.get() /** - * Read from the materialization stoage, and mimcs the behavior of an upstream worker's output manager. + * Read from the materialization storage, and mimics the behavior of an upstream worker's output manager. + * + * States and tuples are persisted to separate tables, so the original + * interleaving is lost and replay has to pick an order: we replay states + * first because downstream operators typically need their state set up + * before they process the incoming tuples. Every state is broadcast to + * every downstream worker -- no partitioner filtering, unlike the tuple + * loop. State is shared context (e.g. config / counters), not per-key + * data, so each worker needs the full set. */ override def run(): Unit = { // Notify the input port of start of input channel emitECM(METHOD_START_CHANNEL, NO_ALIGNMENT) try { + val stateDocument = + DocumentFactory + .openDocument(VFSURIFactory.stateURI(uri)) + ._1 + .asInstanceOf[VirtualDocument[Tuple]] + val stateReadIterator = stateDocument.get() + while (stateReadIterator.hasNext) { + val state = State.fromTuple(stateReadIterator.next()) + inputMessageQueue.put( + FIFOMessageElement(WorkflowFIFOMessage(channelId, getSequenceNumber, StateFrame(state))) + ) + } + val materialization: VirtualDocument[Tuple] = DocumentFactory - .openDocument(uri) + .openDocument(VFSURIFactory.resultURI(uri)) ._1 .asInstanceOf[VirtualDocument[Tuple]] val storageReadIterator = materialization.get() diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortResultWriterThread.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortStorageWriterThread.scala similarity index 98% rename from amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortResultWriterThread.scala rename to amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortStorageWriterThread.scala index 4223d920da5..bcabde00894 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortResultWriterThread.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortStorageWriterThread.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal sealed trait TerminateSignal case object PortStorageWriterTerminateSignal extends TerminateSignal -class OutputPortResultWriterThread( +class OutputPortStorageWriterThread( bufferedItemWriter: BufferedItemWriter[Tuple] ) extends Thread { diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/promisehandlers/AssignPortHandler.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/promisehandlers/AssignPortHandler.scala index fe959733abb..9e2c7f7c2b6 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/promisehandlers/AssignPortHandler.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/promisehandlers/AssignPortHandler.scala @@ -58,11 +58,11 @@ trait AssignPortHandler { dp.stateManager.assertState(READY, RUNNING, PAUSED) } } else { - val storageURIOption: Option[URI] = msg.storageUris.head match { + val storageURIBaseOption: Option[URI] = msg.storageUris.head match { case "" => None case uriString => Some(URI.create(uriString)) } - dp.outputManager.addPort(msg.portId, schema, storageURIOption) + dp.outputManager.addPort(msg.portId, schema, storageURIBaseOption) } EmptyReturn() } diff --git a/amber/src/test/python/core/architecture/packaging/test_output_manager.py b/amber/src/test/python/core/architecture/packaging/test_output_manager.py new file mode 100644 index 00000000000..dcf7ccde673 --- /dev/null +++ b/amber/src/test/python/core/architecture/packaging/test_output_manager.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock + +import pytest + +from core.architecture.packaging.output_manager import OutputManager +from core.models.state import State +from core.storage.runnables.port_storage_writer import PortStorageWriterElement +from proto.org.apache.texera.amber.core import PortIdentity + + +def _stub_state_writer(output_manager, port_id): + """Inject a (queue, writer, thread) triple as if a port were set up.""" + queue = MagicMock() + writer = MagicMock() + thread = MagicMock() + output_manager._port_state_writers[port_id] = (queue, writer, thread) + return queue, writer, thread + + +class TestSaveStateToStorageIfNeeded: + @pytest.fixture + def output_manager(self): + return OutputManager(worker_id="Worker:WF0-test-main-0") + + @pytest.fixture + def port_a(self): + return PortIdentity(id=0, internal=False) + + @pytest.fixture + def port_b(self): + return PortIdentity(id=1, internal=False) + + @pytest.fixture + def state(self): + return State({"loop_counter": 1, "i": 2}) + + def test_no_state_writers_is_a_noop(self, output_manager, state): + # With no port set up, save_state_to_storage_if_needed must not + # touch any writer. + output_manager.save_state_to_storage_if_needed(state) # no-op + + def test_unknown_port_id_is_a_noop(self, output_manager, state, port_a): + output_manager.save_state_to_storage_if_needed(state, port_id=port_a) + # No assertion needed -- the absence of any writer means nothing + # was attempted. + + def test_enqueues_to_every_port_when_port_id_omitted( + self, output_manager, state, port_a, port_b + ): + queue_a, _, _ = _stub_state_writer(output_manager, port_a) + queue_b, _, _ = _stub_state_writer(output_manager, port_b) + + output_manager.save_state_to_storage_if_needed(state) + + # Each port's writer queue receives one PortStorageWriterElement. + # Critically, save is non-blocking -- the call must not invoke + # put_one / close on the buffered writer directly (those happen + # off-thread). + assert queue_a.put.call_count == 1 + assert queue_b.put.call_count == 1 + assert isinstance(queue_a.put.call_args.args[0], PortStorageWriterElement) + assert isinstance(queue_b.put.call_args.args[0], PortStorageWriterElement) + + def test_enqueues_only_to_selected_port_when_port_id_specified( + self, output_manager, state, port_a, port_b + ): + queue_a, _, _ = _stub_state_writer(output_manager, port_a) + queue_b, _, _ = _stub_state_writer(output_manager, port_b) + + output_manager.save_state_to_storage_if_needed(state, port_id=port_a) + + assert queue_a.put.call_count == 1 + queue_b.put.assert_not_called() + + def test_close_port_storage_writers_stops_state_threads( + self, output_manager, port_a, port_b + ): + # After the port completes, every state-writer thread must be + # stopped and joined so the buffered writer's close() (which + # flushes the final Iceberg commit) actually runs. + _, writer_a, thread_a = _stub_state_writer(output_manager, port_a) + _, writer_b, thread_b = _stub_state_writer(output_manager, port_b) + + output_manager.close_port_storage_writers() + + writer_a.stop.assert_called_once() + writer_b.stop.assert_called_once() + thread_a.join.assert_called_once() + thread_b.join.assert_called_once() + assert output_manager._port_state_writers == {} diff --git a/amber/src/test/python/core/runnables/test_main_loop.py b/amber/src/test/python/core/runnables/test_main_loop.py index 400a7f2a907..78b8ab76aa2 100644 --- a/amber/src/test/python/core/runnables/test_main_loop.py +++ b/amber/src/test/python/core/runnables/test_main_loop.py @@ -1393,6 +1393,82 @@ def fake_switch_context(): assert second_output.payload.frame["value"] == 42 assert second_output.payload.frame["port"] == 0 + @pytest.mark.timeout(2) + def test_process_input_state_persists_output_state_to_storage( + self, + main_loop, + mock_data_output_channel, + monkeypatch, + ): + # process_input_state must invoke save_state_to_storage_if_needed + # with the freshly emitted output state, so every state that flows + # downstream is also durable on the upstream output port. + class DummyExecutor: + @staticmethod + def process_state(state: State, port: int) -> State: + return State({"value": state["value"] + 1, "port": port}) + + saved_states: list[State] = [] + main_loop.context.executor_manager.executor = DummyExecutor() + monkeypatch.setattr(main_loop, "_check_and_process_control", lambda: None) + monkeypatch.setattr( + main_loop.context.output_manager, + "emit_state", + lambda state: [(mock_data_output_channel.to_worker_id, StateFrame(state))], + ) + monkeypatch.setattr( + main_loop.context.output_manager, + "save_state_to_storage_if_needed", + lambda state: saved_states.append(state), + ) + + def fake_switch_context(): + current_input_state = ( + main_loop.context.state_processing_manager.current_input_state + ) + if current_input_state is not None: + main_loop.context.state_processing_manager.current_output_state = ( + DummyExecutor.process_state(current_input_state, 0) + ) + + monkeypatch.setattr(main_loop, "_switch_context", fake_switch_context) + + main_loop._process_state(State({"value": 1})) + main_loop._process_state(State({"value": 41})) + + # Each input state produced one output state, so both must have + # been persisted in order. + assert [s["value"] for s in saved_states] == [2, 42] + assert all(s["port"] == 0 for s in saved_states) + + @pytest.mark.timeout(2) + def test_process_input_state_does_not_save_when_no_output( + self, + main_loop, + monkeypatch, + ): + # When the executor returns no output state (process_state returned + # None), save_state_to_storage_if_needed must not be called -- no + # state means nothing to materialize. + save_calls: list[State] = [] + monkeypatch.setattr(main_loop, "_check_and_process_control", lambda: None) + monkeypatch.setattr( + main_loop.context.output_manager, + "emit_state", + lambda state: [], + ) + monkeypatch.setattr( + main_loop.context.output_manager, + "save_state_to_storage_if_needed", + lambda state: save_calls.append(state), + ) + # Pretend DataProc consumed the input but produced no output. + monkeypatch.setattr(main_loop, "_switch_context", lambda: None) + + main_loop._process_state(State({"value": 1})) + + assert save_calls == [] + @pytest.mark.timeout(2) def test_main_loop_thread_can_process_state( self, diff --git a/amber/src/test/python/core/storage/iceberg/test_iceberg_document.py b/amber/src/test/python/core/storage/iceberg/test_iceberg_document.py index a218c64a2d8..381f8e5ff64 100644 --- a/amber/src/test/python/core/storage/iceberg/test_iceberg_document.py +++ b/amber/src/test/python/core/storage/iceberg/test_iceberg_document.py @@ -24,6 +24,7 @@ from concurrent.futures.thread import ThreadPoolExecutor from core.models import Schema, Tuple +from core.models.state import State from core.storage.document_factory import DocumentFactory from core.storage.storage_config import StorageConfig from core.storage.vfs_uri_factory import VFSURIFactory @@ -49,6 +50,7 @@ rest_catalog_uri="http://localhost:8181/catalog/", rest_catalog_warehouse_name="texera", table_result_namespace="operator-port-result", + table_state_namespace="operator-port-state", directory_path=tempfile.mkdtemp(prefix="texera-iceberg-warehouse-"), commit_batch_size=4096, s3_endpoint="http://localhost:9000", @@ -81,17 +83,21 @@ def iceberg_document(self, amber_schema): with a random operator id """ operator_uuid = str(uuid.uuid4()).replace("-", "") - uri = VFSURIFactory.create_result_uri( - WorkflowIdentity(id=0), - ExecutionIdentity(id=0), - GlobalPortIdentity( - op_id=PhysicalOpIdentity( - logical_op_id=OperatorIdentity(id=f"test_table_{operator_uuid}"), - layer_name="main", + uri = VFSURIFactory.result_uri( + VFSURIFactory.create_port_base_uri( + WorkflowIdentity(id=0), + ExecutionIdentity(id=0), + GlobalPortIdentity( + op_id=PhysicalOpIdentity( + logical_op_id=OperatorIdentity( + id=f"test_table_{operator_uuid}" + ), + layer_name="main", + ), + port_id=PortIdentity(id=0), + input=False, ), - port_id=PortIdentity(id=0), - input=False, - ), + ) ) DocumentFactory.create_document(uri, amber_schema) document, _ = DocumentFactory.open_document(uri) @@ -322,3 +328,85 @@ def test_get_counts(self, iceberg_document, sample_items): assert iceberg_document.get_count() == len(sample_items), ( "get_count should return the same number as the length of sample_items" ) + + def test_state_materialization_round_trip(self): + operator_uuid = str(uuid.uuid4()).replace("-", "") + base_uri = VFSURIFactory.create_port_base_uri( + WorkflowIdentity(id=0), + ExecutionIdentity(id=0), + GlobalPortIdentity( + op_id=PhysicalOpIdentity( + logical_op_id=OperatorIdentity(id=f"test_state_{operator_uuid}"), + layer_name="main", + ), + port_id=PortIdentity(id=0), + input=False, + ), + ) + state_uri = VFSURIFactory.state_uri(base_uri) + DocumentFactory.create_document(state_uri, State.SCHEMA) + document, _ = DocumentFactory.open_document(state_uri) + + state = State( + { + "loop_counter": 3, + "name": "outer-loop", + "payload": b"\x00\x01state-bytes", + "nested": {"enabled": True, "values": [1, 2, 3]}, + } + ) + + writer = document.writer(str(uuid.uuid4())) + writer.open() + writer.put_one(state.to_tuple()) + writer.close() + + stored_rows = list(document.get()) + assert len(stored_rows) == 1 + assert State.from_tuple(stored_rows[0]) == state + + def test_multiple_states_materialize_as_rows_in_one_table(self): + operator_uuid = str(uuid.uuid4()).replace("-", "") + base_uri = VFSURIFactory.create_port_base_uri( + WorkflowIdentity(id=0), + ExecutionIdentity(id=0), + GlobalPortIdentity( + op_id=PhysicalOpIdentity( + logical_op_id=OperatorIdentity( + id=f"test_multiple_states_{operator_uuid}" + ), + layer_name="main", + ), + port_id=PortIdentity(id=0), + input=False, + ), + ) + state_uri = VFSURIFactory.state_uri(base_uri) + DocumentFactory.create_document(state_uri, State.SCHEMA) + document, _ = DocumentFactory.open_document(state_uri) + + states = [ + State({"loop_counter": 0, "i": 1, "payload": b"first"}), + State( + { + "loop_counter": 1, + "i": 2, + "payload": b"second", + "nested": {"values": [3, 4]}, + } + ), + ] + + writer = document.writer(str(uuid.uuid4())) + writer.open() + for state in states: + writer.put_one(state.to_tuple()) + writer.close() + + stored_rows = list(document.get()) + assert len(stored_rows) == len(states) + actual_states = sorted( + [State.from_tuple(row) for row in stored_rows], + key=lambda state: state["loop_counter"], + ) + assert actual_states == states diff --git a/amber/src/test/python/core/storage/runnables/test_input_port_materialization_reader_runnable.py b/amber/src/test/python/core/storage/runnables/test_input_port_materialization_reader_runnable.py new file mode 100644 index 00000000000..5016c2df2f1 --- /dev/null +++ b/amber/src/test/python/core/storage/runnables/test_input_port_materialization_reader_runnable.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock, patch + +import pytest + +from core.models import State, StateFrame +from core.models.internal_queue import DataElement +from core.models.schema import Schema +from core.storage.runnables.input_port_materialization_reader_runnable import ( + InputPortMaterializationReaderRunnable, +) +from proto.org.apache.texera.amber.core import ( + ActorVirtualIdentity, + ChannelIdentity, +) + + +class TestRunStateReadingBlock: + """Cover the state-reading block in run() that opens the state + document and emits its rows as StateFrames directly to the input + queue (no partitioner filtering -- state is broadcast to every + worker). + """ + + @pytest.fixture + def me(self): + return ActorVirtualIdentity(name="me") + + @pytest.fixture + def runnable(self, me): + instance = InputPortMaterializationReaderRunnable.__new__( + InputPortMaterializationReaderRunnable + ) + instance.uri = "vfs:///wf/0/exec/0/result/op-a" + instance.worker_actor_id = me + instance.tuple_schema = Schema(raw_schema={"x": "INTEGER"}) + instance._stopped = False + instance._finished = False + instance.channel_id = ChannelIdentity(me, me, is_control=False) + instance.queue = MagicMock() + instance.partitioner = MagicMock() + # No tuple-batches and no ECM-flush payloads in these tests. + instance.partitioner.flush.return_value = [] + return instance + + def test_state_rows_are_emitted_as_state_frames(self, runnable): + state_a = State({"loop_counter": 0}) + state_b = State({"loop_counter": 1}) + + # The state document yields opaque tuples; from_tuple deserializes + # them. Patch from_tuple so we don't have to wire a real + # serialization. + result_doc = MagicMock() + result_doc.get.return_value = iter([]) # No materialized tuples. + state_doc = MagicMock() + state_doc.get.return_value = iter(["row-a", "row-b"]) + + with ( + patch( + "core.storage.runnables.input_port_materialization_reader_runnable.DocumentFactory" + ) as mock_factory, + patch.object(State, "from_tuple") as mock_from_tuple, + ): + mock_factory.open_document.side_effect = [ + (result_doc, runnable.tuple_schema), + (state_doc, None), + ] + mock_from_tuple.side_effect = [state_a, state_b] + + runnable.run() + + # Two StateFrames must have been put on the queue, in order. + # The state replay must NOT route through the partitioner -- + # state is shared context, broadcast to every worker. + runnable.partitioner.flush_state.assert_not_called() + state_frames = [ + call.args[0] + for call in runnable.queue.put.call_args_list + if isinstance(call.args[0], DataElement) + and isinstance(call.args[0].payload, StateFrame) + ] + assert [sf.payload.frame for sf in state_frames] == [state_a, state_b] + assert runnable._finished is True diff --git a/amber/src/test/python/core/storage/test_document_factory.py b/amber/src/test/python/core/storage/test_document_factory.py new file mode 100644 index 00000000000..859c0040246 --- /dev/null +++ b/amber/src/test/python/core/storage/test_document_factory.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock, patch + +import pytest + +from core.models import Schema +from core.storage.document_factory import DocumentFactory +from core.storage.storage_config import StorageConfig +from core.storage.vfs_uri_factory import VFSResourceType + + +# Avoid initializing the real config (only initializable once per process). +StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE = "test-result-ns" +StorageConfig.ICEBERG_TABLE_STATE_NAMESPACE = "test-state-ns" + +VFS_URI = "vfs:///wid/0/eid/0/opid/test/main/0/0/result" + + +@pytest.fixture +def schema(): + return Schema(raw_schema={"x": "INTEGER"}) + + +def _decode_returning(resource_type): + """Helper: build a VFSURIFactory.decode_uri side_effect.""" + return lambda _uri: (None, None, None, resource_type) + + +@patch("core.storage.document_factory.IcebergDocument") +@patch("core.storage.document_factory.amber_schema_to_iceberg_schema") +@patch("core.storage.document_factory.create_table") +@patch("core.storage.document_factory.IcebergCatalogInstance") +@patch("core.storage.document_factory.VFSURIFactory") +class TestCreateDocumentNamespaceRouting: + def test_state_resource_type_uses_state_namespace( + self, mock_vfs, _icb, mock_create_table, _amber_schema, _doc, schema + ): + mock_vfs.VFS_FILE_URI_SCHEME = "vfs" + mock_vfs.decode_uri.side_effect = _decode_returning(VFSResourceType.STATE) + + DocumentFactory.create_document(VFS_URI, schema) + + args, _ = mock_create_table.call_args + assert args[1] == StorageConfig.ICEBERG_TABLE_STATE_NAMESPACE + + def test_result_resource_type_uses_result_namespace( + self, mock_vfs, _icb, mock_create_table, _amber_schema, _doc, schema + ): + mock_vfs.VFS_FILE_URI_SCHEME = "vfs" + mock_vfs.decode_uri.side_effect = _decode_returning(VFSResourceType.RESULT) + + DocumentFactory.create_document(VFS_URI, schema) + + args, _ = mock_create_table.call_args + assert args[1] == StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE + + def test_unsupported_resource_type_raises_value_error( + self, mock_vfs, _icb, _create_table, _amber_schema, _doc, schema + ): + mock_vfs.VFS_FILE_URI_SCHEME = "vfs" + # CONSOLE_MESSAGES has no namespace mapping in the Python factory. + mock_vfs.decode_uri.side_effect = _decode_returning( + VFSResourceType.CONSOLE_MESSAGES + ) + + with pytest.raises(ValueError, match="not supported"): + DocumentFactory.create_document(VFS_URI, schema) + + +def test_create_document_rejects_non_vfs_scheme(schema): + with pytest.raises(NotImplementedError, match="Unsupported URI scheme"): + DocumentFactory.create_document("file:///tmp/x", schema) + + +@patch("core.storage.document_factory.IcebergDocument") +@patch("core.storage.document_factory.Schema") +@patch("core.storage.document_factory.load_table_metadata") +@patch("core.storage.document_factory.IcebergCatalogInstance") +@patch("core.storage.document_factory.VFSURIFactory") +class TestOpenDocumentNamespaceRouting: + @staticmethod + def _stub_table(): + table = MagicMock() + table.schema.return_value.as_arrow.return_value = MagicMock() + return table + + def test_state_resource_type_uses_state_namespace( + self, mock_vfs, _icb, mock_load, _schema_cls, _doc + ): + mock_vfs.VFS_FILE_URI_SCHEME = "vfs" + mock_vfs.decode_uri.side_effect = _decode_returning(VFSResourceType.STATE) + mock_load.return_value = self._stub_table() + + DocumentFactory.open_document(VFS_URI) + + args, _ = mock_load.call_args + assert args[1] == StorageConfig.ICEBERG_TABLE_STATE_NAMESPACE + + def test_unsupported_resource_type_raises_value_error( + self, mock_vfs, _icb, _load, _schema_cls, _doc + ): + mock_vfs.VFS_FILE_URI_SCHEME = "vfs" + mock_vfs.decode_uri.side_effect = _decode_returning( + VFSResourceType.CONSOLE_MESSAGES + ) + + with pytest.raises(ValueError, match="not supported"): + DocumentFactory.open_document(VFS_URI) + + def test_missing_table_raises_value_error( + self, mock_vfs, _icb, mock_load, _schema_cls, _doc + ): + mock_vfs.VFS_FILE_URI_SCHEME = "vfs" + mock_vfs.decode_uri.side_effect = _decode_returning(VFSResourceType.STATE) + mock_load.return_value = None + + with pytest.raises(ValueError, match="No storage is found"): + DocumentFactory.open_document(VFS_URI) diff --git a/amber/src/test/python/pytexera/storage/test_large_binary_manager.py b/amber/src/test/python/pytexera/storage/test_large_binary_manager.py index 64c7080e520..1942e91f8bc 100644 --- a/amber/src/test/python/pytexera/storage/test_large_binary_manager.py +++ b/amber/src/test/python/pytexera/storage/test_large_binary_manager.py @@ -34,6 +34,7 @@ def setup_storage_config(self): rest_catalog_uri="http://localhost:8181/catalog/", rest_catalog_warehouse_name="texera", table_result_namespace="test", + table_state_namespace="test-state", directory_path="/tmp/test", commit_batch_size=1000, s3_endpoint="http://localhost:9000", diff --git a/amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortResultWriterThreadSpec.scala b/amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortStorageWriterThreadSpec.scala similarity index 90% rename from amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortResultWriterThreadSpec.scala rename to amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortStorageWriterThreadSpec.scala index 31d8c41611d..d7ab0c18314 100644 --- a/amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortResultWriterThreadSpec.scala +++ b/amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortStorageWriterThreadSpec.scala @@ -32,7 +32,7 @@ import org.scalatest.flatspec.AnyFlatSpec import scala.collection.mutable -class OutputPortResultWriterThreadSpec extends AnyFlatSpec { +class OutputPortStorageWriterThreadSpec extends AnyFlatSpec { private class StubWriter( onPutOne: () => Unit = () => (), @@ -51,9 +51,9 @@ class OutputPortResultWriterThreadSpec extends AnyFlatSpec { private def throwing(msg: String): () => Unit = () => throw new RuntimeException(msg) - "OutputPortResultWriterThread" should "leave getFailure empty on a clean run" in { + "OutputPortStorageWriterThread" should "leave getFailure empty on a clean run" in { val writer = new StubWriter() - val thread = new OutputPortResultWriterThread(writer) + val thread = new OutputPortStorageWriterThread(writer) thread.start() thread.queue.put(Right(PortStorageWriterTerminateSignal)) thread.join() @@ -63,7 +63,7 @@ class OutputPortResultWriterThreadSpec extends AnyFlatSpec { it should "capture a close() exception in getFailure so the worker can re-throw" in { val writer = new StubWriter(onClose = throwing("test close failure")) - val thread = new OutputPortResultWriterThread(writer) + val thread = new OutputPortStorageWriterThread(writer) thread.start() thread.queue.put(Right(PortStorageWriterTerminateSignal)) thread.join() @@ -73,7 +73,7 @@ class OutputPortResultWriterThreadSpec extends AnyFlatSpec { it should "capture a putOne exception and still call close()" in { val writer = new StubWriter(onPutOne = throwing("test putOne failure")) - val thread = new OutputPortResultWriterThread(writer) + val thread = new OutputPortStorageWriterThread(writer) thread.start() thread.queue.put(Left(null.asInstanceOf[Tuple])) thread.queue.put(Right(PortStorageWriterTerminateSignal)) @@ -89,7 +89,7 @@ class OutputPortResultWriterThreadSpec extends AnyFlatSpec { onPutOne = throwing("test putOne failure"), onClose = throwing("test close failure") ) - val thread = new OutputPortResultWriterThread(writer) + val thread = new OutputPortStorageWriterThread(writer) thread.start() thread.queue.put(Left(null.asInstanceOf[Tuple])) thread.queue.put(Right(PortStorageWriterTerminateSignal)) @@ -110,14 +110,14 @@ class OutputPortResultWriterThreadSpec extends AnyFlatSpec { private def installWriterThread( manager: OutputManager, portId: PortIdentity, - thread: OutputPortResultWriterThread + thread: OutputPortStorageWriterThread ): Unit = { val field = classOf[OutputManager] .getDeclaredField("outputPortResultWriterThreads") field.setAccessible(true) field .get(manager) - .asInstanceOf[mutable.HashMap[PortIdentity, OutputPortResultWriterThread]] + .asInstanceOf[mutable.HashMap[PortIdentity, OutputPortStorageWriterThread]] .put(portId, thread) } @@ -130,7 +130,7 @@ class OutputPortResultWriterThreadSpec extends AnyFlatSpec { ) val portId = PortIdentity() val failingWriter = new StubWriter(onClose = throwing("test close failure")) - val failingThread = new OutputPortResultWriterThread(failingWriter) + val failingThread = new OutputPortStorageWriterThread(failingWriter) failingThread.start() installWriterThread(outputManager, portId, failingThread) val ex = intercept[RuntimeException] { diff --git a/common/config/src/main/resources/storage.conf b/common/config/src/main/resources/storage.conf index 1f39359155c..da2f7ccc198 100644 --- a/common/config/src/main/resources/storage.conf +++ b/common/config/src/main/resources/storage.conf @@ -61,6 +61,9 @@ storage { runtime-statistics-namespace = "workflow-runtime-statistics" runtime-statistics-namespace = ${?STORAGE_ICEBERG_TABLE_RUNTIME_STATISTICS_NAMESPACE} + state-namespace = "operator-port-state" + state-namespace = ${?STORAGE_ICEBERG_TABLE_STATE_NAMESPACE} + commit { batch-size = 4096 # decide the buffer size of our IcebergTableWriter batch-size = ${?STORAGE_ICEBERG_TABLE_COMMIT_BATCH_SIZE} diff --git a/common/config/src/main/scala/org/apache/texera/amber/config/EnvironmentalVariable.scala b/common/config/src/main/scala/org/apache/texera/amber/config/EnvironmentalVariable.scala index 9ec52bba653..123c56505ee 100644 --- a/common/config/src/main/scala/org/apache/texera/amber/config/EnvironmentalVariable.scala +++ b/common/config/src/main/scala/org/apache/texera/amber/config/EnvironmentalVariable.scala @@ -67,6 +67,7 @@ object EnvironmentalVariable { "STORAGE_ICEBERG_TABLE_CONSOLE_MESSAGES_NAMESPACE" val ENV_ICEBERG_TABLE_RUNTIME_STATISTICS_NAMESPACE = "STORAGE_ICEBERG_TABLE_RUNTIME_STATISTICS_NAMESPACE" + val ENV_ICEBERG_TABLE_STATE_NAMESPACE = "STORAGE_ICEBERG_TABLE_STATE_NAMESPACE" val ENV_ICEBERG_TABLE_COMMIT_BATCH_SIZE = "STORAGE_ICEBERG_TABLE_COMMIT_BATCH_SIZE" val ENV_ICEBERG_TABLE_COMMIT_NUM_RETRIES = "STORAGE_ICEBERG_TABLE_COMMIT_NUM_RETRIES" val ENV_ICEBERG_TABLE_COMMIT_MIN_WAIT_MS = "STORAGE_ICEBERG_TABLE_COMMIT_MIN_WAIT_MS" diff --git a/common/config/src/main/scala/org/apache/texera/amber/config/StorageConfig.scala b/common/config/src/main/scala/org/apache/texera/amber/config/StorageConfig.scala index 728e3c0c2de..07447cfdbee 100644 --- a/common/config/src/main/scala/org/apache/texera/amber/config/StorageConfig.scala +++ b/common/config/src/main/scala/org/apache/texera/amber/config/StorageConfig.scala @@ -54,6 +54,8 @@ object StorageConfig { conf.getString("storage.iceberg.table.console-messages-namespace") val icebergTableRuntimeStatisticsNamespace: String = conf.getString("storage.iceberg.table.runtime-statistics-namespace") + val icebergTableStateNamespace: String = + conf.getString("storage.iceberg.table.state-namespace") val icebergTableCommitBatchSize: Int = conf.getInt("storage.iceberg.table.commit.batch-size") val icebergTableCommitNumRetries: Int = @@ -111,6 +113,7 @@ object StorageConfig { "STORAGE_ICEBERG_TABLE_CONSOLE_MESSAGES_NAMESPACE" val ENV_ICEBERG_TABLE_RUNTIME_STATISTICS_NAMESPACE = "STORAGE_ICEBERG_TABLE_RUNTIME_STATISTICS_NAMESPACE" + val ENV_ICEBERG_TABLE_STATE_NAMESPACE = "STORAGE_ICEBERG_TABLE_STATE_NAMESPACE" val ENV_ICEBERG_TABLE_COMMIT_BATCH_SIZE = "STORAGE_ICEBERG_TABLE_COMMIT_BATCH_SIZE" val ENV_ICEBERG_TABLE_COMMIT_NUM_RETRIES = "STORAGE_ICEBERG_TABLE_COMMIT_NUM_RETRIES" val ENV_ICEBERG_TABLE_COMMIT_MIN_WAIT_MS = "STORAGE_ICEBERG_TABLE_COMMIT_MIN_WAIT_MS" diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala index 15949ef4717..00f6c70ba73 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala @@ -72,6 +72,7 @@ object DocumentFactory { case RESULT => StorageConfig.icebergTableResultNamespace case CONSOLE_MESSAGES => StorageConfig.icebergTableConsoleMessagesNamespace case RUNTIME_STATISTICS => StorageConfig.icebergTableRuntimeStatisticsNamespace + case STATE => StorageConfig.icebergTableStateNamespace case _ => throw new IllegalArgumentException(s"Resource type $resourceType is not supported") } @@ -119,6 +120,7 @@ object DocumentFactory { case RESULT => StorageConfig.icebergTableResultNamespace case CONSOLE_MESSAGES => StorageConfig.icebergTableConsoleMessagesNamespace case RUNTIME_STATISTICS => StorageConfig.icebergTableRuntimeStatisticsNamespace + case STATE => StorageConfig.icebergTableStateNamespace case _ => throw new IllegalArgumentException(s"Resource type $resourceType is not supported") } diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala index 0fbee64457d..291c31896b0 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala @@ -34,6 +34,7 @@ object VFSResourceType extends Enumeration { val RESULT: Value = Value("result") val RUNTIME_STATISTICS: Value = Value("runtimeStatistics") val CONSOLE_MESSAGES: Value = Value("consoleMessages") + val STATE: Value = Value("state") } object VFSURIFactory { @@ -83,18 +84,25 @@ object VFSURIFactory { } /** - * Create a URI pointing to a result storage + * Create the base URI for a port. Result and state URIs are derived + * from this base via `resultURI` / `stateURI`. */ - def createResultURI( + def createPortBaseURI( workflowId: WorkflowIdentity, executionId: ExecutionIdentity, globalPortId: GlobalPortIdentity - ): URI = { - val baseUri = - s"$VFS_FILE_URI_SCHEME:///wid/${workflowId.id}/eid/${executionId.id}/globalportid/${globalPortId.serializeAsString}" + ): URI = + new URI( + s"$VFS_FILE_URI_SCHEME:///wid/${workflowId.id}/eid/${executionId.id}" + + s"/globalportid/${globalPortId.serializeAsString}" + ) - new URI(s"$baseUri/${VFSResourceType.RESULT.toString.toLowerCase}") - } + def resultURI(baseURI: URI): URI = appendResource(baseURI, VFSResourceType.RESULT) + + def stateURI(baseURI: URI): URI = appendResource(baseURI, VFSResourceType.STATE) + + private def appendResource(baseURI: URI, resourceType: VFSResourceType.Value): URI = + new URI(s"$baseURI/${resourceType.toString.toLowerCase}") /** * Create a URI pointing to runtime statistics diff --git a/common/workflow-core/src/test/scala/org/apache/texera/amber/core/storage/VFSURIFactorySpec.scala b/common/workflow-core/src/test/scala/org/apache/texera/amber/core/storage/VFSURIFactorySpec.scala index 6fbe35873a7..0b8ae4a19c0 100644 --- a/common/workflow-core/src/test/scala/org/apache/texera/amber/core/storage/VFSURIFactorySpec.scala +++ b/common/workflow-core/src/test/scala/org/apache/texera/amber/core/storage/VFSURIFactorySpec.scala @@ -42,23 +42,30 @@ class VFSURIFactorySpec extends AnyFlatSpec { input = true ) - "VFSURIFactory.createResultURI" should "include workflow, execution, port, and the result resource type" in { - val uri = VFSURIFactory.createResultURI(workflowId, executionId, portId) - assert(uri.getScheme == VFSURIFactory.VFS_FILE_URI_SCHEME) - val path = uri.getPath + "VFSURIFactory.createPortBaseURI" should "include workflow, execution, and port segments without a resource type" in { + val baseURI = VFSURIFactory.createPortBaseURI(workflowId, executionId, portId) + assert(baseURI.getScheme == VFSURIFactory.VFS_FILE_URI_SCHEME) + val path = baseURI.getPath assert(path.contains("/wid/7")) assert(path.contains("/eid/11")) assert(path.contains("/globalportid/")) - assert(path.endsWith("/result")) + assert(!path.endsWith("/result")) + assert(!path.endsWith("/state")) } - it should "round-trip through decodeURI" in { - val uri = VFSURIFactory.createResultURI(workflowId, executionId, portId) - val (wid, eid, globalPortIdOpt, resourceType) = VFSURIFactory.decodeURI(uri) + "VFSURIFactory.resultURI / stateURI" should "append the resource segment and round-trip through decodeURI" in { + val baseURI = VFSURIFactory.createPortBaseURI(workflowId, executionId, portId) + val resultURI = VFSURIFactory.resultURI(baseURI) + val stateURI = VFSURIFactory.stateURI(baseURI) + assert(resultURI.getPath.endsWith("/result")) + assert(stateURI.getPath.endsWith("/state")) + + val (wid, eid, globalPortIdOpt, resourceType) = VFSURIFactory.decodeURI(resultURI) assert(wid == workflowId) assert(eid == executionId) assert(globalPortIdOpt.contains(portId)) assert(resourceType == VFSResourceType.RESULT) + assert(VFSURIFactory.decodeURI(stateURI)._4 == VFSResourceType.STATE) } "VFSURIFactory.createRuntimeStatisticsURI" should "produce a runtimeStatistics URI without an opid segment" in { diff --git a/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergDocumentSpec.scala b/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergDocumentSpec.scala index 8fdf039f3ea..b865fff94de 100644 --- a/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergDocumentSpec.scala +++ b/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergDocumentSpec.scala @@ -20,6 +20,7 @@ package org.apache.texera.amber.storage.result.iceberg import org.apache.texera.amber.config.StorageConfig +import org.apache.texera.amber.core.state.State import org.apache.texera.amber.core.storage.model.{VirtualDocument, VirtualDocumentSpec} import org.apache.texera.amber.core.storage.{DocumentFactory, IcebergCatalogInstance, VFSURIFactory} import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple} @@ -51,6 +52,7 @@ class IcebergDocumentSpec extends VirtualDocumentSpec[Tuple] with BeforeAndAfter var deserde: (IcebergSchema, Record) => Tuple = _ var catalog: Catalog = _ val tableNamespace = "test_namespace" + var baseURI: URI = _ var uri: URI = _ override def beforeAll(): Unit = { @@ -79,7 +81,7 @@ class IcebergDocumentSpec extends VirtualDocumentSpec[Tuple] with BeforeAndAfter override def beforeEach(): Unit = { // Generate a unique table name for each test - uri = VFSURIFactory.createResultURI( + baseURI = VFSURIFactory.createPortBaseURI( WorkflowIdentity(0), ExecutionIdentity(0), GlobalPortIdentity( @@ -91,6 +93,7 @@ class IcebergDocumentSpec extends VirtualDocumentSpec[Tuple] with BeforeAndAfter PortIdentity() ) ) + uri = VFSURIFactory.resultURI(baseURI) DocumentFactory.createDocument(uri, amberSchema) super.beforeEach() } @@ -141,6 +144,84 @@ class IcebergDocumentSpec extends VirtualDocumentSpec[Tuple] with BeforeAndAfter } } + it should "round trip materialized state documents" in { + val stateUri = VFSURIFactory.stateURI(baseURI) + DocumentFactory.createDocument(stateUri, State.schema) + val stateDocument = + DocumentFactory.openDocument(stateUri)._1.asInstanceOf[VirtualDocument[Tuple]] + val state = State( + Map( + "loop_counter" -> 3, + "name" -> "outer-loop", + "payload" -> Array[Byte](0, 1, 2, 3), + "nested" -> Map("enabled" -> true, "values" -> List(1, 2, 3)) + ) + ) + + val writer = stateDocument.writer(UUID.randomUUID().toString) + writer.open() + writer.putOne(state.toTuple) + writer.close() + + val storedRows = stateDocument.get().toList + assert(storedRows.length == 1) + val deserialized = State.fromTuple(storedRows.head).values + assert(deserialized("loop_counter") == 3L) + assert(deserialized("name") == "outer-loop") + assert(deserialized("payload").asInstanceOf[Array[Byte]].sameElements(Array[Byte](0, 1, 2, 3))) + assert(deserialized("nested").asInstanceOf[Map[String, Any]]("enabled") == true) + assert(deserialized("nested").asInstanceOf[Map[String, Any]]("values") == List(1L, 2L, 3L)) + } + + it should "materialize multiple states as rows in one state table" in { + val stateUri = VFSURIFactory.stateURI(baseURI) + DocumentFactory.createDocument(stateUri, State.schema) + val stateDocument = + DocumentFactory.openDocument(stateUri)._1.asInstanceOf[VirtualDocument[Tuple]] + val states: List[State] = List( + State(Map("loop_counter" -> 0, "i" -> 1, "payload" -> Array[Byte](1, 2, 3))), + State( + Map( + "loop_counter" -> 1, + "i" -> 2, + "payload" -> Array[Byte](4, 5, 6), + "nested" -> Map("values" -> List(3, 4)) + ) + ) + ) + + val writer = stateDocument.writer(UUID.randomUUID().toString) + writer.open() + states.foreach(state => writer.putOne(state.toTuple)) + writer.close() + + val deserializedStates = + stateDocument + .get() + .toList + .map(State.fromTuple) + .sortBy(_.values("loop_counter").asInstanceOf[Long]) + assert(deserializedStates.length == states.length) + deserializedStates.zip(states).foreach { + case (actual, expected) => + assert( + actual.values("loop_counter") == expected.values("loop_counter").asInstanceOf[Int].toLong + ) + assert(actual.values("i") == expected.values("i").asInstanceOf[Int].toLong) + assert( + actual + .values("payload") + .asInstanceOf[Array[Byte]] + .sameElements(expected.values("payload").asInstanceOf[Array[Byte]]) + ) + } + assert( + deserializedStates(1) + .values("nested") + .asInstanceOf[Map[String, Any]]("values") == List(3L, 4L) + ) + } + /** Returns a dynamic proxy for `realTable` that increments `counter` on every `refresh()` call. */ private def tableWithRefreshSpy(realTable: Table, counter: AtomicInteger): Table = Proxy diff --git a/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergTableStatsSpec.scala b/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergTableStatsSpec.scala index 175ebc2c01b..b7611f6f772 100644 --- a/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergTableStatsSpec.scala +++ b/common/workflow-core/src/test/scala/org/apache/texera/amber/storage/result/iceberg/IcebergTableStatsSpec.scala @@ -50,16 +50,18 @@ class IcebergTableStatsSpec extends AnyFlatSpec with BeforeAndAfterAll with Suit var deserde: (IcebergSchema, Record) => Tuple = _ var catalog: Catalog = _ val tableNamespace = "test_namespace" - var uri: URI = VFSURIFactory.createResultURI( - WorkflowIdentity(0), - ExecutionIdentity(0), - GlobalPortIdentity( - PhysicalOpIdentity( - logicalOpId = - OperatorIdentity(s"test_table_${UUID.randomUUID().toString.replace("-", "")}"), - layerName = "main" - ), - PortIdentity() + var uri: URI = VFSURIFactory.resultURI( + VFSURIFactory.createPortBaseURI( + WorkflowIdentity(0), + ExecutionIdentity(0), + GlobalPortIdentity( + PhysicalOpIdentity( + logicalOpId = + OperatorIdentity(s"test_table_${UUID.randomUUID().toString.replace("-", "")}"), + layerName = "main" + ), + PortIdentity() + ) ) )