"
case _ =>
throw new RuntimeException(
s"Expected byte array for binary type field, but got: ${value.getClass.getName}"
)
}
case AttributeType.STRING =>
val stringValue = value.asInstanceOf[String]
if (stringValue.length > maxStringLength && !isVisualization)
stringValue.take(maxStringLength) + "..."
else
stringValue
case _ => value
}
}
}
.toArray[Any]
TupleUtils.tuple2json(tuple.schema, processedFields)
}.toList
}
/**
* convert Tuple from engine's format to JSON format
*/
private def tuplesToWebData(
mode: WebOutputMode,
table: List[Tuple]
): WebDataUpdate = {
val tableInJson = convertTuplesToJson(table, mode == SetSnapshotMode())
WebDataUpdate(mode, tableInJson)
}
/**
* For SET_SNAPSHOT output mode: result is the latest snapshot
* FOR SET_DELTA output mode:
* - for insert-only delta: effectively the same as latest snapshot
* - for insert-retract delta: the union of all delta outputs, not compacted to a snapshot
*
* Produces the WebResultUpdate to send to frontend from a result update from the engine.
*/
private def convertWebResultUpdate(
workflowIdentity: WorkflowIdentity,
executionId: ExecutionIdentity,
physicalOps: List[PhysicalOp],
oldTupleCount: Int,
newTupleCount: Int
): WebResultUpdate = {
val outputMode = physicalOps
.flatMap(op => op.outputPorts)
.filter({
case (portId, (port, links, schema)) => !portId.internal
})
.map({
case (portId, (port, links, schema)) => port.mode
})
.head
val webOutputMode: WebOutputMode = {
outputMode match {
// currently, only table outputs are using these modes
case OutputMode.SET_DELTA => SetDeltaMode()
case OutputMode.SET_SNAPSHOT => PaginationMode()
// currently, only visualizations are using single snapshot mode
case OutputMode.SINGLE_SNAPSHOT => SetSnapshotMode()
case OutputMode.Unrecognized(_) =>
throw new RuntimeException(
s"Unrecognized output mode: $outputMode for workflow ${workflowIdentity.id}"
)
}
}
// Cannot assume the storage is available at this point. The storage object is only available
// after a region is scheduled to execute.
val storageUriOption = WorkflowExecutionsResource.getResultUriByLogicalPortId(
executionId,
physicalOps.head.id.logicalOpId,
PortIdentity()
)
storageUriOption match {
case Some(storageUri) =>
val storage: VirtualDocument[Tuple] =
DocumentFactory.openDocument(storageUri)._1.asInstanceOf[VirtualDocument[Tuple]]
val webUpdate = webOutputMode match {
case PaginationMode() =>
val numTuples = storage.getCount
val maxPageIndex =
Math.ceil(numTuples / defaultPageSize.toDouble).toInt
// This can be extremly expensive when we have a lot of pages.
// It causes delays in some obseved cases.
// TODO: try to optimize this.
WebPaginationUpdate(
PaginationMode(),
newTupleCount,
(1 to maxPageIndex).toList
)
case SetSnapshotMode() =>
tuplesToWebData(webOutputMode, storage.get().toList)
case SetDeltaMode() =>
val deltaList = storage.getAfter(oldTupleCount).toList
tuplesToWebData(webOutputMode, deltaList)
case _ =>
throw new RuntimeException(
"update mode combination not supported: " + (webOutputMode, outputMode)
)
}
webUpdate
case None =>
WebPaginationUpdate(
PaginationMode(),
0,
List.empty
)
}
}
/**
* Behavior for different web output modes:
* - PaginationMode (used by view result operator)
* - send new number of tuples and dirty page index
* - SetSnapshotMode (used by visualization in snapshot mode)
* - send entire snapshot result to frontend
* - SetDeltaMode (used by visualization in delta mode)
* - send incremental delta result to frontend
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
sealed abstract class WebOutputMode extends Product with Serializable
/**
* The result update of one operator that will be sent to the frontend.
* Can be either WebPaginationUpdate (for PaginationMode)
* or WebDataUpdate (for SetSnapshotMode or SetDeltaMode)
*/
sealed abstract class WebResultUpdate extends Product with Serializable
@JsonTypeName("PaginationMode")
final case class PaginationMode() extends WebOutputMode
@JsonTypeName("SetSnapshotMode")
final case class SetSnapshotMode() extends WebOutputMode
@JsonTypeName("SetDeltaMode")
final case class SetDeltaMode() extends WebOutputMode
case class WebPaginationUpdate(
mode: PaginationMode,
totalNumTuples: Long,
dirtyPageIndices: List[Int]
) extends WebResultUpdate
case class WebDataUpdate(mode: WebOutputMode, table: List[ObjectNode]) extends WebResultUpdate
}
/**
* ExecutionResultService manages all operator output ports that have storage in one workflow execution.
*
* On each result update from the engine, WorkflowResultService
* - update the result data for each operator,
* - send result update event to the frontend
*/
class ExecutionResultService(
workflowIdentity: WorkflowIdentity,
computingUnitId: Int,
val workflowStateStore: WorkflowStateStore
) extends SubscriptionManager
with LazyLogging {
private val resultPullingFrequency = ApplicationConfig.executionResultPollingInSecs
private var resultUpdateCancellable: Cancellable = _
def attachToExecution(
executionId: ExecutionIdentity,
stateStore: ExecutionStateStore,
physicalPlan: PhysicalPlan,
client: AmberClient
): Unit = {
if (resultUpdateCancellable != null && !resultUpdateCancellable.isCancelled) {
resultUpdateCancellable.cancel()
}
unsubscribeAll()
addSubscription(stateStore.metadataStore.getStateObservable.subscribe {
newState: ExecutionMetadataStore =>
{
if (newState.state == RUNNING) {
if (resultUpdateCancellable == null || resultUpdateCancellable.isCancelled) {
resultUpdateCancellable = AmberRuntime
.scheduleRecurringCallThroughActorSystem(
2.seconds,
resultPullingFrequency.seconds
) {
onResultUpdate(executionId, physicalPlan)
}
}
} else {
if (resultUpdateCancellable != null) resultUpdateCancellable.cancel()
}
}
})
addSubscription(
client
.registerCallback[ExecutionStateUpdate](evt => {
if (
evt.state == COMPLETED || evt.state == FAILED || evt.state == KILLED || evt.state == TERMINATED
) {
logger.info("Workflow execution terminated. Stop update results.")
if (resultUpdateCancellable.cancel() || resultUpdateCancellable.isCancelled) {
// immediately perform final update
onResultUpdate(executionId, physicalPlan)
}
}
})
)
addSubscription(
client.registerCallback[FatalError](_ =>
if (resultUpdateCancellable != null) {
resultUpdateCancellable.cancel()
}
)
)
addSubscription(
workflowStateStore.resultStore.registerDiffHandler((oldState, newState) => {
val buf = mutable.HashMap[String, ExecutionResultService.WebResultUpdate]()
val allTableStats = mutable.Map[String, Map[String, Map[String, Any]]]()
newState.resultInfo
.filter(info => {
// only update those operators with changing tuple count.
!oldState.resultInfo
.contains(info._1) || oldState.resultInfo(info._1).tupleCount != info._2.tupleCount
})
.foreach {
case (opId, info) =>
val oldInfo = oldState.resultInfo.getOrElse(opId, OperatorResultMetadata())
buf(opId.id) = ExecutionResultService.convertWebResultUpdate(
workflowIdentity,
executionId,
physicalPlan.getPhysicalOpsOfLogicalOp(opId),
oldInfo.tupleCount,
info.tupleCount
)
// using the first port for now. TODO: support multiple ports
val outputPortsMap = physicalPlan
.getPhysicalOpsOfLogicalOp(opId)
.headOption
.map(_.outputPorts)
.getOrElse(Map.empty)
val hasSingleSnapshot = outputPortsMap.values.exists {
case (outputPort, _, _) =>
// SINGLE_SNAPSHOT is used for HTML content
outputPort.mode == OutputMode.SINGLE_SNAPSHOT
}
if (!hasSingleSnapshot) {
val storageUri = WorkflowExecutionsResource
.getResultUriByLogicalPortId(
executionId,
opId,
PortIdentity()
)
if (storageUri.nonEmpty) {
val (_, _, globalPortIdOption, _) = VFSURIFactory.decodeURI(storageUri.get)
val opStorage = DocumentFactory.openDocument(storageUri.get)._1
allTableStats(opId.id) = opStorage.getTableStatistics
WorkflowExecutionsResource.updateResultSize(
executionId,
globalPortIdOption.get,
opStorage.getTotalFileSize
)
WorkflowExecutionsResource.updateRuntimeStatsSize(executionId)
WorkflowExecutionsResource.updateConsoleMessageSize(executionId, opId)
}
}
}
Iterable(
WebResultUpdateEvent(
buf.toMap,
allTableStats.toMap
)
)
})
)
// clear all the result metadata
workflowStateStore.resultStore.updateState { _ =>
WorkflowResultStore() // empty result store
}
}
def handleResultPagination(request: ResultPaginationRequest): TexeraWebSocketEvent = {
// calculate from index (pageIndex starts from 1 instead of 0)
val from = request.pageSize * (request.pageIndex - 1)
val latestExecutionId = getLatestExecutionId(workflowIdentity, computingUnitId).getOrElse(
throw new IllegalStateException("No execution is recorded")
)
val storageUriOption = WorkflowExecutionsResource.getResultUriByLogicalPortId(
latestExecutionId,
OperatorIdentity(request.operatorID),
PortIdentity()
)
storageUriOption match {
case Some(storageUri) =>
val (document, schemaOption) = DocumentFactory.openDocument(storageUri)
val virtualDocument = document.asInstanceOf[VirtualDocument[Tuple]]
val columns = {
val schema = schemaOption.get
val allColumns = schema.getAttributeNames
val filteredColumns = request.columnSearch match {
case Some(search) =>
allColumns.filter(col => col.toLowerCase.contains(search.toLowerCase))
case None => allColumns
}
Some(
filteredColumns.slice(request.columnOffset, request.columnOffset + request.columnLimit)
)
}
val paginationIterable = {
virtualDocument
.getRange(from, from + request.pageSize, columns)
.to(Iterable)
}
val mappedResults = convertTuplesToJson(paginationIterable)
val attributes = paginationIterable.headOption
.map(_.getSchema.getAttributes)
.getOrElse(List.empty)
PaginatedResultEvent.apply(request, mappedResults, attributes)
case None =>
// Handle the case when storageUri is empty
PaginatedResultEvent.apply(request, List.empty, List.empty)
}
}
private def onResultUpdate(executionId: ExecutionIdentity, physicalPlan: PhysicalPlan): Unit = {
workflowStateStore.resultStore.updateState { _ =>
val newInfo: Map[OperatorIdentity, OperatorResultMetadata] = {
WorkflowExecutionsResource
.getResultUrisByExecutionId(executionId)
.map(uri => {
val count = DocumentFactory.openDocument(uri)._1.getCount.toInt
val (_, _, globalPortIdOption, _) = VFSURIFactory.decodeURI(uri)
// Retrieve the mode of the specified output port
val mode = physicalPlan
.getPhysicalOpsOfLogicalOp(globalPortIdOption.get.opId.logicalOpId)
.flatMap(_.outputPorts.get(globalPortIdOption.get.portId))
.map(_._1.mode)
.head
val changeDetector =
if (mode == OutputMode.SET_SNAPSHOT) {
UUID.randomUUID.toString
} else ""
(globalPortIdOption.get.opId.logicalOpId, OperatorResultMetadata(count, changeDetector))
})
.toMap
}
WorkflowResultStore(newInfo)
}
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/service/ExecutionRuntimeService.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.typesafe.scalalogging.LazyLogging
import org.apache.texera.amber.core.virtualidentity.EmbeddedControlMessageIdentity
import org.apache.texera.amber.engine.architecture.controller.ExecutionStateUpdate
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
EmptyRequest,
TakeGlobalCheckpointRequest
}
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState._
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.FaultToleranceConfig
import org.apache.texera.amber.engine.common.client.AmberClient
import org.apache.texera.web.model.websocket.request._
import org.apache.texera.web.storage.ExecutionStateStore
import org.apache.texera.web.storage.ExecutionStateStore.updateWorkflowState
import org.apache.texera.web.{SubscriptionManager, WebsocketInput}
import java.net.URI
import java.util.UUID
class ExecutionRuntimeService(
client: AmberClient,
stateStore: ExecutionStateStore,
wsInput: WebsocketInput,
reconfigurationService: ExecutionReconfigurationService,
logConf: Option[FaultToleranceConfig],
workflowId: Long,
emailNotificationEnabled: Boolean,
userEmailOpt: Option[String],
sessionUri: URI
) extends SubscriptionManager
with LazyLogging {
private val emailNotificationService = for {
email <- userEmailOpt
if emailNotificationEnabled
} yield new EmailNotificationService(
new WorkflowEmailNotifier(
workflowId,
email,
sessionUri
)
)
//Receive skip tuple
addSubscription(wsInput.subscribe((req: SkipTupleRequest, uidOpt) => {
throw new RuntimeException("skipping tuple is temporarily disabled")
}))
// Receive execution state update from Amber
addSubscription(client.registerCallback[ExecutionStateUpdate]((evt: ExecutionStateUpdate) => {
stateStore.metadataStore.updateState(metadataStore =>
updateWorkflowState(evt.state, metadataStore)
)
emailNotificationService.foreach(_.processEmailNotificationIfNeeded(evt.state))
if (evt.state == COMPLETED) {
client.shutdown()
stateStore.statsStore.updateState(stats => stats.withEndTimeStamp(System.currentTimeMillis()))
}
}))
// Receive Pause
addSubscription(wsInput.subscribe((req: WorkflowPauseRequest, uidOpt) => {
stateStore.metadataStore.updateState(metadataStore =>
updateWorkflowState(PAUSING, metadataStore)
)
client.controllerInterface.pauseWorkflow(EmptyRequest(), ())
}))
// Receive Resume
addSubscription(wsInput.subscribe((req: WorkflowResumeRequest, uidOpt) => {
reconfigurationService.performReconfigurationOnResume()
stateStore.metadataStore.updateState(metadataStore =>
updateWorkflowState(RESUMING, metadataStore)
)
client.controllerInterface
.resumeWorkflow(EmptyRequest(), ())
.onSuccess(_ =>
stateStore.metadataStore.updateState(metadataStore =>
updateWorkflowState(RUNNING, metadataStore)
)
)
}))
// Receive Kill
addSubscription(wsInput.subscribe((req: WorkflowKillRequest, uidOpt) => {
client.shutdown()
stateStore.statsStore.updateState(stats => stats.withEndTimeStamp(System.currentTimeMillis()))
stateStore.metadataStore.updateState(metadataStore =>
updateWorkflowState(KILLED, metadataStore)
)
}))
// Receive Interaction
addSubscription(wsInput.subscribe((req: WorkflowCheckpointRequest, uidOpt) => {
assert(
logConf.nonEmpty,
"Fault tolerance log folder is not established. Unable to take a global checkpoint."
)
val checkpointId = EmbeddedControlMessageIdentity(s"Checkpoint_${UUID.randomUUID().toString}")
val uri = logConf.get.writeTo.resolve(checkpointId.toString)
client.controllerInterface.takeGlobalCheckpoint(
TakeGlobalCheckpointRequest(estimationOnly = false, checkpointId, uri.toString),
()
)
}))
override def unsubscribeAll(): Unit = {
super.unsubscribeAll()
emailNotificationService.foreach(_.shutdown())
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/service/ExecutionStatsService.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.google.protobuf.timestamp.Timestamp
import com.typesafe.scalalogging.LazyLogging
import org.apache.texera.amber.core.storage.model.BufferedItemWriter
import org.apache.texera.amber.core.storage.result.ResultSchema
import org.apache.texera.amber.core.storage.{DocumentFactory, VFSURIFactory}
import org.apache.texera.amber.core.tuple.Tuple
import org.apache.texera.amber.core.workflow.WorkflowContext
import org.apache.texera.amber.core.workflowruntimestate.FatalErrorType.EXECUTION_FAILURE
import org.apache.texera.amber.core.workflowruntimestate.WorkflowFatalError
import org.apache.texera.amber.engine.architecture.controller.{
ExecutionStateUpdate,
ExecutionStatsUpdate,
FatalError,
RuntimeStatisticsPersist,
WorkerAssignmentUpdate,
WorkflowRecoveryStatus
}
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState.{
COMPLETED,
FAILED,
KILLED
}
import org.apache.texera.amber.engine.common.Utils
import org.apache.texera.amber.engine.common.Utils.maptoStatusCode
import org.apache.texera.amber.engine.common.client.AmberClient
import org.apache.texera.amber.engine.common.executionruntimestate.{
OperatorMetrics,
OperatorStatistics,
OperatorWorkerMapping
}
import org.apache.texera.amber.error.ErrorUtils.{
getOperatorFromActorIdOpt,
getStackTraceWithAllCauses
}
import org.apache.texera.web.SubscriptionManager
import org.apache.texera.web.model.websocket.event.{
ExecutionDurationUpdateEvent,
OperatorAggregatedMetrics,
OperatorStatisticsUpdateEvent,
WorkerAssignmentUpdateEvent
}
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowExecutionsResource
import org.apache.texera.web.storage.ExecutionStateStore
import org.apache.texera.web.storage.ExecutionStateStore.updateWorkflowState
import java.time.Instant
import java.util.concurrent.Executors
class ExecutionStatsService(
client: AmberClient,
stateStore: ExecutionStateStore,
workflowContext: WorkflowContext
) extends SubscriptionManager
with LazyLogging {
private val (metricsPersistThread, runtimeStatsWriter) = {
val thread = Executors.newSingleThreadExecutor()
val uri = VFSURIFactory.createRuntimeStatisticsURI(
workflowContext.workflowId,
workflowContext.executionId
)
val writer = DocumentFactory
.createDocument(uri, ResultSchema.runtimeStatisticsSchema)
.writer("runtime_statistics")
.asInstanceOf[BufferedItemWriter[Tuple]]
WorkflowExecutionsResource.updateRuntimeStatsUri(
workflowContext.workflowId.id,
workflowContext.executionId.id,
uri
)
writer.open()
(thread, writer)
}
private var lastPersistedMetrics: Map[String, OperatorMetrics] =
Map.empty[String, OperatorMetrics]
registerCallbacks()
addSubscription(
stateStore.statsStore.registerDiffHandler((oldState, newState) => {
// Update operator stats if any operator updates its stat
if (newState.operatorInfo.toSet != oldState.operatorInfo.toSet) {
Iterable(
OperatorStatisticsUpdateEvent(newState.operatorInfo.collect {
case x =>
val metrics = x._2
val inMap = metrics.operatorStatistics.inputMetrics
.map(pm => pm.portId.id.toString -> pm.tupleMetrics.count)
.toMap
val outMap = metrics.operatorStatistics.outputMetrics
.map(pm => pm.portId.id.toString -> pm.tupleMetrics.count)
.toMap
val res = OperatorAggregatedMetrics(
Utils.aggregatedStateToString(metrics.operatorState),
metrics.operatorStatistics.inputMetrics.map(_.tupleMetrics.count).sum,
metrics.operatorStatistics.inputMetrics.map(_.tupleMetrics.size).sum,
inMap,
metrics.operatorStatistics.outputMetrics.map(_.tupleMetrics.count).sum,
metrics.operatorStatistics.outputMetrics.map(_.tupleMetrics.size).sum,
outMap,
metrics.operatorStatistics.numWorkers,
metrics.operatorStatistics.dataProcessingTime,
metrics.operatorStatistics.controlProcessingTime,
metrics.operatorStatistics.idleTime
)
(x._1, res)
})
)
} else {
Iterable.empty
}
})
)
addSubscription(
stateStore.statsStore.registerDiffHandler((oldState, newState) => {
// update operators' workers.
if (newState.operatorWorkerMapping != oldState.operatorWorkerMapping) {
newState.operatorWorkerMapping
.map { opToWorkers =>
WorkerAssignmentUpdateEvent(opToWorkers.operatorId, opToWorkers.workerIds)
}
} else {
Iterable()
}
})
)
addSubscription(
stateStore.statsStore.registerDiffHandler((oldState, newState) => {
// update execution duration.
if (
newState.startTimeStamp != oldState.startTimeStamp || newState.endTimeStamp != oldState.endTimeStamp
) {
if (newState.endTimeStamp != 0) {
Iterable(
ExecutionDurationUpdateEvent(
newState.endTimeStamp - newState.startTimeStamp,
isRunning = false
)
)
} else {
val currentTime = System.currentTimeMillis()
Iterable(
ExecutionDurationUpdateEvent(currentTime - newState.startTimeStamp, isRunning = true)
)
}
} else {
Iterable()
}
})
)
private[this] def registerCallbacks(): Unit = {
registerCallbackOnWorkflowStatsUpdate()
registerCallbackOnWorkerAssignedUpdate()
registerCallbackOnWorkflowRecoveryUpdate()
registerCallbackOnFatalError()
}
private[this] def registerCallbackOnWorkflowStatsUpdate(): Unit = {
// Register callback for UI updates (UI state store update only, no persistence)
addSubscription(
client
.registerCallback[ExecutionStatsUpdate]((evt: ExecutionStatsUpdate) => {
stateStore.statsStore.updateState { statsStore =>
statsStore.withOperatorInfo(evt.operatorMetrics)
}
})
)
// Register callback for statistics persistence (persistence only, no UI update)
addSubscription(
client
.registerCallback[RuntimeStatisticsPersist]((evt: RuntimeStatisticsPersist) => {
metricsPersistThread.execute(() => {
storeRuntimeStatistics(computeStatsDiff(evt.operatorMetrics))
lastPersistedMetrics = evt.operatorMetrics
})
})
)
}
addSubscription(
client.registerCallback[ExecutionStateUpdate] {
case ExecutionStateUpdate(state: WorkflowAggregatedState.Recognized)
if Set(COMPLETED, FAILED, KILLED).contains(state) =>
logger.info("Workflow execution terminated. Commit runtime statistics.")
try {
runtimeStatsWriter.close()
} catch {
case e: Exception =>
logger.error("Failed to close runtime statistics writer", e)
}
case _ =>
}
)
private def computeStatsDiff(
newMetrics: Map[String, OperatorMetrics]
): Map[String, OperatorMetrics] = {
// Default metrics for new operators
val defaultMetrics = OperatorMetrics(
WorkflowAggregatedState.UNINITIALIZED,
OperatorStatistics(Seq.empty, Seq.empty, 0, 0, 0, 0)
)
// Determine new and old keys
val newKeys = newMetrics.keySet.diff(lastPersistedMetrics.keySet)
val oldKeys = lastPersistedMetrics.keySet.diff(newMetrics.keySet)
// Update last metrics with default metrics for new keys
val updatedLastMetrics = lastPersistedMetrics ++ newKeys.map(_ -> defaultMetrics)
// Combine new metrics with old metrics for keys that are no longer present
val completeMetricsMap = newMetrics ++ oldKeys.map(key => key -> updatedLastMetrics(key))
// Transform the complete metrics map to ensure consistent structure
completeMetricsMap.map {
case (key, metrics) =>
key -> OperatorMetrics(
metrics.operatorState,
OperatorStatistics(
metrics.operatorStatistics.inputMetrics,
metrics.operatorStatistics.outputMetrics,
metrics.operatorStatistics.numWorkers,
metrics.operatorStatistics.dataProcessingTime,
metrics.operatorStatistics.controlProcessingTime,
metrics.operatorStatistics.idleTime
)
)
}
}
private def storeRuntimeStatistics(
operatorStatistics: scala.collection.immutable.Map[String, OperatorMetrics]
): Unit = {
try {
operatorStatistics.foreach {
case (operatorId, stat) =>
val runtimeStats = new Tuple(
ResultSchema.runtimeStatisticsSchema,
Array(
operatorId,
new java.sql.Timestamp(System.currentTimeMillis()),
stat.operatorStatistics.inputMetrics.map(_.tupleMetrics.count).sum,
stat.operatorStatistics.inputMetrics.map(_.tupleMetrics.size).sum,
stat.operatorStatistics.outputMetrics.map(_.tupleMetrics.count).sum,
stat.operatorStatistics.outputMetrics.map(_.tupleMetrics.size).sum,
stat.operatorStatistics.dataProcessingTime,
stat.operatorStatistics.controlProcessingTime,
stat.operatorStatistics.idleTime,
stat.operatorStatistics.numWorkers,
maptoStatusCode(stat.operatorState).toInt
)
)
runtimeStatsWriter.putOne(runtimeStats)
}
} catch {
case err: Throwable => logger.error("error occurred when storing runtime statistics", err)
}
}
private[this] def registerCallbackOnWorkerAssignedUpdate(): Unit = {
addSubscription(
client
.registerCallback[WorkerAssignmentUpdate]((evt: WorkerAssignmentUpdate) => {
stateStore.statsStore.updateState { statsStore =>
statsStore.withOperatorWorkerMapping(
evt.workerMapping
.map({
case (opId, workerIds) => OperatorWorkerMapping(opId, workerIds.toSeq)
})
.toSeq
)
}
})
)
}
private[this] def registerCallbackOnWorkflowRecoveryUpdate(): Unit = {
addSubscription(
client
.registerCallback[WorkflowRecoveryStatus]((evt: WorkflowRecoveryStatus) => {
stateStore.metadataStore.updateState { metadataStore =>
metadataStore.withIsRecovering(evt.isRecovering)
}
})
)
}
private[this] def registerCallbackOnFatalError(): Unit = {
addSubscription(
client
.registerCallback[FatalError]((evt: FatalError) => {
client.shutdown()
val (operatorId, workerId) = getOperatorFromActorIdOpt(evt.fromActor)
stateStore.statsStore.updateState(stats =>
stats.withEndTimeStamp(System.currentTimeMillis())
)
stateStore.metadataStore.updateState { metadataStore =>
logger.error("error occurred in execution", evt.e)
updateWorkflowState(FAILED, metadataStore).addFatalErrors(
WorkflowFatalError(
EXECUTION_FAILURE,
Timestamp(Instant.now),
evt.e.toString,
getStackTraceWithAllCauses(evt.e),
operatorId,
workerId
)
)
}
})
)
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/service/ExecutionsMetadataPersistService.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.typesafe.scalalogging.LazyLogging
import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity}
import org.apache.texera.dao.SqlServer
import org.apache.texera.dao.jooq.generated.tables.daos.WorkflowExecutionsDao
import org.apache.texera.dao.jooq.generated.tables.pojos.WorkflowExecutions
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowVersionResource._
import java.sql.Timestamp
/**
* This global object handles inserting a new entry to the DB to store metadata information about every workflow execution
* It also updates the entry if an execution status is updated
*/
object ExecutionsMetadataPersistService extends LazyLogging {
private def context =
SqlServer
.getInstance()
.createDSLContext()
private def workflowExecutionsDao =
new WorkflowExecutionsDao(
context.configuration
)
/**
* This method inserts a new entry of a workflow execution in the database and returns the generated eId
*
* @param workflowId the given workflow
* @param uid user id that initiated the execution
* @return generated execution ID
*/
def insertNewExecution(
workflowId: WorkflowIdentity,
uid: Option[Integer],
executionName: String,
environmentVersion: String,
computingUnitId: Integer
): ExecutionIdentity = {
// first retrieve the latest version of this workflow
val vid = getLatestVersion(workflowId.id.toInt)
val newExecution = new WorkflowExecutions()
if (executionName != "") {
newExecution.setName(executionName)
}
newExecution.setVid(vid)
newExecution.setUid(uid.orNull)
newExecution.setStartingTime(new Timestamp(System.currentTimeMillis()))
newExecution.setEnvironmentVersion(environmentVersion)
// Set computing unit ID if provided
newExecution.setCuid(computingUnitId)
workflowExecutionsDao.insert(newExecution)
ExecutionIdentity(newExecution.getEid.longValue())
}
def tryGetExistingExecution(executionId: ExecutionIdentity): Option[WorkflowExecutions] = {
try {
Some(workflowExecutionsDao.fetchOneByEid(executionId.id.toInt))
} catch {
case t: Throwable =>
logger.info("Unable to get execution. Error = " + t.getMessage)
None
}
}
def tryUpdateExistingExecution(
executionId: ExecutionIdentity
)(updateFunc: WorkflowExecutions => Unit): Unit = {
try {
val execution = workflowExecutionsDao.fetchOneByEid(executionId.id.toInt)
updateFunc(execution)
workflowExecutionsDao.update(execution)
} catch {
case t: Throwable =>
logger.info("Unable to update execution. Error = " + t.getMessage)
}
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/service/ResultExportService.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.fasterxml.jackson.core.`type`.TypeReference
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.github.tototoshi.csv.CSVWriter
import org.apache.texera.amber.config.EnvironmentalVariable
import org.apache.texera.amber.core.storage.DocumentFactory
import org.apache.texera.amber.core.storage.model.VirtualDocument
import org.apache.texera.amber.core.tuple.Tuple
import org.apache.texera.amber.core.virtualidentity.{OperatorIdentity, WorkflowIdentity}
import org.apache.texera.amber.core.workflow.PortIdentity
import org.apache.texera.amber.util.ArrowUtils
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.ArrowFileWriter
import org.apache.commons.io.IOUtils
import org.apache.commons.lang3.StringUtils
import org.apache.texera.auth.JwtAuth
import org.apache.texera.auth.JwtAuth.{TOKEN_EXPIRE_TIME_IN_MINUTES, jwtClaims}
import org.apache.texera.dao.jooq.generated.tables.pojos.User
import org.apache.texera.web.model.http.request.result.{OperatorExportInfo, ResultExportRequest}
import org.apache.texera.web.model.http.response.result.ResultExportResponse
import org.apache.texera.web.resource.dashboard.user.workflow.{
WorkflowExecutionsResource,
WorkflowVersionResource
}
import org.apache.texera.web.service.WorkflowExecutionService.getLatestExecutionId
import java.io.{FilterOutputStream, IOException, OutputStream}
import java.net.{HttpURLConnection, URL, URLEncoder}
import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit
import java.util.zip.{ZipEntry, ZipOutputStream}
import javax.ws.rs.WebApplicationException
import javax.ws.rs.core.{MediaType, Response, StreamingOutput}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Using
object Constants {
val CHUNK_SIZE = 10
}
/**
* A simple wrapper that ignores 'close()' calls on the underlying stream.
* This allows each operator's writer to call close() without ending the entire ZipOutputStream.
*/
private class NonClosingOutputStream(os: OutputStream) extends FilterOutputStream(os) {
@throws[IOException]
override def close(): Unit = {
// do not actually close the underlying stream
super.flush()
// omit super.close()
}
}
object ResultExportService {
lazy val fileServiceUploadOneFileToDatasetEndpoint: String =
sys.env
.getOrElse(
EnvironmentalVariable.ENV_FILE_SERVICE_UPLOAD_ONE_FILE_TO_DATASET_ENDPOINT,
"http://localhost:9092/api/dataset/did/upload"
)
.trim
}
class ResultExportService(workflowIdentity: WorkflowIdentity, computingUnitId: Int) {
import ResultExportService._
/**
* Export operator results to a dataset and return the result.
*/
def exportToDataset(
user: User,
request: ResultExportRequest
): Response = {
val successMessages = new mutable.ListBuffer[String]()
val errorMessages = new mutable.ListBuffer[String]()
request.operators.foreach { op =>
try {
val (msgOpt, errOpt) = exportSingleOperatorToDataset(user, request, op)
msgOpt.foreach(successMessages += _)
errOpt.foreach(errorMessages += _)
} catch {
case ex: Exception =>
errorMessages += s"Error exporting operator $op: ${ex.getMessage}"
}
}
var exportResponse: ResultExportResponse = null
if (errorMessages.isEmpty) {
exportResponse = ResultExportResponse("success", successMessages.mkString("\n"))
} else if (successMessages.isEmpty) {
exportResponse = ResultExportResponse("error", errorMessages.mkString("\n"))
} else {
// At least one success, so we consider overall success (with partial possible).
exportResponse = ResultExportResponse("success", successMessages.mkString("\n"))
}
Response.ok(exportResponse).build()
}
/**
* Export operator results as downloadable files.
* If multiple operators are selected, their results are streamed as a ZIP file.
* If a single operator is selected, its result is streamed directly.
*/
def exportToLocal(request: ResultExportRequest): Response = {
if (request.operators.size > 1) {
val (zipStream, zipFileNameOpt) = exportOperatorsAsZip(request)
if (zipStream == null) {
throw new RuntimeException("Zip stream is null")
}
val fileName = zipFileNameOpt.getOrElse("operators.zip")
Response
.ok(zipStream, "application/zip")
.header("Content-Disposition", s"""attachment; filename="$fileName"""")
.build()
} else {
val op = request.operators.head
val (streamingOutput, fileNameOpt) = exportOperatorResultAsStream(request, op)
if (streamingOutput == null) {
throw new RuntimeException("Failed to export operator")
}
val fileName = fileNameOpt.getOrElse("download.dat")
Response
.ok(streamingOutput, MediaType.APPLICATION_OCTET_STREAM)
.header("Content-Disposition", s"""attachment; filename="$fileName"""")
.build()
}
}
/**
* Export a single operator's result and handle different export types.
*/
private def exportSingleOperatorToDataset(
user: User,
request: ResultExportRequest,
operatorRequest: OperatorExportInfo
): (Option[String], Option[String]) = {
val execIdOpt = getLatestExecutionId(workflowIdentity, computingUnitId)
if (execIdOpt.isEmpty) {
return (None, Some(s"Workflow ${request.workflowId} has no execution result"))
}
val operatorDocument = getOperatorDocument(operatorRequest.id, computingUnitId)
if (operatorDocument == null || operatorDocument.getCount == 0)
return (None, Some(s"No results to export for operator $operatorRequest"))
val attributeNames =
operatorDocument.getRange(0, 1).to(Iterable).head.getSchema.getAttributeNames // small cost
val writer: OutputStream => Unit = operatorRequest.outputType match {
case "csv" => out => streamDocumentAsCSV(operatorDocument, out, Some(attributeNames))
case "arrow" => out => streamDocumentAsArrow(operatorDocument, out)
case "html" => out => streamDocumentAsHTML(out, operatorDocument)
case "data" => out => streamCellData(out, request, operatorDocument)
case "parquet" => out => streamDocumentAsParquetZip(operatorDocument, out)
case _ => out => streamDocumentAsCSV(operatorDocument, out, Some(attributeNames))
}
saveStreamToDataset(
operatorId = operatorRequest.id,
user = user,
request = request,
extension = operatorRequest.outputType,
writer = writer
)
}
/**
* Export a single operator's results as a streaming response (e.g., for download).
*/
def exportOperatorResultAsStream(
request: ResultExportRequest,
operatorRequest: OperatorExportInfo
): (StreamingOutput, Option[String]) = {
val execIdOpt = getLatestExecutionId(workflowIdentity, computingUnitId)
if (execIdOpt.isEmpty) {
return (null, None)
}
val operatorDocument = getOperatorDocument(operatorRequest.id, computingUnitId)
if (operatorDocument == null || operatorDocument.getCount == 0) {
return (null, None)
}
val fileName =
if (request.filename.isEmpty)
generateFileName(
request,
operatorRequest.id,
operatorRequest.outputType
)
else request.filename
val streamingOutput: StreamingOutput = (out: OutputStream) => {
operatorRequest.outputType match {
case "csv" => streamDocumentAsCSV(operatorDocument, out, None)
case "arrow" => streamDocumentAsArrow(operatorDocument, out)
case "data" => streamCellData(out, request, operatorDocument)
case "html" => streamDocumentAsHTML(out, operatorDocument)
case "parquet" => streamDocumentAsParquetZip(operatorDocument, out)
case _ => streamDocumentAsCSV(operatorDocument, out, None)
}
}
(streamingOutput, Some(fileName))
}
/**
* Export multiple operators' results as a single ZIP file stream.
*/
def exportOperatorsAsZip(
request: ResultExportRequest
): (StreamingOutput, Option[String]) = {
val timestamp = LocalDateTime
.now()
.truncatedTo(ChronoUnit.SECONDS)
.format(DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss"))
val zipFileName = s"${request.workflowName}-$timestamp.zip"
val execIdOpt = getLatestExecutionId(workflowIdentity, computingUnitId)
if (execIdOpt.isEmpty) {
throw new WebApplicationException(
s"No execution result for workflow ${request.workflowId}"
)
}
val streamingOutput: StreamingOutput = new StreamingOutput {
override def write(outputStream: OutputStream): Unit = {
Using.resource(new ZipOutputStream(outputStream)) { zipOut =>
request.operators.foreach { op =>
val operatorDocument = getOperatorDocument(op.id, computingUnitId)
if (operatorDocument == null || operatorDocument.getCount == 0) {
// create an "empty" file for this operator
zipOut.putNextEntry(new ZipEntry(s"${op.id}-empty.txt"))
val msg = s"Operator ${op.id} has no results"
zipOut.write(msg.getBytes(StandardCharsets.UTF_8))
zipOut.closeEntry()
} else {
val operatorFileName = generateFileName(request, op.id, op.outputType)
zipOut.putNextEntry(new ZipEntry(operatorFileName))
val nonClosingStream = new NonClosingOutputStream(zipOut)
op.outputType match {
case "csv" => streamDocumentAsCSV(operatorDocument, nonClosingStream, None)
case "arrow" => streamDocumentAsArrow(operatorDocument, nonClosingStream)
case "data" => streamCellData(nonClosingStream, request, operatorDocument)
case "html" => streamDocumentAsHTML(nonClosingStream, operatorDocument)
case "parquet" => streamDocumentAsParquetZip(operatorDocument, nonClosingStream)
case _ => streamDocumentAsCSV(operatorDocument, nonClosingStream, None)
}
zipOut.closeEntry()
}
}
}
}
}
(streamingOutput, Some(zipFileName))
}
/**
* Streams the entire content of `VirtualDocument` as CSV into `outputStream` in a single pass.
*/
private def streamDocumentAsCSV(
doc: VirtualDocument[Tuple],
outputStream: OutputStream,
maybeHeaders: Option[List[String]]
): Unit = {
val totalCount = doc.getCount
if (totalCount == 0) {
return
}
val iterator = doc.get()
if (!iterator.hasNext) {
return
}
val csvWriter = CSVWriter.open(outputStream)
val headers: List[String] = maybeHeaders match {
case Some(hdrs) =>
hdrs
case None =>
val firstRow = iterator.next()
val inferredHeaders = firstRow.getSchema.getAttributeNames
csvWriter.writeRow(inferredHeaders)
csvWriter.writeRow(firstRow.getFields.toIndexedSeq)
inferredHeaders
}
if (maybeHeaders.isDefined) {
csvWriter.writeRow(headers)
}
val buffer = new ArrayBuffer[Tuple](Constants.CHUNK_SIZE)
while (iterator.hasNext) {
buffer.clear()
var count = 0
while (count < Constants.CHUNK_SIZE && iterator.hasNext) {
buffer += iterator.next()
count += 1
}
buffer.foreach { t =>
csvWriter.writeRow(t.getFields.toIndexedSeq)
}
csvWriter.flush()
}
csvWriter.close()
}
/**
* Streams the entire content of `VirtualDocument` as Arrow into `outputStream` in a single pass.
*/
private def streamDocumentAsArrow(
doc: VirtualDocument[Tuple],
outputStream: OutputStream
): Unit = {
if (doc.getCount == 0) return
val allocator = new RootAllocator()
Using.Manager { use =>
val firstTuple = doc.getRange(0, 1).to(Iterable).head
val schema = firstTuple.getSchema
val arrowSchema = ArrowUtils.fromTexeraSchema(schema)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
use(root)
val channel = Channels.newChannel(outputStream)
val writer = new ArrowFileWriter(root, null, channel)
use(writer)
use(allocator)
writer.start()
val iterator = doc.get()
val buffer = new ArrayBuffer[Tuple](Constants.CHUNK_SIZE)
while (iterator.hasNext) {
buffer.clear()
var count = 0
while (count < Constants.CHUNK_SIZE && iterator.hasNext) {
buffer += iterator.next()
count += 1
}
if (buffer.nonEmpty) {
val currentBatchSize = buffer.size
for (i <- 0 until currentBatchSize) {
val tuple = buffer(i)
ArrowUtils.setTexeraTuple(tuple, i, root)
}
root.setRowCount(currentBatchSize)
writer.writeBatch()
root.clear()
}
}
writer.end()
}
}
/*
* Handle streaming HTML result from a visualization operator's result.
*/
private def streamDocumentAsHTML(
out: OutputStream,
operatorDocument: VirtualDocument[Tuple]
): Unit = {
val results: Iterable[Tuple] = operatorDocument.get().to(Iterable)
val resHead = results.head
val htmlCode = resHead.getField(0).toString
out.write(htmlCode.getBytes(StandardCharsets.UTF_8))
out.flush()
}
/**
* Streams the underlying Parquet files of an Iceberg document into a ZIP archive.
* This avoids re-encoding and uses minimal memory and no temporary disk space.
*/
private def streamDocumentAsParquetZip(
doc: VirtualDocument[Tuple],
outputStream: OutputStream
): Unit = {
try {
val zipStream = doc.asInputStream()
try {
IOUtils.copy(zipStream, outputStream)
} finally {
zipStream.close()
}
} catch {
case e: Exception =>
throw e
}
}
/*
* Handle streaming a single (row, column) from an operator's result.
* This is used for the "data" export type, which exports a single field value.
*/
private def streamCellData(
out: OutputStream,
request: ResultExportRequest,
operatorDocument: VirtualDocument[Tuple]
): Unit = {
val rowIndex = request.rowIndex
val columnIndex = request.columnIndex
if (rowIndex >= operatorDocument.getCount) {
throw new WebApplicationException(
s"Invalid rowIndex ($rowIndex). Total rows: ${operatorDocument.getCount}"
)
}
val selectedRow = operatorDocument
.getRange(rowIndex, rowIndex + 1)
.to(Iterable)
.headOption
.getOrElse(throw new RuntimeException(s"Could not retrieve row at index $rowIndex"))
if (columnIndex >= selectedRow.getFields.length) {
throw new WebApplicationException(
s"Invalid columnIndex ($columnIndex). Total columns: ${selectedRow.getFields.length}"
)
}
val field: Any = selectedRow.getField(columnIndex)
val dataBytes = convertFieldToBytes(field)
out.write(dataBytes)
}
/**
* Generate the VirtualDocument for one operator's result.
* Incorporates the remote code's extra parameter `None` for sub-operator ID.
*/
private def getOperatorDocument(
operatorId: String,
computingUnitId: Int
): VirtualDocument[Tuple] = {
// By now the workflow should finish running
// Only supports external port 0 for now. TODO: support multiple ports
val storageUri = WorkflowExecutionsResource.getResultUriByLogicalPortId(
getLatestExecutionId(workflowIdentity, computingUnitId).get,
OperatorIdentity(operatorId),
PortIdentity()
)
storageUri
.map(uri => DocumentFactory.openDocument(uri)._1.asInstanceOf[VirtualDocument[Tuple]])
.orNull
}
private def saveStreamToDataset(
operatorId: String,
user: User,
request: ResultExportRequest,
extension: String,
writer: OutputStream => Unit
): (Option[String], Option[String]) = {
val fileName =
if (request.filename.isEmpty) generateFileName(request, operatorId, extension)
else request.filename
try {
saveToDatasets(request, user, writer, fileName)
(Some(s"$extension export done for operator $operatorId -> file: $fileName"), None)
} catch {
case ex: Exception =>
(None, Some(s"$extension export failed for operator $operatorId: ${ex.getMessage}"))
}
}
private def convertFieldToBytes(field: Any): Array[Byte] = {
field match {
case data: Array[Byte] => data
case data: String => data.getBytes(StandardCharsets.UTF_8)
case other => other.toString.getBytes(StandardCharsets.UTF_8)
}
}
/**
* Save the pipedInputStream into the specified datasets as a new dataset version.
*/
private def saveToDatasets(
request: ResultExportRequest,
user: User,
fileWriter: OutputStream => Unit,
fileName: String
): Unit = {
request.datasetIds.foreach { did =>
val encodedFilePath = URLEncoder.encode(fileName, StandardCharsets.UTF_8.name())
val message = URLEncoder.encode(
s"Export from workflow ${request.workflowName}",
StandardCharsets.UTF_8.name()
)
val uploadUrl = s"$fileServiceUploadOneFileToDatasetEndpoint"
.replace("did", did.toString) + s"?filePath=$encodedFilePath&message=$message"
var connection: HttpURLConnection = null
try {
val url = new URL(uploadUrl)
connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setDoOutput(true)
connection.setRequestMethod("POST")
connection.setRequestProperty("Content-Type", "application/octet-stream")
connection.setRequestProperty(
"Authorization",
s"Bearer ${JwtAuth.jwtToken(jwtClaims(user, TOKEN_EXPIRE_TIME_IN_MINUTES))}"
)
connection.setChunkedStreamingMode(0)
val outputStream = connection.getOutputStream
fileWriter(outputStream)
outputStream.close()
val responseCode = connection.getResponseCode
if (responseCode != HttpURLConnection.HTTP_OK) {
throw new RuntimeException(s"Failed to upload file. Server responded with: $responseCode")
}
} catch {
case e: Exception =>
throw new RuntimeException(s"Error uploading file to dataset $did: ${e.getMessage}", e)
} finally {
if (connection != null) connection.disconnect()
}
}
}
/**
* Generate a file name for an operator's exported file
*/
private def generateFileName(
request: ResultExportRequest,
operatorId: String,
extension: String
): String = {
val extensionMatch = extension match {
case "parquet" => "zip"
case _ => extension
}
val latestVersion =
WorkflowVersionResource.getLatestVersion(request.workflowId)
val timestamp = LocalDateTime
.now()
.truncatedTo(ChronoUnit.SECONDS)
.format(DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss"))
val rawName =
s"${request.workflowName}-op$operatorId-v$latestVersion-$timestamp.$extensionMatch"
// remove path separators
StringUtils.replaceEach(rawName, Array("/", "\\"), Array("", ""))
}
/**
* Parse a JSON string array of operators into a list of OperatorExportInfo objects.
*/
def parseOperators(operatorsJson: String): List[OperatorExportInfo] = {
new ObjectMapper()
.registerModule(DefaultScalaModule)
.readValue(operatorsJson, new TypeReference[List[OperatorExportInfo]] {})
}
/**
* Validate an export request by checking if any operators are selected.
* Return an error response if none are selected, otherwise None.
*/
def validateExportRequest(request: ResultExportRequest): Option[Response] = {
if (request.operators.isEmpty) {
Some(
Response
.status(Response.Status.BAD_REQUEST)
.`type`(MediaType.APPLICATION_JSON)
.entity(Map("error" -> "No operator selected").asJava)
.build()
)
} else None
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/service/WorkflowEmailNotifier.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.typesafe.scalalogging.LazyLogging
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState._
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowResource
import org.apache.texera.web.resource.{EmailMessage, GmailResource}
import org.hibernate.validator.internal.constraintvalidators.hv.EmailValidator
import java.net.URI
import java.time.format.DateTimeFormatter
import java.time.{Instant, ZoneOffset}
class WorkflowEmailNotifier(
workflowId: Long,
userEmail: String,
sessionUri: URI
) extends EmailNotifier
with LazyLogging {
private val workflowName = WorkflowResource.getWorkflowName(workflowId.toInt)
private val emailValidator = new EmailValidator()
private val TerminalStates: Set[WorkflowAggregatedState] = Set(
COMPLETED,
PAUSED,
FAILED,
KILLED
)
override def shouldSendEmail(workflowState: WorkflowAggregatedState): Boolean =
TerminalStates.contains(workflowState)
override def sendStatusEmail(state: WorkflowAggregatedState): Unit = {
if (!isValidEmail(userEmail)) {
logger.warn(s"Invalid email address: $userEmail")
return
}
val emailMessage = createEmailMessage(state)
try {
GmailResource.sendEmail(emailMessage, userEmail)
} catch {
case e: Exception => println(s"Failed to send email: ${e.getMessage}")
}
}
private def isValidEmail(email: String): Boolean = emailValidator.isValid(email, null)
private def createEmailMessage(state: WorkflowAggregatedState): EmailMessage = {
EmailMessage(
receiver = userEmail,
subject = createEmailSubject(state),
content = createEmailContent(state)
)
}
private def createEmailSubject(state: WorkflowAggregatedState): String =
s"[Texera] Workflow $workflowName ($workflowId) Status: $state"
private def createEmailContent(state: WorkflowAggregatedState): String = {
val timestamp = formatTimestamp(Instant.now())
val dashboardUrl = createDashboardUrl()
s"""
|Hello,
|
|The workflow with the following details has changed its state:
|
|- Workflow ID: $workflowId
|- Workflow Name: $workflowName
|- State: $state
|- Timestamp: $timestamp
|
|You can view more details by visiting: $dashboardUrl
|
|Regards,
|Texera Team
""".stripMargin.trim
}
private def formatTimestamp(instant: Instant): String =
DateTimeFormatter
.ofPattern("MMMM d, yyyy, h:mm:ss a '(UTC)'")
.withZone(ZoneOffset.UTC)
.format(instant)
private def createDashboardUrl(): String = {
val host = sessionUri.getHost
val port = sessionUri.getPort
val path = s"/dashboard/user/workspace/$workflowId"
if (port == -1 || port == 80 || port == 443) {
s"http://$host$path"
} else {
s"http://$host:$port$path"
}
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/service/WorkflowExecutionService.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.typesafe.scalalogging.LazyLogging
import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity}
import org.apache.texera.amber.core.workflow.WorkflowContext
import org.apache.texera.amber.engine.architecture.controller.{ControllerConfig, Workflow}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.EmptyRequest
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState._
import org.apache.texera.amber.engine.common.Utils
import org.apache.texera.amber.engine.common.client.AmberClient
import org.apache.texera.amber.engine.common.executionruntimestate.ExecutionMetadataStore
import org.apache.texera.web.model.websocket.event.{
TexeraWebSocketEvent,
WorkflowErrorEvent,
WorkflowStateEvent
}
import org.apache.texera.web.model.websocket.request.WorkflowExecuteRequest
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowExecutionsResource
import org.apache.texera.web.storage.ExecutionStateStore
import org.apache.texera.web.storage.ExecutionStateStore.updateWorkflowState
import org.apache.texera.web.{ComputingUnitMaster, SubscriptionManager, WebsocketInput}
import org.apache.texera.workflow.WorkflowCompiler
import java.net.URI
import scala.collection.mutable
object WorkflowExecutionService {
def getLatestExecutionId(
workflowId: WorkflowIdentity,
computingUnitId: Int
): Option[ExecutionIdentity] = {
WorkflowExecutionsResource
.getLatestExecutionID(workflowId.id.toInt, computingUnitId)
.map(eid => new ExecutionIdentity(eid.longValue()))
}
}
class WorkflowExecutionService(
controllerConfig: ControllerConfig,
val workflowContext: WorkflowContext,
resultService: ExecutionResultService,
request: WorkflowExecuteRequest,
val executionStateStore: ExecutionStateStore,
errorHandler: Throwable => Unit,
userEmailOpt: Option[String],
sessionUri: URI
) extends SubscriptionManager
with LazyLogging {
workflowContext.workflowSettings = request.workflowSettings
val wsInput = new WebsocketInput(errorHandler)
addSubscription(
executionStateStore.metadataStore.registerDiffHandler((oldState, newState) => {
val outputEvents = new mutable.ArrayBuffer[TexeraWebSocketEvent]()
if (newState.state != oldState.state || newState.isRecovering != oldState.isRecovering) {
outputEvents.append(createStateEvent(newState))
}
if (newState.fatalErrors != oldState.fatalErrors) {
outputEvents.append(WorkflowErrorEvent(newState.fatalErrors))
}
outputEvents
})
)
private def createStateEvent(state: ExecutionMetadataStore): WorkflowStateEvent = {
if (state.isRecovering && state.state != COMPLETED) {
WorkflowStateEvent("Recovering")
} else {
WorkflowStateEvent(Utils.aggregatedStateToString(state.state))
}
}
var workflow: Workflow = _
// Runtime starts from here:
logger.info("Initialing an AmberClient, runtime starting...")
var client: AmberClient = _
var executionReconfigurationService: ExecutionReconfigurationService = _
var executionStatsService: ExecutionStatsService = _
var executionRuntimeService: ExecutionRuntimeService = _
var executionConsoleService: ExecutionConsoleService = _
def executeWorkflow(): Unit = {
try {
workflow = new WorkflowCompiler(workflowContext)
.compile(request.logicalPlan)
} catch {
case err: Throwable =>
errorHandler(err)
}
client = ComputingUnitMaster.createAmberRuntime(
workflow.context,
workflow.physicalPlan,
controllerConfig,
errorHandler
)
executionReconfigurationService =
new ExecutionReconfigurationService(client, executionStateStore, workflow)
executionStatsService = new ExecutionStatsService(client, executionStateStore, workflow.context)
executionRuntimeService = new ExecutionRuntimeService(
client,
executionStateStore,
wsInput,
executionReconfigurationService,
controllerConfig.faultToleranceConfOpt,
workflowContext.workflowId.id,
request.emailNotificationEnabled,
userEmailOpt,
sessionUri
)
executionConsoleService =
new ExecutionConsoleService(client, executionStateStore, wsInput, workflow.context)
logger.info("Starting the workflow execution.")
resultService.attachToExecution(
workflow.context.executionId,
executionStateStore,
workflow.physicalPlan,
client
)
executionStateStore.metadataStore.updateState(metadataStore =>
updateWorkflowState(READY, metadataStore)
.withFatalErrors(Seq.empty)
)
executionStateStore.statsStore.updateState(stats =>
stats.withStartTimeStamp(System.currentTimeMillis())
)
client.controllerInterface
.startWorkflow(EmptyRequest(), ())
.onFailure(err => {
errorHandler(err)
})
.onSuccess(resp =>
executionStateStore.metadataStore.updateState(metadataStore =>
if (metadataStore.state != FAILED) {
updateWorkflowState(resp.workflowState, metadataStore)
} else {
metadataStore
}
)
)
}
override def unsubscribeAll(): Unit = {
super.unsubscribeAll()
if (client != null) {
// runtime created
client.shutdown()
executionRuntimeService.unsubscribeAll()
executionConsoleService.unsubscribeAll()
executionStatsService.unsubscribeAll()
executionReconfigurationService.unsubscribeAll()
}
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/service/WorkflowService.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.google.protobuf.timestamp.Timestamp
import com.typesafe.scalalogging.LazyLogging
import io.reactivex.rxjava3.disposables.{CompositeDisposable, Disposable}
import io.reactivex.rxjava3.subjects.BehaviorSubject
import org.apache.texera.amber.config.ApplicationConfig
import org.apache.texera.amber.core.WorkflowRuntimeException
import org.apache.texera.amber.core.storage.DocumentFactory
import org.apache.texera.amber.core.storage.result.iceberg.OnIceberg
import org.apache.texera.amber.core.virtualidentity.{
EmbeddedControlMessageIdentity,
ExecutionIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.WorkflowContext
import org.apache.texera.amber.core.workflowruntimestate.FatalErrorType.EXECUTION_FAILURE
import org.apache.texera.amber.core.workflowruntimestate.WorkflowFatalError
import org.apache.texera.amber.engine.architecture.controller.ControllerConfig
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState.{
COMPLETED,
FAILED
}
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.{
FaultToleranceConfig,
StateRestoreConfig
}
import org.apache.texera.amber.error.ErrorUtils.{
getOperatorFromActorIdOpt,
getStackTraceWithAllCauses
}
import org.apache.texera.dao.jooq.generated.tables.pojos.User
import org.apache.texera.service.util.LargeBinaryManager
import org.apache.texera.web.model.websocket.event.TexeraWebSocketEvent
import org.apache.texera.web.model.websocket.request.WorkflowExecuteRequest
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowExecutionsResource
import org.apache.texera.web.service.WorkflowService.mkWorkflowStateId
import org.apache.texera.web.storage.ExecutionStateStore.updateWorkflowState
import org.apache.texera.web.storage.{ExecutionStateStore, WorkflowStateStore}
import org.apache.texera.web.{SubscriptionManager, WorkflowLifecycleManager}
import org.apache.texera.workflow.LogicalPlan
import play.api.libs.json.Json
import java.net.URI
import java.time.Instant
import java.util.concurrent.ConcurrentHashMap
import scala.jdk.CollectionConverters.IterableHasAsScala
object WorkflowService {
private val workflowServiceMapping = new ConcurrentHashMap[String, WorkflowService]()
val cleanUpDeadlineInSeconds: Int = ApplicationConfig.executionStateCleanUpInSecs
def getAllWorkflowServices: Iterable[WorkflowService] = workflowServiceMapping.values().asScala
def mkWorkflowStateId(workflowId: WorkflowIdentity): String = {
workflowId.toString
}
def getOrCreate(
workflowId: WorkflowIdentity,
computingUnitId: Int,
cleanupTimeout: Int = cleanUpDeadlineInSeconds
): WorkflowService = {
workflowServiceMapping.compute(
mkWorkflowStateId(workflowId),
(_, v) => {
if (v == null) {
new WorkflowService(workflowId, computingUnitId, cleanupTimeout)
} else {
v
}
}
)
}
}
class WorkflowService(
val workflowId: WorkflowIdentity,
val computingUnitId: Int,
cleanUpTimeout: Int
) extends SubscriptionManager
with LazyLogging {
// state across execution:
private val errorSubject = BehaviorSubject.create[TexeraWebSocketEvent]().toSerialized
val stateStore = new WorkflowStateStore()
var executionService: BehaviorSubject[WorkflowExecutionService] = BehaviorSubject.create()
val resultService: ExecutionResultService =
new ExecutionResultService(workflowId, computingUnitId, stateStore)
val lifeCycleManager: WorkflowLifecycleManager = new WorkflowLifecycleManager(
s"workflowId=$workflowId",
cleanUpTimeout,
() => {
// clear the storage resources associated with the latest execution
WorkflowExecutionService
.getLatestExecutionId(workflowId, computingUnitId)
.foreach(eid => {
clearExecutionResources(eid)
})
WorkflowService.workflowServiceMapping.remove(mkWorkflowStateId(workflowId))
if (executionService.getValue != null) {
// shutdown client
executionService.getValue.client.shutdown()
}
unsubscribeAll()
}
)
var lastCompletedLogicalPlan: Option[LogicalPlan] = Option.empty
executionService.subscribe { executionService: WorkflowExecutionService =>
{
executionService.executionStateStore.metadataStore.registerDiffHandler {
(oldState, newState) =>
{
if (oldState.state != COMPLETED && newState.state == COMPLETED) {
lastCompletedLogicalPlan = Option.apply(executionService.workflow.logicalPlan)
}
Iterable.empty
}
}
}
}
def connect(onNext: TexeraWebSocketEvent => Unit): Disposable = {
lifeCycleManager.increaseUserCount()
val subscriptions = stateStore.getAllStores
.map(_.getWebsocketEventObservable)
.map(evtPub =>
evtPub.subscribe { evts: Iterable[TexeraWebSocketEvent] => evts.foreach(onNext) }
)
.toSeq
val errorSubscription = errorSubject.subscribe { evt: TexeraWebSocketEvent => onNext(evt) }
new CompositeDisposable(subscriptions :+ errorSubscription: _*)
}
def connectToExecution(onNext: TexeraWebSocketEvent => Unit): Disposable = {
val localDisposable = new CompositeDisposable()
val disposable = executionService.subscribe { execService: WorkflowExecutionService =>
localDisposable.clear() // Clears previous subscriptions safely
val subscriptions = execService.executionStateStore.getAllStores
.map(_.getWebsocketEventObservable)
.map(evtPub =>
evtPub.subscribe { events: Iterable[TexeraWebSocketEvent] => events.foreach(onNext) }
)
.toSeq
localDisposable.addAll(subscriptions: _*)
}
// Note: this new CompositeDisposable is necessary. DO NOT OPTIMIZE.
new CompositeDisposable(localDisposable, disposable)
}
def disconnect(): Unit = {
lifeCycleManager.decreaseUserCount(
Option(executionService.getValue).map(_.executionStateStore.metadataStore.getState.state)
)
}
private[this] def createWorkflowContext(): WorkflowContext = {
new WorkflowContext(workflowId)
}
def initExecutionService(
req: WorkflowExecuteRequest,
userOpt: Option[User],
sessionUri: URI
): Unit = {
if (executionService.hasValue) {
executionService.getValue.unsubscribeAll()
}
val (uidOpt, userEmailOpt) = userOpt.map(user => (user.getUid, user.getEmail)).unzip
val workflowContext: WorkflowContext = createWorkflowContext()
var controllerConf = ControllerConfig.default
// clean up results from previous run
val previousExecutionId =
WorkflowExecutionService.getLatestExecutionId(workflowId, req.computingUnitId)
previousExecutionId.foreach(eid => {
clearExecutionResources(eid)
}) // TODO: change this behavior after enabling cache.
workflowContext.executionId = ExecutionsMetadataPersistService.insertNewExecution(
workflowContext.workflowId,
uidOpt,
req.executionName,
convertToJson(req.engineVersion),
req.computingUnitId
)
if (ApplicationConfig.faultToleranceLogRootFolder.isDefined) {
val writeLocation = ApplicationConfig.faultToleranceLogRootFolder.get.resolve(
s"${workflowContext.workflowId}/${workflowContext.executionId}/"
)
ExecutionsMetadataPersistService.tryUpdateExistingExecution(workflowContext.executionId) {
execution => execution.setLogLocation(writeLocation.toString)
}
controllerConf = controllerConf.copy(faultToleranceConfOpt =
Some(FaultToleranceConfig(writeTo = writeLocation))
)
}
if (req.replayFromExecution.isDefined) {
val replayInfo = req.replayFromExecution.get
ExecutionsMetadataPersistService
.tryGetExistingExecution(ExecutionIdentity(replayInfo.eid))
.foreach { execution =>
val readLocation = new URI(execution.getLogLocation)
controllerConf = controllerConf.copy(stateRestoreConfOpt =
Some(
StateRestoreConfig(
readFrom = readLocation,
replayDestination = EmbeddedControlMessageIdentity(replayInfo.interaction)
)
)
)
}
}
val executionStateStore = new ExecutionStateStore()
// assign execution id to find the execution from DB in case the constructor fails.
executionStateStore.metadataStore.updateState(state =>
state.withExecutionId(workflowContext.executionId)
)
val errorHandler: Throwable => Unit = { t =>
{
val fromActorOpt = t match {
case ex: WorkflowRuntimeException =>
ex.relatedWorkerId
case other =>
None
}
val (operatorId, workerId) = getOperatorFromActorIdOpt(fromActorOpt)
logger.error("error during execution", t)
executionStateStore.statsStore.updateState(stats =>
stats.withEndTimeStamp(System.currentTimeMillis())
)
executionStateStore.metadataStore.updateState { metadataStore =>
updateWorkflowState(FAILED, metadataStore).addFatalErrors(
WorkflowFatalError(
EXECUTION_FAILURE,
Timestamp(Instant.now),
t.toString,
getStackTraceWithAllCauses(t),
operatorId,
workerId
)
)
}
}
}
try {
val execution = new WorkflowExecutionService(
controllerConf,
workflowContext,
resultService,
req,
executionStateStore,
errorHandler,
userEmailOpt,
sessionUri
)
lifeCycleManager.registerCleanUpOnStateChange(executionStateStore)
executionService.onNext(execution)
execution.executeWorkflow()
} catch {
case e: Throwable => errorHandler(e)
}
}
def convertToJson(frontendVersion: String): String = {
val environmentVersionMap = Map(
"engine_version" -> Json.toJson(frontendVersion)
)
Json.stringify(Json.toJson(environmentVersionMap))
}
override def unsubscribeAll(): Unit = {
super.unsubscribeAll()
Option(executionService.getValue).foreach(_.unsubscribeAll())
resultService.unsubscribeAll()
}
/**
* Cleans up all resources associated with a workflow execution.
*
* This method performs resource cleanup in the following sequence:
* 1. Retrieves all document URIs associated with the execution
* 2. Clears URI references from the execution registry
* 3. Safely clears all result and console message documents
* 4. Expires Iceberg snapshots for runtime statistics
* 5. Deletes large binaries from MinIO
*
* @param eid The execution identity to clean up resources for
*/
private def clearExecutionResources(eid: ExecutionIdentity): Unit = {
// Retrieve URIs for all resources associated with this execution
val resultUris = WorkflowExecutionsResource.getResultUrisByExecutionId(eid)
val consoleMessagesUris = WorkflowExecutionsResource.getConsoleMessagesUriByExecutionId(eid)
// Remove references from registry first
WorkflowExecutionsResource.deleteConsoleMessageAndExecutionResultUris(eid)
// Clean up all result and console message documents
(resultUris ++ consoleMessagesUris).foreach { uri =>
try DocumentFactory.openDocument(uri)._1.clear()
catch {
case error: Throwable =>
logger.debug(s"Error processing document at $uri: ${error.getMessage}")
}
}
// Expire any Iceberg snapshots for runtime statistics
WorkflowExecutionsResource.getRuntimeStatsUriByExecutionId(eid).foreach { uri =>
try {
DocumentFactory.openDocument(uri)._1 match {
case iceberg: OnIceberg => iceberg.expireSnapshots()
case other =>
logger.error(
s"Cannot expire snapshots: document from URI [$uri] is of type ${other.getClass.getName}. " +
s"Expected an instance of ${classOf[OnIceberg].getName}."
)
}
} catch {
case error: Throwable =>
logger.debug(s"Error processing document at $uri: ${error.getMessage}")
}
}
// Delete large binaries
LargeBinaryManager.deleteAllObjects()
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/storage/ExecutionReconfigurationStore.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.storage
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.apache.texera.amber.operator.StateTransferFunc
case class ExecutionReconfigurationStore(
currentReconfigId: Option[String] = None,
unscheduledReconfigurations: List[(PhysicalOp, Option[StateTransferFunc])] = List(),
completedReconfigurations: Set[ActorVirtualIdentity] = Set()
)
================================================
FILE: amber/src/main/scala/org/apache/texera/web/storage/ExecutionStateStore.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.storage
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState
import org.apache.texera.amber.engine.common.Utils.maptoStatusCode
import org.apache.texera.amber.engine.common.executionruntimestate.{
ExecutionBreakpointStore,
ExecutionConsoleStore,
ExecutionMetadataStore,
ExecutionStatsStore
}
import org.apache.texera.web.service.ExecutionsMetadataPersistService
import java.sql.Timestamp
object ExecutionStateStore {
// Update the state of the specified execution if user system is enabled.
// Update the execution only from backend
def updateWorkflowState(
state: WorkflowAggregatedState,
metadataStore: ExecutionMetadataStore
): ExecutionMetadataStore = {
ExecutionsMetadataPersistService.tryUpdateExistingExecution(metadataStore.executionId) {
execution =>
execution.setStatus(maptoStatusCode(state))
execution.setLastUpdateTime(new Timestamp(System.currentTimeMillis()))
}
metadataStore.withState(state)
}
}
// states that within one execution.
class ExecutionStateStore {
val statsStore = new StateStore(ExecutionStatsStore())
val metadataStore = new StateStore(ExecutionMetadataStore())
val consoleStore = new StateStore(ExecutionConsoleStore())
val breakpointStore = new StateStore(ExecutionBreakpointStore())
val reconfigurationStore = new StateStore(ExecutionReconfigurationStore())
def getAllStores: Iterable[StateStore[_]] = {
Iterable(statsStore, consoleStore, breakpointStore, metadataStore, reconfigurationStore)
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/storage/StateStore.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.storage
import io.reactivex.rxjava3.core.{Observable, Single}
import io.reactivex.rxjava3.disposables.Disposable
import io.reactivex.rxjava3.subjects.BehaviorSubject
import org.apache.texera.amber.engine.common.Utils.withLock
import org.apache.texera.web.model.websocket.event.TexeraWebSocketEvent
import java.util
import java.util.concurrent.locks.ReentrantLock
import scala.collection.mutable
class StateStore[T](defaultState: T) {
private val stateSubject = BehaviorSubject.createDefault(defaultState)
private val serializedSubject = stateSubject.toSerialized
private implicit val lock: ReentrantLock = new ReentrantLock()
private val diffHandlers = new mutable.ArrayBuffer[(T, T) => Iterable[TexeraWebSocketEvent]]
private val diffSubject = serializedSubject
.startWith(Single.just(defaultState))
.buffer(2, 1)
.filter(states => states.get(0) != states.get(1))
.map[Iterable[TexeraWebSocketEvent]] { states: util.List[T] =>
withLock {
diffHandlers.flatMap(f => f(states.get(0), states.get(1)))
}
}
def getState: T = stateSubject.getValue
def updateState(func: T => T): Unit = {
withLock {
val newState = func(stateSubject.getValue)
serializedSubject.onNext(newState)
}
}
def registerDiffHandler(handler: (T, T) => Iterable[TexeraWebSocketEvent]): Disposable = {
withLock {
diffHandlers.append(handler)
}
Disposable.fromAction { () =>
withLock {
diffHandlers -= handler
}
}
}
def getWebsocketEventObservable: Observable[Iterable[TexeraWebSocketEvent]] =
diffSubject.onTerminateDetach()
def getStateObservable: Observable[T] = serializedSubject.onTerminateDetach()
}
================================================
FILE: amber/src/main/scala/org/apache/texera/web/storage/WorkflowStateStore.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.storage
import org.apache.texera.amber.core.storage.result.WorkflowResultStore
// states that across executions.
class WorkflowStateStore {
val resultStore = new StateStore(WorkflowResultStore())
def getAllStores: Iterable[StateStore[_]] = {
Iterable(resultStore)
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/workflow/LogicalLink.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.workflow
import com.fasterxml.jackson.annotation.{JsonCreator, JsonProperty}
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.core.workflow.PortIdentity
case class LogicalLink(
@JsonProperty("fromOpId") fromOpId: OperatorIdentity,
fromPortId: PortIdentity,
@JsonProperty("toOpId") toOpId: OperatorIdentity,
toPortId: PortIdentity
) {
@JsonCreator
def this(
@JsonProperty("fromOpId") fromOpId: String,
fromPortId: PortIdentity,
@JsonProperty("toOpId") toOpId: String,
toPortId: PortIdentity
) = {
this(OperatorIdentity(fromOpId), fromPortId, OperatorIdentity(toOpId), toPortId)
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/workflow/LogicalPlan.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.workflow
import com.typesafe.scalalogging.LazyLogging
import org.apache.texera.amber.core.storage.FileResolver
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.operator.LogicalOp
import org.apache.texera.amber.operator.source.scan.ScanSourceOpDesc
import org.apache.texera.web.model.websocket.request.LogicalPlanPojo
import org.jgrapht.graph.DirectedAcyclicGraph
import org.jgrapht.util.SupplierUtil
import java.util
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}
object LogicalPlan {
private def toJgraphtDAG(
operatorList: List[LogicalOp],
links: List[LogicalLink]
): DirectedAcyclicGraph[OperatorIdentity, LogicalLink] = {
val workflowDag =
new DirectedAcyclicGraph[OperatorIdentity, LogicalLink](
null, // vertexSupplier
SupplierUtil.createSupplier(classOf[LogicalLink]), // edgeSupplier
false, // weighted
true // allowMultipleEdges
)
operatorList.foreach(op => workflowDag.addVertex(op.operatorIdentifier))
links.foreach(l =>
workflowDag.addEdge(
l.fromOpId,
l.toOpId,
l
)
)
workflowDag
}
def apply(
pojo: LogicalPlanPojo
): LogicalPlan = {
LogicalPlan(pojo.operators, pojo.links)
}
}
case class LogicalPlan(
operators: List[LogicalOp],
links: List[LogicalLink]
) extends LazyLogging {
private lazy val operatorMap: Map[OperatorIdentity, LogicalOp] =
operators.map(op => (op.operatorIdentifier, op)).toMap
private lazy val jgraphtDag: DirectedAcyclicGraph[OperatorIdentity, LogicalLink] =
LogicalPlan.toJgraphtDAG(operators, links)
def getTopologicalOpIds: util.Iterator[OperatorIdentity] = jgraphtDag.iterator()
def getOperator(opId: OperatorIdentity): LogicalOp = operatorMap(opId)
def getTerminalOperatorIds: List[OperatorIdentity] =
operatorMap.keys
.filter(op => jgraphtDag.outDegreeOf(op) == 0)
.toList
def getUpstreamLinks(opId: OperatorIdentity): List[LogicalLink] = {
links.filter(l => l.toOpId == opId)
}
/**
* Resolve all user-given filename for the scan source operators to URIs, and call op.setFileUri to set the URi
*
* @param errorList if given, put errors during resolving to it
*/
def resolveScanSourceOpFileName(
errorList: Option[ArrayBuffer[(OperatorIdentity, Throwable)]]
): Unit = {
operators.foreach {
case operator @ (scanOp: ScanSourceOpDesc) =>
Try {
// Resolve file path for ScanSourceOpDesc
val fileName = scanOp.fileName.getOrElse(throw new RuntimeException("no input file name"))
val fileUri = FileResolver.resolve(fileName) // Convert to URI
// Set the URI in the ScanSourceOpDesc
scanOp.setResolvedFileName(fileUri)
} match {
case Success(_) => // Successfully resolved and set the file URI
case Failure(err) =>
logger.error("Error resolving file path for ScanSourceOpDesc", err)
errorList match {
case Some(errList) =>
errList.append((operator.operatorIdentifier, err))
case None =>
// Throw the error if no errorList is provided
throw err
}
}
case _ => // Skip non-ScanSourceOpDesc operators
}
}
}
================================================
FILE: amber/src/main/scala/org/apache/texera/workflow/WorkflowCompiler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.workflow
import com.typesafe.scalalogging.LazyLogging
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.core.workflow._
import org.apache.texera.amber.engine.architecture.controller.Workflow
import org.apache.texera.web.model.websocket.request.LogicalPlanPojo
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters.IteratorHasAsScala
import scala.util.{Failure, Success, Try}
class WorkflowCompiler(
context: WorkflowContext
) extends LazyLogging {
/**
* Function to expand logical plan to physical plan
* @return the expanded physical plan and a set of output ports that need storage
*/
private def expandLogicalPlan(
logicalPlan: LogicalPlan,
logicalOpsToViewResult: List[String],
errorList: Option[ArrayBuffer[(OperatorIdentity, Throwable)]]
): (PhysicalPlan, Set[GlobalPortIdentity]) = {
val terminalLogicalOps = logicalPlan.getTerminalOperatorIds
val logicalOpsNeedingStorage =
(terminalLogicalOps ++ logicalOpsToViewResult.map(OperatorIdentity(_))).toSet
var physicalPlan = PhysicalPlan(operators = Set.empty, links = Set.empty)
val outputPortsNeedingStorage: mutable.HashSet[GlobalPortIdentity] = mutable.HashSet()
logicalPlan.getTopologicalOpIds.asScala.foreach(logicalOpId =>
Try {
val logicalOp = logicalPlan.getOperator(logicalOpId)
val subPlan = logicalOp.getPhysicalPlan(context.workflowId, context.executionId)
subPlan
.topologicalIterator()
.map(subPlan.getOperator)
.foreach({ physicalOp =>
{
val externalLinks = logicalPlan
.getUpstreamLinks(logicalOp.operatorIdentifier)
.filter(link => physicalOp.inputPorts.contains(link.toPortId))
.flatMap { link =>
physicalPlan
.getPhysicalOpsOfLogicalOp(link.fromOpId)
.find(_.outputPorts.contains(link.fromPortId))
.map(fromOp =>
PhysicalLink(fromOp.id, link.fromPortId, physicalOp.id, link.toPortId)
)
}
val internalLinks = subPlan.getUpstreamPhysicalLinks(physicalOp.id)
// Add the operator to the physical plan
physicalPlan = physicalPlan.addOperator(physicalOp.propagateSchema())
// Add all the links to the physical plan
physicalPlan = (externalLinks ++ internalLinks)
.foldLeft(physicalPlan) { (plan, link) => plan.addLink(link) }
// **Check for Python-based operator errors during code generation**
if (physicalOp.isPythonBased) {
val code = physicalOp.getCode
val exceptionPattern = """#EXCEPTION DURING CODE GENERATION:\s*(.*)""".r
exceptionPattern.findFirstMatchIn(code).foreach { matchResult =>
val errorMessage = matchResult.group(1).trim
val error =
new RuntimeException(s"Operator is not configured properly: $errorMessage")
errorList match {
case Some(list) => list.append((logicalOpId, error)) // Store error and continue
case None => throw error // Throw immediately if no error list is provided
}
}
}
}
})
// convert logical operators needing storage to output ports needing storage
subPlan
.topologicalIterator()
.filter(opId => logicalOpsNeedingStorage.contains(opId.logicalOpId))
.map(physicalPlan.getOperator)
.foreach { physicalOp =>
physicalOp.outputPorts
.filterNot(_._1.internal)
.foreach {
case (outputPortId, _) =>
outputPortsNeedingStorage += GlobalPortIdentity(
opId = physicalOp.id,
portId = outputPortId
)
}
}
} match {
case Success(_) =>
case Failure(err) =>
errorList match {
case Some(list) => list.append((logicalOpId, err))
case None => throw err
}
}
)
(physicalPlan, outputPortsNeedingStorage.toSet)
}
/**
* Compile a workflow to physical plan, along with the schema propagation result and error(if any)
*
* Comparing to WorkflowCompilingService's compiler, which is used solely for workflow editing,
* This compile is used before executing the workflow.
*
* TODO: we should consider merge this compile with WorkflowCompilingService's compile
* @param logicalPlanPojo the pojo parsed from workflow str provided by user
* @return Workflow, containing the physical plan, logical plan and workflow context
*/
def compile(
logicalPlanPojo: LogicalPlanPojo
): Workflow = {
// 1. convert the pojo to logical plan
val logicalPlan: LogicalPlan = LogicalPlan(logicalPlanPojo)
// 2. resolve the file name in each scan source operator
logicalPlan.resolveScanSourceOpFileName(None)
// 3. expand the logical plan to the physical plan, and get a set of output ports that need storage
val (physicalPlan, outputPortsNeedingStorage) =
expandLogicalPlan(logicalPlan, logicalPlanPojo.opsToViewResult, None)
context.workflowSettings = context.workflowSettings.copy(
outputPortsNeedingStorage = outputPortsNeedingStorage
)
Workflow(context, logicalPlan, physicalPlan)
}
}
================================================
FILE: amber/src/test/integration/org/apache/texera/amber/engine/e2e/ReconfigurationIntegrationSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.e2e
import com.twitter.util.{Await, Duration, Promise, Return}
import com.typesafe.scalalogging.Logger
import org.apache.pekko.actor.{ActorSystem, Props}
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.pekko.util.Timeout
import org.apache.texera.amber.clustering.SingleNodeListener
import org.apache.texera.amber.core.executor.{OpExecInitInfo, OpExecWithCode}
import org.apache.texera.amber.core.tuple.Tuple
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.architecture.controller.{
ControllerConfig,
ExecutionStateUpdate
}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.EmptyRequest
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState.COMPLETED
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.common.client.AmberClient
import org.apache.texera.amber.engine.e2e.TestUtils.{
buildWorkflow,
cleanupWorkflowExecutionData,
initiateTexeraDBForTestCases,
setUpWorkflowExecutionData
}
import org.apache.texera.amber.operator.source.scan.text.TextInputSourceOpDesc
import org.apache.texera.amber.operator.{LogicalOp, TestOperators}
import org.apache.texera.amber.tags.IntegrationTest
import org.apache.texera.workflow.LogicalLink
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Outcome, Retries}
import org.scalatest.flatspec.AnyFlatSpecLike
import scala.concurrent.duration._
/**
* E2E reconfiguration tests that spawn Python UDF workers. Routed to the
* `amber-integration` CI job via the class-level `@IntegrationTest` tag,
* which provisions Python deps; the lighter `amber` job excludes this tag.
*
* Pure-Scala reconfiguration tests live in [[ReconfigurationSpec]] and run
* in the regular `amber` job.
*/
@IntegrationTest
class ReconfigurationIntegrationSpec
extends TestKit(ActorSystem("ReconfigurationIntegrationSpec", AmberRuntime.pekkoConfig))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll
with BeforeAndAfterEach
with Retries {
/**
* This block retries each test once if it fails.
* In the CI environment, there is a chance that executeWorkflow does not receive "COMPLETED" status.
* Until we find the root cause of this issue, we use a retry mechanism here to stabilize CI runs.
*/
override def withFixture(test: NoArgTest): Outcome =
withRetry { super.withFixture(test) }
implicit val timeout: Timeout = Timeout(5.seconds)
val logger = Logger("ReconfigurationIntegrationSpecLogger")
val ctx = new WorkflowContext()
override protected def beforeEach(): Unit = {
setUpWorkflowExecutionData()
}
override protected def afterEach(): Unit = {
cleanupWorkflowExecutionData()
}
override def beforeAll(): Unit = {
system.actorOf(Props[SingleNodeListener](), "cluster-info")
// These test cases access postgres in CI, but occasionally the jdbc driver cannot be found during CI run.
// Explicitly load the JDBC driver to avoid flaky CI failures.
Class.forName("org.postgresql.Driver")
initiateTexeraDBForTestCases()
warmupOnce()
}
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
/**
* Run a trivial pure-Scala workflow (TextInput → terminal) once before the
* timed tests start, so the first 5-second `startWorkflow` await in
* [[TestUtils.shouldReconfigure]] doesn't have to absorb JVM JIT
* warmup, pekko dispatcher first-touch, and `RegionExecutionCoordinator`
* class loading.
*
* Hard-capped at 10 seconds total, defensively wrapped: if warmup itself
* times out or throws, log and continue — the existing `Retries` mixin
* still backs up individual test cases. This ensures warmup can never
* hang the suite.
*/
private def warmupOnce(): Unit = {
val warmupCap = Duration.fromSeconds(10)
setUpWorkflowExecutionData()
var client: AmberClient = null
try {
val src = new TextInputSourceOpDesc()
src.textInput = "warmup"
val warmupCtx = new WorkflowContext()
val workflow = buildWorkflow(List(src), List.empty, warmupCtx)
client = new AmberClient(
system,
workflow.context,
workflow.physicalPlan,
ControllerConfig.default,
_ => {}
)
val completion = Promise[Unit]()
client.registerCallback[ExecutionStateUpdate](evt => {
if (evt.state == COMPLETED) completion.updateIfEmpty(Return(()))
})
Await.result(
client.controllerInterface.startWorkflow(EmptyRequest(), ()),
warmupCap
)
Await.result(completion, warmupCap)
} catch {
case e: Throwable =>
logger.warn(
s"warmup workflow did not finish within ${warmupCap}; tests will run cold and rely on Retries: ${e.getMessage}"
)
} finally {
if (client != null) {
try client.shutdown()
catch { case _: Throwable => () }
}
cleanupWorkflowExecutionData()
}
}
// Thin wrapper around the shared TestUtils helper so call sites below stay
// ctx/system-implicit. The actual workflow-driver logic lives in TestUtils
// and is reused by ReconfigurationSpec.
def shouldReconfigure(
operators: List[LogicalOp],
links: List[LogicalLink],
targetOps: Seq[LogicalOp],
newOpExecInitInfo: OpExecInitInfo
): Map[OperatorIdentity, List[Tuple]] =
TestUtils.shouldReconfigure(system, ctx, operators, links, targetOps, newOpExecInitInfo)
"Engine" should "be able to modify a python UDF worker in workflow" in {
val sourceOpDesc = TestOperators.smallCsvScanOpDesc()
val udfOpDesc = TestOperators.pythonOpDesc()
val code = """
|from pytexera import *
|
|class ProcessTupleOperator(UDFOperatorV2):
| @overrides
| def process_tuple(self, tuple_: Tuple, port: int) -> Iterator[Optional[TupleLike]]:
| tuple_['Region'] = tuple_['Region'] + '_reconfigured'
| yield tuple_
|""".stripMargin
val result = shouldReconfigure(
List(sourceOpDesc, udfOpDesc),
List(
LogicalLink(
sourceOpDesc.operatorIdentifier,
PortIdentity(),
udfOpDesc.operatorIdentifier,
PortIdentity()
)
),
Seq(udfOpDesc),
OpExecWithCode(code, "python")
)
assert(result(udfOpDesc.operatorIdentifier).exists { t =>
t.getField("Region").asInstanceOf[String].contains("_reconfigured")
})
}
"Engine" should "propagate reconfiguration through a source operator in workflow" in {
val sourceOpDesc = TestOperators.pythonSourceOpDesc(10000)
val udfOpDesc = TestOperators.pythonOpDesc()
val code = """
|from pytexera import *
|
|class ProcessTupleOperator(UDFOperatorV2):
| @overrides
| def process_tuple(self, tuple_: Tuple, port: int) -> Iterator[Optional[TupleLike]]:
| tuple_['field_1'] = tuple_['field_1'] + '_reconfigured'
| yield tuple_
|""".stripMargin
val result = shouldReconfigure(
List(sourceOpDesc, udfOpDesc),
List(
LogicalLink(
sourceOpDesc.operatorIdentifier,
PortIdentity(),
udfOpDesc.operatorIdentifier,
PortIdentity()
)
),
Seq(udfOpDesc),
OpExecWithCode(code, "python")
)
assert(result(udfOpDesc.operatorIdentifier).exists { t =>
t.getField("field_1").asInstanceOf[String].contains("_reconfigured")
})
}
"Engine" should "be able to modify two python UDFs in workflow" in {
val sourceOpDesc = TestOperators.smallCsvScanOpDesc()
val udfOpDesc1 = TestOperators.pythonOpDesc()
val udfOpDesc2 = TestOperators.pythonOpDesc()
val code = """
|from pytexera import *
|
|class ProcessTupleOperator(UDFOperatorV2):
| @overrides
| def process_tuple(self, tuple_: Tuple, port: int) -> Iterator[Optional[TupleLike]]:
| tuple_['Region'] = tuple_['Region'] + '_reconfigured'
| yield tuple_
|""".stripMargin
val result = shouldReconfigure(
List(sourceOpDesc, udfOpDesc1, udfOpDesc2),
List(
LogicalLink(
sourceOpDesc.operatorIdentifier,
PortIdentity(),
udfOpDesc1.operatorIdentifier,
PortIdentity()
),
LogicalLink(
udfOpDesc1.operatorIdentifier,
PortIdentity(),
udfOpDesc2.operatorIdentifier,
PortIdentity()
)
),
Seq(udfOpDesc1, udfOpDesc2),
OpExecWithCode(code, "python")
)
assert(result(udfOpDesc2.operatorIdentifier).exists { t =>
t.getField("Region").asInstanceOf[String].contains("_reconfigured_reconfigured")
})
}
}
================================================
FILE: amber/src/test/integration/org/apache/texera/amber/storage/iceberg/IcebergRestCatalogIntegrationSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.storage.iceberg
import org.apache.iceberg.catalog.TableIdentifier
import org.apache.iceberg.exceptions.NoSuchTableException
import org.apache.iceberg.rest.RESTCatalog
import org.apache.texera.amber.config.StorageConfig
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema}
import org.apache.texera.amber.tags.IntegrationTest
import org.apache.texera.amber.util.IcebergUtil
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import java.util.UUID
/** Round-trip table metadata via the REST catalog. */
@IntegrationTest
class IcebergRestCatalogIntegrationSpec extends AnyFlatSpec with BeforeAndAfterAll {
private var restCatalog: RESTCatalog = _
private val testNamespace = "rest_integration_test"
override def beforeAll(): Unit = {
super.beforeAll()
restCatalog = IcebergUtil.createRestCatalog(
"rest_integration_test",
StorageConfig.icebergRESTCatalogWarehouseName
)
}
behavior of "Iceberg REST catalog"
it should "round-trip table metadata via the REST catalog" in {
val amberSchema = Schema(
List(
new Attribute("id", AttributeType.INTEGER),
new Attribute("name", AttributeType.STRING)
)
)
val icebergSchema = IcebergUtil.toIcebergSchema(amberSchema)
val tableName = s"rest_table_${UUID.randomUUID().toString.replace("-", "")}"
val identifier = TableIdentifier.of(testNamespace, tableName)
IcebergUtil.createTable(
restCatalog,
testNamespace,
tableName,
icebergSchema,
overrideIfExists = true
)
assert(restCatalog.tableExists(identifier))
val loaded = restCatalog.loadTable(identifier)
assert(loaded.schema().sameSchema(icebergSchema))
restCatalog.dropTable(identifier, false)
assert(!restCatalog.tableExists(identifier))
intercept[NoSuchTableException] {
restCatalog.loadTable(identifier)
}
}
}
================================================
FILE: amber/src/test/integration/org/apache/texera/amber/tags/IntegrationTest.java
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.tags;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import org.scalatest.TagAnnotation;
/**
* Class-level marker tag for ScalaTest specs that exercise both Scala
* and Python end-to-end. Routing to the {@code amber-integration} CI
* job is by ScalaTest tag filtering, controlled by the
* {@code AMBER_TEST_FILTER} env var in {@code amber/build.sbt}: the
* lighter {@code amber} job runs with {@code skip-integration} (which
* passes {@code -l org.apache.texera.amber.tags.IntegrationTest} to
* ScalaTest), and the {@code amber-integration} job runs with
* {@code integration-only} (which passes {@code -n} for the same tag).
* The {@code amber/src/test/integration} directory is added to sbt's
* {@code Test/unmanagedSourceDirectories} so these specs compile in
* the regular Test config; there is no separate sbt configuration.
*
* Written in Java rather than Scala because ScalaTest detects tag
* annotations via {@code java.lang.annotation} reflection. A Scala
* {@code class extends StaticAnnotation} does not produce a JVM
* annotation interface that {@code @TagAnnotation} can attach to, so
* the tag would be invisible to ScalaTest at runtime.
*/
@TagAnnotation
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
public @interface IntegrationTest {
}
================================================
FILE: amber/src/test/java/org/apache/texera/web/resource/dashboard/user/dataset/GitVersionControlLocalFileStorageSpec.java
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.resource.dashboard.user.dataset;
import org.apache.texera.amber.core.storage.util.dataset.GitVersionControlLocalFileStorage;
import org.apache.texera.amber.core.storage.util.dataset.PhysicalFileNode;
import org.eclipse.jgit.api.errors.GitAPIException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class GitVersionControlLocalFileStorageSpec {
private Path testRepoPath;
private List testRepoMasterCommitHashes;
private final String testFile1Name = "testFile1.txt";
private final String testFile2Name = "testFile2.txt";
private final String testDirectoryName = "testDir";
private final String testFile1ContentV1 = "This is a test file1 v1";
private final String testFile1ContentV2 = "This is a test file1 v2";
private final String testFile1ContentV3 = "This is a test file1 v3";
private final String testFile2Content = "This is a test file2 in the testDir";
private void writeFileToRepo(Path filePath, String fileContent) throws IOException, GitAPIException {
try (ByteArrayInputStream input = new ByteArrayInputStream(fileContent.getBytes())) {
GitVersionControlLocalFileStorage.writeFileToRepo(testRepoPath, filePath, input);
}
}
@Before
public void setUp() throws IOException, GitAPIException {
// Create a temporary directory for the repository
testRepoPath = Files.createTempDirectory("testRepo");
GitVersionControlLocalFileStorage.initRepo(testRepoPath);
Path file1Path = testRepoPath.resolve(testFile1Name);
// Version 1
String v1Hash = GitVersionControlLocalFileStorage.withCreateVersion(
testRepoPath,
"v1",
() -> {
try {
writeFileToRepo(file1Path, testFile1ContentV1);
} catch (IOException | GitAPIException e) {
throw new RuntimeException(e);
}
});
String v2Hash = GitVersionControlLocalFileStorage.withCreateVersion(
testRepoPath,
"v2",
() -> {
try {
writeFileToRepo(file1Path, testFile1ContentV2);
} catch (IOException | GitAPIException e) {
throw new RuntimeException(e);
}
});
// Version 3
String v3Hash = GitVersionControlLocalFileStorage.withCreateVersion(
testRepoPath,
"v3",
() -> {
try {
writeFileToRepo(file1Path, testFile1ContentV3);
} catch (IOException | GitAPIException e) {
throw new RuntimeException(e);
}
});
testRepoMasterCommitHashes = new ArrayList() {{
add(v1Hash);
add(v2Hash);
add(v3Hash);
}};
}
@After
public void tearDown() throws IOException {
// Clean up the test repository directory
GitVersionControlLocalFileStorage.deleteRepo(testRepoPath);
}
@Test
public void testFileContentAcrossVersions() throws IOException, GitAPIException {
// File path for the test file
Path filePath = testRepoPath.resolve(testFile1Name);
// testRepoMasterCommitHashes is populated in chronological order: v1, v2, v3
// Retrieve and compare file content for version 1
ByteArrayOutputStream outputV1 = new ByteArrayOutputStream();
GitVersionControlLocalFileStorage.retrieveFileContentOfVersion(testRepoPath, testRepoMasterCommitHashes.get(0), filePath, outputV1);
String retrievedContentV1 = outputV1.toString();
Assert.assertEquals(
"Content for version 1 does not match",
testFile1ContentV1,
retrievedContentV1);
// Retrieve and compare file content for version 2
ByteArrayOutputStream outputV2 = new ByteArrayOutputStream();
GitVersionControlLocalFileStorage.retrieveFileContentOfVersion(testRepoPath, testRepoMasterCommitHashes.get(1), filePath, outputV2);
String retrievedContentV2 = outputV2.toString();
Assert.assertEquals(
"Content for version 2 does not match",
testFile1ContentV2,
retrievedContentV2);
// Retrieve and compare file content for version 3
ByteArrayOutputStream outputV3 = new ByteArrayOutputStream();
GitVersionControlLocalFileStorage.retrieveFileContentOfVersion(testRepoPath, testRepoMasterCommitHashes.get(2), filePath, outputV3);
String retrievedContentV3 = outputV3.toString();
Assert.assertEquals(
"Content for version 3 does not match",
testFile1ContentV3,
retrievedContentV3);
}
@Test
public void testFileTreeRetrieval() throws Exception {
// File path for the test file
Path file1Path = testRepoPath.resolve(testFile1Name);
PhysicalFileNode file1Node = new PhysicalFileNode(testRepoPath, file1Path, Files.size(file1Path));
Set physicalFileNodes = new HashSet() {{
add(file1Node);
}};
// first retrieve the latest version's file tree
Assert.assertEquals("File Tree should match",
physicalFileNodes,
GitVersionControlLocalFileStorage.retrieveRootFileNodesOfVersion(testRepoPath, testRepoMasterCommitHashes.get(testRepoMasterCommitHashes.size() - 1)));
// now we add a new file testDir/testFile2.txt
Path testDirPath = testRepoPath.resolve(testDirectoryName);
Path file2Path = testDirPath.resolve(testFile2Name);
String v4Hash = GitVersionControlLocalFileStorage.withCreateVersion(testRepoPath, "v4", () -> {
try {
writeFileToRepo(file2Path, testFile2Content);
} catch (IOException | GitAPIException e) {
throw new RuntimeException(e);
}
});
testRepoMasterCommitHashes.add(v4Hash);
PhysicalFileNode dirNode = new PhysicalFileNode(testRepoPath, testDirPath, 0); // Directories typically have size 0
dirNode.addChildNode(new PhysicalFileNode(testRepoPath, file2Path, Files.size(file2Path)));
// update the expected fileNodes
physicalFileNodes.add(dirNode);
// check the file tree
Assert.assertEquals(
"File Tree should match",
physicalFileNodes,
GitVersionControlLocalFileStorage.retrieveRootFileNodesOfVersion(testRepoPath, v4Hash));
// now we delete the file1, check the filetree
String v5Hash = GitVersionControlLocalFileStorage.withCreateVersion(testRepoPath, "v5", () -> {
try {
GitVersionControlLocalFileStorage.removeFileFromRepo(testRepoPath, file1Path);
} catch (IOException | GitAPIException e) {
throw new RuntimeException(e);
}
});
physicalFileNodes.remove(file1Node);
Assert.assertEquals(
"File1 should be gone",
physicalFileNodes,
GitVersionControlLocalFileStorage.retrieveRootFileNodesOfVersion(testRepoPath, v5Hash)
);
}
@Test
public void testUncommittedCheckAndRecoverToLatest() throws Exception {
Path tempFilePath = testRepoPath.resolve("tempFile");
String content = "some random content";
writeFileToRepo(tempFilePath, content);
Assert.assertTrue(
"There should be some uncommitted changes",
GitVersionControlLocalFileStorage.hasUncommittedChanges(testRepoPath));
GitVersionControlLocalFileStorage.discardUncommittedChanges(testRepoPath);
Assert.assertFalse("There should be no uncommitted changes",
GitVersionControlLocalFileStorage.hasUncommittedChanges(testRepoPath));
}
}
================================================
FILE: amber/src/test/python/core/architecture/handlers/control/test_debug_command_handler.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.architecture.handlers.control.debug_command_handler import (
WorkerDebugCommandHandler,
)
from core.architecture.managers.pause_manager import PauseType
from proto.org.apache.texera.amber.engine.architecture.rpc import (
DebugCommandRequest,
EmptyReturn,
)
class TestTranslateDebugCommand:
@pytest.fixture
def context(self):
return SimpleNamespace(
executor_manager=SimpleNamespace(operator_module_name="my_udf")
)
def test_break_with_lineno_prepends_module(self, context):
assert (
WorkerDebugCommandHandler.translate_debug_command("b 5", context)
== "b my_udf:5"
)
def test_long_break_with_lineno_prepends_module(self, context):
assert (
WorkerDebugCommandHandler.translate_debug_command("break 12", context)
== "break my_udf:12"
)
def test_break_preserves_condition_arg(self, context):
assert (
WorkerDebugCommandHandler.translate_debug_command("b 7 x > 0", context)
== "b my_udf:7 x > 0"
)
def test_break_with_no_args_passes_through(self, context):
# No args → falls through to the else branch (no module rewriting).
assert WorkerDebugCommandHandler.translate_debug_command("b", context) == "b"
def test_non_break_command_passes_through(self, context):
assert WorkerDebugCommandHandler.translate_debug_command("n", context) == "n"
def test_non_break_command_with_args_is_rejoined(self, context):
assert (
WorkerDebugCommandHandler.translate_debug_command("p some_var", context)
== "p some_var"
)
def test_leading_and_trailing_whitespace_is_stripped(self, context):
assert (
WorkerDebugCommandHandler.translate_debug_command(" c ", context) == "c"
)
def test_internal_whitespace_is_collapsed_to_single_space(self, context):
# split() with no args collapses any run of whitespace, so the rejoined
# form has single spaces regardless of how many the user typed.
assert (
WorkerDebugCommandHandler.translate_debug_command("p foo bar", context)
== "p foo bar"
)
def test_break_with_only_lineno_has_no_trailing_space(self, context):
# The implementation joins the (empty) tail with " "; the final strip()
# must remove the trailing whitespace so the command stays valid pdb.
result = WorkerDebugCommandHandler.translate_debug_command("b 5", context)
assert result == "b my_udf:5"
assert not result.endswith(" ")
# ----- edge cases / invalid input -----
def test_empty_command_raises_descriptive_error(self, context):
with pytest.raises(ValueError, match="cannot be empty"):
WorkerDebugCommandHandler.translate_debug_command("", context)
def test_whitespace_only_command_raises_descriptive_error(self, context):
with pytest.raises(ValueError, match="cannot be empty"):
WorkerDebugCommandHandler.translate_debug_command(" \t ", context)
def test_uppercase_break_is_not_recognized(self, context):
# The match list is case-sensitive: ("b", "break"). "BREAK" / "B" fall
# through to the pass-through branch and won't get the module prefix.
assert (
WorkerDebugCommandHandler.translate_debug_command("BREAK 5", context)
== "BREAK 5"
)
assert (
WorkerDebugCommandHandler.translate_debug_command("B 5", context) == "B 5"
)
def test_break_with_function_name_passes_through(self, context):
# pdb's `b` accepts a bare function name and resolves it itself; the
# `module:funcname` form is invalid (pdb expects a lineno after a
# filename prefix). So we leave function-name args unchanged.
assert (
WorkerDebugCommandHandler.translate_debug_command("b my_func", context)
== "b my_func"
)
def test_break_with_explicit_filename_passes_through(self, context):
# The user already typed `filename:lineno` — don't double-prefix.
assert (
WorkerDebugCommandHandler.translate_debug_command("b foo.py:5", context)
== "b foo.py:5"
)
def test_break_with_lineno_before_module_init_raises(self, context):
# Without an initialized executor module we cannot construct
# `module:lineno`, so refuse instead of emitting `b None:5`.
context.executor_manager.operator_module_name = None
with pytest.raises(ValueError, match="executor module not initialized"):
WorkerDebugCommandHandler.translate_debug_command("b 5", context)
def test_break_with_function_name_before_module_init_passes_through(self, context):
# Function-name and filename:lineno forms don't need the module name,
# so they should still work even before the executor is initialized.
context.executor_manager.operator_module_name = None
assert (
WorkerDebugCommandHandler.translate_debug_command("b my_func", context)
== "b my_func"
)
class TestDebugCommandAsyncFlow:
@pytest.fixture
def handler(self):
# ControlHandler.__init__ just stashes context; bypass the protobuf
# base class' __init__ by constructing via __new__.
instance = WorkerDebugCommandHandler.__new__(WorkerDebugCommandHandler)
instance.context = SimpleNamespace(
executor_manager=SimpleNamespace(operator_module_name="my_udf"),
debug_manager=MagicMock(),
pause_manager=MagicMock(),
)
return instance
def test_translates_then_forwards_to_debug_manager(self, handler):
asyncio.run(handler.debug_command(DebugCommandRequest(cmd="b 5")))
handler.context.debug_manager.put_debug_command.assert_called_once_with(
"b my_udf:5"
)
def test_resumes_all_three_pause_types(self, handler):
asyncio.run(handler.debug_command(DebugCommandRequest(cmd="c")))
actual = [
call.args[0] for call in handler.context.pause_manager.resume.call_args_list
]
assert actual == [
PauseType.USER_PAUSE,
PauseType.EXCEPTION_PAUSE,
PauseType.DEBUG_PAUSE,
]
def test_returns_empty_return(self, handler):
result = asyncio.run(handler.debug_command(DebugCommandRequest(cmd="n")))
assert isinstance(result, EmptyReturn)
def test_passes_through_non_break_command_unchanged(self, handler):
asyncio.run(handler.debug_command(DebugCommandRequest(cmd="p x")))
handler.context.debug_manager.put_debug_command.assert_called_once_with("p x")
def test_empty_cmd_propagates_value_error(self, handler):
# An empty cmd hits the ValueError in translate_debug_command. The
# handler does not catch it — the RPC layer will surface the failure
# back to the caller. Pin this so silent swallowing doesn't sneak in.
with pytest.raises(ValueError):
asyncio.run(handler.debug_command(DebugCommandRequest(cmd="")))
def test_translation_failure_skips_put_and_resume(self, handler):
with pytest.raises(ValueError):
asyncio.run(handler.debug_command(DebugCommandRequest(cmd="")))
handler.context.debug_manager.put_debug_command.assert_not_called()
handler.context.pause_manager.resume.assert_not_called()
================================================
FILE: amber/src/test/python/core/architecture/handlers/control/test_evaluate_expression_handler.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from core.architecture.handlers.control.evaluate_expression_handler import (
EvaluateExpressionHandler,
)
from proto.org.apache.texera.amber.engine.architecture.rpc import (
EvaluatedValue,
EvaluatePythonExpressionRequest,
TypedValue,
)
class TestEvaluateExpressionHandler:
@pytest.fixture
def executor(self):
# A stand-in for the user's UDF instance — anything addressable as
# `self` from the evaluated expression will do.
return SimpleNamespace(state="alive")
@pytest.fixture
def handler(self, executor):
instance = EvaluateExpressionHandler.__new__(EvaluateExpressionHandler)
instance.context = SimpleNamespace(
executor_manager=SimpleNamespace(executor=executor),
tuple_processing_manager=SimpleNamespace(
current_input_tuple={"col": 42},
current_input_port_id="port-0",
),
)
return instance
def test_returns_what_the_evaluator_returns(self, handler):
sentinel = EvaluatedValue(
value=TypedValue(expression="1+1", value_ref="2", value_type="int")
)
with patch(
"core.architecture.handlers.control.evaluate_expression_handler"
".ExpressionEvaluator.evaluate",
return_value=sentinel,
) as evaluate:
result = asyncio.run(
handler.evaluate_python_expression(
EvaluatePythonExpressionRequest(expression="1+1")
)
)
assert result is sentinel
evaluate.assert_called_once()
def test_runtime_context_exposes_self_tuple_input(self, handler, executor):
with patch(
"core.architecture.handlers.control.evaluate_expression_handler"
".ExpressionEvaluator.evaluate",
return_value=EvaluatedValue(),
) as evaluate:
asyncio.run(
handler.evaluate_python_expression(
EvaluatePythonExpressionRequest(expression="self.state")
)
)
expression, runtime_context = evaluate.call_args.args
assert expression == "self.state"
assert runtime_context["self"] is executor
assert runtime_context["tuple_"] == {"col": 42}
assert runtime_context["input_"] == "port-0"
def test_runtime_context_reflects_current_tuple_at_call_time(
self, handler, executor
):
# The handler must read the *current* tuple/port out of the context on
# each call — not snapshot them at construction. Drive two calls with
# different intermediate state.
captured: list = []
def capture(_expression, runtime_context):
captured.append((runtime_context["tuple_"], runtime_context["input_"]))
return EvaluatedValue()
with patch(
"core.architecture.handlers.control.evaluate_expression_handler"
".ExpressionEvaluator.evaluate",
side_effect=capture,
):
asyncio.run(
handler.evaluate_python_expression(
EvaluatePythonExpressionRequest(expression="x")
)
)
handler.context.tuple_processing_manager.current_input_tuple = {"col": 99}
handler.context.tuple_processing_manager.current_input_port_id = "port-1"
asyncio.run(
handler.evaluate_python_expression(
EvaluatePythonExpressionRequest(expression="x")
)
)
assert captured == [({"col": 42}, "port-0"), ({"col": 99}, "port-1")]
def test_handles_none_input_tuple_and_port(self, handler):
# Before the worker has received any input, current_input_tuple and
# current_input_port_id are None. The handler must still build a
# context (the user might be evaluating `self.foo`).
handler.context.tuple_processing_manager.current_input_tuple = None
handler.context.tuple_processing_manager.current_input_port_id = None
with patch(
"core.architecture.handlers.control.evaluate_expression_handler"
".ExpressionEvaluator.evaluate",
return_value=EvaluatedValue(),
) as evaluate:
asyncio.run(
handler.evaluate_python_expression(
EvaluatePythonExpressionRequest(expression="self.state")
)
)
_expression, runtime_context = evaluate.call_args.args
assert runtime_context["tuple_"] is None
assert runtime_context["input_"] is None
def test_evaluator_exception_propagates(self, handler):
# If the evaluator raises (bad syntax, attribute error in the user's
# expression, etc.), the handler must not swallow it — the RPC layer
# is responsible for surfacing the failure to the frontend.
with patch(
"core.architecture.handlers.control.evaluate_expression_handler"
".ExpressionEvaluator.evaluate",
side_effect=AttributeError("no such attribute"),
):
with pytest.raises(AttributeError, match="no such attribute"):
asyncio.run(
handler.evaluate_python_expression(
EvaluatePythonExpressionRequest(expression="self.missing")
)
)
================================================
FILE: amber/src/test/python/core/architecture/handlers/control/test_replay_current_tuple_handler.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.architecture.handlers.control.replay_current_tuple_handler import (
RetryCurrentTupleHandler,
)
from core.architecture.managers.pause_manager import PauseType
from proto.org.apache.texera.amber.engine.architecture.rpc import (
EmptyRequest,
EmptyReturn,
)
from proto.org.apache.texera.amber.engine.architecture.worker import WorkerState
def _build_handler(state: WorkerState, current_tuple, remaining_iter):
instance = RetryCurrentTupleHandler.__new__(RetryCurrentTupleHandler)
state_manager = MagicMock()
state_manager.confirm_state.side_effect = lambda *states: state in states
instance.context = SimpleNamespace(
state_manager=state_manager,
tuple_processing_manager=SimpleNamespace(
current_input_tuple=current_tuple,
current_input_tuple_iter=iter(remaining_iter),
),
pause_manager=MagicMock(),
)
return instance
class TestRetryCurrentTupleHandler:
@pytest.fixture
def running_handler(self):
return _build_handler(
WorkerState.RUNNING,
current_tuple={"col": "current"},
remaining_iter=[{"col": "next"}],
)
def test_returns_empty_return(self, running_handler):
result = asyncio.run(running_handler.retry_current_tuple(EmptyRequest()))
assert isinstance(result, EmptyReturn)
def test_chains_current_tuple_back_onto_iterator(self, running_handler):
asyncio.run(running_handler.retry_current_tuple(EmptyRequest()))
# The iterator must now yield the current tuple first, then the
# tuples that were already queued.
chained = list(
running_handler.context.tuple_processing_manager.current_input_tuple_iter
)
assert chained == [{"col": "current"}, {"col": "next"}]
def test_resumes_user_and_exception_pause_in_order(self, running_handler):
asyncio.run(running_handler.retry_current_tuple(EmptyRequest()))
actual = [
call.args[0]
for call in running_handler.context.pause_manager.resume.call_args_list
]
assert actual == [PauseType.USER_PAUSE, PauseType.EXCEPTION_PAUSE]
def test_does_not_resume_debug_pause(self, running_handler):
# Unlike WorkerDebugCommandHandler, retry only releases USER and
# EXCEPTION pauses — DEBUG_PAUSE must remain in effect so an active
# debugging session is not silently dropped.
asyncio.run(running_handler.retry_current_tuple(EmptyRequest()))
resumed = {
call.args[0]
for call in running_handler.context.pause_manager.resume.call_args_list
}
assert PauseType.DEBUG_PAUSE not in resumed
def test_no_op_when_state_is_completed(self):
completed_handler = _build_handler(
WorkerState.COMPLETED,
current_tuple={"col": "current"},
remaining_iter=[{"col": "next"}],
)
result = asyncio.run(completed_handler.retry_current_tuple(EmptyRequest()))
# Iterator must be untouched (no chaining), and no pause type is
# resumed — replaying a tuple after completion is meaningless.
remaining = list(
completed_handler.context.tuple_processing_manager.current_input_tuple_iter
)
assert remaining == [{"col": "next"}]
completed_handler.context.pause_manager.resume.assert_not_called()
assert isinstance(result, EmptyReturn)
def test_chains_even_when_remaining_iter_is_exhausted(self):
handler = _build_handler(
WorkerState.RUNNING,
current_tuple={"col": "lone"},
remaining_iter=[],
)
asyncio.run(handler.retry_current_tuple(EmptyRequest()))
chained = list(
handler.context.tuple_processing_manager.current_input_tuple_iter
)
assert chained == [{"col": "lone"}]
def test_paused_state_still_chains_and_resumes(self):
# The completion guard is `if not confirm_state(COMPLETED)`, so every
# other state — RUNNING, READY, PAUSED, UNINITIALIZED — must take the
# chain+resume path. PAUSED is the most likely real-world entry point
# (the user hits "retry" while the worker is paused on an exception).
handler = _build_handler(
WorkerState.PAUSED,
current_tuple={"col": "current"},
remaining_iter=[{"col": "next"}],
)
asyncio.run(handler.retry_current_tuple(EmptyRequest()))
chained = list(
handler.context.tuple_processing_manager.current_input_tuple_iter
)
assert chained == [{"col": "current"}, {"col": "next"}]
resumed = [
call.args[0] for call in handler.context.pause_manager.resume.call_args_list
]
assert resumed == [PauseType.USER_PAUSE, PauseType.EXCEPTION_PAUSE]
================================================
FILE: amber/src/test/python/core/architecture/handlers/control/test_update_executor_handler.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.architecture.handlers.control.update_executor_handler import (
UpdateExecutorHandler,
)
from proto.org.apache.texera.amber.core import OpExecInitInfo, OpExecWithCode
from proto.org.apache.texera.amber.engine.architecture.rpc import (
EmptyReturn,
UpdateExecutorRequest,
)
def make_request(code: str) -> UpdateExecutorRequest:
"""Build an UpdateExecutorRequest carrying inline Python code."""
return UpdateExecutorRequest(
new_exec_init_info=OpExecInitInfo(op_exec_with_code=OpExecWithCode(code=code))
)
def make_handler(executor_is_source: bool = False) -> UpdateExecutorHandler:
"""Wire a handler with a SimpleNamespace context exposing executor_manager."""
executor_manager = MagicMock()
executor_manager.executor = SimpleNamespace(is_source=executor_is_source)
context = SimpleNamespace(executor_manager=executor_manager)
handler = UpdateExecutorHandler(context)
return handler
class TestUpdateExecutorHandler:
def test_returns_empty_return(self):
handler = make_handler(executor_is_source=False)
result = asyncio.run(handler.update_executor(make_request("# code")))
assert isinstance(result, EmptyReturn)
def test_delegates_extracted_code_to_executor_manager(self):
handler = make_handler(executor_is_source=False)
asyncio.run(handler.update_executor(make_request("user-code-v2")))
handler.context.executor_manager.update_executor.assert_called_once_with(
"user-code-v2", False
)
def test_propagates_current_executor_is_source_not_request_field(self):
# The handler passes the *current* executor's is_source flag forward,
# not anything derived from the request payload. Pin this so a future
# change that reads is_source from the request is reviewed.
handler = make_handler(executor_is_source=True)
asyncio.run(handler.update_executor(make_request("source-code")))
handler.context.executor_manager.update_executor.assert_called_once_with(
"source-code", True
)
def test_extracts_code_via_get_one_of_for_op_exec_with_code(self):
# OpExecInitInfo is a sealed-oneof of {with_class_name, with_code,
# source}. The handler relies on get_one_of to surface the populated
# variant; if the request carries a different variant the handler must
# not silently call the manager with stale data — instead the call
# surfaces an attribute error on `.code`. Pin the contract explicitly.
from proto.org.apache.texera.amber.core import OpExecWithClassName
handler = make_handler(executor_is_source=False)
request = UpdateExecutorRequest(
new_exec_init_info=OpExecInitInfo(
op_exec_with_class_name=OpExecWithClassName(class_name="X")
)
)
with pytest.raises(AttributeError):
asyncio.run(handler.update_executor(request))
handler.context.executor_manager.update_executor.assert_not_called()
================================================
FILE: amber/src/test/python/core/architecture/managers/test_console_message_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime, timedelta
from core.architecture.managers.console_message_manager import ConsoleMessageManager
from proto.org.apache.texera.amber.engine.architecture.rpc import (
ConsoleMessage,
ConsoleMessageType,
)
def _msg(title: str) -> ConsoleMessage:
return ConsoleMessage(
worker_id="w0",
timestamp=datetime.now(),
msg_type=ConsoleMessageType.PRINT,
source="src",
title=title,
message=title,
)
class TestConsoleMessageManager:
def test_initially_force_flush_drains_empty(self):
mgr = ConsoleMessageManager()
# No messages put yet — force_flush still yields zero items.
assert list(mgr.get_messages(force_flush=True)) == []
def test_force_flush_drains_all_buffered_in_order(self):
mgr = ConsoleMessageManager()
for t in ("a", "b", "c"):
mgr.put_message(_msg(t))
flushed = list(mgr.get_messages(force_flush=True))
assert [m.title for m in flushed] == ["a", "b", "c"]
# A second drain must come back empty — get() is consumptive.
assert list(mgr.get_messages(force_flush=True)) == []
def test_get_without_flush_below_threshold_yields_nothing(self):
# Below max_message_num (default 10) and within max_flush_interval
# (default 500ms) — the underlying TimedBuffer should withhold output.
# Pin `_last_output_time` to "now" right before the assertion so the
# `(now - _last_output_time).seconds >= 1` branch can't fire if the
# rest of the test happens to run more than ~1s after construction.
mgr = ConsoleMessageManager()
mgr.put_message(_msg("only"))
mgr.print_buf._last_output_time = datetime.now()
assert list(mgr.get_messages(force_flush=False)) == []
# The withheld message must still be drainable on a force flush.
assert [m.title for m in mgr.get_messages(force_flush=True)] == ["only"]
def test_get_without_flush_at_or_over_max_message_num_drains(self):
# Once buffered count crosses max_message_num (default 10), the
# buffer should auto-flush even without force_flush=True.
mgr = ConsoleMessageManager()
for i in range(10):
mgr.put_message(_msg(f"m{i}"))
flushed = [m.title for m in mgr.get_messages(force_flush=False)]
assert flushed == [f"m{i}" for i in range(10)]
def test_get_drains_when_last_output_time_is_stale(self):
# Backdate the buffer's `_last_output_time` directly so the
# >=500ms branch fires even with a single message and
# force_flush=False, without sleeping or monkeypatching `datetime`.
mgr = ConsoleMessageManager()
mgr.put_message(_msg("stale"))
mgr.print_buf._last_output_time = datetime.now() - timedelta(seconds=2)
flushed = [m.title for m in mgr.get_messages(force_flush=False)]
assert flushed == ["stale"]
================================================
FILE: amber/src/test/python/core/architecture/managers/test_debug_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from threading import Condition
import pytest
from core.architecture.managers.debug_manager import DebugManager
class TestDebugManager:
@pytest.fixture
def debug_manager(self):
return DebugManager(Condition())
def test_it_can_init(self, debug_manager):
assert debug_manager.debugger is not None
assert debug_manager.debugger.prompt == ""
def test_it_has_no_command_initially(self, debug_manager):
assert not debug_manager.has_debug_command()
def test_it_has_no_event_initially(self, debug_manager):
assert not debug_manager.has_debug_event()
def test_put_command_sets_has_debug_command(self, debug_manager):
debug_manager.put_debug_command("n")
assert debug_manager.has_debug_command()
def test_get_debug_event_returns_flushed_output(self, debug_manager):
# Pdb writes to its stdout via the SingleBlockingIO; simulate that path
# directly so we don't have to spin up a real debugging session.
debug_manager.debugger.stdout.write("hit breakpoint")
debug_manager.debugger.stdout.flush()
assert debug_manager.has_debug_event()
assert debug_manager.get_debug_event() == "hit breakpoint\n"
assert not debug_manager.has_debug_event()
def test_command_pipe_and_event_pipe_are_independent(self, debug_manager):
debug_manager.put_debug_command("step")
assert debug_manager.has_debug_command()
assert not debug_manager.has_debug_event()
debug_manager.debugger.stdout.write("event")
debug_manager.debugger.stdout.flush()
# Putting a command must not consume an event, and vice versa.
assert debug_manager.has_debug_command()
assert debug_manager.has_debug_event()
def test_pdb_is_wired_to_debug_pipes(self, debug_manager):
# The Pdb instance must read from the same IO that put_debug_command
# writes to, and write to the same IO that get_debug_event reads from.
debug_manager.put_debug_command("c")
# Reading via the debugger's stdin must see the queued command.
assert debug_manager.debugger.stdin.readline() == "c\n"
debug_manager.debugger.stdout.write("paused")
debug_manager.debugger.stdout.flush()
assert debug_manager.get_debug_event() == "paused\n"
def test_event_pipe_supports_multiple_round_trips(self, debug_manager):
for line in ("first", "second", "third"):
debug_manager.debugger.stdout.write(line)
debug_manager.debugger.stdout.flush()
assert debug_manager.get_debug_event() == f"{line}\n"
assert not debug_manager.has_debug_event()
def test_debugger_uses_nosigint_to_avoid_signal_install(self, debug_manager):
# We construct Pdb with nosigint=True to avoid touching signal handlers
# in the worker thread. Guard against accidental flips.
assert debug_manager.debugger.nosigint is True
# ----- edge cases / quirks -----
def test_put_empty_command_still_marks_command_present(self, debug_manager):
# SingleBlockingIO.flush always commits buf + "\n" to value, so even
# an empty command becomes a "\n" line and shows up as a pending
# command. Documents current behavior.
debug_manager.put_debug_command("")
assert debug_manager.has_debug_command()
assert debug_manager.debugger.stdin.readline() == "\n"
def test_put_overwrites_unconsumed_command(self, debug_manager):
# The command pipe holds at most one value. A second put without an
# intervening consume silently overwrites the first — known data-loss
# quirk of SingleBlockingIO. Pinning this so callers don't accidentally
# rely on queued semantics.
debug_manager.put_debug_command("first")
debug_manager.put_debug_command("second")
assert debug_manager.debugger.stdin.readline() == "second\n"
def test_put_command_with_embedded_newline_is_passed_verbatim(self, debug_manager):
# An embedded newline is not sanitized; pdb would see the raw bytes.
debug_manager.put_debug_command("step\nlist")
assert debug_manager.debugger.stdin.readline() == "step\nlist\n"
def test_event_pipe_overwrites_unconsumed_event(self, debug_manager):
debug_manager.debugger.stdout.write("first")
debug_manager.debugger.stdout.flush()
debug_manager.debugger.stdout.write("second")
debug_manager.debugger.stdout.flush()
assert debug_manager.get_debug_event() == "second\n"
================================================
FILE: amber/src/test/python/core/architecture/managers/test_embedded_control_message_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.architecture.managers.embedded_control_message_manager import (
EmbeddedControlMessageManager,
)
from proto.org.apache.texera.amber.core import (
ActorVirtualIdentity,
ChannelIdentity,
EmbeddedControlMessageIdentity,
)
from proto.org.apache.texera.amber.engine.architecture.rpc import (
EmbeddedControlMessage,
EmbeddedControlMessageType,
)
SELF_ID = ActorVirtualIdentity(name="self")
def _channel(from_name: str, to_name: str = "self", is_control: bool = False):
return ChannelIdentity(
from_worker_id=ActorVirtualIdentity(name=from_name),
to_worker_id=ActorVirtualIdentity(name=to_name),
is_control=is_control,
)
def _make_ecm(
ecm_type: EmbeddedControlMessageType,
scope=None,
) -> EmbeddedControlMessage:
# Each call constructs a fresh `EmbeddedControlMessageIdentity(id="ecm-1")`,
# but the dataclass-style equality means all of them hash to the same key
# in `EmbeddedControlMessageManager.ecm_received`, so messages built from
# different invocations still aggregate under the single "ecm-1" entry.
return EmbeddedControlMessage(
id=EmbeddedControlMessageIdentity(id="ecm-1"),
ecm_type=ecm_type,
scope=scope or [],
)
def _gateway_with_data_channels(*data_channels: ChannelIdentity):
"""Stub InputManager that exposes only `get_all_data_channel_ids`."""
gw = MagicMock()
gw.get_all_data_channel_ids.return_value = set(data_channels)
gw.get_all_channel_ids.return_value = set(data_channels)
return gw
def _gateway_with_ports(port_layout: dict, all_channels: set):
"""Stub InputManager that supports per-port lookups for PORT_ALIGNMENT.
`port_layout` maps PortIdentity-key (use any hashable) -> set of channels.
`get_port_id(channel)` resolves channel -> port.
"""
gw = MagicMock()
channel_to_port = {ch: pid for pid, chs in port_layout.items() for ch in chs}
gw.get_port_id.side_effect = lambda ch: channel_to_port[ch]
gw.get_port.side_effect = lambda pid: SimpleNamespace(
get_channels=lambda chs=port_layout[pid]: chs
)
gw.get_all_data_channel_ids.return_value = set(all_channels)
gw.get_all_channel_ids.return_value = set(all_channels)
return gw
class TestEcmAllAlignment:
def test_returns_false_until_all_channels_received(self):
c1, c2, c3 = _channel("a"), _channel("b"), _channel("c")
gw = _gateway_with_data_channels(c1, c2, c3)
mgr = EmbeddedControlMessageManager(SELF_ID, gw)
ecm = _make_ecm(EmbeddedControlMessageType.ALL_ALIGNMENT)
assert mgr.is_ecm_aligned(c1, ecm) is False
assert mgr.is_ecm_aligned(c2, ecm) is False
# The third (last) channel completes the alignment.
assert mgr.is_ecm_aligned(c3, ecm) is True
def test_dict_is_cleaned_up_after_full_alignment(self):
# Pin the cleanup contract: once every expected channel has reported,
# the per-id entry must be deleted so a recycled id cannot bleed
# state into the next ECM round.
c1, c2 = _channel("a"), _channel("b")
gw = _gateway_with_data_channels(c1, c2)
mgr = EmbeddedControlMessageManager(SELF_ID, gw)
ecm = _make_ecm(EmbeddedControlMessageType.ALL_ALIGNMENT)
mgr.is_ecm_aligned(c1, ecm)
mgr.is_ecm_aligned(c2, ecm)
assert ecm.id not in mgr.ecm_received
class TestEcmNoAlignment:
def test_first_message_completes_subsequent_do_not(self):
c1, c2 = _channel("a"), _channel("b")
gw = _gateway_with_data_channels(c1, c2)
mgr = EmbeddedControlMessageManager(SELF_ID, gw)
ecm = _make_ecm(EmbeddedControlMessageType.NO_ALIGNMENT)
# First channel: ecm_received={c1}, len==1 → True.
assert mgr.is_ecm_aligned(c1, ecm) is True
# Second channel: ecm_received={c1,c2}, len==2 → False.
# (And on this call from_all_channels=True so the dict is dropped.)
assert mgr.is_ecm_aligned(c2, ecm) is False
assert ecm.id not in mgr.ecm_received
class TestEcmPortAlignment:
def test_completes_when_a_ports_channels_have_all_arrived(self):
a1, a2 = _channel("a1"), _channel("a2")
b1 = _channel("b1")
ports = {"portA": {a1, a2}, "portB": {b1}}
gw = _gateway_with_ports(ports, all_channels={a1, a2, b1})
mgr = EmbeddedControlMessageManager(SELF_ID, gw)
ecm = _make_ecm(EmbeddedControlMessageType.PORT_ALIGNMENT)
# Port A needs both a1 and a2.
assert mgr.is_ecm_aligned(a1, ecm) is False
assert mgr.is_ecm_aligned(a2, ecm) is True
# Port B is single-channel, so b1 alone completes its port.
assert mgr.is_ecm_aligned(b1, ecm) is True
def test_unsupported_ecm_type_raises_value_error(self):
# The `else: raise ValueError(...)` branch — guard against any new
# enum value silently falling through.
c1 = _channel("a")
gw = _gateway_with_data_channels(c1)
mgr = EmbeddedControlMessageManager(SELF_ID, gw)
# Use a sentinel that is not one of the three known values. The enum
# type is an IntEnum, so an unused integer won't match any branch.
ecm = _make_ecm(EmbeddedControlMessageType.ALL_ALIGNMENT)
ecm.ecm_type = 999 # type: ignore[assignment]
with pytest.raises(ValueError, match="Unsupported ECM type"):
mgr.is_ecm_aligned(c1, ecm)
class TestEcmScope:
def test_scope_intersects_with_all_channel_ids(self):
# When `ecm.scope` is set, get_channels_within_scope filters the
# gateway's known channels to only those whose `to_worker_id == self`
# AND that appear in the gateway's `get_all_channel_ids`.
c_in_scope = _channel("a", to_name="self")
c_other_target = _channel("b", to_name="someone_else")
c_not_in_gateway = _channel("c", to_name="self")
gw = MagicMock()
gw.get_all_channel_ids.return_value = {c_in_scope, c_other_target}
gw.get_all_data_channel_ids.return_value = {c_in_scope, c_other_target}
mgr = EmbeddedControlMessageManager(SELF_ID, gw)
ecm = _make_ecm(
EmbeddedControlMessageType.ALL_ALIGNMENT,
scope=[c_in_scope, c_not_in_gateway],
)
# Only c_in_scope is in scope AND known to the gateway. After
# receiving it, alignment should complete.
assert mgr.is_ecm_aligned(c_in_scope, ecm) is True
def test_no_scope_falls_back_to_all_data_channels(self):
# When scope is empty, the manager uses get_all_data_channel_ids()
# rather than get_all_channel_ids() — control vs data routing.
c_data = _channel("a", is_control=False)
c_control = _channel("b", is_control=True)
gw = MagicMock()
gw.get_all_data_channel_ids.return_value = {c_data}
gw.get_all_channel_ids.return_value = {c_data, c_control}
mgr = EmbeddedControlMessageManager(SELF_ID, gw)
ecm = _make_ecm(EmbeddedControlMessageType.ALL_ALIGNMENT, scope=[])
assert mgr.is_ecm_aligned(c_data, ecm) is True
================================================
FILE: amber/src/test/python/core/architecture/managers/test_exception_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from core.architecture.managers.exception_manager import ExceptionManager
from core.models import ExceptionInfo
def _real_exc_info() -> ExceptionInfo:
"""Build a real ExceptionInfo by raising and catching, so the traceback
object is the actual one Python produces."""
try:
raise RuntimeError("boom")
except RuntimeError:
exc, value, tb = sys.exc_info()
return ExceptionInfo(exc=exc, value=value, tb=tb)
class TestExceptionManager:
def test_initial_state(self):
mgr = ExceptionManager()
assert mgr.exc_info is None
assert mgr.exc_info_history == []
assert mgr.has_exception() is False
def test_set_then_has_exception_true(self):
mgr = ExceptionManager()
info = _real_exc_info()
mgr.set_exception_info(info)
assert mgr.has_exception() is True
assert mgr.exc_info is info
assert mgr.exc_info_history == [info]
def test_get_exc_info_returns_and_clears_current_only(self):
# Pin the documented contract: get_exc_info returns the latest stashed
# info AND clears the live slot, but the history must keep it. A
# regression that also clears history would break replay/retry flows.
mgr = ExceptionManager()
info = _real_exc_info()
mgr.set_exception_info(info)
assert mgr.get_exc_info() is info
assert mgr.exc_info is None
assert mgr.has_exception() is False
assert mgr.exc_info_history == [info]
def test_get_exc_info_when_none_returns_none(self):
mgr = ExceptionManager()
assert mgr.get_exc_info() is None
def test_history_accumulates_in_order(self):
mgr = ExceptionManager()
first = _real_exc_info()
second = _real_exc_info()
mgr.set_exception_info(first)
mgr.set_exception_info(second)
assert mgr.exc_info is second
assert mgr.exc_info_history == [first, second]
# Consuming the latest must leave both entries in history.
assert mgr.get_exc_info() is second
assert mgr.exc_info_history == [first, second]
================================================
FILE: amber/src/test/python/core/architecture/managers/test_executor_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
import pytest
from unittest.mock import MagicMock
from core.architecture.managers.executor_manager import ExecutorManager
# Sample operator code for testing
SAMPLE_OPERATOR_CODE = """
from pytexera import *
class TestOperator(UDFOperatorV2):
def process_tuple(self, tuple_: Tuple, port: int) -> Iterator[Optional[TupleLike]]:
yield tuple_
"""
SAMPLE_SOURCE_OPERATOR_CODE = """
from pytexera import *
class TestSourceOperator(UDFSourceOperator):
def produce(self) -> Iterator[Union[TupleLike, TableLike, None]]:
yield Tuple({"test": "data"})
"""
class TestExecutorManager:
"""Test suite for ExecutorManager, focusing on R UDF plugin support."""
@pytest.fixture
def executor_manager(self):
"""Create a fresh ExecutorManager instance for each test."""
manager = ExecutorManager()
yield manager
# Cleanup: close the temp filesystem
if hasattr(manager, "_fs"):
manager.close()
def _mock_r_plugin(self, executor_class_name, is_source):
"""
Helper to mock the texera_r plugin module.
:param executor_class_name: Name of the executor class (e.g., 'RTupleExecutor')
:param is_source: Whether the executor is a source operator
:return: Tuple of (mock_texera_r, mock_executor_instance)
"""
from core.models import SourceOperator, Operator
mock_texera_r = MagicMock()
mock_executor_class = MagicMock()
setattr(mock_texera_r, executor_class_name, mock_executor_class)
# Use appropriate spec based on operator type
spec_class = SourceOperator if is_source else Operator
mock_executor_instance = MagicMock(spec=spec_class)
mock_executor_instance.is_source = is_source
mock_executor_class.return_value = mock_executor_instance
sys.modules["texera_r"] = mock_texera_r
return mock_texera_r, mock_executor_instance
def _cleanup_r_plugin(self):
"""Remove the mocked texera_r module from sys.modules."""
if "texera_r" in sys.modules:
del sys.modules["texera_r"]
def test_initialization(self, executor_manager):
"""Test that ExecutorManager initializes correctly."""
assert executor_manager.executor is None
assert executor_manager.operator_module_name is None
def test_reject_r_tuple_language(self, executor_manager):
"""Test that 'r-tuple' language is rejected with ImportError when plugin is not available."""
with pytest.raises(ImportError) as exc_info:
executor_manager.initialize_executor(
code=SAMPLE_OPERATOR_CODE, is_source=False, language="r-tuple"
)
# Verify the error message mentions R operators require the texera-rudf package
assert "texera-rudf" in str(exc_info.value) or "R operators require" in str(
exc_info.value
)
def test_reject_r_table_language(self, executor_manager):
"""Test that 'r-table' language is rejected with ImportError when plugin is not available."""
with pytest.raises(ImportError) as exc_info:
executor_manager.initialize_executor(
code=SAMPLE_OPERATOR_CODE, is_source=False, language="r-table"
)
# Verify the error message mentions R operators require the texera-rudf package
assert "texera-rudf" in str(exc_info.value) or "R operators require" in str(
exc_info.value
)
def test_accept_r_tuple_language_with_plugin(self, executor_manager):
"""Test that 'r-tuple' language is accepted when plugin is available."""
_, mock_executor = self._mock_r_plugin("RTupleExecutor", is_source=False)
try:
executor_manager.initialize_executor(
code="# R code", is_source=False, language="r-tuple"
)
assert executor_manager.executor == mock_executor
finally:
self._cleanup_r_plugin()
def test_accept_r_table_language_with_plugin(self, executor_manager):
"""Test that 'r-table' language is accepted when plugin is available."""
_, mock_executor = self._mock_r_plugin("RTableExecutor", is_source=False)
try:
executor_manager.initialize_executor(
code="# R code", is_source=False, language="r-table"
)
assert executor_manager.executor == mock_executor
finally:
self._cleanup_r_plugin()
def test_accept_r_tuple_source_with_plugin(self, executor_manager):
"""Test that 'r-tuple' source operators work when plugin is available."""
_, mock_executor = self._mock_r_plugin("RTupleSourceExecutor", is_source=True)
try:
executor_manager.initialize_executor(
code="# R code", is_source=True, language="r-tuple"
)
assert executor_manager.executor == mock_executor
finally:
self._cleanup_r_plugin()
def test_accept_r_table_source_with_plugin(self, executor_manager):
"""Test that 'r-table' source operators work when plugin is available."""
_, mock_executor = self._mock_r_plugin("RTableSourceExecutor", is_source=True)
try:
executor_manager.initialize_executor(
code="# R code", is_source=True, language="r-table"
)
assert executor_manager.executor == mock_executor
finally:
self._cleanup_r_plugin()
def test_accept_python_language_regular_operator(self, executor_manager):
"""Test that 'python' language is accepted for regular operators."""
# This should not raise any assertion error
executor_manager.initialize_executor(
code=SAMPLE_OPERATOR_CODE, is_source=False, language="python"
)
# Verify executor was initialized
assert executor_manager.executor is not None
# Module name comes from a process-wide counter, so it has the
# right shape but its exact value depends on what other tests
# have run in the same pytest session.
assert executor_manager.operator_module_name is not None
assert executor_manager.operator_module_name.startswith("udf-v")
assert executor_manager.executor.is_source is False
def test_accept_python_language_source_operator(self, executor_manager):
"""Test that 'python' language is accepted for source operators."""
# This should not raise any assertion error
executor_manager.initialize_executor(
code=SAMPLE_SOURCE_OPERATOR_CODE, is_source=True, language="python"
)
# Verify executor was initialized
assert executor_manager.executor is not None
assert executor_manager.operator_module_name is not None
assert executor_manager.operator_module_name.startswith("udf-v")
assert executor_manager.executor.is_source is True
def test_reject_other_unsupported_languages(self, executor_manager):
"""Test that other arbitrary languages still work (no R-specific check)."""
# Languages other than r-tuple and r-table should be allowed to pass
# the assertion, though they may fail at code execution
try:
executor_manager.initialize_executor(
code=SAMPLE_OPERATOR_CODE,
is_source=False,
language="javascript", # arbitrary language
)
# If we get here, the assertion passed (which is correct behavior)
# But the code execution might fail, which is fine
except AssertionError:
# Should NOT raise AssertionError for non-R languages
pytest.fail("Should not raise AssertionError for non-R languages")
except Exception:
# Other exceptions (like import errors) are expected and acceptable
pass
def test_gen_module_file_name_increments(self, executor_manager):
"""Test that module file names increment monotonically.
The counter is process-wide so the absolute starting value
depends on prior tests in the same pytest session; only the
relative ordering matters for correctness.
"""
module1, file1 = executor_manager.gen_module_file_name()
module2, file2 = executor_manager.gen_module_file_name()
module3, file3 = executor_manager.gen_module_file_name()
def version(module_name: str) -> int:
return int(module_name.removeprefix("udf-v"))
v1 = version(module1)
assert version(module2) == v1 + 1
assert version(module3) == v1 + 2
assert file1 == f"{module1}.py"
assert file2 == f"{module2}.py"
assert file3 == f"{module3}.py"
def test_is_concrete_operator_static_method(self):
"""Test the is_concrete_operator static method."""
from core.models import TupleOperatorV2
# Should return True for concrete operator classes
# Note: We can't easily test with actual concrete classes here without imports
# This test just verifies the method exists and is callable
assert hasattr(ExecutorManager, "is_concrete_operator")
assert callable(ExecutorManager.is_concrete_operator)
# Test with non-class
assert ExecutorManager.is_concrete_operator("not a class") is False
assert ExecutorManager.is_concrete_operator(123) is False
# Test with abstract base classes (TupleOperatorV2 has abstract methods)
assert ExecutorManager.is_concrete_operator(TupleOperatorV2) is False
def test_regular_operator_is_not_source(self, executor_manager):
"""Test that regular operator with is_source=False works correctly."""
executor_manager.initialize_executor(
code=SAMPLE_OPERATOR_CODE, is_source=False, language="python"
)
assert executor_manager.executor.is_source is False
def test_source_operator_mismatch_raises_error(self, executor_manager):
"""Test that mismatched source operator flag raises AssertionError."""
with pytest.raises(AssertionError) as exc_info:
executor_manager.initialize_executor(
code=SAMPLE_OPERATOR_CODE,
is_source=True, # Wrong: regular operator but marked as source
language="python",
)
assert "SourceOperator API" in str(exc_info.value)
REPLACEMENT_OPERATOR_CODE = """
from pytexera import *
class ReplacementOperator(UDFOperatorV2):
def process_tuple(self, tuple_: Tuple, port: int) -> Iterator[Optional[TupleLike]]:
yield tuple_
"""
NO_OPERATOR_CODE = """
def helper():
return 42
"""
TWO_OPERATORS_CODE = """
from pytexera import *
class FirstOperator(UDFOperatorV2):
def process_tuple(self, tuple_: Tuple, port: int) -> Iterator[Optional[TupleLike]]:
yield tuple_
class SecondOperator(UDFOperatorV2):
def process_tuple(self, tuple_: Tuple, port: int) -> Iterator[Optional[TupleLike]]:
yield tuple_
"""
class TestUpdateExecutor:
"""Test suite for ExecutorManager.update_executor.
Notes on test isolation: the existing TestExecutorManager fixture cannot
fully clean up the udf-vN modules it imports (its `hasattr(manager, "_fs")`
cleanup guard is buggy — the actual cached_property key is `fs`), so a
given udf-v1 module may already live in sys.modules with a path attached
to a previous test's tmp filesystem. These tests therefore avoid asserting
on attributes baked into a specific operator class and instead use
setattr/getattr-only semantics that hold regardless of which cached
module satisfies the import.
"""
@pytest.fixture
def initialized_manager(self):
manager = ExecutorManager()
manager.initialize_executor(
code=SAMPLE_OPERATOR_CODE, is_source=False, language="python"
)
# Stamp custom attributes on the live instance so the dict-preservation
# check works even if the underlying class came from a cached module.
manager.executor.runtime_field = "set-after-init"
manager.executor.counter = 6
yield manager
manager.close()
def test_update_preserves_pre_update_dict_state(self, initialized_manager):
before = initialized_manager.executor
before_dict = dict(before.__dict__)
initialized_manager.update_executor(
code=REPLACEMENT_OPERATOR_CODE, is_source=False
)
# update_executor reuses the prior __dict__ on a freshly instantiated
# operator — verify both halves: a NEW instance, but the OLD state.
assert initialized_manager.executor is not before
assert initialized_manager.executor.runtime_field == "set-after-init"
assert initialized_manager.executor.counter == 6
# Assert key presence explicitly so a missing key with an expected
# value of None doesn't slip past via dict.get()'s default.
after_dict = initialized_manager.executor.__dict__
for key, value in before_dict.items():
assert key in after_dict, f"key {key!r} missing after update"
assert after_dict[key] == value
def test_update_advances_module_name_monotonically(self, initialized_manager):
# The module-name counter is process-wide, so absolute values
# depend on prior tests in the same pytest session; only the
# relative bump matters.
before = initialized_manager.operator_module_name
assert before is not None and before.startswith("udf-v")
initialized_manager.update_executor(
code=REPLACEMENT_OPERATOR_CODE, is_source=False
)
after = initialized_manager.operator_module_name
assert after is not None and after.startswith("udf-v")
assert int(after.removeprefix("udf-v")) == int(before.removeprefix("udf-v")) + 1
def test_update_with_source_mismatch_raises_assertion(self, initialized_manager):
# The replacement code is a regular operator, but is_source=True asks
# the manager to treat it as a source operator. Same guardrail as
# initialize_executor.
with pytest.raises(AssertionError) as exc_info:
initialized_manager.update_executor(
code=REPLACEMENT_OPERATOR_CODE, is_source=True
)
assert "SourceOperator API" in str(exc_info.value)
def test_update_with_no_operator_class_raises_assertion(self, initialized_manager):
# load_executor_definition asserts exactly one Operator subclass exists
# in the module — an empty module trips that assertion.
with pytest.raises(AssertionError) as exc_info:
initialized_manager.update_executor(code=NO_OPERATOR_CODE, is_source=False)
assert "one and only one Operator" in str(exc_info.value)
def test_update_with_multiple_operator_classes_raises_assertion(
self, initialized_manager
):
with pytest.raises(AssertionError) as exc_info:
initialized_manager.update_executor(
code=TWO_OPERATORS_CODE, is_source=False
)
assert "one and only one Operator" in str(exc_info.value)
def test_repeated_updates_keep_carrying_the_running_state(
self, initialized_manager
):
# Update once, mutate the new instance, then update again — the second
# update must see the *latest* state, not the snapshot from before
# the first update.
initialized_manager.update_executor(
code=REPLACEMENT_OPERATOR_CODE, is_source=False
)
initialized_manager.executor.counter = 42
initialized_manager.executor.added_after_update = True
initialized_manager.update_executor(
code=REPLACEMENT_OPERATOR_CODE, is_source=False
)
assert initialized_manager.executor.counter == 42
assert initialized_manager.executor.added_after_update is True
================================================
FILE: amber/src/test/python/core/architecture/managers/test_pause_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.architecture.managers import StateManager
from core.architecture.managers.pause_manager import PauseManager, PauseType
from core.models import InternalQueue
from proto.org.apache.texera.amber.engine.architecture.worker import WorkerState
class TestPauseManager:
@pytest.fixture
def input_queue(self):
return InternalQueue()
@pytest.fixture
def state_manager(self):
return StateManager(
{
WorkerState.UNINITIALIZED: {WorkerState.READY},
WorkerState.READY: {WorkerState.PAUSED, WorkerState.RUNNING},
WorkerState.RUNNING: {WorkerState.PAUSED, WorkerState.COMPLETED},
WorkerState.PAUSED: {WorkerState.RUNNING},
WorkerState.COMPLETED: set(),
},
WorkerState.READY, # initial state set to READY for testing purpose
)
@pytest.fixture
def pause_manager(self, input_queue, state_manager):
return PauseManager(input_queue, state_manager)
def test_it_can_init(self, pause_manager):
pass
def test_it_is_not_paused_initially(self, pause_manager):
assert not pause_manager.is_paused()
def test_it_can_be_paused_and_resumed(self, pause_manager):
pause_manager.pause(PauseType.USER_PAUSE)
assert pause_manager.is_paused()
pause_manager.resume(PauseType.USER_PAUSE)
assert not pause_manager.is_paused()
def test_it_can_be_paused_when_paused(self, pause_manager):
pause_manager.pause(PauseType.USER_PAUSE)
assert pause_manager.is_paused()
pause_manager.pause(PauseType.USER_PAUSE)
assert pause_manager.is_paused()
pause_manager.resume(PauseType.USER_PAUSE)
assert not pause_manager.is_paused()
def test_it_can_be_resumed_when_resumed(self, pause_manager):
pause_manager.pause(PauseType.USER_PAUSE)
assert pause_manager.is_paused()
pause_manager.resume(PauseType.USER_PAUSE)
assert not pause_manager.is_paused()
pause_manager.resume(PauseType.USER_PAUSE)
assert not pause_manager.is_paused()
================================================
FILE: amber/src/test/python/core/architecture/managers/test_state_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.architecture.managers.state_manager import (
InvalidStateException,
InvalidTransitionException,
StateManager,
)
from proto.org.apache.texera.amber.engine.architecture.worker import WorkerState
class TestStateManager:
@pytest.fixture
def state_manager(self):
return StateManager(
{
WorkerState.UNINITIALIZED: {WorkerState.READY},
WorkerState.READY: {WorkerState.PAUSED, WorkerState.RUNNING},
WorkerState.RUNNING: {WorkerState.PAUSED, WorkerState.COMPLETED},
WorkerState.PAUSED: {WorkerState.RUNNING},
WorkerState.COMPLETED: set(),
},
WorkerState.UNINITIALIZED,
)
def test_it_can_init(self, state_manager):
pass
def test_it_can_transit_to_defined_state(self, state_manager):
state_manager.assert_state(WorkerState.UNINITIALIZED)
for state in [
WorkerState.READY,
WorkerState.PAUSED,
WorkerState.RUNNING,
WorkerState.COMPLETED,
]:
state_manager.transit_to(state)
assert state_manager.confirm_state(state)
state_manager.assert_state(state)
def test_it_raises_exception_when_transit_to_undefined_state(self, state_manager):
state_manager.assert_state(WorkerState.UNINITIALIZED)
for state in [WorkerState.READY, WorkerState.PAUSED]:
state_manager.transit_to(state)
assert state_manager.confirm_state(state)
state_manager.assert_state(state)
with pytest.raises(InvalidTransitionException):
state_manager.transit_to(WorkerState.READY)
def test_it_raises_exception_when_asserting_a_different_state(self, state_manager):
state_manager.assert_state(WorkerState.UNINITIALIZED)
for state in [WorkerState.READY, WorkerState.PAUSED]:
state_manager.transit_to(state)
assert state_manager.confirm_state(state)
state_manager.assert_state(state)
with pytest.raises(InvalidStateException):
state_manager.assert_state(WorkerState.COMPLETED)
================================================
FILE: amber/src/test/python/core/architecture/managers/test_state_processing_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from core.architecture.managers.state_processing_manager import StateProcessingManager
from core.models.state import State
class TestStateProcessingManager:
def test_initial_state_is_none(self):
mgr = StateProcessingManager()
assert mgr.current_input_state is None
assert mgr.current_output_state is None
assert mgr.get_input_state() is None
assert mgr.get_output_state() is None
def test_get_input_state_returns_then_clears(self):
# The contract is "consume-once": the first get returns the stashed
# value, and subsequent gets see None until the slot is set again.
mgr = StateProcessingManager()
s = State({"k": "v"})
mgr.current_input_state = s
assert mgr.get_input_state() is s
assert mgr.current_input_state is None
assert mgr.get_input_state() is None
def test_get_output_state_returns_then_clears(self):
mgr = StateProcessingManager()
s = State({"k": "v"})
mgr.current_output_state = s
assert mgr.get_output_state() is s
assert mgr.current_output_state is None
assert mgr.get_output_state() is None
def test_input_and_output_slots_are_independent(self):
# Reading the input slot must not consume the output slot, and
# vice versa — pin the no-cross-talk contract.
mgr = StateProcessingManager()
in_s = State({"side": "in"})
out_s = State({"side": "out"})
mgr.current_input_state = in_s
mgr.current_output_state = out_s
assert mgr.get_input_state() is in_s
assert mgr.current_output_state is out_s
assert mgr.get_output_state() is out_s
================================================
FILE: amber/src/test/python/core/architecture/managers/test_statistics_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.architecture.managers.statistics_manager import StatisticsManager
from proto.org.apache.texera.amber.core import PortIdentity
def _port(pid: int) -> PortIdentity:
return PortIdentity(id=pid, internal=False)
class TestStatisticsManagerDefaults:
def test_get_statistics_with_no_activity(self):
stats = StatisticsManager().get_statistics()
assert list(stats.input_tuple_metrics) == []
assert list(stats.output_tuple_metrics) == []
assert stats.data_processing_time == 0
assert stats.control_processing_time == 0
# idle_time = total_execution - data - control = 0 at init.
assert stats.idle_time == 0
class TestStatisticsManagerInputOutput:
def test_increase_input_aggregates_count_and_size_per_port(self):
mgr = StatisticsManager()
mgr.increase_input_statistics(_port(0), 10)
mgr.increase_input_statistics(_port(0), 5)
mgr.increase_input_statistics(_port(1), 7)
stats = mgr.get_statistics()
by_port = {m.port_id.id: m.tuple_metrics for m in stats.input_tuple_metrics}
assert by_port[0].count == 2
assert by_port[0].size == 15
assert by_port[1].count == 1
assert by_port[1].size == 7
# Output side stayed empty.
assert list(stats.output_tuple_metrics) == []
def test_increase_output_aggregates_count_and_size_per_port(self):
mgr = StatisticsManager()
mgr.increase_output_statistics(_port(2), 100)
mgr.increase_output_statistics(_port(2), 200)
stats = mgr.get_statistics()
by_port = {m.port_id.id: m.tuple_metrics for m in stats.output_tuple_metrics}
assert by_port[2].count == 2
assert by_port[2].size == 300
assert list(stats.input_tuple_metrics) == []
def test_zero_size_input_is_allowed(self):
# Pin: zero is valid (size validation is `< 0`, not `<= 0`).
# Empty tuples / heartbeat-style records can legitimately be size 0.
mgr = StatisticsManager()
mgr.increase_input_statistics(_port(0), 0)
stats = mgr.get_statistics()
m = list(stats.input_tuple_metrics)[0].tuple_metrics
assert m.count == 1
assert m.size == 0
@pytest.mark.parametrize(
"method", ["increase_input_statistics", "increase_output_statistics"]
)
def test_negative_size_raises(self, method):
mgr = StatisticsManager()
with pytest.raises(ValueError, match="Tuple size must be non-negative"):
getattr(mgr, method)(_port(0), -1)
class TestStatisticsManagerProcessingTime:
def test_data_and_control_time_accumulate(self):
mgr = StatisticsManager()
mgr.increase_data_processing_time(100)
mgr.increase_data_processing_time(50)
mgr.increase_control_processing_time(20)
stats = mgr.get_statistics()
assert stats.data_processing_time == 150
assert stats.control_processing_time == 20
def test_zero_processing_time_is_allowed(self):
mgr = StatisticsManager()
mgr.increase_data_processing_time(0)
mgr.increase_control_processing_time(0)
stats = mgr.get_statistics()
assert stats.data_processing_time == 0
assert stats.control_processing_time == 0
@pytest.mark.parametrize(
"method",
["increase_data_processing_time", "increase_control_processing_time"],
)
def test_negative_time_raises(self, method):
mgr = StatisticsManager()
with pytest.raises(ValueError, match="Time must be non-negative"):
getattr(mgr, method)(-1)
class TestStatisticsManagerExecutionTime:
def test_total_execution_time_is_relative_to_worker_start(self):
mgr = StatisticsManager()
mgr.initialize_worker_start_time(1_000)
mgr.update_total_execution_time(1_500)
stats = mgr.get_statistics()
# idle = total_exec - data - control = 500 - 0 - 0
assert stats.idle_time == 500
def test_total_execution_time_equal_to_start_is_allowed(self):
# The validation is `time < start`, so equality is OK and yields 0.
mgr = StatisticsManager()
mgr.initialize_worker_start_time(1_000)
mgr.update_total_execution_time(1_000)
assert mgr.get_statistics().idle_time == 0
def test_total_execution_time_before_start_raises(self):
mgr = StatisticsManager()
mgr.initialize_worker_start_time(1_000)
with pytest.raises(
ValueError,
match="Current time must be greater than or equal to worker start time",
):
mgr.update_total_execution_time(999)
def test_idle_time_clamped_to_zero_when_processing_overshoots(self):
# When data+control exceed total_execution_time (e.g. update_total was
# called before all increase_* calls for that interval), idle_time is
# clamped to 0 and a warning is logged. It must never be negative.
mgr = StatisticsManager()
mgr.initialize_worker_start_time(1_000)
mgr.update_total_execution_time(1_100) # 100ns total
mgr.increase_data_processing_time(80)
mgr.increase_control_processing_time(50) # 130 > 100
assert mgr.get_statistics().idle_time == 0
================================================
FILE: amber/src/test/python/core/architecture/managers/test_tuple_processing_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from threading import Condition, Event
from core.architecture.managers.tuple_processing_manager import TupleProcessingManager
from core.models import InternalMarker
from proto.org.apache.texera.amber.core import PortIdentity
class TestTupleProcessingManager:
def test_initial_state(self):
mgr = TupleProcessingManager()
assert mgr.current_input_tuple is None
assert mgr.current_input_port_id is None
assert mgr.current_input_tuple_iter is None
assert mgr.current_output_tuple is None
assert mgr.current_internal_marker is None
assert isinstance(mgr.context_switch_condition, Condition)
assert isinstance(mgr.finished_current, Event)
assert mgr.finished_current.is_set() is False
def test_get_internal_marker_consume_once(self):
mgr = TupleProcessingManager()
marker = InternalMarker()
mgr.current_internal_marker = marker
assert mgr.get_internal_marker() is marker
assert mgr.current_internal_marker is None
assert mgr.get_internal_marker() is None
def test_get_input_tuple_consume_once(self):
mgr = TupleProcessingManager()
sentinel = object()
mgr.current_input_tuple = sentinel
assert mgr.get_input_tuple() is sentinel
assert mgr.current_input_tuple is None
assert mgr.get_input_tuple() is None
def test_get_output_tuple_consume_once(self):
mgr = TupleProcessingManager()
sentinel = object()
mgr.current_output_tuple = sentinel
assert mgr.get_output_tuple() is sentinel
assert mgr.current_output_tuple is None
assert mgr.get_output_tuple() is None
def test_get_input_port_id_returns_zero_when_unset(self):
# Documented "no upstream / source executor" fallback. Worth pinning
# because it conflates "unset" with "real port id 0" — see the
# follow-up test below that exposes the collision.
mgr = TupleProcessingManager()
assert mgr.current_input_port_id is None
assert mgr.get_input_port_id() == 0
def test_get_input_port_id_returns_real_port_id(self):
mgr = TupleProcessingManager()
mgr.current_input_port_id = PortIdentity(id=7, internal=False)
assert mgr.get_input_port_id() == 7
def test_get_input_port_id_collides_for_port_zero(self):
# Pin: a real port with id=0 is indistinguishable from the
# "no upstream" sentinel. If callers ever need to tell them apart,
# the API has to change — this test guards the current behavior so
# any future fix breaks it deliberately.
mgr = TupleProcessingManager()
mgr.current_input_port_id = PortIdentity(id=0, internal=False)
assert mgr.get_input_port_id() == 0
# And the sentinel path also returns 0.
mgr.current_input_port_id = None
assert mgr.get_input_port_id() == 0
def test_finished_current_event_can_be_signalled(self):
mgr = TupleProcessingManager()
mgr.finished_current.set()
assert mgr.finished_current.is_set() is True
mgr.finished_current.clear()
assert mgr.finished_current.is_set() is False
def test_input_tuple_does_not_clear_output_or_marker(self):
mgr = TupleProcessingManager()
mgr.current_input_tuple = "in"
mgr.current_output_tuple = "out"
mgr.current_internal_marker = InternalMarker()
mgr.get_input_tuple()
assert mgr.current_output_tuple == "out"
assert mgr.current_internal_marker is not None
================================================
FILE: amber/src/test/python/core/architecture/rpc/test_async_rpc_client.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import inspect
from concurrent.futures import Future
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.architecture.rpc import async_rpc_client as async_rpc_client_module
from core.architecture.rpc.async_rpc_client import AsyncRPCClient, async_run
from proto.org.apache.texera.amber.core import (
ActorVirtualIdentity,
ChannelIdentity,
)
from proto.org.apache.texera.amber.engine.architecture.rpc import (
ControllerServiceStub,
ControlReturn,
ReturnInvocation,
)
def _make_client():
"""AsyncRPCClient with mock queue and a SimpleNamespace context.
The constructor only reads `context.worker_id` and calls `output_queue.put`
along the send path, so a duck-typed namespace + MagicMock queue is enough.
"""
return AsyncRPCClient(MagicMock(), SimpleNamespace(worker_id="w0"))
class TestAsyncRunDecorator:
def test_runs_coroutine_via_asyncio_run_when_no_loop(self):
@async_run
async def f():
return 42
# No running loop here, so the wrapper hits the RuntimeError branch
# and dispatches via asyncio.run.
assert f() == 42
def test_returns_awaitable_directly_when_called_inside_running_loop(self):
# Inside a running loop, the wrapper just calls the underlying function
# and returns the coroutine, leaving the await to the caller.
@async_run
async def f():
return "deep"
async def driver():
result = f() # Must be a coroutine
assert asyncio.iscoroutine(result)
return await result
assert asyncio.run(driver()) == "deep"
class TestCreateFuture:
def test_returns_future_instance(self):
client = _make_client()
to = ActorVirtualIdentity(name="dest")
fut = client._create_future(to)
assert isinstance(fut, Future)
def test_records_promise_at_pre_increment_sequence_and_then_increments(self):
client = _make_client()
to = ActorVirtualIdentity(name="dest")
# _send_sequences starts at 0 (defaultdict(int)). _create_future stores
# the promise at the current sequence and only THEN increments — so the
# very first promise lives at key (to, 0).
fut = client._create_future(to)
assert client._unfulfilled_promises[(to, 0)] is fut
assert client._send_sequences[to] == 1
def test_sequence_increments_per_target_independently(self):
client = _make_client()
a = ActorVirtualIdentity(name="A")
b = ActorVirtualIdentity(name="B")
client._create_future(a)
client._create_future(a)
client._create_future(b)
assert client._send_sequences[a] == 2
assert client._send_sequences[b] == 1
assert (a, 0) in client._unfulfilled_promises
assert (a, 1) in client._unfulfilled_promises
assert (b, 0) in client._unfulfilled_promises
class TestFulfillPromise:
def _channel(self, name: str) -> ChannelIdentity:
# `_fulfill_promise` looks up the dict by `from_.from_worker_id`; build
# a ChannelIdentity whose sender slot matches the actor we promised to.
return ChannelIdentity(
from_worker_id=ActorVirtualIdentity(name=name),
to_worker_id=ActorVirtualIdentity(name="self"),
is_control=True,
)
def test_resolves_matching_future_and_clears_the_entry(self):
client = _make_client()
actor = ActorVirtualIdentity(name="A")
fut = client._create_future(actor)
ret = ControlReturn()
client._fulfill_promise(self._channel("A"), command_id=0, control_return=ret)
assert fut.done() and fut.result() is ret
assert (actor, 0) not in client._unfulfilled_promises
def test_silently_logs_when_no_matching_promise_exists(self, monkeypatch):
client = _make_client()
# Place an unrelated pending promise so we can verify the no-match
# branch leaves it alone instead of silently dropping the dict entry.
actor_b = ActorVirtualIdentity(name="B")
fut_b = client._create_future(actor_b)
# Patch the loguru logger used inside async_rpc_client so we can
# assert that the no-match branch DID emit a warning. Without this
# the implementation could silently drop unknown ControlReturns and
# the suite would still pass.
warning_calls = []
monkeypatch.setattr(
async_rpc_client_module.logger,
"warning",
lambda msg, *a, **kw: warning_calls.append(msg),
)
# No prior _create_future for actor "A" — nothing to match. Method
# must not raise.
client._fulfill_promise(
self._channel("A"), command_id=99, control_return=ControlReturn()
)
assert len(warning_calls) == 1
assert "no corresponding ControlCommand found" in warning_calls[0]
# Unrelated pending promise is untouched.
assert not fut_b.done()
assert (actor_b, 0) in client._unfulfilled_promises
def test_does_not_disturb_unrelated_pending_promises(self):
client = _make_client()
actor_a = ActorVirtualIdentity(name="A")
actor_b = ActorVirtualIdentity(name="B")
fut_a = client._create_future(actor_a)
fut_b = client._create_future(actor_b)
client._fulfill_promise(
self._channel("A"), command_id=0, control_return=ControlReturn()
)
assert fut_a.done()
assert not fut_b.done()
assert (actor_b, 0) in client._unfulfilled_promises
class TestReceive:
def test_delegates_command_id_and_return_value_to_fulfill_promise(self):
client = _make_client()
actor = ActorVirtualIdentity(name="A")
fut = client._create_future(actor)
ret = ControlReturn()
invocation = ReturnInvocation(command_id=0, return_value=ret)
from_ = ChannelIdentity(
from_worker_id=actor,
to_worker_id=ActorVirtualIdentity(name="self"),
is_control=True,
)
client.receive(from_, invocation)
assert fut.done() and fut.result() is ret
class TestProxyStreamBlockers:
def test_stream_unary_blocked(self):
client = _make_client()
proxy = client.get_worker_interface("worker-X")
with pytest.raises(NotImplementedError, match="_stream_unary"):
proxy._stream_unary()
def test_unary_stream_blocked(self):
client = _make_client()
proxy = client.get_worker_interface("worker-X")
with pytest.raises(NotImplementedError, match="_unary_stream"):
proxy._unary_stream()
def test_stream_stream_blocked(self):
client = _make_client()
proxy = client.get_worker_interface("worker-X")
with pytest.raises(NotImplementedError, match="_stream_stream"):
proxy._stream_stream()
class TestControllerStub:
def test_controller_stub_returns_configured_stub(self):
client = _make_client()
stub = client.controller_stub()
# Identity check: same instance every call (lazily configured in __init__).
assert stub is client._controller_service_stub
assert stub is client.controller_stub()
def test_controller_stub_unary_unary_is_rewired_with_async_context(self):
# AsyncRPCClient.__init__ replaces the stub's `_unary_unary` with the
# closure produced by `_assign_context`, then `_wrap_all_async_methods`
# wraps that (originally async) function with `async_run`. The end
# state is therefore: the handler is no longer the bound method from
# ControllerServiceStub, but a synchronous async_run wrapper. A
# regression that returned an unconfigured stub would pass the identity
# check above, but cannot pass this one.
client = _make_client()
stub = client.controller_stub()
baseline = ControllerServiceStub("")
assert stub._unary_unary is not baseline._unary_unary
# The _assign_context wrapper closes over the AsyncRPCClient self, so
# if the rewiring really happened the function we end up with mentions
# `_assign_context` somewhere in its qualname (either directly, when
# async_run reuses the wrapped name, or via __wrapped__).
target = getattr(stub._unary_unary, "__wrapped__", stub._unary_unary)
assert "_assign_context" in target.__qualname__
def test_controller_stub_async_methods_are_wrapped_with_async_run(self):
# AsyncRPCClient also runs `_wrap_all_async_methods_with_async_run`,
# which replaces every coroutinefunction on the stub with the sync
# `async_run` wrapper. So whatever methods were async on a fresh
# `ControllerServiceStub` must now be NON-coroutine on the configured
# stub. Without this assertion the wrap-all pass could no-op silently.
client = _make_client()
stub = client.controller_stub()
baseline = ControllerServiceStub("")
async_method_names = [
name
for name in dir(baseline)
if not name.startswith("_")
and inspect.iscoroutinefunction(getattr(baseline, name))
]
# Sanity: the upstream stub really does ship with async methods.
assert async_method_names, (
"ControllerServiceStub no longer has any async methods; this test "
"needs to be reconsidered."
)
for name in async_method_names:
assert not inspect.iscoroutinefunction(getattr(stub, name)), (
f"{name!r} on the configured stub should have been wrapped by "
"async_run but is still a coroutine function."
)
================================================
FILE: amber/src/test/python/core/architecture/sendsemantics/test_partitioners.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.architecture.sendsemantics.broad_cast_partitioner import BroadcastPartitioner
from core.architecture.sendsemantics.hash_based_shuffle_partitioner import (
HashBasedShufflePartitioner,
)
from core.architecture.sendsemantics.one_to_one_partitioner import OneToOnePartitioner
from core.architecture.sendsemantics.range_based_shuffle_partitioner import (
RangeBasedShufflePartitioner,
)
from core.architecture.sendsemantics.round_robin_partitioner import (
RoundRobinPartitioner,
)
from core.models import Tuple
from core.models.schema.schema import Schema
from core.models.state import State
from proto.org.apache.texera.amber.core import (
ActorVirtualIdentity,
ChannelIdentity,
)
from proto.org.apache.texera.amber.engine.architecture.rpc import (
EmbeddedControlMessage,
)
from proto.org.apache.texera.amber.engine.architecture.sendsemantics import (
BroadcastPartitioning,
HashBasedShufflePartitioning,
OneToOnePartitioning,
RangeBasedShufflePartitioning,
RoundRobinPartitioning,
)
_HASHABLE_SCHEMA = Schema(raw_schema={"k": "INTEGER", "v": "STRING"})
def _worker(name: str) -> ActorVirtualIdentity:
return ActorVirtualIdentity(name=name)
def _channel(src: str, dst: str) -> ChannelIdentity:
return ChannelIdentity(from_worker_id=_worker(src), to_worker_id=_worker(dst))
def _tuple(**fields) -> Tuple:
return Tuple(fields)
def _hashable_tuple(**fields) -> Tuple:
return Tuple(fields, schema=_HASHABLE_SCHEMA)
def _snapshot(generator):
# Several partitioners yield the receiver's pending batch by reference and
# then clear it in the next statement of the generator. Snapshot list
# payloads at yield time so tests see what the caller would see when
# iterating tuple-by-tuple.
out = []
for item in generator:
out.append(list(item) if isinstance(item, list) else item)
return out
class TestBroadcastPartitioner:
@pytest.fixture
def partitioner(self):
return BroadcastPartitioner(
BroadcastPartitioning(
batch_size=2,
channels=[_channel("S", "A"), _channel("S", "B")],
)
)
def test_init_collects_unique_receivers(self):
p = BroadcastPartitioner(
BroadcastPartitioning(
batch_size=4,
channels=[
_channel("S", "A"),
_channel("S", "B"),
_channel("S", "A"),
],
)
)
assert p.batch_size == 4
assert set(p.receivers) == {_worker("A"), _worker("B")}
assert p.batch == []
def test_add_tuple_below_batch_size_yields_nothing(self, partitioner):
out = list(partitioner.add_tuple_to_batch(_tuple(k=1)))
assert out == []
assert partitioner.batch == [_tuple(k=1)]
def test_add_tuple_at_batch_size_emits_to_every_receiver_and_resets(
self, partitioner
):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
out = list(partitioner.add_tuple_to_batch(_tuple(k=2)))
emitted_receivers = {r for r, _ in out}
assert emitted_receivers == {_worker("A"), _worker("B")}
for _, batch in out:
assert batch == [_tuple(k=1), _tuple(k=2)]
assert partitioner.batch == []
def test_flush_emits_pending_batch_and_ecm_only_to_target(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
ecm = EmbeddedControlMessage()
out = list(partitioner.flush(_worker("A"), ecm))
assert out == [[_tuple(k=1)], ecm]
assert partitioner.batch == []
def test_flush_with_empty_batch_emits_only_ecm_for_target(self, partitioner):
ecm = EmbeddedControlMessage()
out = list(partitioner.flush(_worker("A"), ecm))
assert out == [ecm]
def test_flush_to_non_receiver_emits_nothing(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
ecm = EmbeddedControlMessage()
out = list(partitioner.flush(_worker("Z"), ecm))
assert out == []
def test_flush_state_emits_pending_batch_and_state_to_every_receiver(
self, partitioner
):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
state = State()
out = list(partitioner.flush_state(state))
receivers_with_batch = [r for r, payload in out if payload == [_tuple(k=1)]]
receivers_with_state = [r for r, payload in out if payload is state]
assert set(receivers_with_batch) == {_worker("A"), _worker("B")}
assert set(receivers_with_state) == {_worker("A"), _worker("B")}
assert partitioner.batch == []
def test_reset_clears_pending_batch(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
partitioner.reset()
assert partitioner.batch == []
class TestRoundRobinPartitioner:
@pytest.fixture
def partitioner(self):
return RoundRobinPartitioner(
RoundRobinPartitioning(
batch_size=2,
channels=[_channel("S", "A"), _channel("S", "B"), _channel("S", "C")],
)
)
def test_init_preserves_channel_order(self, partitioner):
assert [r for r, _ in partitioner.receivers] == [
_worker("A"),
_worker("B"),
_worker("C"),
]
assert partitioner.round_robin_index == 0
def test_init_dedupes_duplicate_channels_preserving_first_seen_order(self):
p = RoundRobinPartitioner(
RoundRobinPartitioning(
batch_size=2,
channels=[
_channel("S", "B"),
_channel("S", "A"),
_channel("S", "B"),
],
)
)
assert [r for r, _ in p.receivers] == [_worker("B"), _worker("A")]
def test_index_advances_modulo_receivers(self, partitioner):
for tup in (_tuple(k=1), _tuple(k=2), _tuple(k=3), _tuple(k=4)):
list(partitioner.add_tuple_to_batch(tup))
# 4 tuples across 3 receivers, batch_size=2 → no batch reaches size 2 yet
assert partitioner.round_robin_index == 1
# one tuple landed in A's slot (index 0) twice (round 0 + round 3),
# filling its batch and emitting on the second hit.
# B has 1 (round 1), C has 1 (round 2).
# We should not have seen any yield from B or C yet.
def test_emits_batch_when_a_receiver_slot_fills(self, partitioner):
outs = []
for tup in (_tuple(k=1), _tuple(k=2), _tuple(k=3), _tuple(k=4)):
outs.extend(list(partitioner.add_tuple_to_batch(tup)))
# Tuple #4 lands in receiver A again (index 0) → batch [k=1, k=4] of size 2
assert outs == [(_worker("A"), [_tuple(k=1), _tuple(k=4)])]
def test_flush_emits_pending_batch_and_ecm_for_target_only(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1))) # → A
list(partitioner.add_tuple_to_batch(_tuple(k=2))) # → B
ecm = EmbeddedControlMessage()
a_out = _snapshot(partitioner.flush(_worker("A"), ecm))
assert a_out == [[_tuple(k=1)], ecm]
# A's batch is now drained, B's pending batch remains untouched
assert partitioner.receivers[1][1] == [_tuple(k=2)]
def test_flush_to_unknown_receiver_emits_nothing(self, partitioner):
ecm = EmbeddedControlMessage()
assert list(partitioner.flush(_worker("Z"), ecm)) == []
def test_flush_state_emits_pending_batches_and_state_for_each_receiver(
self, partitioner
):
list(partitioner.add_tuple_to_batch(_tuple(k=1))) # → A
list(partitioner.add_tuple_to_batch(_tuple(k=2))) # → B
state = State()
out = []
for receiver, payload in partitioner.flush_state(state):
snap = list(payload) if isinstance(payload, list) else payload
out.append((receiver, snap))
# A and B emit (batch, state); C only emits state.
assert (_worker("A"), [_tuple(k=1)]) in out
assert (_worker("B"), [_tuple(k=2)]) in out
assert (_worker("A"), state) in out
assert (_worker("B"), state) in out
assert (_worker("C"), state) in out
class TestHashBasedShufflePartitioner:
def _partitioner(self, batch_size=10, hash_keys=("k",)):
return HashBasedShufflePartitioner(
HashBasedShufflePartitioning(
batch_size=batch_size,
channels=[_channel("S", "A"), _channel("S", "B")],
hash_attribute_names=list(hash_keys),
)
)
def test_same_key_routes_to_same_receiver_deterministically(self):
p1 = self._partitioner()
p2 = self._partitioner()
# Drive each with the same tuple; routing is deterministic per process,
# so two independent partitioners must place the tuple in the same slot.
list(p1.add_tuple_to_batch(_hashable_tuple(k=42, v="x")))
list(p2.add_tuple_to_batch(_hashable_tuple(k=42, v="x")))
nonempty1 = [(r, b) for r, b in p1.receivers if b]
nonempty2 = [(r, b) for r, b in p2.receivers if b]
assert len(nonempty1) == 1
assert nonempty1[0][0] == nonempty2[0][0]
def test_full_batch_yields_and_clears_only_that_slot(self):
p = self._partitioner(batch_size=2)
outs = _snapshot(
x
for tup in (_hashable_tuple(k=7) for _ in range(5))
for x in p.add_tuple_to_batch(tup)
)
assert len(outs) >= 1
# After a yield the slot's batch is replaced with a fresh empty list,
# so no receiver slot may exceed batch_size at any observation point.
for _, batch in p.receivers:
assert len(batch) < p.batch_size
def test_no_hash_attribute_names_falls_back_to_whole_tuple(self):
p = self._partitioner(hash_keys=())
list(p.add_tuple_to_batch(_hashable_tuple(k=1, v="a")))
list(p.add_tuple_to_batch(_hashable_tuple(k=2, v="b")))
total = sum(len(b) for _, b in p.receivers)
assert total == 2
def test_flush_emits_pending_batch_and_ecm_for_target_only(self):
p = self._partitioner(batch_size=10)
# Force a tuple into receiver A regardless of hash outcome.
p.receivers[0] = (p.receivers[0][0], [_hashable_tuple(k=1)])
ecm = EmbeddedControlMessage()
a_out = _snapshot(p.flush(p.receivers[0][0], ecm))
assert a_out == [[_hashable_tuple(k=1)], ecm]
def test_flush_state_emits_pending_batches_and_state(self):
p = self._partitioner(batch_size=10)
p.receivers[0] = (p.receivers[0][0], [_hashable_tuple(k=1)])
state = State()
out = []
for receiver, payload in p.flush_state(state):
snap = list(payload) if isinstance(payload, list) else payload
out.append((receiver, snap))
assert (p.receivers[0][0], [_hashable_tuple(k=1)]) in out
# Each receiver still emits the state record.
assert sum(1 for r, payload in out if payload is state) == len(p.receivers)
class TestRangeBasedShufflePartitioner:
@pytest.fixture
def partitioner(self):
return RangeBasedShufflePartitioner(
RangeBasedShufflePartitioning(
batch_size=10,
channels=[
_channel("S", "A"),
_channel("S", "B"),
_channel("S", "C"),
],
range_attribute_names=["k"],
range_min=0,
range_max=9,
)
)
def test_keys_per_receiver_partitions_range_evenly(self, partitioner):
# (9 - 0) // 3 + 1 = 4
assert partitioner.keys_per_receiver == 4
def test_value_below_range_min_routes_to_first_receiver(self, partitioner):
assert partitioner.get_receiver_index(-100) == 0
def test_value_above_range_max_routes_to_last_receiver(self, partitioner):
assert partitioner.get_receiver_index(10**6) == 2
def test_value_in_range_routes_by_quotient(self, partitioner):
# keys_per_receiver = 4 → indices: 0..3 → 0, 4..7 → 1, 8..9 (capped) → 2
assert partitioner.get_receiver_index(0) == 0
assert partitioner.get_receiver_index(3) == 0
assert partitioner.get_receiver_index(4) == 1
assert partitioner.get_receiver_index(7) == 1
assert partitioner.get_receiver_index(8) == 2
assert partitioner.get_receiver_index(9) == 2
def test_add_tuple_routes_using_first_attribute(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=2)))
list(partitioner.add_tuple_to_batch(_tuple(k=5)))
list(partitioner.add_tuple_to_batch(_tuple(k=8)))
receivers_to_batches = {r.name: b for r, b in partitioner.receivers}
assert receivers_to_batches["A"] == [_tuple(k=2)]
assert receivers_to_batches["B"] == [_tuple(k=5)]
assert receivers_to_batches["C"] == [_tuple(k=8)]
def test_full_batch_yields_and_clears_only_that_slot(self):
p = RangeBasedShufflePartitioner(
RangeBasedShufflePartitioning(
batch_size=2,
channels=[_channel("S", "A"), _channel("S", "B")],
range_attribute_names=["k"],
range_min=0,
range_max=9,
)
)
outs = []
for v in (1, 2, 3): # all route to receiver A (idx 0)
outs.extend(list(p.add_tuple_to_batch(_tuple(k=v))))
# First two tuples fill A's batch; second one yields and resets.
assert outs == [(_worker("A"), [_tuple(k=1), _tuple(k=2)])]
# A is now empty again, holding only the third tuple.
assert p.receivers[0][1] == [_tuple(k=3)]
def test_flush_emits_pending_batch_and_ecm_for_target_only(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=2))) # → A
list(partitioner.add_tuple_to_batch(_tuple(k=5))) # → B
ecm = EmbeddedControlMessage()
a_out = _snapshot(partitioner.flush(_worker("A"), ecm))
assert a_out == [[_tuple(k=2)], ecm]
# B is untouched.
assert partitioner.receivers[1][1] == [_tuple(k=5)]
def test_flush_state_emits_pending_batches_and_state(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=2))) # → A
state = State()
out = []
for receiver, payload in partitioner.flush_state(state):
snap = list(payload) if isinstance(payload, list) else payload
out.append((receiver, snap))
assert (_worker("A"), [_tuple(k=2)]) in out
# Every receiver still emits the state, even with empty pending batch.
assert sum(1 for r, payload in out if payload is state) == 3
class TestOneToOnePartitioner:
@pytest.fixture
def partitioner(self):
return OneToOnePartitioner(
OneToOnePartitioning(
batch_size=2,
channels=[
_channel("OTHER", "X"),
_channel("S", "A"),
],
),
worker_id="S",
)
def test_init_picks_receiver_matching_worker_id(self, partitioner):
assert partitioner.receiver == _worker("A")
def test_add_tuple_below_batch_yields_nothing(self, partitioner):
out = list(partitioner.add_tuple_to_batch(_tuple(k=1)))
assert out == []
assert partitioner.batch == [_tuple(k=1)]
def test_add_tuple_at_batch_yields_to_unique_receiver_and_resets(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
out = list(partitioner.add_tuple_to_batch(_tuple(k=2)))
assert out == [(_worker("A"), [_tuple(k=1), _tuple(k=2)])]
assert partitioner.batch == []
def test_flush_emits_pending_batch_then_ecm(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
ecm = EmbeddedControlMessage()
out = list(partitioner.flush(_worker("A"), ecm))
assert out == [[_tuple(k=1)], ecm]
assert partitioner.batch == []
def test_flush_with_empty_batch_emits_only_ecm(self, partitioner):
ecm = EmbeddedControlMessage()
assert list(partitioner.flush(_worker("A"), ecm)) == [ecm]
def test_flush_state_emits_pending_batch_then_state(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
state = State()
out = list(partitioner.flush_state(state))
assert out == [
(_worker("A"), [_tuple(k=1)]),
(_worker("A"), state),
]
assert partitioner.batch == []
def test_reset_clears_pending_batch(self, partitioner):
list(partitioner.add_tuple_to_batch(_tuple(k=1)))
partitioner.reset()
assert partitioner.batch == []
================================================
FILE: amber/src/test/python/core/models/schema/test_schema.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pyarrow as pa
import pytest
from core.models.schema.attribute_type import AttributeType
from core.models.schema.schema import Schema
class TestSchema:
@pytest.fixture
def raw_schema(self):
return {
"field-1": "STRING",
"field-2": "INTEGER",
"field-3": "LONG",
"field-4": "DOUBLE",
"field-5": "BOOLEAN",
"field-6": "TIMESTAMP",
"field-7": "BINARY",
}
@pytest.fixture
def arrow_schema(self):
return pa.schema(
[
pa.field("field-1", pa.string()),
pa.field("field-2", pa.int32()),
pa.field("field-3", pa.int64()),
pa.field("field-4", pa.float64()),
pa.field("field-5", pa.bool_()),
pa.field("field-6", pa.timestamp("us")),
pa.field("field-7", pa.binary()),
]
)
@pytest.fixture
def schema(self):
s = Schema()
s.add("field-1", AttributeType.STRING)
s.add("field-2", AttributeType.INT)
s.add("field-3", AttributeType.LONG)
s.add("field-4", AttributeType.DOUBLE)
s.add("field-5", AttributeType.BOOL)
s.add("field-6", AttributeType.TIMESTAMP)
s.add("field-7", AttributeType.BINARY)
return s
def test_accessors_and_mutators(self, schema):
assert schema.get_attr_names() == [f"field-{i}" for i in range(1, 8)]
assert schema.get_attr_type("field-2") == AttributeType.INT
assert schema.get_attr_type("field-6") == AttributeType.TIMESTAMP
assert schema.as_key_value_pairs() == [
("field-1", AttributeType.STRING),
("field-2", AttributeType.INT),
("field-3", AttributeType.LONG),
("field-4", AttributeType.DOUBLE),
("field-5", AttributeType.BOOL),
("field-6", AttributeType.TIMESTAMP),
("field-7", AttributeType.BINARY),
]
with pytest.raises(KeyError):
schema.get_attr_type("does not exist")
with pytest.raises(TypeError):
schema["illegal_assign"] = "value"
with pytest.raises(TypeError):
_ = schema["illegal_access"]
with pytest.raises(KeyError):
schema.add("field-2", AttributeType.LONG)
def test_convert_from_raw_schema(self, raw_schema, schema):
assert schema == Schema(raw_schema=raw_schema)
def test_convert_from_arrow_schema(self, arrow_schema, schema):
assert schema == Schema(arrow_schema=arrow_schema)
assert schema.as_arrow_schema() == arrow_schema
def test_large_binary_in_raw_schema(self):
"""Test creating schema with LARGE_BINARY from raw schema."""
raw_schema = {
"regular_field": "STRING",
"large_binary_field": "LARGE_BINARY",
}
schema = Schema(raw_schema=raw_schema)
assert schema.get_attr_type("regular_field") == AttributeType.STRING
assert schema.get_attr_type("large_binary_field") == AttributeType.LARGE_BINARY
def test_large_binary_in_arrow_schema_with_metadata(self):
"""Test creating schema with LARGE_BINARY from Arrow schema with metadata."""
arrow_schema = pa.schema(
[
pa.field("regular_field", pa.string()),
pa.field(
"large_binary_field",
pa.string(),
metadata={b"texera_type": b"LARGE_BINARY"},
),
]
)
schema = Schema(arrow_schema=arrow_schema)
assert schema.get_attr_type("regular_field") == AttributeType.STRING
assert schema.get_attr_type("large_binary_field") == AttributeType.LARGE_BINARY
def test_large_binary_as_arrow_schema_includes_metadata(self):
"""Test that LARGE_BINARY fields include metadata in Arrow schema."""
schema = Schema()
schema.add("regular_field", AttributeType.STRING)
schema.add("large_binary_field", AttributeType.LARGE_BINARY)
arrow_schema = schema.as_arrow_schema()
# Regular field should have no metadata
regular_field = arrow_schema.field("regular_field")
assert (
regular_field.metadata is None
or b"texera_type" not in regular_field.metadata
)
# LARGE_BINARY field should have metadata
large_binary_field = arrow_schema.field("large_binary_field")
assert large_binary_field.metadata is not None
assert large_binary_field.metadata.get(b"texera_type") == b"LARGE_BINARY"
assert (
large_binary_field.type == pa.string()
) # LARGE_BINARY is stored as string
def test_round_trip_large_binary_schema(self):
"""Test round-trip conversion of schema with LARGE_BINARY."""
original_schema = Schema()
original_schema.add("field1", AttributeType.STRING)
original_schema.add("field2", AttributeType.LARGE_BINARY)
original_schema.add("field3", AttributeType.INT)
# Convert to Arrow and back
arrow_schema = original_schema.as_arrow_schema()
round_trip_schema = Schema(arrow_schema=arrow_schema)
assert round_trip_schema == original_schema
assert round_trip_schema.get_attr_type("field1") == AttributeType.STRING
assert round_trip_schema.get_attr_type("field2") == AttributeType.LARGE_BINARY
assert round_trip_schema.get_attr_type("field3") == AttributeType.INT
================================================
FILE: amber/src/test/python/core/models/test_operator.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import base64
import pytest
from core.models import (
BatchOperator,
SourceOperator,
State,
Table,
Tuple,
TupleOperatorV2,
)
from core.models.operator import Operator, TableOperator
class _ConcreteOperator(TupleOperatorV2):
"""Minimal concrete subclass; implements abstract process_tuple."""
def process_tuple(self, tuple_, port):
yield tuple_
class _ConcreteSource(SourceOperator):
"""Minimal concrete subclass; implements abstract produce."""
def produce(self):
yield None
class _ConcreteBatch(BatchOperator):
BATCH_SIZE = 4
def process_batch(self, batch, port):
yield batch
class _ConcreteTable(TableOperator):
"""Concrete subclass that records the table it received via process_table."""
def __init__(self):
super().__init__()
self.received_tables = []
def process_table(self, table, port):
self.received_tables.append(table)
yield None
class TestPythonTemplateDecoder:
def test_stdlib_decoder_decodes_str_input(self):
decoder = Operator.PythonTemplateDecoder.StdlibBase64Decoder()
encoded = base64.b64encode(b"hello").decode("ascii")
assert decoder.to_str(encoded) == "hello"
def test_stdlib_decoder_accepts_bytes_input(self):
decoder = Operator.PythonTemplateDecoder.StdlibBase64Decoder()
encoded = base64.b64encode("中".encode("utf-8")) # bytes
assert decoder.to_str(encoded) == "中"
def test_stdlib_decoder_rejects_non_utf8_bytes_strictly(self):
# `errors='strict'` must raise; `0x80` is not a valid UTF-8 leading byte.
decoder = Operator.PythonTemplateDecoder.StdlibBase64Decoder()
bad = base64.b64encode(b"\x80\x81").decode("ascii")
with pytest.raises(UnicodeDecodeError):
decoder.to_str(bad)
def test_default_decoder_when_none_supplied(self):
wrapper = Operator.PythonTemplateDecoder()
encoded = base64.b64encode(b"abc").decode("ascii")
assert wrapper.decode(encoded) == "abc"
def test_uses_injected_custom_decoder(self):
class CountingDecoder:
def __init__(self):
self.calls = 0
def to_str(self, data):
self.calls += 1
return f"decoded:{data}"
injected = CountingDecoder()
wrapper = Operator.PythonTemplateDecoder(decoder=injected)
assert wrapper.decode("x") == "decoded:x"
assert injected.calls == 1
def test_lru_cache_reuses_results_for_repeated_inputs(self):
# Pin: the cache short-circuits the underlying decoder so identical
# inputs incur only one decode call. This is what makes the wrapper
# cheap when the same template appears in many tuples.
class CountingDecoder:
def __init__(self):
self.calls = 0
def to_str(self, data):
self.calls += 1
return f"d{self.calls}:{data}"
injected = CountingDecoder()
wrapper = Operator.PythonTemplateDecoder(decoder=injected, cache_size=8)
first = wrapper.decode("same")
second = wrapper.decode("same")
assert first == "d1:same"
assert second == "d1:same" # same cached result
assert injected.calls == 1
def test_lru_cache_evicts_when_size_exceeded(self):
class CountingDecoder:
def __init__(self):
self.calls = 0
def to_str(self, data):
self.calls += 1
return f"d{self.calls}:{data}"
injected = CountingDecoder()
wrapper = Operator.PythonTemplateDecoder(decoder=injected, cache_size=2)
wrapper.decode("a")
wrapper.decode("b")
wrapper.decode("c") # evicts "a"
wrapper.decode("a") # cache miss → re-decode
assert injected.calls == 4
class TestIsSourceProperty:
def test_default_is_false(self):
op = _ConcreteOperator()
assert op.is_source is False
def test_setter_true_takes_effect(self):
op = _ConcreteOperator()
op.is_source = True
assert op.is_source is True
def test_setter_can_flip_back_to_false(self):
op = _ConcreteOperator()
op.is_source = True
op.is_source = False
assert op.is_source is False
def test_source_operator_subclass_reports_is_source_true(self):
src = _ConcreteSource()
assert src.is_source is True
class TestOperatorDefaultMethods:
def test_open_is_no_op(self):
# No state to assert; verify it does not raise and returns None.
assert _ConcreteOperator().open() is None
def test_close_is_no_op(self):
assert _ConcreteOperator().close() is None
def test_process_state_returns_input_state_unchanged(self):
# Default behavior is to forward the State to downstream operators.
op = _ConcreteOperator()
state = State()
assert op.process_state(state, port=0) is state
def test_produce_state_on_start_returns_none_by_default(self):
assert _ConcreteOperator().produce_state_on_start(port=0) is None
def test_produce_state_on_finish_returns_none_by_default(self):
assert _ConcreteOperator().produce_state_on_finish(port=0) is None
class TestLazyTemplateDecoder:
def test_first_call_creates_decoder_and_caches_on_instance(self):
op = _ConcreteOperator()
assert not hasattr(op, "_python_template_decoder")
op._get_template_decoder()
assert hasattr(op, "_python_template_decoder")
def test_subsequent_calls_reuse_the_cached_decoder(self):
op = _ConcreteOperator()
first = op._get_template_decoder()
second = op._get_template_decoder()
assert first is second
def test_decode_python_template_delegates_to_lazy_decoder(self):
op = _ConcreteOperator()
encoded = base64.b64encode(b"payload").decode("ascii")
assert op.decode_python_template(encoded) == "payload"
class TestBatchOperatorValidation:
def test_validate_batch_size_rejects_none(self):
with pytest.raises(ValueError, match="cannot be None"):
BatchOperator._validate_batch_size(None)
def test_validate_batch_size_rejects_non_int(self):
with pytest.raises(ValueError):
BatchOperator._validate_batch_size("10")
def test_validate_batch_size_rejects_zero(self):
with pytest.raises(ValueError, match="positive"):
BatchOperator._validate_batch_size(0)
def test_validate_batch_size_rejects_negative(self):
with pytest.raises(ValueError, match="positive"):
BatchOperator._validate_batch_size(-3)
def test_validate_batch_size_accepts_positive_int(self):
# No raise = pass; method returns None implicitly.
assert BatchOperator._validate_batch_size(1) is None
assert BatchOperator._validate_batch_size(1024) is None
def test_concrete_batch_operator_initializes_with_valid_size(self):
op = _ConcreteBatch()
assert op.BATCH_SIZE == 4
class TestTableOperator:
def test_process_tuple_buffers_input_and_yields_none(self):
# process_tuple is @final on TableOperator: it must record the tuple
# internally and yield exactly one None so the framework's iterator
# protocol still sees a value, but no output is produced per-tuple.
op = _ConcreteTable()
out = list(op.process_tuple(Tuple({"x": 1}), port=0))
assert out == [None]
# Nothing was passed downstream to process_table yet.
assert op.received_tables == []
def test_on_finish_calls_process_table_with_buffered_tuples(self):
op = _ConcreteTable()
list(op.process_tuple(Tuple({"x": 1, "y": "a"}), port=0))
list(op.process_tuple(Tuple({"x": 2, "y": "b"}), port=0))
# Drain on_finish so the generator runs.
list(op.on_finish(port=0))
assert len(op.received_tables) == 1
table = op.received_tables[0]
assert isinstance(table, Table)
rows = [t for t in table.as_tuples()]
assert rows == [Tuple({"x": 1, "y": "a"}), Tuple({"x": 2, "y": "b"})]
def test_on_finish_with_no_buffered_tuples_yields_empty_table(self):
op = _ConcreteTable()
list(op.on_finish(port=0))
assert len(op.received_tables) == 1
assert list(op.received_tables[0].as_tuples()) == []
def test_buffers_are_keyed_by_port(self):
# Each input port has its own tuple buffer; on_finish for one port
# must not surface tuples written through a different port.
op = _ConcreteTable()
list(op.process_tuple(Tuple({"x": 1}), port=0))
list(op.process_tuple(Tuple({"x": 99}), port=1))
list(op.on_finish(port=0))
rows = list(op.received_tables[0].as_tuples())
assert rows == [Tuple({"x": 1})]
================================================
FILE: amber/src/test/python/core/models/test_state.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.models.state import State
class TestState:
def test_state_subclasses_dict(self):
state = State({"a": 1})
assert isinstance(state, dict)
assert state["a"] == 1
assert State() == {}
def test_class_attributes(self):
assert State.CONTENT == "content"
assert State.SCHEMA.get_attr_names() == ["content"]
def test_json_round_trip_primitives(self):
original = State(
{
"string": "hello",
"int": 42,
"float": 3.14,
"bool_true": True,
"bool_false": False,
"none_value": None,
}
)
decoded = State.from_json(original.to_json())
assert decoded == original
def test_json_round_trip_empty(self):
assert State.from_json(State().to_json()) == State()
def test_json_round_trip_bytes(self):
original = State({"payload": b"\x00\x01\x02\xff"})
decoded = State.from_json(original.to_json())
assert decoded["payload"] == b"\x00\x01\x02\xff"
assert isinstance(decoded["payload"], bytes)
def test_json_round_trip_nested_dict(self):
original = State({"outer": {"inner": {"value": 1}}})
decoded = State.from_json(original.to_json())
assert decoded == original
def test_json_round_trip_list_of_mixed_values(self):
original = State({"items": [1, "two", 3.0, True, None]})
decoded = State.from_json(original.to_json())
assert decoded == original
def test_json_round_trip_bytes_inside_list_and_nested_dict(self):
original = State(
{
"blobs": [b"first", b"second"],
"nested": {"sub_blob": b"inside"},
}
)
decoded = State.from_json(original.to_json())
assert decoded["blobs"] == [b"first", b"second"]
assert decoded["nested"]["sub_blob"] == b"inside"
def test_to_json_rejects_non_serializable_value(self):
class Custom:
pass
with pytest.raises(TypeError):
State({"bad": Custom()}).to_json()
def test_tuple_round_trip(self):
original = State({"loop_counter": 3, "label": "outer", "blob": b"\x01\x02"})
decoded = State.from_tuple(original.to_tuple())
assert decoded == original
def test_to_tuple_uses_state_schema(self):
tuple_ = State({"x": 1}).to_tuple()
# Single STRING column whose value is the JSON serialization.
assert tuple_[State.CONTENT] == '{"x":1}'
def test_nested_dict_decodes_to_plain_dict(self):
# Top-level returns a State; nested dicts come back as plain dict.
# This is intentional -- only the outermost mapping is wrapped.
decoded = State.from_json('{"outer":{"inner":1}}')
assert isinstance(decoded, State)
assert isinstance(decoded["outer"], dict)
assert not isinstance(decoded["outer"], State)
================================================
FILE: amber/src/test/python/core/models/test_table.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import pandas
import pickle
import pytest
from pandas import RangeIndex
from core.models import Table, Tuple
class TestTable:
@pytest.fixture
def a_timestamp(self):
return datetime.datetime.now()
@pytest.fixture
def target_raw_tuples(self, a_timestamp):
return [
{
"field1": 1,
"field2": "hello",
"field3": 2.3,
"field4": True,
"field5": a_timestamp,
"field6": b"some binary",
"7_special-name": None,
"none": None,
},
{
"field1": 2,
"field2": "world",
"field3": 0.0,
"field4": False,
"field5": datetime.datetime.fromtimestamp(1000000000),
"field6": pickle.dumps([1, 2, 3]),
"7_special-name": "a strange value",
"none": None,
},
]
@pytest.fixture
def target_tuples(self, target_raw_tuples):
return [Tuple(raw_tuple) for raw_tuple in target_raw_tuples]
@pytest.fixture
def target_table(self, target_raw_tuples):
return Table(target_raw_tuples)
@pytest.fixture
def target_data_frame(self, a_timestamp):
return pandas.DataFrame(
{
"field1": [1, 2],
"field2": ["hello", "world"],
"field3": [2.3, 0.0],
"field4": [True, False],
"field5": [
a_timestamp,
datetime.datetime.fromtimestamp(1000000000),
],
"field6": [b"some binary", pickle.dumps([1, 2, 3])],
"7_special-name": [None, "a strange value"],
"none": [None, None],
},
columns=[
"field1",
"field2",
"field3",
"field4",
"field5",
"field6",
"7_special-name",
"none",
],
)
def test_table_creation(self, target_table, a_timestamp):
assert target_table["field1"][0] == 1
assert target_table["field1"][1] == 2
assert target_table["field2"][0] == "hello"
assert target_table["field2"][1] == "world"
assert target_table["field3"][0] == 2.3
assert target_table["field3"][1] == 0.0
assert target_table["field4"][0]
assert not target_table["field4"][1]
assert target_table["field5"][0] == a_timestamp
assert target_table["field5"][1] == datetime.datetime.fromtimestamp(1000000000)
assert target_table["field6"][0] == b"some binary"
assert target_table["field6"][1] == pickle.dumps([1, 2, 3])
assert target_table["7_special-name"][0] is None
assert target_table["7_special-name"][1] == "a strange value"
assert target_table["none"][0] is None
assert target_table["none"][1] is None
def test_as_tuples_preserve_types(self, target_table, target_tuples):
assert list(target_table.as_tuples()) == target_tuples
def test_table_from_data_frame(self, target_table, target_data_frame):
assert Table(target_data_frame) == target_table
def test_table_from_list_of_tuples(self, target_table, target_tuples):
table = Table(target_tuples)
assert table == target_table
assert list(table.as_tuples()) == target_tuples
def test_table_from_list_of_series(
self, target_table, a_timestamp, target_raw_tuples, target_tuples
):
table = Table([pandas.Series(raw_tuple) for raw_tuple in target_raw_tuples])
assert table == target_table
assert list(table.as_tuples()) == target_tuples
def test_table_from_table(self, target_table, target_tuples):
table = Table(target_table)
assert table == target_table
assert list(table.as_tuples()) == target_tuples
def test_use_table_as_data_frame(self, target_table, target_data_frame):
df = target_table
assert (df.index == RangeIndex(start=0, stop=2, step=1)).all()
concat_df = pandas.concat([df, df])
assert len(concat_df) == 4
assert target_table.equals(target_data_frame)
def test_validation_of_schema(self):
with pytest.raises(AssertionError):
Table([{"text": "hello"}, {"book": "harry"}])
================================================
FILE: amber/src/test/python/core/models/test_tuple.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import pandas
import pyarrow
import pytest
import numpy as np
from copy import deepcopy
from core.models import Tuple, ArrowTableTupleProvider
from core.models.schema.schema import Schema
class TestTuple:
@pytest.fixture
def target_tuple(self):
return Tuple({"x": 1, "y": "a"})
def test_tuple_from_list(self, target_tuple):
assert Tuple([("x", 1), ("y", "a")]) == target_tuple
def test_tuple_from_dict(self, target_tuple):
assert Tuple({"x": 1, "y": "a"}) == target_tuple
def test_tuple_from_series(self, target_tuple):
assert Tuple(pandas.Series({"x": 1, "y": "a"})) == target_tuple
def test_tuple_as_key_value_pairs(self, target_tuple):
assert target_tuple.as_key_value_pairs() == [("x", 1), ("y", "a")]
def test_tuple_as_dict(self, target_tuple):
assert target_tuple.as_dict() == {"x": 1, "y": "a"}
def test_tuple_as_series(self, target_tuple):
assert (target_tuple.as_series() == pandas.Series({"x": 1, "y": "a"})).all()
def test_tuple_get_fields(self, target_tuple):
assert target_tuple.get_fields() == (1, "a")
def test_tuple_get_field_names(self, target_tuple):
assert target_tuple.get_field_names() == ("x", "y")
def test_tuple_get_item(self, target_tuple):
assert target_tuple["x"] == 1
assert target_tuple["y"] == "a"
assert target_tuple[0] == 1
assert target_tuple[1] == "a"
def test_tuple_set_item(self, target_tuple):
target_tuple["x"] = 3
assert target_tuple["x"] == 3
assert target_tuple["y"] == "a"
assert target_tuple[0] == 3
assert target_tuple[1] == "a"
target_tuple["z"] = 1.1
assert target_tuple[2] == 1.1
assert target_tuple["z"] == 1.1
def test_tuple_str(self, target_tuple):
assert str(target_tuple) == "Tuple['x': 1, 'y': 'a']"
def test_tuple_repr(self, target_tuple):
assert repr(target_tuple) == "Tuple['x': 1, 'y': 'a']"
def test_tuple_eq(self, target_tuple):
assert target_tuple == target_tuple
assert not Tuple({"x": 2, "y": "a"}) == target_tuple
def test_tuple_ne(self, target_tuple):
assert not target_tuple != target_tuple
assert Tuple({"x": 1, "y": "b"}) != target_tuple
def test_reject_empty_tuplelike(self):
with pytest.raises(AssertionError):
Tuple([])
with pytest.raises(AssertionError):
Tuple({})
with pytest.raises(AssertionError):
Tuple(pandas.Series(dtype=pandas.StringDtype()))
def test_reject_invalid_tuplelike(self):
with pytest.raises(TypeError):
Tuple(1)
with pytest.raises(TypeError):
Tuple([1])
with pytest.raises(TypeError):
Tuple([None])
def test_tuple_lazy_get_from_arrow(self):
def field_accessor(field_name):
return chr(96 + int(field_name))
chr_tuple = Tuple({"1": "a", "3": "c"})
tuple_ = Tuple({"1": field_accessor, "3": field_accessor})
assert tuple_ == Tuple({"1": "a", "3": "c"})
tuple_ = Tuple({"1": field_accessor, "3": field_accessor})
assert deepcopy(tuple_) == chr_tuple
def test_retrieve_tuple_from_empty_arrow_table(self):
arrow_schema = pyarrow.schema([])
arrow_table = arrow_schema.empty_table()
tuple_provider = ArrowTableTupleProvider(arrow_table)
tuples = [
Tuple({name: field_accessor for name in arrow_table.column_names})
for field_accessor in tuple_provider
]
assert tuples == []
def test_finalize_tuple(self):
tuple_ = Tuple(
{"name": "texera", "age": 21, "scores": [85, 94, 100], "height": np.nan}
)
schema = Schema(
raw_schema={
"name": "STRING",
"age": "INTEGER",
"scores": "BINARY",
"height": "DOUBLE",
}
)
tuple_.finalize(schema)
assert isinstance(tuple_["scores"], bytes)
assert tuple_["height"] is None
def test_hash(self):
schema = Schema(
raw_schema={
"col-int": "INTEGER",
"col-string": "STRING",
"col-bool": "BOOLEAN",
"col-long": "LONG",
"col-double": "DOUBLE",
"col-timestamp": "TIMESTAMP",
"col-binary": "BINARY",
}
)
tuple_ = Tuple(
{
"col-int": 922323,
"col-string": "string-attr",
"col-bool": True,
"col-long": 1123213213213,
"col-double": 214214.9969346,
"col-timestamp": datetime.datetime.fromtimestamp(100000000),
"col-binary": b"hello",
},
schema,
)
assert hash(tuple_) == -1335416166 # calculated with Java
tuple2 = Tuple(
{
"col-int": 0,
"col-string": "",
"col-bool": False,
"col-long": 0,
"col-double": 0.0,
"col-timestamp": datetime.datetime.fromtimestamp(0),
"col-binary": b"",
},
schema,
)
assert hash(tuple2) == -1409761483 # calculated with Java
tuple3 = Tuple(
{
"col-int": None,
"col-string": None,
"col-bool": None,
"col-long": None,
"col-double": None,
"col-timestamp": None,
"col-binary": None,
},
schema,
)
assert hash(tuple3) == 1742810335 # calculated with Java
tuple4 = Tuple(
{
"col-int": -3245763,
"col-string": "\n\r\napple",
"col-bool": True,
"col-long": -8965536434247,
"col-double": 1 / 3,
"col-timestamp": datetime.datetime.fromtimestamp(-1990),
"col-binary": None,
},
schema,
)
assert hash(tuple4) == -592643630 # calculated with Java
tuple5 = Tuple(
{
"col-int": 0x7FFFFFFF,
"col-string": "",
"col-bool": True,
"col-long": 0x7FFFFFFFFFFFFFFF,
"col-double": 7 / 17,
"col-timestamp": datetime.datetime.fromtimestamp(1234567890),
"col-binary": b"o" * 4097,
},
schema,
)
assert hash(tuple5) == -2099556631 # calculated with Java
def test_tuple_with_large_binary(self):
"""Test tuple with largebinary field."""
from core.models.type.large_binary import largebinary
schema = Schema(
raw_schema={
"regular_field": "STRING",
"large_binary_field": "LARGE_BINARY",
}
)
large_binary = largebinary("s3://test-bucket/path/to/object")
tuple_ = Tuple(
{
"regular_field": "test string",
"large_binary_field": large_binary,
},
schema=schema,
)
assert tuple_["regular_field"] == "test string"
assert tuple_["large_binary_field"] == large_binary
assert isinstance(tuple_["large_binary_field"], largebinary)
assert tuple_["large_binary_field"].uri == "s3://test-bucket/path/to/object"
def test_tuple_from_arrow_with_large_binary(self):
"""Test creating tuple from Arrow table with LARGE_BINARY metadata."""
import pyarrow as pa
from core.models.type.large_binary import largebinary
# Create Arrow schema with LARGE_BINARY metadata
arrow_schema = pa.schema(
[
pa.field("regular_field", pa.string()),
pa.field(
"large_binary_field",
pa.string(),
metadata={b"texera_type": b"LARGE_BINARY"},
),
]
)
# Create Arrow table with URI string for large_binary_field
arrow_table = pa.Table.from_pydict(
{
"regular_field": ["test"],
"large_binary_field": ["s3://test-bucket/path/to/object"],
},
schema=arrow_schema,
)
# Create tuple from Arrow table
tuple_provider = ArrowTableTupleProvider(arrow_table)
tuples = [
Tuple({name: field_accessor for name in arrow_table.column_names})
for field_accessor in tuple_provider
]
assert len(tuples) == 1
tuple_ = tuples[0]
assert tuple_["regular_field"] == "test"
assert isinstance(tuple_["large_binary_field"], largebinary)
assert tuple_["large_binary_field"].uri == "s3://test-bucket/path/to/object"
def test_tuple_with_null_large_binary(self):
"""Test tuple with null largebinary field."""
import pyarrow as pa
# Create Arrow schema with LARGE_BINARY metadata
arrow_schema = pa.schema(
[
pa.field(
"large_binary_field",
pa.string(),
metadata={b"texera_type": b"LARGE_BINARY"},
),
]
)
# Create Arrow table with null value
arrow_table = pa.Table.from_pydict(
{
"large_binary_field": [None],
},
schema=arrow_schema,
)
# Create tuple from Arrow table
tuple_provider = ArrowTableTupleProvider(arrow_table)
tuples = [
Tuple({name: field_accessor for name in arrow_table.column_names})
for field_accessor in tuple_provider
]
assert len(tuples) == 1
tuple_ = tuples[0]
assert tuple_["large_binary_field"] is None
================================================
FILE: amber/src/test/python/core/models/type/test_large_binary.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from unittest.mock import patch
from core.models.type.large_binary import largebinary
class TestLargeBinary:
def test_create_with_uri(self):
"""Test creating largebinary with a valid S3 URI."""
uri = "s3://test-bucket/path/to/object"
large_binary = largebinary(uri)
assert large_binary.uri == uri
assert str(large_binary) == uri
assert repr(large_binary) == f"largebinary('{uri}')"
def test_create_without_uri(self):
"""Test creating largebinary without URI (calls large_binary_manager.create)."""
with patch("pytexera.storage.large_binary_manager.create") as mock_create:
mock_create.return_value = "s3://bucket/objects/123/uuid"
large_binary = largebinary()
assert large_binary.uri == "s3://bucket/objects/123/uuid"
mock_create.assert_called_once()
def test_invalid_uri_raises_value_error(self):
"""Test that invalid URI (not starting with s3://) raises ValueError."""
with pytest.raises(ValueError, match="largebinary URI must start with 's3://'"):
largebinary("http://invalid-uri")
with pytest.raises(ValueError, match="largebinary URI must start with 's3://'"):
largebinary("invalid-uri")
def test_get_bucket_name(self):
"""Test extracting bucket name from URI."""
large_binary = largebinary("s3://my-bucket/path/to/object")
assert large_binary.get_bucket_name() == "my-bucket"
def test_get_object_key(self):
"""Test extracting object key from URI."""
large_binary = largebinary("s3://my-bucket/path/to/object")
assert large_binary.get_object_key() == "path/to/object"
def test_get_object_key_with_leading_slash(self):
"""Test extracting object key when URI has leading slash."""
large_binary = largebinary("s3://my-bucket/path/to/object")
# urlparse includes leading slash, but get_object_key removes it
assert large_binary.get_object_key() == "path/to/object"
def test_equality(self):
"""Test largebinary equality comparison."""
uri = "s3://bucket/path"
obj1 = largebinary(uri)
obj2 = largebinary(uri)
obj3 = largebinary("s3://bucket/different")
assert obj1 == obj2
assert obj1 != obj3
assert obj1 != "not a largebinary"
def test_hash(self):
"""Test largebinary hashing."""
uri = "s3://bucket/path"
obj1 = largebinary(uri)
obj2 = largebinary(uri)
assert hash(obj1) == hash(obj2)
assert hash(obj1) == hash(uri)
def test_uri_property(self):
"""Test URI property access."""
uri = "s3://test-bucket/test/path"
large_binary = largebinary(uri)
assert large_binary.uri == uri
================================================
FILE: amber/src/test/python/core/proxy/test_proxy_client.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pandas import DataFrame
from pyarrow import ArrowNotImplementedError, Table
from queue import Queue
from core.proxy.proxy_client import ProxyClient
from core.proxy.proxy_server import ProxyServer
class TestProxyClient:
@pytest.fixture
def data_queue(self):
return Queue()
@pytest.fixture
def server(self):
server = ProxyServer(port=5005)
yield server
server.graceful_shutdown()
@pytest.fixture
def server_with_dp(self, data_queue):
server = ProxyServer(port=5005)
server.register_data_handler(
lambda _, table: list(
map(data_queue.put, map(lambda t: t[1], table.to_pandas().iterrows()))
)
)
yield server
server.graceful_shutdown()
class MockFlightMetadataReader:
"""
MockFlightMetadataReader is a mocked FlightMetadataReader class to ultimately
mock a credit value to be returned from Scala server to Python client
"""
class MockBuffer:
def to_pybytes(self):
dummy_credit = 31
return dummy_credit.to_bytes(8, "little")
def read(self):
return self.MockBuffer()
@pytest.fixture
def client(self):
mock_client = ProxyClient()
def mock_do_put(
self,
FlightDescriptor_descriptor,
Schema_schema,
FlightCallOptions_options=None,
):
"""
Mocking FlightClient.do_put that is called in ProxyClient to return
a MockFlightMetadataReader instead of a FlightMetadataReader
:param self: an instance of FlightClient (would be ProxyClient in this case)
:param FlightDescriptor_descriptor: descriptor
:param Schema_schema: schema
:param FlightCallOptions_options: options, None by default
:return: writer : FlightStreamWriter, reader : MockFlightMetadataReader
"""
writer, _ = super(ProxyClient, self).do_put(
FlightDescriptor_descriptor, Schema_schema, FlightCallOptions_options
)
reader = TestProxyClient.MockFlightMetadataReader()
return writer, reader
mock_client.do_put = mock_do_put.__get__(
mock_client, ProxyClient
) # override do_put with mock_do_put
yield mock_client
@pytest.fixture
def data_table(self):
df_to_sent = DataFrame(
{
"Brand": ["Honda Civic", "Toyota Corolla", "Ford Focus", "Audi A4"],
"Price": [22000, 25000, 27000, 35000],
},
columns=["Brand", "Price"],
)
return Table.from_pandas(df_to_sent)
def test_client_can_connect_to_server(self, server, client):
assert client.call_action("heartbeat") == b"ack"
def test_client_can_shutdown_server(self, server, client):
assert client.call_action("shutdown") == b"Bye bye!"
def test_client_can_call_registered_lambdas(self, server, client):
action_count = len(client.list_actions())
server.register("hello", lambda: "hello")
server.register("this is another call", lambda: "ack!!!")
assert len(client.list_actions()) == action_count + 2
assert client.call_action("hello") == b"hello"
assert client.call_action("this is another call") == b"ack!!!"
assert client.call_action("shutdown") == b"Bye bye!"
def test_client_can_call_registered_function(self, server, client):
def hello():
return "hello-function"
action_count = len(client.list_actions())
server.register("hello-function", hello)
assert len(client.list_actions()) == action_count + 1
assert client.call_action("hello-function") == b"hello-function"
assert client.call_action("shutdown") == b"Bye bye!"
def test_client_can_call_registered_callable_class(self, server, client):
class HelloClass:
def __call__(self):
return "hello-class"
action_count = len(client.list_actions())
server.register("hello-class", HelloClass())
assert len(client.list_actions()) == action_count + 1
assert client.call_action("hello-class") == b"hello-class"
assert client.call_action("shutdown") == b"Bye bye!"
def test_client_cannot_send_data_without_handler(self, server, client, data_table):
# send the pyarrow table to server as a flight
with pytest.raises(ArrowNotImplementedError):
client.send_data(command=bytes(), table=data_table)
def test_client_can_send_data_with_handler(
self, data_queue: Queue, server_with_dp, client, data_table
):
# send the pyarrow table to server as a flight
client.send_data(bytes(), data_table)
assert data_queue.qsize() == 4
for i, row in data_table.to_pandas().iterrows():
assert data_queue.get().equals(row)
================================================
FILE: amber/src/test/python/core/proxy/test_proxy_server.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pyarrow.flight import Action
from core.proxy.proxy_server import ProxyServer
class TestProxyServer:
@pytest.fixture()
def server(self):
server = ProxyServer()
yield server
server.graceful_shutdown()
def test_server_can_register_control_actions_with_lambda(self, server):
assert "hello" not in server._procedures
server.register("hello", lambda: None)
assert "hello" in server._procedures
def test_server_can_register_control_actions_with_function(self, server):
def hello():
return None
assert "hello" not in server._procedures
server.register("hello", hello)
assert "hello" in server._procedures
def test_server_can_register_control_actions_with_callable_class(self, server):
class Hello:
def __call__(self):
return None
assert "hello" not in server._procedures
server.register("hello", Hello())
assert "hello" in server._procedures
def test_server_can_invoke_registered_control_actions(self, server):
procedure_contents = {
"hello": "hello world",
"get an int": 12,
"get a float": 1.23,
"get a tuple": (5, None, 123.4),
"get a list": [5, (None, 123.4)],
"get a dict": {"entry": [5, (None, 123.4)]},
}
for name, result in procedure_contents.items():
server.register(name, lambda: result)
assert name in server._procedures
assert next(
server.do_action(None, Action(name, b""))
).body.to_pybytes() == str(result).encode("utf-8")
================================================
FILE: amber/src/test/python/core/runnables/test_console_message.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import pytest
from core.models.internal_queue import InternalQueue
from core.util import set_one_of
from core.util.buffer.timed_buffer import TimedBuffer
from proto.org.apache.texera.amber.core import ActorVirtualIdentity, ChannelIdentity
from proto.org.apache.texera.amber.engine.architecture.rpc import (
ControlInvocation,
ControlRequest,
ConsoleMessage,
ConsoleMessageType,
)
from proto.org.apache.texera.amber.engine.common import (
DirectControlMessagePayloadV2,
PythonControlMessage,
)
class TestConsoleMessage:
@pytest.fixture
def internal_queue(self):
return InternalQueue()
@pytest.fixture
def timed_buffer(self):
return TimedBuffer()
@pytest.fixture
def console_message(self):
return ConsoleMessage(
worker_id="0",
timestamp=datetime.datetime.now(),
msg_type=ConsoleMessageType.PRINT,
source="pytest",
title="Test Message",
message="Test Message",
)
@pytest.fixture
def mock_controller_channel(self):
return ChannelIdentity(
ActorVirtualIdentity("CONTROLLER"), ActorVirtualIdentity("test"), True
)
@pytest.mark.timeout(2)
def test_console_message_serialization(
self, mock_controller_channel, console_message
):
"""
Test the serialization of the console message
:param mock_controller_channel: the mock control channel id
:param console_message: the test message
"""
# below statements wrap the console message as the python control message
command = set_one_of(ControlRequest, console_message)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="ConsoleMessageTriggered", command_id=1, command=command
),
)
python_control_message = PythonControlMessage(
tag=mock_controller_channel, payload=payload
)
# serialize the python control message to bytes
python_control_message_bytes = bytes(python_control_message)
# deserialize the control message from bytes
parsed_python_control_message = PythonControlMessage().parse(
python_control_message_bytes
)
# deserialized one should equal to the original one
assert python_control_message == parsed_python_control_message
================================================
FILE: amber/src/test/python/core/runnables/test_data_processor.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.architecture.managers import Context
from core.models import State
from core.models.internal_queue import InternalQueue
from core.models.internal_marker import EndChannel, StartChannel
from core.runnables.data_processor import DataProcessor
from proto.org.apache.texera.amber.engine.architecture.rpc import ConsoleMessageType
@pytest.fixture
def context():
return Context(worker_id="test-worker", input_queue=InternalQueue())
@pytest.fixture
def data_processor(context, monkeypatch):
"""
DataProcessor with `_switch_context` swapped for a counter so each test
can drive the synchronous parts of the per-call boilerplate without
blocking on the cross-thread handshake.
"""
dp = DataProcessor(context)
dp.switch_calls = 0
def fake_switch():
dp.switch_calls += 1
monkeypatch.setattr(dp, "_switch_context", fake_switch)
return dp
class _StubExecutor:
"""
Records what `process_internal_marker` invokes on it so the test can
assert the StartChannel / EndChannel branches of `data_processor`
without standing up a real Operator.
"""
def __init__(self):
self.calls = []
def produce_state_on_start(self, port_id):
self.calls.append(("produce_state_on_start", port_id))
return {"phase": "start"}
def produce_state_on_finish(self, port_id):
self.calls.append(("produce_state_on_finish", port_id))
return {"phase": "finish"}
def on_finish(self, port_id):
self.calls.append(("on_finish", port_id))
return iter([])
class TestProcessInternalMarker:
@pytest.mark.timeout(2)
def test_start_channel_invokes_produce_state_on_start(
self, context, data_processor
):
executor = _StubExecutor()
context.executor_manager.executor = executor
data_processor.process_internal_marker(StartChannel())
# StartChannel routes to produce_state_on_start with the current
# input port id (0 when no upstream is set), and the returned dict
# is wrapped into a State on the output slot.
assert executor.calls == [("produce_state_on_start", 0)]
out = context.state_processing_manager.current_output_state
assert isinstance(out, State)
assert out["phase"] == "start"
# `_executor_session` always switches once on exit.
assert data_processor.switch_calls == 1
@pytest.mark.timeout(2)
def test_end_channel_flushes_state_then_drains_on_finish(
self, context, data_processor
):
executor = _StubExecutor()
context.executor_manager.executor = executor
data_processor.process_internal_marker(EndChannel())
# EndChannel must call produce_state_on_finish first, switch
# context to flush that state separately from the on_finish
# tuple stream, then drain on_finish. The session itself adds
# its own trailing switch on exit.
assert executor.calls == [
("produce_state_on_finish", 0),
("on_finish", 0),
]
# 1 switch from the explicit flush + 1 from `_executor_session`
# exit. `_set_output_tuple` exits early on an empty iterator and
# does not switch.
assert data_processor.switch_calls == 2
class TestExecutorSession:
@pytest.mark.timeout(2)
def test_exception_inside_session_is_reported_before_the_switch(
self, context, data_processor
):
# Order matters: MainLoop's _check_exception flushes pending
# console messages and then immediately enters EXCEPTION_PAUSE,
# so the stack trace must already be in the buffer at the moment
# _executor_session calls _switch_context. Capture the buffer
# state from inside the fake switch to pin that ordering.
seen_at_switch = []
def capturing_switch():
seen_at_switch.extend(
context.console_message_manager.get_messages(force_flush=True)
)
data_processor.switch_calls += 1
data_processor._switch_context = capturing_switch
with data_processor._executor_session() as session:
assert session is not None
raise RuntimeError("boom-from-executor")
# Exception was routed into the manager so MainLoop's
# _check_exception can see it.
assert context.exception_manager.has_exception()
exc_info = context.exception_manager.get_exc_info()
assert exc_info[0] is RuntimeError
assert "boom-from-executor" in str(exc_info[1])
# And the stack-trace console message was queued *before* the
# finally-clause switch — without this, the worker would pause
# before ever sending the error to the controller.
assert len(seen_at_switch) == 1
msg = seen_at_switch[0]
assert msg.worker_id == "test-worker"
assert msg.msg_type == ConsoleMessageType.ERROR
assert "RuntimeError: boom-from-executor" in msg.title
# Exit always switches back to MainLoop, even on the failure path.
assert data_processor.switch_calls == 1
@pytest.mark.timeout(2)
def test_clean_session_does_not_record_an_exception(self, context, data_processor):
with data_processor._executor_session():
pass
assert not context.exception_manager.has_exception()
assert (
list(context.console_message_manager.get_messages(force_flush=True)) == []
)
# Even on the success path, the finally clause yields control
# back to MainLoop exactly once.
assert data_processor.switch_calls == 1
class TestRunInvariant:
"""
`run()` enforces that exactly one of marker / state / tuple is queued per
iteration. The invariant raises a RuntimeError otherwise — that branch
is otherwise unreachable in the integration tests, so cover it directly.
"""
@staticmethod
def _drive_run_synchronously(context, monkeypatch) -> DataProcessor:
# `run()` opens with a condition.wait() so MainLoop can hand off
# control. Stub that out so the test thread can call run() directly
# and reach the invariant check on the very first iteration without
# any cross-thread coordination.
cond = context.tuple_processing_manager.context_switch_condition
monkeypatch.setattr(cond, "wait", lambda *a, **kw: None)
return DataProcessor(context)
@pytest.mark.timeout(2)
def test_zero_queued_inputs_raises_invariant_error(self, context, monkeypatch):
dp = self._drive_run_synchronously(context, monkeypatch)
# Nothing is set on tpm/spm — has_marker + has_state + has_tuple == 0.
with pytest.raises(RuntimeError) as excinfo:
dp.run()
assert "expected exactly one queued input" in str(excinfo.value)
assert "marker=False, state=False, tuple=False" in str(excinfo.value)
@pytest.mark.timeout(2)
def test_two_queued_inputs_raises_invariant_error(self, context, monkeypatch):
dp = self._drive_run_synchronously(context, monkeypatch)
# Populate two slots — has_marker + has_tuple == 2.
context.tuple_processing_manager.current_internal_marker = StartChannel()
context.tuple_processing_manager.current_input_tuple = ("payload",)
with pytest.raises(RuntimeError) as excinfo:
dp.run()
assert "expected exactly one queued input" in str(excinfo.value)
assert "marker=True" in str(excinfo.value)
assert "tuple=True" in str(excinfo.value)
================================================
FILE: amber/src/test/python/core/runnables/test_heartbeat.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import socket
from threading import Event
from unittest.mock import patch, MagicMock
import pytest
from core.runnables.heartbeat import Heartbeat
def make_heartbeat(host="localhost", port=12345, interval=0.05, event=None):
return Heartbeat(host, port, interval, event or Event())
class TestHeartbeatInit:
def test_parses_host_and_port_from_grpc_tcp_url(self):
hb = make_heartbeat(host="example.test", port=9090)
assert hb._parsed_server_host == "example.test"
assert hb._parsed_server_port == 9090
def test_records_interval_and_stop_event_references(self):
event = Event()
hb = make_heartbeat(interval=2.5, event=event)
assert hb._interval == 2.5
assert hb._stop_event is event
def test_captures_original_parent_pid_at_construction_time(self):
with patch("core.runnables.heartbeat.os.getppid", return_value=4242):
hb = make_heartbeat()
assert hb._original_parent_pid == 4242
def test_supports_ipv6_host_in_bracketed_form(self):
hb = make_heartbeat(host="[::1]", port=9090)
assert hb._parsed_server_host == "::1"
assert hb._parsed_server_port == 9090
class TestCheckHeartbeat:
def test_returns_true_when_socket_connects(self):
hb = make_heartbeat(host="h", port=1)
fake_sock = MagicMock()
with patch(
"core.runnables.heartbeat.socket.create_connection",
return_value=fake_sock,
) as mock_connect:
assert hb._check_heartbeat() is True
mock_connect.assert_called_once_with(("h", 1), timeout=1)
fake_sock.close.assert_called_once()
def test_returns_false_when_socket_connection_raises(self):
hb = make_heartbeat()
with patch(
"core.runnables.heartbeat.socket.create_connection",
side_effect=ConnectionRefusedError("nope"),
):
assert hb._check_heartbeat() is False
def test_returns_false_when_socket_connection_times_out(self):
hb = make_heartbeat()
with patch(
"core.runnables.heartbeat.socket.create_connection",
side_effect=socket.timeout("slow"),
):
assert hb._check_heartbeat() is False
def test_returns_false_when_socket_close_raises(self):
# Pins the false-negative path: connect succeeds but the subsequent
# close() throws (e.g. broken pipe on a half-open socket). The bare
# `except Exception` in _check_heartbeat() catches it and reports
# "server down", and a regression that narrows that handler would be
# caught here.
hb = make_heartbeat()
fake_sock = MagicMock()
fake_sock.close.side_effect = OSError("close failed")
with patch(
"core.runnables.heartbeat.socket.create_connection",
return_value=fake_sock,
):
assert hb._check_heartbeat() is False
class TestRunEarlyExit:
@pytest.mark.timeout(2)
def test_returns_immediately_when_stop_event_is_already_set(self):
event = Event()
event.set()
hb = make_heartbeat(interval=10.0, event=event)
# Event.wait(timeout=10) returns immediately because the event is
# already set, so `while not self._stop_event.wait(...)` short-circuits
# before the loop body runs and _check_heartbeat() is never called.
# The pytest timeout above turns a regression that re-enters the loop
# (or blocks on wait()) into a fast failure rather than a hung CI job.
with patch.object(hb, "_check_heartbeat") as mock_check:
hb.run()
mock_check.assert_not_called()
@pytest.mark.parametrize("port", [1, 65535, 8080])
def test_init_accepts_full_port_range(port):
hb = make_heartbeat(port=port)
assert hb._parsed_server_port == port
================================================
FILE: amber/src/test/python/core/runnables/test_main_loop.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import inspect
import pandas
import pickle
import pyarrow
import pytest
import sys
import time
from threading import Thread
from core.models import (
DataFrame,
InternalQueue,
State,
StateFrame,
Tuple,
)
from core.models.internal_queue import (
DataElement,
DCMElement,
ECMElement,
)
from core.runnables import MainLoop
from core.util import set_one_of
from proto.org.apache.texera.amber.core import (
ActorVirtualIdentity,
PhysicalLink,
PhysicalOpIdentity,
OperatorIdentity,
ChannelIdentity,
PortIdentity,
OpExecWithCode,
OpExecInitInfo,
EmbeddedControlMessageIdentity,
)
from core.architecture.managers.pause_manager import PauseType
from core.util.console_message.timestamp import current_time_in_local_timezone
from proto.org.apache.texera.amber.engine.architecture.rpc import (
ControlRequest,
AssignPortRequest,
ControlInvocation,
AddInputChannelRequest,
InitializeExecutorRequest,
EmptyReturn,
ReturnInvocation,
ControlReturn,
WorkerMetricsResponse,
AddPartitioningRequest,
EmptyRequest,
PortCompletedRequest,
AsyncRpcContext,
WorkerStateResponse,
EmbeddedControlMessageType,
EmbeddedControlMessage,
ConsoleMessage,
ConsoleMessageType,
)
from proto.org.apache.texera.amber.engine.architecture.sendsemantics import (
OneToOnePartitioning,
Partitioning,
)
from proto.org.apache.texera.amber.engine.architecture.worker import (
WorkerMetrics,
WorkerState,
WorkerStatistics,
PortTupleMetricsMapping,
TupleMetrics,
)
from proto.org.apache.texera.amber.engine.common import DirectControlMessagePayloadV2
from pytexera.udf.examples.count_batch_operator import CountBatchOperator
from pytexera.udf.examples.echo_operator import EchoOperator
class TestMainLoop:
@pytest.fixture
def command_sequence(self):
return 1
@pytest.fixture
def mock_link(self):
return PhysicalLink(
from_op_id=PhysicalOpIdentity(OperatorIdentity("from"), "from"),
from_port_id=PortIdentity(0, internal=False),
to_op_id=PhysicalOpIdentity(OperatorIdentity("to"), "to"),
to_port_id=PortIdentity(0, internal=False),
)
@pytest.fixture
def mock_tuple(self):
return Tuple({"test-1": "hello", "test-2": 10})
@pytest.fixture
def mock_binary_tuple(self):
return Tuple({"test-1": [1, 2, 3, 4], "test-2": 10})
@pytest.fixture
def mock_batch(self):
batch_list = []
for i in range(57):
batch_list.append(Tuple({"test-1": "hello", "test-2": i}))
return batch_list
@pytest.fixture
def mock_sender_actor(self):
return ActorVirtualIdentity("sender")
@pytest.fixture
def mock_data_input_channel(self):
return ChannelIdentity(
ActorVirtualIdentity("sender"),
ActorVirtualIdentity("dummy_worker_id"),
False,
)
@pytest.fixture
def mock_data_output_channel(self):
return ChannelIdentity(
ActorVirtualIdentity("dummy_worker_id"),
ActorVirtualIdentity("dummy_worker_id"),
False,
)
@pytest.fixture
def mock_control_input_channel(self):
return ChannelIdentity(
ActorVirtualIdentity("CONTROLLER"),
ActorVirtualIdentity("dummy_worker_id"),
True,
)
@pytest.fixture
def mock_control_output_channel(self):
return ChannelIdentity(
ActorVirtualIdentity("dummy_worker_id"),
ActorVirtualIdentity("CONTROLLER"),
True,
)
@pytest.fixture
def mock_receiver_actor(self):
return ActorVirtualIdentity("dummy_worker_id")
@pytest.fixture
def mock_data_element(self, mock_tuple, mock_data_input_channel):
return DataElement(
tag=mock_data_input_channel,
payload=DataFrame(
frame=pyarrow.Table.from_pandas(
pandas.DataFrame([mock_tuple.as_dict()])
)
),
)
@pytest.fixture
def mock_state_data_elements(self, mock_data_input_channel):
elements = []
for value in (1, 2, 3, 4):
state = State({"value": value})
elements.append(
DataElement(
tag=mock_data_input_channel,
payload=StateFrame(frame=state),
)
)
return elements
@pytest.fixture
def state_processing_executor(self):
# In-process executor for the state-pipeline tests. Tags processed
# states with `processed_marker` and emits a finish-marker state
# from `produce_state_on_finish` so EndChannel handling can be
# observed.
class StateProcessingExecutor:
@staticmethod
def process_tuple(tuple_, port):
yield tuple_
@staticmethod
def process_state(state: State, port: int) -> State:
new_state = State(
{key: value for key, value in state.items() if key != "schema"}
)
new_state["processed_marker"] = "executed"
new_state["port"] = port
return new_state
@staticmethod
def produce_state_on_finish(port: int) -> State:
return State({"finish_marker": "produce_state_on_finish_ran"})
@staticmethod
def on_finish(port):
yield
@staticmethod
def close():
pass
return StateProcessingExecutor()
@pytest.fixture
def mock_binary_data_element(self, mock_binary_tuple, mock_data_input_channel):
return DataElement(
tag=mock_data_input_channel,
payload=DataFrame(
frame=pyarrow.Table.from_pandas(
pandas.DataFrame([mock_binary_tuple.as_dict()])
)
),
)
@pytest.fixture
def mock_batch_data_elements(self, mock_batch, mock_data_input_channel):
data_elements = []
for i in range(57):
mock_tuple = Tuple({"test-1": "hello", "test-2": i})
data_elements.append(
DataElement(
tag=mock_data_input_channel,
payload=DataFrame(
frame=pyarrow.Table.from_pandas(
pandas.DataFrame([mock_tuple.as_dict()])
)
),
)
)
return data_elements
@pytest.fixture
def mock_end_of_upstream(self, mock_tuple, mock_data_input_channel):
return ECMElement(
tag=mock_data_input_channel,
payload=EmbeddedControlMessage(
EmbeddedControlMessageIdentity("EndChannel"),
EmbeddedControlMessageType.PORT_ALIGNMENT,
[],
{
mock_data_input_channel.to_worker_id.name: ControlInvocation(
"EndChannel",
ControlRequest(empty_request=EmptyRequest()),
AsyncRpcContext(ActorVirtualIdentity(), ActorVirtualIdentity()),
-1,
)
},
),
)
@pytest.fixture
def input_queue(self):
return InternalQueue()
@pytest.fixture
def output_queue(self):
return InternalQueue()
@pytest.fixture
def mock_assign_input_port(
self, mock_raw_schema, mock_control_input_channel, mock_link, command_sequence
):
command = set_one_of(
ControlRequest,
AssignPortRequest(
port_id=mock_link.to_port_id, input=True, schema=mock_raw_schema
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="AssignPort", command_id=command_sequence, command=command
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_assign_output_port(
self, mock_raw_schema, mock_control_input_channel, command_sequence
):
command = set_one_of(
ControlRequest,
AssignPortRequest(
port_id=PortIdentity(id=0), input=False, schema=mock_raw_schema
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="AssignPort", command_id=command_sequence, command=command
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_assign_input_port_binary(
self,
mock_binary_raw_schema,
mock_control_input_channel,
mock_link,
command_sequence,
):
command = set_one_of(
ControlRequest,
AssignPortRequest(
port_id=mock_link.to_port_id, input=True, schema=mock_binary_raw_schema
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="AssignPort", command_id=command_sequence, command=command
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_assign_output_port_binary(
self, mock_binary_raw_schema, mock_control_input_channel, command_sequence
):
command = set_one_of(
ControlRequest,
AssignPortRequest(
port_id=PortIdentity(id=0), input=False, schema=mock_binary_raw_schema
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="AssignPort", command_id=command_sequence, command=command
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_add_input_channel(
self,
mock_control_input_channel,
mock_sender_actor,
mock_receiver_actor,
mock_link,
command_sequence,
):
command = set_one_of(
ControlRequest,
AddInputChannelRequest(
ChannelIdentity(
from_worker_id=mock_sender_actor,
to_worker_id=mock_receiver_actor,
is_control=False,
),
port_id=mock_link.to_port_id,
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="AddInputChannel",
command_id=command_sequence,
command=command,
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_raw_schema(self):
return {"test-1": "STRING", "test-2": "INTEGER"}
@pytest.fixture
def mock_binary_raw_schema(self):
return {"test-1": "BINARY", "test-2": "INTEGER"}
@pytest.fixture
def mock_initialize_executor(
self,
mock_control_input_channel,
mock_sender_actor,
mock_link,
command_sequence,
mock_raw_schema,
):
operator_code = "from pytexera import *\n" + inspect.getsource(EchoOperator)
command = set_one_of(
ControlRequest,
InitializeExecutorRequest(
op_exec_init_info=set_one_of(
OpExecInitInfo, OpExecWithCode(operator_code, "python")
),
is_source=False,
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="InitializeExecutor",
command_id=command_sequence,
command=command,
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_initialize_batch_count_executor(
self,
mock_control_input_channel,
mock_sender_actor,
mock_link,
command_sequence,
mock_raw_schema,
):
operator_code = "from pytexera import *\n" + inspect.getsource(
CountBatchOperator
)
command = set_one_of(
ControlRequest,
InitializeExecutorRequest(
op_exec_init_info=set_one_of(
OpExecInitInfo, OpExecWithCode(operator_code, "python")
),
is_source=False,
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="InitializeExecutor",
command_id=command_sequence,
command=command,
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_add_partitioning(
self,
mock_control_input_channel,
mock_receiver_actor,
command_sequence,
mock_link,
):
command = set_one_of(
ControlRequest,
AddPartitioningRequest(
tag=mock_link,
partitioning=set_one_of(
Partitioning,
OneToOnePartitioning(
batch_size=1,
channels=[
ChannelIdentity(
from_worker_id=ActorVirtualIdentity("dummy_worker_id"),
to_worker_id=mock_receiver_actor,
is_control=False,
)
],
),
),
),
)
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="AddPartitioning",
command_id=command_sequence,
command=command,
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_query_statistics(
self, mock_control_input_channel, mock_sender_actor, command_sequence
):
command = set_one_of(ControlRequest, EmptyRequest())
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="QueryStatistics",
command_id=command_sequence,
command=command,
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_pause(
self, mock_control_input_channel, mock_sender_actor, command_sequence
):
command = set_one_of(ControlRequest, EmptyRequest())
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="PauseWorker", command_id=command_sequence, command=command
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def mock_resume(
self, mock_control_input_channel, mock_sender_actor, command_sequence
):
command = set_one_of(ControlRequest, EmptyRequest())
payload = set_one_of(
DirectControlMessagePayloadV2,
ControlInvocation(
method_name="ResumeWorker", command_id=command_sequence, command=command
),
)
return DCMElement(tag=mock_control_input_channel, payload=payload)
@pytest.fixture
def main_loop(self, input_queue, output_queue, mock_link):
main_loop = MainLoop("dummy_worker_id", input_queue, output_queue)
yield main_loop
main_loop.stop()
@pytest.fixture
def main_loop_thread(self, main_loop, reraise):
def wrapper():
with reraise:
main_loop.run()
main_loop_thread = Thread(target=wrapper, name="main_loop_thread")
yield main_loop_thread
@staticmethod
def check_batch_rank_sum(
executor,
input_queue,
mock_batch_data_elements,
output_data_elements,
output_queue,
mock_batch,
start,
end,
count,
):
# Checking the rank sum of each batch to make sure the accuracy
for i in range(start, end):
input_queue.put(mock_batch_data_elements[i])
rank_sum_real = 0
rank_sum_suppose = 0
for i in range(start, end):
output_data_elements.append(output_queue.get())
rank_sum_real += output_data_elements[i].payload.frame[0]["test-2"]
rank_sum_suppose += mock_batch[i]["test-2"]
assert executor.count == count
assert rank_sum_real == rank_sum_suppose
@pytest.mark.timeout(2)
def test_main_loop_thread_can_start(self, main_loop_thread):
main_loop_thread.start()
assert main_loop_thread.is_alive()
@pytest.mark.timeout(2)
def test_main_loop_thread_can_process_messages(
self,
mock_link,
mock_data_input_channel,
mock_data_output_channel,
mock_control_input_channel,
mock_control_output_channel,
input_queue,
output_queue,
mock_data_element,
main_loop_thread,
mock_assign_input_port,
mock_assign_output_port,
mock_add_input_channel,
mock_add_partitioning,
mock_initialize_executor,
mock_end_of_upstream,
mock_query_statistics,
mock_tuple,
command_sequence,
reraise,
):
main_loop_thread.start()
# can process AssignPort
input_queue.put(mock_assign_input_port)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
input_queue.put(mock_assign_output_port)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddInputChannel
input_queue.put(mock_add_input_channel)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddPartitioning
input_queue.put(mock_add_partitioning)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process InitializeExecutor
input_queue.put(mock_initialize_executor)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process a DataFrame
input_queue.put(mock_data_element)
output_data_element: DataElement = output_queue.get()
assert output_data_element.tag == mock_data_output_channel
assert isinstance(output_data_element.payload, DataFrame)
data_frame: DataFrame = output_data_element.payload
assert len(data_frame.frame) == 1
assert Tuple(data_frame.frame.to_pylist()[0]) == mock_tuple
# can process QueryStatistics
input_queue.put(mock_query_statistics)
elem = output_queue.get()
stats_invocation = elem.payload.return_invocation
worker_metrics_response = stats_invocation.return_value.worker_metrics_response
stats = worker_metrics_response.metrics.worker_statistics
metrics = WorkerMetrics(
worker_state=WorkerState.RUNNING,
worker_statistics=WorkerStatistics(
input_tuple_metrics=[
PortTupleMetricsMapping(
PortIdentity(0),
TupleMetrics(
1,
stats.input_tuple_metrics[0].tuple_metrics.size,
),
)
],
output_tuple_metrics=[
PortTupleMetricsMapping(
PortIdentity(0),
TupleMetrics(
1,
stats.output_tuple_metrics[0].tuple_metrics.size,
),
)
],
data_processing_time=stats.data_processing_time,
control_processing_time=stats.control_processing_time,
idle_time=stats.idle_time,
),
)
assert elem == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=1,
return_value=ControlReturn(
worker_metrics_response=WorkerMetricsResponse(metrics=metrics),
),
),
),
)
input_queue.put(mock_end_of_upstream)
output_queue.disable_data(InternalQueue.DisableType.DISABLE_BY_PAUSE)
# the input port should complete
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
control_invocation=ControlInvocation(
method_name="PortCompleted",
command_id=0,
context=AsyncRpcContext(
sender=ActorVirtualIdentity(name="dummy_worker_id"),
receiver=ActorVirtualIdentity(name="CONTROLLER"),
),
command=ControlRequest(
port_completed_request=PortCompletedRequest(
port_id=mock_link.to_port_id, input=True
)
),
)
),
)
# the output port should complete
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
control_invocation=ControlInvocation(
method_name="PortCompleted",
command_id=1,
context=AsyncRpcContext(
sender=ActorVirtualIdentity(name="dummy_worker_id"),
receiver=ActorVirtualIdentity(name="CONTROLLER"),
),
command=ControlRequest(
port_completed_request=PortCompletedRequest(
port_id=PortIdentity(id=0), input=False
)
),
)
),
)
# WorkerExecutionCompletedV2 should be triggered when workflow finishes
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
control_invocation=ControlInvocation(
method_name="WorkerExecutionCompleted",
command_id=2,
context=AsyncRpcContext(
sender=ActorVirtualIdentity(name="dummy_worker_id"),
receiver=ActorVirtualIdentity(name="CONTROLLER"),
),
command=ControlRequest(empty_request=EmptyRequest()),
)
),
)
output_queue.enable_data(InternalQueue.DisableType.DISABLE_BY_PAUSE)
assert output_queue.get() == ECMElement(
tag=mock_data_output_channel,
payload=EmbeddedControlMessage(
EmbeddedControlMessageIdentity("EndChannel"),
EmbeddedControlMessageType.PORT_ALIGNMENT,
[],
{
mock_data_output_channel.to_worker_id.name: ControlInvocation(
"EndChannel",
ControlRequest(empty_request=EmptyRequest()),
AsyncRpcContext(ActorVirtualIdentity(), ActorVirtualIdentity()),
-1,
)
},
),
)
# can process ReturnInvocation
input_queue.put(
DCMElement(
tag=mock_control_input_channel,
payload=set_one_of(
DirectControlMessagePayloadV2,
ReturnInvocation(
command_id=0,
return_value=ControlReturn(empty_return=EmptyReturn()),
),
),
)
)
reraise()
@pytest.mark.timeout(5)
def test_batch_dp_thread_can_process_batch(
self,
mock_control_input_channel,
mock_control_output_channel,
mock_data_input_channel,
mock_data_output_channel,
mock_link,
input_queue,
output_queue,
mock_receiver_actor,
main_loop,
main_loop_thread,
mock_query_statistics,
mock_assign_input_port,
mock_assign_output_port,
mock_add_input_channel,
mock_add_partitioning,
mock_pause,
mock_resume,
mock_initialize_batch_count_executor,
mock_batch,
mock_batch_data_elements,
mock_end_of_upstream,
command_sequence,
reraise,
):
main_loop_thread.start()
# can process AssignPort
input_queue.put(mock_assign_input_port)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
input_queue.put(mock_assign_output_port)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddInputChannel
input_queue.put(mock_add_input_channel)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddPartitioning
input_queue.put(mock_add_partitioning)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process InitializeExecutor
input_queue.put(mock_initialize_batch_count_executor)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
executor = main_loop.context.executor_manager.executor
output_data_elements = []
# can process a DataFrame
executor.BATCH_SIZE = 10
for i in range(13):
input_queue.put(mock_batch_data_elements[i])
for i in range(10):
output_data_elements.append(output_queue.get())
self.send_pause(
command_sequence,
input_queue,
mock_control_output_channel,
mock_pause,
output_queue,
)
# input queue 13, output queue 10, batch_buffer 3
assert executor.count == 1
executor.BATCH_SIZE = 20
self.send_resume(
command_sequence,
input_queue,
mock_control_output_channel,
mock_resume,
output_queue,
)
for i in range(13, 41):
input_queue.put(mock_batch_data_elements[i])
for i in range(20):
output_data_elements.append(output_queue.get())
self.send_pause(
command_sequence,
input_queue,
mock_control_output_channel,
mock_pause,
output_queue,
)
# input queue 41, output queue 30, batch_buffer 11
assert executor.count == 2
executor.BATCH_SIZE = 5
self.send_resume(
command_sequence,
input_queue,
mock_control_output_channel,
mock_resume,
output_queue,
)
input_queue.put(mock_batch_data_elements[41])
input_queue.put(mock_batch_data_elements[42])
for i in range(10):
output_data_elements.append(output_queue.get())
self.send_pause(
command_sequence,
input_queue,
mock_control_output_channel,
mock_pause,
output_queue,
)
# input queue 43, output queue 40, batch_buffer 3
assert executor.count == 4
self.send_resume(
command_sequence,
input_queue,
mock_control_output_channel,
mock_resume,
output_queue,
)
for i in range(43, 57):
input_queue.put(mock_batch_data_elements[i])
for i in range(15):
output_data_elements.append(output_queue.get())
self.send_pause(
command_sequence,
input_queue,
mock_control_output_channel,
mock_pause,
output_queue,
)
# input queue 57, output queue 55, batch_buffer 2
assert executor.count == 7
self.send_resume(
command_sequence,
input_queue,
mock_control_output_channel,
mock_resume,
output_queue,
)
input_queue.put(mock_end_of_upstream)
for i in range(2):
output_data_elements.append(output_queue.get())
# check the batch count
assert main_loop.context.executor_manager.executor.count == 8
assert output_data_elements[0].tag == mock_data_output_channel
assert isinstance(output_data_elements[0].payload, DataFrame)
data_frame: DataFrame = output_data_elements[0].payload
assert len(data_frame.frame) == 1
assert Tuple(data_frame.frame.to_pylist()[0]) == Tuple(mock_batch[0])
reraise()
@pytest.mark.timeout(5)
def test_main_loop_thread_can_process_single_tuple_with_binary(
self,
mock_link,
mock_data_input_channel,
mock_data_output_channel,
mock_control_output_channel,
mock_control_input_channel,
input_queue,
output_queue,
mock_binary_tuple,
mock_binary_data_element,
main_loop_thread,
mock_assign_input_port_binary,
mock_assign_output_port_binary,
mock_add_input_channel,
mock_add_partitioning,
mock_initialize_executor,
mock_end_of_upstream,
mock_query_statistics,
command_sequence,
reraise,
):
main_loop_thread.start()
# can process AssignPort
input_queue.put(mock_assign_input_port_binary)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
input_queue.put(mock_assign_output_port_binary)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddInputChannel
input_queue.put(mock_add_input_channel)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddPartitioning
input_queue.put(mock_add_partitioning)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process InitializeExecutor
input_queue.put(mock_initialize_executor)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
input_queue.put(mock_binary_data_element)
output_data_element: DataElement = output_queue.get()
assert output_data_element.tag == mock_data_output_channel
assert isinstance(output_data_element.payload, DataFrame)
data_frame: DataFrame = output_data_element.payload
assert len(data_frame.frame) == 1
assert data_frame.frame.to_pylist()[0][
"test-1"
] == b"pickle " + pickle.dumps(mock_binary_tuple["test-1"])
reraise()
@staticmethod
def send_pause(
command_sequence,
input_queue,
mock_control_output_channel,
mock_pause,
output_queue,
):
input_queue.put(mock_pause)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(
worker_state_response=WorkerStateResponse(WorkerState.PAUSED)
),
)
),
)
@staticmethod
def send_resume(
command_sequence,
input_queue,
mock_control_output_channel,
mock_resume,
output_queue,
):
input_queue.put(mock_resume)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(
worker_state_response=WorkerStateResponse(WorkerState.RUNNING)
),
)
),
)
@pytest.mark.timeout(2)
def test_process_state_can_emit_consecutive_states(
self,
main_loop,
output_queue,
mock_data_output_channel,
monkeypatch,
):
class DummyExecutor:
@staticmethod
def process_state(state, port: int):
return State({"value": state["value"] + 1, "port": port})
main_loop.context.executor_manager.executor = DummyExecutor()
monkeypatch.setattr(main_loop, "_check_and_process_control", lambda: None)
monkeypatch.setattr(
main_loop.context.output_manager,
"emit_state",
lambda state: [(mock_data_output_channel.to_worker_id, StateFrame(state))],
)
def fake_switch_context():
current_input_state = (
main_loop.context.state_processing_manager.current_input_state
)
if current_input_state is not None:
main_loop.context.state_processing_manager.current_output_state = (
DummyExecutor.process_state(current_input_state, 0)
)
monkeypatch.setattr(main_loop, "_switch_context", fake_switch_context)
first_state = State({"value": 1})
second_state = State({"value": 41})
main_loop._process_state(first_state)
main_loop._process_state(second_state)
first_output: DataElement = output_queue.get()
second_output: DataElement = output_queue.get()
assert first_output.tag == mock_data_output_channel
assert isinstance(first_output.payload, StateFrame)
assert first_output.payload.frame["value"] == 2
assert first_output.payload.frame["port"] == 0
assert second_output.tag == mock_data_output_channel
assert isinstance(second_output.payload, StateFrame)
assert second_output.payload.frame["value"] == 42
assert second_output.payload.frame["port"] == 0
@pytest.mark.timeout(5)
def test_main_loop_thread_can_align_ecm(
self,
mock_link,
mock_data_input_channel,
mock_data_output_channel,
mock_control_output_channel,
mock_control_input_channel,
input_queue,
output_queue,
mock_binary_tuple,
mock_binary_data_element,
main_loop_thread,
mock_assign_input_port_binary,
mock_assign_output_port_binary,
mock_add_input_channel,
mock_add_partitioning,
mock_initialize_executor,
mock_end_of_upstream,
mock_query_statistics,
command_sequence,
reraise,
):
main_loop_thread.start()
# can process AssignPort
input_queue.put(mock_assign_input_port_binary)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
input_queue.put(mock_assign_output_port_binary)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddInputChannel
input_queue.put(mock_add_input_channel)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process AddPartitioning
input_queue.put(mock_add_partitioning)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# can process InitializeExecutor
input_queue.put(mock_initialize_executor)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
scope = [mock_control_input_channel, mock_data_input_channel]
command_mapping = {
mock_control_input_channel.to_worker_id.name: ControlInvocation(
"NoOperation", EmptyRequest(), AsyncRpcContext(), 98
)
}
test_ecm = EmbeddedControlMessage(
"test_ecm", EmbeddedControlMessageType.ALL_ALIGNMENT, scope, command_mapping
)
input_queue.put(ECMElement(tag=mock_control_input_channel, payload=test_ecm))
input_queue.put(mock_binary_data_element)
input_queue.put(ECMElement(tag=mock_data_input_channel, payload=test_ecm))
# The two outputs land on different channel sub-queues:
# - DataElement on the data channel to the downstream worker
# - DCMElement (NoOperation reply) on the control channel back to "sender"
# output_queue is a priority multi-queue. With both items present,
# the control sub-queue (priority 1) outranks the data sub-queue
# (priority 2), so the control reply must come out first. Wait for
# both channels to have their item before popping, so the priority
# guarantee is what we're actually testing — see #4524.
control_reply_channel = ChannelIdentity(
ActorVirtualIdentity("dummy_worker_id"),
ActorVirtualIdentity("sender"),
is_control=True,
)
def channel_size(channel: ChannelIdentity) -> int:
# Sub-queues are added lazily on first put, so the channel may not
# exist in the LBMQ yet. Treat that as size zero.
if channel not in output_queue._queue.sub_queues:
return 0
return output_queue._queue.size(channel)
deadline = time.time() + 5.0
while channel_size(mock_data_output_channel) == 0 or (
channel_size(control_reply_channel) == 0
):
if time.time() > deadline:
raise AssertionError(
f"timed out waiting for outputs on both channels; "
f"data={channel_size(mock_data_output_channel)}, "
f"control={channel_size(control_reply_channel)}"
)
time.sleep(0.001)
# Priority pulls control before data when both are queued.
output_control_element = output_queue.get()
assert isinstance(output_control_element, DCMElement), (
f"expected control reply first (priority), got {type(output_control_element).__name__}"
)
assert output_control_element.tag == control_reply_channel
assert output_control_element.payload.return_invocation.command_id == 98
assert (
output_control_element.payload.return_invocation.return_value
== ControlReturn(empty_return=EmptyReturn())
)
output_data_element = output_queue.get()
assert isinstance(output_data_element, DataElement), (
f"expected data element second, got {type(output_data_element).__name__}"
)
assert output_data_element.tag == mock_data_output_channel
assert isinstance(output_data_element.payload, DataFrame)
data_frame: DataFrame = output_data_element.payload
assert len(data_frame.frame) == 1
assert data_frame.frame.to_pylist()[0][
"test-1"
] == b"pickle " + pickle.dumps(mock_binary_tuple["test-1"])
reraise()
@pytest.mark.timeout(2)
def test_process_state_can_emit_multiple_states(
self,
main_loop,
output_queue,
mock_data_output_channel,
monkeypatch,
):
# Stub-level coverage of the single-switch state handshake. Each
# call to the (stubbed) _switch_context simulates DataProc
# consuming the queued input state and writing
# current_output_state, mirroring what real DataProc.process_state
# does between MainLoop's switches.
class DummyExecutor:
@staticmethod
def process_state(state: State, port: int) -> State:
return State({"value": state["value"] + 1, "port": port})
main_loop.context.executor_manager.executor = DummyExecutor()
monkeypatch.setattr(main_loop, "_check_and_process_control", lambda: None)
monkeypatch.setattr(
main_loop.context.output_manager,
"emit_state",
lambda state: [(mock_data_output_channel.to_worker_id, StateFrame(state))],
)
def fake_switch_context():
current_input_state = (
main_loop.context.state_processing_manager.current_input_state
)
if current_input_state is not None:
main_loop.context.state_processing_manager.current_output_state = (
DummyExecutor.process_state(current_input_state, 0)
)
monkeypatch.setattr(main_loop, "_switch_context", fake_switch_context)
first_state = State({"value": 1})
second_state = State({"value": 41})
main_loop._process_state(first_state)
main_loop._process_state(second_state)
first_output: DataElement = output_queue.get()
second_output: DataElement = output_queue.get()
assert first_output.tag == mock_data_output_channel
assert isinstance(first_output.payload, StateFrame)
assert first_output.payload.frame["value"] == 2
assert first_output.payload.frame["port"] == 0
assert second_output.tag == mock_data_output_channel
assert isinstance(second_output.payload, StateFrame)
assert second_output.payload.frame["value"] == 42
assert second_output.payload.frame["port"] == 0
@pytest.mark.timeout(2)
def test_main_loop_thread_can_process_state(
self,
mock_data_output_channel,
mock_control_output_channel,
input_queue,
output_queue,
main_loop,
main_loop_thread,
mock_assign_input_port,
mock_assign_output_port,
mock_add_input_channel,
mock_add_partitioning,
mock_initialize_executor,
mock_state_data_elements,
mock_end_of_upstream,
state_processing_executor,
command_sequence,
reraise,
):
# End-to-end coverage of the state-processing path through the real
# MainLoop + DataProcessor threads. The single-switch state handshake
# in MainLoop.process_input_state means each state is emitted in its
# own cycle (no lag), and an EndChannel ECM after the last state
# produces an additional output via produce_state_on_finish.
main_loop_thread.start()
for setup_msg in [
mock_assign_input_port,
mock_assign_output_port,
mock_add_input_channel,
mock_add_partitioning,
mock_initialize_executor,
]:
input_queue.put(setup_msg)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
# Going through the InitializeExecutor RPC above sets up the rest of
# the worker state (output schema, partitioning bookkeeping). Swap
# the executor instance with the test helper here so the test can
# assert the executor's process_state and produce_state_on_finish
# actually ran, without depending on Python's cross-test module
# caching for operator classes loaded via OpExecWithCode.
main_loop.context.executor_manager.executor = state_processing_executor
# Send four states. With the lag-free state pipeline we expect each
# state to produce its own output in order.
for state_element in mock_state_data_elements:
input_queue.put(state_element)
for expected_value in (1, 2, 3, 4):
output_data_element: DataElement = output_queue.get()
assert output_data_element.tag == mock_data_output_channel
assert isinstance(output_data_element.payload, StateFrame), (
f"expected StateFrame for value={expected_value}, got "
f"{type(output_data_element.payload).__name__}"
)
output_state = output_data_element.payload.frame
assert output_state["value"] == expected_value, (
f"state outputs arrived out of order: expected value="
f"{expected_value}, got value={output_state['value']}"
)
assert output_state["processed_marker"] == "executed"
assert output_state["port"] == 0
# Send EndChannel to drive _process_end_channel. The executor's
# produce_state_on_finish writes a finish-marker state into
# current_output_state inside DataProc's process_internal_marker;
# MainLoop's process_input_state then emits it.
input_queue.put(mock_end_of_upstream)
# Drain the control reply messages so the next data
# output_queue.get() returns the post-EndChannel data emission.
output_queue.disable_data(InternalQueue.DisableType.DISABLE_BY_PAUSE)
for _ in range(3):
control_reply = output_queue.get()
assert isinstance(control_reply, DCMElement), (
f"expected DCMElement during EndChannel teardown, got "
f"{type(control_reply).__name__}"
)
output_queue.enable_data(InternalQueue.DisableType.DISABLE_BY_PAUSE)
end_channel_state_output: DataElement = output_queue.get()
assert end_channel_state_output.tag == mock_data_output_channel
assert isinstance(end_channel_state_output.payload, StateFrame), (
f"expected StateFrame for the EndChannel-driven emission, got "
f"{type(end_channel_state_output.payload).__name__}"
)
end_channel_state = end_channel_state_output.payload.frame
assert "finish_marker" in end_channel_state, (
f"EndChannel emission should be the finish-marker state from "
f"produce_state_on_finish, got {end_channel_state!r}"
)
assert end_channel_state["finish_marker"] == "produce_state_on_finish_ran"
reraise()
@pytest.mark.timeout(2)
def test_main_loop_thread_can_process_state_after_tuple(
self,
mock_data_output_channel,
mock_control_output_channel,
input_queue,
output_queue,
main_loop,
main_loop_thread,
mock_assign_input_port,
mock_assign_output_port,
mock_add_input_channel,
mock_add_partitioning,
mock_initialize_executor,
mock_data_element,
mock_state_data_elements,
state_processing_executor,
command_sequence,
reraise,
):
# Coverage for the mixed (tuple, then state) input sequence: a
# tuple followed by several state DataElements should still emit
# every state's processed output in order.
main_loop_thread.start()
for setup_msg in [
mock_assign_input_port,
mock_assign_output_port,
mock_add_input_channel,
mock_add_partitioning,
mock_initialize_executor,
]:
input_queue.put(setup_msg)
assert output_queue.get() == DCMElement(
tag=mock_control_output_channel,
payload=DirectControlMessagePayloadV2(
return_invocation=ReturnInvocation(
command_id=command_sequence,
return_value=ControlReturn(empty_return=EmptyReturn()),
)
),
)
main_loop.context.executor_manager.executor = state_processing_executor
# Tuple first, then four states.
input_queue.put(mock_data_element)
warmup_output: DataElement = output_queue.get()
assert warmup_output.tag == mock_data_output_channel
assert isinstance(warmup_output.payload, DataFrame)
for state_element in mock_state_data_elements:
input_queue.put(state_element)
for expected_value in (1, 2, 3, 4):
output_data_element: DataElement = output_queue.get()
assert output_data_element.tag == mock_data_output_channel
assert isinstance(output_data_element.payload, StateFrame), (
f"expected StateFrame for value={expected_value}, got "
f"{type(output_data_element.payload).__name__}"
)
output_state = output_data_element.payload.frame
assert output_state["value"] == expected_value, (
f"state outputs after a tuple arrived out of order: "
f"expected value={expected_value}, "
f"got value={output_state['value']}"
)
assert output_state["processed_marker"] == "executed"
reraise()
@pytest.mark.timeout(2)
def test_console_message_rpc_fires_before_exception_pause(
self, main_loop, monkeypatch
):
# Pin the controller-facing contract: when DataProcessor raises
# during an executor call, the stack-trace ConsoleMessage must
# reach the controller *before* the worker enters EXCEPTION_PAUSE
# — otherwise the UI sees a paused worker with no error to show
# until the user resumes. The DataProcessor side queues the
# message before the switch (covered by
# test_data_processor.TestExecutorSession); this test pins the
# MainLoop side: post-switch hook flushes RPCs first, pauses last.
events = []
monkeypatch.setattr(
main_loop,
"_send_console_message",
lambda msg: events.append(("rpc", msg)),
)
monkeypatch.setattr(
main_loop.context.pause_manager,
"pause",
lambda pause_type, change_state=True: events.append(("pause", pause_type)),
)
try:
raise RuntimeError("boom-from-executor")
except RuntimeError:
exc_info = sys.exc_info()
main_loop.context.exception_manager.set_exception_info(exc_info)
main_loop.context.console_message_manager.put_message(
ConsoleMessage(
worker_id="dummy_worker_id",
timestamp=current_time_in_local_timezone(),
msg_type=ConsoleMessageType.ERROR,
source="test:_capture_exc_info:0",
title="RuntimeError: boom-from-executor",
message="RuntimeError: boom-from-executor",
)
)
main_loop._post_switch_context_checks()
kinds = [e[0] for e in events]
assert kinds == ["rpc", "pause"], (
"console message must reach controller before pause; "
f"observed order: {kinds}"
)
assert events[0][1].msg_type == ConsoleMessageType.ERROR
assert "boom-from-executor" in events[0][1].title
assert events[1][1] is PauseType.EXCEPTION_PAUSE
================================================
FILE: amber/src/test/python/core/runnables/test_network_receiver.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import threading
from pyarrow import Table
from core.models.internal_queue import (
InternalQueue,
DCMElement,
DataElement,
ECMElement,
)
from core.models.payload import DataFrame, StateFrame
from core.models.state import State
from core.proxy import ProxyClient
from core.runnables.network_receiver import NetworkReceiver
from core.runnables.network_sender import NetworkSender
from core.util.proto import set_one_of
from proto.org.apache.texera.amber.core import (
ActorVirtualIdentity,
ChannelIdentity,
EmbeddedControlMessageIdentity,
)
from proto.org.apache.texera.amber.engine.architecture.rpc import (
ControlInvocation,
EmbeddedControlMessage,
EmbeddedControlMessageType,
EmptyRequest,
AsyncRpcContext,
ControlRequest,
)
from proto.org.apache.texera.amber.engine.common import DirectControlMessagePayloadV2
class TestNetworkReceiver:
@pytest.fixture
def input_queue(self):
return InternalQueue()
@pytest.fixture
def output_queue(self):
return InternalQueue()
@pytest.fixture
def network_receiver(self, output_queue):
network_receiver = NetworkReceiver(output_queue, host="localhost", port=5555)
yield network_receiver
network_receiver.stop()
class MockFlightMetadataReader:
"""
MockFlightMetadataReader is a mocked FlightMetadataReader class to ultimately
mock a credit value to be returned from Scala server to Python client
"""
class MockBuffer:
def to_pybytes(self):
dummy_credit = 31
return dummy_credit.to_bytes(8, "little")
def read(self):
return self.MockBuffer()
@pytest.fixture
def network_sender_thread(self, input_queue):
network_sender = NetworkSender(input_queue, host="localhost", port=5555)
# mocking do_put, read, to_pybytes to return fake credit values
def mock_do_put(
self,
FlightDescriptor_descriptor,
Schema_schema,
FlightCallOptions_options=None,
):
"""
Mocking FlightClient.do_put that is called in ProxyClient to return
a MockFlightMetadataReader instead of a FlightMetadataReader
:param self: an instance of FlightClient (would be ProxyClient in this case)
:param FlightDescriptor_descriptor: descriptor
:param Schema_schema: schema
:param FlightCallOptions_options: options, None by default
:return: writer : FlightStreamWriter, reader : MockFlightMetadataReader
"""
writer, _ = super(ProxyClient, self).do_put(
FlightDescriptor_descriptor, Schema_schema, FlightCallOptions_options
)
reader = TestNetworkReceiver.MockFlightMetadataReader()
return writer, reader
mock_proxy_client = network_sender._proxy_client
mock_proxy_client.do_put = mock_do_put.__get__(
mock_proxy_client, ProxyClient
) # override do_put with mock_do_put
network_sender_thread = threading.Thread(target=network_sender.run)
yield network_sender_thread
network_sender.stop()
@pytest.fixture
def data_payload(self):
return DataFrame(
frame=Table.from_pydict(
{
"Brand": ["Honda Civic", "Toyota Corolla", "Ford Focus", "Audi A4"],
"Price": [22000, 25000, 27000, 35000],
}
)
)
@pytest.mark.timeout(10)
def test_network_receiver_can_receive_data_messages(
self,
data_payload,
output_queue,
input_queue,
network_receiver,
network_sender_thread,
):
network_sender_thread.start()
worker_id = ActorVirtualIdentity(name="test")
channel_id = ChannelIdentity(worker_id, worker_id, False)
input_queue.put(DataElement(tag=channel_id, payload=data_payload))
element: DataElement = output_queue.get()
assert len(element.payload.frame) == len(data_payload.frame)
assert element.tag == channel_id
@pytest.mark.timeout(10)
def test_network_receiver_can_receive_consecutive_state_messages(
self,
output_queue,
input_queue,
network_receiver,
network_sender_thread,
):
network_sender_thread.start()
worker_id = ActorVirtualIdentity(name="test")
channel_id = ChannelIdentity(worker_id, worker_id, False)
input_queue.put(
DataElement(
tag=channel_id,
payload=StateFrame(State({"loop_counter": 0, "i": 1})),
)
)
input_queue.put(
DataElement(
tag=channel_id,
payload=StateFrame(State({"loop_counter": 1, "i": 2})),
)
)
first_element: DataElement = output_queue.get()
second_element: DataElement = output_queue.get()
assert isinstance(first_element.payload, StateFrame)
assert first_element.payload.frame == {"loop_counter": 0, "i": 1}
assert first_element.tag == channel_id
assert isinstance(second_element.payload, StateFrame)
assert second_element.payload.frame == {"loop_counter": 1, "i": 2}
assert second_element.tag == channel_id
@pytest.mark.timeout(10)
def test_network_receiver_can_receive_control_messages(
self,
data_payload,
output_queue,
input_queue,
network_receiver,
network_sender_thread,
):
worker_id = ActorVirtualIdentity(name="test")
control_payload = set_one_of(DirectControlMessagePayloadV2, ControlInvocation())
channel_id = ChannelIdentity(worker_id, worker_id, False)
input_queue.put(DCMElement(tag=channel_id, payload=control_payload))
network_sender_thread.start()
element: DCMElement = output_queue.get()
assert element.payload == control_payload
assert element.tag == channel_id
@pytest.mark.timeout(10)
def test_network_receiver_can_receive_ecm(
self,
output_queue,
input_queue,
network_receiver,
network_sender_thread,
):
network_sender_thread.start()
worker_id = ActorVirtualIdentity(name="test")
channel_id = ChannelIdentity(worker_id, worker_id, False)
ecm_id = EmbeddedControlMessageIdentity("test_ecm")
scope = [channel_id]
rpc_context = AsyncRpcContext(worker_id, worker_id)
command_mapping = {
str(worker_id): ControlInvocation(
"NoOperation",
ControlRequest(empty_request=EmptyRequest()),
rpc_context,
12,
)
}
input_queue.put(
ECMElement(
tag=channel_id,
payload=EmbeddedControlMessage(
ecm_id,
EmbeddedControlMessageType.ALL_ALIGNMENT,
scope,
command_mapping,
),
)
)
element: DataElement = output_queue.get()
assert isinstance(element.payload, EmbeddedControlMessage)
assert element.payload.ecm_type == EmbeddedControlMessageType.ALL_ALIGNMENT
assert element.payload.id == ecm_id
assert element.payload.command_mapping == command_mapping
assert element.payload.scope == scope
assert element.tag == channel_id
================================================
FILE: amber/src/test/python/core/runnables/test_network_sender.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import threading
from time import sleep
from core.models.internal_queue import InternalQueue
from core.runnables.network_receiver import NetworkReceiver
from core.runnables.network_sender import NetworkSender
class TestNetworkSender:
@pytest.fixture
def network_receiver(self):
network_receiver = NetworkReceiver(InternalQueue(), host="localhost", port=5555)
yield network_receiver
network_receiver.stop()
@pytest.fixture
def network_receiver_thread(self, network_receiver):
network_receiver_thread = threading.Thread(target=network_receiver.run)
yield network_receiver_thread
@pytest.fixture
def network_sender(self):
network_sender = NetworkSender(InternalQueue(), host="localhost", port=5555)
yield network_sender
network_sender.stop()
@pytest.fixture
def network_sender_thread(self, network_sender):
network_sender_thread = threading.Thread(target=network_sender.run)
yield network_sender_thread
@pytest.mark.timeout(2)
def test_network_sender_can_stop(
self,
network_receiver,
network_receiver_thread,
network_sender,
network_sender_thread,
):
network_receiver_thread.start()
network_sender_thread.start()
assert network_receiver_thread.is_alive()
assert network_sender_thread.is_alive()
sleep(0.1)
network_receiver.stop()
network_sender.stop()
sleep(0.1)
assert not network_receiver_thread.is_alive()
assert not network_sender_thread.is_alive()
network_receiver_thread.join()
network_sender_thread.join()
================================================
FILE: amber/src/test/python/core/storage/iceberg/test_iceberg_document.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import pytest
import random
import tempfile
import uuid
from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from core.models import Schema, Tuple
from core.storage.document_factory import DocumentFactory
from core.storage.storage_config import StorageConfig
from core.storage.vfs_uri_factory import VFSURIFactory
from proto.org.apache.texera.amber.core import (
WorkflowIdentity,
ExecutionIdentity,
OperatorIdentity,
PortIdentity,
GlobalPortIdentity,
PhysicalOpIdentity,
)
# Hardcoded storage config only for test purposes. The iceberg warehouse
# directory must be a writable absolute path; using `tempfile.mkdtemp()`
# avoids depending on pytest's cwd (an earlier `"../../../../../../amber/
# user-resources/..."` value silently relied on CWD = amber/src/main/python
# and broke when the cwd moved up to amber/).
StorageConfig.initialize(
catalog_type="postgres",
postgres_uri_without_scheme="localhost:5432/texera_iceberg_catalog",
postgres_username="texera",
postgres_password="password",
rest_catalog_uri="http://localhost:8181/catalog/",
rest_catalog_warehouse_name="texera",
table_result_namespace="operator-port-result",
directory_path=tempfile.mkdtemp(prefix="texera-iceberg-warehouse-"),
commit_batch_size=4096,
s3_endpoint="http://localhost:9000",
s3_region="us-east-1",
s3_auth_username="minioadmin",
s3_auth_password="minioadmin",
)
class TestIcebergDocument:
@pytest.fixture
def amber_schema(self):
"""Sample Amber schema"""
return Schema(
raw_schema={
"col-string": "STRING",
"col-int": "INTEGER",
"col-bool": "BOOLEAN",
"col-long": "LONG",
"col-double": "DOUBLE",
"col-timestamp": "TIMESTAMP",
"col-binary": "BINARY",
}
)
@pytest.fixture
def iceberg_document(self, amber_schema):
"""
Creates an iceberg document of operator port results using the sample schema
with a random operator id
"""
operator_uuid = str(uuid.uuid4()).replace("-", "")
uri = VFSURIFactory.create_result_uri(
WorkflowIdentity(id=0),
ExecutionIdentity(id=0),
GlobalPortIdentity(
op_id=PhysicalOpIdentity(
logical_op_id=OperatorIdentity(id=f"test_table_{operator_uuid}"),
layer_name="main",
),
port_id=PortIdentity(id=0),
input=False,
),
)
DocumentFactory.create_document(uri, amber_schema)
document, _ = DocumentFactory.open_document(uri)
return document
@pytest.fixture
def sample_items(self, amber_schema) -> [Tuple]:
"""
Generates a list of sample tuples
"""
base_tuples = [
Tuple(
{
"col-string": "Hello World",
"col-int": 42,
"col-bool": True,
"col-long": 1123213213213,
"col-double": 214214.9969346,
"col-timestamp": datetime.datetime.now(),
"col-binary": b"hello",
},
schema=amber_schema,
),
Tuple(
{
"col-string": "",
"col-int": -1,
"col-bool": False,
"col-long": -98765432109876,
"col-double": -0.001,
"col-timestamp": datetime.datetime.fromtimestamp(100000000),
"col-binary": bytearray([255, 0, 0, 64]),
},
schema=amber_schema,
),
Tuple(
{
"col-string": "Special Characters: \n\t\r",
"col-int": 2147483647,
"col-bool": True,
"col-long": 9223372036854775807,
"col-double": 1.7976931348623157e308,
"col-timestamp": datetime.datetime.fromtimestamp(1234567890),
"col-binary": bytearray([1, 2, 3, 4, 5]),
},
schema=amber_schema,
),
]
# Function to generate random binary data
def generate_random_binary(size):
return bytearray(random.getrandbits(8) for _ in range(size))
# Generate additional tuples
additional_tuples = [
Tuple(
{
"col-string": None if i % 7 == 0 else f"Generated String {i}",
"col-int": None if i % 5 == 0 else i,
"col-bool": None if i % 6 == 0 else i % 2 == 0,
"col-long": None if i % 4 == 0 else i * 1000000,
"col-double": None if i % 3 == 0 else i * 0.12345,
"col-timestamp": (
None
if i % 8 == 0
else datetime.datetime.fromtimestamp(
datetime.datetime.now().timestamp() + i
)
),
"col-binary": None if i % 9 == 0 else generate_random_binary(10),
},
schema=amber_schema,
)
for i in range(1, 20001)
]
return base_tuples + additional_tuples
def test_basic_read_and_write(self, iceberg_document, sample_items):
"""
Create an iceberg document, write sample items, and read it back.
"""
writer = iceberg_document.writer(str(uuid.uuid4()))
writer.open()
for item in sample_items:
writer.put_one(item)
writer.close()
retrieved_items = list(iceberg_document.get())
assert sample_items == retrieved_items
def test_clear_document(self, iceberg_document, sample_items):
"""
Create an iceberg document, write sample items, and clear the document.
"""
writer = iceberg_document.writer(str(uuid.uuid4()))
writer.open()
for item in sample_items:
writer.put_one(item)
writer.close()
assert len(list(iceberg_document.get())) > 0
iceberg_document.clear()
assert len(list(iceberg_document.get())) == 0
def test_handle_empty_read(self, iceberg_document):
"""
The iceberg document should handle empty reads gracefully
"""
retrieved_items = list(iceberg_document.get())
assert retrieved_items == []
def test_concurrent_writes_followed_by_read(self, iceberg_document, sample_items):
"""
Tests multiple concurrent writers writing to the same iceberg document
"""
all_items = sample_items
num_writers = 10
# Calculate the batch size and the remainder
batch_size = len(all_items) // num_writers
remainder = len(all_items) % num_writers
# Create writer's batches
item_batches = [
all_items[
i * batch_size + min(i, remainder) : i * batch_size
+ min(i, remainder)
+ batch_size
+ (1 if i < remainder else 0)
]
for i in range(num_writers)
]
assert len(item_batches) == num_writers, (
f"Expected {num_writers} batches but got {len(item_batches)}"
)
# Perform concurrent writes
def write_batch(batch):
writer = iceberg_document.writer(str(uuid.uuid4()))
writer.open()
for item in batch:
writer.put_one(item)
writer.close()
with ThreadPoolExecutor(max_workers=num_writers) as executor:
futures = [executor.submit(write_batch, batch) for batch in item_batches]
for future in as_completed(futures):
future.result() # Wait for each future to complete
# Read all items back
retrieved_items = list(iceberg_document.get())
# Verify that the retrieved items match the original items
assert set(retrieved_items) == set(all_items), (
"All items should be read correctly after concurrent writes."
)
def test_read_using_range(self, iceberg_document, sample_items):
"""
The iceberg document should read all items using rages correctly.
"""
writer = iceberg_document.writer(str(uuid.uuid4()))
writer.open()
for item in sample_items:
writer.put_one(item)
writer.close()
# Read all items using ranges
batch_size = 1500
# Generate ranges
ranges = [
range(i, min(i + batch_size, len(sample_items)))
for i in range(0, len(sample_items), batch_size)
]
# Retrieve items using ranges
retrieved_items = [
item for r in ranges for item in iceberg_document.get_range(r.start, r.stop)
]
assert len(retrieved_items) == len(sample_items), (
"The number of retrieved items does not match the number of all items."
)
# Verify that the retrieved items match the original items
assert set(retrieved_items) == set(sample_items), (
"All items should be retrieved correctly using ranges."
)
def test_get_after(self, iceberg_document, sample_items):
"""
The iceberg document should retrieve items correctly using get_after
"""
writer = iceberg_document.writer(str(uuid.uuid4()))
writer.open()
for item in sample_items:
writer.put_one(item)
writer.close()
# Test get_after for various offsets
offsets = [0, len(sample_items) // 2, len(sample_items) - 1]
for offset in offsets:
if offset < len(sample_items):
expected_items = sample_items[offset:]
else:
expected_items = []
retrieved_items = list(iceberg_document.get_after(offset))
assert retrieved_items == expected_items, (
f"get_after({offset}) did not return the expected items. "
f"Expected: {expected_items}, Got: {retrieved_items}"
)
# Test get_after for an offset beyond the range
invalid_offset = len(sample_items)
retrieved_items = list(iceberg_document.get_after(invalid_offset))
assert not retrieved_items, (
f"get_after({invalid_offset}) should return "
f"an empty list, but got: {retrieved_items}"
)
def test_get_counts(self, iceberg_document, sample_items):
"""
The iceberg document should correctly return the count of items.
"""
writer = iceberg_document.writer(str(uuid.uuid4()))
writer.open()
for item in sample_items:
writer.put_one(item)
writer.close()
assert iceberg_document.get_count() == len(sample_items), (
"get_count should return the same number as the length of sample_items"
)
================================================
FILE: amber/src/test/python/core/storage/iceberg/test_iceberg_rest_catalog_integration.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import uuid
import pytest
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.schema import Schema
from pyiceberg.types import IntegerType, NestedField, StringType
from core.storage.iceberg.iceberg_utils import create_rest_catalog
pytestmark = pytest.mark.integration
@pytest.fixture
def rest_catalog():
return create_rest_catalog(
catalog_name="rest_integration_test",
warehouse_name="texera",
rest_uri="http://localhost:8181/catalog/",
s3_endpoint="http://localhost:9000",
s3_region="us-west-2",
s3_username="texera_minio",
s3_password="password",
)
def test_rest_catalog_round_trip(rest_catalog):
"""Round-trip table metadata via the REST catalog (Lakekeeper)."""
namespace = "rest_integration_test_ns"
table_name = f"rest_test_{uuid.uuid4().hex}"
identifier = f"{namespace}.{table_name}"
schema = Schema(
NestedField(field_id=1, name="id", field_type=IntegerType(), required=False),
NestedField(field_id=2, name="name", field_type=StringType(), required=False),
)
rest_catalog.create_namespace_if_not_exists(namespace)
if rest_catalog.table_exists(identifier):
rest_catalog.drop_table(identifier)
# create — exercises REST createTable.
rest_catalog.create_table(identifier=identifier, schema=schema)
assert rest_catalog.table_exists(identifier)
# load — exercises REST loadTable (metadata fetch).
loaded = rest_catalog.load_table(identifier)
assert len(loaded.schema().fields) == 2
# drop — exercises REST dropTable.
rest_catalog.drop_table(identifier)
assert not rest_catalog.table_exists(identifier)
with pytest.raises(NoSuchTableError):
rest_catalog.load_table(identifier)
================================================
FILE: amber/src/test/python/core/storage/iceberg/test_iceberg_utils_catalog.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import patch
from core.storage.iceberg import iceberg_utils
from core.storage.iceberg.iceberg_utils import create_postgres_catalog
class TestCreatePostgresCatalog:
"""
Regression tests for `create_postgres_catalog`.
The Scala side (`IcebergUtil.createPostgresCatalog`) initializes the JDBC
catalog with a plain filesystem warehouse path (no URI scheme). PyIceberg
persists the `warehouse` property into table metadata, so if the Python
side registers the catalog with a `file://`-prefixed value, Iceberg tables
written from Python UDFs become unreadable from the Scala/Java engine
(and vice versa). These tests pin the Python side to the same plain-path
convention used on the Scala side.
"""
def test_warehouse_is_passed_without_file_scheme(self):
"""`warehouse` must be forwarded as-is, without a `file://` prefix."""
warehouse_path = "/tmp/texera/iceberg-warehouse"
with patch.object(iceberg_utils, "SqlCatalog") as mock_sql_catalog:
create_postgres_catalog(
catalog_name="texera_iceberg",
warehouse_path=warehouse_path,
uri_without_scheme="localhost:5432/texera_iceberg_catalog",
username="texera",
password="password",
)
assert mock_sql_catalog.call_count == 1
_, kwargs = mock_sql_catalog.call_args
assert kwargs["warehouse"] == warehouse_path
assert not kwargs["warehouse"].startswith("file://")
def test_windows_style_warehouse_is_passed_verbatim(self):
"""
The Scala side strips the Windows drive colon (e.g. `C:/x` -> `C/x`)
before registering the catalog so PyArrow can parse the path. The
Python side should forward whatever it receives verbatim, so the two
runtimes agree on the warehouse string stored in Iceberg metadata.
"""
warehouse_path = "C/Users/texera/iceberg-warehouse"
with patch.object(iceberg_utils, "SqlCatalog") as mock_sql_catalog:
create_postgres_catalog(
catalog_name="texera_iceberg",
warehouse_path=warehouse_path,
uri_without_scheme="localhost:5432/texera_iceberg_catalog",
username="texera",
password="password",
)
_, kwargs = mock_sql_catalog.call_args
assert kwargs["warehouse"] == warehouse_path
assert "file://" not in kwargs["warehouse"]
def test_postgres_uri_is_built_with_pg8000_scheme(self):
"""The JDBC URI should be prefixed with `postgresql+pg8000://` and
include credentials; nothing about that should bleed into `warehouse`.
"""
warehouse_path = "/var/lib/texera/warehouse"
with patch.object(iceberg_utils, "SqlCatalog") as mock_sql_catalog:
create_postgres_catalog(
catalog_name="texera_iceberg",
warehouse_path=warehouse_path,
uri_without_scheme="db.internal:5432/texera_iceberg_catalog",
username="texera",
password="s3cret",
)
args, kwargs = mock_sql_catalog.call_args
assert args == ("texera_iceberg",)
assert kwargs["uri"] == (
"postgresql+pg8000://texera:s3cret@db.internal:5432/texera_iceberg_catalog"
)
# And warehouse is still the plain path.
assert kwargs["warehouse"] == warehouse_path
================================================
FILE: amber/src/test/python/core/storage/iceberg/test_iceberg_utils_large_binary.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pyarrow as pa
from pyiceberg import types as iceberg_types
from pyiceberg.schema import Schema as IcebergSchema
from core.models import Schema, Tuple
from core.models.schema.attribute_type import AttributeType
from core.models.type.large_binary import largebinary
from core.storage.iceberg.iceberg_utils import (
encode_large_binary_field_name,
decode_large_binary_field_name,
iceberg_schema_to_amber_schema,
amber_schema_to_iceberg_schema,
amber_tuples_to_arrow_table,
arrow_table_to_amber_tuples,
)
class TestIcebergUtilsLargeBinary:
def test_encode_large_binary_field_name(self):
"""Test encoding LARGE_BINARY field names with suffix."""
assert (
encode_large_binary_field_name("my_field", AttributeType.LARGE_BINARY)
== "my_field__texera_large_binary_ptr"
)
assert (
encode_large_binary_field_name("my_field", AttributeType.STRING)
== "my_field"
)
def test_decode_large_binary_field_name(self):
"""Test decoding LARGE_BINARY field names by removing suffix."""
assert (
decode_large_binary_field_name("my_field__texera_large_binary_ptr")
== "my_field"
)
assert decode_large_binary_field_name("my_field") == "my_field"
assert decode_large_binary_field_name("regular_field") == "regular_field"
def test_amber_schema_to_iceberg_schema_with_large_binary(self):
"""Test converting Amber schema with LARGE_BINARY to Iceberg schema."""
amber_schema = Schema()
amber_schema.add("regular_field", AttributeType.STRING)
amber_schema.add("large_binary_field", AttributeType.LARGE_BINARY)
amber_schema.add("int_field", AttributeType.INT)
iceberg_schema = amber_schema_to_iceberg_schema(amber_schema)
# Check field names are encoded
field_names = [field.name for field in iceberg_schema.fields]
assert "regular_field" in field_names
assert "large_binary_field__texera_large_binary_ptr" in field_names
assert "int_field" in field_names
# Check types
large_binary_field = next(
f for f in iceberg_schema.fields if "large_binary" in f.name
)
assert isinstance(large_binary_field.field_type, iceberg_types.StringType)
def test_iceberg_schema_to_amber_schema_with_large_binary(self):
"""Test converting Iceberg schema with LARGE_BINARY to Amber schema."""
iceberg_schema = IcebergSchema(
iceberg_types.NestedField(
1, "regular_field", iceberg_types.StringType(), required=False
),
iceberg_types.NestedField(
2,
"large_binary_field__texera_large_binary_ptr",
iceberg_types.StringType(),
required=False,
),
iceberg_types.NestedField(
3, "int_field", iceberg_types.IntegerType(), required=False
),
)
amber_schema = iceberg_schema_to_amber_schema(iceberg_schema)
assert amber_schema.get_attr_type("regular_field") == AttributeType.STRING
assert (
amber_schema.get_attr_type("large_binary_field")
== AttributeType.LARGE_BINARY
)
assert amber_schema.get_attr_type("int_field") == AttributeType.INT
# Check Arrow schema has metadata for LARGE_BINARY
arrow_schema = amber_schema.as_arrow_schema()
large_binary_field = arrow_schema.field("large_binary_field")
assert large_binary_field.metadata is not None
assert large_binary_field.metadata.get(b"texera_type") == b"LARGE_BINARY"
def test_amber_tuples_to_arrow_table_with_large_binary(self):
"""Test converting Amber tuples with largebinary to Arrow table."""
amber_schema = Schema()
amber_schema.add("regular_field", AttributeType.STRING)
amber_schema.add("large_binary_field", AttributeType.LARGE_BINARY)
large_binary1 = largebinary("s3://bucket/path1")
large_binary2 = largebinary("s3://bucket/path2")
tuples = [
Tuple(
{"regular_field": "value1", "large_binary_field": large_binary1},
schema=amber_schema,
),
Tuple(
{"regular_field": "value2", "large_binary_field": large_binary2},
schema=amber_schema,
),
]
iceberg_schema = amber_schema_to_iceberg_schema(amber_schema)
arrow_table = amber_tuples_to_arrow_table(iceberg_schema, tuples)
# Check that largebinary values are converted to URI strings
regular_values = arrow_table.column("regular_field").to_pylist()
large_binary_values = arrow_table.column(
"large_binary_field__texera_large_binary_ptr"
).to_pylist()
assert regular_values == ["value1", "value2"]
assert large_binary_values == ["s3://bucket/path1", "s3://bucket/path2"]
def test_arrow_table_to_amber_tuples_with_large_binary(self):
"""Test converting Arrow table with LARGE_BINARY to Amber tuples."""
# Create Iceberg schema with encoded field name
iceberg_schema = IcebergSchema(
iceberg_types.NestedField(
1, "regular_field", iceberg_types.StringType(), required=False
),
iceberg_types.NestedField(
2,
"large_binary_field__texera_large_binary_ptr",
iceberg_types.StringType(),
required=False,
),
)
# Create Arrow table with URI strings
arrow_table = pa.Table.from_pydict(
{
"regular_field": ["value1", "value2"],
"large_binary_field__texera_large_binary_ptr": [
"s3://bucket/path1",
"s3://bucket/path2",
],
}
)
tuples = list(arrow_table_to_amber_tuples(iceberg_schema, arrow_table))
assert len(tuples) == 2
assert tuples[0]["regular_field"] == "value1"
assert isinstance(tuples[0]["large_binary_field"], largebinary)
assert tuples[0]["large_binary_field"].uri == "s3://bucket/path1"
assert tuples[1]["regular_field"] == "value2"
assert isinstance(tuples[1]["large_binary_field"], largebinary)
assert tuples[1]["large_binary_field"].uri == "s3://bucket/path2"
def test_round_trip_large_binary_tuples(self):
"""Test round-trip conversion of tuples with largebinary."""
amber_schema = Schema()
amber_schema.add("regular_field", AttributeType.STRING)
amber_schema.add("large_binary_field", AttributeType.LARGE_BINARY)
large_binary = largebinary("s3://bucket/path/to/object")
original_tuples = [
Tuple(
{"regular_field": "value1", "large_binary_field": large_binary},
schema=amber_schema,
),
]
# Convert to Iceberg and Arrow
iceberg_schema = amber_schema_to_iceberg_schema(amber_schema)
arrow_table = amber_tuples_to_arrow_table(iceberg_schema, original_tuples)
# Convert back to Amber tuples
retrieved_tuples = list(
arrow_table_to_amber_tuples(iceberg_schema, arrow_table)
)
assert len(retrieved_tuples) == 1
assert retrieved_tuples[0]["regular_field"] == "value1"
assert isinstance(retrieved_tuples[0]["large_binary_field"], largebinary)
assert retrieved_tuples[0]["large_binary_field"].uri == large_binary.uri
def test_arrow_table_to_amber_tuples_with_null_large_binary(self):
"""Test converting Arrow table with null largebinary values."""
iceberg_schema = IcebergSchema(
iceberg_types.NestedField(
1, "regular_field", iceberg_types.StringType(), required=False
),
iceberg_types.NestedField(
2,
"large_binary_field__texera_large_binary_ptr",
iceberg_types.StringType(),
required=False,
),
)
arrow_table = pa.Table.from_pydict(
{
"regular_field": ["value1"],
"large_binary_field__texera_large_binary_ptr": [None],
}
)
tuples = list(arrow_table_to_amber_tuples(iceberg_schema, arrow_table))
assert len(tuples) == 1
assert tuples[0]["regular_field"] == "value1"
assert tuples[0]["large_binary_field"] is None
================================================
FILE: amber/src/test/python/core/test_python_worker.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import core.python_worker as pw
class _FakeReceiver:
def __init__(self, input_queue, host):
self.input_queue = input_queue
self.host = host
self.proxy_server = type(
"FakeProxyServer", (), {"get_port_number": staticmethod(lambda: 12345)}
)()
self._shutdown_cb = None
def register_shutdown(self, cb):
self._shutdown_cb = cb
def run(self):
pass
def stop(self):
pass
class _FakeSender:
def __init__(self, output_queue, host, port, handshake_port):
self.output_queue = output_queue
self.host = host
self.port = port
self.handshake_port = handshake_port
self.stopped = False
def run(self):
pass
def stop(self):
self.stopped = True
class _FakeMainLoop:
def __init__(self, worker_id, input_queue, output_queue):
self.worker_id = worker_id
self.stopped = False
def run(self):
pass
def stop(self):
self.stopped = True
class _FakeHeartbeat:
def __init__(self, host, port, interval, stop_event):
self.host = host
self.port = port
self.interval = interval
self.stop_event = stop_event
self.stopped = False
def run(self):
pass
def stop(self):
self.stopped = True
@pytest.fixture
def stub_network(monkeypatch):
monkeypatch.setattr(pw, "NetworkReceiver", _FakeReceiver)
monkeypatch.setattr(pw, "NetworkSender", _FakeSender)
monkeypatch.setattr(pw, "MainLoop", _FakeMainLoop)
monkeypatch.setattr(pw, "Heartbeat", _FakeHeartbeat)
class TestPythonWorker:
@pytest.mark.timeout(5)
def test_construction_wires_dependencies(self, stub_network):
worker = pw.PythonWorker(worker_id="w-1", host="localhost", output_port=9999)
# NetworkSender must receive the handshake port from the receiver's
# proxy server — this is the Java→Python wiring contract.
assert worker._network_sender.handshake_port == 12345
assert worker._network_sender.port == 9999
# The receiver's shutdown callback is wired to worker.stop so a
# client-side disconnect tears the worker down.
assert worker._network_receiver._shutdown_cb == worker.stop
@pytest.mark.timeout(5)
def test_stop_cascades_to_main_loop_sender_and_heartbeat(self, stub_network):
worker = pw.PythonWorker(worker_id="w-1", host="localhost", output_port=9999)
worker.stop()
assert worker._main_loop.stopped is True
assert worker._network_sender.stopped is True
assert worker._heartbeat.stopped is True
@pytest.mark.timeout(5)
def test_run_sets_stop_event_after_main_loop_returns(self, stub_network):
worker = pw.PythonWorker(worker_id="w-1", host="localhost", output_port=9999)
# All fakes' run() return immediately, so run() drains all threads
# without blocking. The contract is that the heartbeat stop event
# is set after the main loop / sender threads join, so the
# heartbeat thread can exit cleanly.
worker.run()
assert worker._stop_event.is_set()
================================================
FILE: amber/src/test/python/core/util/console_message/test_replace_print.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import builtins
import io
from typing import List
import pytest
from core.util.console_message.replace_print import replace_print
from proto.org.apache.texera.amber.engine.architecture.rpc import (
ConsoleMessage,
ConsoleMessageType,
)
class CapturingBuffer:
"""Minimal IBuffer stand-in that just records put calls."""
def __init__(self):
self.messages: List[ConsoleMessage] = []
def put(self, msg):
self.messages.append(msg)
class TestReplacePrintLifecycle:
def test_print_is_replaced_inside_the_context_and_restored_on_exit(self):
original = builtins.print
buf = CapturingBuffer()
with replace_print("w", buf):
assert builtins.print is not original
assert builtins.print is original
def test_print_is_restored_even_when_the_block_raises(self):
original = builtins.print
buf = CapturingBuffer()
with pytest.raises(RuntimeError):
with replace_print("w", buf):
raise RuntimeError("boom")
assert builtins.print is original
def test_exit_returns_true_for_clean_block_and_false_for_raising_block(self):
# Pin: __exit__ returns True when no exception, False otherwise. The
# contextlib protocol then suppresses or surfaces the exception
# accordingly. The class returns False on exception, so the exception
# propagates out — matching the docstring claim.
ctx = replace_print("w", CapturingBuffer())
ctx.__enter__()
assert ctx.__exit__(None, None, None) is True
ctx2 = replace_print("w", CapturingBuffer())
ctx2.__enter__()
try:
assert ctx2.__exit__(RuntimeError, RuntimeError("x"), None) is False
finally:
# The class only restores `print` if __exit__ runs to completion;
# call it explicitly to clean up either way.
builtins.print = ctx2.builtins_print
class TestReplacePrintBufferPayload:
def test_print_inside_context_enqueues_a_console_message(self):
buf = CapturingBuffer()
with replace_print("worker-A", buf):
print("hello")
assert len(buf.messages) == 1
msg = buf.messages[0]
assert msg.worker_id == "worker-A"
assert msg.msg_type == ConsoleMessageType.PRINT
# Default print appends a newline; the title carries the full line.
assert msg.title == "hello\n"
assert msg.message == ""
def test_joins_args_via_the_real_print_so_sep_and_end_kwargs_apply(self):
buf = CapturingBuffer()
with replace_print("w", buf):
print("a", "b", "c", sep="-", end="!")
assert buf.messages[0].title == "a-b-c!"
def test_each_print_call_produces_one_buffer_entry(self):
# Pin: the wrapped print writes to the buffer once per print call,
# not once per argument (contextlib.redirect_stdout-style would do the
# latter). The docstring calls this out.
buf = CapturingBuffer()
with replace_print("w", buf):
print("first")
print("second", "third")
assert [m.title for m in buf.messages] == ["first\n", "second third\n"]
def test_print_with_file_kwarg_bypasses_the_buffer(self):
# When the caller provides a `file=...` argument, the wrap delegates
# straight to the original builtins.print and does not enqueue a
# ConsoleMessage. This is what lets explicit logging redirects keep
# working inside the context.
buf = CapturingBuffer()
sink = io.StringIO()
with replace_print("w", buf):
print("ignored-by-buffer", file=sink)
assert buf.messages == []
assert sink.getvalue() == "ignored-by-buffer\n"
def test_source_field_records_caller_module_function_and_line(self):
# The wrap walks one frame up to identify where the print() came from,
# so the source string carries `::`. We verify
# only the structural parts — the exact line number and module name
# depend on this test's location, so use loose checks.
buf = CapturingBuffer()
def caller_under_test():
print("from-caller")
with replace_print("w", buf):
caller_under_test()
source = buf.messages[0].source
parts = source.split(":")
assert len(parts) == 3
# The reported function name is the function that called print().
assert parts[1] == "caller_under_test"
# And the line number is a positive integer.
assert parts[2].isdigit() and int(parts[2]) > 0
================================================
FILE: amber/src/test/python/core/util/customized_queue/test_inner.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.util.customized_queue.inner import (
class_inner,
inner,
raw_inner,
static_inner,
)
class TestRawInner:
def test_returns_the_class_unchanged(self):
# raw_inner is a no-op decorator preserved for forward-compat with
# potential changes to default Python inner-class semantics.
class C:
pass
assert raw_inner(C) is C
class TestStaticInner:
def test_assigns_owner_to_outer_class_at_definition_time(self):
class Outer:
@static_inner
class Inner:
pass
# __set_name__ replaces the descriptor with the actual inner class
# and stamps `owner` so the inner can refer back to its outer class.
assert Outer.Inner.owner is Outer
def test_inner_class_is_accessible_directly_on_outer(self):
class Outer:
@static_inner
class Inner:
@staticmethod
def hello():
return "hi"
assert Outer.Inner.hello() == "hi"
class TestClassInnerDescriptorGuard:
@pytest.mark.parametrize("descriptor_method", ["__get__", "__set__", "__del__"])
def test_rejects_classes_that_implement_descriptor_methods(self, descriptor_method):
# class_inner refuses to wrap classes that look like descriptors —
# the `__get__` it installs would conflict with the user-provided one.
attrs = {descriptor_method: lambda *args, **kwargs: None}
bad_cls = type("Bad", (object,), attrs)
with pytest.raises(ValueError, match="descriptors"):
class_inner(bad_cls)
class TestClassInnerCarriedInheritance:
def test_subclass_of_outer_gets_a_derived_inner_class(self):
# When the outer class is subclassed, accessing its inner-class
# attribute on the subclass produces a *new* inner class that lists
# the parent's inner class among its bases — this is "carried
# inheritance". The two inner classes are not the same object and
# the derived one's owner points at the subclass.
class Outer:
@class_inner
class Inner:
pass
class Sub(Outer):
pass
derived = Sub.Inner
assert derived is not Outer.Inner
assert Outer.Inner in derived.__mro__
assert derived.owner is Sub
# Direct access on the original outer is still the original class.
assert Outer.Inner.owner is Outer
class TestInnerInstanceBinding:
def test_outer_instance_inner_call_binds_owner_to_that_outer_instance(self):
# This is the most common @inner usage in Texera: an inner class on
# an outer object instance, where the inner needs `self.owner` to
# reach the outer object (e.g. LinkedBlockingMultiQueue's Node /
# SubQueue / PriorityGroup).
class Outer:
def __init__(self, label):
self.label = label
@inner
class Inner:
def __init__(self, x):
self.x = x
a = Outer("a")
b = Outer("b")
a_inner = a.Inner(7)
b_inner = b.Inner(11)
assert a_inner.x == 7
assert b_inner.x == 11
assert a_inner.owner is a
assert b_inner.owner is b
# Instances are independent; binding to one doesn't leak to the other.
assert a_inner is not b_inner
class TestInnerProperty:
def test_property_auto_instantiates_inner_on_access(self):
class Outer:
@inner.property
class Counter:
def __init__(self):
self.value = 0
outer = Outer()
c = outer.Counter
# The property short-circuits the constructor signature so accessing
# the attribute returns a configured instance directly.
assert c.value == 0
assert c.owner is outer
def test_property_returns_a_new_instance_each_access(self):
# Plain `@inner.property` (without `cached_property`) does not memoize,
# so two accesses produce two distinct instances. Pin this so the
# difference between `property` and `cached_property` is preserved.
class Outer:
@inner.property
class Counter:
def __init__(self):
self.value = 0
outer = Outer()
first = outer.Counter
second = outer.Counter
assert first is not second
class TestInnerCachedProperty:
def test_cached_property_returns_the_same_instance_on_repeat_access(self):
# cached_property memoizes the inner instance on the outer object,
# so subsequent attribute reads return the exact same object.
class Outer:
@inner.cached_property
class Counter:
def __init__(self):
self.value = 0
outer = Outer()
first = outer.Counter
first.value = 42
second = outer.Counter
assert first is second
assert second.value == 42
def test_cached_property_caches_independently_per_outer_instance(self):
class Outer:
@inner.cached_property
class Counter:
def __init__(self):
self.value = 0
a = Outer()
b = Outer()
a.Counter.value = 1
b.Counter.value = 2
assert a.Counter.value == 1
assert b.Counter.value == 2
================================================
FILE: amber/src/test/python/core/util/customized_queue/test_linked_blocking_multi_queue.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import random
import time
from threading import Thread
from core.util.customized_queue.linked_blocking_multi_queue import (
LinkedBlockingMultiQueue,
)
class TestLinkedBlockingMultiQueue:
@pytest.fixture
def queue(self):
lbmq = LinkedBlockingMultiQueue()
lbmq.add_sub_queue("control", 0)
lbmq.add_sub_queue("data", 1)
return lbmq
def test_sub_can_emit(self, queue):
assert queue.is_empty()
queue.put("data", 1)
assert not queue.is_empty()
assert queue.is_empty("control")
assert queue.get() == 1
assert queue.is_empty()
assert queue.is_empty("control")
def test_main_can_emit(self, queue):
assert queue.is_empty()
queue.put("control", "s")
assert not queue.is_empty()
assert queue.get() == "s"
assert queue.is_empty()
def test_main_can_emit_before_sub(self, queue):
assert queue.is_empty()
queue.put("data", 1)
queue.put("control", "s")
assert not queue.is_empty()
assert queue.get() == "s"
assert queue.is_empty("control")
assert not queue.is_empty()
assert queue.get() == 1
assert queue.is_empty()
def test_can_maintain_order_respectively(self, queue):
queue.put("data", 1)
queue.put("control", "s1")
queue.put("data", 99)
queue.put("control", "s2")
queue.put("control", "s3")
queue.put("data", 3)
queue.put("control", "s4")
res = list()
while not queue.is_empty():
res.append(queue.get())
assert res == ["s1", "s2", "s3", "s4", 1, 99, 3]
def test_can_disable_sub(self, queue):
queue.disable("data")
queue.put("data", 1)
queue.put("control", "s1")
queue.put("data", 99)
queue.put("control", "s2")
queue.put("control", "s3")
queue.put("data", 3)
queue.put("control", "s4")
res = list()
while not queue.is_empty():
res.append(queue.get())
assert res == ["s1", "s2", "s3", "s4"]
assert queue.is_empty()
queue.enable("data")
assert not queue.is_empty()
res = list()
while not queue.is_empty():
res.append(queue.get())
assert res == [1, 99, 3]
assert queue.is_empty()
@pytest.mark.timeout(2)
def test_producer_first_insert_sub(self, queue, reraise):
def producer():
with reraise:
time.sleep(0.2)
queue.put("data", 1)
producer_thread = Thread(target=producer)
producer_thread.start()
producer_thread.join()
assert queue.get() == 1
reraise()
@pytest.mark.timeout(2)
def test_consumer_first_insert_sub(self, queue, reraise):
def consumer():
with reraise:
assert queue.get() == 1
assert queue.is_empty()
consumer_thread = Thread(target=consumer)
consumer_thread.start()
time.sleep(0.2)
queue.put("data", 1)
consumer_thread.join()
reraise()
@pytest.mark.timeout(2)
def test_producer_first_insert_main(self, queue, reraise):
def producer():
with reraise:
time.sleep(0.2)
queue.put("control", "s")
producer_thread = Thread(target=producer)
producer_thread.start()
producer_thread.join()
assert queue.get() == "s"
reraise()
@pytest.mark.timeout(2)
def test_consumer_first_insert_main(self, queue, reraise):
def consumer():
with reraise:
assert queue.get() == "s"
assert queue.is_empty()
consumer_thread = Thread(target=consumer)
consumer_thread.start()
time.sleep(0.2)
queue.put("control", "s")
consumer_thread.join()
reraise()
@pytest.mark.timeout(10)
def test_multiple_producer_race(self, queue, reraise):
def producer(k):
with reraise:
if isinstance(k, int):
for i in range(k):
queue.put("data", i)
else:
queue.put("control", k)
threads = []
target = set()
for i in range(1000):
if random.random() > 0.5:
i = chr(i)
target.add(i)
producer_thread = Thread(target=producer, args=(i,))
producer_thread.start()
threads.append(producer_thread)
res = set()
def consumer():
with reraise:
queue.disable("data")
while len(res) < len(target):
res.add(queue.get())
consumer_thread = Thread(target=consumer)
consumer_thread.start()
for thread in threads:
thread.join()
consumer_thread.join()
assert res == target
reraise()
def test_multi_types(
self,
queue,
):
queue.put("data", 1)
queue.put("data", 1.1)
queue.put("control", "s")
queue.disable("data")
assert queue.get() == "s"
assert queue.is_empty()
@pytest.mark.timeout(2)
def test_common_single_producer_single_consumer(self, queue, reraise):
def producer():
with reraise:
for i in range(11):
if i % 3 == 0:
queue.put("control", "s")
else:
queue.put("data", i)
producer_thread = Thread(target=producer)
producer_thread.start()
producer_thread.join()
total: int = 0
while True:
queue.enable("data")
queue.enable("data")
t = queue.get()
queue.is_empty("control")
if isinstance(t, int):
total += t
else:
assert t == "s"
queue.is_empty()
queue.disable("data")
queue.disable("data")
queue.is_empty()
queue.disable("data")
if t == 10:
break
assert total == sum(filter(lambda x: x % 3 != 0, range(11)))
reraise()
================================================
FILE: amber/src/test/python/core/util/test_atomic.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import threading
import pytest
from core.util.atomic import AtomicInteger
class TestAtomicIntegerSingleThreaded:
def test_default_starts_at_zero(self):
assert AtomicInteger().value == 0
def test_initial_value_is_coerced_to_int(self):
# The constructor wraps the input through int(), which lets callers
# pass a numeric string or float and still get a clean integer state.
assert AtomicInteger("7").value == 7
assert AtomicInteger(3.9).value == 3 # int() truncates toward zero
def test_inc_returns_new_value_after_adding_default_one(self):
a = AtomicInteger(10)
assert a.inc() == 11
assert a.value == 11
def test_inc_with_custom_delta_uses_int_coercion(self):
a = AtomicInteger(10)
assert a.inc(5) == 15
# int("3") -> 3, the underlying state increments by 3.
assert a.inc("3") == 18
def test_dec_is_inc_with_negated_delta(self):
a = AtomicInteger(10)
assert a.dec() == 9
assert a.dec(4) == 5
def test_get_and_inc_returns_pre_increment_value(self):
a = AtomicInteger(10)
assert a.get_and_inc() == 10
assert a.value == 11
def test_get_and_dec_returns_pre_decrement_value(self):
a = AtomicInteger(10)
assert a.get_and_dec(2) == 10
assert a.value == 8
def test_value_setter_replaces_state_with_int_coercion(self):
a = AtomicInteger(10)
a.value = 42
assert a.value == 42
a.value = "100"
assert a.value == 100
def test_get_and_set_currently_deadlocks_on_non_reentrant_lock(self):
# Bug pin: get_and_set acquires self._lock and then reads self.value,
# which is a property that ALSO tries to acquire self._lock. The lock
# is a non-reentrant threading.Lock, so the call deadlocks the moment
# it is invoked. Document via thread + timeout so the test surfaces
# the deadlock without hanging the whole suite, and pair it with an
# xfail-strict test below that asserts the intended contract.
a = AtomicInteger(10)
started = threading.Event()
completed = threading.Event()
errors: list[BaseException] = []
def attempt():
started.set()
try:
a.get_and_set(99)
completed.set()
except BaseException as exc:
errors.append(exc)
worker = threading.Thread(target=attempt, daemon=True)
worker.start()
# Make sure the worker actually entered `attempt` — otherwise a
# scheduling delay alone could let the assertions below pass even on
# a fixed implementation.
assert started.wait(timeout=2.0), "worker thread never started"
# Give get_and_set a moment to either deadlock or return.
completed.wait(timeout=0.5)
assert not errors, (
f"get_and_set raised before reaching the deadlock spin: {errors[0]!r}"
)
assert worker.is_alive(), (
"worker thread exited unexpectedly — get_and_set neither deadlocked "
"nor completed; the test no longer pins the documented bug."
)
assert not completed.is_set(), (
"get_and_set unexpectedly returned — the deadlock bug appears fixed; "
"delete this pinned test along with the xfail below."
)
@pytest.mark.xfail(
strict=True,
reason=(
"Known bug: AtomicInteger.get_and_set deadlocks because it holds "
"the non-reentrant lock while accessing the value property. "
"This xfail flips to XPASS when the bug is fixed."
),
)
@pytest.mark.timeout(2)
def test_get_and_set_should_return_old_value_and_replace_state(self):
a = AtomicInteger(10)
assert a.get_and_set(99) == 10
assert a.value == 99
class TestAtomicIntegerThreadSafety:
def test_inc_under_concurrent_threads_is_lossless(self):
a = AtomicInteger(0)
threads_count = 8
per_thread = 1000
def worker():
for _ in range(per_thread):
a.inc()
threads = [threading.Thread(target=worker) for _ in range(threads_count)]
for t in threads:
t.start()
for t in threads:
t.join()
assert a.value == threads_count * per_thread
================================================
FILE: amber/src/test/python/core/util/test_expression_evaluator.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from core.util.expression_evaluator import ExpressionEvaluator
from proto.org.apache.texera.amber.engine.architecture.rpc import (
EvaluatedValue,
TypedValue,
)
class TestExpressionEvaluator:
def test_evaluate_basic_expressions(self):
i = 10
assert ExpressionEvaluator.evaluate(
"i", runtime_context={"i": i}
) == EvaluatedValue(
value=TypedValue(
expression="i",
value_ref="i",
value_str="10",
value_type="int",
expandable=False,
),
attributes=[],
)
f = 1.1
assert ExpressionEvaluator.evaluate(
"f", runtime_context={"f": f}
) == EvaluatedValue(
value=TypedValue(
expression="f",
value_ref="f",
value_str="1.1",
value_type="float",
expandable=False,
),
attributes=[],
)
def test_evaluate_str_expression(self):
s = "hello world"
assert ExpressionEvaluator.evaluate(
"s", runtime_context={"s": s}
) == EvaluatedValue(
value=TypedValue(
expression="s",
value_ref="s",
value_str="'hello world'",
value_type="str",
expandable=True,
),
attributes=[
TypedValue(
expression="__getitem__(0)",
value_ref="0",
value_str="'h'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(1)",
value_ref="1",
value_str="'e'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(2)",
value_ref="2",
value_str="'l'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(3)",
value_ref="3",
value_str="'l'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(4)",
value_ref="4",
value_str="'o'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(5)",
value_ref="5",
value_str="' '",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(6)",
value_ref="6",
value_str="'w'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(7)",
value_ref="7",
value_str="'o'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(8)",
value_ref="8",
value_str="'r'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(9)",
value_ref="9",
value_str="'l'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__(10)",
value_ref="10",
value_str="'d'",
value_type="str",
expandable=True,
),
],
)
assert ExpressionEvaluator.evaluate(
"s[4]", runtime_context={"s": s}
) == EvaluatedValue(
value=TypedValue(
expression="s[4]",
value_ref="s[4]",
value_str="'o'",
value_type="str",
expandable=True,
),
attributes=[
TypedValue(
expression="__getitem__(0)",
value_ref="0",
value_str="'o'",
value_type="str",
expandable=True,
)
],
)
assert ExpressionEvaluator.evaluate(
"s.__getitem__(2)", runtime_context={"s": s}
) == EvaluatedValue(
value=TypedValue(
expression="s.__getitem__(2)",
value_ref="s.__getitem__(2)",
value_str="'l'",
value_type="str",
expandable=True,
),
attributes=[
TypedValue(
expression="__getitem__(0)",
value_ref="0",
value_str="'l'",
value_type="str",
expandable=True,
)
],
)
def test_evaluate_object_expression(self):
class A:
def __init__(self):
self.i = 10
self.j = 1.1
a = A()
assert ExpressionEvaluator.evaluate(
"a", runtime_context={"a": a}
) == EvaluatedValue(
value=TypedValue(
expression="a",
value_ref="a",
value_str=(
f"<{A.__module__}."
"TestExpressionEvaluator.test_evaluate_object_expression..A"
f" object at {hex(id(a))}>"
),
value_type="A",
expandable=True,
),
attributes=[
TypedValue(
expression="i",
value_ref="i",
value_str="10",
value_type="int",
expandable=False,
),
TypedValue(
expression="j",
value_ref="j",
value_str="1.1",
value_type="float",
expandable=False,
),
],
)
def test_evaluate_container_expressions(self):
i = 10
f = 1.1
a_list = [i, f, (i, f)]
assert ExpressionEvaluator.evaluate(
"a_list", runtime_context={"a_list": a_list}
) == EvaluatedValue(
value=TypedValue(
expression="a_list",
value_ref="a_list",
value_str="[10, 1.1, (10, 1.1)]",
value_type="list",
expandable=True,
),
attributes=[
TypedValue(
expression="__getitem__(0)",
value_ref="0",
value_str="10",
value_type="int",
expandable=False,
),
TypedValue(
expression="__getitem__(1)",
value_ref="1",
value_str="1.1",
value_type="float",
expandable=False,
),
TypedValue(
expression="__getitem__(2)",
value_ref="2",
value_str="(10, 1.1)",
value_type="tuple",
expandable=True,
),
],
)
t = (i, f, {i, f})
assert ExpressionEvaluator.evaluate(
"t", runtime_context={"t": t}
) == EvaluatedValue(
value=TypedValue(
expression="t",
value_ref="t",
value_str="(10, 1.1, {1.1, 10})",
value_type="tuple",
expandable=True,
),
attributes=[
TypedValue(
expression="__getitem__(0)",
value_ref="0",
value_str="10",
value_type="int",
expandable=False,
),
TypedValue(
expression="__getitem__(1)",
value_ref="1",
value_str="1.1",
value_type="float",
expandable=False,
),
TypedValue(
expression="__getitem__(2)",
value_ref="2",
value_str="{1.1, 10}",
value_type="set",
expandable=True,
),
],
)
s = {i, f, (i, f)}
assert ExpressionEvaluator.evaluate(
"s", runtime_context={"s": s}
) == EvaluatedValue(
value=TypedValue(
expression="s",
value_ref="s",
value_str="{1.1, 10, (10, 1.1)}",
value_type="set",
expandable=True,
),
attributes=[
TypedValue(
expression="__getitem__(0)",
value_ref="0",
value_str="1.1",
value_type="float",
expandable=False,
),
TypedValue(
expression="__getitem__(1)",
value_ref="1",
value_str="10",
value_type="int",
expandable=False,
),
TypedValue(
expression="__getitem__(2)",
value_ref="2",
value_str="(10, 1.1)",
value_type="tuple",
expandable=False,
),
],
)
d = {1: "a", "b": [{i, f}], (i,): f}
assert ExpressionEvaluator.evaluate(
"d", runtime_context={"d": d}
) == EvaluatedValue(
value=TypedValue(
expression="d",
value_ref="d",
value_str="{1: 'a', 'b': [{1.1, 10}], (10,): 1.1}",
value_type="dict",
expandable=True,
),
attributes=[
TypedValue(
expression="__getitem__(1)",
value_ref="1",
value_str="'a'",
value_type="str",
expandable=True,
),
TypedValue(
expression="__getitem__('b')",
value_ref="'b'",
value_str="[{1.1, 10}]",
value_type="list",
expandable=True,
),
TypedValue(
expression="__getitem__((10,))",
value_ref="(10,)",
value_str="1.1",
value_type="float",
expandable=False,
),
],
)
g = (i for i in range(10))
assert ExpressionEvaluator.evaluate(
"g", runtime_context={"g": g}
) == EvaluatedValue(
value=TypedValue(
expression="g",
value_ref="g",
value_str=(
". at {hex(id(g))}>"
),
value_type="generator",
expandable=True,
),
attributes=[],
)
def gen():
for i in range(10):
yield i
g = gen()
next(g)
assert ExpressionEvaluator.evaluate(
"g", runtime_context={"g": g}
) == EvaluatedValue(
value=TypedValue(
expression="g",
value_ref="g",
value_str=(
".gen at {hex(id(g))}>"
),
value_type="generator",
expandable=True,
),
attributes=[
TypedValue(
expression="i",
value_ref="i",
value_str="0",
value_type="int",
expandable=False,
)
],
)
it = iter([1, 2, 3])
assert ExpressionEvaluator.evaluate(
"it", runtime_context={"it": it}
) == EvaluatedValue(
value=TypedValue(
expression="it",
value_ref="it",
value_str=f"",
value_type="list_iterator",
expandable=False,
),
attributes=[],
)
it = iter([1, 2, 3])
next(it)
assert ExpressionEvaluator.evaluate(
"it", runtime_context={"it": it}
) == EvaluatedValue(
value=TypedValue(
expression="it",
value_ref="it",
value_str=f"",
value_type="list_iterator",
expandable=False,
),
attributes=[],
)
def test_evaluate_in_another_context(self):
i = 10
j = 20
assert ExpressionEvaluator.evaluate(
"j", runtime_context={"j": i, "i": j}
) == EvaluatedValue(
value=TypedValue(
expression="j",
value_ref="j",
value_str="10",
value_type="int",
expandable=False,
),
attributes=[],
)
assert ExpressionEvaluator.evaluate(
"i", runtime_context={"j": i, "i": j}
) == EvaluatedValue(
value=TypedValue(
expression="i",
value_ref="i",
value_str="20",
value_type="int",
expandable=False,
),
attributes=[],
)
================================================
FILE: amber/src/test/python/core/util/test_virtual_identity.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from core.util.virtual_identity import (
deserialize_global_port_identity,
get_from_actor_id_for_input_port_storage,
get_worker_index,
serialize_global_port_identity,
)
from proto.org.apache.texera.amber.core import (
ActorVirtualIdentity,
GlobalPortIdentity,
OperatorIdentity,
PhysicalOpIdentity,
PortIdentity,
)
def _gpi(
op_id: str = "myOp",
layer: str = "main",
port: int = 0,
internal: bool = False,
is_input: bool = True,
) -> GlobalPortIdentity:
return GlobalPortIdentity(
op_id=PhysicalOpIdentity(
logical_op_id=OperatorIdentity(id=op_id), layer_name=layer
),
port_id=PortIdentity(id=port, internal=internal),
input=is_input,
)
class TestGetWorkerIndex:
def test_extracts_trailing_numeric_index_from_worker_actor_name(self):
assert get_worker_index("Worker:WF1-myOp-main-7") == 7
def test_handles_multi_digit_indexes(self):
assert get_worker_index("Worker:WF42-someOp-layerX-1234") == 1234
def test_raises_value_error_on_unmatched_actor_name(self):
# Companions like CONTROLLER / SELF do not match the worker pattern.
with pytest.raises(ValueError, match="Invalid worker ID format"):
get_worker_index("CONTROLLER")
def test_raises_value_error_on_partial_match(self):
# Missing trailing index also fails the match.
with pytest.raises(ValueError, match="Invalid worker ID format"):
get_worker_index("Worker:WF1-myOp-main")
def test_extracts_trailing_index_even_when_layer_name_contains_hyphens(self):
# The Scala VirtualIdentityUtils sibling has a documented bug where
# the layer capture group `(\w+)` cannot accept hyphens (Bug #4728),
# but Python's get_worker_index only consumes the trailing index
# group `(\d+)`, so greedy backtracking on `.+` still leaves the
# final dash-separated number for capture and the index comes out
# correctly. Pin this so a future regex tightening that drops the
# greedy `.+` and breaks the trailing match surfaces here.
assert get_worker_index("Worker:WF1-myOp-1st-physical-op-3") == 3
class TestSerializeGlobalPortIdentity:
def test_emits_documented_format_for_canonical_input(self):
encoded = serialize_global_port_identity(_gpi())
assert (
encoded
== "(logicalOpId=myOp,layerName=main,portId=0,isInternal=false,isInput=true)"
)
def test_lowercases_boolean_fields(self):
# Pin: the format spec spells out `true`/`false` lowercase, even
# though Python's str(bool) is `True`/`False`. Lowercasing is the
# contract the deserializer relies on.
encoded = serialize_global_port_identity(_gpi(internal=True, is_input=False))
assert "isInternal=true" in encoded
assert "isInput=false" in encoded
def test_round_trips_through_deserialize(self):
original = _gpi(
op_id="myOp", layer="main", port=3, internal=True, is_input=False
)
recovered = deserialize_global_port_identity(
serialize_global_port_identity(original)
)
assert recovered.op_id.logical_op_id.id == "myOp"
assert recovered.op_id.layer_name == "main"
assert recovered.port_id.id == 3
assert recovered.port_id.internal is True
assert recovered.input is False
class TestDeserializeGlobalPortIdentity:
def test_parses_canonical_encoded_string(self):
encoded = "(logicalOpId=op,layerName=l,portId=2,isInternal=true,isInput=true)"
result = deserialize_global_port_identity(encoded)
assert result.op_id.logical_op_id.id == "op"
assert result.op_id.layer_name == "l"
assert result.port_id.id == 2
assert result.port_id.internal is True
assert result.input is True
def test_treats_boolean_capitalization_case_insensitively(self):
# The deserializer lowercases the captured token before comparing,
# so producers that emit `True`/`TRUE` still parse cleanly even
# though the canonical serializer always writes lowercase.
encoded = "(logicalOpId=op,layerName=l,portId=0,isInternal=TRUE,isInput=False)"
result = deserialize_global_port_identity(encoded)
assert result.port_id.internal is True
assert result.input is False
def test_raises_value_error_on_malformed_input(self):
with pytest.raises(ValueError, match="Invalid GlobalPortIdentity format"):
deserialize_global_port_identity("not-a-port-id")
def test_raises_value_error_on_missing_field(self):
# The pattern requires all five comma-separated fields. Dropping one
# — here `isInput` — must surface as ValueError, not silent default.
with pytest.raises(ValueError, match="Invalid GlobalPortIdentity format"):
deserialize_global_port_identity(
"(logicalOpId=op,layerName=l,portId=0,isInternal=true)"
)
class TestGetFromActorIdForInputPortStorage:
def test_prefixes_materialization_reader_to_uri_plus_actor_name(self):
actor = ActorVirtualIdentity(name="Worker:WF1-myOp-main-0")
virtual_reader = get_from_actor_id_for_input_port_storage(
"iceberg:/warehouse/x", actor
)
assert virtual_reader.name == (
"MATERIALIZATION_READER_iceberg:/warehouse/xWorker:WF1-myOp-main-0"
)
================================================
FILE: amber/src/test/python/pytexera/storage/test_dataset_file_document.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import pytest
from unittest.mock import patch, MagicMock
from pytexera.storage.dataset_file_document import DatasetFileDocument
DEFAULT_ENDPOINT = "http://localhost:9092/api/dataset/presign-download"
CUSTOM_ENDPOINT = "https://example.test/api/presign"
@pytest.fixture
def auth_env(monkeypatch):
"""Provide a JWT and pinned presign endpoint for the duration of one test."""
monkeypatch.setenv("USER_JWT_TOKEN", "test-jwt-token")
monkeypatch.setenv("FILE_SERVICE_GET_PRESIGNED_URL_ENDPOINT", CUSTOM_ENDPOINT)
def make_response(status_code: int, body=None, content: bytes = b""):
response = MagicMock()
response.status_code = status_code
response.json.return_value = body or {}
response.text = "" if body is None else str(body)
response.content = content
return response
class TestDatasetFileDocumentInit:
def test_parses_minimal_four_part_path(self, auth_env):
doc = DatasetFileDocument("/bob@x.com/ds/v1/file.csv")
assert doc.owner_email == "bob@x.com"
assert doc.dataset_name == "ds"
assert doc.version_name == "v1"
assert doc.file_relative_path == "file.csv"
def test_joins_nested_relative_path_back_with_slashes(self, auth_env):
doc = DatasetFileDocument("/bob@x.com/ds/v1/a/b/c/file.csv")
assert doc.file_relative_path == "a/b/c/file.csv"
def test_strips_leading_and_trailing_slashes_before_parsing(self, auth_env):
doc = DatasetFileDocument("///bob@x.com/ds/v1/file.csv///")
assert doc.owner_email == "bob@x.com"
assert doc.file_relative_path == "file.csv"
def test_rejects_path_with_fewer_than_four_segments(self, auth_env):
with pytest.raises(ValueError, match="Invalid file path format"):
DatasetFileDocument("/bob@x.com/ds/v1")
def test_requires_jwt_token_in_environment(self, monkeypatch):
monkeypatch.delenv("USER_JWT_TOKEN", raising=False)
monkeypatch.setenv("FILE_SERVICE_GET_PRESIGNED_URL_ENDPOINT", CUSTOM_ENDPOINT)
with pytest.raises(ValueError, match="JWT token is required"):
DatasetFileDocument("/bob@x.com/ds/v1/file.csv")
def test_treats_empty_jwt_as_missing(self, monkeypatch):
# An empty string is falsy and should be rejected just like an unset var.
monkeypatch.setenv("USER_JWT_TOKEN", "")
with pytest.raises(ValueError, match="JWT token is required"):
DatasetFileDocument("/bob@x.com/ds/v1/file.csv")
def test_falls_back_to_default_endpoint_when_env_missing(self, monkeypatch):
monkeypatch.setenv("USER_JWT_TOKEN", "tok")
monkeypatch.delenv("FILE_SERVICE_GET_PRESIGNED_URL_ENDPOINT", raising=False)
doc = DatasetFileDocument("/bob@x.com/ds/v1/file.csv")
assert doc.presign_endpoint == DEFAULT_ENDPOINT
def test_uses_explicit_endpoint_from_environment(self, auth_env):
doc = DatasetFileDocument("/bob@x.com/ds/v1/file.csv")
assert doc.presign_endpoint == CUSTOM_ENDPOINT
class TestGetPresignedUrl:
def _make_doc(self, monkeypatch, path="/bob@x.com/ds/v1/file.csv"):
monkeypatch.setenv("USER_JWT_TOKEN", "test-jwt-token")
monkeypatch.setenv("FILE_SERVICE_GET_PRESIGNED_URL_ENDPOINT", CUSTOM_ENDPOINT)
return DatasetFileDocument(path)
def test_returns_presigned_url_field_from_json_body(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.return_value = make_response(
200, body={"presignedUrl": "https://signed.test/x"}
)
assert doc.get_presigned_url() == "https://signed.test/x"
def test_sends_bearer_authorization_header_with_jwt(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": "u"})
doc.get_presigned_url()
_, kwargs = mock_get.call_args
assert kwargs["headers"] == {"Authorization": "Bearer test-jwt-token"}
def test_url_encodes_filepath_query_parameter(self, monkeypatch):
# urllib.parse.quote keeps "/" as safe by default, but encodes "@"
# and " " — pin both pieces so the contract is explicit.
doc = self._make_doc(monkeypatch, path="/bob@x.com/ds/v1/data file.csv")
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": "u"})
doc.get_presigned_url()
_, kwargs = mock_get.call_args
file_path = kwargs["params"]["filePath"]
assert "data%20file.csv" in file_path
assert "bob%40x.com" in file_path
assert file_path.startswith("/")
def test_calls_configured_endpoint(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.return_value = make_response(200, body={"presignedUrl": "u"})
doc.get_presigned_url()
args, _ = mock_get.call_args
assert args[0] == CUSTOM_ENDPOINT
def test_raises_runtime_error_with_status_and_body_on_failure(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.return_value = make_response(403, body="forbidden")
with pytest.raises(RuntimeError, match=r"403.*forbidden"):
doc.get_presigned_url()
def test_returns_none_when_response_body_lacks_presigned_url_key(self, monkeypatch):
# Pins current behavior: a 200 with no "presignedUrl" key yields None
# rather than raising. read_file() will then call requests.get(None).
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.return_value = make_response(200, body={"other": "value"})
assert doc.get_presigned_url() is None
class TestReadFile:
def _make_doc(self, monkeypatch):
monkeypatch.setenv("USER_JWT_TOKEN", "test-jwt-token")
monkeypatch.setenv("FILE_SERVICE_GET_PRESIGNED_URL_ENDPOINT", CUSTOM_ENDPOINT)
return DatasetFileDocument("/bob@x.com/ds/v1/file.csv")
def test_returns_bytesio_with_downloaded_content(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
make_response(200, content=b"hello-bytes"),
]
buf = doc.read_file()
assert isinstance(buf, io.BytesIO)
assert buf.read() == b"hello-bytes"
def test_propagates_presigned_url_failure(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.return_value = make_response(500, body="upstream down")
with pytest.raises(RuntimeError, match=r"500.*upstream down"):
doc.read_file()
def test_raises_runtime_error_when_download_fails(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
make_response(404, body="missing"),
]
with pytest.raises(RuntimeError, match=r"404.*missing"):
doc.read_file()
def test_downloads_from_presigned_url_returned_by_first_call(self, monkeypatch):
doc = self._make_doc(monkeypatch)
with patch("pytexera.storage.dataset_file_document.requests.get") as mock_get:
mock_get.side_effect = [
make_response(200, body={"presignedUrl": "https://signed.test/x"}),
make_response(200, content=b""),
]
doc.read_file()
second_call_args, _ = mock_get.call_args_list[1]
assert second_call_args[0] == "https://signed.test/x"
================================================
FILE: amber/src/test/python/pytexera/storage/test_large_binary_input_stream.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from unittest.mock import patch, MagicMock
from io import BytesIO
from core.models.type.large_binary import largebinary
from pytexera.storage.large_binary_input_stream import LargeBinaryInputStream
from pytexera.storage import large_binary_manager
class TestLargeBinaryInputStream:
@pytest.fixture
def large_binary(self):
"""Create a test largebinary."""
return largebinary("s3://test-bucket/path/to/object")
@pytest.fixture
def mock_s3_response(self):
"""Create a mock S3 response with a BytesIO body."""
return {"Body": BytesIO(b"test data content")}
def test_init_with_valid_large_binary(self, large_binary):
"""Test initialization with a valid largebinary."""
stream = LargeBinaryInputStream(large_binary)
try:
assert stream._large_binary == large_binary
assert stream._underlying is None
assert not stream.closed
finally:
stream.close()
def test_init_with_none_raises_error(self):
"""Test that initializing with None raises ValueError."""
with pytest.raises(ValueError, match="largebinary cannot be None"):
LargeBinaryInputStream(None)
def test_lazy_init_downloads_from_s3(self, large_binary, mock_s3_response):
"""Test that _lazy_init downloads from S3 on first read."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = mock_s3_response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
try:
assert stream._underlying is None # Not initialized yet
# Trigger lazy init by reading
data = stream.read()
assert data == b"test data content"
assert stream._underlying is not None
# Verify S3 was called correctly
mock_s3_client.get_object.assert_called_once_with(
Bucket="test-bucket", Key="path/to/object"
)
finally:
stream.close()
def test_read_all(self, large_binary, mock_s3_response):
"""Test reading all data."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = mock_s3_response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
try:
data = stream.read()
assert data == b"test data content"
finally:
stream.close()
def test_read_partial(self, large_binary, mock_s3_response):
"""Test reading partial data."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = mock_s3_response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
try:
data = stream.read(4)
assert data == b"test"
finally:
stream.close()
def test_readline(self, large_binary):
"""Test reading a line."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
response = {"Body": BytesIO(b"line1\nline2\nline3")}
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
try:
line = stream.readline()
assert line == b"line1\n"
finally:
stream.close()
def test_readlines(self, large_binary):
"""Test reading all lines."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
response = {"Body": BytesIO(b"line1\nline2\nline3")}
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
try:
lines = stream.readlines()
assert lines == [b"line1\n", b"line2\n", b"line3"]
finally:
stream.close()
def test_readable(self, large_binary):
"""Test readable() method."""
stream = LargeBinaryInputStream(large_binary)
try:
assert stream.readable() is True
stream.close()
assert stream.readable() is False
finally:
if not stream.closed:
stream.close()
def test_seekable(self, large_binary):
"""Test seekable() method (should always return False)."""
stream = LargeBinaryInputStream(large_binary)
try:
assert stream.seekable() is False
finally:
stream.close()
def test_closed_property(self, large_binary):
"""Test closed property."""
stream = LargeBinaryInputStream(large_binary)
try:
assert stream.closed is False
stream.close()
assert stream.closed is True
finally:
if not stream.closed:
stream.close()
def test_close(self, large_binary, mock_s3_response):
"""Test closing the stream."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = mock_s3_response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
stream.read(1) # Trigger lazy init
assert stream._underlying is not None
stream.close()
assert stream.closed is True
assert stream._underlying.closed
def test_context_manager(self, large_binary, mock_s3_response):
"""Test using as context manager."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = mock_s3_response
mock_get_s3_client.return_value = mock_s3_client
with LargeBinaryInputStream(large_binary) as stream:
data = stream.read()
assert data == b"test data content"
assert not stream.closed
# Stream should be closed after context exit
assert stream.closed
def test_iteration(self, large_binary):
"""Test iteration over lines."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
response = {"Body": BytesIO(b"line1\nline2\nline3")}
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
try:
lines = list(stream)
assert lines == [b"line1\n", b"line2\n", b"line3"]
finally:
stream.close()
def test_read_after_close_raises_error(self, large_binary, mock_s3_response):
"""Test that reading after close raises ValueError."""
with patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client:
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = mock_s3_response
mock_get_s3_client.return_value = mock_s3_client
stream = LargeBinaryInputStream(large_binary)
stream.close()
with pytest.raises(ValueError, match="I/O operation on closed stream"):
stream.read()
# Stream is already closed, no need to close again
================================================
FILE: amber/src/test/python/pytexera/storage/test_large_binary_manager.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from unittest.mock import patch, MagicMock
from pytexera.storage import large_binary_manager
from core.storage.storage_config import StorageConfig
class TestLargeBinaryManager:
@pytest.fixture(autouse=True)
def setup_storage_config(self):
"""Initialize StorageConfig for tests."""
if not StorageConfig._initialized:
StorageConfig.initialize(
catalog_type="postgres",
postgres_uri_without_scheme="localhost:5432/test",
postgres_username="test",
postgres_password="test",
rest_catalog_uri="http://localhost:8181/catalog/",
rest_catalog_warehouse_name="texera",
table_result_namespace="test",
directory_path="/tmp/test",
commit_batch_size=1000,
s3_endpoint="http://localhost:9000",
s3_region="us-east-1",
s3_auth_username="minioadmin",
s3_auth_password="minioadmin",
)
def test_get_s3_client_initializes_once(self):
"""Test that S3 client is initialized and cached."""
# Reset the client
large_binary_manager._s3_client = None
with patch("boto3.client") as mock_boto3_client:
mock_client = MagicMock()
mock_boto3_client.return_value = mock_client
# First call should create client
client1 = large_binary_manager._get_s3_client()
assert client1 == mock_client
assert mock_boto3_client.call_count == 1
# Second call should return cached client
client2 = large_binary_manager._get_s3_client()
assert client2 == mock_client
assert mock_boto3_client.call_count == 1 # Still 1, not 2
def test_get_s3_client_without_boto3_raises_error(self):
"""Test that missing boto3 raises RuntimeError."""
large_binary_manager._s3_client = None
import sys
# Temporarily remove boto3 from sys.modules to simulate it not being installed
boto3_backup = sys.modules.pop("boto3", None)
try:
# Mock the import to raise ImportError
original_import = __import__
def mock_import(name, *args, **kwargs):
if name == "boto3":
raise ImportError("No module named boto3")
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=mock_import):
with pytest.raises(RuntimeError, match="boto3 required"):
large_binary_manager._get_s3_client()
finally:
# Restore boto3 if it was there
if boto3_backup is not None:
sys.modules["boto3"] = boto3_backup
def test_ensure_bucket_exists_when_bucket_exists(self):
"""Test that existing bucket doesn't trigger creation."""
large_binary_manager._s3_client = None
with patch("boto3.client") as mock_boto3_client:
mock_client = MagicMock()
mock_boto3_client.return_value = mock_client
# head_bucket doesn't raise exception (bucket exists)
mock_client.head_bucket.return_value = None
mock_client.exceptions.NoSuchBucket = type("NoSuchBucket", (Exception,), {})
large_binary_manager._ensure_bucket_exists("test-bucket")
mock_client.head_bucket.assert_called_once_with(Bucket="test-bucket")
mock_client.create_bucket.assert_not_called()
def test_ensure_bucket_exists_creates_bucket_when_missing(self):
"""Test that missing bucket triggers creation."""
large_binary_manager._s3_client = None
with patch("boto3.client") as mock_boto3_client:
mock_client = MagicMock()
mock_boto3_client.return_value = mock_client
# head_bucket raises NoSuchBucket exception
no_such_bucket = type("NoSuchBucket", (Exception,), {})
mock_client.exceptions.NoSuchBucket = no_such_bucket
mock_client.head_bucket.side_effect = no_such_bucket()
large_binary_manager._ensure_bucket_exists("test-bucket")
mock_client.head_bucket.assert_called_once_with(Bucket="test-bucket")
mock_client.create_bucket.assert_called_once_with(Bucket="test-bucket")
def test_create_generates_unique_uri(self):
"""Test that create() generates a unique S3 URI."""
large_binary_manager._s3_client = None
with patch("boto3.client") as mock_boto3_client:
mock_client = MagicMock()
mock_boto3_client.return_value = mock_client
mock_client.head_bucket.return_value = None
mock_client.exceptions.NoSuchBucket = type("NoSuchBucket", (Exception,), {})
uri = large_binary_manager.create()
# Check URI format
assert uri.startswith("s3://")
assert uri.startswith(f"s3://{large_binary_manager.DEFAULT_BUCKET}/")
assert "objects/" in uri
# Verify bucket was checked/created
mock_client.head_bucket.assert_called_once_with(
Bucket=large_binary_manager.DEFAULT_BUCKET
)
def test_create_uses_default_bucket(self):
"""Test that create() uses the default bucket."""
large_binary_manager._s3_client = None
with patch("boto3.client") as mock_boto3_client:
mock_client = MagicMock()
mock_boto3_client.return_value = mock_client
mock_client.head_bucket.return_value = None
mock_client.exceptions.NoSuchBucket = type("NoSuchBucket", (Exception,), {})
uri = large_binary_manager.create()
assert large_binary_manager.DEFAULT_BUCKET in uri
================================================
FILE: amber/src/test/python/pytexera/storage/test_large_binary_output_stream.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import queue
import pytest
import time
from unittest.mock import patch, MagicMock
from core.models.type.large_binary import largebinary
from pytexera.storage.large_binary_output_stream import (
LargeBinaryOutputStream,
_QueueReader,
)
from pytexera.storage import large_binary_manager
class TestLargeBinaryOutputStream:
@pytest.fixture
def large_binary(self):
"""Create a test largebinary."""
return largebinary("s3://test-bucket/path/to/object")
def test_init_with_valid_large_binary(self, large_binary):
"""Test initialization with a valid largebinary."""
stream = LargeBinaryOutputStream(large_binary)
assert stream._large_binary == large_binary
assert stream._bucket_name == "test-bucket"
assert stream._object_key == "path/to/object"
assert not stream.closed
assert stream._upload_thread is None
def test_init_with_none_raises_error(self):
"""Test that initializing with None raises ValueError."""
with pytest.raises(ValueError, match="largebinary cannot be None"):
LargeBinaryOutputStream(None)
def test_write_starts_upload_thread(self, large_binary):
"""Test that write() starts the upload thread."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
stream = LargeBinaryOutputStream(large_binary)
assert stream._upload_thread is None
stream.write(b"test data")
assert stream._upload_thread is not None
# Thread may have already completed, so just check it was created
assert stream._upload_thread is not None
# Wait for thread to finish
stream.close()
def test_write_data(self, large_binary):
"""Test writing data to the stream."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
stream = LargeBinaryOutputStream(large_binary)
bytes_written = stream.write(b"test data")
assert bytes_written == len(b"test data")
stream.close()
def test_write_multiple_chunks(self, large_binary):
"""Test writing multiple chunks of data."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
stream = LargeBinaryOutputStream(large_binary)
stream.write(b"chunk1")
stream.write(b"chunk2")
stream.write(b"chunk3")
stream.close()
def test_writable(self, large_binary):
"""Test writable() method."""
stream = LargeBinaryOutputStream(large_binary)
assert stream.writable() is True
stream.close()
assert stream.writable() is False
def test_seekable(self, large_binary):
"""Test seekable() method (should always return False)."""
stream = LargeBinaryOutputStream(large_binary)
assert stream.seekable() is False
def test_closed_property(self, large_binary):
"""Test closed property."""
stream = LargeBinaryOutputStream(large_binary)
assert stream.closed is False
stream.close()
assert stream.closed is True
def test_flush(self, large_binary):
"""Test flush() method (should be a no-op)."""
stream = LargeBinaryOutputStream(large_binary)
# Should not raise any exception
stream.flush()
def test_close_completes_upload(self, large_binary):
"""Test that close() completes the upload."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
stream = LargeBinaryOutputStream(large_binary)
stream.write(b"test data")
# Close should wait for upload to complete
stream.close()
# Verify upload_fileobj was called
assert mock_s3.upload_fileobj.called
def test_context_manager(self, large_binary):
"""Test using as context manager."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
with LargeBinaryOutputStream(large_binary) as stream:
stream.write(b"test data")
assert not stream.closed
# Stream should be closed after context exit
assert stream.closed
def test_write_after_close_raises_error(self, large_binary):
"""Test that writing after close raises ValueError."""
stream = LargeBinaryOutputStream(large_binary)
stream.close()
with pytest.raises(ValueError, match="I/O operation on closed stream"):
stream.write(b"data")
def test_close_handles_upload_error(self, large_binary):
"""Test that close() raises IOError if upload fails."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
mock_s3.upload_fileobj.side_effect = Exception("Upload failed")
stream = LargeBinaryOutputStream(large_binary)
stream.write(b"test data")
with pytest.raises(IOError, match="Failed to complete upload"):
stream.close()
def test_write_after_upload_error_raises_error(self, large_binary):
"""Test that writing after upload error raises IOError."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
mock_s3.upload_fileobj.side_effect = Exception("Upload failed")
stream = LargeBinaryOutputStream(large_binary)
stream.write(b"test data")
# Wait a bit for the error to be set
time.sleep(0.1)
with pytest.raises(IOError, match="Background upload failed"):
stream.write(b"more data")
def test_multiple_close_calls(self, large_binary):
"""Test that multiple close() calls are safe."""
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
stream = LargeBinaryOutputStream(large_binary)
stream.write(b"test data")
stream.close()
# Second close should not raise error
stream.close()
class TestCleanupFailedUpload:
"""Direct unit tests for _cleanup_failed_upload's silent-swallow path."""
@pytest.fixture
def large_binary(self):
return largebinary("s3://test-bucket/path/to/object")
def test_delete_object_failure_is_swallowed(self, large_binary):
# If the post-failure cleanup itself raises, the original upload
# IOError must still surface unmasked. Pinning this so a future
# change that propagates cleanup errors is intentional.
with (
patch.object(large_binary_manager, "_get_s3_client") as mock_get_s3_client,
patch.object(
large_binary_manager, "_ensure_bucket_exists"
) as mock_ensure_bucket,
):
mock_s3 = MagicMock()
mock_get_s3_client.return_value = mock_s3
mock_ensure_bucket.return_value = None
mock_s3.upload_fileobj.side_effect = Exception("upload failed")
mock_s3.delete_object.side_effect = Exception("delete also failed")
stream = LargeBinaryOutputStream(large_binary)
stream.write(b"data")
with pytest.raises(IOError, match="Failed to complete upload"):
stream.close()
mock_s3.delete_object.assert_called_once_with(
Bucket="test-bucket", Key="path/to/object"
)
class TestQueueReader:
"""Direct unit tests for the private _QueueReader helper."""
@staticmethod
def _populate(q: queue.Queue, *items):
for item in items:
q.put(item)
return q
def test_read_returns_empty_on_immediate_eof(self):
q = self._populate(queue.Queue(), None)
reader = _QueueReader(q)
assert reader.read() == b""
def test_read_after_eof_returns_empty_repeatedly(self):
q = self._populate(queue.Queue(), b"abc", None)
reader = _QueueReader(q)
assert reader.read() == b"abc"
# Subsequent reads must keep returning empty without blocking.
assert reader.read() == b""
assert reader.read(10) == b""
def test_read_default_size_joins_all_chunks_until_eof(self):
q = self._populate(queue.Queue(), b"abc", b"def", b"ghi", None)
reader = _QueueReader(q)
assert reader.read() == b"abcdefghi"
def test_read_with_explicit_size_smaller_than_first_chunk(self):
q = self._populate(queue.Queue(), b"abcdef", None)
reader = _QueueReader(q)
assert reader.read(3) == b"abc"
# Remainder is buffered for the next read; EOF marker drained next.
assert reader.read() == b"def"
def test_read_buffer_remainder_carries_over_subsequent_calls(self):
q = self._populate(queue.Queue(), b"helloworld", None)
reader = _QueueReader(q)
assert reader.read(5) == b"hello"
# Pull two more bytes from the buffer; rest stays buffered.
assert reader.read(2) == b"wo"
assert reader.read() == b"rld"
def test_read_size_can_span_multiple_queued_chunks(self):
q = self._populate(queue.Queue(), b"ab", b"cd", b"ef", None)
reader = _QueueReader(q)
assert reader.read(5) == b"abcde"
assert reader.read() == b"f"
def test_read_size_zero_returns_empty_and_preserves_buffer(self):
# _QueueReader.read(size=0) must short-circuit without consuming
# bytes that the caller hasn't asked for.
q = self._populate(queue.Queue(), b"abc", None)
reader = _QueueReader(q)
# Prime the buffer by reading 1 byte, leaving "bc" buffered.
assert reader.read(1) == b"a"
assert reader.read(0) == b""
# Nothing was lost: a follow-up read still surfaces the rest.
assert reader.read() == b"bc"
def test_read_with_size_larger_than_available_returns_all_before_eof(self):
q = self._populate(queue.Queue(), b"abc", None)
reader = _QueueReader(q)
assert reader.read(100) == b"abc"
def test_eof_only_terminates_when_queue_drained_first(self):
# Bytes queued before the EOF sentinel must all surface in the first read.
q = self._populate(queue.Queue(), b"x", b"y", b"z", None)
reader = _QueueReader(q)
assert reader.read() == b"xyz"
def test_read_polls_until_data_arrives(self):
# Validates the queue.Empty retry path: the reader must continue
# past a timeout and only return once data is available.
# Using a mock with a deterministic side_effect avoids real sleeps
# and the flakiness of relying on a background thread under load.
q = MagicMock()
q.get.side_effect = [queue.Empty(), b"late", None]
reader = _QueueReader(q)
assert reader.read() == b"late"
# The first call raised Empty, so we expect three total get() calls.
assert q.get.call_count == 3
================================================
FILE: amber/src/test/python/pytexera/udf/examples/test_count_batch_operator.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import inspect
import pytest
from collections import deque
from pytexera import *
from pytexera.udf.examples.count_batch_operator import CountBatchOperator
class TestCountBatchOperator:
@pytest.fixture
def count_batch_operator(self):
return CountBatchOperator()
def test_count_batch_operator(self, count_batch_operator):
count_batch_operator.open()
for i in range(27):
deque(
count_batch_operator.process_tuple(
Tuple({"test-1": "hello", "test-2": 10}), 0
)
)
deque(count_batch_operator.on_finish(0))
batch_counter = count_batch_operator.count
assert batch_counter == 3
count_batch_operator.close()
def test_count_batch_operator_simple(self, count_batch_operator):
count_batch_operator.open()
for i in range(20):
deque(
count_batch_operator.process_tuple(
Tuple({"test-1": "hello", "test-2": 10}), 0
)
)
deque(count_batch_operator.on_finish(0))
batch_counter = count_batch_operator.count
assert batch_counter == 2
count_batch_operator.close()
def test_count_batch_operator_medium(self, count_batch_operator):
count_batch_operator.open()
for i in range(27):
deque(
count_batch_operator.process_tuple(
Tuple({"test-1": "hello", "test-2": 10}), 0
)
)
deque(count_batch_operator.on_finish(0))
batch_counter = count_batch_operator.count
assert batch_counter == 3
count_batch_operator.close()
def test_count_batch_operator_hard(self, count_batch_operator):
count_batch_operator.open()
count_batch_operator.BATCH_SIZE = 10
for i in range(27):
deque(
count_batch_operator.process_tuple(
Tuple({"test-1": "hello", "test-2": 10}), 0
)
)
count_batch_operator.BATCH_SIZE = 5
for i in range(27):
deque(
count_batch_operator.process_tuple(
Tuple({"test-1": "hello", "test-2": 10}), 0
)
)
deque(count_batch_operator.on_finish(0))
batch_counter = count_batch_operator.count
assert batch_counter == 9
count_batch_operator.close()
def test_edge_case_string(self):
with pytest.raises(ValueError) as exc_info:
operator_string = str(inspect.getsource(CountBatchOperator))
operator_string = operator_string.replace(
"BATCH_SIZE = 10", 'BATCH_SIZE = "test"'
)
operator_string += "operator = CountBatchOperator()"
exec(operator_string)
assert (
exc_info.value.args[0]
== "BATCH_SIZE cannot be " + str(type("test")) + "."
)
def test_edge_case_non_positive(self, count_batch_operator):
with pytest.raises(ValueError) as exc_info:
operator_string = str(inspect.getsource(CountBatchOperator))
operator_string = operator_string.replace(
"BATCH_SIZE = 10", "BATCH_SIZE = -20"
)
operator_string += "operator = CountBatchOperator()"
exec(operator_string)
assert exc_info.value.args[0] == "BATCH_SIZE should be positive."
def test_edge_case_none(self, count_batch_operator):
with pytest.raises(ValueError) as exc_info:
operator_string = str(inspect.getsource(CountBatchOperator))
operator_string = operator_string.replace(
"BATCH_SIZE = 10", "BATCH_SIZE = None"
)
operator_string += "operator = CountBatchOperator()"
exec(operator_string)
assert exc_info.value.args[0] == "BATCH_SIZE cannot be None."
================================================
FILE: amber/src/test/python/pytexera/udf/examples/test_echo_operator.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pytexera import Tuple
from pytexera.udf.examples.echo_operator import EchoOperator
class TestEchoOperator:
@pytest.fixture
def echo_operator(self):
return EchoOperator()
def test_echo_operator(self, echo_operator):
echo_operator.open()
tuple_ = Tuple({"test-1": "hello", "test-2": 10})
outputs = echo_operator.process_tuple(tuple_, 0)
output_tuple = next(outputs)
assert output_tuple == tuple_
with pytest.raises(StopIteration):
next(outputs)
================================================
FILE: amber/src/test/python/pytexera/udf/examples/test_echo_table_operator.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from collections import deque
from core.models.table import all_output_to_tuple
from pytexera import Tuple
from pytexera.udf.examples.echo_table_operator import EchoTableOperator
class TestEchoTableOperator:
@pytest.fixture
def echo_table_operator(self):
return EchoTableOperator()
def test_echo_table_operator(self, echo_table_operator):
echo_table_operator.open()
tuple_ = Tuple({"test-1": "hello", "test-2": 10})
print(tuple_)
deque(echo_table_operator.process_tuple(tuple_, 0))
outputs = echo_table_operator.on_finish(0)
output_tuple = next(all_output_to_tuple(next(outputs)))
assert output_tuple == tuple_
with pytest.raises(StopIteration):
next(outputs)
echo_table_operator.close()
================================================
FILE: amber/src/test/python/pytexera/udf/examples/test_generator_operator_binary.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pytexera import Tuple
from pytexera.udf.examples.generator_operator_binary import GeneratorOperatorBinary
class TestEchoOperator:
@pytest.fixture
def generator_operator_binary(self):
return GeneratorOperatorBinary()
def test_generator_operator_binary(self, generator_operator_binary):
generator_operator_binary.open()
outputs = generator_operator_binary.produce()
output_tuple = Tuple(next(outputs))
assert output_tuple == Tuple({"test": [1, 2, 3]})
generator_operator_binary.close()
================================================
FILE: amber/src/test/python/pytexera/udf/examples/test_generator_operator_integer.py
================================================
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pytexera import Tuple
from pytexera.udf.examples.generator_operator_integer import GeneratorOperatorInteger
class TestEchoOperator:
@pytest.fixture
def generator_operator_integer(self):
return GeneratorOperatorInteger()
def test_generator_operator_integer(self, generator_operator_integer):
generator_operator_integer.open()
outputs = generator_operator_integer.produce()
for i in [1, 2, 3]:
output_tuple = Tuple(next(outputs))
assert output_tuple == Tuple({"test": i})
generator_operator_integer.close()
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/common/ProcessingStepCursorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.common
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.scalatest.flatspec.AnyFlatSpec
class ProcessingStepCursorSpec extends AnyFlatSpec {
private val channelA =
ChannelIdentity(ActorVirtualIdentity("a"), ActorVirtualIdentity("b"), isControl = false)
private val channelB =
ChannelIdentity(ActorVirtualIdentity("a"), ActorVirtualIdentity("c"), isControl = true)
"ProcessingStepCursor" should "start at INIT_STEP with no current channel" in {
val cursor = new ProcessingStepCursor()
assert(cursor.getStep == ProcessingStepCursor.INIT_STEP)
assert(cursor.getStep == -1L)
assert(cursor.getChannel == null)
}
"ProcessingStepCursor.stepIncrement" should "advance the step to 0 on the first call" in {
val cursor = new ProcessingStepCursor()
cursor.stepIncrement()
assert(cursor.getStep == 0L)
}
it should "advance the step by exactly one each call" in {
val cursor = new ProcessingStepCursor()
(0 until 5).foreach(_ => cursor.stepIncrement())
assert(cursor.getStep == 4L)
}
"ProcessingStepCursor.setCurrentChannel" should "store the latest channel" in {
val cursor = new ProcessingStepCursor()
cursor.setCurrentChannel(channelA)
assert(cursor.getChannel == channelA)
cursor.setCurrentChannel(channelB)
assert(cursor.getChannel == channelB)
}
it should "leave the step counter unchanged" in {
val cursor = new ProcessingStepCursor()
cursor.stepIncrement()
cursor.setCurrentChannel(channelA)
assert(cursor.getStep == 0L)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/TrivialControlSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control
import org.apache.pekko.actor.{ActorRef, ActorSystem, PoisonPill, Props}
import org.apache.pekko.testkit.{TestKit, TestProbe}
import io.grpc.MethodDescriptor
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.common.WorkflowActor.{
GetActorRef,
NetworkAck,
NetworkMessage,
RegisterActorRef
}
import org.apache.texera.amber.engine.architecture.control.utils.TrivialControlTester
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.{
IntResponse,
ReturnInvocation,
StringResponse
}
import org.apache.texera.amber.engine.architecture.rpc.testerservice.RPCTesterGrpc._
import org.apache.texera.amber.engine.common.ambermessage.WorkflowFIFOMessage
import org.apache.texera.amber.engine.common.ambermessage.WorkflowMessage.getInMemSize
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient.ControlInvocation
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.scalatest.wordspec.AnyWordSpecLike
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import scala.collection.mutable
import scala.concurrent.duration._
class TrivialControlSpec
extends TestKit(ActorSystem("TrivialControlSpec"))
with AnyWordSpecLike
with BeforeAndAfterEach
with BeforeAndAfterAll {
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
def testControl[T](
numActors: Int,
eventPairs: ((MethodDescriptor[_, _], ControlRequest), T)*
): Unit = {
val (events, expectedValues) = eventPairs.unzip
val (probe, idMap) = setUp(numActors, events: _*)
var flag = 0
while (flag < expectedValues.length) {
probe.receiveOne(10.seconds) match {
case null =>
throw new AssertionError(
s"timeout: received $flag of ${expectedValues.length} expected returns"
)
case GetActorRef(id, replyTo) =>
replyTo.foreach { actor =>
actor ! RegisterActorRef(id, idMap(id))
}
case NetworkMessage(
msgID,
workflowMsg @ WorkflowFIFOMessage(_, _, ReturnInvocation(id, returnValue))
) =>
probe.sender() ! NetworkAck(
msgID,
getInMemSize(workflowMsg),
0L // no queued credit
)
assert(returnValue.asInstanceOf[T] == expectedValues(id.toInt))
flag += 1
case _ =>
//skip
}
}
idMap.foreach { x =>
x._2 ! PoisonPill
}
}
def setUp(
numActors: Int,
cmd: (MethodDescriptor[_, _], ControlRequest)*
): (TestProbe, mutable.HashMap[ActorVirtualIdentity, ActorRef]) = {
val probe = TestProbe()
val idMap = mutable.HashMap[ActorVirtualIdentity, ActorRef]()
for (i <- 0 until numActors) {
val id = ActorVirtualIdentity(s"$i")
val ref =
probe.childActorOf(Props(new TrivialControlTester(id)))
idMap(id) = ref
}
idMap(CONTROLLER) = probe.ref
var seqNum = 0
cmd.foreach {
case (methodName, msg) =>
probe.send(
idMap(ActorVirtualIdentity("0")),
NetworkMessage(
seqNum,
WorkflowFIFOMessage(
ChannelIdentity(CONTROLLER, ActorVirtualIdentity("0"), isControl = true),
seqNum,
ControlInvocation(
methodName,
msg,
AsyncRPCContext(CONTROLLER, ActorVirtualIdentity("0")),
seqNum
)
)
)
)
seqNum += 1
}
(probe, idMap)
}
"testers" should {
"execute Ping Pong" in {
testControl(2, ((METHOD_SEND_PING, Ping(1, 5, ActorVirtualIdentity("1"))), IntResponse(5)))
}
"execute Ping Pong 2 times" in {
testControl(
2,
((METHOD_SEND_PING, Ping(1, 4, ActorVirtualIdentity("1"))), IntResponse(4)),
((METHOD_SEND_PING, Ping(10, 13, ActorVirtualIdentity("1"))), IntResponse(13))
)
}
"execute Chain" in {
testControl(
10,
(
(METHOD_SEND_CHAIN, Chain((1 to 9).map(i => ActorVirtualIdentity(i.toString)))),
StringResponse("9")
)
)
}
"execute Collect" in {
testControl(
4,
(
(METHOD_SEND_COLLECT, Collect((1 to 3).map(i => ActorVirtualIdentity(i.toString)))),
StringResponse("finished")
)
)
}
"execute RecursiveCall" in {
testControl(1, ((METHOD_SEND_RECURSION, Recursion(0)), StringResponse("0")))
}
"execute MultiCall" in {
testControl(
10,
(
(METHOD_SEND_MULTI_CALL, MultiCall((1 to 9).map(i => ActorVirtualIdentity(i.toString)))),
StringResponse("finished")
)
)
}
"execute NestedCall" in {
testControl(1, ((METHOD_SEND_NESTED, Nested(5)), StringResponse("Hello World!")))
}
"execute ErrorCall" in {
assertThrows[RuntimeException] {
testControl(1, ((METHOD_SEND_ERROR_COMMAND, ErrorCommand()), ()))
}
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/ChainHandler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
trait ChainHandler {
this: TesterAsyncRPCHandlerInitializer =>
override def sendChain(request: Chain, ctx: AsyncRPCContext): Future[StringResponse] = {
println(s"chained $myID")
if (request.nexts.isEmpty) {
Future(StringResponse(myID.name))
} else {
getProxy.sendChain(Chain(request.nexts.drop(1)), mkContext(request.nexts.head)).map { x =>
println(s"chain returns from $x")
x
}
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/CollectHandler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
import scala.util.Random
trait CollectHandler {
this: TesterAsyncRPCHandlerInitializer =>
override def sendCollect(request: Collect, ctx: AsyncRPCContext): Future[StringResponse] = {
println(s"start collecting numbers.")
val p = Future.collect(
request.workers.indices.map(i =>
getProxy.sendGenerateNumber(GenerateNumber(), mkContext(request.workers(i)))
)
)
p.map { res =>
println(s"collected: ${res.mkString(" ")}")
StringResponse("finished")
}
}
override def sendGenerateNumber(
request: GenerateNumber,
ctx: AsyncRPCContext
): Future[IntResponse] = {
IntResponse(Random.nextInt())
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/ErrorHandler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
trait ErrorHandler {
this: TesterAsyncRPCHandlerInitializer =>
override def sendErrorCommand(
request: ErrorCommand,
ctx: AsyncRPCContext
): Future[StringResponse] = {
throw new RuntimeException("this is an EXPECTED exception for testing")
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/MultiCallHandler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
trait MultiCallHandler {
this: TesterAsyncRPCHandlerInitializer =>
override def sendMultiCall(request: MultiCall, ctx: AsyncRPCContext): Future[StringResponse] = {
getProxy
.sendChain(Chain(request.seq), myID)
.flatMap(x => getProxy.sendRecursion(Recursion(1), mkContext(ActorVirtualIdentity(x.value))))
.flatMap(ret => getProxy.sendCollect(Collect(request.seq.take(3)), myID))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/NestedHandler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
trait NestedHandler {
this: TesterAsyncRPCHandlerInitializer =>
override def sendNested(request: Nested, ctx: AsyncRPCContext): Future[StringResponse] = {
getProxy
.sendPass(Pass("Hello"), myID)
.flatMap(ret => getProxy.sendPass(Pass(ret.value + " "), myID))
.flatMap(ret => getProxy.sendPass(Pass(ret.value + "World!"), myID))
}
override def sendPass(request: Pass, ctx: AsyncRPCContext): Future[StringResponse] = {
StringResponse(request.value)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/PingPongHandler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
trait PingPongHandler {
this: TesterAsyncRPCHandlerInitializer =>
override def sendPing(ping: Ping, ctx: AsyncRPCContext): Future[IntResponse] = {
println(s"${ping.i} ping")
if (ping.i < ping.end) {
getProxy.sendPong(Pong(ping.i + 1, ping.end, myID), ping.to).map { ret: IntResponse =>
println(s"${ping.i} ping replied with value ${ret.value}!")
ret
}
} else {
Future(ping.i)
}
}
override def sendPong(pong: Pong, ctx: AsyncRPCContext): Future[IntResponse] = {
println(s"${pong.i} pong")
if (pong.i < pong.end) {
getProxy.sendPing(Ping(pong.i + 1, pong.end, myID), pong.to).map { ret: IntResponse =>
println(s"${pong.i} pong replied with value ${ret.value}!")
ret
}
} else {
Future(pong.i)
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/RecursionHandler.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
trait RecursionHandler {
this: TesterAsyncRPCHandlerInitializer =>
override def sendRecursion(r: Recursion, ctx: AsyncRPCContext): Future[StringResponse] = {
if (r.i < 5) {
println(r.i)
getProxy.sendRecursion(Recursion(r.i + 1), myID).map { res =>
println(res)
r.i.toString
}
} else {
Future(r.i.toString)
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/TesterAsyncRPCHandlerInitializer.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.engine.architecture.control.utils.TrivialControlTester.ControlTesterRPCClient
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.AsyncRPCContext
import org.apache.texera.amber.engine.architecture.rpc.testerservice.RPCTesterFs2Grpc
import org.apache.texera.amber.engine.common.rpc.{AsyncRPCHandlerInitializer, AsyncRPCServer}
class TesterAsyncRPCHandlerInitializer(
val myID: ActorVirtualIdentity,
source: ControlTesterRPCClient,
receiver: AsyncRPCServer
) extends AsyncRPCHandlerInitializer(source, receiver)
with RPCTesterFs2Grpc[Future, AsyncRPCContext]
with PingPongHandler
with ChainHandler
with MultiCallHandler
with CollectHandler
with NestedHandler
with RecursionHandler
with ErrorHandler {
def getProxy: RPCTesterFs2Grpc[Future, AsyncRPCContext] = source.getProxy
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/control/utils/TrivialControlTester.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.control.utils
import com.twitter.util.Future
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.common.WorkflowActor.NetworkAck
import org.apache.texera.amber.engine.architecture.common.{AmberProcessor, WorkflowActor}
import org.apache.texera.amber.engine.architecture.control.utils.TrivialControlTester.ControlTesterRPCClient
import org.apache.texera.amber.engine.architecture.messaginglayer.{
NetworkInputGateway,
NetworkOutputGateway
}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.AsyncRPCContext
import org.apache.texera.amber.engine.architecture.rpc.testerservice.RPCTesterFs2Grpc
import org.apache.texera.amber.engine.common.CheckpointState
import org.apache.texera.amber.engine.common.ambermessage.WorkflowMessage.getInMemSize
import org.apache.texera.amber.engine.common.ambermessage.{
DataPayload,
DirectControlMessagePayload,
WorkflowFIFOMessage
}
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient
object TrivialControlTester {
class ControlTesterRPCClient(
inputGateway: NetworkInputGateway,
outputGateway: NetworkOutputGateway,
actorId: ActorVirtualIdentity
) extends AsyncRPCClient(inputGateway, outputGateway, actorId) {
val getProxy: RPCTesterFs2Grpc[Future, AsyncRPCContext] =
AsyncRPCClient
.createProxy[RPCTesterFs2Grpc[Future, AsyncRPCContext]](createPromise, outputGateway)
}
}
class TrivialControlTester(
id: ActorVirtualIdentity
) extends WorkflowActor(replayLogConfOpt = None, actorId = id) {
val ap = new AmberProcessor(
id,
{
case Left(value) => ???
case Right(value) => transferService.send(value)
}
) {
override val asyncRPCClient = new ControlTesterRPCClient(inputGateway, outputGateway, id)
}
val initializer =
new TesterAsyncRPCHandlerInitializer(ap.actorId, ap.asyncRPCClient, ap.asyncRPCServer)
override def handleInputMessage(id: Long, workflowMsg: WorkflowFIFOMessage): Unit = {
val channel = ap.inputGateway.getChannel(workflowMsg.channelId)
channel.acceptMessage(workflowMsg)
while (channel.isEnabled && channel.hasMessage) {
val msg = channel.take
msg.payload match {
case payload: DirectControlMessagePayload => ap.processDCM(msg.channelId, payload)
case _: DataPayload => ???
case _ => ???
}
}
sender() ! NetworkAck(id, getInMemSize(workflowMsg), getQueuedCredit(workflowMsg.channelId))
}
/** flow-control */
override def getQueuedCredit(channelId: ChannelIdentity): Long = 0L
override def preStart(): Unit = {
transferService.initialize()
}
override def handleBackpressure(isBackpressured: Boolean): Unit = {}
override def initState(): Unit = {}
override def loadFromCheckpoint(chkpt: CheckpointState): Unit = {}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/controller/ControllerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.controller
import org.apache.pekko.actor.{ActorSystem, Props}
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.pekko.util.Timeout
import org.apache.texera.amber.clustering.SingleNodeListener
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpecLike
import scala.concurrent.ExecutionContextExecutor
import scala.concurrent.duration._
class ControllerSpec
extends TestKit(ActorSystem("ControllerSpec"))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll {
implicit val timeout: Timeout = Timeout(5.seconds)
implicit val executionContext: ExecutionContextExecutor = system.dispatcher
override def beforeAll(): Unit = {
system.actorOf(Props[SingleNodeListener](), "cluster-info")
}
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
// private val logicalPlan1 =
// """{
// |"operators":[
// |{"tableName":"D:\\large_input.csv","operatorId":"Scan","operatorType":"LocalScanSource","delimiter":","},
// |{"attributeName":0,"keyword":"Asia","operatorId":"KeywordSearch","operatorType":"KeywordMatcher"},
// |{"operatorId":"Count","operatorType":"Aggregation"},
// |{"operatorId":"Sink","operatorType":"Sink"}],
// |"links":[
// |{"origin":"Scan","destination":"KeywordSearch"},
// |{"origin":"KeywordSearch","destination":"Count"},
// |{"origin":"Count","destination":"Sink"}]
// |}""".stripMargin
//
// private val logicalPlan2 =
// """{
// |"operators":[
// |{"tableName":"D:\\large_input.csv","operatorId":"Scan","operatorType":"LocalScanSource","delimiter":","},
// |{"operatorId":"Count","operatorType":"Aggregation"},
// |{"operatorId":"Sink","operatorType":"Sink"}],
// |"links":[
// |{"origin":"Scan","destination":"Count"},
// |{"origin":"Count","destination":"Sink"}]
// |}""".stripMargin
//
// private val logicalPlan3 =
// """{
// |"operators":[
// |{"tableName":"D:\\test.txt","operatorId":"Scan","operatorType":"LocalScanSource","delimiter":"|"},
// |{"attributeName":15,"keyword":"package","operatorId":"KeywordSearch","operatorType":"KeywordMatcher"},
// |{"operatorId":"Count","operatorType":"Aggregation"},
// |{"operatorId":"Sink","operatorType":"Sink"}],
// |"links":[
// |{"origin":"Scan","destination":"KeywordSearch"},
// |{"origin":"KeywordSearch","destination":"Count"},
// |{"origin":"Count","destination":"Sink"}]
// |}""".stripMargin
//
// private val logicalPlan4 =
// """{
// |"operators":[
// |{"tableName":"D:\\test.txt","operatorId":"Scan1","operatorType":"LocalScanSource","delimiter":"|","indicesToKeep":null},
// |{"tableName":"D:\\test.txt","operatorId":"Scan2","operatorType":"LocalScanSource","delimiter":"|","indicesToKeep":null},
// |{"attributeName":15,"keyword":"package","operatorId":"KeywordSearch","operatorType":"KeywordMatcher"},
// |{"operatorId":"Join","operatorType":"HashJoin","innerTableIndex":0,"outerTableIndex":0},
// |{"operatorId":"Count","operatorType":"Aggregation"},
// |{"operatorId":"Sink","operatorType":"Sink"}],
// |"links":[
// |{"origin":"Scan1","destination":"KeywordSearch"},
// |{"origin":"KeywordSearch","destination":"Join"},
// |{"origin":"Scan2","destination":"Join"},
// |{"origin":"Join","destination":"Count"},
// |{"origin":"Count","destination":"Sink"}]
// |}""".stripMargin
//
// "A controller" should "be able to set and trigger count breakpoint in the workflow1" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! PassBreakpointTo("KeywordSearch", new CountGlobalBreakpoint("break1", 100000))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// var isCompleted = false
// parent.receiveWhile(30.seconds, 10.seconds) {
// case ReportState(ControllerState.Paused) =>
// controller ! Resume
// case ReportState(ControllerState.Completed) =>
// isCompleted = true
// case _ =>
// }
// assert(isCompleted)
// parent.ref ! PoisonPill
// }
//
// "A controller" should "execute the workflow1 normally" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// parent.expectMsg(1.minute, ReportState(ControllerState.Completed))
// parent.ref ! PoisonPill
// }
//
// "A controller" should "execute the workflow3 normally" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan3))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// parent.expectMsg(1.minute, ReportState(ControllerState.Completed))
// parent.ref ! PoisonPill
// }
//
// "A controller" should "execute the workflow2 normally" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan2))
// controller ! AckedControllerInitialization
// parent.expectMsg(ReportState(ControllerState.Ready))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// parent.expectMsg(1.minute, ReportState(ControllerState.Completed))
// parent.ref ! PoisonPill
// }
//
// "A controller" should "be able to pause/resume the workflow1" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(ReportState(ControllerState.Ready))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// controller ! Pause
// parent.expectMsg(ReportState(ControllerState.Pausing))
// parent.expectMsg(ReportState(ControllerState.Paused))
// controller ! Resume
// parent.expectMsg(ReportState(ControllerState.Resuming))
// parent.expectMsg(ReportState(ControllerState.Running))
// controller ! Pause
// parent.expectMsg(ReportState(ControllerState.Pausing))
// parent.expectMsg(ReportState(ControllerState.Paused))
// controller ! Resume
// parent.expectMsg(ReportState(ControllerState.Resuming))
// parent.expectMsg(ReportState(ControllerState.Running))
// controller ! Pause
// parent.expectMsg(ReportState(ControllerState.Pausing))
// parent.expectMsg(ReportState(ControllerState.Paused))
// controller ! Resume
// parent.expectMsg(ReportState(ControllerState.Resuming))
// parent.expectMsg(ReportState(ControllerState.Running))
// controller ! Pause
// parent.expectMsg(ReportState(ControllerState.Pausing))
// parent.expectMsg(ReportState(ControllerState.Paused))
// controller ! Resume
// parent.expectMsg(ReportState(ControllerState.Resuming))
// parent.expectMsg(ReportState(ControllerState.Running))
// parent.expectMsg(1.minute, ReportState(ControllerState.Completed))
// parent.ref ! PoisonPill
// }
// "A controller" should "be able to modify the logic after pausing the workflow1" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// Thread.sleep(300)
// controller ! Pause
// parent.expectMsg(ReportState(ControllerState.Pausing))
// parent.expectMsg(ReportState(ControllerState.Paused))
// controller ! ModifyLogic(
// new KeywordSearchMetadata(
// OperatorTag("sample", "KeywordSearch"),
// Constants.currentWorkerNum,
// 0,
// "asia"
// )
// )
// parent.expectMsg(Ack)
// Thread.sleep(10000)
// controller ! Resume
// parent.expectMsg(ReportState(ControllerState.Resuming))
// parent.expectMsg(ReportState(ControllerState.Running))
// parent.expectMsg(1.minute, ReportState(ControllerState.Completed))
// parent.ref ! PoisonPill
// }
// "A controller" should "be able to set and trigger conditional breakpoint in the workflow1" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! PassBreakpointTo(
// "KeywordSearch",
// new ConditionalGlobalBreakpoint("break2", x => x.getString(8).toInt == 9884)
// )
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// var isCompleted = false
// parent.receiveWhile(30.seconds, 10.seconds) {
// case ReportState(ControllerState.Paused) =>
// controller ! Resume
// case ReportState(ControllerState.Completed) =>
// isCompleted = true
// case _ =>
// }
// assert(isCompleted)
// parent.ref ! PoisonPill
// }
//
// "A controller" should "be able to set and trigger count breakpoint on complete in the workflow1" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! PassBreakpointTo("KeywordSearch", new CountGlobalBreakpoint("break1", 146017))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// var isCompleted = false
// parent.receiveWhile(30.seconds, 10.seconds) {
// case ReportState(ControllerState.Paused) =>
// controller ! Resume
// case ReportState(ControllerState.Completed) =>
// isCompleted = true
// case _ =>
// }
// assert(isCompleted)
// parent.ref ! PoisonPill
// }
//
// "A controller" should "be able to pause/resume with conditional breakpoint in the workflow1" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! PassBreakpointTo(
// "KeywordSearch",
// new ConditionalGlobalBreakpoint("break2", x => x.getString(8).toInt == 9884)
// )
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// val random = new Random()
// for (i <- 0 until 100) {
// if (random.nextBoolean()) {
// controller ! Pause
// } else {
// controller ! Resume
// }
// }
// controller ! Resume
// var isCompleted = false
// parent.receiveWhile(30.seconds, 10.seconds) {
// case ReportState(ControllerState.Paused) =>
// controller ! Resume
// case ReportState(ControllerState.Completed) =>
// isCompleted = true
// case _ =>
// }
// assert(isCompleted)
// parent.ref ! PoisonPill
// }
//
// "A controller" should "be able to pause/resume with count breakpoint in the workflow1" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan1))
// controller ! AckedControllerInitialization
// parent.expectMsg(30.seconds, ReportState(ControllerState.Ready))
// controller ! PassBreakpointTo("KeywordSearch", new CountGlobalBreakpoint("break1", 100000))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// val random = new Random()
// for (i <- 0 until 100) {
// if (random.nextBoolean()) {
// controller ! Pause
// } else {
// controller ! Resume
// }
// }
// controller ! Resume
// var isCompleted = false
// parent.receiveWhile(30.seconds, 10.seconds) {
// case ReportState(ControllerState.Paused) =>
// controller ! Resume
// case ReportState(ControllerState.Completed) =>
// isCompleted = true
// case _ =>
// }
// assert(isCompleted)
// parent.ref ! PoisonPill
// }
//
// "A controller" should "execute the workflow4 normally" in {
// val parent = TestProbe()
// val controller = parent.childActorOf(CONTROLLER.props(logicalPlan4))
// controller ! AckedControllerInitialization
// parent.expectMsg(ReportState(ControllerState.Ready))
// controller ! Start
// parent.expectMsg(ReportState(ControllerState.Running))
// parent.expectMsg(1.minute, ReportState(ControllerState.Completed))
// parent.ref ! PoisonPill
// }
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/controller/GlobalReplayManagerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.controller
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.scalatest.flatspec.AnyFlatSpec
class GlobalReplayManagerSpec extends AnyFlatSpec {
private class CallbackCounter {
var startCount = 0
var completeCount = 0
val onStart: () => Unit = () => startCount += 1
val onComplete: () => Unit = () => completeCount += 1
}
private val workerA = ActorVirtualIdentity("a")
private val workerB = ActorVirtualIdentity("b")
"GlobalReplayManager" should "fire onRecoveryStart on the first transition into recovery" in {
val cb = new CallbackCounter
val mgr = new GlobalReplayManager(cb.onStart, cb.onComplete)
mgr.markRecoveryStatus(workerA, isRecovering = true)
assert(cb.startCount == 1)
assert(cb.completeCount == 0)
}
it should "not refire onRecoveryStart while still recovering" in {
val cb = new CallbackCounter
val mgr = new GlobalReplayManager(cb.onStart, cb.onComplete)
mgr.markRecoveryStatus(workerA, isRecovering = true)
mgr.markRecoveryStatus(workerB, isRecovering = true)
assert(cb.startCount == 1, "onStart must fire only on the first transition into recovery")
}
it should "fire onRecoveryComplete only once all recovering workers have cleared" in {
val cb = new CallbackCounter
val mgr = new GlobalReplayManager(cb.onStart, cb.onComplete)
mgr.markRecoveryStatus(workerA, isRecovering = true)
mgr.markRecoveryStatus(workerB, isRecovering = true)
mgr.markRecoveryStatus(workerA, isRecovering = false)
assert(cb.completeCount == 0, "still has recovering workers")
mgr.markRecoveryStatus(workerB, isRecovering = false)
assert(cb.completeCount == 1)
}
it should "not fire onRecoveryComplete when no recovery was ever started" in {
val cb = new CallbackCounter
val mgr = new GlobalReplayManager(cb.onStart, cb.onComplete)
mgr.markRecoveryStatus(workerA, isRecovering = false)
assert(cb.startCount == 0)
assert(cb.completeCount == 0)
}
it should "be idempotent for repeated isRecovering=true on the same worker" in {
val cb = new CallbackCounter
val mgr = new GlobalReplayManager(cb.onStart, cb.onComplete)
mgr.markRecoveryStatus(workerA, isRecovering = true)
mgr.markRecoveryStatus(workerA, isRecovering = true)
mgr.markRecoveryStatus(workerA, isRecovering = false)
assert(cb.startCount == 1)
assert(cb.completeCount == 1)
}
it should "fire onRecoveryStart again when recovery restarts after completing" in {
val cb = new CallbackCounter
val mgr = new GlobalReplayManager(cb.onStart, cb.onComplete)
// First cycle: start and finish.
mgr.markRecoveryStatus(workerA, isRecovering = true)
mgr.markRecoveryStatus(workerA, isRecovering = false)
assert(cb.startCount == 1)
assert(cb.completeCount == 1)
// Second cycle: a new transition into recovery must fire onStart again,
// and the subsequent clear must fire onComplete again.
mgr.markRecoveryStatus(workerB, isRecovering = true)
mgr.markRecoveryStatus(workerB, isRecovering = false)
assert(cb.startCount == 2)
assert(cb.completeCount == 2)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/controller/WorkflowSchedulerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.controller
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.apache.texera.amber.engine.e2e.TestUtils.buildWorkflow
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.workflow.LogicalLink
import org.scalatest.flatspec.AnyFlatSpec
class WorkflowSchedulerSpec extends AnyFlatSpec {
private def buildHeaderlessCsvKeywordWorkflow() = {
val csvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
buildWorkflow(
List(csvOpDesc, keywordOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(0),
keywordOpDesc.operatorIdentifier,
PortIdentity(0)
)
),
new WorkflowContext()
)
}
"WorkflowScheduler.updateSchedule" should "populate the schedule and physicalPlan fields" in {
val workflow = buildHeaderlessCsvKeywordWorkflow()
val scheduler = new WorkflowScheduler(workflow.context, CONTROLLER)
assert(scheduler.getSchedule == null)
assert(scheduler.physicalPlan == null)
scheduler.updateSchedule(workflow.physicalPlan)
assert(scheduler.getSchedule != null)
assert(scheduler.physicalPlan != null)
assert(scheduler.getSchedule.getRegions.nonEmpty)
}
it should "include every workflow operator in some region of the produced schedule" in {
val workflow = buildHeaderlessCsvKeywordWorkflow()
val scheduler = new WorkflowScheduler(workflow.context, CONTROLLER)
scheduler.updateSchedule(workflow.physicalPlan)
val operatorsInSchedule = scheduler.getSchedule.getRegions
.flatMap(_.getOperators.map(_.id.logicalOpId))
.toSet
val operatorsInPlan = scheduler.physicalPlan.operators.map(_.id.logicalOpId)
assert(operatorsInPlan.subsetOf(operatorsInSchedule))
}
"WorkflowScheduler.getNextRegions" should "exhaust the schedule and then return an empty set" in {
val workflow = buildHeaderlessCsvKeywordWorkflow()
val scheduler = new WorkflowScheduler(workflow.context, CONTROLLER)
scheduler.updateSchedule(workflow.physicalPlan)
val pulledLevels = Iterator
.continually(scheduler.getNextRegions)
.takeWhile(_.nonEmpty)
.toList
assert(pulledLevels.nonEmpty)
assert(scheduler.getNextRegions.isEmpty)
}
it should "yield region sets that together cover every region in the schedule" in {
val workflow = buildHeaderlessCsvKeywordWorkflow()
val scheduler = new WorkflowScheduler(workflow.context, CONTROLLER)
scheduler.updateSchedule(workflow.physicalPlan)
val expectedRegions = scheduler.getSchedule.getRegions.toSet
val pulledRegions = Iterator
.continually(scheduler.getNextRegions)
.takeWhile(_.nonEmpty)
.flatten
.toSet
assert(pulledRegions == expectedRegions)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/controller/execution/ExecutionUtilsSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.controller.execution
import org.apache.texera.amber.core.workflow.PortIdentity
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState
import org.apache.texera.amber.engine.architecture.worker.statistics.{
PortTupleMetricsMapping,
TupleMetrics
}
import org.apache.texera.amber.engine.common.executionruntimestate.{
OperatorMetrics,
OperatorStatistics
}
import org.scalatest.flatspec.AnyFlatSpec
class ExecutionUtilsSpec extends AnyFlatSpec {
// Sentinel labels used as the generic T for ExecutionUtils.aggregateStates.
private val Completed = "completed"
private val Terminated = "terminated"
private val Running = "running"
private val Uninitialized = "uninitialized"
private val Paused = "paused"
private val Ready = "ready"
private def aggregate(states: String*): WorkflowAggregatedState =
ExecutionUtils.aggregateStates(
states,
Completed,
Terminated,
Running,
Uninitialized,
Paused,
Ready
)
"ExecutionUtils.aggregateStates" should "return UNINITIALIZED for an empty input" in {
assert(aggregate() == WorkflowAggregatedState.UNINITIALIZED)
}
it should "return COMPLETED when every state is the completed sentinel" in {
assert(aggregate(Completed, Completed) == WorkflowAggregatedState.COMPLETED)
}
it should "return COMPLETED when every state is the terminated sentinel" in {
assert(aggregate(Terminated, Terminated) == WorkflowAggregatedState.COMPLETED)
}
it should "return RUNNING when any state is the running sentinel" in {
assert(aggregate(Completed, Running, Paused) == WorkflowAggregatedState.RUNNING)
}
it should "return UNINITIALIZED when remaining non-completed states are all uninitialized" in {
assert(
aggregate(Completed, Uninitialized, Uninitialized) ==
WorkflowAggregatedState.UNINITIALIZED
)
}
it should "return PAUSED when remaining non-completed states are all paused" in {
assert(aggregate(Completed, Paused, Paused) == WorkflowAggregatedState.PAUSED)
}
it should "return RUNNING when remaining non-completed states are all ready" in {
// Note: an all-ready aggregate maps to RUNNING by current contract.
assert(aggregate(Completed, Ready, Ready) == WorkflowAggregatedState.RUNNING)
}
it should "return UNKNOWN when remaining non-completed states are mixed" in {
assert(aggregate(Completed, Paused, Ready) == WorkflowAggregatedState.UNKNOWN)
}
// Anti / boundary cases — make sure unexpected inputs cannot smuggle in a wrong
// state, and that branch precedence is what the contract claims.
it should "return UNKNOWN when completed and terminated are mixed (neither forall branch matches)" in {
// Both `forall(_ == completed)` and `forall(_ == terminated)` fail, no running
// sentinel is present, and the non-completed remainder is purely terminated —
// which is none of uninitialized / paused / ready, so the result must be
// UNKNOWN rather than COMPLETED.
assert(aggregate(Completed, Terminated) == WorkflowAggregatedState.UNKNOWN)
}
it should "give running precedence over completed and terminated" in {
assert(aggregate(Completed, Running) == WorkflowAggregatedState.RUNNING)
assert(aggregate(Terminated, Running) == WorkflowAggregatedState.RUNNING)
assert(aggregate(Running) == WorkflowAggregatedState.RUNNING)
}
it should "report PAUSED / UNINITIALIZED / RUNNING even when no completed sentinel is present" in {
assert(aggregate(Paused, Paused) == WorkflowAggregatedState.PAUSED)
assert(aggregate(Uninitialized, Uninitialized) == WorkflowAggregatedState.UNINITIALIZED)
// All-ready (no completed) maps to RUNNING, same as the with-completed case above.
assert(aggregate(Ready, Ready) == WorkflowAggregatedState.RUNNING)
}
it should "fall back to UNKNOWN when input contains values matching none of the sentinels" in {
// Defensive: a stray label that is not any of the six sentinels must not be
// silently classified as completed or running.
assert(aggregate("not-a-real-state") == WorkflowAggregatedState.UNKNOWN)
assert(aggregate(Completed, "not-a-real-state") == WorkflowAggregatedState.UNKNOWN)
}
// -- aggregatePortMetrics -----------------------------------------------
"ExecutionUtils.aggregatePortMetrics" should "return empty when given no mappings" in {
assert(ExecutionUtils.aggregatePortMetrics(Iterable.empty).isEmpty)
}
it should "preserve a single mapping" in {
val mapping = PortTupleMetricsMapping(PortIdentity(0), TupleMetrics(3, 30))
assert(ExecutionUtils.aggregatePortMetrics(List(mapping)) == Seq(mapping))
}
it should "sum count and size across mappings on the same port" in {
val portId = PortIdentity(0)
val a = PortTupleMetricsMapping(portId, TupleMetrics(3, 30))
val b = PortTupleMetricsMapping(portId, TupleMetrics(5, 50))
val result = ExecutionUtils.aggregatePortMetrics(List(a, b))
assert(result == Seq(PortTupleMetricsMapping(portId, TupleMetrics(8, 80))))
}
it should "group mappings by port id when ports differ" in {
val a = PortTupleMetricsMapping(PortIdentity(0), TupleMetrics(1, 10))
val b = PortTupleMetricsMapping(PortIdentity(1), TupleMetrics(2, 20))
val result = ExecutionUtils.aggregatePortMetrics(List(a, b)).toSet
assert(result == Set(a, b))
}
it should "sum more than two mappings on the same port without losing any" in {
val portId = PortIdentity(0)
val mappings = List(
PortTupleMetricsMapping(portId, TupleMetrics(1, 10)),
PortTupleMetricsMapping(portId, TupleMetrics(2, 20)),
PortTupleMetricsMapping(portId, TupleMetrics(4, 40))
)
assert(
ExecutionUtils.aggregatePortMetrics(mappings) ==
Seq(PortTupleMetricsMapping(portId, TupleMetrics(7, 70)))
)
}
it should "sum independently per port when multiple ports each have multiple mappings" in {
val port0 = PortIdentity(0)
val port1 = PortIdentity(1)
val mappings = List(
PortTupleMetricsMapping(port0, TupleMetrics(1, 10)),
PortTupleMetricsMapping(port1, TupleMetrics(3, 30)),
PortTupleMetricsMapping(port0, TupleMetrics(2, 20)),
PortTupleMetricsMapping(port1, TupleMetrics(4, 40))
)
val result = ExecutionUtils.aggregatePortMetrics(mappings).toSet
assert(
result == Set(
PortTupleMetricsMapping(port0, TupleMetrics(3, 30)),
PortTupleMetricsMapping(port1, TupleMetrics(7, 70))
)
)
}
it should "preserve a zero-count, zero-size mapping rather than dropping it" in {
val mapping = PortTupleMetricsMapping(PortIdentity(0), TupleMetrics(0, 0))
assert(ExecutionUtils.aggregatePortMetrics(List(mapping)) == Seq(mapping))
}
// -- aggregateMetrics ---------------------------------------------------
private def metricsWith(
state: WorkflowAggregatedState,
input: Seq[PortTupleMetricsMapping] = Seq.empty,
output: Seq[PortTupleMetricsMapping] = Seq.empty,
numWorkers: Int = 0,
dataTime: Long = 0,
controlTime: Long = 0,
idleTime: Long = 0
): OperatorMetrics =
OperatorMetrics(
state,
OperatorStatistics(input, output, numWorkers, dataTime, controlTime, idleTime)
)
"ExecutionUtils.aggregateMetrics" should "return UNINITIALIZED defaults when given no metrics" in {
val result = ExecutionUtils.aggregateMetrics(Iterable.empty)
assert(result.operatorState == WorkflowAggregatedState.UNINITIALIZED)
assert(result.operatorStatistics.inputMetrics.isEmpty)
assert(result.operatorStatistics.outputMetrics.isEmpty)
assert(result.operatorStatistics.numWorkers == 0)
assert(result.operatorStatistics.dataProcessingTime == 0)
assert(result.operatorStatistics.controlProcessingTime == 0)
assert(result.operatorStatistics.idleTime == 0)
}
it should "sum scalar statistics and merge per-port metrics across operators" in {
val portIn = PortIdentity(0)
val portOut = PortIdentity(0)
val left = metricsWith(
WorkflowAggregatedState.RUNNING,
input = Seq(PortTupleMetricsMapping(portIn, TupleMetrics(2, 20))),
output = Seq(PortTupleMetricsMapping(portOut, TupleMetrics(1, 10))),
numWorkers = 1,
dataTime = 100,
controlTime = 5,
idleTime = 1
)
val right = metricsWith(
WorkflowAggregatedState.RUNNING,
input = Seq(PortTupleMetricsMapping(portIn, TupleMetrics(3, 30))),
output = Seq(PortTupleMetricsMapping(portOut, TupleMetrics(4, 40))),
numWorkers = 2,
dataTime = 200,
controlTime = 10,
idleTime = 2
)
val result = ExecutionUtils.aggregateMetrics(List(left, right))
assert(result.operatorState == WorkflowAggregatedState.RUNNING)
assert(
result.operatorStatistics.inputMetrics ==
Seq(PortTupleMetricsMapping(portIn, TupleMetrics(5, 50)))
)
assert(
result.operatorStatistics.outputMetrics ==
Seq(PortTupleMetricsMapping(portOut, TupleMetrics(5, 50)))
)
assert(result.operatorStatistics.numWorkers == 3)
assert(result.operatorStatistics.dataProcessingTime == 300)
assert(result.operatorStatistics.controlProcessingTime == 15)
assert(result.operatorStatistics.idleTime == 3)
}
it should "filter out internal ports when aggregating port metrics" in {
val publicPort = PortIdentity(0)
val internalPort = PortIdentity(1, internal = true)
val metrics = metricsWith(
WorkflowAggregatedState.RUNNING,
input = Seq(
PortTupleMetricsMapping(publicPort, TupleMetrics(1, 10)),
PortTupleMetricsMapping(internalPort, TupleMetrics(99, 990))
),
output = Seq(PortTupleMetricsMapping(internalPort, TupleMetrics(7, 70)))
)
val result = ExecutionUtils.aggregateMetrics(List(metrics))
assert(
result.operatorStatistics.inputMetrics ==
Seq(PortTupleMetricsMapping(publicPort, TupleMetrics(1, 10)))
)
assert(result.operatorStatistics.outputMetrics.isEmpty)
}
it should "preserve a single operator's statistics (modulo internal-port filtering)" in {
val portIn = PortIdentity(0)
val portOut = PortIdentity(0)
val single = metricsWith(
WorkflowAggregatedState.RUNNING,
input = Seq(PortTupleMetricsMapping(portIn, TupleMetrics(2, 20))),
output = Seq(PortTupleMetricsMapping(portOut, TupleMetrics(3, 30))),
numWorkers = 4,
dataTime = 50,
controlTime = 6,
idleTime = 1
)
val result = ExecutionUtils.aggregateMetrics(List(single))
assert(result.operatorState == WorkflowAggregatedState.RUNNING)
assert(
result.operatorStatistics.inputMetrics ==
Seq(PortTupleMetricsMapping(portIn, TupleMetrics(2, 20)))
)
assert(
result.operatorStatistics.outputMetrics ==
Seq(PortTupleMetricsMapping(portOut, TupleMetrics(3, 30)))
)
assert(result.operatorStatistics.numWorkers == 4)
assert(result.operatorStatistics.dataProcessingTime == 50)
assert(result.operatorStatistics.controlProcessingTime == 6)
assert(result.operatorStatistics.idleTime == 1)
}
it should "report RUNNING when at least one operator is running and the rest are completed" in {
val running = metricsWith(WorkflowAggregatedState.RUNNING)
val completed = metricsWith(WorkflowAggregatedState.COMPLETED)
val result = ExecutionUtils.aggregateMetrics(List(running, completed))
assert(result.operatorState == WorkflowAggregatedState.RUNNING)
}
it should "report COMPLETED when every operator is completed" in {
val completedA = metricsWith(WorkflowAggregatedState.COMPLETED, numWorkers = 1)
val completedB = metricsWith(WorkflowAggregatedState.COMPLETED, numWorkers = 2)
val result = ExecutionUtils.aggregateMetrics(List(completedA, completedB))
assert(result.operatorState == WorkflowAggregatedState.COMPLETED)
assert(result.operatorStatistics.numWorkers == 3)
}
it should "tolerate operators with empty per-port stats while summing scalars" in {
val withStats = metricsWith(
WorkflowAggregatedState.RUNNING,
input = Seq(PortTupleMetricsMapping(PortIdentity(0), TupleMetrics(1, 10))),
numWorkers = 1,
dataTime = 5
)
val empty = metricsWith(WorkflowAggregatedState.RUNNING, numWorkers = 2, dataTime = 7)
val result = ExecutionUtils.aggregateMetrics(List(withStats, empty))
assert(result.operatorState == WorkflowAggregatedState.RUNNING)
assert(
result.operatorStatistics.inputMetrics ==
Seq(PortTupleMetricsMapping(PortIdentity(0), TupleMetrics(1, 10)))
)
assert(result.operatorStatistics.outputMetrics.isEmpty)
assert(result.operatorStatistics.numWorkers == 3)
assert(result.operatorStatistics.dataProcessingTime == 12)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/controller/execution/LinkExecutionSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.controller.execution
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.scalatest.flatspec.AnyFlatSpec
class LinkExecutionSpec extends AnyFlatSpec {
private def channelId(from: String, to: String, isControl: Boolean = false): ChannelIdentity =
ChannelIdentity(ActorVirtualIdentity(from), ActorVirtualIdentity(to), isControl)
"LinkExecution" should "have no channel executions when freshly constructed" in {
val link = LinkExecution()
assert(link.getAllChannelExecutions.isEmpty)
}
"LinkExecution.initChannelExecution" should "register a new ChannelExecution for the given channel id" in {
val link = LinkExecution()
val cid = channelId("a", "b")
link.initChannelExecution(cid)
val all = link.getAllChannelExecutions.toMap
assert(all.contains(cid))
assert(all(cid) == ChannelExecution())
}
it should "throw an AssertionError if called twice for the same channel id" in {
val link = LinkExecution()
val cid = channelId("a", "b")
link.initChannelExecution(cid)
assertThrows[AssertionError] {
link.initChannelExecution(cid)
}
}
it should "track multiple distinct channel executions" in {
val link = LinkExecution()
val c1 = channelId("a", "b")
val c2 = channelId("a", "b", isControl = true)
val c3 = channelId("a", "c")
link.initChannelExecution(c1)
link.initChannelExecution(c2)
link.initChannelExecution(c3)
assert(link.getAllChannelExecutions.toMap.keySet == Set(c1, c2, c3))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/controller/execution/WorkerPortExecutionSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.controller.execution
import org.scalatest.flatspec.AnyFlatSpec
class WorkerPortExecutionSpec extends AnyFlatSpec {
"WorkerPortExecution" should "be incomplete by default" in {
assert(!WorkerPortExecution().completed)
}
"WorkerPortExecution.setCompleted" should "flip the completed flag to true" in {
val wpe = WorkerPortExecution()
wpe.setCompleted()
assert(wpe.completed)
}
it should "be idempotent on repeated calls" in {
val wpe = WorkerPortExecution()
wpe.setCompleted()
wpe.setCompleted()
assert(wpe.completed)
}
it should "not affect a separate WorkerPortExecution instance" in {
val a = WorkerPortExecution()
val b = WorkerPortExecution()
a.setCompleted()
assert(a.completed)
assert(!b.completed)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/controller/execution/WorkflowExecutionSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.controller.execution
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.virtualidentity.{
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState
import org.apache.texera.amber.engine.architecture.scheduling.{Region, RegionIdentity}
import org.scalatest.flatspec.AnyFlatSpec
class WorkflowExecutionSpec extends AnyFlatSpec {
private def physicalOpId(opId: String): PhysicalOpIdentity =
PhysicalOpIdentity(OperatorIdentity(opId), "main")
private def op(opId: String): PhysicalOp =
PhysicalOp(
physicalOpId(opId),
WorkflowIdentity(0),
ExecutionIdentity(0),
OpExecInitInfo.Empty
)
/** A region with no ports — its `RegionExecution.getState` defaults to COMPLETED. */
private def region(regionId: Long, opId: String): Region =
Region(RegionIdentity(regionId), Set(op(opId)), Set.empty)
"WorkflowExecution.initRegionExecution" should "create a new RegionExecution for the given region" in {
val we = WorkflowExecution()
val r = region(1, "a")
val regionExecution = we.initRegionExecution(r)
assert(regionExecution.region == r)
assert(we.getRegionExecution(r.id) eq regionExecution)
}
it should "throw when called twice for the same region id" in {
val we = WorkflowExecution()
val r = region(1, "a")
we.initRegionExecution(r)
assertThrows[AssertionError] {
we.initRegionExecution(r)
}
}
"WorkflowExecution.hasRegionExecution" should "be false before init and true after" in {
val we = WorkflowExecution()
val r = region(1, "a")
assert(!we.hasRegionExecution(r.id))
we.initRegionExecution(r)
assert(we.hasRegionExecution(r.id))
}
"WorkflowExecution.getRegionExecution" should "throw NoSuchElementException for an unknown region id" in {
val we = WorkflowExecution()
assertThrows[NoSuchElementException] {
we.getRegionExecution(RegionIdentity(99))
}
}
"WorkflowExecution.getAllRegionExecutions" should "preserve the insertion order of region executions" in {
val we = WorkflowExecution()
val r0 = region(0, "a")
val r1 = region(1, "b")
val r2 = region(2, "c")
val e0 = we.initRegionExecution(r0)
val e1 = we.initRegionExecution(r1)
val e2 = we.initRegionExecution(r2)
assert(we.getAllRegionExecutions.toList == List(e0, e1, e2))
}
"WorkflowExecution.restartRegionExecution" should "behave like a fresh init when no prior region execution exists" in {
val we = WorkflowExecution()
val r = region(1, "a")
val regionExecution = we.restartRegionExecution(r)
assert(we.hasRegionExecution(r.id))
assert(we.getRegionExecution(r.id) eq regionExecution)
}
it should "replace an existing completed region execution with a fresh one" in {
val we = WorkflowExecution()
val r = region(1, "a")
val original = we.initRegionExecution(r)
assert(original.isCompleted)
val replacement = we.restartRegionExecution(r)
assert(replacement ne original)
assert(we.getRegionExecution(r.id) eq replacement)
}
"WorkflowExecution.getRunningRegionExecutions" should "exclude completed region executions" in {
val we = WorkflowExecution()
val r = region(1, "a")
val regionExecution = we.initRegionExecution(r)
assert(regionExecution.isCompleted)
assert(we.getRunningRegionExecutions.toList.isEmpty)
}
"WorkflowExecution.getState" should "return UNINITIALIZED when no regions have been initialized" in {
val we = WorkflowExecution()
assert(we.getState == WorkflowAggregatedState.UNINITIALIZED)
assert(!we.isCompleted)
}
it should "return COMPLETED when every initialized region is completed" in {
val we = WorkflowExecution()
we.initRegionExecution(region(0, "a"))
we.initRegionExecution(region(1, "b"))
assert(we.getState == WorkflowAggregatedState.COMPLETED)
assert(we.isCompleted)
}
"WorkflowExecution.getLatestOperatorExecutionOption" should "return None when no operator execution exists for the id" in {
val we = WorkflowExecution()
we.initRegionExecution(region(0, "a"))
assert(we.getLatestOperatorExecutionOption(physicalOpId("never-initialized")).isEmpty)
}
it should "return the latest matching operator execution across regions" in {
val we = WorkflowExecution()
val regionA = we.initRegionExecution(region(0, "a"))
val regionB = we.initRegionExecution(region(1, "b"))
val olderExecution = regionA.initOperatorExecution(physicalOpId("a"))
val newerExecution = regionB.initOperatorExecution(physicalOpId("a"))
val result = we.getLatestOperatorExecutionOption(physicalOpId("a"))
// Use reference identity: OperatorExecution is a no-field case class so
// instances are structurally equal; only `eq` distinguishes them.
assert(result.exists(_ eq newerExecution))
assert(!result.exists(_ eq olderExecution))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/deploysemantics/AddressInfoSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.deploysemantics
import org.apache.pekko.actor.Address
import org.scalatest.flatspec.AnyFlatSpec
class AddressInfoSpec extends AnyFlatSpec {
private def addr(host: String, port: Int): Address =
Address("pekko", "Amber", host, port)
"AddressInfo" should "expose the addresses it was constructed with" in {
val nodes = Array(addr("h1", 2552), addr("h2", 2552), addr("h3", 2552))
val controller = addr("ctrl", 2552)
val info = AddressInfo(nodes, controller)
assert(info.allAddresses.toList == nodes.toList)
assert(info.controllerAddress == controller)
}
it should "preserve the order of allAddresses" in {
// The cluster scheduler picks workers based on this list's order, so
// any reorder is observable.
val nodes = Array(addr("c", 1), addr("a", 2), addr("b", 3))
val info = AddressInfo(nodes, addr("ctrl", 0))
assert(info.allAddresses.map(_.host.get).toList == List("c", "a", "b"))
}
it should "accept an empty allAddresses array" in {
// Edge case: no worker nodes (e.g., controller-only configuration).
val info = AddressInfo(Array.empty[Address], addr("ctrl", 0))
assert(info.allAddresses.isEmpty)
assert(info.controllerAddress.host.contains("ctrl"))
}
it should "allow the controller to also appear in allAddresses (collocated)" in {
val controller = addr("ctrl", 2552)
val info = AddressInfo(Array(controller, addr("worker", 2552)), controller)
assert(info.allAddresses.contains(controller))
assert(info.controllerAddress == controller)
}
it should "support copy(), allowing one field to change while the other is preserved" in {
val a = AddressInfo(Array(addr("h1", 1)), addr("ctrl-a", 0))
val b = a.copy(controllerAddress = addr("ctrl-b", 0))
assert(b.controllerAddress.host.contains("ctrl-b"))
assert(b.allAddresses.toList == a.allAddresses.toList)
// original is unchanged
assert(a.controllerAddress.host.contains("ctrl-a"))
}
it should "use Array reference equality (not element-wise) for the allAddresses field" in {
// Case-class equality on `Array` fields uses array reference equality,
// not element-wise equality. Two AddressInfo values that hold the SAME
// array instance compare equal; two AddressInfo values that hold
// distinct arrays with the SAME elements do NOT. Lock this down so a
// future change to (say) Seq doesn't silently flip equality semantics
// for callers.
val nodes = Array(addr("h", 1))
val ctrl = addr("ctrl", 0)
val sameRef = AddressInfo(nodes, ctrl)
val sameRefAgain = AddressInfo(nodes, ctrl) // shares the same array reference
val differentRef = AddressInfo(Array(addr("h", 1)), ctrl) // different array reference
assert(sameRef == sameRefAgain, "shared Array reference → equal")
assert(sameRef != differentRef, "distinct Array references → not equal (no element-wise check)")
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/deploysemantics/deploystrategy/DeployStrategiesSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.deploysemantics.deploystrategy
import org.apache.pekko.actor.Address
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
class DeployStrategiesSpec extends AnyFlatSpec with Matchers {
// Use the "pekko" protocol to match Amber's real node addresses
// (e.g. AmberConfig.masterNodeAddr); "akka" diverges from production and
// can mislead anyone who debugs a failure by comparing addresses.
private val nodeA = Address("pekko", "sys", "host-a", 2552)
private val nodeB = Address("pekko", "sys", "host-b", 2552)
private val nodeC = Address("pekko", "sys", "host-c", 2552)
// ----- OneOnEach -----
"OneOnEach" should "hand out each address exactly once in array order" in {
val strategy = OneOnEach()
strategy.initialize(Array(nodeA, nodeB, nodeC))
strategy.next() shouldBe nodeA
strategy.next() shouldBe nodeB
strategy.next() shouldBe nodeC
}
it should "raise IndexOutOfBoundsException once the array is exhausted" in {
val strategy = OneOnEach()
strategy.initialize(Array(nodeA))
strategy.next() shouldBe nodeA
assertThrows[IndexOutOfBoundsException](strategy.next())
}
it should "raise IndexOutOfBoundsException immediately when initialized with an empty array" in {
val strategy = OneOnEach()
strategy.initialize(Array.empty[Address])
assertThrows[IndexOutOfBoundsException](strategy.next())
}
it should "preserve its iteration cursor across re-initialization (current behavior)" in {
// Pin: initialize() replaces the array reference but does NOT reset the
// index, so a re-initialized strategy continues counting from the prior
// position. A future fix that zeroes index inside initialize will break
// this spec on purpose so the contract change is reviewed.
val strategy = OneOnEach()
strategy.initialize(Array(nodeA, nodeB))
strategy.next() shouldBe nodeA
strategy.initialize(Array(nodeC))
// index is still 1 from the previous run; the new single-element array
// is therefore reported as exhausted.
assertThrows[IndexOutOfBoundsException](strategy.next())
}
"OneOnEach.apply" should "produce a fresh, independent instance" in {
val s1 = OneOnEach()
val s2 = OneOnEach()
s1 should not be theSameInstanceAs(s2)
}
// ----- RoundRobinDeployment -----
"RoundRobinDeployment" should "rotate addresses in a repeating cycle" in {
val strategy = RoundRobinDeployment()
strategy.initialize(Array(nodeA, nodeB, nodeC))
strategy.next() shouldBe nodeA
strategy.next() shouldBe nodeB
strategy.next() shouldBe nodeC
strategy.next() shouldBe nodeA
strategy.next() shouldBe nodeB
}
it should "always return the only address when the array has length 1" in {
val strategy = RoundRobinDeployment()
strategy.initialize(Array(nodeA))
for (_ <- 1 to 5) strategy.next() shouldBe nodeA
}
it should "raise ArithmeticException on next() with an empty array (current behavior)" in {
// Pin: RoundRobinDeployment.next does `index = (index + 1) % length`,
// which divides by zero when length == 0 and crashes with
// ArithmeticException before any address is returned. Other strategies
// raise IndexOutOfBoundsException for the same situation, so this is a
// contract divergence — pinning the current behavior so a future fix
// that aligns the empty-array error type will need to update this spec.
val strategy = RoundRobinDeployment()
strategy.initialize(Array.empty[Address])
assertThrows[ArithmeticException](strategy.next())
}
"RoundRobinDeployment.apply" should "produce a fresh, independent instance" in {
val s1 = RoundRobinDeployment()
val s2 = RoundRobinDeployment()
s1 should not be theSameInstanceAs(s2)
}
// ----- RandomDeployment -----
"RandomDeployment" should "always return one of the available addresses" in {
val strategy = RandomDeployment()
val pool = Array(nodeA, nodeB, nodeC)
strategy.initialize(pool)
val poolSet = pool.toSet
for (_ <- 1 to 50) {
poolSet should contain(strategy.next())
}
}
it should "always return the only address when the array has length 1" in {
val strategy = RandomDeployment()
strategy.initialize(Array(nodeA))
for (_ <- 1 to 5) strategy.next() shouldBe nodeA
}
it should "raise IllegalArgumentException on next() with an empty array (current behavior)" in {
// Pin: RandomDeployment.next() calls Random.nextInt(0), which throws
// IllegalArgumentException with bound must be positive. Same root issue
// as the empty-array case for RoundRobinDeployment: each strategy reports
// the empty-array fault with a different exception type. Pinning this
// separately so a unification fix shows up here.
val strategy = RandomDeployment()
strategy.initialize(Array.empty[Address])
assertThrows[IllegalArgumentException](strategy.next())
}
"RandomDeployment.apply" should "produce a fresh, independent instance" in {
val s1 = RandomDeployment()
val s2 = RandomDeployment()
s1 should not be theSameInstanceAs(s2)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/deploysemantics/layer/WorkerExecutionSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.deploysemantics.layer
import org.apache.texera.amber.core.workflow.PortIdentity
import org.apache.texera.amber.engine.architecture.worker.statistics.{WorkerState, WorkerStatistics}
import org.scalatest.flatspec.AnyFlatSpec
class WorkerExecutionSpec extends AnyFlatSpec {
private def stats(idle: Long): WorkerStatistics =
WorkerStatistics(Seq.empty, Seq.empty, 0L, 0L, idle)
"WorkerExecution" should "have UNINITIALIZED state and zeroed stats by default" in {
val we = WorkerExecution()
assert(we.getState == WorkerState.UNINITIALIZED)
assert(we.getStats.idleTime == 0L)
assert(we.getStats.dataProcessingTime == 0L)
assert(we.getStats.controlProcessingTime == 0L)
}
"WorkerExecution.update(state)" should "apply when the timestamp is newer" in {
val we = WorkerExecution()
we.update(timeStamp = 10L, state = WorkerState.RUNNING)
assert(we.getState == WorkerState.RUNNING)
}
it should "ignore updates with a non-newer timestamp" in {
val we = WorkerExecution()
we.update(timeStamp = 10L, state = WorkerState.RUNNING)
we.update(timeStamp = 10L, state = WorkerState.PAUSED) // not strictly newer
we.update(timeStamp = 5L, state = WorkerState.COMPLETED) // older
assert(we.getState == WorkerState.RUNNING)
}
"WorkerExecution.update(state, stats)" should "update both atomically when newer" in {
val we = WorkerExecution()
we.update(timeStamp = 10L, state = WorkerState.RUNNING, stats = stats(idle = 7L))
assert(we.getState == WorkerState.RUNNING)
assert(we.getStats.idleTime == 7L)
}
it should "ignore updates with a non-newer timestamp" in {
val we = WorkerExecution()
we.update(timeStamp = 10L, state = WorkerState.RUNNING, stats = stats(idle = 7L))
we.update(timeStamp = 5L, state = WorkerState.COMPLETED, stats = stats(idle = 99L))
assert(we.getState == WorkerState.RUNNING)
assert(we.getStats.idleTime == 7L)
}
"WorkerExecution.update(stats)" should "update only the stats when newer" in {
val we = WorkerExecution()
we.update(timeStamp = 10L, state = WorkerState.RUNNING, stats = stats(idle = 7L))
we.update(timeStamp = 20L, stats = stats(idle = 42L))
assert(we.getState == WorkerState.RUNNING)
assert(we.getStats.idleTime == 42L)
}
it should "ignore stats updates with a non-newer timestamp" in {
val we = WorkerExecution()
we.update(timeStamp = 20L, stats = stats(idle = 42L))
we.update(timeStamp = 20L, stats = stats(idle = 99L)) // not strictly newer
we.update(timeStamp = 5L, stats = stats(idle = 0L)) // older
assert(we.getStats.idleTime == 42L)
}
"WorkerExecution.getInputPortExecution" should "lazily create and reuse a port execution per port id" in {
val we = WorkerExecution()
val first = we.getInputPortExecution(PortIdentity(0))
val same = we.getInputPortExecution(PortIdentity(0))
val other = we.getInputPortExecution(PortIdentity(1))
assert(first eq same)
assert(first ne other)
}
"WorkerExecution.getOutputPortExecution" should "lazily create and reuse a port execution per port id" in {
val we = WorkerExecution()
val first = we.getOutputPortExecution(PortIdentity(0))
val same = we.getOutputPortExecution(PortIdentity(0))
assert(first eq same)
}
it should "use a separate map from getInputPortExecution" in {
val we = WorkerExecution()
val input = we.getInputPortExecution(PortIdentity(0))
val output = we.getOutputPortExecution(PortIdentity(0))
assert(input ne output)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/logreplay/EmptyReplayLogManagerImplSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.logreplay
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ChannelIdentity,
EmbeddedControlMessageIdentity
}
import org.apache.texera.amber.engine.architecture.common.ProcessingStepCursor
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.MainThreadDelegateMessage
import org.apache.texera.amber.engine.common.ambermessage.{
DataFrame,
WorkflowFIFOMessage,
WorkflowFIFOMessagePayload
}
import org.apache.texera.amber.engine.common.storage.EmptyRecordStorage
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class EmptyReplayLogManagerImplSpec extends AnyFlatSpec {
private val channel =
ChannelIdentity(ActorVirtualIdentity("from"), ActorVirtualIdentity("to"), isControl = false)
private def fifo(
seq: Long,
payload: WorkflowFIFOMessagePayload = DataFrame(Array.empty)
): WorkflowFIFOMessage =
WorkflowFIFOMessage(channel, seq, payload)
private class CapturingHandler {
val received: mutable.ListBuffer[Either[MainThreadDelegateMessage, WorkflowFIFOMessage]] =
mutable.ListBuffer()
val handler: Either[MainThreadDelegateMessage, WorkflowFIFOMessage] => Unit =
msg => received += msg
}
"EmptyReplayLogManagerImpl" should "expose getStep starting at INIT_STEP" in {
val mgr = new EmptyReplayLogManagerImpl(_ => ())
assert(mgr.getStep == ProcessingStepCursor.INIT_STEP)
}
it should "no-op on setupWriter / markAsReplayDestination / terminate" in {
val mgr = new EmptyReplayLogManagerImpl(_ => ())
// Use real fixtures rather than nulls so the test reflects realistic
// call sites and would catch an accidental NPE if the no-op shape ever
// changes.
val writer = new EmptyRecordStorage[ReplayLogRecord]().getWriter("x")
mgr.setupWriter(writer)
mgr.markAsReplayDestination(EmbeddedControlMessageIdentity("test"))
mgr.terminate()
assert(mgr.getStep == ProcessingStepCursor.INIT_STEP)
}
"EmptyReplayLogManagerImpl.sendCommitted" should "forward the message to the configured handler" in {
val cap = new CapturingHandler
val mgr = new EmptyReplayLogManagerImpl(cap.handler)
val msg = Right[MainThreadDelegateMessage, WorkflowFIFOMessage](fifo(1L))
mgr.sendCommitted(msg)
assert(cap.received.toList == List(msg))
}
"ReplayLogManager.withFaultTolerant" should "advance the step counter after the body runs" in {
val mgr = new EmptyReplayLogManagerImpl(_ => ())
// Express the expected step relative to INIT_STEP so the test does not
// need to be touched if the initial-step constant ever changes.
mgr.withFaultTolerant(channel, Some(fifo(1L))) {}
assert(mgr.getStep == ProcessingStepCursor.INIT_STEP + 1)
mgr.withFaultTolerant(channel, Some(fifo(2L))) {}
assert(mgr.getStep == ProcessingStepCursor.INIT_STEP + 2)
}
it should "still advance the step counter and rethrow when the body throws" in {
val mgr = new EmptyReplayLogManagerImpl(_ => ())
intercept[RuntimeException] {
mgr.withFaultTolerant(channel, Some(fifo(1L))) {
throw new RuntimeException("boom")
}
}
assert(mgr.getStep == ProcessingStepCursor.INIT_STEP + 1)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/logreplay/LogreplayPrimitivesSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.logreplay
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.serialization.{Serialization, SerializationExtension}
import org.apache.pekko.testkit.TestKit
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ChannelIdentity,
EmbeddedControlMessageIdentity
}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
AsyncRPCContext,
ControlInvocation,
EmptyRequest
}
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.{
EmptyReturn,
ReturnInvocation
}
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.common.ambermessage.{
WorkflowFIFOMessage,
WorkflowFIFOMessagePayload
}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class LogreplayPrimitivesSpec extends AnyFlatSpec with BeforeAndAfterAll {
private val workerId = ActorVirtualIdentity("worker-1")
private val cidA =
ChannelIdentity(ActorVirtualIdentity("up1"), workerId, isControl = false)
private val cidB =
ChannelIdentity(ActorVirtualIdentity("up2"), workerId, isControl = false)
private val cidC =
ChannelIdentity(ActorVirtualIdentity("up3"), workerId, isControl = false)
// Suite-local ActorSystem + Serialization injected into AmberRuntime so the
// ReplayLogRecord round-trip below uses the same Pekko serialization stack
// that SequentialRecordStorage uses in production. Torn down in afterAll
// so no Pekko threads outlive the suite. (Same pattern as
// CheckpointSubsystemSpec.)
private val testSystem: ActorSystem =
ActorSystem("LogreplayPrimitivesSpec-test", AmberRuntime.pekkoConfig)
private val testSerde: Serialization = SerializationExtension(testSystem)
private def setAmberRuntimeField(name: String, value: AnyRef): Unit = {
val field = AmberRuntime.getClass.getDeclaredField(name)
field.setAccessible(true)
field.set(AmberRuntime, value)
}
override protected def beforeAll(): Unit = {
super.beforeAll()
setAmberRuntimeField("_actorSystem", testSystem)
setAmberRuntimeField("_serde", testSerde)
}
override protected def afterAll(): Unit = {
setAmberRuntimeField("_serde", null)
setAmberRuntimeField("_actorSystem", null)
TestKit.shutdownActorSystem(testSystem)
super.afterAll()
}
private case class FixedSizePayload() extends WorkflowFIFOMessagePayload
private def msg(seq: Long): WorkflowFIFOMessage =
WorkflowFIFOMessage(cidA, seq, FixedSizePayload())
// ---------------------------------------------------------------------------
// ReplayLoggerImpl
// ---------------------------------------------------------------------------
"ReplayLoggerImpl.logCurrentStepWithMessage" should "append a ProcessingStep when the channel changes" in {
val l = new ReplayLoggerImpl()
l.logCurrentStepWithMessage(0L, cidA, None)
val drained = l.drainCurrentLogRecords(0L)
assert(drained.toList == List(ProcessingStep(cidA, 0L)))
}
it should "skip a same-channel call with no message" in {
val l = new ReplayLoggerImpl()
l.logCurrentStepWithMessage(0L, cidA, None)
l.drainCurrentLogRecords(0L) // reset
l.logCurrentStepWithMessage(1L, cidA, None) // same channel, no message
val drained = l.drainCurrentLogRecords(1L)
// Should still only carry the trailing ProcessingStep emitted by drain.
assert(drained.toList == List(ProcessingStep(cidA, 1L)))
}
it should "append both a ProcessingStep and a MessageContent when a message is provided" in {
val l = new ReplayLoggerImpl()
val m = msg(7L)
l.logCurrentStepWithMessage(2L, cidA, Some(m))
val drained = l.drainCurrentLogRecords(2L)
assert(drained.toList == List(ProcessingStep(cidA, 2L), MessageContent(m)))
}
it should "still log when a same-channel call carries a message (only no-message + same-channel is skipped)" in {
// The skip guard in logCurrentStepWithMessage is `currentChannelId == channelId
// && message.isEmpty` — both conditions, not just the channel match. After a
// first call sets the current channel, a *subsequent* same-channel call with
// a non-empty message must still emit ProcessingStep + MessageContent.
val l = new ReplayLoggerImpl()
l.logCurrentStepWithMessage(0L, cidA, None) // sets currentChannelId = cidA
l.drainCurrentLogRecords(0L) // reset
val m = msg(11L)
l.logCurrentStepWithMessage(1L, cidA, Some(m)) // SAME channel, WITH message
val drained = l.drainCurrentLogRecords(1L)
assert(drained.toList == List(ProcessingStep(cidA, 1L), MessageContent(m)))
}
it should "append a ProcessingStep on a channel switch even if no message is provided" in {
val l = new ReplayLoggerImpl()
l.logCurrentStepWithMessage(0L, cidA, None)
l.logCurrentStepWithMessage(1L, cidB, None) // channel change → must record
val drained = l.drainCurrentLogRecords(1L)
assert(drained.toList == List(ProcessingStep(cidA, 0L), ProcessingStep(cidB, 1L)))
}
"ReplayLoggerImpl.markAsReplayDestination" should
"preserve exact ordering: in-flight ProcessingStep, then ReplayDestination, then synthetic trailing step" in {
// ReplayLogGenerator depends on the relative position of ReplayDestination
// within the record stream — replay stops at it. So a `contains` check
// would silently accept a regression that duplicated ReplayDestination or
// moved it after the synthetic trailing ProcessingStep emitted by drain.
// Pin the full sequence instead.
val l = new ReplayLoggerImpl()
val ecm = EmbeddedControlMessageIdentity("checkpoint-1")
l.logCurrentStepWithMessage(0L, cidA, None) // sets currentChannelId, appends ProcessingStep
l.markAsReplayDestination(ecm)
// Drain at a step beyond lastStep so the synthetic trailing ProcessingStep
// is also emitted; this is exactly the production drain behavior we need
// to lock down (the synthetic step must come AFTER the destination).
val drained = l.drainCurrentLogRecords(3L).toList
assert(
drained == List(
ProcessingStep(cidA, 0L),
ReplayDestination(ecm),
ProcessingStep(cidA, 3L)
),
s"unexpected record order: $drained"
)
}
"ReplayLoggerImpl.drainCurrentLogRecords" should "clear the buffer between drains" in {
val l = new ReplayLoggerImpl()
l.logCurrentStepWithMessage(0L, cidA, None)
val first = l.drainCurrentLogRecords(0L)
val second = l.drainCurrentLogRecords(0L)
assert(first.nonEmpty)
assert(second.isEmpty, "second drain must yield no leftover records")
}
it should "append a synthetic ProcessingStep when the requested step differs from lastStep" in {
val l = new ReplayLoggerImpl()
l.logCurrentStepWithMessage(0L, cidA, None)
val drained = l.drainCurrentLogRecords(5L)
// Two records: the original ProcessingStep at step 0 and the synthetic one at step 5.
assert(drained.toList == List(ProcessingStep(cidA, 0L), ProcessingStep(cidA, 5L)))
}
// ---------------------------------------------------------------------------
// OrderEnforcer trait
// ---------------------------------------------------------------------------
"OrderEnforcer trait" should "be implementable as a custom subclass" in {
val enf = new OrderEnforcer {
override var isCompleted: Boolean = false
override def canProceed(channelId: ChannelIdentity): Boolean = !isCompleted
}
assert(enf.canProceed(cidA))
enf.isCompleted = true
assert(!enf.canProceed(cidA))
}
// ---------------------------------------------------------------------------
// ReplayOrderEnforcer
// ---------------------------------------------------------------------------
/** Stub that exposes a controllable `getStep`. */
private class StubLogManager(
handler: Either[
org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.MainThreadDelegateMessage,
WorkflowFIFOMessage
] => Unit
) extends EmptyReplayLogManagerImpl(handler) {
private var step = 0L
def setStep(s: Long): Unit = { step = s }
override def getStep: Long = step
}
"ReplayOrderEnforcer" should "be completed immediately when the step queue is empty" in {
val mgr = new StubLogManager(_ => ())
val empty = mutable.Queue[ProcessingStep]()
var fired = false
val enf = new ReplayOrderEnforcer(mgr, empty, startStep = 0L, () => fired = true)
assert(enf.isCompleted)
assert(fired)
}
it should "skip log entries whose step is at or below startStep during construction (boundary inclusive)" in {
// Use distinct channels at and around the boundary so the test is sensitive
// to a `step < startStep` (vs `<= startStep`) regression. With startStep=1L:
// - correct impl drops steps 0 and 1 → after ctor, head is step 2 (cidC)
// - buggy impl that drops only step < 1 leaves step 1 (cidB) at the head
// At step=2, canProceed(cidC) consumes step 2 and returns true under the
// correct impl, but the buggy impl never matches the leftover step-1 entry
// (`head.step == step` is 1 != 2), so currentChannelId stays at cidA (set
// from the only forwardNext that ran on step 0) and canProceed(cidC) returns
// false. Either side mismatching this assertion catches the boundary bug.
val mgr = new StubLogManager(_ => ())
mgr.setStep(2L)
val q = mutable.Queue[ProcessingStep](
ProcessingStep(cidA, 0L),
ProcessingStep(cidB, 1L), // boundary entry — distinct channel
ProcessingStep(cidC, 2L),
ProcessingStep(cidA, 3L)
)
val enf = new ReplayOrderEnforcer(mgr, q, startStep = 1L, () => ())
assert(!enf.isCompleted)
val proceeded = enf.canProceed(cidC)
assert(
proceeded,
"boundary entry must be dropped (step <= startStep), so cidC at step 2 is the next allowed channel"
)
}
it should "advance to the next channel and fire onComplete on the non-empty-to-empty transition" in {
val mgr = new StubLogManager(_ => ())
mgr.setStep(0L)
val q = mutable.Queue[ProcessingStep](
ProcessingStep(cidA, 0L),
ProcessingStep(cidB, 1L)
)
var fired = 0
val enf = new ReplayOrderEnforcer(mgr, q, startStep = -1L, () => fired += 1)
assert(fired == 0, "onComplete must NOT fire at construction while the queue is non-empty")
// At step 0, the head matches and is consumed; currentChannelId becomes cidA.
assert(enf.canProceed(cidA))
assert(!enf.canProceed(cidB), "still on cidA until the next step is observed")
assert(!enf.isCompleted)
assert(fired == 0, "onComplete must NOT fire while the queue still has entries")
mgr.setStep(1L)
// The pre-advancement query: cidA is the previous channel, but the while
// loop in canProceed will consume step 1 (cidB) before evaluating the
// membership check, so cidA is rejected at step 1.
assert(!enf.canProceed(cidA), "step 1's channel is cidB, not cidA")
assert(enf.isCompleted, "queue is exhausted, replay must mark completed")
// Now the previously consumed cidB is the current channel — pin that
// a regression that drains the queue without updating the active
// channel would NOT just satisfy this test by silence.
assert(enf.canProceed(cidB), "cidB must be the active channel after step 1 is consumed")
// onComplete must fire exactly once, on the empty-transition.
assert(fired == 1, "onComplete must fire on the non-empty-to-empty transition")
enf.canProceed(cidA) // further calls past completion must not refire
assert(fired == 1)
}
it should "fire onComplete exactly once even if canProceed is called repeatedly past the end" in {
val mgr = new StubLogManager(_ => ())
var fired = 0
val enf = new ReplayOrderEnforcer(
mgr,
mutable.Queue.empty[ProcessingStep],
startStep = 0L,
() => fired += 1
)
assert(fired == 1)
enf.canProceed(cidA) // already completed → must not refire
enf.canProceed(cidB)
assert(fired == 1)
}
it should "consume every queue entry sharing the current step (the duplicate-step while loop)" in {
// canProceed contains a `while (head.step == step) forwardNext()` loop
// specifically because checkpoints produce duplicate step records (the
// MainThreadDelegateMessage path emits an extra ProcessingStep at the
// same step). A regression that consumed only one entry per step would
// leave a stale duplicate at the head, so a subsequent canProceed at
// the next step would still see the old channel — not the real next one.
// Pin the multi-consume behavior with two adjacent same-step entries.
val mgr = new StubLogManager(_ => ())
mgr.setStep(0L)
val q = mutable.Queue[ProcessingStep](
ProcessingStep(cidA, 0L),
ProcessingStep(cidB, 0L), // duplicate step — must be consumed too
ProcessingStep(cidC, 1L)
)
val enf = new ReplayOrderEnforcer(mgr, q, startStep = -1L, () => ())
// After both step-0 entries are consumed, currentChannelId is the LAST
// one (cidB). cidA was the head but is no longer the active channel.
assert(enf.canProceed(cidB), "the second step-0 entry (cidB) must be the active channel")
assert(
!enf.canProceed(cidA),
"cidA was consumed by the duplicate-step while loop and is no longer active"
)
assert(!enf.isCompleted, "step 1 (cidC) is still queued")
// Advancing to step 1 consumes cidC, leaving the queue empty.
mgr.setStep(1L)
assert(enf.canProceed(cidC))
assert(enf.isCompleted)
}
// ---------------------------------------------------------------------------
// ReplayLogRecord serde
// ---------------------------------------------------------------------------
// Round-trip each ReplayLogRecord subtype through Pekko Serialization (the
// exact path SequentialRecordStorage uses in production via
// AmberRuntime.serde). A broken serde registration or a deserialization
// mismatch would fail this test, where `isInstanceOf[Serializable]` would
// not.
private def roundTrip(r: ReplayLogRecord): ReplayLogRecord = {
val bytes = AmberRuntime.serde.serialize(r).get
AmberRuntime.serde.deserialize(bytes, classOf[ReplayLogRecord]).get
}
// Production never writes a DataFrame to the replay log: both the
// controller and DP-thread paths filter for `DirectControlMessagePayload`
// before logging (see `Controller.scala` and `DPThread.scala` use of
// `_.payload.isInstanceOf[DirectControlMessagePayload]`). The trait has
// two concrete subtypes that production actually serializes —
// `ControlInvocation` (outgoing call) and `ReturnInvocation` (reply) —
// and `processDCM` handles both. Round-trip each so a serializer
// regression on either subtype fails this spec.
"ReplayLogRecord MessageContent" should "round-trip a ControlInvocation payload through AmberRuntime.serde" in {
val payload = ControlInvocation(
methodName = "doNothing",
command = EmptyRequest(),
context = AsyncRPCContext(workerId, workerId),
commandId = 42L
)
val msg = WorkflowFIFOMessage(cidA, 1L, payload)
val original: ReplayLogRecord = MessageContent(msg)
val restored = roundTrip(original)
assert(restored == original)
val restoredMsg = restored.asInstanceOf[MessageContent].message
assert(restoredMsg == msg)
val restoredPayload = restoredMsg.payload.asInstanceOf[ControlInvocation]
assert(restoredPayload.methodName == "doNothing")
assert(restoredPayload.commandId == 42L)
}
it should "round-trip a ReturnInvocation payload through AmberRuntime.serde" in {
val payload = ReturnInvocation(commandId = 42L, returnValue = EmptyReturn())
val msg = WorkflowFIFOMessage(cidA, 2L, payload)
val original: ReplayLogRecord = MessageContent(msg)
val restored = roundTrip(original)
assert(restored == original)
val restoredMsg = restored.asInstanceOf[MessageContent].message
assert(restoredMsg == msg)
val restoredPayload = restoredMsg.payload.asInstanceOf[ReturnInvocation]
assert(restoredPayload.commandId == 42L)
assert(restoredPayload.returnValue == EmptyReturn())
}
"ReplayLogRecord ProcessingStep" should "round-trip through AmberRuntime.serde" in {
val original: ReplayLogRecord = ProcessingStep(cidA, 7L)
val restored = roundTrip(original)
assert(restored == original)
val ps = restored.asInstanceOf[ProcessingStep]
assert(ps.channelId == cidA)
assert(ps.step == 7L)
}
"ReplayLogRecord ReplayDestination" should "round-trip through AmberRuntime.serde" in {
val ecm = EmbeddedControlMessageIdentity("ecm-1")
val original: ReplayLogRecord = ReplayDestination(ecm)
val restored = roundTrip(original)
assert(restored == original)
assert(restored.asInstanceOf[ReplayDestination].id == ecm)
}
// NOTE: TerminateSignal is intentionally NOT round-tripped here. It is an
// in-memory shutdown sentinel for AsyncReplayLogWriter and is filtered
// out before records are written to storage, so a Pekko-serialization
// round-trip is not on a real production path. Pinning `eq`-identity
// post-deserialization would over-constrain the serializer (a future
// change that re-creates the case-object via reflection — still
// semantically correct — would fail). Subtype membership is already
// pinned by the case-object's compile-time `extends ReplayLogRecord`.
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/AmberFIFOChannelSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.core.workflow.PortIdentity
import org.apache.texera.amber.engine.common.ambermessage.{
WorkflowFIFOMessage,
WorkflowFIFOMessagePayload
}
import org.scalatest.flatspec.AnyFlatSpec
class AmberFIFOChannelSpec extends AnyFlatSpec {
private val cid =
ChannelIdentity(ActorVirtualIdentity("from"), ActorVirtualIdentity("to"), isControl = false)
// Non-DataFrame payload, so each message has a deterministic 200L size
// for credit/queued/stashed accounting.
private case class FixedSizePayload() extends WorkflowFIFOMessagePayload
private val msgSize: Long = 200L
private def msg(seq: Long): WorkflowFIFOMessage =
WorkflowFIFOMessage(cid, seq, FixedSizePayload())
// ---------------------------------------------------------------------------
// Construction defaults
// ---------------------------------------------------------------------------
"AmberFIFOChannel" should "expose the configured channelId and have an empty queue at construction" in {
val ch = new AmberFIFOChannel(cid)
assert(ch.channelId == cid)
assert(!ch.hasMessage)
assert(ch.getCurrentSeq == 0L)
assert(ch.getQueuedCredit == 0L)
assert(ch.getTotalMessageSize == 0L)
assert(ch.getTotalStashedSize == 0L)
}
it should "default to enabled" in {
val ch = new AmberFIFOChannel(cid)
assert(ch.isEnabled)
}
// ---------------------------------------------------------------------------
// FIFO ordering and stash
// ---------------------------------------------------------------------------
"AmberFIFOChannel.acceptMessage" should "forward an in-order seq=0 message and advance the current sequence" in {
val ch = new AmberFIFOChannel(cid)
ch.acceptMessage(msg(0L))
assert(ch.hasMessage)
assert(ch.getCurrentSeq == 1L)
assert(ch.getQueuedCredit == msgSize)
assert(ch.getTotalMessageSize == msgSize)
}
it should "stash an out-of-order message until its predecessor arrives, then drain in FIFO order" in {
val ch = new AmberFIFOChannel(cid)
// arrives out of order: seq 1 first, then seq 0
ch.acceptMessage(msg(1L))
assert(!ch.hasMessage, "ahead-of-window message must be stashed, not delivered")
assert(ch.getCurrentSeq == 0L)
assert(ch.getTotalStashedSize == msgSize)
ch.acceptMessage(msg(0L))
// both should drain
assert(ch.hasMessage)
assert(ch.getCurrentSeq == 2L)
assert(ch.getQueuedCredit == 2 * msgSize)
assert(ch.getTotalStashedSize == 0L)
val first = ch.take
val second = ch.take
assert(first.sequenceNumber == 0L)
assert(second.sequenceNumber == 1L)
assert(!ch.hasMessage)
assert(ch.getQueuedCredit == 0L)
}
it should "drain a contiguous run from the stash once the gap fills, leaving a non-contiguous stashed message behind" in {
// A three-message stash with a gap: seq 1, 2, 4 are all stashed because
// seq 0 hasn't arrived; once 0 arrives, the contiguous run 0..2 drains
// but 4 stays stashed because seq 3 is still missing.
val ch = new AmberFIFOChannel(cid)
ch.acceptMessage(msg(1L))
ch.acceptMessage(msg(2L))
ch.acceptMessage(msg(4L))
ch.acceptMessage(msg(0L))
assert(ch.getCurrentSeq == 3L, "drain must advance current to the first missing seq")
// queued: 0, 1, 2 — three messages worth of credit
assert(ch.getQueuedCredit == 3 * msgSize)
assert(ch.getTotalStashedSize == msgSize, "only seq=4 remains stashed")
}
it should "drop duplicates whose sequence number is below the current high-water mark" in {
val ch = new AmberFIFOChannel(cid)
ch.acceptMessage(msg(0L))
ch.acceptMessage(msg(0L)) // duplicate
assert(ch.getCurrentSeq == 1L, "duplicate must not advance the sequence")
// only one message is buffered
val out = ch.take
assert(out.sequenceNumber == 0L)
assert(!ch.hasMessage)
}
it should "drop duplicates that are stashed twice ahead of the current window" in {
val ch = new AmberFIFOChannel(cid)
ch.acceptMessage(msg(2L))
ch.acceptMessage(msg(2L)) // duplicate stash
assert(ch.getTotalStashedSize == msgSize, "duplicate stash must not double-count")
// unblock by delivering 0 and 1
ch.acceptMessage(msg(0L))
ch.acceptMessage(msg(1L))
assert(ch.getCurrentSeq == 3L)
val received = (0 until 3).map(_ => ch.take.sequenceNumber).toList
assert(received == List(0L, 1L, 2L))
}
// ---------------------------------------------------------------------------
// Accounting under take
// ---------------------------------------------------------------------------
"AmberFIFOChannel.take" should "decrement getQueuedCredit by the size of the dequeued message" in {
val ch = new AmberFIFOChannel(cid)
ch.acceptMessage(msg(0L))
ch.acceptMessage(msg(1L))
assert(ch.getQueuedCredit == 2 * msgSize)
ch.take
assert(ch.getQueuedCredit == msgSize)
ch.take
assert(ch.getQueuedCredit == 0L)
}
// ---------------------------------------------------------------------------
// Size accessors
// ---------------------------------------------------------------------------
"AmberFIFOChannel.getTotalMessageSize" should "report the sum of in-memory size across queued messages" in {
val ch = new AmberFIFOChannel(cid)
ch.acceptMessage(msg(0L))
ch.acceptMessage(msg(1L))
assert(ch.getTotalMessageSize == 2 * msgSize)
}
"AmberFIFOChannel.getTotalStashedSize" should "report the sum of in-memory size across stashed messages only" in {
val ch = new AmberFIFOChannel(cid)
ch.acceptMessage(msg(2L))
ch.acceptMessage(msg(4L))
assert(ch.getTotalStashedSize == 2 * msgSize)
assert(ch.getTotalMessageSize == 0L, "stashed messages do not count toward queued size")
}
// ---------------------------------------------------------------------------
// enable / isEnabled
// ---------------------------------------------------------------------------
"AmberFIFOChannel.enable" should "toggle the enabled flag" in {
val ch = new AmberFIFOChannel(cid)
ch.enable(false)
assert(!ch.isEnabled)
ch.enable(true)
assert(ch.isEnabled)
}
// ---------------------------------------------------------------------------
// PortId association
// ---------------------------------------------------------------------------
"AmberFIFOChannel.getPortId" should "throw IllegalStateException with a descriptive message when no portId has been set" in {
val ch = new AmberFIFOChannel(cid)
val ex = intercept[IllegalStateException] {
ch.getPortId
}
assert(ex.getMessage.contains("portId has not been set"))
assert(ex.getMessage.contains(cid.toString))
}
it should "return the most recently configured portId" in {
val ch = new AmberFIFOChannel(cid)
ch.setPortId(PortIdentity(0))
assert(ch.getPortId == PortIdentity(0))
ch.setPortId(PortIdentity(7))
assert(ch.getPortId == PortIdentity(7))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/CongestionControlSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.common.WorkflowActor.NetworkMessage
import org.apache.texera.amber.engine.common.ambermessage.{DataFrame, WorkflowFIFOMessage}
import org.scalatest.flatspec.AnyFlatSpec
class CongestionControlSpec extends AnyFlatSpec {
private val channelId =
ChannelIdentity(ActorVirtualIdentity("from"), ActorVirtualIdentity("to"), isControl = false)
private def msg(id: Long): NetworkMessage =
NetworkMessage(id, WorkflowFIFOMessage(channelId, id, DataFrame(Array.empty)))
// Backdate `sentTime` for `id` so the timeout branches (ack > ackTimeLimit
// and getTimedOutInTransitMessages > resendTimeLimit) become reachable
// without sleeping. The field is `private val sentTime: LongMap[Long]`,
// accessed via Java reflection on the instance's backing field.
private def backdateSentTime(cc: CongestionControl, id: Long, ageMillis: Long): Unit = {
val field = classOf[CongestionControl].getDeclaredField("sentTime")
field.setAccessible(true)
val map = field.get(cc).asInstanceOf[scala.collection.mutable.LongMap[Long]]
map(id) = System.currentTimeMillis() - ageMillis
}
"CongestionControl.canSend" should "be true initially with empty in-transit set" in {
val cc = new CongestionControl()
assert(cc.canSend)
}
it should "become false once in-transit messages reach the window size" in {
val cc = new CongestionControl()
// initial windowSize = 1
cc.markMessageInTransit(msg(1L))
assert(!cc.canSend)
}
it should "not block markMessageInTransit when in-transit count already exceeds window" in {
// CongestionControl tracks message *count*, not byte size — payload size
// does not factor into the window check (that's FlowControl's job, not
// this class's). markMessageInTransit is a passive setter: it does not
// check `canSend`. Callers are expected to consult `canSend` first; if
// they don't, the in-transit map grows past windowSize but `canSend`
// stays false.
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
cc.markMessageInTransit(msg(2L)) // ignores window; should still record
cc.markMessageInTransit(msg(3L))
assert(cc.getInTransitMessages.size == 3)
assert(!cc.canSend)
}
it should "stay true while in-transit count is below the grown window" in {
val cc = new CongestionControl()
// After three slow-start acks, the window should be at least 4. Verify
// that three in-transit messages still leave room for more.
(1L to 3L).foreach { i =>
cc.markMessageInTransit(msg(i))
cc.ack(i)
}
cc.markMessageInTransit(msg(10L))
cc.markMessageInTransit(msg(11L))
cc.markMessageInTransit(msg(12L))
assert(cc.canSend, "window grew via slow start; 3 in-transit must not yet hit the cap")
}
it should "absorb arbitrarily many enqueued messages even when the window is full" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L)) // fills window-of-1
assert(!cc.canSend)
// Receivers may push many more while we are blocked; they must all queue
// up and surface via getAllMessages without truncation or error.
(10L until 30L).foreach(i => cc.enqueueMessage(msg(i)))
val all = cc.getAllMessages.map(_.messageId).toSet
assert(all.contains(1L))
assert((10L until 30L).forall(all.contains))
}
"CongestionControl.ack" should "be a no-op for an unknown message id" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
cc.ack(99L)
// CongestionControl.ack returns silently for ids not in `inTransit`
// (no logging, no exception, no window change). Pin the state-level
// no-op: the previously in-transit message survives, window stays full.
assert(cc.getInTransitMessages.exists(_.messageId == 1L))
assert(cc.getInTransitMessages.size == 1)
assert(!cc.canSend)
}
it should "be a no-op when the same message id is acked twice" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
cc.ack(1L)
val sizeAfterFirst = cc.getInTransitMessages.size
cc.ack(1L) // duplicate ack — must not throw or further alter state
assert(cc.getInTransitMessages.size == sizeAfterFirst)
}
it should "remove an acked in-transit message and allow more sending" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
cc.ack(1L)
assert(!cc.getInTransitMessages.exists(_.messageId == 1L))
assert(cc.canSend)
}
it should "grow the window via slow start when acked within the ack time limit" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
cc.ack(1L) // immediate ack — well within ackTimeLimit (3s)
// After the first slow-start ack, windowSize should be at least 2.
cc.markMessageInTransit(msg(2L))
assert(
cc.canSend,
"window must permit at least one more in-transit message after slow-start ack"
)
}
it should "double the window during slow start, then increment linearly past ssThreshold" in {
// ssThreshold defaults to 16 and windowSize to 1. Five quick acks should
// double 1→2→4→8→16, then increment to 17 on the next ack (the fifth ack
// hits the linear branch because windowSize == ssThreshold == 16).
val cc = new CongestionControl()
for (i <- 0 until 5) {
cc.markMessageInTransit(msg(i.toLong))
cc.ack(i.toLong)
}
assert(
cc.getStatusReport.contains("current window size = 17"),
s"unexpected status: ${cc.getStatusReport}"
)
}
"CongestionControl.ack outside ackTimeLimit" should
"halve ssThreshold and snap windowSize back to ssThreshold" in {
// Drive windowSize up to 16 (== ssThreshold) via four in-window acks,
// then backdate the next send so the ack falls outside ackTimeLimit.
// The timeout branch should halve ssThreshold to 8 and snap windowSize
// back to 8.
val cc = new CongestionControl()
for (i <- 0 until 4) {
cc.markMessageInTransit(msg(i.toLong))
cc.ack(i.toLong)
}
assert(cc.getStatusReport.contains("current window size = 16"))
cc.markMessageInTransit(msg(99L))
backdateSentTime(cc, 99L, 5000) // > ackTimeLimit (3000)
cc.ack(99L)
assert(
cc.getStatusReport.contains("current window size = 8"),
s"unexpected status: ${cc.getStatusReport}"
)
}
"CongestionControl.getBufferedMessagesToSend" should "be bounded by remaining window capacity" in {
val cc = new CongestionControl()
cc.enqueueMessage(msg(1L))
cc.enqueueMessage(msg(2L))
cc.enqueueMessage(msg(3L))
// initial windowSize = 1, inTransit.size = 0 → send up to 1
val first = cc.getBufferedMessagesToSend.toList
assert(first.size == 1)
assert(first.head.messageId == 1L)
}
it should "return an empty iterable when the window is fully consumed" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
cc.enqueueMessage(msg(2L))
assert(cc.getBufferedMessagesToSend.isEmpty)
}
"CongestionControl.getAllMessages" should "include both in-transit and queued messages" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
cc.enqueueMessage(msg(2L))
val all = cc.getAllMessages.map(_.messageId).toSet
assert(all == Set(1L, 2L))
}
"CongestionControl.getTimedOutInTransitMessages" should "be empty when no message has been marked in transit" in {
val cc = new CongestionControl()
assert(cc.getTimedOutInTransitMessages.isEmpty)
}
it should "exclude messages that are still inside the resend time limit" in {
val cc = new CongestionControl()
cc.markMessageInTransit(msg(1L))
// The message was just enqueued, so it is well inside the 60s resend
// window and must not be reported as timed out.
assert(cc.getTimedOutInTransitMessages.isEmpty)
}
it should "return only the messages whose sentTime is older than resendTimeLimit" in {
// Cover the PekkoMessageTransferService.checkResend() retransmission path:
// the in-transit message that has been sitting past the 60s
// resendTimeLimit must surface; the freshly-sent one must not.
val cc = new CongestionControl()
cc.markMessageInTransit(msg(0L))
cc.markMessageInTransit(msg(1L))
backdateSentTime(cc, 0L, 70000) // > resendTimeLimit (60000)
val timedOut = cc.getTimedOutInTransitMessages.toList.map(_.messageId)
assert(timedOut == List(0L))
}
"CongestionControl.enqueueMessage" should "not place the message into the in-transit set on its own" in {
val cc = new CongestionControl()
cc.enqueueMessage(msg(1L))
assert(cc.getInTransitMessages.isEmpty)
// The message should still surface via getAllMessages (which unions
// inTransit and toBeSent), proving it was buffered, not dropped.
assert(cc.getAllMessages.exists(_.messageId == 1L))
}
"CongestionControl.getStatusReport" should
"format the three core counters in the documented order" in {
// Pin the exact format string (separator + ordering) so a reorder of
// the three fields or a tab-vs-comma swap fails this spec.
val cc = new CongestionControl()
cc.markMessageInTransit(msg(0L))
cc.enqueueMessage(msg(1L))
assert(
cc.getStatusReport == "current window size = 1 \t in transit = 1 \t waiting = 1",
s"unexpected format: ${cc.getStatusReport}"
)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/FlowControlSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import org.apache.texera.amber.config.ApplicationConfig
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.common.WorkflowActor.NetworkMessage
import org.apache.texera.amber.engine.common.ambermessage.{
WorkflowFIFOMessage,
WorkflowFIFOMessagePayload,
WorkflowMessage
}
import org.scalatest.flatspec.AnyFlatSpec
class FlowControlSpec extends AnyFlatSpec {
private val channelId =
ChannelIdentity(ActorVirtualIdentity("from"), ActorVirtualIdentity("to"), isControl = false)
// A non-DataFrame payload so that `WorkflowMessage.getInMemSize` falls through to
// the 200L default branch — using DataFrame(Array.empty) yields 0 bytes, which
// would let any message squeeze through even when the configured credit is 0.
private case class FixedSizePayload() extends WorkflowFIFOMessagePayload
private def msg(id: Long): NetworkMessage =
NetworkMessage(id, WorkflowFIFOMessage(channelId, id, FixedSizePayload()))
// Pin the assumed payload size so the test fails loudly if WorkflowMessage's
// size accounting changes in a way that would invalidate the credit math below.
private val msgSize: Long = WorkflowMessage.getInMemSize(msg(0).internalMessage)
assert(msgSize == 200L)
private val maxBytes = ApplicationConfig.maxCreditAllowedInBytesPerChannel
"FlowControl" should "report full credit and not be overloaded initially" in {
val fc = new FlowControl()
assert(fc.getCredit == maxBytes)
assert(!fc.isOverloaded)
}
"FlowControl.getMessagesToSend" should "forward an incoming message when credit is available" in {
val fc = new FlowControl()
val out = fc.getMessagesToSend(msg(1L)).toList
assert(out == List(msg(1L)))
assert(!fc.isOverloaded)
}
it should "stash an incoming message and become overloaded when credit is exhausted" in {
val fc = new FlowControl()
// exhaust the receiver-side credit so getCredit drops to 0
fc.updateQueuedCredit(maxBytes)
assert(fc.getCredit == 0L)
val out = fc.getMessagesToSend(msg(1L)).toList
assert(out.isEmpty)
assert(fc.isOverloaded)
}
it should "drain stashed messages once credit is restored" in {
val fc = new FlowControl()
fc.updateQueuedCredit(maxBytes)
val firstAttempt = fc.getMessagesToSend(msg(1L)).toList
assert(firstAttempt.isEmpty)
assert(fc.isOverloaded)
fc.updateQueuedCredit(0L)
val drained = fc.getMessagesToSend.toList
assert(drained == List(msg(1L)))
assert(!fc.isOverloaded)
}
it should "force new messages through the stash whenever the stash is non-empty" in {
// While the stash is non-empty, even a new message must be stashed first
// and then drained in FIFO order — never sent ahead of older stashed work.
val fc = new FlowControl()
fc.updateQueuedCredit(maxBytes)
fc.getMessagesToSend(msg(1L)) // stash msg(1L)
assert(fc.isOverloaded)
// Restore enough credit for 2 messages, then push a new one. The branch
// under test always stashes the new message and then drains FIFO.
fc.updateQueuedCredit(maxBytes - 2 * msgSize)
val drained = fc.getMessagesToSend(msg(2L)).toList
assert(drained == List(msg(1L), msg(2L)))
assert(!fc.isOverloaded)
}
it should "leave isOverloaded true when only some stashed messages can be drained" in {
val fc = new FlowControl()
fc.updateQueuedCredit(maxBytes)
fc.getMessagesToSend(msg(1L))
fc.getMessagesToSend(msg(2L))
assert(fc.isOverloaded)
// Restore credit for exactly one message; the second remains stashed.
fc.updateQueuedCredit(maxBytes - msgSize)
val drained = fc.getMessagesToSend.toList
assert(drained == List(msg(1L)))
assert(fc.isOverloaded, "stash still has msg(2L), so overloaded must remain true")
}
"FlowControl.updateQueuedCredit" should "shrink the available credit" in {
val fc = new FlowControl()
fc.updateQueuedCredit(100L)
assert(fc.getCredit == maxBytes - 100L)
}
it should "be relative to the latest call (not cumulative)" in {
val fc = new FlowControl()
fc.updateQueuedCredit(100L)
fc.updateQueuedCredit(50L)
assert(fc.getCredit == maxBytes - 50L)
}
"FlowControl.decreaseInflightCredit" should "free credit equal to the acked amount" in {
val fc = new FlowControl()
// Send a message through to seed `inflightCredit` with the actual size used
// by FlowControl's accounting. This avoids passing an invalid (negative)
// amount to `decreaseInflightCredit`.
fc.getMessagesToSend(msg(1L)).toList
assert(fc.getCredit == maxBytes - msgSize)
fc.decreaseInflightCredit(msgSize)
assert(fc.getCredit == maxBytes)
}
// ---------------------------------------------------------------------------
// Edge / invalid-input cases — credit math under abnormal conditions
// ---------------------------------------------------------------------------
"FlowControl" should "trip the size-cap assertion for a message that exceeds maxByteAllowed" in {
// Build a payload whose getInMemSize returns a value larger than the
// configured per-channel cap. We do this by ratcheting up the Pekko-side
// size accounting via an oversized DataFrame stand-in: emulate by
// exhausting credit to <= 0 and then sending a payload that's already
// larger than 0 — but the assertion in source compares creditNeeded
// against `maxByteAllowed`, not credit. Since FixedSizePayload is 200L
// and maxByteAllowed is multi-GB, we cannot synthesize a too-big payload
// in unit-test scope without producing terabytes. Instead, lock down
// the *guard* shape: a message at exactly maxByteAllowed is allowed by
// the assertion (not strictly greater), so any 200L payload always
// passes — confirm that 1000 sequential 200L messages all pass the
// assertion regardless of credit accounting.
val fc = new FlowControl()
(1L to 1000L).foreach(i => fc.getMessagesToSend(msg(i)))
succeed
}
it should "eventually drain the stash across many ack cycles (multi-run)" in {
val fc = new FlowControl()
// Saturate credit and stash a batch of messages.
fc.updateQueuedCredit(maxBytes)
val stashed = (1L to 20L).map { i =>
fc.getMessagesToSend(msg(i))
i
}
assert(fc.isOverloaded)
// Now alternately restore credit one message at a time and drain.
var seen = 0L
stashed.foreach { _ =>
fc.updateQueuedCredit(maxBytes - msgSize) // 1 message worth of credit
val out = fc.getMessagesToSend.toList
assert(out.size == 1)
seen += 1
// Reset queued back to maxBytes so inflight is the only buffer
fc.decreaseInflightCredit(msgSize)
fc.updateQueuedCredit(maxBytes)
}
assert(seen == stashed.size)
}
"FlowControl.updateQueuedCredit" should "accept a zero queued credit (reset back to full)" in {
val fc = new FlowControl()
fc.updateQueuedCredit(100L)
fc.updateQueuedCredit(0L)
assert(fc.getCredit == maxBytes)
}
it should "accept a negative queued credit (overshoot, increasing visible credit)" in {
// FlowControl performs no validation on queuedCredit; a negative input
// simply increases getCredit. Pin this so a future input validator
// surfaces as a test failure.
val fc = new FlowControl()
fc.updateQueuedCredit(-100L)
assert(fc.getCredit == maxBytes - (-100L))
}
"FlowControl.decreaseInflightCredit" should "be a tolerated no-op for amount = 0" in {
val fc = new FlowControl()
fc.decreaseInflightCredit(0L)
assert(fc.getCredit == maxBytes)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/NetworkInputGatewaySpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import org.apache.texera.amber.core.tuple.{AttributeType, Schema, TupleLike}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.common.ambermessage.{DataFrame, WorkflowFIFOMessage}
import org.scalamock.scalatest.MockFactory
import org.scalatest.flatspec.AnyFlatSpec
class NetworkInputGatewaySpec extends AnyFlatSpec with MockFactory {
private val fakeReceiverID = ActorVirtualIdentity("testReceiver")
private val fakeSenderID = ActorVirtualIdentity("testSender")
private val channelId = ChannelIdentity(fakeSenderID, fakeReceiverID, isControl = false)
private val payloads = (0 until 4).map { i =>
DataFrame(
Array(
TupleLike(i) enforceSchema Schema().add("field1", AttributeType.INTEGER)
)
)
}.toArray
private val messages = (0 until 4).map { i =>
WorkflowFIFOMessage(channelId, i, payloads(i))
}.toArray
"network input port" should "output payload in FIFO order" in {
val inputPort = new NetworkInputGateway(fakeReceiverID)
Array(2, 0, 1, 3).foreach { i =>
inputPort.getChannel(channelId).acceptMessage(messages(i))
}
(0 until 4).foreach { i =>
val msg = inputPort.getChannel(channelId).take
assert(msg.sequenceNumber == i)
}
}
"network input port" should "de-duplicate payload" in {
val inputPort = new NetworkInputGateway(fakeReceiverID)
Array(2, 2, 2, 2, 2, 2, 0, 1, 1, 3, 3).foreach { i =>
inputPort.getChannel(channelId).acceptMessage(messages(i))
}
(0 until 4).foreach { i =>
val msg = inputPort.getChannel(channelId).take
assert(msg.sequenceNumber == i)
}
assert(!inputPort.getChannel(channelId).hasMessage)
}
"network input port" should "keep unordered messages" in {
val inputPort = new NetworkInputGateway(fakeReceiverID)
Array(3, 2, 1).foreach { i =>
inputPort.getChannel(channelId).acceptMessage(messages(i))
}
assert(!inputPort.getChannel(channelId).hasMessage)
inputPort.getChannel(channelId).acceptMessage(messages(0))
assert(inputPort.getChannel(channelId).hasMessage)
(0 until 4).foreach { i =>
val msg = inputPort.getChannel(channelId).take
assert(msg.sequenceNumber == i)
}
assert(!inputPort.getChannel(channelId).hasMessage)
}
"network input port" should "remove control channel by sender" in {
val inputPort = new NetworkInputGateway(fakeReceiverID)
val controlChannelId = ChannelIdentity(fakeSenderID, fakeReceiverID, isControl = true)
inputPort.getChannel(controlChannelId)
assert(inputPort.getAllControlChannels.size == 1)
inputPort.removeControlChannel(fakeSenderID)
assert(inputPort.getAllControlChannels.isEmpty)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OrderingEnforcerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
class OrderingEnforcerSpec extends AnyFlatSpec with Matchers {
// ----- initial state -----
"OrderingEnforcer" should "start with current=0 and an empty stash" in {
val enforcer = new OrderingEnforcer[String]
enforcer.current shouldBe 0L
enforcer.ofoMap shouldBe empty
}
// ----- setCurrent -----
"setCurrent" should "advance the current cursor and shift the duplicate threshold" in {
val enforcer = new OrderingEnforcer[String]
enforcer.setCurrent(10L)
enforcer.current shouldBe 10L
enforcer.isDuplicated(9L) shouldBe true
enforcer.isDuplicated(10L) shouldBe false
}
// ----- isDuplicated -----
"isDuplicated" should "treat sequence numbers below current as duplicates" in {
val enforcer = new OrderingEnforcer[String]
enforcer.setCurrent(5L)
enforcer.isDuplicated(0L) shouldBe true
enforcer.isDuplicated(4L) shouldBe true
}
it should "treat sequence numbers >= current that are not stashed as not duplicated" in {
val enforcer = new OrderingEnforcer[String]
enforcer.setCurrent(5L)
enforcer.isDuplicated(5L) shouldBe false
enforcer.isDuplicated(7L) shouldBe false
}
it should "report stashed future sequence numbers as duplicated" in {
val enforcer = new OrderingEnforcer[String]
enforcer.stash(7L, "seven")
enforcer.isDuplicated(7L) shouldBe true
}
// ----- isAhead -----
"isAhead" should "be true only for sequence numbers strictly greater than current" in {
val enforcer = new OrderingEnforcer[String]
enforcer.setCurrent(5L)
enforcer.isAhead(6L) shouldBe true
enforcer.isAhead(5L) shouldBe false
enforcer.isAhead(4L) shouldBe false
}
// ----- stash -----
"stash" should "store data under its sequence number for later draining" in {
val enforcer = new OrderingEnforcer[String]
enforcer.stash(2L, "two")
enforcer.ofoMap(2L) shouldBe "two"
}
it should "overwrite an existing stash entry at the same sequence number" in {
// Pin: there is no guard against re-stashing the same sequence number.
// Callers rely on isDuplicated to skip the second stash, but a direct
// re-stash still overwrites silently.
val enforcer = new OrderingEnforcer[String]
enforcer.stash(2L, "first")
enforcer.stash(2L, "second")
enforcer.ofoMap(2L) shouldBe "second"
}
// ----- enforceFIFO -----
"enforceFIFO" should "advance current by one and emit just the input when no stash is queued" in {
val enforcer = new OrderingEnforcer[String]
enforcer.enforceFIFO("zero") shouldBe List("zero")
enforcer.current shouldBe 1L
}
it should "drain a single contiguous stashed entry after the input" in {
val enforcer = new OrderingEnforcer[String]
enforcer.stash(1L, "one")
enforcer.enforceFIFO("zero") shouldBe List("zero", "one")
enforcer.current shouldBe 2L
enforcer.ofoMap should not contain key(1L)
}
it should "drain a contiguous run from the stash and stop at the first gap" in {
val enforcer = new OrderingEnforcer[String]
enforcer.stash(1L, "one")
enforcer.stash(2L, "two")
enforcer.stash(4L, "four") // gap at 3
val emitted = enforcer.enforceFIFO("zero")
emitted shouldBe List("zero", "one", "two")
enforcer.current shouldBe 3L
enforcer.ofoMap.keys.toList shouldBe List(4L)
}
it should "leave the stash untouched when none of the queued entries are contiguous" in {
val enforcer = new OrderingEnforcer[String]
enforcer.stash(5L, "five")
enforcer.stash(7L, "seven")
val emitted = enforcer.enforceFIFO("zero")
emitted shouldBe List("zero")
enforcer.current shouldBe 1L
enforcer.ofoMap.keys.toSet shouldBe Set(5L, 7L)
}
it should "respect a non-zero starting current when draining" in {
// Setting the cursor manually mimics replay/recovery: the enforcer skips
// past prior messages and only drains entries with sequence numbers
// strictly greater than the current value at call time.
val enforcer = new OrderingEnforcer[String]
enforcer.setCurrent(10L)
enforcer.stash(11L, "eleven")
enforcer.stash(12L, "twelve")
val emitted = enforcer.enforceFIFO("ten")
emitted shouldBe List("ten", "eleven", "twelve")
enforcer.current shouldBe 13L
}
it should "support int payloads via the type parameter" in {
val enforcer = new OrderingEnforcer[Int]
enforcer.stash(1L, 100)
enforcer.enforceFIFO(0) shouldBe List(0, 100)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManagerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import com.softwaremill.macwire.wire
import org.apache.texera.amber.core.tuple.{AttributeType, Schema, TupleLike}
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ChannelIdentity,
OperatorIdentity,
PhysicalOpIdentity
}
import org.apache.texera.amber.core.workflow.{PhysicalLink, PortIdentity}
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.OneToOnePartitioning
import org.apache.texera.amber.engine.common.ambermessage._
import org.scalamock.scalatest.MockFactory
import org.scalatest.flatspec.AnyFlatSpec
class OutputManagerSpec extends AnyFlatSpec with MockFactory {
private val mockHandler =
mock[WorkflowFIFOMessage => Unit]
private val identifier = ActorVirtualIdentity("batch producer mock")
private val mockDataOutputPort = // scalafix:ok; need it for wiring purpose
new NetworkOutputGateway(identifier, mockHandler)
var counter: Int = 0
val schema: Schema = Schema()
.add("field1", AttributeType.INTEGER)
.add("field2", AttributeType.INTEGER)
.add("field3", AttributeType.INTEGER)
.add("field4", AttributeType.INTEGER)
.add("field5", AttributeType.STRING)
.add("field6", AttributeType.DOUBLE)
def physicalOpId(): PhysicalOpIdentity = {
counter += 1
PhysicalOpIdentity(OperatorIdentity("" + counter), "" + counter)
}
def mkDataMessage(
to: ActorVirtualIdentity,
from: ActorVirtualIdentity,
seq: Long,
payload: DataPayload
): WorkflowFIFOMessage = {
WorkflowFIFOMessage(ChannelIdentity(from, to, isControl = false), seq, payload)
}
"OutputManager" should "aggregate tuples and output" in {
val outputManager = wire[OutputManager]
val mockPortId = PortIdentity()
outputManager.addPort(mockPortId, schema, None)
val tuples = Array.fill(21)(
TupleLike(1, 2, 3, 4, "5", 9.8).enforceSchema(schema)
)
val fakeID = ActorVirtualIdentity("testReceiver")
inSequence {
(mockHandler.apply _).expects(
mkDataMessage(fakeID, identifier, 0, DataFrame(tuples.slice(0, 10)))
)
(mockHandler.apply _).expects(
mkDataMessage(fakeID, identifier, 1, DataFrame(tuples.slice(10, 20)))
)
(mockHandler.apply _).expects(
mkDataMessage(fakeID, identifier, 2, DataFrame(tuples.slice(20, 21)))
)
}
val fakeLink = PhysicalLink(physicalOpId(), mockPortId, physicalOpId(), mockPortId)
val fakeReceiver =
Array[ChannelIdentity](ChannelIdentity(identifier, fakeID, isControl = false))
outputManager.addPartitionerWithPartitioning(
fakeLink,
OneToOnePartitioning(10, fakeReceiver.toSeq)
)
tuples.foreach { t =>
outputManager.passTupleToDownstream(TupleLike(t.getFields).enforceSchema(schema), None)
}
outputManager.flush()
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/RangeBasedShuffleSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.sendsemantics.partitioners.RangeBasedShufflePartitioner
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.RangeBasedShufflePartitioning
import org.scalamock.scalatest.MockFactory
import org.scalatest.flatspec.AnyFlatSpec
class RangeBasedShuffleSpec extends AnyFlatSpec with MockFactory {
val identifier = ActorVirtualIdentity("batch producer mock")
val fakeID1: ActorVirtualIdentity = ActorVirtualIdentity("rec1")
val fakeID2: ActorVirtualIdentity = ActorVirtualIdentity("rec2")
val fakeID3: ActorVirtualIdentity = ActorVirtualIdentity("rec3")
val fakeID4: ActorVirtualIdentity = ActorVirtualIdentity("rec4")
val fakeID5: ActorVirtualIdentity = ActorVirtualIdentity("rec5")
val attr: Attribute = new Attribute("Attr1", AttributeType.INTEGER)
val schema: Schema = Schema().add(attr)
val partitioning: RangeBasedShufflePartitioning =
RangeBasedShufflePartitioning(
400,
List(
ChannelIdentity(identifier, fakeID1, isControl = false),
ChannelIdentity(identifier, fakeID2, isControl = false),
ChannelIdentity(identifier, fakeID3, isControl = false),
ChannelIdentity(identifier, fakeID4, isControl = false),
ChannelIdentity(identifier, fakeID5, isControl = false)
),
Seq("Attr1"),
-400,
600
)
val partitioner: RangeBasedShufflePartitioner = RangeBasedShufflePartitioner(partitioning)
"RangeBasedShuffleSpec" should "return 0 when value is less than rangeMin" in {
val tuple = Tuple.builder(schema).add(attr, -600).build()
val idx = partitioner.getBucketIndex(tuple)
assert(idx.next() == 0)
}
"RangeBasedShuffleSpec" should "return last receiver when value is more than rangeMax" in {
val tuple = Tuple.builder(schema).add(attr, 800).build()
val idx = partitioner.getBucketIndex(tuple)
assert(idx.next() == 4)
}
"RangeBasedShuffleSpec" should "find index correctly" in {
var tuple = Tuple.builder(schema).add(attr, -400).build()
var idx = partitioner.getBucketIndex(tuple)
assert(idx.next() == 0)
tuple = Tuple.builder(schema).add(attr, -200).build()
idx = partitioner.getBucketIndex(tuple)
assert(idx.next() == 0)
tuple = Tuple.builder(schema).add(attr, -199).build()
idx = partitioner.getBucketIndex(tuple)
assert(idx.next() == 1)
}
"RangeBasedShuffleSpec" should "handle different data types correctly" in {
var tuple = Tuple.builder(schema).add(attr, -90).build()
var idx = partitioner.getBucketIndex(tuple)
assert(idx.next() == 1)
val partitioning2: RangeBasedShufflePartitioning =
RangeBasedShufflePartitioning(
400,
List(
ChannelIdentity(identifier, fakeID1, isControl = false),
ChannelIdentity(identifier, fakeID2, isControl = false),
ChannelIdentity(identifier, fakeID3, isControl = false),
ChannelIdentity(identifier, fakeID4, isControl = false),
ChannelIdentity(identifier, fakeID5, isControl = false)
),
Seq("Attr2"),
-400,
600
)
val partitioner2: RangeBasedShufflePartitioner = RangeBasedShufflePartitioner(partitioning2)
val doubleAttr: Attribute = new Attribute("Attr2", AttributeType.DOUBLE)
val doubleSchema: Schema = Schema().add(doubleAttr)
tuple = Tuple.builder(doubleSchema).add(doubleAttr, -90.5).build()
idx = partitioner2.getBucketIndex(tuple)
assert(idx.next() == 1)
val partitioning3: RangeBasedShufflePartitioning =
RangeBasedShufflePartitioning(
400,
List(
ChannelIdentity(identifier, fakeID1, isControl = false),
ChannelIdentity(identifier, fakeID2, isControl = false),
ChannelIdentity(identifier, fakeID3, isControl = false),
ChannelIdentity(identifier, fakeID4, isControl = false),
ChannelIdentity(identifier, fakeID5, isControl = false)
),
Seq("Attr3"),
-400,
600
)
val partitioner3: RangeBasedShufflePartitioner = RangeBasedShufflePartitioner(partitioning3)
val longAttr: Attribute = new Attribute("Attr3", AttributeType.LONG)
val longSchema: Schema = Schema().add(longAttr)
tuple = Tuple.builder(longSchema).add(longAttr, -90L).build()
idx = partitioner3.getBucketIndex(tuple)
assert(idx.next() == 1)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/messaginglayer/WorkerPortSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.messaginglayer
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class WorkerPortSpec extends AnyFlatSpec {
private val schema: Schema = Schema().add(new Attribute("v", AttributeType.INTEGER))
"WorkerPort" should "default to an empty channel set and not-completed state" in {
val p = WorkerPort(schema)
assert(p.schema == schema)
assert(p.channels.isEmpty)
assert(!p.completed)
}
it should "carry the channel set provided at construction" in {
val cid =
ChannelIdentity(ActorVirtualIdentity("a"), ActorVirtualIdentity("b"), isControl = false)
val p = WorkerPort(schema, mutable.Set(cid))
assert(p.channels == mutable.Set(cid))
}
it should "allow `completed` to be flipped to true" in {
val p = WorkerPort(schema)
p.completed = true
assert(p.completed)
}
it should "allow channels to be appended after construction" in {
val p = WorkerPort(schema)
val cid =
ChannelIdentity(ActorVirtualIdentity("a"), ActorVirtualIdentity("b"), isControl = false)
p.channels += cid
assert(p.channels.contains(cid))
}
it should "treat distinct instances with the same fields as case-class equal" in {
val cid =
ChannelIdentity(ActorVirtualIdentity("a"), ActorVirtualIdentity("b"), isControl = false)
val a = WorkerPort(schema, mutable.Set(cid), completed = true)
val b = WorkerPort(schema, mutable.Set(cid), completed = true)
assert(a == b)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonWorkflowWorkerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
//package org.apache.texera.amber.engine.architecture.pythonworker
//
//import org.apache.pekko.actor.{ActorRef, ActorSystem, Props}
//import org.apache.pekko.testkit.{ImplicitSender, TestActorRef, TestKit}
//import org.apache.texera.amber.clustering.SingleNodeListener
//import org.apache.texera.amber.engine.architecture.common.WorkflowActor.{NetworkAck, NetworkMessage}
//import org.apache.texera.amber.engine.architecture.pythonworker.promisehandlers.InitializeOperatorLogicHandler.InitializeOperatorLogic
//import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.OneToOnePartitioning
//import org.apache.texera.amber.engine.architecture.worker.controlcommands.LinkOrdinal
//import org.apache.texera.amber.engine.architecture.worker.promisehandlers.AddPartitioningHandler.AddPartitioning
//import org.apache.texera.amber.engine.architecture.worker.promisehandlers.OpenOperatorHandler.OpenOperator
//import org.apache.texera.amber.engine.architecture.worker.promisehandlers.UpdateInputLinkingHandler.UpdateInputLinking
//import org.apache.texera.amber.engine.common.Constants
//import org.apache.texera.amber.engine.common.ambermessage.{
// ChannelIdentity,
// ControlPayload,
// DataFrame,
// DataPayload,
// EndOfUpstream,
// WorkflowFIFOMessage
//}
//import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient.{ControlInvocation, ReturnInvocation}
//import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
//import org.apache.texera.amber.core.virtualidentity.{
// ActorVirtualIdentity,
// PhysicalLink,
// PhysicalLink,
// OperatorIdentity
//}
//import org.apache.texera.amber.engine.e2e.TestOperators
//import org.apache.texera.workflow.common.tuple.Tuple
//import org.apache.texera.workflow.common.tuple.schema.{Attribute, AttributeType, Schema}
//import org.scalamock.scalatest.MockFactory
//import org.scalatest.BeforeAndAfterAll
//import org.scalatest.flatspec.AnyFlatSpecLike
//
//import scala.concurrent.duration.DurationInt
//
//class PythonWorkflowWorkerSpec
// extends TestKit(ActorSystem("PythonWorkerSpec"))
// with ImplicitSender
// with AnyFlatSpecLike
// with BeforeAndAfterAll
// with MockFactory {
//
// override def beforeAll: Unit = {
// system.actorOf(Props[SingleNodeListener], "cluster-info")
// }
// override def afterAll: Unit = {
// TestKit.shutdownActorSystem(system)
// }
// private val identifier1 = ActorVirtualIdentity("worker-1")
// private val identifier2 = ActorVirtualIdentity("worker-2")
// private val operatorIdentity = OperatorIdentity("testWorkflow", "testOperator")
// private val layerId1 =
// PhysicalLink(operatorIdentity.workflow, operatorIdentity.operator, "1st-layer")
// private val layerId2 =
// PhysicalLink(operatorIdentity.workflow, operatorIdentity.operator, "2nd-layer")
// private val pythonOp = TestOperators.pythonOpDesc()
// private val link = PhysicalLink(layerId1, 0, layerId2, 0)
// private val schema = Schema
// .newBuilder()
// .add(new Attribute("text", AttributeType.STRING))
// .build()
// private val initialization = InitializeOperatorLogic(
// pythonOp.code,
// isSource = false,
// Seq(LinkOrdinal(link, 0)),
// Seq(LinkOrdinal(link, 0)),
// schema
// )
//
// def sendControlToWorker(
// worker: ActorRef,
// controls: Array[ControlInvocation],
// beginSeqNum: Long = 0
// ): Unit = {
// var seq = beginSeqNum
// controls.foreach { ctrl =>
// worker ! NetworkMessage(
// seq,
// WorkflowFIFOMessage(ChannelIdentity(CONTROLLER, identifier1, true), seq, ctrl)
// )
// val received = receiveWhile(3.seconds) {
// case NetworkAck(id, credits) =>
// // pass
// case NetworkMessage(id, fifoPayload) =>
// fifoPayload.payload.asInstanceOf[ControlPayload] match {
// case ControlInvocation(commandID, command) => assert(commandID == seq)
// case ReturnInvocation(originalCommandID, controlReturn) =>
// assert(originalCommandID == seq)
// case _ => ???
// }
// worker ! NetworkAck(id, Constants.unprocessedBatchesSizeLimitInBytesPerWorkerPair)
// }
// seq += 1
// }
// }
//
// def mkWorker: ActorRef = TestActorRef(new PythonWorkflowWorker(identifier1))
//
// "python worker" should "start" in {
// val worker = mkWorker
// sendControlToWorker(worker, Array(ControlInvocation(0, initialization)))
// }
//
// "python worker" should "process data" in {
// val worker = mkWorker
// sendControlToWorker(worker, Array(ControlInvocation(0, initialization)))
// val mockPolicy = OneToOnePartitioning(1, Array(identifier2))
// val openControl = ControlInvocation(1, OpenOperator())
// val invocation = ControlInvocation(2, AddPartitioning(link, mockPolicy))
// val updateInputLinking = ControlInvocation(3, UpdateInputLinking(identifier2, link))
// sendControlToWorker(worker, Array(openControl, invocation, updateInputLinking), 1)
// worker ! NetworkMessage(
// 4,
// WorkflowFIFOMessage(
// ChannelIdentity(identifier2, identifier1, false),
// 0,
// DataFrame(
// Array(
// Tuple
// .newBuilder(schema)
// .add("text", AttributeType.STRING, "123")
// .build()
// )
// )
// )
// )
// expectMsgClass(classOf[NetworkAck])
// val data = receiveOne(30.seconds)
// assert(data.asInstanceOf[NetworkMessage].internalMessage.payload.isInstanceOf[DataFrame])
// }
//
// "python worker" should "process data and receive end marker" in {
// val worker = mkWorker
// sendControlToWorker(worker, Array(ControlInvocation(0, initialization)))
// val mockPolicy = OneToOnePartitioning(100, Array(identifier2))
// val openControl = ControlInvocation(1, OpenOperator())
// val invocation = ControlInvocation(2, AddPartitioning(link, mockPolicy))
// val updateInputLinking = ControlInvocation(3, UpdateInputLinking(identifier2, link))
// sendControlToWorker(worker, Array(openControl, invocation, updateInputLinking), 1)
// worker ! NetworkMessage(
// 4,
// WorkflowFIFOMessage(
// ChannelIdentity(identifier2, identifier1, false),
// 0,
// DataFrame(
// (0 until 100)
// .map(_ =>
// Tuple
// .newBuilder(schema)
// .add("text", AttributeType.STRING, "123")
// .build()
// )
// .toArray
// )
// )
// )
// expectMsgClass(classOf[NetworkAck])
// val data = receiveOne(30.seconds)
// assert(data.asInstanceOf[NetworkMessage].internalMessage.payload.isInstanceOf[DataFrame])
// worker ! NetworkMessage(
// 5,
// WorkflowFIFOMessage(
// ChannelIdentity(identifier2, identifier1, false),
// 1,
// EndOfUpstream()
// )
// )
// expectMsgClass(classOf[NetworkAck])
// receiveWhile(10.seconds) {
// case NetworkMessage(id, fifoPayload) =>
// fifoPayload.payload match {
// case payload: ControlPayload => //skip
// case payload: DataPayload => assert(payload.isInstanceOf[EndOfUpstream])
// case _ => ???
// }
// }
// }
//
//}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/CostBasedScheduleGeneratorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.texera.amber.core.workflow.{
ExecutionMode,
PortIdentity,
WorkflowContext,
WorkflowSettings
}
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.apache.texera.amber.engine.e2e.TestUtils.buildWorkflow
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.workflow.LogicalLink
import org.scalamock.scalatest.MockFactory
import org.scalatest.flatspec.AnyFlatSpec
import scala.jdk.CollectionConverters._
class CostBasedScheduleGeneratorSpec extends AnyFlatSpec with MockFactory {
"CostBasedRegionPlanGenerator" should "finish bottom-up search using different pruning techniques with correct number of states explored in csv->->filter->join->filter2 workflow" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val keywordOpDesc2 = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
keywordOpDesc,
joinOpDesc,
keywordOpDesc2
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
),
LogicalLink(
joinOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc2.operatorIdentifier,
PortIdentity()
)
),
new WorkflowContext()
)
val globalSearchNoPruningResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).bottomUpSearch(globalSearch = true, oChains = false, oCleanEdges = false, oEarlyStop = false)
// Should have explored all possible states (2^4 states)
assert(globalSearchNoPruningResult.numStatesExplored == 16)
val globalSearchOChainsResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).bottomUpSearch(globalSearch = true, oCleanEdges = false, oEarlyStop = false)
// By applying pruning based on Chains alone, it should skip 10 (8 + 2) states. 8 states where CSV->Build is
// materialized should be skipped because this edge is in the same chain as another blocking edge.
// Of the remaining states, 2 more states where both CSV->KeywordFilter and KeywordFilter->Probe are materialized
// should be skipped because these two edges are in the same chain.
assert(globalSearchOChainsResult.numStatesExplored == 6)
val globalSearchOCleanEdgesResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).bottomUpSearch(globalSearch = true, oChains = false, oEarlyStop = false)
// By applying pruning based on Clean edges (bridges) alone, it should skip 8 states. There is one clean edge
// in the DAG (Probe->Keyword2) and the 8 states where this edge is materialized should be skipped.
assert(globalSearchOCleanEdgesResult.numStatesExplored == 8)
val globalSearchOEarlyStopResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).bottomUpSearch(globalSearch = true, oChains = false, oCleanEdges = false)
// By applying pruning based on Early Stop alone, only 6 states that are not descendants of a schedulable states
// should be explored.
assert(globalSearchOEarlyStopResult.numStatesExplored == 6)
val globalSearchAllPruningEnabledResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).bottomUpSearch(globalSearch = true)
// By combining all pruning techniques, only 3 states should be visited (1 state where both CSV->KeywordFilter and
// KeywordFilter->Probe are pipelined, and two states where only one of CSV->KeywordFilter or KeywordFilter->Probe
// is materialized. The other two edges should always be pipelined.)
assert(globalSearchAllPruningEnabledResult.numStatesExplored == 3)
}
"CostBasedRegionPlanGenerator" should "finish top-down search using different pruning techniques with correct number of states explored in csv->->filter->join->filter2 workflow" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val keywordOpDesc2 = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
keywordOpDesc,
joinOpDesc,
keywordOpDesc2
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
),
LogicalLink(
joinOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc2.operatorIdentifier,
PortIdentity()
)
),
new WorkflowContext()
)
val globalSearchNoPruningResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).topDownSearch(globalSearch = true, oChains = false, oCleanEdges = false)
// Should have explored all possible states (2^4 states)
assert(globalSearchNoPruningResult.numStatesExplored == 16)
val globalSearchOChainsResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).topDownSearch(globalSearch = true, oCleanEdges = false)
// By applying pruning based on Chains alone, it should start with a state where CSV->Build is pipelined because
// this edge is in the same chain as another blocking edge. That reduces the search space to 8 states.
assert(globalSearchOChainsResult.numStatesExplored == 8)
val globalSearchOCleanEdgesResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).topDownSearch(globalSearch = true, oChains = false)
// By applying pruning based on Clean Edges (bridges) alone, it should start with a state where Probe->Keyword2 is
// pipelined because this edge is a clean edge. That reduces the search space to 8 states.
assert(globalSearchOCleanEdgesResult.numStatesExplored == 8)
val globalSearchAllPruningEnabledResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).topDownSearch(globalSearch = true)
// By combining both pruning techniques, the search should start with a state where both CSV->Build and
// Probe->Keyword2 are pipelined, reducing the search space to 4 states.
assert(globalSearchAllPruningEnabledResult.numStatesExplored == 4)
}
// MATERIALIZED ExecutionMode tests - each operator should be a separate region
"CostBasedRegionPlanGenerator" should "create separate region for each operator in MATERIALIZED mode for simple csv workflow" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val materializedContext = new WorkflowContext(
workflowSettings = WorkflowSettings(
dataTransferBatchSize = 400,
executionMode = ExecutionMode.MATERIALIZED
)
)
val workflow = buildWorkflow(
List(csvOpDesc),
List(),
materializedContext
)
val scheduleGenerator = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
)
val result = scheduleGenerator.getFullyMaterializedSearchState
// Should only explore 1 state (fully materialized)
assert(result.numStatesExplored == 1)
// Each physical operator should be in its own region
val regions = result.regionDAG.vertexSet().asScala
val numPhysicalOps = workflow.physicalPlan.operators.size
assert(regions.size == numPhysicalOps, s"Expected $numPhysicalOps regions, got ${regions.size}")
// Each region should contain exactly 1 operator
regions.foreach { region =>
assert(
region.getOperators.size == 1,
s"Expected region to have 1 operator, got ${region.getOperators.size}"
)
}
}
"CostBasedRegionPlanGenerator" should "create separate region for each operator in MATERIALIZED mode for csv->keyword workflow" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val materializedContext = new WorkflowContext(
workflowSettings = WorkflowSettings(
dataTransferBatchSize = 400,
executionMode = ExecutionMode.MATERIALIZED
)
)
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
),
materializedContext
)
val scheduleGenerator = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
)
val result = scheduleGenerator.getFullyMaterializedSearchState
// Should only explore 1 state (fully materialized)
assert(result.numStatesExplored == 1)
// Each physical operator should be in its own region
val regions = result.regionDAG.vertexSet().asScala
val numPhysicalOps = workflow.physicalPlan.operators.size
assert(regions.size == numPhysicalOps, s"Expected $numPhysicalOps regions, got ${regions.size}")
// Each region should contain exactly 1 operator
regions.foreach { region =>
assert(
region.getOperators.size == 1,
s"Expected region to have 1 operator, got ${region.getOperators.size}"
)
}
// All links should be materialized (represented as region links)
val numRegionLinks = result.regionDAG.edgeSet().asScala.size
val numPhysicalLinks = workflow.physicalPlan.links.size
assert(
numRegionLinks == numPhysicalLinks,
s"Expected $numPhysicalLinks region links, got $numRegionLinks"
)
}
"CostBasedRegionPlanGenerator" should "create separate region for each operator in MATERIALIZED mode for csv->keyword->count workflow" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val countOpDesc = TestOperators.aggregateAndGroupByDesc(
"Region",
org.apache.texera.amber.operator.aggregate.AggregationFunction.COUNT,
List[String]()
)
val materializedContext = new WorkflowContext(
workflowSettings = WorkflowSettings(
dataTransferBatchSize = 400,
executionMode = ExecutionMode.MATERIALIZED
)
)
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc, countOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
countOpDesc.operatorIdentifier,
PortIdentity()
)
),
materializedContext
)
val scheduleGenerator = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
)
val result = scheduleGenerator.getFullyMaterializedSearchState
// Should only explore 1 state (fully materialized)
assert(result.numStatesExplored == 1)
// Each physical operator should be in its own region
val regions = result.regionDAG.vertexSet().asScala
val numPhysicalOps = workflow.physicalPlan.operators.size
assert(regions.size == numPhysicalOps, s"Expected $numPhysicalOps regions, got ${regions.size}")
// Each region should contain exactly 1 operator
regions.foreach { region =>
assert(
region.getOperators.size == 1,
s"Expected region to have 1 operator, got ${region.getOperators.size}"
)
}
// All links should be materialized (represented as region links)
val numRegionLinks = result.regionDAG.edgeSet().asScala.size
val numPhysicalLinks = workflow.physicalPlan.links.size
assert(
numRegionLinks == numPhysicalLinks,
s"Expected $numPhysicalLinks region links, got $numRegionLinks"
)
}
"CostBasedRegionPlanGenerator" should "create separate region for each operator in MATERIALIZED mode for join workflow" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val headerlessCsvOpDesc2 = TestOperators.headerlessSmallCsvScanOpDesc()
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val materializedContext = new WorkflowContext(
workflowSettings = WorkflowSettings(
dataTransferBatchSize = 400,
executionMode = ExecutionMode.MATERIALIZED
)
)
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
headerlessCsvOpDesc2,
joinOpDesc
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc2.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
)
),
materializedContext
)
val scheduleGenerator = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
)
val result = scheduleGenerator.getFullyMaterializedSearchState
// Should only explore 1 state (fully materialized)
assert(result.numStatesExplored == 1)
// Each physical operator should be in its own region
val regions = result.regionDAG.vertexSet().asScala
val numPhysicalOps = workflow.physicalPlan.operators.size
assert(regions.size == numPhysicalOps, s"Expected $numPhysicalOps regions, got ${regions.size}")
// Each region should contain exactly 1 operator
regions.foreach { region =>
assert(
region.getOperators.size == 1,
s"Expected region to have 1 operator, got ${region.getOperators.size}"
)
}
// All links should be materialized (represented as region links)
val numRegionLinks = result.regionDAG.edgeSet().asScala.size
val numPhysicalLinks = workflow.physicalPlan.links.size
assert(
numRegionLinks == numPhysicalLinks,
s"Expected $numPhysicalLinks region links, got $numRegionLinks"
)
}
"CostBasedRegionPlanGenerator" should "create separate region for each operator in MATERIALIZED mode for complex csv->->filter->join->filter2 workflow" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val keywordOpDesc2 = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val materializedContext = new WorkflowContext(
workflowSettings = WorkflowSettings(
dataTransferBatchSize = 400,
executionMode = ExecutionMode.MATERIALIZED
)
)
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
keywordOpDesc,
joinOpDesc,
keywordOpDesc2
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
),
LogicalLink(
joinOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc2.operatorIdentifier,
PortIdentity()
)
),
materializedContext
)
val scheduleGenerator = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
)
val result = scheduleGenerator.getFullyMaterializedSearchState
// Should only explore 1 state (fully materialized)
assert(result.numStatesExplored == 1)
// Each physical operator should be in its own region
val regions = result.regionDAG.vertexSet().asScala
val numPhysicalOps = workflow.physicalPlan.operators.size
assert(regions.size == numPhysicalOps, s"Expected $numPhysicalOps regions, got ${regions.size}")
// Each region should contain exactly 1 operator
regions.foreach { region =>
assert(
region.getOperators.size == 1,
s"Expected region to have 1 operator, got ${region.getOperators.size}"
)
}
// All links should be materialized (represented as region links)
val numRegionLinks = result.regionDAG.edgeSet().asScala.size
val numPhysicalLinks = workflow.physicalPlan.links.size
assert(
numRegionLinks == numPhysicalLinks,
s"Expected $numPhysicalLinks region links, got $numRegionLinks"
)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/DefaultCostEstimatorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.texera.amber.core.storage.model.{BufferedItemWriter, VirtualDocument}
import org.apache.texera.amber.core.storage.result.ResultSchema
import org.apache.texera.amber.core.storage.{DocumentFactory, VFSURIFactory}
import org.apache.texera.amber.core.tuple.Tuple
import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity}
import org.apache.texera.amber.core.workflow.{GlobalPortIdentity, PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.architecture.scheduling.resourcePolicies.{
DefaultResourceAllocator,
ExecutionClusterInfo
}
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.apache.texera.amber.engine.e2e.TestUtils.buildWorkflow
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.amber.operator.aggregate.{AggregateOpDesc, AggregationFunction}
import org.apache.texera.amber.operator.keywordSearch.KeywordSearchOpDesc
import org.apache.texera.amber.operator.source.scan.csv.CSVScanSourceOpDesc
import org.apache.texera.dao.MockTexeraDB
import org.apache.texera.dao.jooq.generated.enums.UserRoleEnum
import org.apache.texera.dao.jooq.generated.tables.daos._
import org.apache.texera.dao.jooq.generated.tables.pojos._
import org.apache.texera.workflow.LogicalLink
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import java.net.URI
import java.sql.Timestamp
import scala.jdk.CollectionConverters.CollectionHasAsScala
class DefaultCostEstimatorSpec
extends AnyFlatSpec
with BeforeAndAfterAll
with BeforeAndAfterEach
with MockTexeraDB {
private val headerlessCsvOpDesc: CSVScanSourceOpDesc =
TestOperators.headerlessSmallCsvScanOpDesc()
private val keywordOpDesc: KeywordSearchOpDesc =
TestOperators.keywordSearchOpDesc("column-1", "Asia")
private val groupByOpDesc: AggregateOpDesc =
TestOperators.aggregateAndGroupByDesc("column-1", AggregationFunction.COUNT, List[String]())
private val testUser: User = {
val user = new User
user.setUid(Integer.valueOf(1))
user.setName("test_user")
user.setRole(UserRoleEnum.ADMIN)
user.setPassword("123")
user.setEmail("test_user@test.com")
user
}
private val testWorkflowEntry: Workflow = {
val workflow = new Workflow
workflow.setName("test workflow")
workflow.setWid(Integer.valueOf(1))
workflow.setContent("test workflow content")
workflow.setDescription("test description")
workflow
}
private val testWorkflowVersionEntry: WorkflowVersion = {
val workflowVersion = new WorkflowVersion
workflowVersion.setWid(Integer.valueOf(1))
workflowVersion.setVid(Integer.valueOf(1))
workflowVersion.setContent("test version content")
workflowVersion
}
private val testWorkflowExecutionEntry: WorkflowExecutions = {
val workflowExecution = new WorkflowExecutions
workflowExecution.setEid(Integer.valueOf(1))
workflowExecution.setVid(Integer.valueOf(1))
workflowExecution.setUid(Integer.valueOf(1))
workflowExecution.setStatus(3.toByte)
workflowExecution.setEnvironmentVersion("test engine")
workflowExecution
}
private var uri: URI = _
private var writer: BufferedItemWriter[Tuple] = _
private var document: VirtualDocument[_] = _
override protected def beforeEach(): Unit = {
initializeDBAndReplaceDSLContext()
uri = VFSURIFactory.createRuntimeStatisticsURI(
WorkflowIdentity(testWorkflowEntry.getWid.longValue()),
ExecutionIdentity(testWorkflowExecutionEntry.getEid.longValue())
)
document = DocumentFactory.createDocument(uri, ResultSchema.runtimeStatisticsSchema)
writer = document
.writer(s"runtime_statistics_${testWorkflowExecutionEntry.getEid.longValue()}")
.asInstanceOf[BufferedItemWriter[Tuple]]
writer.open()
}
override protected def afterEach(): Unit = {
document.clear()
shutdownDB()
}
"DefaultCostEstimator" should "use fallback method when no past statistics are available" in {
val workflow = buildWorkflow(
List(headerlessCsvOpDesc, keywordOpDesc),
List(
LogicalLink(
headerlessCsvOpDesc.operatorIdentifier,
PortIdentity(0),
keywordOpDesc.operatorIdentifier,
PortIdentity(0)
)
),
new WorkflowContext()
)
val resourceAllocator =
new DefaultResourceAllocator(
workflow.physicalPlan,
new ExecutionClusterInfo(),
workflow.context.workflowSettings
)
val costEstimator = new DefaultCostEstimator(
workflow.context,
resourceAllocator,
CONTROLLER
)
val ports = workflow.physicalPlan.operators.flatMap(op =>
op.inputPorts.keys
.map(inputPortId => GlobalPortIdentity(op.id, inputPortId, input = true))
.toSet ++ op.outputPorts.keys
.map(outputPortId => GlobalPortIdentity(op.id, outputPortId))
.toSet
)
val region = Region(
id = RegionIdentity(0),
physicalOps = workflow.physicalPlan.operators,
physicalLinks = workflow.physicalPlan.links,
ports = ports
)
val (_, costOfRegion) = costEstimator.allocateResourcesAndEstimateCost(region, 1)
assert(costOfRegion == 0)
}
"DefaultCostEstimator" should "use the latest successful execution to estimate cost when available" in {
val workflow = buildWorkflow(
List(headerlessCsvOpDesc, keywordOpDesc),
List(
LogicalLink(
headerlessCsvOpDesc.operatorIdentifier,
PortIdentity(0),
keywordOpDesc.operatorIdentifier,
PortIdentity(0)
)
),
new WorkflowContext()
)
val userDao = new UserDao(getDSLContext.configuration())
val workflowDao = new WorkflowDao(getDSLContext.configuration())
val workflowExecutionsDao = new WorkflowExecutionsDao(getDSLContext.configuration())
val workflowVersionDao = new WorkflowVersionDao(getDSLContext.configuration())
userDao.insert(testUser)
workflowDao.insert(testWorkflowEntry)
workflowVersionDao.insert(testWorkflowVersionEntry)
testWorkflowExecutionEntry.setRuntimeStatsUri(uri.toString)
workflowExecutionsDao.insert(testWorkflowExecutionEntry)
val headerlessCsvOpRuntimeStatistics = new Tuple(
ResultSchema.runtimeStatisticsSchema,
Array(
headerlessCsvOpDesc.operatorIdentifier.id,
new Timestamp(System.currentTimeMillis()),
0L,
0L,
0L,
0L,
100L,
100L,
0L,
1,
0
)
)
val keywordOpRuntimeStatistics = new Tuple(
ResultSchema.runtimeStatisticsSchema,
Array(
keywordOpDesc.operatorIdentifier.id,
new Timestamp(System.currentTimeMillis()),
0L,
0L,
0L,
0L,
300L,
300L,
0L,
1,
0
)
)
writer.putOne(headerlessCsvOpRuntimeStatistics)
writer.putOne(keywordOpRuntimeStatistics)
writer.close()
val resourceAllocator =
new DefaultResourceAllocator(
workflow.physicalPlan,
new ExecutionClusterInfo(),
workflow.context.workflowSettings
)
val costEstimator = new DefaultCostEstimator(
workflow.context,
resourceAllocator,
CONTROLLER
)
val ports = workflow.physicalPlan.operators.flatMap(op =>
op.inputPorts.keys
.map(inputPortId => GlobalPortIdentity(op.id, inputPortId, input = true))
.toSet ++ op.outputPorts.keys
.map(outputPortId => GlobalPortIdentity(op.id, outputPortId))
.toSet
)
val region = Region(
id = RegionIdentity(0),
physicalOps = workflow.physicalPlan.operators,
physicalLinks = workflow.physicalPlan.links,
ports = ports
)
val (_, costOfRegion) = costEstimator.allocateResourcesAndEstimateCost(region, 1)
assert(costOfRegion != 0)
}
"DefaultCostEstimator" should "use correctly estimate costs in a search" in {
val workflow = buildWorkflow(
List(headerlessCsvOpDesc, groupByOpDesc, keywordOpDesc),
List(
LogicalLink(
headerlessCsvOpDesc.operatorIdentifier,
PortIdentity(0),
groupByOpDesc.operatorIdentifier,
PortIdentity(0)
),
LogicalLink(
groupByOpDesc.operatorIdentifier,
PortIdentity(0),
keywordOpDesc.operatorIdentifier,
PortIdentity(0)
)
),
new WorkflowContext()
)
val userDao = new UserDao(getDSLContext.configuration())
val workflowDao = new WorkflowDao(getDSLContext.configuration())
val workflowExecutionsDao = new WorkflowExecutionsDao(getDSLContext.configuration())
val workflowVersionDao = new WorkflowVersionDao(getDSLContext.configuration())
userDao.insert(testUser)
workflowDao.insert(testWorkflowEntry)
workflowVersionDao.insert(testWorkflowVersionEntry)
testWorkflowExecutionEntry.setRuntimeStatsUri(uri.toString)
workflowExecutionsDao.insert(testWorkflowExecutionEntry)
val headerlessCsvOpRuntimeStatistics = new Tuple(
ResultSchema.runtimeStatisticsSchema,
Array(
headerlessCsvOpDesc.operatorIdentifier.id,
new Timestamp(System.currentTimeMillis()),
0L,
0L,
0L,
0L,
100L,
100L,
0L,
1,
0
)
)
val groupByOpRuntimeStatistics = new Tuple(
ResultSchema.runtimeStatisticsSchema,
Array(
groupByOpDesc.operatorIdentifier.id,
new Timestamp(System.currentTimeMillis()),
0L,
0L,
0L,
0L,
1000L,
1000L,
0L,
1,
0
)
)
val keywordOpRuntimeStatistics = new Tuple(
ResultSchema.runtimeStatisticsSchema,
Array(
keywordOpDesc.operatorIdentifier.id,
new Timestamp(System.currentTimeMillis()),
0L,
0L,
0L,
0L,
300L,
300L,
0L,
1,
0
)
)
writer.putOne(headerlessCsvOpRuntimeStatistics)
writer.putOne(groupByOpRuntimeStatistics)
writer.putOne(keywordOpRuntimeStatistics)
writer.close()
// Should contain two regions, one with CSV->localAgg->globalAgg, another with keyword
val searchResult = new CostBasedScheduleGenerator(
workflow.context,
workflow.physicalPlan,
CONTROLLER
).bottomUpSearch()
val groupByRegion =
searchResult.regionDAG.vertexSet().asScala.filter(region => region.physicalOps.size == 3).head
val keywordRegion =
searchResult.regionDAG.vertexSet().asScala.filter(region => region.physicalOps.size == 1).head
val resourceAllocator =
new DefaultResourceAllocator(
workflow.physicalPlan,
new ExecutionClusterInfo(),
workflow.context.workflowSettings
)
val costEstimator = new DefaultCostEstimator(
workflow.context,
resourceAllocator,
CONTROLLER
)
val (_, groupByRegionCost) = costEstimator.allocateResourcesAndEstimateCost(groupByRegion, 1)
val groupByOperatorCost = (groupByOpRuntimeStatistics.getField(6).asInstanceOf[Long] +
groupByOpRuntimeStatistics.getField(7).asInstanceOf[Long]) / 1e9
// The cost of the first region should be the cost of the GroupBy operator (note the two physical operators for
// the GroupBy logical operator have the same cost because we use logical operator in the statistics.
// The GroupBy operator has a longer running time.
assert(groupByRegionCost == groupByOperatorCost)
val (_, keywordRegionCost) = costEstimator.allocateResourcesAndEstimateCost(keywordRegion, 1)
val keywordOperatorCost = (keywordOpRuntimeStatistics.getField(6).asInstanceOf[Long] +
keywordOpRuntimeStatistics.getField(7).asInstanceOf[Long]) / 1e9
// The cost of the second region should be the cost of the keyword operator.
assert(keywordRegionCost == keywordOperatorCost)
// The cost of the region plan should be the sum of region costs
assert(searchResult.cost == groupByRegionCost + keywordRegionCost)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/ExpansionGreedyScheduleGeneratorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.e2e.TestUtils.buildWorkflow
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.amber.operator.split.SplitOpDesc
import org.apache.texera.amber.operator.udf.python.{
DualInputPortsPythonUDFOpDescV2,
PythonUDFOpDescV2
}
import org.apache.texera.workflow.LogicalLink
import org.scalamock.scalatest.MockFactory
import org.scalatest.flatspec.AnyFlatSpec
@deprecated("This greedy schedule generator test will be removed in the future.")
class ExpansionGreedyScheduleGeneratorSpec extends AnyFlatSpec with MockFactory {
"RegionPlanGenerator" should "correctly find regions in headerlessCsv->keyword workflow" in {
val headerlessCsvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val workflow = buildWorkflow(
List(headerlessCsvOpDesc, keywordOpDesc),
List(
LogicalLink(
headerlessCsvOpDesc.operatorIdentifier,
PortIdentity(0),
keywordOpDesc.operatorIdentifier,
PortIdentity(0)
)
),
new WorkflowContext()
)
val (schedule, _) = new ExpansionGreedyScheduleGenerator(
workflow.context,
workflow.physicalPlan
).generate()
// Assuming each level only has one region
val regionList = schedule.toList.map(level => level.head)
assert(regionList.size == 1)
regionList.zip(Iterator(2)).foreach {
case (region, opCount) =>
assert(region.getOperators.size == opCount)
}
regionList.zip(Iterator(1)).foreach {
case (region, linkCount) =>
assert(region.getLinks.size == linkCount)
}
regionList.zip(Iterator(3)).foreach {
case (region, portCount) =>
assert(region.getPorts.size == portCount)
}
}
"RegionPlanGenerator" should "correctly find regions in csv->(csv->)->join workflow" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val headerlessCsvOpDesc2 = TestOperators.headerlessSmallCsvScanOpDesc()
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
headerlessCsvOpDesc2,
joinOpDesc
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc2.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
)
),
new WorkflowContext()
)
val (schedule, _) = new ExpansionGreedyScheduleGenerator(
workflow.context,
workflow.physicalPlan
).generate()
// Assuming each level only has one region
val regionList = schedule.toList.map(level => level.head)
assert(regionList.size == 2)
regionList.zip(Iterator(2, 2)).foreach {
case (region, opCount) =>
assert(region.getOperators.size == opCount)
}
regionList.zip(Iterator(1, 1)).foreach {
case (region, linkCount) =>
assert(region.getLinks.size == linkCount)
}
regionList.zip(Iterator(3, 4)).foreach {
case (region, portCount) =>
assert(region.getPorts.size == portCount)
}
// The fist region should be the build region
assert(
regionList.head.getOperators
.map(_.id)
.exists(physicalOpId =>
OperatorIdentity(physicalOpId.logicalOpId.id) == headerlessCsvOpDesc1.operatorIdentifier
)
)
// The second region should be the probe region
assert(
regionList(1).getOperators
.map(_.id)
.exists(physicalOpId =>
OperatorIdentity(physicalOpId.logicalOpId.id) == headerlessCsvOpDesc2.operatorIdentifier
)
)
}
"RegionPlanGenerator" should "correctly find regions in csv->->filter->join workflow" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
keywordOpDesc,
joinOpDesc
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
)
),
new WorkflowContext()
)
val (schedule, _) = new ExpansionGreedyScheduleGenerator(
workflow.context,
workflow.physicalPlan
).generate()
// Assuming each level only has one region
val regionList = schedule.toList.map(level => level.head)
assert(regionList.size == 2)
regionList.zip(Iterator(3, 1)).foreach {
case (region, opCount) =>
assert(region.getOperators.size == opCount)
}
regionList.zip(Iterator(2, 0)).foreach {
case (region, linkCount) =>
assert(region.getLinks.size == linkCount)
}
regionList.zip(Iterator(5, 3)).foreach {
case (region, portCount) =>
assert(region.getPorts.size == portCount)
}
}
//
"RegionPlanGenerator" should "correctly find regions in buildcsv->probecsv->hashjoin->hashjoin workflow" in {
val buildCsv = TestOperators.headerlessSmallCsvScanOpDesc()
val probeCsv = TestOperators.smallCsvScanOpDesc()
val hashJoin1 = TestOperators.joinOpDesc("column-1", "Region")
val hashJoin2 = TestOperators.joinOpDesc("column-2", "Country")
val workflow = buildWorkflow(
List(
buildCsv,
probeCsv,
hashJoin1,
hashJoin2
),
List(
LogicalLink(
buildCsv.operatorIdentifier,
PortIdentity(),
hashJoin1.operatorIdentifier,
PortIdentity()
),
LogicalLink(
probeCsv.operatorIdentifier,
PortIdentity(),
hashJoin1.operatorIdentifier,
PortIdentity(1)
),
LogicalLink(
buildCsv.operatorIdentifier,
PortIdentity(),
hashJoin2.operatorIdentifier,
PortIdentity()
),
LogicalLink(
hashJoin1.operatorIdentifier,
PortIdentity(),
hashJoin2.operatorIdentifier,
PortIdentity(1)
)
),
new WorkflowContext()
)
val (schedule, _) = new ExpansionGreedyScheduleGenerator(
workflow.context,
workflow.physicalPlan
).generate()
// Assuming each level only has one region
val regionList = schedule.toList.map(level => level.head)
assert(regionList.size == 2)
regionList.zip(Iterator(3, 3)).foreach {
case (region, opCount) =>
assert(region.getOperators.size == opCount)
}
regionList.zip(Iterator(2, 2)).foreach {
case (region, linkCount) =>
assert(region.getLinks.size == linkCount)
}
regionList.zip(Iterator(5, 7)).foreach {
case (region, portCount) =>
assert(region.getPorts.size == portCount)
}
}
"RegionPlanGenerator" should "correctly find regions in csv->split->training-infer workflow" in {
val csv = TestOperators.headerlessSmallCsvScanOpDesc()
val split = new SplitOpDesc()
val training = new PythonUDFOpDescV2()
val inference = new DualInputPortsPythonUDFOpDescV2()
val workflow = buildWorkflow(
List(
csv,
split,
training,
inference
),
List(
LogicalLink(
csv.operatorIdentifier,
PortIdentity(),
split.operatorIdentifier,
PortIdentity()
),
LogicalLink(
split.operatorIdentifier,
PortIdentity(),
training.operatorIdentifier,
PortIdentity()
),
LogicalLink(
training.operatorIdentifier,
PortIdentity(),
inference.operatorIdentifier,
PortIdentity()
),
LogicalLink(
split.operatorIdentifier,
PortIdentity(1),
inference.operatorIdentifier,
PortIdentity(1)
)
),
new WorkflowContext()
)
val (schedule, _) = new ExpansionGreedyScheduleGenerator(
workflow.context,
workflow.physicalPlan
).generate()
val regionList = schedule.toList.map(level => level.head)
assert(regionList.size == 2)
regionList.zip(Iterator(3, 1)).foreach {
case (region, opCount) =>
assert(region.getOperators.size == opCount)
}
regionList.zip(Iterator(2, 0)).foreach {
case (region, linkCount) =>
assert(region.getLinks.size == linkCount)
}
regionList.zip(Iterator(6, 3)).foreach {
case (region, portCount) =>
assert(region.getPorts.size == portCount)
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionCoordinatorTestSupport.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import com.twitter.util.{Await, Duration, Future}
import org.apache.pekko.actor.{Actor, ActorRef, Props}
import org.apache.pekko.testkit.{TestActorRef, TestKit}
import org.apache.texera.amber.core.executor.OpExecWithClassName
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ChannelIdentity,
OperatorIdentity,
PhysicalOpIdentity
}
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.apache.texera.amber.core.workflow.WorkflowContext.{
DEFAULT_EXECUTION_ID,
DEFAULT_WORKFLOW_ID
}
import org.apache.texera.amber.engine.architecture.common.{
PekkoActorRefMappingService,
PekkoActorService,
WorkflowActor
}
import org.apache.texera.amber.engine.architecture.controller.execution.WorkflowExecution
import org.apache.texera.amber.engine.architecture.messaginglayer.{
NetworkInputGateway,
NetworkOutputGateway
}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.ControlInvocation
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
import org.apache.texera.amber.engine.architecture.scheduling.config.{
OperatorConfig,
ResourceConfig,
WorkerConfig
}
import org.apache.texera.amber.engine.architecture.worker.statistics.WorkerState
import org.apache.texera.amber.engine.common.CheckpointState
import org.apache.texera.amber.engine.common.ambermessage.WorkflowFIFOMessage
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.apache.texera.amber.util.VirtualIdentityUtils
import scala.collection.mutable
object RegionCoordinatorTestSupport {
val InitializeExecutor = "initializeExecutor"
val OpenExecutor = "openExecutor"
val StartWorker = "startWorker"
val EndWorker = "endWorker"
// Generous deadline for the polling helpers below. Production timing under test (notably the
// 200 ms `killRetryDelay` in `RegionExecutionCoordinator`) fits comfortably; the rest is
// headroom for slow CI.
val testTimeout: Duration = Duration.fromSeconds(5)
case class WorkerRpcCall(
methodName: String,
receiver: ActorVirtualIdentity,
commandId: Long
)
case class ControllerHarnessFixture(
actorService: PekkoActorService,
actorRefService: PekkoActorRefMappingService
)
/**
* Captures controller-to-worker RPCs at the same boundary used by production
* `AsyncRPCClient.workerInterface`.
*
* Non-termination RPCs are completed immediately because these tests focus on termination
* ordering. `endWorker` responses are controlled by `endWorkerResponse`, allowing each test to
* hold termination pending, fail an attempt, or allow it to succeed.
*/
class ControllerRpcProbe(endWorkerResponse: WorkerRpcCall => Option[ControlReturn]) {
val calls: mutable.ArrayBuffer[WorkerRpcCall] = mutable.ArrayBuffer()
val inputGateway = new NetworkInputGateway(CONTROLLER)
val outputGateway = new NetworkOutputGateway(CONTROLLER, handleOutput)
val asyncRPCClient = new AsyncRPCClient(inputGateway, outputGateway, CONTROLLER)
def methodTrace: Seq[String] = calls.map(_.methodName).toSeq
def initializedWorkers: Seq[ActorVirtualIdentity] =
calls.filter(_.methodName == InitializeExecutor).map(_.receiver).toSeq
def startedWorkers: Seq[ActorVirtualIdentity] =
calls.filter(_.methodName == StartWorker).map(_.receiver).toSeq
def endWorkerCalls: Seq[WorkerRpcCall] =
calls.filter(_.methodName == EndWorker).toSeq
def onlyEndWorkerCall: WorkerRpcCall = {
assert(endWorkerCalls.size == 1)
endWorkerCalls.head
}
def fulfill(call: WorkerRpcCall, returnValue: ControlReturn): Unit = {
asyncRPCClient.fulfillPromise(ReturnInvocation(call.commandId, returnValue))
}
private def handleOutput(message: WorkflowFIFOMessage): Unit = {
message.payload match {
case invocation: ControlInvocation =>
recordAndMaybeFulfill(invocation)
case _ =>
// Client events and stats updates are irrelevant to the coordinator lifecycle assertions.
}
}
private def recordAndMaybeFulfill(invocation: ControlInvocation): Unit = {
val call = WorkerRpcCall(
methodName = invocation.methodName,
receiver = invocation.context.receiver,
commandId = invocation.commandId
)
calls += call
immediateReturn(call).foreach(fulfill(call, _))
}
private def immediateReturn(call: WorkerRpcCall): Option[ControlReturn] = {
call.methodName match {
case InitializeExecutor | OpenExecutor =>
Some(EmptyReturn())
case StartWorker =>
Some(WorkerStateResponse(WorkerState.RUNNING))
case EndWorker =>
endWorkerResponse(call)
case other =>
throw new AssertionError(s"Unexpected worker RPC in test: $other")
}
}
}
class IdleActor extends Actor {
override def receive: Receive = { case _ => () }
}
class ControllerHarness extends WorkflowActor(None, CONTROLLER) {
override def handleInputMessage(id: Long, workflowMsg: WorkflowFIFOMessage): Unit = ()
override def getQueuedCredit(channelId: ChannelIdentity): Long = 0
override def handleBackpressure(isBackpressured: Boolean): Unit = ()
override def initState(): Unit = ()
override def loadFromCheckpoint(chkpt: CheckpointState): Unit = ()
}
def createSourceOp(logicalOpId: String): PhysicalOp =
PhysicalOp.sourcePhysicalOp(
PhysicalOpIdentity(OperatorIdentity(logicalOpId), "main"),
DEFAULT_WORKFLOW_ID,
DEFAULT_EXECUTION_ID,
OpExecWithClassName("unused")
)
def createWorkerId(physicalOp: PhysicalOp): ActorVirtualIdentity =
VirtualIdentityUtils.createWorkerIdentity(DEFAULT_WORKFLOW_ID, physicalOp.id, 0)
def createSingleWorkerRegion(
regionId: Long,
physicalOp: PhysicalOp,
workerId: ActorVirtualIdentity
): Region =
Region(
RegionIdentity(regionId),
physicalOps = Set(physicalOp),
physicalLinks = Set.empty,
resourceConfig = Some(
ResourceConfig(
operatorConfigs = Map(physicalOp.id -> OperatorConfig(List(WorkerConfig(workerId))))
)
)
)
def seedReusableWorkerExecution(
workflowExecution: WorkflowExecution,
seedRegionId: Long,
physicalOp: PhysicalOp,
workerId: ActorVirtualIdentity
): Unit = {
// RegionExecutionCoordinator skips real worker creation when an execution for this operator
// already exists.
workflowExecution
.initRegionExecution(createSingleWorkerRegion(seedRegionId, physicalOp, workerId))
.initOperatorExecution(physicalOp.id)
.initWorkerExecution(workerId)
}
def await[T](future: Future[T]): T = Await.result(future, testTimeout)
def waitUntil(condition: => Boolean): Unit = {
val deadline = System.nanoTime() + testTimeout.inNanoseconds
while (!condition && System.nanoTime() < deadline) {
Thread.sleep(20)
}
assert(condition, s"condition not satisfied within $testTimeout")
}
}
trait RegionCoordinatorTestSupport { self: TestKit =>
import RegionCoordinatorTestSupport._
protected def createControllerHarness(): ControllerHarnessFixture = {
val controllerRef = TestActorRef(new ControllerHarness)
controllerRef.underlyingActor.actorService.getAvailableNodeAddressesFunc = () =>
Array(controllerRef.path.address)
ControllerHarnessFixture(
actorService = controllerRef.underlyingActor.actorService,
actorRefService = controllerRef.underlyingActor.actorRefMappingService
)
}
protected def registerLiveWorker(
actorRefService: PekkoActorRefMappingService,
workerId: ActorVirtualIdentity
): ActorRef = {
val workerRef = system.actorOf(Props(new IdleActor), s"worker-${System.nanoTime()}")
actorRefService.registerActorRef(workerId, workerRef)
workerRef
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinatorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import com.twitter.util.Future
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.testkit.TestKit
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.apache.texera.amber.engine.architecture.common.PekkoActorRefMappingService
import org.apache.texera.amber.engine.architecture.controller.ControllerConfig
import org.apache.texera.amber.engine.architecture.controller.execution.WorkflowExecution
import org.apache.texera.amber.engine.architecture.rpc.controlreturns._
import org.apache.texera.amber.engine.architecture.scheduling.RegionCoordinatorTestSupport._
import org.apache.texera.amber.engine.architecture.worker.statistics.WorkerState
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpecLike
import java.util.concurrent.atomic
/**
* Tests the real region-coordination lifecycle around synchronous region kill.
*
* The tests let the coordinator call the real `AsyncRPCClient.workerInterface`, capture the generated
* `ControlInvocation`s at the controller output gateway, and fulfill those RPC promises
* explicitly. This keeps the important production behavior under test:
*
* - regular launch RPCs (`initializeExecutor`, `openExecutor`, `startWorker`) are allowed to
* complete immediately;
* - `endWorker` can be held pending or failed to model worker-side drain/termination behavior;
* - the real coordinator then decides when to remove actor refs, clean control channels, mark
* workers terminated, and allow the next region to start.
*/
class RegionExecutionCoordinatorSpec
extends TestKit(ActorSystem("RegionExecutionCoordinatorSpec", AmberRuntime.pekkoConfig))
with AnyFlatSpecLike
with BeforeAndAfterAll
with RegionCoordinatorTestSupport {
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
"RegionExecutionCoordinator" should "send gracefulStop only after EndWorker succeeds" in {
val fixture = createSingleRegionFixture(endWorkerResponse = _ => None)
launchRegion(fixture.coordinator)
val completion = requestRegionCompletion(fixture.coordinator)
assert(
fixture.rpcProbe.methodTrace == Seq(InitializeExecutor, OpenExecutor, StartWorker, EndWorker)
)
assert(completion.poll.isEmpty)
assert(!fixture.coordinator.isCompleted)
assert(fixture.actorRefService.hasActorRef(fixture.workerId))
fixture.rpcProbe.fulfill(fixture.rpcProbe.onlyEndWorkerCall, EmptyReturn())
await(completion)
assert(fixture.coordinator.isCompleted)
assert(!fixture.actorRefService.hasActorRef(fixture.workerId))
assert(workerState(fixture) == WorkerState.TERMINATED)
assertControlChannelsAreRemoved(fixture)
}
it should "retry EndWorker failures and delay gracefulStop until a retry succeeds" in {
val attempts = new atomic.AtomicInteger(0)
val fixture = createSingleRegionFixture(endWorkerResponse =
_ =>
if (attempts.incrementAndGet() == 1) {
Some(transientEndWorkerFailure)
} else {
None
}
)
launchRegion(fixture.coordinator)
val completion = requestRegionCompletion(fixture.coordinator)
waitUntil(fixture.rpcProbe.endWorkerCalls.size >= 2)
assert(completion.poll.isEmpty)
assert(!fixture.coordinator.isCompleted)
assert(fixture.actorRefService.hasActorRef(fixture.workerId))
fixture.rpcProbe.fulfill(fixture.rpcProbe.endWorkerCalls.last, EmptyReturn())
await(completion)
assert(fixture.coordinator.isCompleted)
assert(fixture.rpcProbe.endWorkerCalls.size == 2)
assert(!fixture.actorRefService.hasActorRef(fixture.workerId))
assert(workerState(fixture) == WorkerState.TERMINATED)
}
private case class SingleRegionFixture(
coordinator: RegionExecutionCoordinator,
rpcProbe: ControllerRpcProbe,
workflowExecution: WorkflowExecution,
region: Region,
physicalOp: PhysicalOp,
workerId: ActorVirtualIdentity,
actorRefService: PekkoActorRefMappingService
)
private def createSingleRegionFixture(
endWorkerResponse: WorkerRpcCall => Option[ControlReturn]
): SingleRegionFixture = {
val physicalOp = createSourceOp("test-op")
val workerId = createWorkerId(physicalOp)
val region = createSingleWorkerRegion(1, physicalOp, workerId)
val workflowExecution = WorkflowExecution()
seedReusableWorkerExecution(workflowExecution, seedRegionId = 0, physicalOp, workerId)
workflowExecution.initRegionExecution(region)
val rpcProbe = new ControllerRpcProbe(endWorkerResponse)
val controller = createControllerHarness()
registerLiveWorker(controller.actorRefService, workerId)
// Seed stale control channels to verify that successful termination removes them.
rpcProbe.inputGateway.getChannel(ChannelIdentity(workerId, CONTROLLER, isControl = true))
rpcProbe.outputGateway.getSequenceNumber(
ChannelIdentity(CONTROLLER, workerId, isControl = true)
)
val coordinator = new RegionExecutionCoordinator(
region,
isRestart = false,
workflowExecution,
rpcProbe.asyncRPCClient,
ControllerConfig(None, None, None, None),
controller.actorService,
controller.actorRefService
)
SingleRegionFixture(
coordinator = coordinator,
rpcProbe = rpcProbe,
workflowExecution = workflowExecution,
region = region,
physicalOp = physicalOp,
workerId = workerId,
actorRefService = controller.actorRefService
)
}
private def launchRegion(coordinator: RegionExecutionCoordinator): Unit = {
await(coordinator.syncStatusAndTransitionRegionExecutionPhase())
}
private def requestRegionCompletion(
coordinator: RegionExecutionCoordinator
): Future[Unit] = {
coordinator.syncStatusAndTransitionRegionExecutionPhase()
}
private def workerState(fixture: SingleRegionFixture): WorkerState =
fixture.workflowExecution
.getRegionExecution(fixture.region.id)
.getOperatorExecution(fixture.physicalOp.id)
.getWorkerExecution(fixture.workerId)
.getState
private def assertControlChannelsAreRemoved(fixture: SingleRegionFixture): Unit = {
assert(
!fixture.rpcProbe.inputGateway.getAllControlChannels.exists(
_.channelId == ChannelIdentity(fixture.workerId, CONTROLLER, isControl = true)
)
)
assert(
!fixture.rpcProbe.outputGateway.getActiveChannels.exists(
_ == ChannelIdentity(CONTROLLER, fixture.workerId, isControl = true)
)
)
}
private def transientEndWorkerFailure: ControlError =
ControlError(
errorMessage = "transient EndWorker failure",
errorDetails = "",
stackTrace = "",
language = ErrorLanguage.SCALA
)
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionPlanSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.virtualidentity.{
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.{
GlobalPortIdentity,
PhysicalLink,
PhysicalOp,
PortIdentity
}
import org.scalatest.flatspec.AnyFlatSpec
class RegionPlanSpec extends AnyFlatSpec {
private def physicalOpId(opId: String): PhysicalOpIdentity =
PhysicalOpIdentity(OperatorIdentity(opId), "main")
private def op(opId: String): PhysicalOp =
PhysicalOp(
physicalOpId(opId),
WorkflowIdentity(0),
ExecutionIdentity(0),
OpExecInitInfo.Empty
)
private def link(fromOp: String, toOp: String): PhysicalLink =
PhysicalLink(physicalOpId(fromOp), PortIdentity(0), physicalOpId(toOp), PortIdentity(0))
private def globalPort(opId: String): GlobalPortIdentity =
GlobalPortIdentity(physicalOpId(opId), PortIdentity(0), input = true)
"RegionPlan.getRegion" should "return the region with the given id" in {
val r0 = Region(RegionIdentity(0), Set(op("a")), Set.empty)
val r1 = Region(RegionIdentity(1), Set(op("b")), Set.empty)
val plan = RegionPlan(Set(r0, r1), Set.empty)
assert(plan.getRegion(RegionIdentity(0)) == r0)
assert(plan.getRegion(RegionIdentity(1)) == r1)
}
it should "throw NoSuchElementException for an unknown region id" in {
val plan = RegionPlan(Set(Region(RegionIdentity(0), Set(op("a")), Set.empty)), Set.empty)
assertThrows[NoSuchElementException] {
plan.getRegion(RegionIdentity(99))
}
}
"RegionPlan.getRegionOfLink" should "return the region whose physicalLinks include the link" in {
val ab = link("a", "b")
val r0 = Region(RegionIdentity(0), Set(op("a"), op("b")), Set(ab))
val r1 = Region(RegionIdentity(1), Set(op("c")), Set.empty)
val plan = RegionPlan(Set(r0, r1), Set.empty)
assert(plan.getRegionOfLink(ab) == r0)
}
it should "throw NoSuchElementException when no region claims the link" in {
val r0 = Region(RegionIdentity(0), Set(op("a")), Set.empty)
val plan = RegionPlan(Set(r0), Set.empty)
assertThrows[NoSuchElementException] {
plan.getRegionOfLink(link("a", "missing"))
}
}
"RegionPlan.getRegionOfPortId" should "find the region whose ports contain the global port id" in {
val portA = globalPort("a")
val r0 = Region(RegionIdentity(0), Set(op("a")), Set.empty, ports = Set(portA))
val r1 = Region(RegionIdentity(1), Set(op("b")), Set.empty)
val plan = RegionPlan(Set(r0, r1), Set.empty)
assert(plan.getRegionOfPortId(portA).contains(r0))
}
it should "return None when no region claims the port" in {
val r0 = Region(RegionIdentity(0), Set(op("a")), Set.empty)
val plan = RegionPlan(Set(r0), Set.empty)
assert(plan.getRegionOfPortId(globalPort("missing")).isEmpty)
}
"RegionPlan.topologicalIterator" should "yield region ids in topological order based on regionLinks" in {
val r0 = Region(RegionIdentity(0), Set(op("a")), Set.empty)
val r1 = Region(RegionIdentity(1), Set(op("b")), Set.empty)
val r2 = Region(RegionIdentity(2), Set(op("c")), Set.empty)
val plan = RegionPlan(
regions = Set(r0, r1, r2),
regionLinks = Set(
RegionLink(r0.id, r1.id),
RegionLink(r1.id, r2.id)
)
)
assert(plan.topologicalIterator().toList == List(r0.id, r1.id, r2.id))
}
// ---------------------------------------------------------------------------
// Larger / more complex region plan exercises
// ---------------------------------------------------------------------------
/**
* Build a "diamond" of regions:
*
* src
* / \
* mid1 mid2 mid3 (all parallel siblings of src)
* \ /
* sink
*
* src fans out to three middle regions; all three middle regions feed a
* single sink. Each region carries multiple operators and multiple links.
*/
private def buildDiamondPlan(): RegionPlan = {
val src = Region(
RegionIdentity(0),
physicalOps = Set(op("src-1"), op("src-2"), op("src-3")),
physicalLinks = Set(link("src-1", "src-2"), link("src-2", "src-3"))
)
val mid1 = Region(
RegionIdentity(1),
physicalOps = Set(op("mid1-1"), op("mid1-2")),
physicalLinks = Set(link("mid1-1", "mid1-2")),
ports = Set(globalPort("mid1-1"))
)
val mid2 = Region(
RegionIdentity(2),
physicalOps = Set(op("mid2-1")),
physicalLinks = Set.empty,
ports = Set(globalPort("mid2-1"))
)
val mid3 = Region(
RegionIdentity(3),
physicalOps = Set(op("mid3-1"), op("mid3-2"), op("mid3-3"), op("mid3-4")),
physicalLinks = Set(
link("mid3-1", "mid3-2"),
link("mid3-2", "mid3-3"),
link("mid3-3", "mid3-4")
)
)
val sink = Region(
RegionIdentity(4),
physicalOps = Set(op("sink-1"), op("sink-2")),
physicalLinks = Set(link("sink-1", "sink-2")),
ports = Set(globalPort("sink-1"))
)
RegionPlan(
regions = Set(src, mid1, mid2, mid3, sink),
regionLinks = Set(
RegionLink(src.id, mid1.id),
RegionLink(src.id, mid2.id),
RegionLink(src.id, mid3.id),
RegionLink(mid1.id, sink.id),
RegionLink(mid2.id, sink.id),
RegionLink(mid3.id, sink.id)
)
)
}
"RegionPlan (diamond fan-out / fan-in)" should "look up every region by id" in {
val plan = buildDiamondPlan()
val ids = (0L to 4L).map(RegionIdentity).toList
ids.foreach(id => assert(plan.getRegion(id).id == id))
}
it should "find the region containing each physical link across multiple regions" in {
val plan = buildDiamondPlan()
// src has 2 internal links, mid1 has 1, mid3 has 3, sink has 1 → 7 internal links total.
val internalLinks = Seq(
("src-1", "src-2", RegionIdentity(0)),
("src-2", "src-3", RegionIdentity(0)),
("mid1-1", "mid1-2", RegionIdentity(1)),
("mid3-1", "mid3-2", RegionIdentity(3)),
("mid3-2", "mid3-3", RegionIdentity(3)),
("mid3-3", "mid3-4", RegionIdentity(3)),
("sink-1", "sink-2", RegionIdentity(4))
)
internalLinks.foreach {
case (from, to, expectedRegion) =>
assert(plan.getRegionOfLink(link(from, to)).id == expectedRegion)
}
}
it should "find each port-bearing region by its global port id" in {
val plan = buildDiamondPlan()
assert(plan.getRegionOfPortId(globalPort("mid1-1")).map(_.id).contains(RegionIdentity(1)))
assert(plan.getRegionOfPortId(globalPort("mid2-1")).map(_.id).contains(RegionIdentity(2)))
assert(plan.getRegionOfPortId(globalPort("sink-1")).map(_.id).contains(RegionIdentity(4)))
// Unknown port → None.
assert(plan.getRegionOfPortId(globalPort("not-in-plan")).isEmpty)
}
it should "produce a topological ordering with src first, sink last, and middles in the middle" in {
val plan = buildDiamondPlan()
val order = plan.topologicalIterator().toList
assert(order.size == 5)
assert(order.head == RegionIdentity(0), "src must come first")
assert(order.last == RegionIdentity(4), "sink must come last")
assert(order.slice(1, 4).toSet == Set(RegionIdentity(1), RegionIdentity(2), RegionIdentity(3)))
}
"RegionPlan.topologicalIterator" should
"respect a wide DAG with multiple parallel branches and joins" in {
// Construct a 9-region DAG:
//
// 0 ──┬──► 1 ──┬──► 4 ──┐
// │ │ │
// │ ├──► 5 ──┤
// │ │ ├──► 7 ──► 8
// ├──► 2 ──┤ │
// │ ├──► 6 ──┘
// └──► 3 ──┘
//
// 0 is the only source, 8 is the only sink. Multiple intermediate
// joins/forks make the test more meaningful than a linked list.
val ids = (0L to 8L).map(RegionIdentity)
val regs = ids.map(rid => Region(rid, Set(op(s"r${rid.id}-x")), Set.empty)).toSet
val edges = Set(
RegionLink(ids(0), ids(1)),
RegionLink(ids(0), ids(2)),
RegionLink(ids(0), ids(3)),
RegionLink(ids(1), ids(4)),
RegionLink(ids(1), ids(5)),
RegionLink(ids(2), ids(5)),
RegionLink(ids(2), ids(6)),
RegionLink(ids(3), ids(6)),
RegionLink(ids(4), ids(7)),
RegionLink(ids(5), ids(7)),
RegionLink(ids(6), ids(7)),
RegionLink(ids(7), ids(8))
)
val plan = RegionPlan(regs, edges)
val order = plan.topologicalIterator().toList
val pos = order.zipWithIndex.toMap
edges.foreach { e =>
assert(
pos(e.fromRegionId) < pos(e.toRegionId),
s"topological order violates edge $e: " +
s"${e.fromRegionId}@${pos(e.fromRegionId)} should come before " +
s"${e.toRegionId}@${pos(e.toRegionId)}"
)
}
assert(order.head == ids(0))
assert(order.last == ids(8))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.virtualidentity.{
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.{
GlobalPortIdentity,
PhysicalLink,
PhysicalOp,
PortIdentity
}
import org.scalatest.flatspec.AnyFlatSpec
class RegionSpec extends AnyFlatSpec {
private def physicalOpId(opId: String): PhysicalOpIdentity =
PhysicalOpIdentity(OperatorIdentity(opId), "main")
private def op(opId: String): PhysicalOp =
PhysicalOp(
physicalOpId(opId),
WorkflowIdentity(0),
ExecutionIdentity(0),
OpExecInitInfo.Empty
)
private def link(fromOp: String, toOp: String): PhysicalLink =
PhysicalLink(physicalOpId(fromOp), PortIdentity(0), physicalOpId(toOp), PortIdentity(0))
"Region" should "expose the physical operators provided at construction" in {
val a = op("a")
val b = op("b")
val region = Region(RegionIdentity(1), Set(a, b), Set.empty)
assert(region.getOperators == Set(a, b))
}
it should "expose the physical links provided at construction" in {
val a = op("a")
val b = op("b")
val ab = link("a", "b")
val region = Region(RegionIdentity(1), Set(a, b), Set(ab))
assert(region.getLinks == Set(ab))
}
it should "default ports to an empty set" in {
val region = Region(RegionIdentity(1), Set(op("a")), Set.empty)
assert(region.getPorts.isEmpty)
}
it should "expose the ports provided at construction" in {
val portId = GlobalPortIdentity(physicalOpId("a"), PortIdentity(0), input = true)
val region = Region(RegionIdentity(1), Set(op("a")), Set.empty, ports = Set(portId))
assert(region.getPorts == Set(portId))
}
"Region.getOperator" should "look up a physical operator by id" in {
val a = op("a")
val b = op("b")
val region = Region(RegionIdentity(1), Set(a, b), Set.empty)
assert(region.getOperator(physicalOpId("a")) == a)
assert(region.getOperator(physicalOpId("b")) == b)
}
it should "throw NoSuchElementException for an unknown operator id" in {
val region = Region(RegionIdentity(1), Set(op("a")), Set.empty)
assertThrows[NoSuchElementException] {
region.getOperator(physicalOpId("missing"))
}
}
"Region.topologicalIterator" should "yield operators in topological order based on physical links" in {
val a = op("a")
val b = op("b")
val c = op("c")
val region = Region(RegionIdentity(1), Set(a, b, c), Set(link("a", "b"), link("b", "c")))
assert(
region.topologicalIterator().toList ==
List(physicalOpId("a"), physicalOpId("b"), physicalOpId("c"))
)
}
"Region.getSourceOperators" should "treat operators without input ports as sources" in {
val a = op("a")
val b = op("b")
val region = Region(RegionIdentity(1), Set(a, b), Set.empty)
assert(region.getSourceOperators == Set(a, b))
}
"Region.getStarterOperators" should "match getSourceOperators when no resource config is provided" in {
val a = op("a")
val b = op("b")
val region = Region(RegionIdentity(1), Set(a, b), Set.empty)
assert(region.getStarterOperators == region.getSourceOperators)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/ScheduleSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.virtualidentity.{
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.scalatest.flatspec.AnyFlatSpec
class ScheduleSpec extends AnyFlatSpec {
private def region(regionId: Long, opId: String): Region = {
val physicalOp = PhysicalOp(
PhysicalOpIdentity(OperatorIdentity(opId), "main"),
WorkflowIdentity(0),
ExecutionIdentity(0),
OpExecInitInfo.Empty
)
Region(RegionIdentity(regionId), Set(physicalOp), Set.empty)
}
"Schedule.getRegions" should "return all regions across all levels" in {
val r0 = region(0, "a")
val r1a = region(1, "b")
val r1b = region(2, "c")
val schedule = Schedule(Map(0 -> Set(r0), 1 -> Set(r1a, r1b)))
assert(schedule.getRegions.toSet == Set(r0, r1a, r1b))
}
it should "return an empty list when the schedule is empty" in {
assert(Schedule(Map.empty).getRegions.isEmpty)
}
"Schedule" should "iterate level sets in ascending key order starting from the minimum level" in {
val r0 = region(0, "a")
val r1 = region(1, "b")
val r2 = region(2, "c")
val schedule = Schedule(Map(1 -> Set(r1), 0 -> Set(r0), 2 -> Set(r2)))
assert(schedule.toList == List(Set(r0), Set(r1), Set(r2)))
}
it should "report hasNext as false for an empty schedule" in {
assert(!Schedule(Map.empty).hasNext)
}
it should "report hasNext as false after the last contiguous level is consumed" in {
val schedule = Schedule(Map(0 -> Set(region(0, "a")), 1 -> Set(region(1, "b"))))
schedule.next()
schedule.next()
assert(!schedule.hasNext)
}
it should "reject construction when level keys contain a gap" in {
assertThrows[IllegalArgumentException] {
Schedule(Map(0 -> Set(region(0, "a")), 2 -> Set(region(2, "b"))))
}
}
it should "reject construction when level keys do not start at zero" in {
assertThrows[IllegalArgumentException] {
Schedule(Map(3 -> Set(region(3, "a")), 4 -> Set(region(4, "b"))))
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/SchedulingUtilsSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.virtualidentity.{
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.jgrapht.graph.DirectedAcyclicGraph
import org.scalatest.flatspec.AnyFlatSpec
import scala.jdk.CollectionConverters.CollectionHasAsScala
class SchedulingUtilsSpec extends AnyFlatSpec {
private def region(regionId: Long, opId: String): Region = {
val physicalOp = PhysicalOp(
PhysicalOpIdentity(OperatorIdentity(opId), "main"),
WorkflowIdentity(0),
ExecutionIdentity(0),
OpExecInitInfo.Empty
)
Region(RegionIdentity(regionId), Set(physicalOp), Set.empty)
}
private def newGraph(): DirectedAcyclicGraph[Region, RegionLink] =
new DirectedAcyclicGraph[Region, RegionLink](classOf[RegionLink])
"SchedulingUtils.replaceVertex" should "replace an isolated vertex with no incident edges" in {
val graph = newGraph()
val oldVertex = region(1, "a")
val newVertex = region(1, "a-prime")
graph.addVertex(oldVertex)
SchedulingUtils.replaceVertex(graph, oldVertex, newVertex)
assert(!graph.containsVertex(oldVertex))
assert(graph.containsVertex(newVertex))
assert(graph.edgeSet().isEmpty)
}
it should "rewrite outgoing edges to originate from the new vertex" in {
val graph = newGraph()
val oldVertex = region(1, "a")
val downstream = region(2, "b")
val newVertex = region(1, "a-prime")
graph.addVertex(oldVertex)
graph.addVertex(downstream)
graph.addEdge(oldVertex, downstream, RegionLink(oldVertex.id, downstream.id))
SchedulingUtils.replaceVertex(graph, oldVertex, newVertex)
assert(!graph.containsVertex(oldVertex))
assert(graph.containsVertex(newVertex))
val outgoing = graph.outgoingEdgesOf(newVertex).asScala.toList
assert(outgoing.size == 1)
assert(graph.getEdgeTarget(outgoing.head) == downstream)
assert(outgoing.head == RegionLink(newVertex.id, downstream.id))
}
it should "rewrite incoming edges to terminate at the new vertex" in {
val graph = newGraph()
val upstream = region(0, "u")
val oldVertex = region(1, "a")
val newVertex = region(1, "a-prime")
graph.addVertex(upstream)
graph.addVertex(oldVertex)
graph.addEdge(upstream, oldVertex, RegionLink(upstream.id, oldVertex.id))
SchedulingUtils.replaceVertex(graph, oldVertex, newVertex)
assert(!graph.containsVertex(oldVertex))
val incoming = graph.incomingEdgesOf(newVertex).asScala.toList
assert(incoming.size == 1)
assert(graph.getEdgeSource(incoming.head) == upstream)
assert(incoming.head == RegionLink(upstream.id, newVertex.id))
}
it should "preserve both upstream and downstream edges in a chain" in {
val graph = newGraph()
val upstream = region(0, "u")
val oldVertex = region(1, "a")
val downstream = region(2, "d")
val newVertex = region(1, "a-prime")
graph.addVertex(upstream)
graph.addVertex(oldVertex)
graph.addVertex(downstream)
graph.addEdge(upstream, oldVertex, RegionLink(upstream.id, oldVertex.id))
graph.addEdge(oldVertex, downstream, RegionLink(oldVertex.id, downstream.id))
SchedulingUtils.replaceVertex(graph, oldVertex, newVertex)
assert(graph.vertexSet().asScala.toSet == Set(upstream, newVertex, downstream))
assert(
graph.edgeSet().asScala.toSet ==
Set(
RegionLink(upstream.id, newVertex.id),
RegionLink(newVertex.id, downstream.id)
)
)
}
it should "leave the graph unchanged when old and new vertices are equal" in {
val graph = newGraph()
val upstream = region(0, "u")
val vertex = region(1, "a")
val downstream = region(2, "d")
graph.addVertex(upstream)
graph.addVertex(vertex)
graph.addVertex(downstream)
graph.addEdge(upstream, vertex, RegionLink(upstream.id, vertex.id))
graph.addEdge(vertex, downstream, RegionLink(vertex.id, downstream.id))
SchedulingUtils.replaceVertex(graph, vertex, vertex)
assert(graph.vertexSet().asScala.toSet == Set(upstream, vertex, downstream))
assert(graph.edgeSet().size == 2)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/WorkflowExecutionCoordinatorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.testkit.TestKit
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.virtualidentity.{
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.apache.texera.amber.engine.architecture.controller.ControllerConfig
import org.apache.texera.amber.engine.architecture.controller.execution.WorkflowExecution
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.EmptyReturn
import org.apache.texera.amber.engine.architecture.scheduling.RegionCoordinatorTestSupport._
import org.apache.texera.amber.engine.common.AmberRuntime
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpecLike
class WorkflowExecutionCoordinatorSpec
extends TestKit(ActorSystem("WorkflowExecutionCoordinatorSpec", AmberRuntime.pekkoConfig))
with AnyFlatSpecLike
with BeforeAndAfterAll
with RegionCoordinatorTestSupport {
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
// -- Helpers used only by the jump-to-operator-region tests --
private def jumpRegion(regionId: Long, opId: String): Region = {
val physicalOp = PhysicalOp(
PhysicalOpIdentity(OperatorIdentity(opId), "main"),
WorkflowIdentity(0),
ExecutionIdentity(0),
OpExecInitInfo.Empty
)
Region(RegionIdentity(regionId), Set(physicalOp), Set.empty)
}
private def threeLevelSchedule(): (Region, Region, Region, Schedule) = {
val first = jumpRegion(1, "first")
val second = jumpRegion(2, "second")
val third = jumpRegion(3, "third")
val schedule = Schedule(
Map(
0 -> Set(first),
1 -> Set(second),
2 -> Set(third)
)
)
(first, second, third, schedule)
}
private def newJumpCoordinator(schedule: Schedule): WorkflowExecutionCoordinator = {
val coordinator = new WorkflowExecutionCoordinator(WorkflowExecution(), null, null)
coordinator.schedule = schedule
coordinator
}
private def nextRegions(coordinator: WorkflowExecutionCoordinator): Set[Region] = {
val schedule = coordinator.schedule
if (schedule.hasNext) schedule.next() else Set.empty
}
// Mirrors what JumpToOperatorRegionHandler does: read the current schedule, scan for the
// level containing the target operator, and replace the schedule with a copy whose cursor is
// at that level.
private def jumpTo(coordinator: WorkflowExecutionCoordinator, opName: String): Unit = {
val opId = OperatorIdentity(opName)
val schedule = coordinator.schedule
schedule.levelSets
.collectFirst {
case (level, regions) if regions.exists(_.getOperators.exists(_.id.logicalOpId == opId)) =>
level
}
.foreach { targetLevel =>
coordinator.schedule = schedule.copy(initialLevelIndex = targetLevel)
}
}
"WorkflowExecutionCoordinator" should
"start the next region only after previous region termination succeeds" in {
val firstOp = createSourceOp("first-op")
val firstWorkerId = createWorkerId(firstOp)
val firstRegion = createSingleWorkerRegion(1, firstOp, firstWorkerId)
val secondOp = createSourceOp("second-op")
val secondWorkerId = createWorkerId(secondOp)
val secondRegion = createSingleWorkerRegion(2, secondOp, secondWorkerId)
val workflowExecution = WorkflowExecution()
seedReusableWorkerExecution(workflowExecution, seedRegionId = 101, firstOp, firstWorkerId)
seedReusableWorkerExecution(workflowExecution, seedRegionId = 102, secondOp, secondWorkerId)
// First region's worker holds endWorker pending until we explicitly fulfill it; the second
// region's worker terminates immediately. This lets us assert the second region cannot start
// until termination of the first finishes.
val rpcProbe = new ControllerRpcProbe(
endWorkerResponse = call => if (call.receiver == firstWorkerId) None else Some(EmptyReturn())
)
val controller = createControllerHarness()
registerLiveWorker(controller.actorRefService, firstWorkerId)
registerLiveWorker(controller.actorRefService, secondWorkerId)
val workflowCoordinator = new WorkflowExecutionCoordinator(
workflowExecution,
ControllerConfig(None, None, None, None),
rpcProbe.asyncRPCClient
)
workflowCoordinator.schedule = Schedule(Map(0 -> Set(firstRegion), 1 -> Set(secondRegion)))
workflowCoordinator.setupActorRefService(controller.actorRefService)
await(workflowCoordinator.coordinateRegionExecutors(controller.actorService))
assert(rpcProbe.startedWorkers == Seq(firstWorkerId))
val coordination = workflowCoordinator.coordinateRegionExecutors(controller.actorService)
waitUntil(rpcProbe.endWorkerCalls.size == 1)
assert(coordination.poll.isEmpty)
assert(!rpcProbe.initializedWorkers.contains(secondWorkerId))
assert(controller.actorRefService.hasActorRef(firstWorkerId))
rpcProbe.fulfill(rpcProbe.onlyEndWorkerCall, EmptyReturn())
await(coordination)
assert(!controller.actorRefService.hasActorRef(firstWorkerId))
assert(rpcProbe.initializedWorkers.contains(secondWorkerId))
assert(rpcProbe.startedWorkers.contains(secondWorkerId))
}
"Jumping to an operator's region" should
"make the next scheduled region contain the target operator's region" in {
val (first, second, _, schedule) = threeLevelSchedule()
val coordinator = newJumpCoordinator(schedule)
assert(nextRegions(coordinator) == Set(first))
assert(nextRegions(coordinator) == Set(second))
jumpTo(coordinator, "first")
assert(nextRegions(coordinator) == Set(first))
}
it should "support multiple sequential jumps interleaved with region pulls" in {
val (first, second, third, schedule) = threeLevelSchedule()
val coordinator = newJumpCoordinator(schedule)
assert(nextRegions(coordinator) == Set(first))
assert(nextRegions(coordinator) == Set(second))
jumpTo(coordinator, "first")
assert(nextRegions(coordinator) == Set(first))
jumpTo(coordinator, "second")
assert(nextRegions(coordinator) == Set(second))
assert(nextRegions(coordinator) == Set(third))
jumpTo(coordinator, "first")
assert(nextRegions(coordinator) == Set(first))
}
it should "be a no-op when the target operator is not in any scheduled region" in {
val (first, second, _, schedule) = threeLevelSchedule()
val coordinator = newJumpCoordinator(schedule)
assert(nextRegions(coordinator) == Set(first))
jumpTo(coordinator, "does-not-exist")
// Iteration position must be unaffected by an unknown target.
assert(nextRegions(coordinator) == Set(second))
}
it should "leave the schedule untouched when called repeatedly with unknown operators" in {
val (first, second, third, schedule) = threeLevelSchedule()
val coordinator = newJumpCoordinator(schedule)
jumpTo(coordinator, "ghost-1")
jumpTo(coordinator, "ghost-2")
jumpTo(coordinator, "ghost-3")
assert(nextRegions(coordinator) == Set(first))
assert(nextRegions(coordinator) == Set(second))
assert(nextRegions(coordinator) == Set(third))
}
it should "allow jumping back to the first region after the schedule is exhausted" in {
val (first, second, third, schedule) = threeLevelSchedule()
val coordinator = newJumpCoordinator(schedule)
assert(nextRegions(coordinator) == Set(first))
assert(nextRegions(coordinator) == Set(second))
assert(nextRegions(coordinator) == Set(third))
assert(nextRegions(coordinator) == Set.empty)
jumpTo(coordinator, "first")
assert(nextRegions(coordinator) == Set(first))
}
it should "support jumping forward past regions that have not yet been pulled" in {
val (first, _, third, schedule) = threeLevelSchedule()
val coordinator = newJumpCoordinator(schedule)
assert(nextRegions(coordinator) == Set(first))
jumpTo(coordinator, "third")
assert(nextRegions(coordinator) == Set(third))
assert(nextRegions(coordinator) == Set.empty)
}
it should "replay the target-onward range each time it jumps back" in {
// Schedule ABCDEF: jumping from E back to C yields the visible sequence ABCDECDEF; jumping
// again from E back to C yields ABCDECDECDEF.
val a = jumpRegion(1, "a")
val b = jumpRegion(2, "b")
val c = jumpRegion(3, "c")
val d = jumpRegion(4, "d")
val e = jumpRegion(5, "e")
val f = jumpRegion(6, "f")
val schedule = Schedule(
Map(0 -> Set(a), 1 -> Set(b), 2 -> Set(c), 3 -> Set(d), 4 -> Set(e), 5 -> Set(f))
)
val coordinator = newJumpCoordinator(schedule)
Seq(a, b, c, d, e).foreach { region =>
assert(nextRegions(coordinator) == Set(region))
}
jumpTo(coordinator, "c")
Seq(c, d, e).foreach { region =>
assert(nextRegions(coordinator) == Set(region))
}
jumpTo(coordinator, "c")
Seq(c, d, e, f).foreach { region =>
assert(nextRegions(coordinator) == Set(region))
}
assert(nextRegions(coordinator) == Set.empty)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/config/ChannelConfigSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling.config
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.core.workflow.{
BroadcastPartition,
HashPartition,
OneToOnePartition,
PortIdentity,
RangePartition,
SinglePartition,
UnknownPartition
}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
class ChannelConfigSpec extends AnyFlatSpec with Matchers {
private val port: PortIdentity = PortIdentity(id = 0, internal = false)
private def actor(name: String): ActorVirtualIdentity = ActorVirtualIdentity(name)
private val w1 = actor("w1")
private val w2 = actor("w2")
private val w3 = actor("w3")
private val u1 = actor("u1")
private val u2 = actor("u2")
private val u3 = actor("u3")
// Helper: extract the (sender, receiver) endpoint pairs from a list of
// ChannelConfigs to make the assertions readable.
private def endpoints(cs: List[ChannelConfig]): List[(String, String)] =
cs.map(c => (c.channelId.fromWorkerId.name, c.channelId.toWorkerId.name))
// ----- cross-product partition arms -----
"generateChannelConfigs" should "produce the full from*to cross product for HashPartition" in {
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2),
List(u1, u2, u3),
port,
HashPartition()
)
endpoints(out) shouldBe List(
("w1", "u1"),
("w1", "u2"),
("w1", "u3"),
("w2", "u1"),
("w2", "u2"),
("w2", "u3")
)
out.foreach(_.channelId.isControl shouldBe false)
out.foreach(_.toPortId shouldBe port)
}
it should "produce the full from*to cross product for BroadcastPartition" in {
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2),
List(u1, u2),
port,
BroadcastPartition()
)
endpoints(out) shouldBe List(("w1", "u1"), ("w1", "u2"), ("w2", "u1"), ("w2", "u2"))
}
it should "produce the full from*to cross product for RangePartition" in {
val out = ChannelConfig.generateChannelConfigs(
List(w1),
List(u1, u2),
port,
RangePartition(List("k"), 0L, 100L)
)
endpoints(out) shouldBe List(("w1", "u1"), ("w1", "u2"))
}
it should "produce the full from*to cross product for UnknownPartition" in {
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2),
List(u1),
port,
UnknownPartition()
)
endpoints(out) shouldBe List(("w1", "u1"), ("w2", "u1"))
}
// ----- SinglePartition arm -----
"SinglePartition" should "produce one channel per from-worker to the single to-worker" in {
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2, w3),
List(u1),
port,
SinglePartition()
)
endpoints(out) shouldBe List(("w1", "u1"), ("w2", "u1"), ("w3", "u1"))
}
it should "raise an AssertionError when more than one to-worker is supplied" in {
// Pin: SinglePartition is only valid when collapsing onto exactly one
// downstream worker; passing more violates the assertion in the source.
assertThrows[AssertionError] {
ChannelConfig.generateChannelConfigs(
List(w1, w2),
List(u1, u2),
port,
SinglePartition()
)
}
}
// ----- OneToOnePartition arm -----
"OneToOnePartition" should "zip equal-length from and to lists pairwise" in {
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2, w3),
List(u1, u2, u3),
port,
OneToOnePartition()
)
endpoints(out) shouldBe List(("w1", "u1"), ("w2", "u2"), ("w3", "u3"))
}
it should "truncate to the shorter list when from and to lengths differ (current behavior)" in {
// Pin: Scala List.zip drops the tail of the longer side. Callers are
// expected to enforce equal lengths upstream; an asymmetric input here
// silently loses pairings rather than raising. Documenting so a future
// tightening (e.g. require/asserting equal lengths) breaks this spec
// on purpose and forces the contract change to be reviewed.
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2, w3),
List(u1, u2),
port,
OneToOnePartition()
)
endpoints(out) shouldBe List(("w1", "u1"), ("w2", "u2"))
val out2 = ChannelConfig.generateChannelConfigs(
List(w1),
List(u1, u2, u3),
port,
OneToOnePartition()
)
endpoints(out2) shouldBe List(("w1", "u1"))
}
// ----- empty inputs -----
// The previous block ended with `"OneToOnePartition" should ...`, so switch
// back to `generateChannelConfigs` here. Otherwise the empty-input cases
// (which exercise Hash/Broadcast arms too) and the toPortId test below
// would be reported as `"OneToOnePartition" should ...`.
"generateChannelConfigs" should "return an empty list when fromWorkerIds is empty (cross-product arm)" in {
val out = ChannelConfig.generateChannelConfigs(
Nil,
List(u1, u2),
port,
HashPartition()
)
out shouldBe empty
}
it should "return an empty list when toWorkerIds is empty (cross-product arm)" in {
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2),
Nil,
port,
HashPartition()
)
out shouldBe empty
}
it should "return an empty list when both inputs are empty (OneToOne)" in {
val out = ChannelConfig.generateChannelConfigs(
Nil,
Nil,
port,
OneToOnePartition()
)
out shouldBe empty
}
// ----- toPortId propagation -----
it should "propagate the same toPortId onto every produced ChannelConfig" in {
val customPort = PortIdentity(id = 7, internal = true)
val out = ChannelConfig.generateChannelConfigs(
List(w1, w2),
List(u1, u2),
customPort,
BroadcastPartition()
)
out.foreach(_.toPortId shouldBe customPort)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/config/LinkConfigSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling.config
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.core.workflow.{
BroadcastPartition,
HashPartition,
OneToOnePartition,
RangePartition,
SinglePartition,
UnknownPartition
}
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.{
BroadcastPartitioning,
HashBasedShufflePartitioning,
OneToOnePartitioning,
RangeBasedShufflePartitioning,
RoundRobinPartitioning
}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
class LinkConfigSpec extends AnyFlatSpec with Matchers {
private val w1 = ActorVirtualIdentity("w1")
private val w2 = ActorVirtualIdentity("w2")
private val w3 = ActorVirtualIdentity("w3")
private val u1 = ActorVirtualIdentity("u1")
private val u2 = ActorVirtualIdentity("u2")
private val u3 = ActorVirtualIdentity("u3")
private val batch = 64
private def endpoints(channels: Seq[ChannelIdentity]): Seq[(String, String)] =
channels.map(c => (c.fromWorkerId.name, c.toWorkerId.name))
// ----- HashPartition -----
"toPartitioning" should "produce a HashBasedShufflePartitioning with full cross product channels" in {
val out = LinkConfig.toPartitioning(
List(w1, w2),
List(u1, u2, u3),
HashPartition(List("k1", "k2")),
batch
)
out shouldBe a[HashBasedShufflePartitioning]
val hp = out.asInstanceOf[HashBasedShufflePartitioning]
hp.batchSize shouldBe batch
hp.hashAttributeNames shouldBe Seq("k1", "k2")
endpoints(hp.channels) shouldBe Seq(
("w1", "u1"),
("w1", "u2"),
("w1", "u3"),
("w2", "u1"),
("w2", "u2"),
("w2", "u3")
)
hp.channels.foreach(_.isControl shouldBe false)
}
// ----- RangePartition -----
"RangePartition" should "produce a RangeBasedShufflePartitioning carrying the range bounds and cross-product channels" in {
val out = LinkConfig.toPartitioning(
List(w1),
List(u1, u2),
RangePartition(List("k"), 0L, 100L),
batch
)
out shouldBe a[RangeBasedShufflePartitioning]
val rp = out.asInstanceOf[RangeBasedShufflePartitioning]
rp.batchSize shouldBe batch
rp.rangeAttributeNames shouldBe Seq("k")
rp.rangeMin shouldBe 0L
rp.rangeMax shouldBe 100L
endpoints(rp.channels) shouldBe Seq(("w1", "u1"), ("w1", "u2"))
}
// ----- SinglePartition -----
"SinglePartition" should "produce a OneToOnePartitioning with one channel per from-worker to the single to-worker" in {
val out = LinkConfig.toPartitioning(
List(w1, w2, w3),
List(u1),
SinglePartition(),
batch
)
out shouldBe a[OneToOnePartitioning]
val op = out.asInstanceOf[OneToOnePartitioning]
op.batchSize shouldBe batch
endpoints(op.channels) shouldBe Seq(("w1", "u1"), ("w2", "u1"), ("w3", "u1"))
}
it should "raise an AssertionError when more than one to-worker is supplied" in {
assertThrows[AssertionError] {
LinkConfig.toPartitioning(List(w1, w2), List(u1, u2), SinglePartition(), batch)
}
}
// ----- OneToOnePartition -----
"OneToOnePartition" should "produce a OneToOnePartitioning with zip pairing for equal-length inputs" in {
val out = LinkConfig.toPartitioning(
List(w1, w2, w3),
List(u1, u2, u3),
OneToOnePartition(),
batch
)
out shouldBe a[OneToOnePartitioning]
val op = out.asInstanceOf[OneToOnePartitioning]
endpoints(op.channels) shouldBe Seq(("w1", "u1"), ("w2", "u2"), ("w3", "u3"))
}
it should "silently truncate when from and to lengths differ (current behavior)" in {
// Pin: same `List.zip` truncation hazard as ChannelConfig (Bug #4799).
// Documenting the parallel here so a fix that aligns the two helpers
// surfaces this spec at the same time.
val out = LinkConfig.toPartitioning(
List(w1, w2, w3),
List(u1, u2),
OneToOnePartition(),
batch
)
val op = out.asInstanceOf[OneToOnePartitioning]
endpoints(op.channels) shouldBe Seq(("w1", "u1"), ("w2", "u2"))
}
// ----- BroadcastPartition -----
"BroadcastPartition" should "produce a BroadcastPartitioning whose channels follow zip pairing today (current behavior)" in {
// Pin: BroadcastPartition currently uses `fromWorkerIds.zip(toWorkerIds)`
// — the SAME 1:1 pairing as OneToOnePartition. ChannelConfig in the same
// package emits a full cross product for the BroadcastPartition arm,
// which matches broadcast semantics ("each sender targets every
// receiver"). The two helpers diverge today; pinning this so a fix that
// realigns the contract surfaces here. Filed as a Bug.
val out = LinkConfig.toPartitioning(
List(w1, w2, w3),
List(u1, u2, u3),
BroadcastPartition(),
batch
)
out shouldBe a[BroadcastPartitioning]
val bp = out.asInstanceOf[BroadcastPartitioning]
bp.batchSize shouldBe batch
endpoints(bp.channels) shouldBe Seq(("w1", "u1"), ("w2", "u2"), ("w3", "u3"))
}
it should "silently truncate broadcast pairings when sides differ in length (current behavior)" in {
val out = LinkConfig.toPartitioning(
List(w1, w2, w3),
List(u1, u2),
BroadcastPartition(),
batch
)
val bp = out.asInstanceOf[BroadcastPartitioning]
endpoints(bp.channels) shouldBe Seq(("w1", "u1"), ("w2", "u2"))
}
// ----- UnknownPartition -----
"UnknownPartition" should "produce a RoundRobinPartitioning with the full cross product" in {
val out = LinkConfig.toPartitioning(
List(w1, w2),
List(u1, u2),
UnknownPartition(),
batch
)
out shouldBe a[RoundRobinPartitioning]
val rr = out.asInstanceOf[RoundRobinPartitioning]
rr.batchSize shouldBe batch
endpoints(rr.channels) shouldBe Seq(
("w1", "u1"),
("w1", "u2"),
("w2", "u1"),
("w2", "u2")
)
}
// ----- empty inputs -----
// The previous block ended with a `"UnknownPartition" should ...` subject.
// Switch back to "toPartitioning" so test reports for the empty-input,
// batch-propagation, and unsupported-branch cases below don't get
// misattributed to UnknownPartition.
"toPartitioning" should "return empty channels when fromWorkerIds is empty (cross-product arm)" in {
val out = LinkConfig.toPartitioning(
Nil,
List(u1, u2),
HashPartition(),
batch
)
out.asInstanceOf[HashBasedShufflePartitioning].channels shouldBe empty
}
it should "return empty channels when toWorkerIds is empty (cross-product arm)" in {
val out = LinkConfig.toPartitioning(
List(w1, w2),
Nil,
HashPartition(),
batch
)
out.asInstanceOf[HashBasedShufflePartitioning].channels shouldBe empty
}
// ----- batch size propagation -----
it should "propagate dataTransferBatchSize verbatim regardless of partitioning arm" in {
val customBatch = 1024
val out = LinkConfig.toPartitioning(
List(w1),
List(u1),
OneToOnePartition(),
customBatch
)
out.asInstanceOf[OneToOnePartitioning].batchSize shouldBe customBatch
}
// ----- unsupported branch -----
it should "throw UnsupportedOperationException when partitionInfo is unrecognized" in {
// PartitionInfo is sealed, so the only way to reach the catch-all
// `case _` branch from a test is to pass an off-domain value such as
// null. This pins the contract that an unknown PartitionInfo subtype
// results in UnsupportedOperationException rather than silently
// dropping into a default partitioning.
assertThrows[UnsupportedOperationException] {
LinkConfig.toPartitioning(
List(w1),
List(u1),
null,
batch
)
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/config/SchedulingConfigsSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling.config
import org.apache.texera.amber.config.ApplicationConfig
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ChannelIdentity,
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow._
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings._
import org.scalatest.flatspec.AnyFlatSpec
import java.net.URI
class SchedulingConfigsSpec extends AnyFlatSpec {
private def actor(name: String): ActorVirtualIdentity = ActorVirtualIdentity(name)
private def chan(from: ActorVirtualIdentity, to: ActorVirtualIdentity): ChannelIdentity =
ChannelIdentity(from, to, isControl = false)
// ---------------------------------------------------------------------------
// ChannelConfig.generateChannelConfigs
// ---------------------------------------------------------------------------
"ChannelConfig.generateChannelConfigs" should "produce a full cross-product for HashPartition" in {
val from = List(actor("f1"), actor("f2"))
val to = List(actor("t1"), actor("t2"), actor("t3"))
val configs =
ChannelConfig.generateChannelConfigs(from, to, PortIdentity(0), HashPartition(List("k")))
assert(configs.size == 6)
assert(configs.map(_.channelId).toSet == (for (f <- from; t <- to) yield chan(f, t)).toSet)
configs.foreach(c => assert(c.toPortId == PortIdentity(0)))
}
it should "produce a full cross-product for RangePartition" in {
val from = List(actor("f1"))
val to = List(actor("t1"), actor("t2"))
val configs = ChannelConfig.generateChannelConfigs(
from,
to,
PortIdentity(1),
new RangePartition(List("k"), 0L, 10L)
)
assert(configs.size == 2)
}
it should "produce a full cross-product for BroadcastPartition" in {
val from = List(actor("f1"), actor("f2"))
val to = List(actor("t1"), actor("t2"))
val configs =
ChannelConfig.generateChannelConfigs(from, to, PortIdentity(0), BroadcastPartition())
assert(configs.size == 4)
}
it should "produce a full cross-product for UnknownPartition" in {
val from = List(actor("f1"))
val to = List(actor("t1"), actor("t2"))
val configs =
ChannelConfig.generateChannelConfigs(from, to, PortIdentity(0), UnknownPartition())
assert(configs.size == 2)
}
it should "fan-in to a single receiver for SinglePartition" in {
val from = List(actor("f1"), actor("f2"), actor("f3"))
val to = List(actor("only-receiver"))
val configs =
ChannelConfig.generateChannelConfigs(from, to, PortIdentity(0), SinglePartition())
assert(configs.size == 3)
assert(configs.forall(_.channelId.toWorkerId == actor("only-receiver")))
}
it should "fail the SinglePartition assertion when toWorkerIds has more than one entry" in {
val from = List(actor("f1"))
val to = List(actor("t1"), actor("t2"))
assertThrows[AssertionError] {
ChannelConfig.generateChannelConfigs(from, to, PortIdentity(0), SinglePartition())
}
}
it should "zip from/to in OneToOnePartition" in {
val from = List(actor("f1"), actor("f2"), actor("f3"))
val to = List(actor("t1"), actor("t2"), actor("t3"))
val configs =
ChannelConfig.generateChannelConfigs(from, to, PortIdentity(0), OneToOnePartition())
assert(configs.size == 3)
val pairs = configs.map(c => (c.channelId.fromWorkerId, c.channelId.toWorkerId))
assert(
pairs == List(
(actor("f1"), actor("t1")),
(actor("f2"), actor("t2")),
(actor("f3"), actor("t3"))
)
)
}
it should "produce empty list for unhandled partition cases" in {
// PartitionInfo is sealed, so `null` is the only value that falls through
// the named cases without adding a new subtype. This pins the catch-all
// `case _ => List()` branch.
val configs = ChannelConfig.generateChannelConfigs(
List(actor("f")),
List(actor("t")),
PortIdentity(0),
null.asInstanceOf[PartitionInfo]
)
assert(configs.isEmpty)
}
// ---------------------------------------------------------------------------
// LinkConfig.toPartitioning
// ---------------------------------------------------------------------------
"LinkConfig.toPartitioning" should "map HashPartition to HashBasedShufflePartitioning carrying its hash attributes" in {
val from = List(actor("f"))
val to = List(actor("t1"), actor("t2"))
val partitioning =
LinkConfig.toPartitioning(from, to, HashPartition(List("a", "b")), dataTransferBatchSize = 50)
val hashed = partitioning.asInstanceOf[HashBasedShufflePartitioning]
assert(hashed.batchSize == 50)
assert(hashed.hashAttributeNames == List("a", "b"))
assert(hashed.channels.size == 2)
}
it should "map RangePartition to RangeBasedShufflePartitioning carrying its range bounds" in {
val from = List(actor("f"))
val to = List(actor("t1"))
val partitioning = LinkConfig.toPartitioning(
from,
to,
new RangePartition(List("a"), 0L, 99L),
dataTransferBatchSize = 10
)
val ranged = partitioning.asInstanceOf[RangeBasedShufflePartitioning]
assert(ranged.batchSize == 10)
assert(ranged.rangeMin == 0L)
assert(ranged.rangeMax == 99L)
assert(ranged.rangeAttributeNames == List("a"))
}
it should "map SinglePartition to OneToOnePartitioning fanned in to the single receiver" in {
val from = List(actor("f1"), actor("f2"))
val to = List(actor("only"))
val partitioning =
LinkConfig.toPartitioning(from, to, SinglePartition(), dataTransferBatchSize = 1)
val one = partitioning.asInstanceOf[OneToOnePartitioning]
assert(one.channels.forall(_.toWorkerId == actor("only")))
assert(one.channels.size == 2)
}
it should "fail the SinglePartition assertion when toWorkerIds has more than one entry" in {
val from = List(actor("f"))
val to = List(actor("t1"), actor("t2"))
assertThrows[AssertionError] {
LinkConfig.toPartitioning(from, to, SinglePartition(), dataTransferBatchSize = 1)
}
}
it should "map OneToOnePartition to OneToOnePartitioning over zipped pairs" in {
val from = List(actor("f1"), actor("f2"))
val to = List(actor("t1"), actor("t2"))
val partitioning =
LinkConfig.toPartitioning(from, to, OneToOnePartition(), dataTransferBatchSize = 1)
val one = partitioning.asInstanceOf[OneToOnePartitioning]
assert(one.channels.size == 2)
assert(one.channels.head == chan(actor("f1"), actor("t1")))
}
it should "map BroadcastPartition to BroadcastPartitioning over zipped pairs" in {
val from = List(actor("f1"), actor("f2"))
val to = List(actor("t1"), actor("t2"))
val partitioning =
LinkConfig.toPartitioning(from, to, BroadcastPartition(), dataTransferBatchSize = 1)
assert(partitioning.isInstanceOf[BroadcastPartitioning])
}
it should "map UnknownPartition to RoundRobinPartitioning across the cross-product" in {
val from = List(actor("f1"), actor("f2"))
val to = List(actor("t1"), actor("t2"))
val partitioning =
LinkConfig.toPartitioning(from, to, UnknownPartition(), dataTransferBatchSize = 1)
val rr = partitioning.asInstanceOf[RoundRobinPartitioning]
assert(rr.channels.size == 4)
}
it should "throw UnsupportedOperationException for unhandled partition cases" in {
// PartitionInfo is sealed; `null` is the only value that falls through
// the named cases without adding a new subtype. This pins the catch-all
// `case _ => throw new UnsupportedOperationException()` branch.
assertThrows[UnsupportedOperationException] {
LinkConfig.toPartitioning(
List(actor("f")),
List(actor("t")),
null.asInstanceOf[PartitionInfo],
dataTransferBatchSize = 1
)
}
}
// ---------------------------------------------------------------------------
// PortConfig hierarchy
// ---------------------------------------------------------------------------
"OutputPortConfig" should "expose its single storage URI via storageURIs" in {
val uri = new URI("vfs:///wid/1/eid/1/result")
val cfg = OutputPortConfig(uri)
assert(cfg.storageURIs == List(uri))
}
"IntermediateInputPortConfig" should "expose every URI it was constructed with" in {
val uris = List(new URI("vfs:///a"), new URI("vfs:///b"))
val cfg = IntermediateInputPortConfig(uris)
assert(cfg.storageURIs == uris)
}
"InputPortConfig" should "expose the URI projection of its storage pairs in order" in {
val a = new URI("vfs:///a")
val b = new URI("vfs:///b")
val partitioningA = OneToOnePartitioning(1, Seq.empty)
val partitioningB = OneToOnePartitioning(2, Seq.empty)
val cfg = InputPortConfig(List((a, partitioningA), (b, partitioningB)))
assert(cfg.storageURIs == List(a, b))
}
// ---------------------------------------------------------------------------
// OperatorConfig
// ---------------------------------------------------------------------------
"OperatorConfig.empty" should "have no worker configs" in {
assert(OperatorConfig.empty.workerConfigs.isEmpty)
}
it should "preserve the workerConfigs given at construction" in {
val configs = List(WorkerConfig(actor("w1")), WorkerConfig(actor("w2")))
val op = OperatorConfig(configs)
assert(op.workerConfigs == configs)
}
// ---------------------------------------------------------------------------
// ResourceConfig defaults
// ---------------------------------------------------------------------------
"ResourceConfig" should "default all three maps to empty" in {
val rc = ResourceConfig()
assert(rc.operatorConfigs.isEmpty)
assert(rc.linkConfigs.isEmpty)
assert(rc.portConfigs.isEmpty)
}
// ---------------------------------------------------------------------------
// WorkerConfig.generateWorkerConfigs
// ---------------------------------------------------------------------------
private def physicalOp(parallelizable: Boolean, suggested: Option[Int]): PhysicalOp =
PhysicalOp(
PhysicalOpIdentity(OperatorIdentity("op"), "main"),
WorkflowIdentity(0),
ExecutionIdentity(0),
OpExecInitInfo.Empty,
parallelizable = parallelizable,
suggestedWorkerNum = suggested
)
"WorkerConfig.generateWorkerConfigs" should "produce exactly one WorkerConfig for non-parallelizable ops" in {
val configs =
WorkerConfig.generateWorkerConfigs(physicalOp(parallelizable = false, suggested = None))
assert(configs.size == 1)
}
it should "ignore a suggested worker count for non-parallelizable ops" in {
val configs =
WorkerConfig.generateWorkerConfigs(physicalOp(parallelizable = false, suggested = Some(8)))
assert(configs.size == 1)
}
it should "honor the suggested worker count for parallelizable ops" in {
val configs =
WorkerConfig.generateWorkerConfigs(physicalOp(parallelizable = true, suggested = Some(5)))
assert(configs.size == 5)
// distinct worker ids
assert(configs.map(_.workerId).distinct.size == 5)
}
it should "fall back to the configured default when no suggested count is given for a parallelizable op" in {
val configs =
WorkerConfig.generateWorkerConfigs(physicalOp(parallelizable = true, suggested = None))
assert(configs.size == ApplicationConfig.numWorkerPerOperatorByDefault)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/scheduling/resourcePolicies/ResourcePoliciesSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.scheduling.resourcePolicies
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.architecture.scheduling.{Region, RegionIdentity}
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.{
BroadcastPartitioning,
HashBasedShufflePartitioning,
OneToOnePartitioning,
Partitioning,
RangeBasedShufflePartitioning,
RoundRobinPartitioning
}
import org.apache.texera.amber.engine.e2e.TestUtils.buildWorkflow
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.workflow.LogicalLink
import org.scalatest.flatspec.AnyFlatSpec
class ResourcePoliciesSpec extends AnyFlatSpec {
// ---------------------------------------------------------------------------
// ExecutionClusterInfo
// ---------------------------------------------------------------------------
"ExecutionClusterInfo" should "construct without arguments" in {
// No-arg constructor must not throw; the type currently has no observable
// state to assert beyond that.
new ExecutionClusterInfo()
}
// ---------------------------------------------------------------------------
// DefaultResourceAllocator (helpers + tests)
// ---------------------------------------------------------------------------
/** Build a small linear `csv -> keyword` workflow to feed the allocator. */
private def buildLinearWorkflow() = {
val csv = TestOperators.headerlessSmallCsvScanOpDesc()
val keyword = TestOperators.keywordSearchOpDesc("column-1", "Asia")
buildWorkflow(
List(csv, keyword),
List(
LogicalLink(
csv.operatorIdentifier,
PortIdentity(0),
keyword.operatorIdentifier,
PortIdentity(0)
)
),
new WorkflowContext()
)
}
private def newAllocator(): (DefaultResourceAllocator, Region) = {
val workflow = buildLinearWorkflow()
val allocator = new DefaultResourceAllocator(
workflow.physicalPlan,
new ExecutionClusterInfo(),
workflow.context.workflowSettings
)
val region = Region(
id = RegionIdentity(0),
physicalOps = workflow.physicalPlan.operators,
physicalLinks = workflow.physicalPlan.links
)
(allocator, region)
}
"DefaultResourceAllocator.allocate" should "return zero cost (placeholder)" in {
val (allocator, region) = newAllocator()
val (_, cost) = allocator.allocate(region)
assert(cost == 0d)
}
it should "produce an OperatorConfig entry for every operator in the region" in {
val (allocator, region) = newAllocator()
val (resourceConfig, _) = allocator.allocate(region)
val opIds = region.getOperators.map(_.id)
assert(resourceConfig.operatorConfigs.keySet == opIds)
}
it should "respect parallelizable / suggested-worker settings on each PhysicalOp" in {
val (allocator, region) = newAllocator()
val (resourceConfig, _) = allocator.allocate(region)
region.getOperators.foreach { op =>
val workers = resourceConfig.operatorConfigs(op.id).workerConfigs.size
val expected =
if (!op.parallelizable) 1
else
op.suggestedWorkerNum.getOrElse(
org.apache.texera.amber.config.ApplicationConfig.numWorkerPerOperatorByDefault
)
assert(workers == expected, s"unexpected worker count for ${op.id}")
}
}
it should "honor an explicit suggestedWorkerNum on a parallelizable op" in {
val workflow = buildLinearWorkflow()
val keywordPhysicalOpId =
workflow.physicalPlan.operators.find(_.parallelizable).map(_.id).get
val rebuiltOps = workflow.physicalPlan.operators.map { op =>
if (op.id == keywordPhysicalOpId) op.withSuggestedWorkerNum(7) else op
}
val rebuiltPlan = workflow.physicalPlan.copy(operators = rebuiltOps)
val allocator = new DefaultResourceAllocator(
rebuiltPlan,
new ExecutionClusterInfo(),
workflow.context.workflowSettings
)
val region = Region(
id = RegionIdentity(0),
physicalOps = rebuiltOps,
physicalLinks = rebuiltPlan.links
)
val (resourceConfig, _) = allocator.allocate(region)
assert(resourceConfig.operatorConfigs(keywordPhysicalOpId).workerConfigs.size == 7)
}
it should "emit distinct worker ids per operator" in {
val (allocator, region) = newAllocator()
val (resourceConfig, _) = allocator.allocate(region)
val ids = resourceConfig.operatorConfigs.values.flatMap(_.workerConfigs.map(_.workerId)).toList
assert(ids.distinct.size == ids.size, s"duplicate worker ids in $ids")
}
it should "produce a LinkConfig entry for every physical link in the region" in {
val (allocator, region) = newAllocator()
val (resourceConfig, _) = allocator.allocate(region)
assert(resourceConfig.linkConfigs.keySet == region.getLinks)
}
it should "wire each LinkConfig so its Partitioning channels match its channelConfigs" in {
val (allocator, region) = newAllocator()
val (resourceConfig, _) = allocator.allocate(region)
resourceConfig.linkConfigs.values.foreach { link =>
assert(link.channelConfigs.nonEmpty)
val partitioningChannels = partitioningOf(link.partitioning)
assert(partitioningChannels == link.channelConfigs.map(_.channelId))
}
}
private def partitioningOf(p: Partitioning) =
p match {
case x: OneToOnePartitioning => x.channels
case x: RoundRobinPartitioning => x.channels
case x: HashBasedShufflePartitioning => x.channels
case x: RangeBasedShufflePartitioning => x.channels
case x: BroadcastPartitioning => x.channels
case other => fail(s"allocator emitted unexpected Partitioning: $other")
}
it should "leave portConfigs empty when the region has no prior resourceConfig" in {
val (allocator, region) = newAllocator()
val (resourceConfig, _) = allocator.allocate(region)
assert(resourceConfig.portConfigs.isEmpty)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/sendsemantics/partitioners/NetworkOutputBufferSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.sendsemantics.partitioners
import org.apache.texera.amber.config.ApplicationConfig
import org.apache.texera.amber.core.state.State
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.engine.architecture.messaginglayer.NetworkOutputGateway
import org.apache.texera.amber.engine.common.ambermessage.{
DataFrame,
StateFrame,
WorkflowFIFOMessage
}
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable.ArrayBuffer
class NetworkOutputBufferSpec extends AnyFlatSpec {
// --- fixtures --------------------------------------------------------------
private val sender = ActorVirtualIdentity("sender")
private val receiver = ActorVirtualIdentity("receiver-1")
private val intAttr = new Attribute("v", AttributeType.INTEGER)
private val schema: Schema = Schema().add(intAttr)
private def tuple(value: Int): Tuple =
Tuple.builder(schema).add(intAttr, value).build()
/** Recording wrapper around a real `NetworkOutputGateway`. */
private class Capture {
val messages: ArrayBuffer[WorkflowFIFOMessage] = ArrayBuffer.empty
val gateway: NetworkOutputGateway =
new NetworkOutputGateway(sender, m => messages += m)
}
private def newBuffer(batchSize: Int = 4): (NetworkOutputBuffer, Capture) = {
val cap = new Capture
val buf = new NetworkOutputBuffer(receiver, cap.gateway, batchSize = batchSize)
(buf, cap)
}
// --- construction defaults -------------------------------------------------
"NetworkOutputBuffer" should "default batchSize to ApplicationConfig.defaultDataTransferBatchSize" in {
val cap = new Capture
val buf = new NetworkOutputBuffer(receiver, cap.gateway)
assert(buf.batchSize == ApplicationConfig.defaultDataTransferBatchSize)
}
it should "expose `to` and `dataOutputPort` as immutable accessors" in {
val cap = new Capture
val buf = new NetworkOutputBuffer(receiver, cap.gateway, batchSize = 4)
assert(buf.to == receiver)
assert(buf.dataOutputPort eq cap.gateway)
}
it should "start with an empty buffer (no implicit auto-flush at construction)" in {
val (_, cap) = newBuffer()
assert(cap.messages.isEmpty)
}
// --- addTuple buffering / auto-flush --------------------------------------
"NetworkOutputBuffer.addTuple" should "NOT flush while the buffer is below batchSize" in {
val (buf, cap) = newBuffer(batchSize = 4)
buf.addTuple(tuple(0))
buf.addTuple(tuple(1))
buf.addTuple(tuple(2))
assert(cap.messages.isEmpty, "no DataFrame should be sent until batchSize is reached")
}
it should "auto-flush when the buffer exactly reaches batchSize" in {
val (buf, cap) = newBuffer(batchSize = 3)
buf.addTuple(tuple(0))
buf.addTuple(tuple(1))
buf.addTuple(tuple(2)) // boundary: now size == batchSize
assert(cap.messages.size == 1, "exactly one DataFrame should be auto-flushed at the boundary")
val frame = cap.messages.head.payload.asInstanceOf[DataFrame]
assert(frame.frame.toList == List(tuple(0), tuple(1), tuple(2)))
}
it should "produce a separate DataFrame for each successive batch" in {
val (buf, cap) = newBuffer(batchSize = 2)
(0 until 6).foreach(i => buf.addTuple(tuple(i)))
assert(cap.messages.size == 3, "three full batches → three DataFrames")
val payloads = cap.messages.map(_.payload.asInstanceOf[DataFrame].frame.toList)
assert(payloads.head == List(tuple(0), tuple(1)))
assert(payloads(1) == List(tuple(2), tuple(3)))
assert(payloads(2) == List(tuple(4), tuple(5)))
}
it should "send DataFrames to the configured receiver only" in {
val (buf, cap) = newBuffer(batchSize = 2)
buf.addTuple(tuple(0))
buf.addTuple(tuple(1))
assert(cap.messages.size == 1)
val msg = cap.messages.head
assert(msg.channelId.fromWorkerId == sender)
assert(msg.channelId.toWorkerId == receiver)
assert(!msg.channelId.isControl, "data path must not use the control channel")
}
// --- flush() ----------------------------------------------------------------
"NetworkOutputBuffer.flush" should "send a DataFrame and reset the buffer when the buffer is non-empty" in {
val (buf, cap) = newBuffer(batchSize = 100) // never auto-flushes
buf.addTuple(tuple(7))
buf.addTuple(tuple(8))
buf.flush()
assert(cap.messages.size == 1)
val frame = cap.messages.head.payload.asInstanceOf[DataFrame]
assert(frame.frame.toList == List(tuple(7), tuple(8)))
// A second flush() with nothing buffered must not send another frame.
buf.flush()
assert(cap.messages.size == 1, "flush() on an empty buffer must be a no-op")
}
it should "be a no-op when called on an empty buffer (no DataFrame, no StateFrame)" in {
val (buf, cap) = newBuffer()
buf.flush()
buf.flush()
buf.flush()
assert(cap.messages.isEmpty)
}
it should "assign monotonically increasing sequence numbers across multiple flushes" in {
// The gateway tracks sequence numbers per channel; each successive
// DataFrame on the same channel gets the next number. Pin so a
// regression that resets seq on flush is visible.
val (buf, cap) = newBuffer(batchSize = 1) // each addTuple flushes
(0 until 4).foreach(i => buf.addTuple(tuple(i)))
val seqs = cap.messages.map(_.sequenceNumber).toList
assert(seqs == List(0L, 1L, 2L, 3L), s"unexpected sequence: $seqs")
}
// --- sendState ----------------------------------------------------------
"NetworkOutputBuffer.sendState" should "flush pending tuples FIRST, then send the StateFrame" in {
val (buf, cap) = newBuffer(batchSize = 100)
buf.addTuple(tuple(0))
buf.addTuple(tuple(1))
val state = State(Map("checkpoint" -> 99))
buf.sendState(state)
// Expected order: DataFrame (the buffered tuples) → StateFrame.
assert(cap.messages.size == 2)
val first = cap.messages.head.payload
val second = cap.messages(1).payload
assert(first.isInstanceOf[DataFrame], s"first frame should be DataFrame, got $first")
assert(first.asInstanceOf[DataFrame].frame.toList == List(tuple(0), tuple(1)))
assert(second == StateFrame(state))
}
it should "send only the StateFrame when no tuples are pending (empty pre-flush is a no-op)" in {
val (buf, cap) = newBuffer()
val state = State(Map("k" -> "v"))
buf.sendState(state)
assert(cap.messages.size == 1)
assert(cap.messages.head.payload == StateFrame(state))
}
it should "leave the tuple buffer empty after sendState (trailing flush no-op)" in {
// sendState calls flush() AFTER sending the state too. Pin that the
// trailing flush doesn't double-send and that subsequent addTuple
// starts from a clean buffer.
val (buf, cap) = newBuffer(batchSize = 100)
buf.addTuple(tuple(0))
buf.sendState(State(Map.empty))
val countBefore = cap.messages.size // DataFrame + StateFrame = 2
assert(countBefore == 2)
// Add another tuple and explicit flush — must produce one fresh frame.
buf.addTuple(tuple(99))
buf.flush()
assert(cap.messages.size == 3)
val third = cap.messages(2).payload.asInstanceOf[DataFrame]
assert(third.frame.toList == List(tuple(99)), "post-state buffer must start empty")
}
it should "share a single sequence-number stream across DataFrames and the StateFrame on the same channel" in {
// Pin: DataFrame and StateFrame go through the same `sendTo` path on
// the same channel, so they share the gateway's sequence-number
// counter. A regression that opens a side-channel for StateFrame
// would produce a non-monotonic stream and fail this.
val (buf, cap) = newBuffer(batchSize = 100)
buf.addTuple(tuple(0))
buf.addTuple(tuple(1))
buf.sendState(State(Map("x" -> 1)))
buf.addTuple(tuple(2))
buf.flush()
val seqs = cap.messages.map(_.sequenceNumber).toList
assert(seqs == List(0L, 1L, 2L), s"unexpected sequence: $seqs")
}
// --- batchSize edge cases -------------------------------------------------
"NetworkOutputBuffer with batchSize = 1" should "flush immediately after every addTuple" in {
val (buf, cap) = newBuffer(batchSize = 1)
buf.addTuple(tuple(0))
assert(cap.messages.size == 1)
buf.addTuple(tuple(1))
assert(cap.messages.size == 2)
val frames = cap.messages.toList.map(_.payload.asInstanceOf[DataFrame].frame.toList)
assert(frames == List(List(tuple(0)), List(tuple(1))))
}
// `batchSize <= 0` IS reachable from production today: the
// workflow-settings UI restricts the value to `>= 1`, but
// `SyncExecutionResource` accepts `request.workflowSettings` directly
// from the API and the backend forwards `workflowSettings
// .dataTransferBatchSize` into `NetworkOutputBuffer` without
// validating it. The reachable path is covered by a characterization
// test (current lenient `>=` behavior — flush every tuple) plus a
// pendingUntilFixed test pinning the desired hardening (rejection
// at construction). When the hardening lands the characterization
// test breaks on purpose AND pendingUntilFixed flips into a
// deliberate failure forcing both markers to be updated together.
"NetworkOutputBuffer with non-positive batchSize" should
"currently flush per-tuple under the `>=` guard (characterization, today's lenient behavior)" in {
// Pin the current observable behavior for the reachable-from-API
// `batchSize <= 0` path so a regression that breaks per-tuple
// flush (e.g. a partial change that disables flushing entirely
// for non-positive batch sizes) surfaces here. A future hardening
// that rejects `<= 0` at construction WILL break this test on
// purpose — and the pendingUntilFixed test below will flip into
// a deliberate failure at the same time, forcing both markers to
// be updated together.
val (buf0, cap0) = newBuffer(batchSize = 0)
buf0.addTuple(tuple(1))
buf0.addTuple(tuple(2))
val frames0 = cap0.messages.toList.map(_.payload.asInstanceOf[DataFrame].frame.toList)
assert(frames0 == List(List(tuple(1)), List(tuple(2))))
val (bufNeg, capNeg) = newBuffer(batchSize = -1)
bufNeg.addTuple(tuple(99))
val framesNeg = capNeg.messages.toList.map(_.payload.asInstanceOf[DataFrame].frame.toList)
assert(framesNeg == List(List(tuple(99))))
}
it should "eventually reject construction (pendingUntilFixed)" in pendingUntilFixed {
// Today the constructor accepts `batchSize <= 0` and the `>=`
// guard then fires after every append (the characterization
// above pins that behavior). The intended contract is that a
// non-positive batch size is invalid input and should be
// rejected at construction (e.g. `require(batchSize > 0, ...)`).
// Asserting `IllegalArgumentException` here flips this from
// pending to passing once the hardening lands.
val cap = new Capture
intercept[IllegalArgumentException] {
new NetworkOutputBuffer(receiver, cap.gateway, batchSize = 0)
}
intercept[IllegalArgumentException] {
new NetworkOutputBuffer(receiver, cap.gateway, batchSize = -1)
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/sendsemantics/partitioners/PartitionersSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.sendsemantics.partitioners
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.{
BroadcastPartitioning,
HashBasedShufflePartitioning,
OneToOnePartitioning,
RoundRobinPartitioning
}
import org.scalatest.flatspec.AnyFlatSpec
class PartitionersSpec extends AnyFlatSpec {
private val sender: ActorVirtualIdentity = ActorVirtualIdentity("sender")
private val r1: ActorVirtualIdentity = ActorVirtualIdentity("rec1")
private val r2: ActorVirtualIdentity = ActorVirtualIdentity("rec2")
private val r3: ActorVirtualIdentity = ActorVirtualIdentity("rec3")
private def channel(to: ActorVirtualIdentity): ChannelIdentity =
ChannelIdentity(sender, to, isControl = false)
private val intAttr: Attribute = new Attribute("v", AttributeType.INTEGER)
private val intSchema: Schema = Schema().add(intAttr)
private def intTuple(value: Int): Tuple =
Tuple.builder(intSchema).add(intAttr, value).build()
private val twoStringSchema: Schema = Schema()
.add(new Attribute("k", AttributeType.STRING))
.add(new Attribute("v", AttributeType.STRING))
private def stringTuple(k: String, v: String): Tuple =
Tuple
.builder(twoStringSchema)
.add(new Attribute("k", AttributeType.STRING), k)
.add(new Attribute("v", AttributeType.STRING), v)
.build()
// -- OneToOnePartitioner --------------------------------------------------
"OneToOnePartitioner.getBucketIndex" should "always return Iterator(0)" in {
val partitioning = OneToOnePartitioning(
batchSize = 100,
channels = Seq(channel(r1))
)
val partitioner = OneToOnePartitioner(partitioning, sender)
assert(partitioner.getBucketIndex(intTuple(7)).toList == List(0))
assert(partitioner.getBucketIndex(intTuple(42)).toList == List(0))
}
"OneToOnePartitioner.allReceivers" should "return the receiver from the channel matching the actor id" in {
val partitioning = OneToOnePartitioning(
batchSize = 100,
channels = Seq(
ChannelIdentity(ActorVirtualIdentity("other-sender"), r2, isControl = false),
channel(r1)
)
)
val partitioner = OneToOnePartitioner(partitioning, sender)
assert(partitioner.allReceivers == Seq(r1))
}
// -- BroadcastPartitioner -------------------------------------------------
"BroadcastPartitioner.getBucketIndex" should "yield every receiver index for any tuple" in {
val partitioning = BroadcastPartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r2), channel(r3))
)
val partitioner = BroadcastPartitioner(partitioning)
assert(partitioner.getBucketIndex(intTuple(0)).toList == List(0, 1, 2))
}
"BroadcastPartitioner" should "deduplicate receivers when channels list a worker twice" in {
val partitioning = BroadcastPartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r1), channel(r2))
)
val partitioner = BroadcastPartitioner(partitioning)
assert(partitioner.allReceivers == Seq(r1, r2))
assert(partitioner.getBucketIndex(intTuple(0)).toList == List(0, 1))
}
// -- RoundRobinPartitioner ------------------------------------------------
"RoundRobinPartitioner.getBucketIndex" should "cycle through bucket indices" in {
val partitioning = RoundRobinPartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r2), channel(r3))
)
val partitioner = RoundRobinPartitioner(partitioning)
val indices = (1 to 7).map(_ => partitioner.getBucketIndex(intTuple(0)).next()).toList
// Implementation increments first, then emits. Starting from 0, the first
// emitted index is therefore 1, then 2, then 0, repeating.
assert(indices == List(1, 2, 0, 1, 2, 0, 1))
}
"RoundRobinPartitioner.allReceivers" should "preserve channel order while deduplicating" in {
val partitioning = RoundRobinPartitioning(
batchSize = 100,
channels = Seq(channel(r2), channel(r1), channel(r2))
)
val partitioner = RoundRobinPartitioner(partitioning)
assert(partitioner.allReceivers == Seq(r2, r1))
}
// -- HashBasedShufflePartitioner ------------------------------------------
"HashBasedShufflePartitioner.getBucketIndex" should "return a non-negative index within the receiver count" in {
val partitioning = HashBasedShufflePartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r2), channel(r3)),
hashAttributeNames = Seq("k")
)
val partitioner = HashBasedShufflePartitioner(partitioning)
(0 until 50).foreach { i =>
val idx = partitioner.getBucketIndex(stringTuple(s"key-$i", "v")).next()
assert(idx >= 0 && idx < 3, s"index $idx out of range for tuple key-$i")
}
}
it should "be deterministic for the same input tuple" in {
val partitioning = HashBasedShufflePartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r2), channel(r3)),
hashAttributeNames = Seq("k")
)
val partitioner = HashBasedShufflePartitioner(partitioning)
// Same tuple instance, two consecutive calls — the contract says the
// second call must produce the same bucket as the first.
val tuple = stringTuple("alpha", "ignored")
val first = partitioner.getBucketIndex(tuple).next()
val second = partitioner.getBucketIndex(tuple).next()
assert(first == second)
}
it should "depend only on the hash-attribute subset, not on other fields" in {
val partitioning = HashBasedShufflePartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r2), channel(r3)),
hashAttributeNames = Seq("k")
)
val partitioner = HashBasedShufflePartitioner(partitioning)
// Sweep several (k, v) pairs so a buggy implementation that hashes the
// full tuple would have to collide modulo 3 on every single key — which
// is not realistic for any reasonable hash. For each k, vary the second
// field across multiple values; the bucket must be the same for all of
// them.
val keys = Seq("alpha", "beta", "gamma", "delta", "epsilon", "zeta")
val varyingSecondField = (0 until 8).map(i => s"v-$i")
keys.foreach { k =>
val buckets =
varyingSecondField.map(v => partitioner.getBucketIndex(stringTuple(k, v)).next())
assert(
buckets.distinct.size == 1,
s"key=$k produced different buckets when varying the non-hash field: $buckets"
)
}
}
it should "use the full tuple when no hash attributes are configured" in {
val partitioning = HashBasedShufflePartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r2), channel(r3)),
hashAttributeNames = Seq.empty
)
val partitioner = HashBasedShufflePartitioner(partitioning)
// Hold k constant; vary the second field across many values. If the
// partitioner hashed only the (empty) hash-attr subset, every bucket
// would collapse to a single value. With the full tuple feeding the
// hash, varying v across enough samples must produce more than one
// distinct bucket among 3 receivers.
val sampleSize = 50
val buckets =
(0 until sampleSize).map(i => partitioner.getBucketIndex(stringTuple("k", s"v-$i")).next())
buckets.foreach(idx => assert(idx >= 0 && idx < 3))
assert(
buckets.distinct.size > 1,
s"empty hashAttributeNames should hash the full tuple, but $sampleSize samples all landed in: ${buckets.distinct}"
)
}
"HashBasedShufflePartitioner.allReceivers" should "deduplicate channel destinations" in {
val partitioning = HashBasedShufflePartitioning(
batchSize = 100,
channels = Seq(channel(r1), channel(r2), channel(r1)),
hashAttributeNames = Seq("k")
)
val partitioner = HashBasedShufflePartitioner(partitioning)
assert(partitioner.allReceivers == Seq(r1, r2))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/DPThreadSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.worker
import org.apache.texera.amber.core.executor.OperatorExecutor
import org.apache.texera.amber.core.tuple.{AttributeType, Schema, Tuple, TupleLike}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.core.workflow.PortIdentity
import org.apache.texera.amber.engine.architecture.logreplay.{ReplayLogManager, ReplayLogRecord}
import org.apache.texera.amber.engine.architecture.messaginglayer.WorkerTimerService
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
AsyncRPCContext,
EmptyRequest
}
import org.apache.texera.amber.engine.architecture.rpc.workerservice.WorkerServiceGrpc.{
METHOD_PAUSE_WORKER,
METHOD_RESUME_WORKER
}
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.{
DPInputQueueElement,
FIFOMessageElement,
TimerBasedControlElement
}
import org.apache.texera.amber.engine.common.ambermessage.{DataFrame, WorkflowFIFOMessage}
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient.ControlInvocation
import org.apache.texera.amber.engine.common.storage.SequentialRecordStorage
import org.apache.texera.amber.engine.common.virtualidentity.util.SELF
import org.scalamock.scalatest.MockFactory
import org.scalatest.flatspec.AnyFlatSpec
import java.net.URI
import java.util.concurrent.LinkedBlockingQueue
class DPThreadSpec extends AnyFlatSpec with MockFactory {
private val workerId: ActorVirtualIdentity = ActorVirtualIdentity("DP mock")
private val senderWorkerId: ActorVirtualIdentity = ActorVirtualIdentity("mock sender")
private val dataChannelId = ChannelIdentity(senderWorkerId, workerId, isControl = false)
private val controlChannelId = ChannelIdentity(senderWorkerId, workerId, isControl = true)
private val executor = mock[OperatorExecutor]
private val mockInputPortId = PortIdentity()
private val schema: Schema = Schema().add("field1", AttributeType.INTEGER)
private val tuples: Array[Tuple] = (0 until 5000)
.map(i => TupleLike(i).enforceSchema(schema))
.toArray
private val logStorage = SequentialRecordStorage.getStorage[ReplayLogRecord](None)
private val logManager: ReplayLogManager =
ReplayLogManager.createLogManager(logStorage, "none", x => {})
"DP Thread" should "handle pause/resume during processing" in {
val inputQueue = new LinkedBlockingQueue[DPInputQueueElement]()
val dp = new DataProcessor(workerId, x => {}, inputMessageQueue = inputQueue)
dp.executor = executor
dp.inputManager.addPort(mockInputPortId, schema, List.empty, List.empty)
dp.inputGateway.getChannel(dataChannelId).setPortId(mockInputPortId)
dp.adaptiveBatchingMonitor = mock[WorkerTimerService]
(dp.adaptiveBatchingMonitor.resumeAdaptiveBatching _).expects().anyNumberOfTimes()
val dpThread = new DPThread(workerId, dp, logManager, inputQueue)
dpThread.start()
tuples.foreach { x =>
(
(
tuple: Tuple,
input: Int
) => executor.processTupleMultiPort(tuple, input)
)
.expects(x, 0)
}
val message = WorkflowFIFOMessage(dataChannelId, 0, DataFrame(tuples))
inputQueue.put(FIFOMessageElement(message))
inputQueue.put(
TimerBasedControlElement(
ControlInvocation(METHOD_PAUSE_WORKER, EmptyRequest(), AsyncRPCContext(SELF, SELF), 0)
)
)
Thread.sleep(1000)
assert(dp.pauseManager.isPaused)
inputQueue.put(
TimerBasedControlElement(
ControlInvocation(METHOD_RESUME_WORKER, EmptyRequest(), AsyncRPCContext(SELF, SELF), 1)
)
)
Thread.sleep(1000)
while (dp.inputManager.hasUnfinishedInput) {
Thread.sleep(100)
}
}
"DP Thread" should "handle pause/resume using fifo messages" in {
val inputQueue = new LinkedBlockingQueue[DPInputQueueElement]()
val dp = new DataProcessor(workerId, x => {}, inputMessageQueue = inputQueue)
dp.inputManager.addPort(mockInputPortId, schema, List.empty, List.empty)
dp.inputGateway.getChannel(dataChannelId).setPortId(mockInputPortId)
dp.adaptiveBatchingMonitor = mock[WorkerTimerService]
(dp.adaptiveBatchingMonitor.resumeAdaptiveBatching _).expects().anyNumberOfTimes()
val dpThread = new DPThread(workerId, dp, logManager, inputQueue)
dp.executor = executor
dpThread.start()
tuples.foreach { x =>
(
(
tuple: Tuple,
input: Int
) => executor.processTupleMultiPort(tuple, input)
)
.expects(x, 0)
}
val message = WorkflowFIFOMessage(dataChannelId, 0, DataFrame(tuples))
val pauseControl = WorkflowFIFOMessage(
controlChannelId,
0,
ControlInvocation(METHOD_PAUSE_WORKER, EmptyRequest(), AsyncRPCContext(SELF, SELF), 0)
)
val resumeControl =
WorkflowFIFOMessage(
controlChannelId,
1,
ControlInvocation(METHOD_RESUME_WORKER, EmptyRequest(), AsyncRPCContext(SELF, SELF), 1)
)
inputQueue.put(FIFOMessageElement(message))
inputQueue.put(
FIFOMessageElement(pauseControl)
)
Thread.sleep(1000)
assert(dp.pauseManager.isPaused)
inputQueue.put(FIFOMessageElement(resumeControl))
Thread.sleep(1000)
while (dp.inputManager.hasUnfinishedInput) {
Thread.sleep(100)
}
}
"DP Thread" should "handle multiple batches from multiple sources" in {
val inputQueue = new LinkedBlockingQueue[DPInputQueueElement]()
val dp = new DataProcessor(workerId, x => {}, inputMessageQueue = inputQueue)
dp.executor = executor
val anotherSenderWorkerId = ActorVirtualIdentity("another")
dp.inputManager.addPort(mockInputPortId, schema, List.empty, List.empty)
dp.inputGateway.getChannel(dataChannelId).setPortId(mockInputPortId)
dp.inputGateway
.getChannel(ChannelIdentity(anotherSenderWorkerId, workerId, isControl = false))
.setPortId(mockInputPortId)
dp.adaptiveBatchingMonitor = mock[WorkerTimerService]
(dp.adaptiveBatchingMonitor.resumeAdaptiveBatching _).expects().anyNumberOfTimes()
val dpThread = new DPThread(workerId, dp, logManager, inputQueue)
dpThread.start()
tuples.foreach { x =>
(
(
tuple: Tuple,
input: Int
) => executor.processTupleMultiPort(tuple, input)
)
.expects(x, 0)
}
val dataChannelID2 = ChannelIdentity(anotherSenderWorkerId, workerId, isControl = false)
val message1 = WorkflowFIFOMessage(dataChannelId, 0, DataFrame(tuples.slice(0, 100)))
val message2 = WorkflowFIFOMessage(dataChannelId, 1, DataFrame(tuples.slice(100, 200)))
val message3 = WorkflowFIFOMessage(dataChannelID2, 0, DataFrame(tuples.slice(300, 1000)))
val message4 = WorkflowFIFOMessage(dataChannelId, 2, DataFrame(tuples.slice(200, 300)))
val message5 = WorkflowFIFOMessage(dataChannelID2, 1, DataFrame(tuples.slice(1000, 5000)))
inputQueue.put(FIFOMessageElement(message1))
inputQueue.put(FIFOMessageElement(message2))
inputQueue.put(FIFOMessageElement(message3))
inputQueue.put(FIFOMessageElement(message4))
inputQueue.put(FIFOMessageElement(message5))
Thread.sleep(1000)
while (dp.inputManager.hasUnfinishedInput) {
Thread.sleep(100)
}
}
"DP Thread" should "write determinant logs to local storage while processing" in {
val inputQueue = new LinkedBlockingQueue[DPInputQueueElement]()
val dp = new DataProcessor(workerId, _ => {}, inputMessageQueue = inputQueue)
dp.executor = executor
val anotherSenderWorkerId = ActorVirtualIdentity("another")
dp.inputManager.addPort(mockInputPortId, schema, List.empty, List.empty)
dp.inputGateway.getChannel(dataChannelId).setPortId(mockInputPortId)
dp.inputGateway
.getChannel(ChannelIdentity(anotherSenderWorkerId, workerId, isControl = false))
.setPortId(mockInputPortId)
dp.adaptiveBatchingMonitor = mock[WorkerTimerService]
(dp.adaptiveBatchingMonitor.resumeAdaptiveBatching _).expects().anyNumberOfTimes()
val logStorage = SequentialRecordStorage.getStorage[ReplayLogRecord](
Some(new URI("ram:///recovery-logs/tmp"))
)
logStorage.deleteStorage()
val logManager: ReplayLogManager =
ReplayLogManager.createLogManager(logStorage, "tmpLog", _ => {})
val dpThread = new DPThread(workerId, dp, logManager, inputQueue)
dpThread.start()
tuples.foreach { x =>
(
(
tuple: Tuple,
input: Int
) => executor.processTupleMultiPort(tuple, input)
)
.expects(x, 0)
}
val dataChannelId2 = ChannelIdentity(anotherSenderWorkerId, workerId, isControl = false)
val message1 = WorkflowFIFOMessage(dataChannelId, 0, DataFrame(tuples.slice(0, 100)))
val message2 = WorkflowFIFOMessage(dataChannelId, 1, DataFrame(tuples.slice(100, 200)))
val message3 = WorkflowFIFOMessage(dataChannelId2, 0, DataFrame(tuples.slice(300, 1000)))
val message4 = WorkflowFIFOMessage(dataChannelId, 2, DataFrame(tuples.slice(200, 300)))
val message5 = WorkflowFIFOMessage(dataChannelId2, 1, DataFrame(tuples.slice(1000, 5000)))
inputQueue.put(FIFOMessageElement(message1))
inputQueue.put(FIFOMessageElement(message2))
inputQueue.put(FIFOMessageElement(message3))
Thread.sleep(1000)
inputQueue.put(FIFOMessageElement(message4))
inputQueue.put(FIFOMessageElement(message5))
Thread.sleep(1000)
while (logManager.getStep < 4999) {
Thread.sleep(100)
}
logManager.sendCommitted(null) // drain in-mem records to flush
logManager.terminate()
val logs = logStorage.getReader("tmpLog").mkRecordIterator().toArray
logStorage.deleteStorage()
assert(logs.length > 1)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/DataProcessorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.worker
import org.apache.texera.amber.core.executor.OperatorExecutor
import org.apache.texera.amber.core.tuple.{AttributeType, Schema, Tuple, TupleLike}
import org.apache.texera.amber.core.virtualidentity._
import org.apache.texera.amber.core.workflow.PortIdentity
import org.apache.texera.amber.core.workflow.WorkflowContext.DEFAULT_WORKFLOW_ID
import org.apache.texera.amber.engine.architecture.logreplay.{ReplayLogManager, ReplayLogRecord}
import org.apache.texera.amber.engine.architecture.messaginglayer.WorkerTimerService
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
AsyncRPCContext,
EmbeddedControlMessage,
EmbeddedControlMessageType,
EmptyRequest
}
import org.apache.texera.amber.engine.architecture.rpc.workerservice.WorkerServiceGrpc.{
METHOD_END_CHANNEL,
METHOD_FLUSH_NETWORK_BUFFER,
METHOD_OPEN_EXECUTOR
}
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.{
DPInputQueueElement,
MainThreadDelegateMessage
}
import org.apache.texera.amber.engine.architecture.worker.statistics.WorkerState.READY
import org.apache.texera.amber.engine.common.ambermessage.{DataFrame, WorkflowFIFOMessage}
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient.ControlInvocation
import org.apache.texera.amber.engine.common.storage.SequentialRecordStorage
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.apache.texera.amber.util.VirtualIdentityUtils
import org.scalamock.scalatest.MockFactory
import org.scalatest.BeforeAndAfterEach
import org.scalatest.flatspec.AnyFlatSpec
import java.util.concurrent.LinkedBlockingQueue
class DataProcessorSpec extends AnyFlatSpec with MockFactory with BeforeAndAfterEach {
private val testOpId = PhysicalOpIdentity(OperatorIdentity("testop"), "main")
private val upstreamOpId = PhysicalOpIdentity(OperatorIdentity("sender"), "main")
private val testWorkerId: ActorVirtualIdentity = VirtualIdentityUtils.createWorkerIdentity(
DEFAULT_WORKFLOW_ID,
testOpId,
0
)
private val senderWorkerId: ActorVirtualIdentity = VirtualIdentityUtils.createWorkerIdentity(
DEFAULT_WORKFLOW_ID,
upstreamOpId,
0
)
private val executor = mock[OperatorExecutor]
private val inputPortId = PortIdentity()
private val outputPortId = PortIdentity()
private val outputHandler = mock[Either[MainThreadDelegateMessage, WorkflowFIFOMessage] => Unit]
private val adaptiveBatchingMonitor = mock[WorkerTimerService]
private val schema: Schema = Schema().add("field1", AttributeType.INTEGER)
private val tuples: Array[Tuple] = (0 until 400)
.map(i => TupleLike(i).enforceSchema(schema))
.toArray
private val logStorage = SequentialRecordStorage.getStorage[ReplayLogRecord](None)
private val logManager: ReplayLogManager =
ReplayLogManager.createLogManager(logStorage, "none", x => {})
private val endChannelPayload = EmbeddedControlMessage(
EmbeddedControlMessageIdentity("EndChannel"),
EmbeddedControlMessageType.PORT_ALIGNMENT,
Seq(),
Map(
testWorkerId.name ->
ControlInvocation(
METHOD_END_CHANNEL.getBareMethodName,
EmptyRequest(),
AsyncRPCContext(ActorVirtualIdentity(""), ActorVirtualIdentity("")),
-1
)
)
)
def mkDataProcessor: DataProcessor = {
val dp: DataProcessor = new DataProcessor(
testWorkerId,
outputHandler,
inputMessageQueue = new LinkedBlockingQueue[DPInputQueueElement]()
)
dp.initTimerService(adaptiveBatchingMonitor)
dp
}
"data processor" should "process data messages" in {
val dp = mkDataProcessor
dp.executor = executor
dp.stateManager.transitTo(READY)
(outputHandler.apply _).expects(*).once()
(executor.open _).expects().once()
tuples.foreach { x =>
(
(
tuple: Tuple,
input: Int
) => executor.processTupleMultiPort(tuple, input)
)
.expects(x, 0)
}
(
(
input: Int
) => executor.produceStateOnFinish(input)
)
.expects(0)
.returning(None)
(
(
input: Int
) => executor.onFinishMultiPort(input)
)
.expects(
0
)
(adaptiveBatchingMonitor.startAdaptiveBatching _).expects().anyNumberOfTimes()
(adaptiveBatchingMonitor.stopAdaptiveBatching _).expects().once()
(executor.close _).expects().once()
(outputHandler.apply _).expects(*).anyNumberOfTimes()
dp.inputManager.addPort(inputPortId, schema, List.empty, List.empty)
dp.inputGateway
.getChannel(ChannelIdentity(senderWorkerId, testWorkerId, isControl = false))
.setPortId(inputPortId)
dp.outputManager.addPort(outputPortId, schema, None)
dp.processDCM(
ChannelIdentity(CONTROLLER, testWorkerId, isControl = true),
ControlInvocation(
METHOD_OPEN_EXECUTOR,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, testWorkerId),
0
)
)
dp.processDataPayload(
ChannelIdentity(senderWorkerId, testWorkerId, isControl = false),
DataFrame(tuples)
)
while (dp.inputManager.hasUnfinishedInput || dp.outputManager.hasUnfinishedOutput) {
dp.continueDataProcessing()
}
dp.processECM(
ChannelIdentity(senderWorkerId, testWorkerId, isControl = false),
endChannelPayload,
logManager
)
while (dp.inputManager.hasUnfinishedInput || dp.outputManager.hasUnfinishedOutput) {
dp.continueDataProcessing()
}
}
"data processor" should "process control messages during data processing" in {
val dp = mkDataProcessor
dp.executor = executor
dp.stateManager.transitTo(READY)
(outputHandler.apply _).expects(*).anyNumberOfTimes()
(executor.open _).expects().once()
tuples.foreach { x =>
(
(
tuple: Tuple,
input: Int
) => executor.processTupleMultiPort(tuple, input)
)
.expects(x, 0)
}
(
(
input: Int
) => executor.produceStateOnFinish(input)
)
.expects(0)
.returning(None)
(
(
input: Int
) => executor.onFinishMultiPort(input)
)
.expects(0)
(adaptiveBatchingMonitor.startAdaptiveBatching _).expects().anyNumberOfTimes()
dp.inputManager.addPort(inputPortId, schema, List.empty, List.empty)
dp.inputGateway
.getChannel(ChannelIdentity(senderWorkerId, testWorkerId, isControl = false))
.setPortId(inputPortId)
dp.outputManager.addPort(outputPortId, schema, None)
dp.processDCM(
ChannelIdentity(CONTROLLER, testWorkerId, isControl = true),
ControlInvocation(
METHOD_OPEN_EXECUTOR,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, testWorkerId),
0
)
)
dp.processDataPayload(
ChannelIdentity(senderWorkerId, testWorkerId, isControl = false),
DataFrame(tuples)
)
while (dp.inputManager.hasUnfinishedInput || dp.outputManager.hasUnfinishedOutput) {
dp.processDCM(
ChannelIdentity(CONTROLLER, testWorkerId, isControl = true),
ControlInvocation(
METHOD_FLUSH_NETWORK_BUFFER,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, testWorkerId),
1
)
)
dp.continueDataProcessing()
}
(adaptiveBatchingMonitor.stopAdaptiveBatching _).expects().once()
(executor.close _).expects().once()
dp.processECM(
ChannelIdentity(senderWorkerId, testWorkerId, isControl = false),
endChannelPayload,
logManager
)
while (dp.inputManager.hasUnfinishedInput || dp.outputManager.hasUnfinishedOutput) {
dp.continueDataProcessing()
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/PauseTypeSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.worker
import org.apache.texera.amber.core.virtualidentity.EmbeddedControlMessageIdentity
import org.scalatest.flatspec.AnyFlatSpec
class PauseTypeSpec extends AnyFlatSpec {
// --- singletons ------------------------------------------------------------
//
// The sealed-trait subtype relationship is enforced at compile time by the
// type ascriptions (`val u: PauseType = UserPause`, etc.) used below. There
// is no runtime test for "singletons extend PauseType" because that would
// be tautological — if any singleton stopped extending the trait, this
// file would fail to compile.
"PauseType singletons" should "compare equal to themselves and unequal to each other" in {
// Widen to PauseType so the compiler doesn't reduce inter-singleton
// comparisons to constant `false` at compile time.
val u: PauseType = UserPause
val b: PauseType = BackpressurePause
val o: PauseType = OperatorLogicPause
assert(u == UserPause)
assert(b == BackpressurePause)
assert(o == OperatorLogicPause)
assert(u != b)
assert(u != o)
assert(b != o)
}
it should "be the same singleton instance per access (object identity)" in {
assert((UserPause: AnyRef) eq UserPause)
assert((BackpressurePause: AnyRef) eq BackpressurePause)
assert((OperatorLogicPause: AnyRef) eq OperatorLogicPause)
}
// --- ECMPause --------------------------------------------------------------
"ECMPause" should "carry the EmbeddedControlMessageIdentity it was constructed with" in {
val id = EmbeddedControlMessageIdentity("ckpt-1")
val p = ECMPause(id)
assert(p.id == id)
}
it should "support case-class value equality and hashCode (same id → equal)" in {
val a = ECMPause(EmbeddedControlMessageIdentity("ckpt-1"))
val b = ECMPause(EmbeddedControlMessageIdentity("ckpt-1"))
val c = ECMPause(EmbeddedControlMessageIdentity("ckpt-2"))
assert(a == b)
assert(a.hashCode == b.hashCode)
assert(a != c)
}
it should "not equal any of the singleton PauseTypes" in {
// Subtype relationship is already proven by the `: PauseType` ascription;
// what we actually want to lock down here is the cross-kind inequality:
// an ECMPause (with any id) must not collide with any singleton kind.
val p: PauseType = ECMPause(EmbeddedControlMessageIdentity("ckpt"))
assert(p != UserPause)
assert(p != BackpressurePause)
assert(p != OperatorLogicPause)
}
// --- pattern matching ------------------------------------------------------
"PauseType" should "support exhaustive pattern matching that distinguishes each subtype" in {
def label(p: PauseType): String =
p match {
case UserPause => "user"
case BackpressurePause => "backpressure"
case OperatorLogicPause => "operator-logic"
case ECMPause(_) => "ecm"
}
assert(label(UserPause) == "user")
assert(label(BackpressurePause) == "backpressure")
assert(label(OperatorLogicPause) == "operator-logic")
assert(label(ECMPause(EmbeddedControlMessageIdentity("x"))) == "ecm")
}
// --- Set-based coexistence (the contract PauseManager actually relies on) --
// PauseManager stores active pauses in a `HashSet[PauseType]` (additive,
// no priority — resuming one type only removes that type). The override-order
// semantics that the data type would need to support priorities don't exist
// in PauseType; the data type only has to behave well as Set elements.
// These tests pin that contract here. The multi-pause coexistence behavior
// through PauseManager.pause/resume/isPaused is covered separately in
// WorkerManagersSpec.
it should "coexist as distinct elements in a Set without aliasing" in {
val active: Set[PauseType] = Set(
UserPause,
BackpressurePause,
OperatorLogicPause,
ECMPause(EmbeddedControlMessageIdentity("ckpt-1"))
)
assert(active.size == 4, "all four pause kinds must be distinct Set elements")
assert(active.contains(UserPause))
assert(active.contains(BackpressurePause))
assert(active.contains(OperatorLogicPause))
assert(active.contains(ECMPause(EmbeddedControlMessageIdentity("ckpt-1"))))
}
it should "deduplicate identical pauses inside a Set" in {
// PauseManager.pause(t) treats duplicate pauses as a no-op. That works
// because Set deduplication leans on PauseType.equals/hashCode — pin it.
val active: Set[PauseType] = Set(
UserPause,
UserPause, // singleton — must collapse
ECMPause(EmbeddedControlMessageIdentity("ckpt-1")),
ECMPause(EmbeddedControlMessageIdentity("ckpt-1")) // same id — must collapse
)
assert(active.size == 2)
}
it should "treat ECMPause instances with different ids as distinct Set elements" in {
// Two checkpoint pauses with different ids must be independently
// tracked, so the manager can resume one without clearing the other.
val active: Set[PauseType] = Set(
ECMPause(EmbeddedControlMessageIdentity("ckpt-1")),
ECMPause(EmbeddedControlMessageIdentity("ckpt-2"))
)
assert(active.size == 2)
val afterResumeFirst = active - ECMPause(EmbeddedControlMessageIdentity("ckpt-1"))
assert(afterResumeFirst.size == 1)
assert(afterResumeFirst.contains(ECMPause(EmbeddedControlMessageIdentity("ckpt-2"))))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/WorkerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.worker
import org.apache.pekko.actor.{ActorRef, ActorSystem, Props}
import org.apache.pekko.testkit.{ImplicitSender, TestActorRef, TestKit}
import org.apache.texera.amber.clustering.SingleNodeListener
import org.apache.texera.amber.core.executor.{OpExecWithClassName, OperatorExecutor}
import org.apache.texera.amber.core.tuple._
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ChannelIdentity,
OperatorIdentity,
PhysicalOpIdentity
}
import org.apache.texera.amber.core.workflow.{PhysicalLink, PortIdentity}
import org.apache.texera.amber.engine.architecture.common.WorkflowActor.NetworkMessage
import org.apache.texera.amber.engine.architecture.rpc.controlcommands._
import org.apache.texera.amber.engine.architecture.rpc.workerservice.WorkerServiceGrpc._
import org.apache.texera.amber.engine.architecture.scheduling.config.WorkerConfig
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.OneToOnePartitioning
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.{
DPInputQueueElement,
MainThreadDelegateMessage,
WorkerReplayInitialization
}
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.common.ambermessage.{
DataFrame,
DataPayload,
WorkflowFIFOMessage
}
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.scalamock.scalatest.MockFactory
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpecLike
import java.util.concurrent.{CompletableFuture, LinkedBlockingQueue}
import scala.collection.mutable
import scala.concurrent.duration.MILLISECONDS
import scala.util.Random
class DummyOperatorExecutor extends OperatorExecutor {
override def processTuple(tuple: Tuple, port: Int): Iterator[TupleLike] = {
Iterator(tuple)
}
}
class WorkerSpec
extends TestKit(ActorSystem("WorkerSpec", AmberRuntime.pekkoConfig))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll
with MockFactory {
def mkSchema(fields: Any*): Schema = {
var schema = Schema()
fields.indices.foreach { i =>
schema = schema.add(new Attribute("field" + i, AttributeType.ANY))
}
schema
}
def mkTuple(fields: Any*): Tuple = {
Tuple.builder(mkSchema(fields: _*)).addSequentially(fields.toArray).build()
}
override def beforeAll(): Unit = {
system.actorOf(Props[SingleNodeListener](), "cluster-info")
}
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
private val identifier1 = ActorVirtualIdentity("Worker:WF1-E1-op-layer-1")
private val identifier2 = ActorVirtualIdentity("Worker:WF1-E1-op-layer-2")
private val operatorIdentity = OperatorIdentity("testOperator")
private val mockPortId = PortIdentity()
private val mockLink =
PhysicalLink(
PhysicalOpIdentity(operatorIdentity, "1st-physical-op"),
mockPortId,
PhysicalOpIdentity(operatorIdentity, "2nd-physical-op"),
mockPortId
)
private val mockPolicy =
OneToOnePartitioning(10, Seq(ChannelIdentity(identifier1, identifier2, isControl = false)))
def sendControlToWorker(
worker: ActorRef,
controls: Array[ControlInvocation],
beginSeqNum: Long = 0
): Unit = {
var seq = beginSeqNum
controls.foreach { ctrl =>
worker ! NetworkMessage(
seq,
WorkflowFIFOMessage(ChannelIdentity(CONTROLLER, identifier1, isControl = true), seq, ctrl)
)
seq += 1
}
}
def mkWorker(expectedOutput: Iterable[TupleLike]): (ActorRef, CompletableFuture[Boolean]) = {
val expected = mutable.Queue.from(expectedOutput)
val completeStatus = new CompletableFuture[Boolean]()
val mockHandler: Either[MainThreadDelegateMessage, WorkflowFIFOMessage] => Unit = {
case Left(value) => ???
case Right(value) =>
value match {
case WorkflowFIFOMessage(_, _, payload) =>
payload match {
case payload: DataPayload =>
payload.asInstanceOf[DataFrame].frame.foreach { item =>
val expectedOutput = expected.dequeue()
if (expectedOutput != item) {
completeStatus.complete(false)
} else {
if (expected.isEmpty) {
completeStatus.complete(true)
}
}
}
case _ => //skip
}
}
}
val worker = TestActorRef(
new WorkflowWorker(
WorkerConfig(identifier1),
WorkerReplayInitialization(restoreConfOpt = None, faultToleranceConfOpt = None)
) {
this.dp = new DataProcessor(
identifier1,
mockHandler,
inputMessageQueue = new LinkedBlockingQueue[DPInputQueueElement]()
)
this.dp.initTimerService(timerService)
dpThread = new DPThread(
actorId,
dp,
logManager,
inputQueue
)
}
)
val invocation = AsyncRPCClient.ControlInvocation(
METHOD_ADD_PARTITIONING,
AddPartitioningRequest(mockLink, mockPolicy),
AsyncRPCContext(CONTROLLER, identifier1),
0
)
val addPort1 = AsyncRPCClient.ControlInvocation(
METHOD_ASSIGN_PORT,
AssignPortRequest(mockPortId, input = true, mkSchema(1).toRawSchema, List(""), List()),
AsyncRPCContext(CONTROLLER, identifier1),
1
)
val addPort2 = AsyncRPCClient.ControlInvocation(
METHOD_ASSIGN_PORT,
AssignPortRequest(mockPortId, input = false, mkSchema(1).toRawSchema, List(""), List()),
AsyncRPCContext(CONTROLLER, identifier1),
2
)
val addInputChannel = AsyncRPCClient.ControlInvocation(
METHOD_ADD_INPUT_CHANNEL,
AddInputChannelRequest(
ChannelIdentity(identifier2, identifier1, isControl = false),
mockLink.toPortId
),
AsyncRPCContext(CONTROLLER, identifier1),
3
)
val initializeOperatorLogic = AsyncRPCClient.ControlInvocation(
METHOD_INITIALIZE_EXECUTOR,
InitializeExecutorRequest(
1,
OpExecWithClassName(
"org.apache.texera.amber.engine.architecture.worker.DummyOperatorExecutor"
),
isSource = false
),
AsyncRPCContext(CONTROLLER, identifier1),
4
)
sendControlToWorker(
worker,
Array(invocation, addPort1, addPort2, addInputChannel, initializeOperatorLogic)
)
(worker, completeStatus)
}
"Worker" should "process data messages correctly" in {
val (worker, future) = mkWorker(Array(mkTuple(1)))
worker ! NetworkMessage(
0,
WorkflowFIFOMessage(
ChannelIdentity(identifier2, identifier1, isControl = false),
0,
DataFrame(Array(mkTuple(1)))
)
)
worker ! AsyncRPCClient.ControlInvocation(
METHOD_FLUSH_NETWORK_BUFFER,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, identifier1),
1
)
//wait test to finish
assert(future.get(3000, MILLISECONDS))
}
"Worker" should "process batches correctly" in {
ignoreMsg {
case a => println(a); true
}
def mkBatch(start: Int, end: Int): Array[Tuple] = {
(start until end).map { x =>
mkTuple(x)
}.toArray
}
val batch1 = mkBatch(0, 400)
val batch2 = mkBatch(400, 500)
val batch3 = mkBatch(500, 800)
val (worker, future) = mkWorker(mkBatch(0, 800))
worker ! NetworkMessage(
3,
WorkflowFIFOMessage(
ChannelIdentity(identifier2, identifier1, isControl = false),
0,
DataFrame(batch1)
)
)
worker ! NetworkMessage(
2,
WorkflowFIFOMessage(
ChannelIdentity(identifier2, identifier1, isControl = false),
1,
DataFrame(batch2)
)
)
Thread.sleep(1000)
worker ! NetworkMessage(
4,
WorkflowFIFOMessage(
ChannelIdentity(identifier2, identifier1, isControl = false),
2,
DataFrame(batch3)
)
)
//wait test to finish
assert(future.get(3000, MILLISECONDS))
}
"Worker" should "accept messages in fifo order" in {
ignoreMsg {
case a => println(a); true
}
val (worker, future) = mkWorker((0 until 100).map(mkTuple(_)))
Random
.shuffle((0 until 50).map { i =>
NetworkMessage(
i + 2,
WorkflowFIFOMessage(
ChannelIdentity(identifier2, identifier1, isControl = false),
i,
DataFrame(Array(mkTuple(i)))
)
)
})
.foreach { x =>
worker ! x
}
Thread.sleep(1000)
Random
.shuffle((50 until 100).map { i =>
NetworkMessage(
i + 2,
WorkflowFIFOMessage(
ChannelIdentity(identifier2, identifier1, isControl = false),
i,
DataFrame(Array(mkTuple(i)))
)
)
})
.foreach { x =>
worker ! x
}
//wait test to finish
assert(future.get(3000, MILLISECONDS))
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/OutputPortResultWriterThreadSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.worker.managers
import org.apache.texera.amber.core.storage.model.BufferedItemWriter
import org.apache.texera.amber.core.tuple.Tuple
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.core.workflow.PortIdentity
import org.apache.texera.amber.engine.architecture.messaginglayer.{
NetworkOutputGateway,
OutputManager
}
import org.apache.texera.amber.engine.common.ambermessage.WorkflowFIFOMessage
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class OutputPortResultWriterThreadSpec extends AnyFlatSpec {
private class StubWriter(
onPutOne: () => Unit = () => (),
onClose: () => Unit = () => ()
) extends BufferedItemWriter[Tuple] {
val bufferSize: Int = 1024
var closeCalled = false
def open(): Unit = ()
def putOne(item: Tuple): Unit = onPutOne()
def removeOne(item: Tuple): Unit = ()
def close(): Unit = {
closeCalled = true
onClose()
}
}
private def throwing(msg: String): () => Unit = () => throw new RuntimeException(msg)
"OutputPortResultWriterThread" should "leave getFailure empty on a clean run" in {
val writer = new StubWriter()
val thread = new OutputPortResultWriterThread(writer)
thread.start()
thread.queue.put(Right(PortStorageWriterTerminateSignal))
thread.join()
assert(thread.getFailure.isEmpty)
assert(writer.closeCalled)
}
it should "capture a close() exception in getFailure so the worker can re-throw" in {
val writer = new StubWriter(onClose = throwing("test close failure"))
val thread = new OutputPortResultWriterThread(writer)
thread.start()
thread.queue.put(Right(PortStorageWriterTerminateSignal))
thread.join()
assert(thread.getFailure.exists(_.getMessage.contains("test close failure")))
assert(writer.closeCalled)
}
it should "capture a putOne exception and still call close()" in {
val writer = new StubWriter(onPutOne = throwing("test putOne failure"))
val thread = new OutputPortResultWriterThread(writer)
thread.start()
thread.queue.put(Left(null.asInstanceOf[Tuple]))
thread.queue.put(Right(PortStorageWriterTerminateSignal))
thread.join()
assert(thread.getFailure.exists(_.getMessage.contains("test putOne failure")))
// The finally clause must run close() even after putOne threw, or
// the underlying writer leaks file handles.
assert(writer.closeCalled)
}
it should "preserve both errors when putOne and close() fail in the same run" in {
val writer = new StubWriter(
onPutOne = throwing("test putOne failure"),
onClose = throwing("test close failure")
)
val thread = new OutputPortResultWriterThread(writer)
thread.start()
thread.queue.put(Left(null.asInstanceOf[Tuple]))
thread.queue.put(Right(PortStorageWriterTerminateSignal))
thread.join()
val captured = thread.getFailure.getOrElse(fail("expected putOne failure"))
assert(captured.getMessage.contains("test putOne failure"))
assert(
captured.getSuppressed.exists(_.getMessage.contains("test close failure")),
"close() failure should be attached as suppressed on the original putOne failure"
)
}
// Reach into OutputManager's private outputPortResultWriterThreads map to
// install a writer thread whose close() has already failed. This pins the
// contract that closeOutputStorageWriterIfNeeded re-throws the captured
// failure, which is the bridge from the writer thread to the DP thread →
// worker actor → controller supervisor → FatalError to client.
private def installWriterThread(
manager: OutputManager,
portId: PortIdentity,
thread: OutputPortResultWriterThread
): Unit = {
val field = classOf[OutputManager]
.getDeclaredField("outputPortResultWriterThreads")
field.setAccessible(true)
field
.get(manager)
.asInstanceOf[mutable.HashMap[PortIdentity, OutputPortResultWriterThread]]
.put(portId, thread)
}
"OutputManager.closeOutputStorageWriterIfNeeded" should
"re-throw the writer thread's captured failure" in {
val identifier = ActorVirtualIdentity("test-worker")
val outputManager = new OutputManager(
identifier,
new NetworkOutputGateway(identifier, (_: WorkflowFIFOMessage) => ())
)
val portId = PortIdentity()
val failingWriter = new StubWriter(onClose = throwing("test close failure"))
val failingThread = new OutputPortResultWriterThread(failingWriter)
failingThread.start()
installWriterThread(outputManager, portId, failingThread)
val ex = intercept[RuntimeException] {
outputManager.closeOutputStorageWriterIfNeeded(portId)
}
assert(ex.getMessage.contains("test close failure"))
}
it should "be a no-op when the port has no writer thread" in {
val identifier = ActorVirtualIdentity("test-worker")
val outputManager = new OutputManager(
identifier,
new NetworkOutputGateway(identifier, (_: WorkflowFIFOMessage) => ())
)
// No installWriterThread call — the port has never had a writer.
outputManager.closeOutputStorageWriterIfNeeded(PortIdentity())
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/managers/WorkerManagersSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.worker.managers
import org.apache.texera.amber.core.executor.OperatorExecutor
import org.apache.texera.amber.core.tuple.{Tuple, TupleLike}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.core.workflow.PortIdentity
import org.scalatest.flatspec.AnyFlatSpec
class WorkerManagersSpec extends AnyFlatSpec {
// ---------------------------------------------------------------------------
// StatisticsManager
// ---------------------------------------------------------------------------
// Minimal OperatorExecutor instance — StatisticsManager.getStatistics ignores
// its argument today, so any concrete impl works.
private val nullExec: OperatorExecutor = new OperatorExecutor {
override def processTuple(t: Tuple, port: Int): Iterator[TupleLike] = Iterator.empty
}
"StatisticsManager" should "default all counters to zero" in {
val sm = new StatisticsManager()
assert(sm.getInputTupleCount == 0L)
assert(sm.getOutputTupleCount == 0L)
val s = sm.getStatistics(nullExec)
assert(s.inputTupleMetrics.isEmpty)
assert(s.outputTupleMetrics.isEmpty)
assert(s.dataProcessingTime == 0L)
assert(s.controlProcessingTime == 0L)
// totalExecutionTime - data - control = 0 - 0 - 0 = 0
assert(s.idleTime == 0L)
}
"StatisticsManager.increaseInputStatistics" should "accumulate count and size per port" in {
val sm = new StatisticsManager()
sm.increaseInputStatistics(PortIdentity(0), 100)
sm.increaseInputStatistics(PortIdentity(0), 50)
sm.increaseInputStatistics(PortIdentity(1), 25)
assert(sm.getInputTupleCount == 3L)
val byPort = sm
.getStatistics(nullExec)
.inputTupleMetrics
.map(m => m.portId -> (m.tupleMetrics.count, m.tupleMetrics.size))
.toMap
assert(byPort(PortIdentity(0)) == (2L, 150L))
assert(byPort(PortIdentity(1)) == (1L, 25L))
}
it should "reject negative tuple sizes" in {
val sm = new StatisticsManager()
assertThrows[IllegalArgumentException] {
sm.increaseInputStatistics(PortIdentity(0), -1)
}
}
"StatisticsManager.increaseOutputStatistics" should "accumulate count and size per port" in {
val sm = new StatisticsManager()
sm.increaseOutputStatistics(PortIdentity(0), 30)
sm.increaseOutputStatistics(PortIdentity(0), 70)
assert(sm.getOutputTupleCount == 2L)
val out = sm.getStatistics(nullExec).outputTupleMetrics
assert(out.size == 1)
assert(out.head.tupleMetrics.count == 2L)
assert(out.head.tupleMetrics.size == 100L)
}
it should "reject negative tuple sizes" in {
val sm = new StatisticsManager()
assertThrows[IllegalArgumentException] {
sm.increaseOutputStatistics(PortIdentity(0), -1)
}
}
"StatisticsManager.increaseDataProcessingTime" should "accumulate time and reject negatives" in {
val sm = new StatisticsManager()
sm.increaseDataProcessingTime(100)
sm.increaseDataProcessingTime(50)
assert(sm.getStatistics(nullExec).dataProcessingTime == 150L)
assertThrows[IllegalArgumentException] {
sm.increaseDataProcessingTime(-1)
}
}
"StatisticsManager.increaseControlProcessingTime" should "accumulate time and reject negatives" in {
val sm = new StatisticsManager()
sm.increaseControlProcessingTime(20)
sm.increaseControlProcessingTime(40)
assert(sm.getStatistics(nullExec).controlProcessingTime == 60L)
assertThrows[IllegalArgumentException] {
sm.increaseControlProcessingTime(-1)
}
}
"StatisticsManager.updateTotalExecutionTime" should "compute idleTime as total - data - control" in {
val sm = new StatisticsManager()
sm.initializeWorkerStartTime(1000L)
sm.increaseDataProcessingTime(200L)
sm.increaseControlProcessingTime(100L)
sm.updateTotalExecutionTime(2000L)
val s = sm.getStatistics(nullExec)
assert(s.dataProcessingTime == 200L)
assert(s.controlProcessingTime == 100L)
assert(s.idleTime == 2000L - 1000L - 200L - 100L)
}
it should "reject a current time before workerStartTime" in {
val sm = new StatisticsManager()
sm.initializeWorkerStartTime(1000L)
assertThrows[IllegalArgumentException] {
sm.updateTotalExecutionTime(500L)
}
}
// ---------------------------------------------------------------------------
// SerializationManager
// ---------------------------------------------------------------------------
"SerializationManager.applySerialization" should "be a no-op when no callback is registered" in {
val sm = new SerializationManager(ActorVirtualIdentity("worker-1"))
sm.applySerialization() // does not throw
succeed
}
it should "invoke the registered callback exactly once and then clear it" in {
val sm = new SerializationManager(ActorVirtualIdentity("worker-1"))
var calls = 0
sm.registerSerialization(() => calls += 1)
sm.applySerialization()
sm.applySerialization() // second call must be a no-op (callback was cleared)
assert(calls == 1)
}
it should "let the latest registered callback overwrite any previous one" in {
val sm = new SerializationManager(ActorVirtualIdentity("worker-1"))
var firstCalls = 0
var secondCalls = 0
sm.registerSerialization(() => firstCalls += 1)
sm.registerSerialization(() => secondCalls += 1)
sm.applySerialization()
assert(firstCalls == 0)
assert(secondCalls == 1)
}
// ---------------------------------------------------------------------------
// PauseManager (with a stub InputGateway)
// ---------------------------------------------------------------------------
import org.apache.texera.amber.engine.architecture.logreplay.OrderEnforcer
import org.apache.texera.amber.engine.architecture.messaginglayer.{AmberFIFOChannel, InputGateway}
import org.apache.texera.amber.engine.architecture.worker.{
BackpressurePause,
OperatorLogicPause,
PauseManager,
UserPause
}
/**
* Stub gateway with a fixed set of channels. `tryPickChannel` /
* `tryPickControlChannel` are unused by PauseManager and return None.
*/
private class StubGateway(channels: Map[ChannelIdentity, AmberFIFOChannel]) extends InputGateway {
override def tryPickControlChannel: Option[AmberFIFOChannel] = None
override def tryPickChannel: Option[AmberFIFOChannel] = None
override def getAllChannels: Iterable[AmberFIFOChannel] = channels.values
override def getAllDataChannels: Iterable[AmberFIFOChannel] =
channels.collect { case (cid, ch) if !cid.isControl => ch }
override def getChannel(channelId: ChannelIdentity): AmberFIFOChannel = channels(channelId)
override def getAllControlChannels: Iterable[AmberFIFOChannel] =
channels.collect { case (cid, ch) if cid.isControl => ch }
override def addEnforcer(enforcer: OrderEnforcer): Unit = ()
}
private val workerId = ActorVirtualIdentity("w")
private val dataA =
ChannelIdentity(ActorVirtualIdentity("up1"), workerId, isControl = false)
private val dataB =
ChannelIdentity(ActorVirtualIdentity("up2"), workerId, isControl = false)
private val ctrl =
ChannelIdentity(ActorVirtualIdentity("ctrl"), workerId, isControl = true)
private def newGateway(): (StubGateway, AmberFIFOChannel, AmberFIFOChannel, AmberFIFOChannel) = {
val a = new AmberFIFOChannel(dataA)
val b = new AmberFIFOChannel(dataB)
val c = new AmberFIFOChannel(ctrl)
val gw = new StubGateway(Map(dataA -> a, dataB -> b, ctrl -> c))
(gw, a, b, c)
}
"PauseManager.isPaused" should "be false initially" in {
val (gw, _, _, _) = newGateway()
val pm = new PauseManager(workerId, gw)
assert(!pm.isPaused)
}
"PauseManager.pause" should "disable every data channel and report paused" in {
val (gw, a, b, c) = newGateway()
val pm = new PauseManager(workerId, gw)
pm.pause(UserPause)
assert(pm.isPaused)
assert(!a.isEnabled)
assert(!b.isEnabled)
// control channel is not in getAllDataChannels, so it stays enabled
assert(c.isEnabled)
}
"PauseManager.resume" should "re-enable all data channels when no specific input pauses remain" in {
val (gw, a, b, c) = newGateway()
val pm = new PauseManager(workerId, gw)
pm.pause(UserPause)
pm.resume(UserPause)
assert(!pm.isPaused)
assert(a.isEnabled)
assert(b.isEnabled)
assert(c.isEnabled)
}
it should "stay paused if other global pauses are still active" in {
val (gw, a, _, _) = newGateway()
val pm = new PauseManager(workerId, gw)
pm.pause(UserPause)
pm.pause(BackpressurePause)
pm.resume(UserPause)
// backpressure still pausing → channels stay disabled
assert(pm.isPaused)
assert(!a.isEnabled)
}
"PauseManager.pauseInputChannel" should "disable only the listed channels" in {
val (gw, a, b, _) = newGateway()
val pm = new PauseManager(workerId, gw)
pm.pauseInputChannel(OperatorLogicPause, List(dataA))
// global pauses are empty → not "isPaused"
assert(!pm.isPaused)
assert(!a.isEnabled)
assert(b.isEnabled)
}
it should "leave still-paused specific channels disabled when only one of multiple specific pauses is resumed" in {
val (gw, a, b, _) = newGateway()
val pm = new PauseManager(workerId, gw)
pm.pauseInputChannel(OperatorLogicPause, List(dataA))
pm.pauseInputChannel(BackpressurePause, List(dataB))
pm.resume(OperatorLogicPause)
// dataA's only specific pause was OperatorLogicPause → re-enabled.
// dataB still has BackpressurePause → still disabled.
assert(a.isEnabled)
assert(!b.isEnabled)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/architecture/worker/promisehandlers/EndHandlerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.architecture.worker.promisehandlers
import com.twitter.util.{Await, Duration, Future}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
AsyncRPCContext,
EmptyRequest
}
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.EmptyReturn
import org.apache.texera.amber.engine.architecture.rpc.workerservice.WorkerServiceGrpc.METHOD_QUERY_STATISTICS
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.{
ActorCommandElement,
DPInputQueueElement,
FIFOMessageElement,
MainThreadDelegateMessage
}
import org.apache.texera.amber.engine.architecture.worker.{
DataProcessor,
DataProcessorRPCHandlerInitializer
}
import org.apache.texera.amber.engine.common.actormessage.Backpressure
import org.apache.texera.amber.engine.common.ambermessage.WorkflowFIFOMessage
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient.ControlInvocation
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.scalatest.flatspec.AnyFlatSpec
import java.util.concurrent.LinkedBlockingQueue
/**
* `endWorker` is the controller's acknowledgement point before it sends actor-level `gracefulStop`.
*
* A successful reply means the worker has drained every queued workflow message. If the queue still contains work,
* the handler must fail so the region coordinator can retry the kill instead of stopping the actor too early.
*/
class EndHandlerSpec extends AnyFlatSpec {
private val workerId = ActorVirtualIdentity("Worker:WF1-test-op-main-0")
private val rpcContext = AsyncRPCContext(CONTROLLER, workerId)
private val awaitTimeout = Duration.fromSeconds(1)
private def createEndHandlerForQueue(
queue: LinkedBlockingQueue[DPInputQueueElement]
): DataProcessorRPCHandlerInitializer = {
val outputHandler: Either[MainThreadDelegateMessage, WorkflowFIFOMessage] => Unit = _ => ()
val dp = new DataProcessor(workerId, outputHandler, queue)
new DataProcessorRPCHandlerInitializer(dp)
}
private def await[T](future: Future[T]): T = Await.result(future, awaitTimeout)
private def assertEndWorkerFails(handler: DataProcessorRPCHandlerInitializer): Unit = {
val exception = intercept[IllegalStateException] {
await(handler.endWorker(EmptyRequest(), rpcContext))
}
assert(exception.getMessage == "worker still has unprocessed messages")
}
private def queueWithFifoControlMessage(): LinkedBlockingQueue[DPInputQueueElement] = {
val queue = new LinkedBlockingQueue[DPInputQueueElement]()
queue.put(
FIFOMessageElement(
WorkflowFIFOMessage(
ChannelIdentity(CONTROLLER, workerId, isControl = true),
0,
ControlInvocation(METHOD_QUERY_STATISTICS, EmptyRequest(), rpcContext, 1)
)
)
)
queue
}
private def queueWithActorCommand(): LinkedBlockingQueue[DPInputQueueElement] = {
val queue = new LinkedBlockingQueue[DPInputQueueElement]()
queue.put(ActorCommandElement(Backpressure(enableBackpressure = true)))
queue
}
"EndHandler" should "reply successfully when there are no unprocessed messages" in {
val handler = createEndHandlerForQueue(new LinkedBlockingQueue[DPInputQueueElement]())
assert(await(handler.endWorker(EmptyRequest(), rpcContext)) == EmptyReturn())
}
it should "fail when a FIFO control message is still queued" in {
val handler = createEndHandlerForQueue(queueWithFifoControlMessage())
assertEndWorkerFails(handler)
}
it should "fail when an actor command is still queued" in {
val handler = createEndHandlerForQueue(queueWithActorCommand())
assertEndWorkerFails(handler)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/common/CheckpointSubsystemSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.common
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.serialization.{Serialization, SerializationExtension}
import org.apache.pekko.testkit.TestKit
import org.apache.texera.amber.core.tuple.TupleLike
import org.apache.texera.amber.core.workflow.PortIdentity
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
class CheckpointSubsystemSpec extends AnyFlatSpec with BeforeAndAfterAll {
// Suite-local actor system. We also inject it into AmberRuntime via
// reflection so that CheckpointState.save/load (which hard-code
// AmberRuntime.serde) reuse the same system. Both the suite-local system
// and AmberRuntime's reference are torn down in afterAll, so no Pekko
// threads outlive the test (matching ControllerSpec/WorkerSpec hygiene).
private val testSystem: ActorSystem =
ActorSystem("CheckpointSubsystemSpec-test", AmberRuntime.pekkoConfig)
private val testSerde: Serialization = SerializationExtension(testSystem)
private def setAmberRuntimeField(name: String, value: AnyRef): Unit = {
val field = AmberRuntime.getClass.getDeclaredField(name)
field.setAccessible(true)
field.set(AmberRuntime, value)
}
override protected def beforeAll(): Unit = {
super.beforeAll()
setAmberRuntimeField("_actorSystem", testSystem)
setAmberRuntimeField("_serde", testSerde)
}
override protected def afterAll(): Unit = {
setAmberRuntimeField("_serde", null)
setAmberRuntimeField("_actorSystem", null)
TestKit.shutdownActorSystem(testSystem)
super.afterAll()
}
// ---------------------------------------------------------------------------
// SerializedState
// ---------------------------------------------------------------------------
"SerializedState" should "expose stable well-known key constants" in {
// These constants are referenced from outside the engine; pin the strings
// so a rename surfaces as a test failure.
assert(SerializedState.CP_STATE_KEY == "Amber_CPState")
assert(SerializedState.DP_STATE_KEY == "Amber_DPState")
assert(SerializedState.IN_FLIGHT_MSG_KEY == "Amber_Inflight_Messages")
assert(SerializedState.DP_QUEUED_MSG_KEY == "Amber_DP_Queued_Messages")
assert(SerializedState.OUTPUT_MSG_KEY == "Amber_Output_Messages")
}
it should "round-trip a value through fromObject / toObject using a suite-local Serialization" in {
// Use the suite-local serde directly so this case does not even touch
// AmberRuntime.
val original: java.lang.Integer = Integer.valueOf(42)
val state = SerializedState.fromObject(original, testSerde)
assert(state.bytes.length > 0)
assert(state.size() == state.bytes.length.toLong)
val restored = state.toObject[java.lang.Integer](testSerde)
assert(restored == original)
}
it should "carry the serializer id and manifest given at construction" in {
val s = SerializedState(Array[Byte](1, 2, 3), serializerId = 7, manifest = "manifest-x")
assert(s.bytes.toSeq == Seq[Byte](1, 2, 3))
assert(s.serializerId == 7)
assert(s.manifest == "manifest-x")
assert(s.size() == 3L)
}
// ---------------------------------------------------------------------------
// CheckpointState
// ---------------------------------------------------------------------------
"CheckpointState" should "default to size = 0 with no entries" in {
val cp = new CheckpointState()
assert(cp.size() == 0L)
assert(!cp.has("anything"))
}
"CheckpointState.save / load" should "round-trip a primitive value" in {
val cp = new CheckpointState()
cp.save("answer", java.lang.Integer.valueOf(42))
assert(cp.has("answer"))
val restored: java.lang.Integer = cp.load[java.lang.Integer]("answer")
assert(restored == java.lang.Integer.valueOf(42))
}
it should "round-trip a String value" in {
val cp = new CheckpointState()
cp.save("greeting", "hello")
assert(cp.load[String]("greeting") == "hello")
}
it should "overwrite a previously saved key" in {
val cp = new CheckpointState()
cp.save("k", java.lang.Integer.valueOf(1))
cp.save("k", java.lang.Integer.valueOf(2))
assert(cp.load[java.lang.Integer]("k") == java.lang.Integer.valueOf(2))
}
it should "track distinct keys independently" in {
val cp = new CheckpointState()
cp.save("a", "alpha")
cp.save("b", "beta")
assert(cp.load[String]("a") == "alpha")
assert(cp.load[String]("b") == "beta")
}
"CheckpointState.load" should "raise RuntimeException for an unknown key" in {
val cp = new CheckpointState()
val ex = intercept[RuntimeException] {
cp.load[Any]("missing")
}
assert(ex.getMessage.contains("missing"))
}
"CheckpointState.size" should "be the sum of every entry's serialized byte length" in {
val cp = new CheckpointState()
cp.save("a", "x")
val sizeAfterOne = cp.size()
assert(sizeAfterOne > 0L)
cp.save("b", "yy")
assert(cp.size() > sizeAfterOne)
}
// ---------------------------------------------------------------------------
// CheckpointSupport (trait shape)
// ---------------------------------------------------------------------------
"CheckpointSupport" should "be implementable by a custom subclass forwarding to a CheckpointState" in {
val support = new CheckpointSupport {
override def serializeState(
currentIteratorState: Iterator[(TupleLike, Option[PortIdentity])],
checkpoint: CheckpointState
): Iterator[(TupleLike, Option[PortIdentity])] = {
checkpoint.save("marker", java.lang.Integer.valueOf(1))
currentIteratorState
}
override def deserializeState(
checkpoint: CheckpointState
): Iterator[(TupleLike, Option[PortIdentity])] = Iterator.empty
override def getEstimatedCheckpointCost: Long = 7L
}
val cp = new CheckpointState()
val out = support.serializeState(Iterator.empty, cp)
assert(out.isEmpty)
assert(cp.has("marker"))
assert(support.deserializeState(cp).isEmpty)
assert(support.getEstimatedCheckpointCost == 7L)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/common/UtilsSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.common
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState
import org.scalatest.flatspec.AnyFlatSpec
import java.util.concurrent.locks.ReentrantLock
class UtilsSpec extends AnyFlatSpec {
// -- aggregatedStateToString ----------------------------------------------
"Utils.aggregatedStateToString" should "round-trip every named WorkflowAggregatedState through stringToAggregatedState" in {
val namedStates = Seq(
WorkflowAggregatedState.UNINITIALIZED,
WorkflowAggregatedState.READY,
WorkflowAggregatedState.RUNNING,
WorkflowAggregatedState.PAUSING,
WorkflowAggregatedState.PAUSED,
WorkflowAggregatedState.RESUMING,
WorkflowAggregatedState.COMPLETED,
WorkflowAggregatedState.TERMINATED,
WorkflowAggregatedState.FAILED,
WorkflowAggregatedState.KILLED,
WorkflowAggregatedState.UNKNOWN
)
namedStates.foreach { state =>
assert(
Utils.stringToAggregatedState(Utils.aggregatedStateToString(state)) == state,
s"round-trip failed for $state"
)
}
}
it should "render an unrecognized aggregated state with its raw value" in {
val unrecognized = WorkflowAggregatedState.Unrecognized(99)
assert(Utils.aggregatedStateToString(unrecognized) == "Unrecognized(99)")
}
// -- stringToAggregatedState ----------------------------------------------
"Utils.stringToAggregatedState" should "be case-insensitive and tolerant of surrounding whitespace" in {
assert(Utils.stringToAggregatedState("RUNNING") == WorkflowAggregatedState.RUNNING)
assert(Utils.stringToAggregatedState("running") == WorkflowAggregatedState.RUNNING)
assert(Utils.stringToAggregatedState(" Running ") == WorkflowAggregatedState.RUNNING)
}
it should "accept 'Initializing' as an alias for READY" in {
assert(Utils.stringToAggregatedState("Initializing") == WorkflowAggregatedState.READY)
assert(Utils.stringToAggregatedState("ready") == WorkflowAggregatedState.READY)
}
it should "throw IllegalArgumentException for an unrecognized state name" in {
assertThrows[IllegalArgumentException] {
Utils.stringToAggregatedState("not-a-real-state")
}
}
// -- maptoStatusCode ------------------------------------------------------
"Utils.maptoStatusCode" should "map known states to their documented byte codes" in {
assert(Utils.maptoStatusCode(WorkflowAggregatedState.UNINITIALIZED) == 0.toByte)
assert(Utils.maptoStatusCode(WorkflowAggregatedState.READY) == 0.toByte)
assert(Utils.maptoStatusCode(WorkflowAggregatedState.RUNNING) == 1.toByte)
assert(Utils.maptoStatusCode(WorkflowAggregatedState.PAUSED) == 2.toByte)
assert(Utils.maptoStatusCode(WorkflowAggregatedState.COMPLETED) == 3.toByte)
assert(Utils.maptoStatusCode(WorkflowAggregatedState.FAILED) == 4.toByte)
assert(Utils.maptoStatusCode(WorkflowAggregatedState.KILLED) == 5.toByte)
}
it should "return -1 for states that have no documented code" in {
Seq(
WorkflowAggregatedState.PAUSING,
WorkflowAggregatedState.RESUMING,
WorkflowAggregatedState.TERMINATED,
WorkflowAggregatedState.UNKNOWN
).foreach { state =>
assert(Utils.maptoStatusCode(state) == -1.toByte, s"expected -1 for $state")
}
}
// -- retry ---------------------------------------------------------------
"Utils.retry" should "return the value on the first successful attempt without retrying" in {
var calls = 0
val result = Utils.retry(attempts = 3, baseBackoffTimeInMS = 0L) {
calls += 1
"ok"
}
assert(result == "ok")
assert(calls == 1)
}
it should "retry on failure until success and return the eventual result" in {
var calls = 0
val result = Utils.retry(attempts = 3, baseBackoffTimeInMS = 0L) {
calls += 1
if (calls < 2) throw new RuntimeException("transient")
"ok"
}
assert(result == "ok")
assert(calls == 2)
}
it should "rethrow the last exception after exhausting all attempts" in {
var calls = 0
val ex = intercept[RuntimeException] {
Utils.retry(attempts = 2, baseBackoffTimeInMS = 0L) {
calls += 1
throw new RuntimeException(s"failure-$calls")
}
}
assert(calls == 2)
assert(ex.getMessage == "failure-2")
}
// -- withLock ------------------------------------------------------------
"Utils.withLock" should "release the lock after the body returns" in {
implicit val lock: ReentrantLock = new ReentrantLock()
val result = Utils.withLock {
assert(lock.isHeldByCurrentThread)
42
}
assert(result == 42)
assert(!lock.isHeldByCurrentThread)
}
it should "release the lock when the body throws" in {
implicit val lock: ReentrantLock = new ReentrantLock()
intercept[RuntimeException] {
Utils.withLock[Unit] {
throw new RuntimeException("boom")
}
}
assert(!lock.isHeldByCurrentThread)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/common/ambermessage/AmberMessageEnvelopesSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.common.ambermessage
import org.apache.pekko.actor.{Address, ActorSystem}
import org.apache.pekko.testkit.TestKit
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
class AmberMessageEnvelopesSpec extends AnyFlatSpec with BeforeAndAfterAll {
// Suite-local actor system used only by the ResendOutputTo test below;
// shut down via TestKit.shutdownActorSystem in afterAll so threads do not
// outlive the test, matching the cleanup pattern in ControllerSpec /
// WorkerSpec.
private val pekkoSystem: ActorSystem = ActorSystem("amber-message-envelopes-test")
override protected def afterAll(): Unit = {
TestKit.shutdownActorSystem(pekkoSystem)
super.afterAll()
}
private val channel =
ChannelIdentity(ActorVirtualIdentity("from"), ActorVirtualIdentity("to"), isControl = false)
private val intSchema: Schema = Schema().add(new Attribute("v", AttributeType.INTEGER))
private def tuple(v: Int): Tuple =
Tuple
.builder(intSchema)
.add(intSchema.getAttribute("v"), Integer.valueOf(v))
.build()
// ---------------------------------------------------------------------------
// WorkflowFIFOMessage / WorkflowRecoveryMessage envelope shape
// ---------------------------------------------------------------------------
"WorkflowFIFOMessage" should "carry channelId, sequenceNumber, and payload as constructed" in {
val payload = DataFrame(Array(tuple(1)))
val msg = WorkflowFIFOMessage(channel, 7L, payload)
assert(msg.channelId == channel)
assert(msg.sequenceNumber == 7L)
assert(msg.payload == payload)
}
it should "be a WorkflowMessage and Serializable" in {
val msg = WorkflowFIFOMessage(channel, 0L, DataFrame(Array.empty))
assert(msg.isInstanceOf[WorkflowMessage])
assert(msg.isInstanceOf[Serializable])
}
"WorkflowRecoveryMessage" should "carry the sender and payload as constructed" in {
val from = ActorVirtualIdentity("worker-1")
val payload = UpdateRecoveryStatus(isRecovering = true)
val msg = WorkflowRecoveryMessage(from, payload)
assert(msg.from == from)
assert(msg.payload == payload)
}
// ---------------------------------------------------------------------------
// RecoveryPayload subtypes
// ---------------------------------------------------------------------------
"RecoveryPayload subtypes" should "carry their constructor arguments" in {
val update = UpdateRecoveryStatus(isRecovering = true)
assert(update.isRecovering)
val updateOff = UpdateRecoveryStatus(isRecovering = false)
assert(!updateOff.isRecovering)
val nodeFailure = NotifyFailedNode(Address("pekko", "test"))
assert(nodeFailure.addr == Address("pekko", "test"))
}
it should "exercise ResendOutputTo via a real ActorRef so the case class wires correctly" in {
val deadRef = pekkoSystem.deadLetters
val vid = ActorVirtualIdentity("downstream")
val payload = ResendOutputTo(vid, deadRef)
assert(payload.vid == vid)
assert(payload.ref == deadRef)
}
it should "be Serializable on every subtype" in {
val payloads: Seq[RecoveryPayload] = Seq(
UpdateRecoveryStatus(isRecovering = true),
NotifyFailedNode(Address("pekko", "n"))
)
payloads.foreach(p => assert(p.isInstanceOf[Serializable]))
}
// ---------------------------------------------------------------------------
// WorkflowMessage.getInMemSize
// ---------------------------------------------------------------------------
// A non-DataFrame payload so getInMemSize falls into the 200L default branch.
private case class FixedSizePayload() extends WorkflowFIFOMessagePayload
"WorkflowMessage.getInMemSize" should "be the DataFrame's inMemSize for a WorkflowFIFOMessage carrying a DataFrame" in {
val df = DataFrame(Array(tuple(1), tuple(2)))
val msg = WorkflowFIFOMessage(channel, 0L, df)
assert(WorkflowMessage.getInMemSize(msg) == df.inMemSize)
}
it should "be zero for an empty-frame WorkflowFIFOMessage" in {
val msg = WorkflowFIFOMessage(channel, 0L, DataFrame(Array.empty))
assert(WorkflowMessage.getInMemSize(msg) == 0L)
}
it should "default to 200L for a non-DataFrame WorkflowFIFOMessagePayload" in {
val msg = WorkflowFIFOMessage(channel, 0L, FixedSizePayload())
assert(WorkflowMessage.getInMemSize(msg) == 200L)
}
// The catch-all `case _ => 200L` for non-WorkflowFIFOMessage subtypes is
// guarded by `WorkflowMessage` being sealed. Today the sealed hierarchy
// only has `WorkflowFIFOMessage`, so this branch is dead by construction;
// we leave it untested rather than open the seal.
// ---------------------------------------------------------------------------
// WorkflowFIFOMessagePayload trait wiring (sanity)
// ---------------------------------------------------------------------------
"WorkflowFIFOMessagePayload trait" should "be implementable as a custom payload" in {
val payload: WorkflowFIFOMessagePayload = FixedSizePayload()
assert(payload.isInstanceOf[Serializable])
}
"DirectControlMessagePayload trait" should "be a WorkflowFIFOMessagePayload subtype" in {
val custom: DirectControlMessagePayload = new DirectControlMessagePayload {}
assert(custom.isInstanceOf[WorkflowFIFOMessagePayload])
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/common/ambermessage/DataPayloadSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.common.ambermessage
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
import org.scalatest.flatspec.AnyFlatSpec
class DataPayloadSpec extends AnyFlatSpec {
private val vAttr = new Attribute("v", AttributeType.INTEGER)
private val schema: Schema = Schema().add(vAttr)
// Use the schema's Attribute when adding fields so the helper is always
// consistent with the schema under test.
private def tuple(v: Int): Tuple =
Tuple.builder(schema).add(schema.getAttribute("v"), Integer.valueOf(v)).build()
"DataFrame.inMemSize" should "be zero for an empty frame" in {
assert(DataFrame(Array.empty).inMemSize == 0L)
}
it should "be the sum of inMemSize across the contained tuples" in {
val a = tuple(1)
val b = tuple(2)
val df = DataFrame(Array(a, b))
assert(df.inMemSize == a.inMemSize + b.inMemSize)
}
"DataFrame.equals" should "be reflexive on a single empty frame instance" in {
val df = DataFrame(Array.empty)
assert(df == df)
}
it should "consider two distinct empty frames equal" in {
assert(DataFrame(Array.empty) == DataFrame(Array.empty))
}
it should "reject comparison against non-DataFrame values" in {
val df = DataFrame(Array(tuple(1)))
assert(!df.equals("not a dataframe"))
assert(!df.equals(null))
}
it should "reject frames whose lengths differ" in {
val a = DataFrame(Array(tuple(1)))
val b = DataFrame(Array(tuple(1), tuple(2)))
assert(a != b)
}
it should "treat element-wise equal frames as equal" in {
val a = DataFrame(Array(tuple(1), tuple(2)))
val b = DataFrame(Array(tuple(1), tuple(2)))
assert(a == b)
}
it should "respect element order" in {
val a = DataFrame(Array(tuple(1), tuple(2)))
val b = DataFrame(Array(tuple(2), tuple(1)))
assert(a != b)
}
it should "reject frames whose elements differ" in {
val a = DataFrame(Array(tuple(1)))
val b = DataFrame(Array(tuple(2)))
assert(a != b)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/common/statetransition/StateManagerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.common.statetransition
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.engine.common.statetransition.StateManager.{
InvalidStateException,
InvalidTransitionException
}
import org.scalatest.flatspec.AnyFlatSpec
class StateManagerSpec extends AnyFlatSpec {
private sealed trait DummyState
private case object S0 extends DummyState
private case object S1 extends DummyState
private case object S2 extends DummyState
private case object Orphan extends DummyState
private val actorId: ActorVirtualIdentity = ActorVirtualIdentity("test-actor")
/** Linear graph S0 -> S1 -> S2; S2 is terminal. Orphan is unreachable. */
private def linear(initial: DummyState = S0): StateManager[DummyState] =
new StateManager[DummyState](
actorId,
Map(
S0 -> Set(S1),
S1 -> Set(S2),
S2 -> Set.empty
),
initial
)
"StateManager" should "report the initial state via getCurrentState" in {
assert(linear(S1).getCurrentState == S1)
}
"StateManager.transitTo" should "advance to a state listed as a successor in the transition graph" in {
val sm = linear()
sm.transitTo(S1)
assert(sm.getCurrentState == S1)
}
it should "be a no-op when transitioning to the current state" in {
val sm = linear(S1)
sm.transitTo(S1)
assert(sm.getCurrentState == S1)
}
it should "throw InvalidTransitionException when the target is not a successor of the current state" in {
val sm = linear()
val ex = intercept[InvalidTransitionException] {
sm.transitTo(S2)
}
assert(ex.getMessage.contains(S0.toString))
assert(ex.getMessage.contains(S2.toString))
}
it should "throw InvalidTransitionException when transitioning out of a terminal state with no listed successors" in {
val sm = linear(S2) // S2 is a key, but with `Set.empty`, so no transitions are allowed.
intercept[InvalidTransitionException] {
sm.transitTo(S0)
}
}
it should "throw InvalidTransitionException when the current state is not a key in the transition graph" in {
// Orphan is intentionally absent from `linear()`'s key set, so
// `stateTransitionGraph.getOrElse(currentState, Set())` falls back to
// empty and any target should be rejected.
val sm = linear(Orphan)
intercept[InvalidTransitionException] {
sm.transitTo(S0)
}
}
"StateManager.assertState" should "succeed when the current state matches" in {
val sm = linear()
sm.assertState(S0) // does not throw
sm.assertState(S0, S1) // varargs: any-of
}
it should "throw InvalidStateException when the current state does not match the expected state" in {
val sm = linear()
intercept[InvalidStateException] {
sm.assertState(S1)
}
}
it should "throw InvalidStateException when none of the expected states match (varargs form)" in {
val sm = linear()
intercept[InvalidStateException] {
sm.assertState(S1, S2)
}
}
"StateManager.confirmState" should "report whether the current state matches" in {
val sm = linear()
assert(sm.confirmState(S0))
assert(!sm.confirmState(S1))
}
it should "report whether the current state is one of the given states (varargs form)" in {
val sm = linear()
assert(sm.confirmState(S0, S1))
assert(!sm.confirmState(S1, S2))
}
"StateManager.conditionalTransitTo" should "transition and run the callback when the precondition matches" in {
val sm = linear()
var called = false
sm.conditionalTransitTo(S0, S1, () => called = true)
assert(sm.getCurrentState == S1)
assert(called)
}
it should "do nothing and skip the callback when the precondition does not match" in {
val sm = linear()
var called = false
sm.conditionalTransitTo(S1, S2, () => called = true)
assert(sm.getCurrentState == S0)
assert(!called)
}
it should "still validate the transition graph and throw on an invalid transition" in {
val sm = linear()
intercept[InvalidTransitionException] {
sm.conditionalTransitTo(S0, S2, () => ())
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/common/statetransition/WorkerStateManagerSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.common.statetransition
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.engine.architecture.worker.statistics.WorkerState
import org.apache.texera.amber.engine.architecture.worker.statistics.WorkerState._
import org.apache.texera.amber.engine.common.statetransition.StateManager.InvalidTransitionException
import org.scalatest.flatspec.AnyFlatSpec
class WorkerStateManagerSpec extends AnyFlatSpec {
private val actorId: ActorVirtualIdentity = ActorVirtualIdentity("test-worker")
// No default here on purpose: the "default to UNINITIALIZED" test below must
// exercise WorkerStateManager's own default-parameter contract, not the
// helper's.
private def newManager(initial: WorkerState): WorkerStateManager =
new WorkerStateManager(actorId, initial)
"WorkerStateManager" should "default to the UNINITIALIZED state" in {
// Construct without an explicit initial so the test would catch a change
// to WorkerStateManager's default-parameter value.
assert(new WorkerStateManager(actorId).getCurrentState == UNINITIALIZED)
}
it should "honor the explicit initial state when provided" in {
assert(newManager(READY).getCurrentState == READY)
}
// -- Allowed transitions per the documented graph (one test per edge) --
it should "allow UNINITIALIZED -> READY" in {
val sm = newManager(UNINITIALIZED)
sm.transitTo(READY)
assert(sm.getCurrentState == READY)
}
it should "allow READY -> RUNNING" in {
val sm = newManager(READY)
sm.transitTo(RUNNING)
assert(sm.getCurrentState == RUNNING)
}
it should "allow RUNNING -> PAUSED" in {
val sm = newManager(RUNNING)
sm.transitTo(PAUSED)
assert(sm.getCurrentState == PAUSED)
}
it should "allow PAUSED -> RUNNING" in {
val sm = newManager(PAUSED)
sm.transitTo(RUNNING)
assert(sm.getCurrentState == RUNNING)
}
it should "allow RUNNING -> COMPLETED" in {
val sm = newManager(RUNNING)
sm.transitTo(COMPLETED)
assert(sm.getCurrentState == COMPLETED)
}
it should "allow READY -> PAUSED" in {
val sm = newManager(READY)
sm.transitTo(PAUSED)
assert(sm.getCurrentState == PAUSED)
}
it should "allow READY -> COMPLETED" in {
val sm = newManager(READY)
sm.transitTo(COMPLETED)
assert(sm.getCurrentState == COMPLETED)
}
// -- Disallowed transitions --
it should "reject UNINITIALIZED -> RUNNING (must go through READY)" in {
val sm = newManager(UNINITIALIZED)
intercept[InvalidTransitionException] {
sm.transitTo(RUNNING)
}
}
it should "treat COMPLETED as a terminal state" in {
val sm = newManager(COMPLETED)
intercept[InvalidTransitionException] {
sm.transitTo(RUNNING)
}
intercept[InvalidTransitionException] {
sm.transitTo(READY)
}
// Self-transition is a no-op, not an exception.
sm.transitTo(COMPLETED)
assert(sm.getCurrentState == COMPLETED)
}
it should "reject transitions into TERMINATED from every reachable source state" in {
// TERMINATED is absent from the graph entirely, so it must be unreachable
// from any state in the graph — including the COMPLETED terminal state.
Seq(UNINITIALIZED, READY, RUNNING, PAUSED, COMPLETED).foreach { from =>
val sm = newManager(from)
intercept[InvalidTransitionException] {
sm.transitTo(TERMINATED)
}
}
}
it should "reject PAUSED -> COMPLETED (only RUNNING -> COMPLETED is permitted)" in {
val sm = newManager(PAUSED)
intercept[InvalidTransitionException] {
sm.transitTo(COMPLETED)
}
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/e2e/BatchSizePropagationSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.e2e
import org.apache.pekko.actor.{ActorSystem, Props}
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.pekko.util.Timeout
import org.apache.texera.amber.clustering.SingleNodeListener
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext, WorkflowSettings}
import org.apache.texera.amber.engine.architecture.controller._
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings._
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.apache.texera.amber.engine.e2e.TestUtils.buildWorkflow
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.amber.operator.aggregate.AggregationFunction
import org.apache.texera.workflow.LogicalLink
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import scala.concurrent.duration.DurationInt
class BatchSizePropagationSpec
extends TestKit(ActorSystem("BatchSizePropagationSpec"))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll
with BeforeAndAfterEach {
implicit val timeout: Timeout = Timeout(5.seconds)
override def beforeAll(): Unit = {
system.actorOf(Props[SingleNodeListener](), "cluster-info")
}
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
def verifyBatchSizeInPartitioning(
workflowScheduler: WorkflowScheduler,
expectedBatchSize: Int
): Unit = {
var nextRegions = workflowScheduler.getNextRegions
while (nextRegions.nonEmpty) {
nextRegions.foreach { region =>
region.resourceConfig.foreach { resourceConfig =>
resourceConfig.linkConfigs.foreach {
case (_, linkConfig) =>
val partitioning = linkConfig.partitioning
partitioning match {
case oneToOne: OneToOnePartitioning =>
println(s"Testing OneToOnePartitioning with batch size: ${oneToOne.batchSize}")
assert(
oneToOne.batchSize == expectedBatchSize,
s"Batch size mismatch: ${oneToOne.batchSize} != $expectedBatchSize"
)
case roundRobin: RoundRobinPartitioning =>
println(
s"Testing RoundRobinPartitioning with batch size: ${roundRobin.batchSize}"
)
assert(
roundRobin.batchSize == expectedBatchSize,
s"Batch size mismatch: ${roundRobin.batchSize} != $expectedBatchSize"
)
case hashBased: HashBasedShufflePartitioning =>
println(
s"Testing HashBasedShufflePartitioning with batch size: ${hashBased.batchSize}"
)
assert(
hashBased.batchSize == expectedBatchSize,
s"Batch size mismatch: ${hashBased.batchSize} != $expectedBatchSize"
)
case rangeBased: RangeBasedShufflePartitioning =>
println(
s"Testing RangeBasedShufflePartitioning with batch size: ${rangeBased.batchSize}"
)
assert(
rangeBased.batchSize == expectedBatchSize,
s"Batch size mismatch: ${rangeBased.batchSize} != $expectedBatchSize"
)
case broadcast: BroadcastPartitioning =>
println(s"Testing BroadcastPartitioning with batch size: ${broadcast.batchSize}")
assert(
broadcast.batchSize == expectedBatchSize,
s"Batch size mismatch: ${broadcast.batchSize} != $expectedBatchSize"
)
case _ =>
throw new IllegalArgumentException("Unknown partitioning type encountered")
}
}
}
}
nextRegions = workflowScheduler.getNextRegions
}
}
"Engine" should "propagate the correct batch size for headerlessCsv workflow" in {
val expectedBatchSize = 1
val customWorkflowSettings = WorkflowSettings(dataTransferBatchSize = expectedBatchSize)
val context =
new WorkflowContext(workflowSettings = customWorkflowSettings)
val headerlessCsvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
val workflow = buildWorkflow(
List(headerlessCsvOpDesc),
List(),
context
)
val workflowScheduler = new WorkflowScheduler(context, CONTROLLER)
workflowScheduler.updateSchedule(workflow.physicalPlan)
verifyBatchSizeInPartitioning(workflowScheduler, 1)
}
"Engine" should "propagate the correct batch size for headerlessCsv->keyword workflow" in {
val expectedBatchSize = 500
val customWorkflowSettings = WorkflowSettings(dataTransferBatchSize = expectedBatchSize)
val context =
new WorkflowContext(workflowSettings = customWorkflowSettings)
val headerlessCsvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val workflow = buildWorkflow(
List(headerlessCsvOpDesc, keywordOpDesc),
List(
LogicalLink(
headerlessCsvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
),
context
)
val workflowScheduler = new WorkflowScheduler(context, CONTROLLER)
workflowScheduler.updateSchedule(workflow.physicalPlan)
verifyBatchSizeInPartitioning(workflowScheduler, 500)
}
"Engine" should "propagate the correct batch size for csv->keyword->count workflow" in {
val expectedBatchSize = 100
val customWorkflowSettings = WorkflowSettings(dataTransferBatchSize = expectedBatchSize)
val context =
new WorkflowContext(workflowSettings = customWorkflowSettings)
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val countOpDesc =
TestOperators.aggregateAndGroupByDesc("Region", AggregationFunction.COUNT, List[String]())
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc, countOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
countOpDesc.operatorIdentifier,
PortIdentity()
)
),
context
)
val workflowScheduler = new WorkflowScheduler(context, CONTROLLER)
workflowScheduler.updateSchedule(workflow.physicalPlan)
verifyBatchSizeInPartitioning(workflowScheduler, 100)
}
"Engine" should "propagate the correct batch size for csv->keyword->averageAndGroupBy workflow" in {
val expectedBatchSize = 300
val customWorkflowSettings = WorkflowSettings(dataTransferBatchSize = expectedBatchSize)
val context =
new WorkflowContext(workflowSettings = customWorkflowSettings)
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val averageAndGroupByOpDesc =
TestOperators.aggregateAndGroupByDesc(
"Units Sold",
AggregationFunction.AVERAGE,
List[String]("Country")
)
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc, averageAndGroupByOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
averageAndGroupByOpDesc.operatorIdentifier,
PortIdentity()
)
),
context
)
val workflowScheduler = new WorkflowScheduler(context, CONTROLLER)
workflowScheduler.updateSchedule(workflow.physicalPlan)
verifyBatchSizeInPartitioning(workflowScheduler, 300)
}
"Engine" should "propagate the correct batch size for csv->(csv->)->join workflow" in {
val expectedBatchSize = 1
val customWorkflowSettings = WorkflowSettings(dataTransferBatchSize = expectedBatchSize)
val context =
new WorkflowContext(workflowSettings = customWorkflowSettings)
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val headerlessCsvOpDesc2 = TestOperators.headerlessSmallCsvScanOpDesc()
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
headerlessCsvOpDesc2,
joinOpDesc
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc2.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
)
),
context
)
val workflowScheduler = new WorkflowScheduler(context, CONTROLLER)
workflowScheduler.updateSchedule(workflow.physicalPlan)
verifyBatchSizeInPartitioning(workflowScheduler, 1)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/e2e/DataProcessingSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.e2e
import org.apache.pekko.actor.{ActorSystem, Props}
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.pekko.util.Timeout
import com.twitter.util.{Await, Duration, Promise}
import org.apache.texera.amber.clustering.SingleNodeListener
import org.apache.texera.amber.core.storage.DocumentFactory
import org.apache.texera.amber.core.storage.model.VirtualDocument
import org.apache.texera.amber.core.tuple.{AttributeType, Tuple}
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.core.workflow.{
ExecutionMode,
PortIdentity,
WorkflowContext,
WorkflowSettings
}
import org.apache.texera.amber.engine.architecture.controller._
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.EmptyRequest
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState.COMPLETED
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.common.client.AmberClient
import org.apache.texera.amber.engine.e2e.TestUtils.{
buildWorkflow,
cleanupWorkflowExecutionData,
initiateTexeraDBForTestCases,
setUpWorkflowExecutionData
}
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.amber.operator.aggregate.AggregationFunction
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowExecutionsResource.getResultUriByLogicalPortId
import org.apache.texera.workflow.LogicalLink
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Outcome, Retries}
import scala.concurrent.duration.DurationInt
class DataProcessingSpec
extends TestKit(ActorSystem("DataProcessingSpec", AmberRuntime.pekkoConfig))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll
with BeforeAndAfterEach
with Retries {
/**
* This block retries each test once if it fails.
* In the CI environment, there is a chance that executeWorkflow does not receive "COMPLETED" status.
* Until we find the root cause of this issue, we use a retry mechanism here to stablize CI runs.
*/
override def withFixture(test: NoArgTest): Outcome =
withRetry { super.withFixture(test) }
implicit val timeout: Timeout = Timeout(5.seconds)
val workflowContext: WorkflowContext = new WorkflowContext()
val materializedWorkflowContext: WorkflowContext = new WorkflowContext(
workflowSettings = WorkflowSettings(
dataTransferBatchSize = 400,
executionMode = ExecutionMode.MATERIALIZED
)
)
override protected def beforeEach(): Unit = {
setUpWorkflowExecutionData()
}
override protected def afterEach(): Unit = {
cleanupWorkflowExecutionData()
}
override def beforeAll(): Unit = {
system.actorOf(Props[SingleNodeListener](), "cluster-info")
// These test cases access postgres in CI, but occasionally the jdbc driver cannot be found during CI run.
// Explicitly load the JDBC driver to avoid flaky CI failures.
Class.forName("org.postgresql.Driver")
initiateTexeraDBForTestCases()
}
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
def executeWorkflow(workflow: Workflow): Map[OperatorIdentity, List[Tuple]] = {
var results: Map[OperatorIdentity, List[Tuple]] = null
val client = new AmberClient(
system,
workflow.context,
workflow.physicalPlan,
ControllerConfig.default,
error => {}
)
val completion = Promise[Unit]()
client.registerCallback[FatalError](evt => {
completion.setException(evt.e)
client.shutdown()
})
client
.registerCallback[ExecutionStateUpdate](evt => {
if (evt.state == COMPLETED) {
results = workflow.logicalPlan.getTerminalOperatorIds
.filter(terminalOpId => {
val uri = getResultUriByLogicalPortId(
workflowContext.executionId,
terminalOpId,
PortIdentity()
)
uri.nonEmpty
})
.map(terminalOpId => {
val uri = getResultUriByLogicalPortId(
workflowContext.executionId,
terminalOpId,
PortIdentity()
).get
terminalOpId -> DocumentFactory
.openDocument(uri)
._1
.asInstanceOf[VirtualDocument[Tuple]]
.get()
.toList
})
.toMap
completion.setDone()
}
})
Await.result(client.controllerInterface.startWorkflow(EmptyRequest(), ()))
Await.result(completion, Duration.fromMinutes(1))
results
}
"Engine" should "execute headerlessCsv workflow normally" in {
val headerlessCsvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
val workflow = buildWorkflow(
List(headerlessCsvOpDesc),
List(),
workflowContext
)
val results = executeWorkflow(workflow)(headerlessCsvOpDesc.operatorIdentifier)
assert(results.size == 100)
}
"Engine" should "execute headerlessMultiLineDataCsv workflow normally" in {
val headerlessCsvOpDesc = TestOperators.headerlessSmallMultiLineDataCsvScanOpDesc()
val workflow = buildWorkflow(
List(headerlessCsvOpDesc),
List(),
workflowContext
)
val results = executeWorkflow(workflow)(headerlessCsvOpDesc.operatorIdentifier)
assert(results.size == 100)
}
"Engine" should "execute jsonl workflow normally" in {
val jsonlOp = TestOperators.smallJSONLScanOpDesc()
val workflow = buildWorkflow(
List(jsonlOp),
List(),
workflowContext
)
val results = executeWorkflow(workflow)(jsonlOp.operatorIdentifier)
assert(results.size == 100)
for (result <- results) {
val schema = result.getSchema
assert(schema.getAttribute("id").getType == AttributeType.LONG)
assert(schema.getAttribute("first_name").getType == AttributeType.STRING)
assert(schema.getAttribute("flagged").getType == AttributeType.BOOLEAN)
assert(schema.getAttribute("year").getType == AttributeType.INTEGER)
assert(schema.getAttribute("created_at").getType == AttributeType.TIMESTAMP)
assert(schema.getAttributes.length == 9)
}
}
"Engine" should "execute mediumFlattenJsonl workflow normally" in {
val jsonlOp = TestOperators.mediumFlattenJSONLScanOpDesc()
val workflow = buildWorkflow(
List(jsonlOp),
List(),
workflowContext
)
val results = executeWorkflow(workflow)(jsonlOp.operatorIdentifier)
assert(results.size == 1000)
for (result <- results) {
val schema = result.getSchema
assert(schema.getAttribute("id").getType == AttributeType.LONG)
assert(schema.getAttribute("first_name").getType == AttributeType.STRING)
assert(schema.getAttribute("flagged").getType == AttributeType.BOOLEAN)
assert(schema.getAttribute("year").getType == AttributeType.INTEGER)
assert(schema.getAttribute("created_at").getType == AttributeType.TIMESTAMP)
assert(schema.getAttribute("test_object.array2.another").getType == AttributeType.INTEGER)
assert(schema.getAttributes.length == 13)
}
}
"Engine" should "execute headerlessCsv->keyword workflow normally" in {
val headerlessCsvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val workflow = buildWorkflow(
List(headerlessCsvOpDesc, keywordOpDesc),
List(
LogicalLink(
headerlessCsvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
),
workflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv workflow normally" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val workflow = buildWorkflow(
List(csvOpDesc),
List(),
workflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->keyword workflow normally" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
),
workflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->keyword->count workflow normally" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val countOpDesc =
TestOperators.aggregateAndGroupByDesc("Region", AggregationFunction.COUNT, List[String]())
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc, countOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
countOpDesc.operatorIdentifier,
PortIdentity()
)
),
workflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->keyword->averageAndGroupBy workflow normally" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val averageAndGroupByOpDesc =
TestOperators.aggregateAndGroupByDesc(
"Units Sold",
AggregationFunction.AVERAGE,
List[String]("Country")
)
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc, averageAndGroupByOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
averageAndGroupByOpDesc.operatorIdentifier,
PortIdentity()
)
),
workflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->(csv->)->join workflow normally" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val headerlessCsvOpDesc2 = TestOperators.headerlessSmallCsvScanOpDesc()
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
headerlessCsvOpDesc2,
joinOpDesc
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc2.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
)
),
workflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute headerlessCsv->keyword workflow with MATERIALIZED mode" in {
val headerlessCsvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("column-1", "Asia")
val workflow = buildWorkflow(
List(headerlessCsvOpDesc, keywordOpDesc),
List(
LogicalLink(
headerlessCsvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
),
materializedWorkflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv workflow with MATERIALIZED mode" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val workflow = buildWorkflow(
List(csvOpDesc),
List(),
materializedWorkflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->keyword workflow with MATERIALIZED mode" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
),
materializedWorkflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->keyword->count workflow with MATERIALIZED mode" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val countOpDesc =
TestOperators.aggregateAndGroupByDesc("Region", AggregationFunction.COUNT, List[String]())
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc, countOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
countOpDesc.operatorIdentifier,
PortIdentity()
)
),
materializedWorkflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->keyword->averageAndGroupBy workflow with MATERIALIZED mode" in {
val csvOpDesc = TestOperators.smallCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val averageAndGroupByOpDesc =
TestOperators.aggregateAndGroupByDesc(
"Units Sold",
AggregationFunction.AVERAGE,
List[String]("Country")
)
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc, averageAndGroupByOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
keywordOpDesc.operatorIdentifier,
PortIdentity(),
averageAndGroupByOpDesc.operatorIdentifier,
PortIdentity()
)
),
materializedWorkflowContext
)
executeWorkflow(workflow)
}
"Engine" should "execute csv->(csv->)->join workflow with MATERIALIZED mode" in {
val headerlessCsvOpDesc1 = TestOperators.headerlessSmallCsvScanOpDesc()
val headerlessCsvOpDesc2 = TestOperators.headerlessSmallCsvScanOpDesc()
val joinOpDesc = TestOperators.joinOpDesc("column-1", "column-1")
val workflow = buildWorkflow(
List(
headerlessCsvOpDesc1,
headerlessCsvOpDesc2,
joinOpDesc
),
List(
LogicalLink(
headerlessCsvOpDesc1.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity()
),
LogicalLink(
headerlessCsvOpDesc2.operatorIdentifier,
PortIdentity(),
joinOpDesc.operatorIdentifier,
PortIdentity(1)
)
),
materializedWorkflowContext
)
executeWorkflow(workflow)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/e2e/PauseSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.e2e
import org.apache.pekko.actor.{ActorSystem, Props}
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.pekko.util.Timeout
import com.twitter.util.{Await, Duration, Promise}
import com.typesafe.scalalogging.Logger
import org.apache.texera.amber.clustering.SingleNodeListener
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.architecture.controller.{
ControllerConfig,
ExecutionStateUpdate
}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.EmptyRequest
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState.{
COMPLETED,
PAUSED
}
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.common.client.AmberClient
import org.apache.texera.amber.engine.e2e.TestUtils.{
cleanupWorkflowExecutionData,
initiateTexeraDBForTestCases,
setUpWorkflowExecutionData,
stateReached
}
import org.apache.texera.amber.operator.{LogicalOp, TestOperators}
import org.apache.texera.workflow.LogicalLink
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Outcome, Retries}
import scala.concurrent.duration._
class PauseSpec
extends TestKit(ActorSystem("PauseSpec", AmberRuntime.pekkoConfig))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll
with BeforeAndAfterEach
with Retries {
/**
* This block retries each test once if it fails.
* In the CI environment, there is a chance that shouldPause does not receive "COMPLETED" status.
* Until we find the root cause of this issue, we use a retry mechanism here to stablize CI runs.
*/
override def withFixture(test: NoArgTest): Outcome =
withRetry { super.withFixture(test) }
implicit val timeout: Timeout = Timeout(5.seconds)
val logger = Logger("PauseSpecLogger")
override protected def beforeEach(): Unit = {
setUpWorkflowExecutionData()
}
override protected def afterEach(): Unit = {
cleanupWorkflowExecutionData()
}
override def beforeAll(): Unit = {
system.actorOf(Props[SingleNodeListener](), "cluster-info")
// These test cases access postgres in CI, but occasionally the jdbc driver cannot be found during CI run.
// Explicitly load the JDBC driver to avoid flaky CI failures.
Class.forName("org.postgresql.Driver")
initiateTexeraDBForTestCases()
}
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
def shouldPause(
operators: List[LogicalOp],
links: List[LogicalLink]
): Unit = {
val workflow =
TestUtils.buildWorkflow(operators, links, new WorkflowContext())
val client =
new AmberClient(
system,
workflow.context,
workflow.physicalPlan,
ControllerConfig.default,
error => {}
)
val completion = Promise[Unit]()
client
.registerCallback[ExecutionStateUpdate](evt => {
if (evt.state == COMPLETED) {
completion.setDone()
}
})
val stateWaitTimeout = Duration.fromSeconds(10)
Await.result(client.controllerInterface.startWorkflow(EmptyRequest(), ()))
val firstPaused = stateReached(client, PAUSED)
Await.result(client.controllerInterface.pauseWorkflow(EmptyRequest(), ()))
Await.result(firstPaused, stateWaitTimeout)
Await.result(client.controllerInterface.resumeWorkflow(EmptyRequest(), ()))
val secondPaused = stateReached(client, PAUSED)
Await.result(client.controllerInterface.pauseWorkflow(EmptyRequest(), ()))
Await.result(secondPaused, stateWaitTimeout)
Await.result(client.controllerInterface.resumeWorkflow(EmptyRequest(), ()))
Await.result(completion, Duration.fromMinutes(1))
}
"Engine" should "be able to pause csv workflow" in {
val csvOpDesc = TestOperators.mediumCsvScanOpDesc()
logger.info(s"csv-id ${csvOpDesc.operatorIdentifier}")
shouldPause(
List(csvOpDesc),
List()
)
}
"Engine" should "be able to pause csv->keyword workflow" in {
val csvOpDesc = TestOperators.mediumCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
logger.info(
s"csv-id ${csvOpDesc.operatorIdentifier}, keyword-id ${keywordOpDesc.operatorIdentifier}"
)
shouldPause(
List(csvOpDesc, keywordOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
)
)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/e2e/ReconfigurationSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.e2e
import com.typesafe.scalalogging.Logger
import org.apache.pekko.actor.{ActorSystem, Props}
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.pekko.util.Timeout
import org.apache.texera.amber.clustering.SingleNodeListener
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.tuple.Tuple
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.e2e.TestUtils.{
cleanupWorkflowExecutionData,
initiateTexeraDBForTestCases,
setUpWorkflowExecutionData
}
import org.apache.texera.amber.operator.{LogicalOp, TestOperators}
import org.apache.texera.workflow.LogicalLink
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Outcome, Retries}
import org.scalatest.flatspec.AnyFlatSpecLike
import scala.concurrent.duration._
class ReconfigurationSpec
extends TestKit(ActorSystem("ReconfigurationSpec", AmberRuntime.pekkoConfig))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll
with BeforeAndAfterEach
with Retries {
/**
* This block retries each test once if it fails.
* In the CI environment, there is a chance that executeWorkflow does not receive "COMPLETED" status.
* Until we find the root cause of this issue, we use a retry mechanism here to stabilize CI runs.
*/
override def withFixture(test: NoArgTest): Outcome =
withRetry { super.withFixture(test) }
implicit val timeout: Timeout = Timeout(5.seconds)
val logger = Logger("ReconfigurationSpecLogger")
val ctx = new WorkflowContext()
override protected def beforeEach(): Unit = {
setUpWorkflowExecutionData()
}
override protected def afterEach(): Unit = {
cleanupWorkflowExecutionData()
}
override def beforeAll(): Unit = {
system.actorOf(Props[SingleNodeListener](), "cluster-info")
// These test cases access postgres in CI, but occasionally the jdbc driver cannot be found during CI run.
// Explicitly load the JDBC driver to avoid flaky CI failures.
Class.forName("org.postgresql.Driver")
initiateTexeraDBForTestCases()
}
override def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
// Thin wrapper around the shared TestUtils helper so call sites below stay
// ctx/system-implicit. The actual workflow-driver logic lives in TestUtils
// and is reused by ReconfigurationIntegrationSpec.
def shouldReconfigure(
operators: List[LogicalOp],
links: List[LogicalLink],
targetOps: Seq[LogicalOp],
newOpExecInitInfo: OpExecInitInfo
): Map[OperatorIdentity, List[Tuple]] =
TestUtils.shouldReconfigure(system, ctx, operators, links, targetOps, newOpExecInitInfo)
"Engine" should "be able to modify a java operator in workflow" in {
val sourceOpDesc = TestOperators.mediumCsvScanOpDesc()
val keywordMatchNoneOpDesc = TestOperators.keywordSearchOpDesc("Region", "ShouldMatchNone")
val keywordMatchManyOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val result = shouldReconfigure(
List(sourceOpDesc, keywordMatchNoneOpDesc),
List(
LogicalLink(
sourceOpDesc.operatorIdentifier,
PortIdentity(),
keywordMatchNoneOpDesc.operatorIdentifier,
PortIdentity()
)
),
Seq(keywordMatchNoneOpDesc),
keywordMatchManyOpDesc.getPhysicalOp(ctx.workflowId, ctx.executionId).opExecInitInfo
)
assert(result(keywordMatchNoneOpDesc.operatorIdentifier).nonEmpty)
}
"Engine" should "not be able to modify a source operator in workflow" in {
val sourceOpDesc = TestOperators.mediumCsvScanOpDesc()
val sourceOpDesc2 = TestOperators.mediumCsvScanOpDesc()
val keywordMatchNoneOpDesc = TestOperators.keywordSearchOpDesc("Region", "ShouldMatchNone")
val ex = intercept[Throwable] {
shouldReconfigure(
List(sourceOpDesc, keywordMatchNoneOpDesc),
List(
LogicalLink(
sourceOpDesc.operatorIdentifier,
PortIdentity(),
keywordMatchNoneOpDesc.operatorIdentifier,
PortIdentity()
)
),
Seq(sourceOpDesc),
sourceOpDesc2.getPhysicalOp(ctx.workflowId, ctx.executionId).opExecInitInfo
)
}
assert(
ex.getMessage == "java.lang.IllegalStateException: Reconfiguration cannot be applied to source operators"
)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/e2e/TestUtils.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.e2e
import com.twitter.util.{Await, Duration, Promise, Return}
import org.apache.pekko.actor.ActorSystem
import org.apache.texera.amber.config.StorageConfig
import org.apache.texera.amber.core.executor.OpExecInitInfo
import org.apache.texera.amber.core.storage.DocumentFactory
import org.apache.texera.amber.core.storage.model.VirtualDocument
import org.apache.texera.amber.core.tuple.Tuple
import org.apache.texera.amber.core.virtualidentity.OperatorIdentity
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.architecture.controller.{
ControllerConfig,
ExecutionStateUpdate,
Workflow
}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
EmptyRequest,
UpdateExecutorRequest,
WorkflowReconfigureRequest
}
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.WorkflowAggregatedState.{
COMPLETED,
PAUSED
}
import org.apache.texera.amber.engine.common.client.AmberClient
import org.apache.texera.amber.operator.LogicalOp
import org.apache.texera.dao.SqlServer
import org.apache.texera.dao.jooq.generated.enums.UserRoleEnum
import org.apache.texera.dao.jooq.generated.tables.daos.{
UserDao,
WorkflowDao,
WorkflowExecutionsDao,
WorkflowVersionDao
}
import org.apache.texera.dao.jooq.generated.tables.pojos.{
User,
WorkflowExecutions,
WorkflowVersion,
Workflow => WorkflowPojo
}
import org.apache.texera.web.model.websocket.request.LogicalPlanPojo
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowExecutionsResource.getResultUriByLogicalPortId
import org.apache.texera.workflow.{LogicalLink, WorkflowCompiler}
object TestUtils {
def buildWorkflow(
operators: List[LogicalOp],
links: List[LogicalLink],
context: WorkflowContext
): Workflow = {
val workflowCompiler = new WorkflowCompiler(
context
)
workflowCompiler.compile(
LogicalPlanPojo(operators, links, List(), List())
)
}
/**
* If a test case accesses the user system through singleton resources that cache the DSLContext (e.g., executes a
* workflow, which accesses WorkflowExecutionsResource), we use a separate texera_db specifically for such test cases.
* Note such test cases need to clean up the database at the end of running each test case.
*/
def initiateTexeraDBForTestCases(): Unit = {
SqlServer.initConnection(
StorageConfig.jdbcUrlForTestCases,
StorageConfig.jdbcUsername,
StorageConfig.jdbcPassword
)
}
val testUser: User = {
val user = new User
user.setUid(Integer.valueOf(1))
user.setName("test_user")
user.setRole(UserRoleEnum.ADMIN)
user.setPassword("123")
user.setEmail("test_user@test.com")
user
}
val testWorkflowEntry: WorkflowPojo = {
val workflow = new WorkflowPojo
workflow.setName("test workflow")
workflow.setWid(Integer.valueOf(1))
workflow.setContent("test workflow content")
workflow.setDescription("test description")
workflow
}
val testWorkflowVersionEntry: WorkflowVersion = {
val workflowVersion = new WorkflowVersion
workflowVersion.setWid(Integer.valueOf(1))
workflowVersion.setVid(Integer.valueOf(1))
workflowVersion.setContent("test version content")
workflowVersion
}
val testWorkflowExecutionEntry: WorkflowExecutions = {
val workflowExecution = new WorkflowExecutions
workflowExecution.setEid(Integer.valueOf(1))
workflowExecution.setVid(Integer.valueOf(1))
workflowExecution.setUid(Integer.valueOf(1))
workflowExecution.setStatus(3.toByte)
workflowExecution.setEnvironmentVersion("test engine")
workflowExecution
}
def setUpWorkflowExecutionData(): Unit = {
val dslConfig = SqlServer.getInstance().context.configuration()
val userDao = new UserDao(dslConfig)
val workflowDao = new WorkflowDao(dslConfig)
val workflowExecutionsDao = new WorkflowExecutionsDao(dslConfig)
val workflowVersionDao = new WorkflowVersionDao(dslConfig)
userDao.insert(testUser)
workflowDao.insert(testWorkflowEntry)
workflowVersionDao.insert(testWorkflowVersionEntry)
workflowExecutionsDao.insert(testWorkflowExecutionEntry)
}
/**
* Returns a Promise that completes the next time the client emits an
* ExecutionStateUpdate with the given target state. Must be called BEFORE
* the action that triggers the state change, since AmberClient observables
* do not replay past events.
*/
def stateReached(
client: AmberClient,
target: WorkflowAggregatedState
): Promise[Unit] = {
val p = Promise[Unit]()
client.registerCallback[ExecutionStateUpdate](evt => {
if (evt.state == target) {
p.updateIfEmpty(Return(()))
}
})
p
}
/**
* Pause a freshly-started workflow, swap the executor for the given target
* operators via WorkflowReconfigureRequest, resume, and collect the
* terminal-port outputs once the run completes. Shared by ReconfigurationSpec
* (pure-Scala) and ReconfigurationIntegrationSpec (Python-tagged), so an
* earlier in-spec copy doesn't drift between the two as new e2e specs
* land. The caller passes its own `system` (TestKit) and `ctx`
* (WorkflowContext) since both are tied to the spec lifecycle.
*/
def shouldReconfigure(
system: ActorSystem,
ctx: WorkflowContext,
operators: List[LogicalOp],
links: List[LogicalLink],
targetOps: Seq[LogicalOp],
newOpExecInitInfo: OpExecInitInfo
): Map[OperatorIdentity, List[Tuple]] = {
val workflow = buildWorkflow(operators, links, ctx)
val client = new AmberClient(
system,
workflow.context,
workflow.physicalPlan,
ControllerConfig.default,
error => {}
)
val completion = Promise[Unit]()
var result: Map[OperatorIdentity, List[Tuple]] = null
client.registerCallback[ExecutionStateUpdate](evt => {
if (evt.state == COMPLETED) {
result = workflow.logicalPlan.getTerminalOperatorIds
.filter(terminalOpId => {
val uri = getResultUriByLogicalPortId(
workflow.context.executionId,
terminalOpId,
PortIdentity()
)
uri.nonEmpty
})
.map(terminalOpId => {
val uri = getResultUriByLogicalPortId(
workflow.context.executionId,
terminalOpId,
PortIdentity()
).get
terminalOpId -> DocumentFactory
.openDocument(uri)
._1
.asInstanceOf[VirtualDocument[Tuple]]
.get()
.toList
})
.toMap
completion.setDone()
}
})
Await.result(
client.controllerInterface.startWorkflow(EmptyRequest(), ()),
Duration.fromSeconds(5)
)
val pausedReached = stateReached(client, PAUSED)
Await.result(
client.controllerInterface.pauseWorkflow(EmptyRequest(), ()),
Duration.fromSeconds(5)
)
Await.result(pausedReached, Duration.fromSeconds(10))
val physicalOps = targetOps.flatMap(op =>
workflow.physicalPlan.getPhysicalOpsOfLogicalOp(op.operatorIdentifier)
)
Await.result(
client.controllerInterface.reconfigureWorkflow(
WorkflowReconfigureRequest(
reconfiguration = physicalOps.map(op => UpdateExecutorRequest(op.id, newOpExecInitInfo)),
reconfigurationId = "test-reconfigure-1"
),
()
),
Duration.fromSeconds(5)
)
Await.result(
client.controllerInterface.resumeWorkflow(EmptyRequest(), ()),
Duration.fromSeconds(5)
)
Await.result(completion, Duration.fromMinutes(1))
result
}
def cleanupWorkflowExecutionData(): Unit = {
val dslConfig = SqlServer.getInstance().context.configuration()
val userDao = new UserDao(dslConfig)
val workflowDao = new WorkflowDao(dslConfig)
val workflowExecutionsDao = new WorkflowExecutionsDao(dslConfig)
val workflowVersionDao = new WorkflowVersionDao(dslConfig)
workflowExecutionsDao.deleteById(1)
workflowVersionDao.deleteById(1)
workflowDao.deleteById(1)
userDao.deleteById(1)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/faulttolerance/CheckpointSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.faulttolerance
import org.apache.pekko.actor.{ActorSystem, Props}
import org.apache.texera.amber.clustering.SingleNodeListener
import org.apache.texera.amber.core.workflow.{PortIdentity, WorkflowContext}
import org.apache.texera.amber.engine.architecture.controller.{
ControllerConfig,
ControllerProcessor
}
import org.apache.texera.amber.engine.architecture.worker.DataProcessor
import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.DPInputQueueElement
import org.apache.texera.amber.engine.common.SerializedState.{CP_STATE_KEY, DP_STATE_KEY}
import org.apache.texera.amber.engine.common.virtualidentity.util.{CONTROLLER, SELF}
import org.apache.texera.amber.engine.common.{AmberRuntime, CheckpointState}
import org.apache.texera.amber.engine.e2e.TestUtils.buildWorkflow
import org.apache.texera.amber.operator.TestOperators
import org.apache.texera.workflow.LogicalLink
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpecLike
import java.util.concurrent.LinkedBlockingQueue
class CheckpointSpec extends AnyFlatSpecLike with BeforeAndAfterAll {
var system: ActorSystem = _
val csvOpDesc = TestOperators.mediumCsvScanOpDesc()
val keywordOpDesc = TestOperators.keywordSearchOpDesc("Region", "Asia")
val workflow = buildWorkflow(
List(csvOpDesc, keywordOpDesc),
List(
LogicalLink(
csvOpDesc.operatorIdentifier,
PortIdentity(),
keywordOpDesc.operatorIdentifier,
PortIdentity()
)
),
new WorkflowContext()
)
override def beforeAll(): Unit = {
system = ActorSystem("CheckpointSpec", AmberRuntime.pekkoConfig)
system.actorOf(Props[SingleNodeListener](), "cluster-info")
}
"Default controller state" should "be serializable" in {
val cp =
new ControllerProcessor(
workflow.context,
ControllerConfig.default,
CONTROLLER,
msg => {}
)
val chkpt = new CheckpointState()
chkpt.save(CP_STATE_KEY, cp)
}
"Default worker state" should "be serializable" in {
val dp = new DataProcessor(
SELF,
msg => {},
inputMessageQueue = new LinkedBlockingQueue[DPInputQueueElement]()
)
val chkpt = new CheckpointState()
chkpt.save(DP_STATE_KEY, dp)
}
"CheckpointState" should "fail loudly on an unknown key" in {
// Pin the documented contract precisely: load throws
// RuntimeException("no state saved for key = $key"). A bare
// `contains("unknown")` would still pass if the message ever drifts to
// something like "unknown checkpoint", silently weakening the assertion.
val chkpt = new CheckpointState()
assert(!chkpt.has("unknown"))
val ex = intercept[RuntimeException] {
chkpt.load[Any]("unknown")
}
assert(ex.getMessage == "no state saved for key = unknown")
}
// "CSVScanOperator" should "be serializable" in {
// val chkpt = new CheckpointState()
// val headerlessCsvOpDesc = TestOperators.headerlessSmallCsvScanOpDesc()
// val context = new WorkflowContext()
// headerlessCsvOpDesc.setContext(context)
// val phyOp = headerlessCsvOpDesc.getPhysicalOp(WorkflowIdentity(1), ExecutionIdentity(1))
// phyOp.opExecInitInfo match {
// case OpExecInitInfoWithCode(codeGen) => ???
// case OpExecInitInfoWithFunc(opGen) =>
// val operator = opGen(1, 1)
// operator.open()
// val outputIter =
// operator.asInstanceOf[SourceOperatorExecutor].produceTuple().map(t => (t, None))
// outputIter.next()
// outputIter.next()
// operator.asInstanceOf[CheckpointSupport].serializeState(outputIter, chkpt)
// chkpt.save("deserialization", opGen)
// val opGen2 = chkpt.load("deserialization").asInstanceOf[(Int, Int) => OperatorExecutor]
// val op = opGen2.apply(1, 1)
// op.asInstanceOf[CheckpointSupport].deserializeState(chkpt)
// }
// }
//
// "Workflow " should "take global checkpoint, reload and continue" in {
// val client1 = new AmberClient(
// system,
// workflow.context,
// workflow.physicalPlan,
// resultStorage,
// ControllerConfig.default,
// error => {}
// )
// Await.result(client1.controllerInterface.startWorkflow(EmptyRequest(), ()))
// Thread.sleep(100)
// Await.result(client1.controllerInterface.pauseWorkflow(EmptyRequest(), ()))
// val checkpointId = EmbeddedControlMessageIdentity(s"Checkpoint_test_1")
// val uri = new URI("ram:///recovery-logs/tmp/")
// Await.result(
// client1.controllerInterface.takeGlobalCheckpoint(
// TakeGlobalCheckpointRequest(estimationOnly = false, checkpointId, uri.toString),
// ()
// ),
// Duration.fromSeconds(30)
// )
// client1.shutdown()
// Thread.sleep(100)
// var controllerConfig = ControllerConfig.default
// controllerConfig =
// controllerConfig.copy(stateRestoreConfOpt = Some(StateRestoreConfig(uri, checkpointId)))
// val completableFuture = new CompletableFuture[Unit]()
// val client2 = new AmberClient(
// system,
// workflow.context,
// workflow.physicalPlan,
// resultStorage,
// controllerConfig,
// error => {}
// )
// client2.registerCallback[ExecutionStateUpdate] { evt =>
// if (evt.state == COMPLETED) {
// completableFuture.complete(())
// }
// }
// Thread.sleep(1000)
// assert(
// Await
// .result(client2.controllerInterface.startWorkflow(EmptyRequest(), ()))
// .workflowState == PAUSED
// )
// Thread.sleep(5000)
// Await.result(client2.controllerInterface.resumeWorkflow(EmptyRequest(), ()))
// completableFuture.get(30000, TimeUnit.MILLISECONDS)
// }
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/faulttolerance/LoggingSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.faulttolerance
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.texera.amber.core.tuple.{AttributeType, Schema, TupleLike}
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ChannelIdentity,
OperatorIdentity,
PhysicalOpIdentity
}
import org.apache.texera.amber.core.workflow.{PhysicalLink, PortIdentity}
import org.apache.texera.amber.engine.architecture.logreplay.{ReplayLogManager, ReplayLogRecord}
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
AddPartitioningRequest,
AsyncRPCContext,
EmptyRequest
}
import org.apache.texera.amber.engine.architecture.rpc.controllerservice.ControllerServiceGrpc.METHOD_WORKER_EXECUTION_COMPLETED
import org.apache.texera.amber.engine.architecture.rpc.workerservice.WorkerServiceGrpc.{
METHOD_ADD_PARTITIONING,
METHOD_PAUSE_WORKER,
METHOD_RESUME_WORKER,
METHOD_START_WORKER
}
import org.apache.texera.amber.engine.architecture.sendsemantics.partitionings.OneToOnePartitioning
import org.apache.texera.amber.engine.common.AmberRuntime
import org.apache.texera.amber.engine.common.ambermessage.{
DataFrame,
WorkflowFIFOMessage,
WorkflowFIFOMessagePayload
}
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient.ControlInvocation
import org.apache.texera.amber.engine.common.storage.SequentialRecordStorage
import org.apache.texera.amber.engine.common.virtualidentity.util.{CONTROLLER, SELF}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.TimeLimitedTests
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
import java.net.URI
class LoggingSpec
extends TestKit(ActorSystem("LoggingSpec", AmberRuntime.pekkoConfig))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll
with TimeLimitedTests {
private val identifier1 = ActorVirtualIdentity("Worker:WF1-E1-op-layer-1")
private val identifier2 = ActorVirtualIdentity("Worker:WF1-E1-op-layer-2")
private val operatorIdentity = OperatorIdentity("testOperator")
private val physicalOpId1 = PhysicalOpIdentity(operatorIdentity, "1st-layer")
private val physicalOpId2 = PhysicalOpIdentity(operatorIdentity, "2nd-layer")
private val mockLink = PhysicalLink(physicalOpId1, PortIdentity(), physicalOpId2, PortIdentity())
private val mockPolicy =
OneToOnePartitioning(10, Seq(ChannelIdentity(identifier1, identifier2, isControl = false)))
val payloadToLog: Array[WorkflowFIFOMessagePayload] = Array(
ControlInvocation(
METHOD_START_WORKER,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, identifier1),
0
),
ControlInvocation(
METHOD_ADD_PARTITIONING,
AddPartitioningRequest(mockLink, mockPolicy),
AsyncRPCContext(CONTROLLER, identifier1),
0
),
ControlInvocation(
METHOD_PAUSE_WORKER,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, identifier1),
0
),
ControlInvocation(
METHOD_RESUME_WORKER,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, identifier1),
0
),
DataFrame(
(0 to 400)
.map(i =>
TupleLike(i, i.toString, i.toDouble).enforceSchema(
Schema()
.add("field1", AttributeType.INTEGER)
.add("field2", AttributeType.STRING)
.add("field3", AttributeType.DOUBLE)
)
)
.toArray
),
ControlInvocation(
METHOD_START_WORKER,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, identifier1),
0
),
ControlInvocation(
METHOD_WORKER_EXECUTION_COMPLETED,
EmptyRequest(),
AsyncRPCContext(identifier1, CONTROLLER),
0
)
)
"determinant logger" should "log processing steps in local storage" in {
Thread.sleep(1000) // wait for serializer to be registered
val logStorage = SequentialRecordStorage.getStorage[ReplayLogRecord](
Some(new URI("ram:///recovery-logs/tmp"))
)
logStorage.deleteStorage()
val logManager = ReplayLogManager.createLogManager(logStorage, "tmpLog", x => {})
payloadToLog.foreach { payload =>
val channel = ChannelIdentity(CONTROLLER, SELF, isControl = true)
val msgOpt = Some(WorkflowFIFOMessage(channel, 0, payload))
logManager.withFaultTolerant(channel, msgOpt) {
// do nothing
}
}
logManager.sendCommitted(null)
logManager.terminate()
val logRecords = logStorage.getReader("tmpLog").mkRecordIterator().toArray
logStorage.deleteStorage()
assert(logRecords.length == 15)
}
override def timeLimit: Span = 30.seconds
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/engine/faulttolerance/ReplaySpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.engine.faulttolerance
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.testkit.{ImplicitSender, TestKit}
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
import org.apache.texera.amber.engine.architecture.logreplay.{
ProcessingStep,
ReplayLogManagerImpl,
ReplayLogRecord,
ReplayOrderEnforcer
}
import org.apache.texera.amber.engine.architecture.messaginglayer.NetworkInputGateway
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
AsyncRPCContext,
EmptyRequest
}
import org.apache.texera.amber.engine.architecture.rpc.workerservice.WorkerServiceGrpc.METHOD_START_WORKER
import org.apache.texera.amber.engine.common.ambermessage.WorkflowFIFOMessage
import org.apache.texera.amber.engine.common.rpc.AsyncRPCClient.ControlInvocation
import org.apache.texera.amber.engine.common.storage.SequentialRecordStorage
import org.apache.texera.amber.engine.common.storage.SequentialRecordStorage.SequentialRecordReader
import org.apache.texera.amber.engine.common.virtualidentity.util.CONTROLLER
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpecLike
import scala.collection.mutable
class ReplaySpec
extends TestKit(ActorSystem("ReplaySpec"))
with ImplicitSender
with AnyFlatSpecLike
with BeforeAndAfterAll {
class IterableReadOnlyLogStore(iter: Iterable[ReplayLogRecord])
extends SequentialRecordStorage[ReplayLogRecord] {
override def getWriter(
fileName: String
): SequentialRecordStorage.SequentialRecordWriter[ReplayLogRecord] = ???
override def getReader(
fileName: String
): SequentialRecordStorage.SequentialRecordReader[ReplayLogRecord] =
new SequentialRecordReader[ReplayLogRecord](null) {
override def mkRecordIterator(): Iterator[ReplayLogRecord] = iter.iterator
}
override def deleteStorage(): Unit = ???
override def containsFolder(folderName: String): Boolean = ???
}
private val actorId = ActorVirtualIdentity("test")
private val actorId2 = ActorVirtualIdentity("upstream1")
private val actorId3 = ActorVirtualIdentity("upstream2")
private val channelId1 = ChannelIdentity(CONTROLLER, actorId, isControl = true)
private val channelId2 = ChannelIdentity(actorId2, actorId, isControl = false)
private val channelId3 = ChannelIdentity(actorId3, actorId, isControl = false)
private val channelId4 = ChannelIdentity(actorId2, actorId, isControl = true)
private val logManager = new ReplayLogManagerImpl(x => {})
"replay input gate" should "replay the message payload in log order" in {
val logRecords = mutable.Queue[ProcessingStep](
ProcessingStep(channelId1, -1),
ProcessingStep(channelId4, 1),
ProcessingStep(channelId3, 2),
ProcessingStep(channelId1, 3),
ProcessingStep(channelId2, 4)
)
val inputGateway = new NetworkInputGateway(actorId)
def inputMessage(channelId: ChannelIdentity, seq: Long): Unit = {
inputGateway
.getChannel(channelId)
.acceptMessage(
WorkflowFIFOMessage(
channelId,
seq,
ControlInvocation(
METHOD_START_WORKER,
EmptyRequest(),
AsyncRPCContext(CONTROLLER, actorId),
0
)
)
)
}
val orderEnforcer = new ReplayOrderEnforcer(logManager, logRecords, -1, () => {})
inputGateway.addEnforcer(orderEnforcer)
def processMessage(channelId: ChannelIdentity, seq: Long): Unit = {
val msg = inputGateway.tryPickChannel.get.take
logManager.withFaultTolerant(msg.channelId, Some(msg)) {
assert(msg.channelId == channelId && msg.sequenceNumber == seq)
}
}
assert(inputGateway.tryPickChannel.isEmpty)
inputMessage(channelId2, 0)
assert(inputGateway.tryPickChannel.isEmpty)
inputMessage(channelId4, 0)
assert(inputGateway.tryPickChannel.isEmpty)
inputMessage(channelId1, 0)
inputMessage(channelId1, 1)
inputMessage(channelId1, 2)
assert(
inputGateway.tryPickChannel.nonEmpty && inputGateway.tryPickChannel.get.channelId == channelId1
)
processMessage(channelId1, 0)
assert(inputGateway.tryPickChannel.nonEmpty)
processMessage(channelId1, 1)
assert(inputGateway.tryPickChannel.nonEmpty)
processMessage(channelId4, 0)
assert(inputGateway.tryPickChannel.isEmpty)
inputMessage(channelId3, 0)
processMessage(channelId3, 0)
assert(inputGateway.tryPickChannel.nonEmpty)
processMessage(channelId1, 2)
assert(inputGateway.tryPickChannel.nonEmpty)
processMessage(channelId2, 0)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/amber/error/ErrorUtilsSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.amber.error
import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.ConsoleMessageType.ERROR
import org.apache.texera.amber.engine.architecture.rpc.controlreturns.{ControlError, ErrorLanguage}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import scala.util.control.ControlThrowable
class ErrorUtilsSpec extends AnyFlatSpec with Matchers {
// ----- safely -----
"safely" should "rethrow ControlThrowable even when the handler is defined for it" in {
val ct = new ControlThrowable {}
val swallowAll: PartialFunction[Throwable, String] = { case _ => "swallowed" }
val wrapped = ErrorUtils.safely(swallowAll)
val thrown = intercept[ControlThrowable](wrapped(ct))
thrown should be theSameInstanceAs ct
}
it should "delegate to the supplied handler when it is defined for the throwable" in {
val handler: PartialFunction[Throwable, String] = {
case e: IllegalStateException => s"handled:${e.getMessage}"
}
val wrapped = ErrorUtils.safely(handler)
wrapped(new IllegalStateException("boom")) shouldBe "handled:boom"
}
it should "leave the wrapped partial function undefined for unhandled throwables" in {
// The wrapped PartialFunction must report isDefinedAt=false for inputs the
// user's handler does not cover, so callers can fall through to other
// catch clauses.
val handler: PartialFunction[Throwable, String] = {
case _: IllegalStateException => "ok"
}
val wrapped = ErrorUtils.safely(handler)
wrapped.isDefinedAt(new RuntimeException("nope")) shouldBe false
}
// ----- mkConsoleMessage -----
"mkConsoleMessage" should "use Unknown Source when the throwable has no stack frames" in {
val err = new RuntimeException("kaboom")
err.setStackTrace(Array.empty)
val msg = ErrorUtils.mkConsoleMessage(ActorVirtualIdentity("worker-A"), err)
msg.workerId shouldBe "worker-A"
msg.source shouldBe "(Unknown Source)"
msg.title shouldBe err.toString
msg.msgType shouldBe ERROR
msg.message shouldBe ""
}
it should "encode the top stack frame as (file:line) when available" in {
val err = new RuntimeException("kaboom")
err.setStackTrace(
Array(new StackTraceElement("com.x.Foo", "bar", "Foo.scala", 42))
)
val msg = ErrorUtils.mkConsoleMessage(ActorVirtualIdentity("worker-A"), err)
msg.source shouldBe "(Foo.scala:42)"
msg.message should include("Foo.scala")
}
// ----- mkControlError -----
"mkControlError" should "leave errorDetails empty and language=SCALA when the cause is null" in {
val err = new RuntimeException("no-cause")
err.setStackTrace(Array(new StackTraceElement("Cls", "m", "F.scala", 7)))
val ce = ErrorUtils.mkControlError(err)
ce.errorMessage shouldBe err.toString
ce.errorDetails shouldBe ""
ce.language shouldBe ErrorLanguage.SCALA
ce.stackTrace should startWith("at ")
ce.stackTrace should include("F.scala:7")
}
it should "populate errorDetails with the cause's toString when present" in {
val cause = new IllegalStateException("root")
val err = new RuntimeException("outer", cause)
val ce = ErrorUtils.mkControlError(err)
ce.errorMessage shouldBe err.toString
ce.errorDetails shouldBe cause.toString
}
// ----- reconstructThrowable -----
"reconstructThrowable" should "skip stack-trace parsing for PYTHON-language errors" in {
// Pin: PYTHON path returns a bare new Throwable(message) and never
// touches the supplied errorDetails/stackTrace strings. The reconstructed
// throwable will still carry the JVM-captured stack from `new Throwable`,
// so the test only asserts what's specific to this branch.
val ce = ControlError(
"py.boom",
"ignored-details",
"at com.x.Foo.bar(Foo.scala:42)",
ErrorLanguage.PYTHON
)
val reconstructed = ErrorUtils.reconstructThrowable(ce)
reconstructed.getMessage shouldBe "py.boom"
reconstructed.getCause shouldBe null
// None of the parsed-stack frames should leak through on the Python path.
reconstructed.getStackTrace.exists(f => f.getClassName == "com.x.Foo.bar") shouldBe false
}
it should "leave the cause null when errorDetails is empty for SCALA errors" in {
val ce = ControlError("scala.boom", "", "", ErrorLanguage.SCALA)
val reconstructed = ErrorUtils.reconstructThrowable(ce)
reconstructed.getMessage shouldBe "scala.boom"
reconstructed.getCause shouldBe null
}
it should "attach a cause Throwable when errorDetails is non-empty" in {
val ce = ControlError("scala.boom", "root-cause", "", ErrorLanguage.SCALA)
val reconstructed = ErrorUtils.reconstructThrowable(ce)
reconstructed.getCause should not be null
reconstructed.getCause.getMessage shouldBe "root-cause"
}
it should "parse stacktrace lines that match the at-className(location) pattern" in {
val ce = ControlError(
"scala.boom",
"",
"at com.x.Foo.bar(Foo.scala:42)\nat com.x.Baz.qux(Baz.scala:7)",
ErrorLanguage.SCALA
)
val reconstructed = ErrorUtils.reconstructThrowable(ce)
val frames = reconstructed.getStackTrace
frames.length shouldBe 2
frames(0).getClassName shouldBe "com.x.Foo.bar"
frames(0).getFileName shouldBe "Foo.scala:42"
frames(1).getClassName shouldBe "com.x.Baz.qux"
}
it should "drop lines that do not match the at-className(location) pattern" in {
val ce = ControlError(
"scala.boom",
"",
"garbage line\nat com.x.Foo.bar(Foo.scala:42)\nmore garbage",
ErrorLanguage.SCALA
)
val reconstructed = ErrorUtils.reconstructThrowable(ce)
reconstructed.getStackTrace.length shouldBe 1
}
// ----- getStackTraceWithAllCauses -----
"getStackTraceWithAllCauses" should "use the developer header at the top level" in {
val err = new RuntimeException("top")
err.setStackTrace(Array.empty)
val out = ErrorUtils.getStackTraceWithAllCauses(err)
out should startWith("Stack trace for developers:")
out should include(err.toString)
}
it should "recurse into nested causes with a Caused by section" in {
val cause = new IllegalStateException("inner")
cause.setStackTrace(Array.empty)
val err = new RuntimeException("outer", cause)
err.setStackTrace(Array.empty)
val out = ErrorUtils.getStackTraceWithAllCauses(err)
out should include("Caused by:")
out should include("inner")
out should include("outer")
}
// ----- getOperatorFromActorIdOpt -----
"getOperatorFromActorIdOpt" should "default to unknown operator and empty worker id when the option is empty" in {
ErrorUtils.getOperatorFromActorIdOpt(None) shouldBe ("unknown operator", "")
}
it should "extract operator id from a worker actor name following the WF/op/layer pattern" in {
val actor = ActorVirtualIdentity("Worker:WF1-E1-myOp-main-0")
val (operatorId, workerId) = ErrorUtils.getOperatorFromActorIdOpt(Some(actor))
// The pattern is Worker:WF---; greedy on operator,
// so layer=`main`, workerIdx=`0`, and the operator captures `E1-myOp`.
operatorId shouldBe "E1-myOp"
workerId shouldBe "Worker:WF1-E1-myOp-main-0"
}
it should "fall back to the dummy operator id for actor names that do not match the pattern" in {
val actor = ActorVirtualIdentity("CONTROLLER")
val (operatorId, workerId) = ErrorUtils.getOperatorFromActorIdOpt(Some(actor))
operatorId shouldBe "__DummyOperator"
workerId shouldBe "CONTROLLER"
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/auth/UserAuthenticatorSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.auth
import org.apache.texera.auth.JwtAuth
import org.apache.texera.dao.jooq.generated.enums.UserRoleEnum
import org.jose4j.jwt.JwtClaims
import org.jose4j.jwt.consumer.JwtContext
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
class UserAuthenticatorSpec extends AnyFlatSpec with Matchers {
// Mirror exactly what JwtAuth.jwtClaims would write at issue time, so
// the spec doubles as a contract check between the issuer and the
// amber-side authenticator.
private def buildClaims(): JwtClaims = {
val claims = new JwtClaims
claims.setSubject("alice")
claims.setClaim("userId", 42)
claims.setClaim("googleId", "g-123")
claims.setClaim("email", "alice@example.com")
claims.setClaim("role", UserRoleEnum.ADMIN.name)
claims.setClaim("googleAvatar", "avatar-blob")
claims.setExpirationTimeMinutesInTheFuture(10f)
claims
}
// Run a token through the production consumer to get a real JwtContext —
// matches what the toastshaman filter hands the authenticator at runtime.
private def contextFor(claims: JwtClaims): JwtContext =
JwtAuth.jwtConsumer.process(JwtAuth.jwtToken(claims))
"UserAuthenticator.authenticate" should "delegate to JwtParser and return a populated SessionUser" in {
val result = UserAuthenticator.authenticate(contextFor(buildClaims()))
result.isPresent shouldBe true
val u = result.get().getUser
u.getUid shouldBe 42
u.getName shouldBe "alice"
u.getEmail shouldBe "alice@example.com"
u.getGoogleId shouldBe "g-123"
u.getGoogleAvatar shouldBe "avatar-blob"
u.getRole shouldBe UserRoleEnum.ADMIN
}
it should "return empty when a required custom claim (userId) is missing" in {
val claims = new JwtClaims
claims.setSubject("alice")
claims.setClaim("role", UserRoleEnum.ADMIN.name)
claims.setExpirationTimeMinutesInTheFuture(10f)
UserAuthenticator.authenticate(contextFor(claims)).isPresent shouldBe false
}
it should "return empty when a required custom claim (role) is missing" in {
val claims = new JwtClaims
claims.setSubject("alice")
claims.setClaim("userId", 42)
claims.setExpirationTimeMinutesInTheFuture(10f)
UserAuthenticator.authenticate(contextFor(claims)).isPresent shouldBe false
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/resource/dashboard/file/WorkflowResourceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.resource.dashboard.file
import org.apache.texera.auth.SessionUser
import org.apache.texera.dao.MockTexeraDB
import org.apache.texera.dao.jooq.generated.Tables.{USER, WORKFLOW, WORKFLOW_OF_PROJECT}
import org.apache.texera.dao.jooq.generated.enums.UserRoleEnum
import org.apache.texera.dao.jooq.generated.tables.daos.UserDao
import org.apache.texera.dao.jooq.generated.tables.pojos.{Project, User, Workflow}
import org.apache.texera.web.resource.dashboard.DashboardResource.SearchQueryParams
import org.apache.texera.web.resource.dashboard.user.project.ProjectResource
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowResource
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowResource.{
DashboardWorkflow,
WorkflowIDs
}
import org.apache.texera.web.resource.dashboard.{DashboardResource, FulltextSearchQueryUtils}
import org.jooq.Condition
import org.jooq.impl.DSL.noCondition
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import java.sql.Timestamp
import java.text.{ParseException, SimpleDateFormat}
import java.time.{Duration, OffsetDateTime, ZoneOffset}
import java.util
import java.util.Collections
import java.util.concurrent.TimeUnit
class WorkflowResourceSpec
extends AnyFlatSpec
with BeforeAndAfterAll
with BeforeAndAfterEach
with MockTexeraDB {
// An example creation time to test Account Creation Time attribute
private val exampleCreationTime: OffsetDateTime =
OffsetDateTime.parse("2025-01-01T00:00:00Z")
private val testUser: User = {
val user = new User
user.setUid(Integer.valueOf(1))
user.setName("test_user")
user.setRole(UserRoleEnum.ADMIN)
user.setPassword("123")
user.setComment("test_comment")
user.setAccountCreationTime(exampleCreationTime)
user
}
private val testUser2: User = {
val user = new User
user.setUid(Integer.valueOf(2))
user.setName("test_user2")
user.setRole(UserRoleEnum.ADMIN)
user.setPassword("123")
user.setComment("test_comment2")
user.setAccountCreationTime(exampleCreationTime)
user
}
private val keywordInWorkflow1Content = "keyword_in_workflow1_content"
private val textPhrase = "text phrases"
private val exampleContent =
"{\"x\":5,\"y\":\"" + keywordInWorkflow1Content + "\",\"z\":\"" + textPhrase + "\"}"
private val testWorkflow1: Workflow = {
val workflow = new Workflow()
workflow.setName("test_workflow1")
workflow.setDescription("keyword_in_workflow_description")
workflow.setContent(exampleContent)
workflow
}
private val testWorkflow2: Workflow = {
val workflow = new Workflow()
workflow.setName("test_workflow2")
workflow.setDescription("another_text")
workflow.setContent("{\"x\":5,\"y\":\"example2\",\"z\":\"\"}")
workflow
}
private val testWorkflow3: Workflow = {
val workflow = new Workflow()
workflow.setName("test_workflow3")
workflow.setDescription("")
workflow.setContent("{\"x\":5,\"y\":\"example3\",\"z\":\"\"}")
workflow
}
private val testProject1: Project = {
val project = new Project()
project.setName("test_project1")
project.setDescription("this is project description")
project
}
private val exampleEmailAddress = "name@example.com"
private val exampleWord1 = "Lorem"
private val exampleWord2 = "Ipsum"
private val testWorkflowWithSpecialCharacters: Workflow = {
val workflow = new Workflow()
workflow.setName("workflow_with_special_characters")
workflow.setDescription(exampleWord1 + " " + exampleWord2 + " " + exampleEmailAddress)
workflow.setContent(exampleContent)
workflow
}
private val sessionUser1: SessionUser = {
new SessionUser(testUser)
}
private val sessionUser2: SessionUser = {
new SessionUser(testUser2)
}
private val workflowResource: WorkflowResource = {
new WorkflowResource()
}
private val projectResource: ProjectResource = {
new ProjectResource()
}
private val dashboardResource: DashboardResource = {
new DashboardResource()
}
override protected def beforeAll(): Unit = {
initializeDBAndReplaceDSLContext()
FulltextSearchQueryUtils.usePgroonga = false // disable pgroonga
// add test user directly
val userDao = new UserDao(getDSLContext.configuration())
userDao.insert(testUser)
userDao.insert(testUser2)
}
override protected def beforeEach(): Unit = {
// Clean up environment before each test case
// Delete all workflows, or reset the state of the `workflowResource` object
}
override protected def afterEach(): Unit = {
// Clean up environment after each test case if necessary
// delete all workflows in the database
var workflows = workflowResource.retrieveWorkflowsBySessionUser(sessionUser1)
workflows.foreach(workflow =>
workflowResource.deleteWorkflow(
WorkflowIDs(List(workflow.workflow.getWid), None),
sessionUser1
)
)
workflows = workflowResource.retrieveWorkflowsBySessionUser(sessionUser2)
workflows.foreach(workflow =>
workflowResource.deleteWorkflow(
WorkflowIDs(List(workflow.workflow.getWid), None),
sessionUser2
)
)
// delete all projects in the database
var projects = projectResource.getProjectList(sessionUser1)
projects.forEach(project => projectResource.deleteProject(project.pid))
projects = projectResource.getProjectList(sessionUser2)
projects.forEach(project => projectResource.deleteProject(project.pid))
}
override protected def afterAll(): Unit = {
shutdownDB()
}
private def getKeywordsArray(keywords: String*): util.ArrayList[String] = {
val keywordsList = new util.ArrayList[String]()
for (keyword <- keywords) {
keywordsList.add(keyword)
}
keywordsList
}
private def insertAndAssertAccountCreation(uid: Int, ts: OffsetDateTime): Unit = {
val userDao = new UserDao(getDSLContext.configuration())
val u = new User
u.setUid(Integer.valueOf(uid))
u.setName(s"tmp_user_$uid")
u.setRole(UserRoleEnum.REGULAR)
u.setPassword("pw")
u.setComment("tmp")
u.setAccountCreationTime(ts)
userDao.insert(u)
try {
val fetched = userDao.fetchOneByUid(Integer.valueOf(uid))
assert(fetched.getAccountCreationTime != null)
assert(fetched.getAccountCreationTime.isEqual(ts))
} finally {
userDao.deleteById(Integer.valueOf(uid))
}
}
private def assertSameWorkflow(a: Workflow, b: DashboardWorkflow): Unit = {
assert(a.getName == b.workflow.getName)
}
"User.accountCreationTime" should "be persisted and retrievable via UserDao" in {
val userDao = new UserDao(getDSLContext.configuration())
val u1 = userDao.fetchOneByUid(Integer.valueOf(1))
val u2 = userDao.fetchOneByUid(Integer.valueOf(2))
assert(u1.getAccountCreationTime != null)
assert(u2.getAccountCreationTime != null)
assert(u1.getAccountCreationTime.isEqual(exampleCreationTime))
assert(u2.getAccountCreationTime.isEqual(exampleCreationTime))
}
it should "remain unchanged when updating unrelated fields" in {
val userDao = new UserDao(getDSLContext.configuration())
val u1 = userDao.fetchOneByUid(Integer.valueOf(1))
val originalTime = u1.getAccountCreationTime
u1.setComment("updated_comment")
userDao.update(u1)
val test_u1 = userDao.fetchOneByUid(Integer.valueOf(1))
assert(test_u1.getAccountCreationTime.isEqual(originalTime))
}
it should "fallback to DB default when not explicitly set on insert" in {
// account_creation_time TIMESTAMPTZ NOT NULL DEFAULT now()
val userDao = new UserDao(getDSLContext.configuration())
// Test user 3 on top of test user 1 and 2
val userId = 3
val tmp = new User
tmp.setUid(Integer.valueOf(userId))
tmp.setName("tmp_user")
tmp.setRole(UserRoleEnum.REGULAR)
tmp.setPassword("pw")
tmp.setComment("tmp")
// Account creation time not set
userDao.insert(tmp)
val fetched = userDao.fetchOneByUid(Integer.valueOf(3))
assert(fetched.getAccountCreationTime != null)
val now = OffsetDateTime.now(ZoneOffset.UTC)
val diff = Duration.between(fetched.getAccountCreationTime, now).abs()
assert(diff.toMinutes <= 2)
}
// Testing with user id 4
it should "persist and retrieve a non-UTC offset time (ex: +09:00 JST)" in {
val userId = 4
insertAndAssertAccountCreation(
uid = userId,
ts = OffsetDateTime.parse("2020-06-15T12:34:56+09:00")
)
}
// Testing with user id 5
it should "persist and retrieve a leap day timestamp" in {
val userId = 5
insertAndAssertAccountCreation(
uid = userId,
ts = OffsetDateTime.parse("2024-02-29T23:59:59Z")
)
}
// Testing with user id 6
it should "persist and retrieve a future timestamp" in {
val userId = 6
insertAndAssertAccountCreation(
uid = userId,
ts = OffsetDateTime.parse("2100-12-31T23:59:59Z")
)
}
"WorkflowResource /owner_name" should "return owner name as plain text" in {
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
val workflows = workflowResource.retrieveWorkflowsBySessionUser(sessionUser1)
assert(workflows.nonEmpty)
val wid =
workflows
.find(_.workflow.getName == testWorkflow1.getName)
.map(_.workflow.getWid)
.getOrElse(workflows.head.workflow.getWid)
val ownerName = workflowResource.getOwnerName(wid)
assert(ownerName == testUser.getName)
}
"/search API " should "be able to search for workflows in different columns in Workflow table" in {
// testWorkflow1: {name: test_name, descrption: test_description, content: test_content}
// search "test_name" or "test_description" or "test_content" should return testWorkflow1
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
// search
val DashboardWorkflowEntryList =
dashboardResource
.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(keywords = getKeywordsArray(keywordInWorkflow1Content))
)
.results
assert(DashboardWorkflowEntryList.head.workflow.get.ownerName.equals(testUser.getName))
assert(DashboardWorkflowEntryList.length == 1)
assertSameWorkflow(testWorkflow1, DashboardWorkflowEntryList.head.workflow.get)
}
it should "be able to search text phrases" in {
// testWorkflow1: {name: "test_name", descrption: "test_description", content: "text phrase"}
// search "text phrase" should return testWorkflow1
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
val DashboardWorkflowEntryList =
dashboardResource
.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(keywords = getKeywordsArray(keywordInWorkflow1Content))
)
.results
assert(DashboardWorkflowEntryList.length == 1)
assertSameWorkflow(testWorkflow1, DashboardWorkflowEntryList.head.workflow.get)
val DashboardWorkflowEntryList1 =
dashboardResource
.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(keywords = getKeywordsArray("text sear"))
)
.results
assert(DashboardWorkflowEntryList1.isEmpty)
}
it should "return an all workflows when given an empty list of keywords" in {
// search "" should return all workflows
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
val DashboardWorkflowEntryList =
dashboardResource.searchAllResourcesCall(sessionUser1, SearchQueryParams())
assert(DashboardWorkflowEntryList.results.length == 2)
}
it should "be able to search with arbitrary number of keywords in different combinations" in {
// testWorkflow1: {name: test_name, description: test_description, content: "key pair"}
// search ["key"] or ["pair", "key"] should return the testWorkflow1
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
// search with multiple keywords
val keywords = new util.ArrayList[String]()
keywords.add(keywordInWorkflow1Content)
keywords.add(testWorkflow1.getDescription)
val DashboardWorkflowEntryList = dashboardResource
.searchAllResourcesCall(sessionUser1, SearchQueryParams(keywords = keywords))
.results
assert(DashboardWorkflowEntryList.size == 1)
assert(DashboardWorkflowEntryList.head.workflow.get.ownerName.equals(testUser.getName))
assertSameWorkflow(testWorkflow1, DashboardWorkflowEntryList.head.workflow.get)
keywords.add("nonexistent")
val DashboardWorkflowEntryList2 = dashboardResource
.searchAllResourcesCall(sessionUser1, SearchQueryParams(keywords = keywords))
.results
assert(DashboardWorkflowEntryList2.isEmpty)
val keywordsReverseOrder = new util.ArrayList[String]()
keywordsReverseOrder.add(testWorkflow1.getDescription)
keywordsReverseOrder.add(keywordInWorkflow1Content)
val DashboardWorkflowEntryList1 =
dashboardResource
.searchAllResourcesCall(sessionUser1, SearchQueryParams(keywords = keywordsReverseOrder))
.results
assert(DashboardWorkflowEntryList1.size == 1)
assert(DashboardWorkflowEntryList1.head.workflow.get.ownerName.equals(testUser.getName))
assertSameWorkflow(testWorkflow1, DashboardWorkflowEntryList1.head.workflow.get)
}
it should "handle reserved characters in the keywords" in {
// testWorkflow1: {name: test_name, description: test_description, content: "key pair"}
// search "key+-pair" or "key@pair" or "key+" or "+key" should return testWorkflow1
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
def testInner(keywords: String): Unit = {
val DashboardWorkflowEntryList = dashboardResource
.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(keywords = getKeywordsArray(keywords))
)
.results
assert(DashboardWorkflowEntryList.size == 1)
assert(DashboardWorkflowEntryList.head.workflow.get.ownerName.equals(testUser.getName))
assertSameWorkflow(testWorkflow1, DashboardWorkflowEntryList.head.workflow.get)
}
testInner(keywordInWorkflow1Content + "+-@()<>~*\"" + keywordInWorkflow1Content)
testInner(keywordInWorkflow1Content + "@" + keywordInWorkflow1Content)
testInner(keywordInWorkflow1Content + "+-@()<>~*\"")
testInner("+-@()<>~*\"" + keywordInWorkflow1Content)
}
it should "return all workflows when keywords only contains reserved keywords +-@()<>~*\"" in {
// search "+-@()<>~*"" should return all workflows
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
val DashboardWorkflowEntryList =
dashboardResource
.searchAllResourcesCall(sessionUser1, SearchQueryParams(getKeywordsArray("+-@()<>~*\"")))
.results
assert(DashboardWorkflowEntryList.size == 2)
}
it should "not be able to search workflows from different user accounts" in {
// user1 has workflow1
// user2 has workflow2
// users should only be able to search for workflows they have access to
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow2, sessionUser2)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
def test(user: SessionUser, workflow: Workflow): Unit = {
// search with reserved characters in keywords
val DashboardWorkflowEntryList =
dashboardResource
.searchAllResourcesCall(
user,
SearchQueryParams(getKeywordsArray(workflow.getDescription))
)
.results
assert(DashboardWorkflowEntryList.size == 1)
assert(DashboardWorkflowEntryList.head.workflow.get.ownerName.equals(user.getName()))
assertSameWorkflow(workflow, DashboardWorkflowEntryList.head.workflow.get)
}
test(sessionUser1, testWorkflow1)
test(sessionUser2, testWorkflow2)
}
it should "return a proper condition for a single owner" in {
val ownerList = new java.util.ArrayList[String](util.Arrays.asList("owner1"))
val ownerFilter: Condition =
FulltextSearchQueryUtils.getContainsFilter(ownerList, USER.EMAIL)
assert(ownerFilter.toString == USER.EMAIL.eq("owner1").toString)
}
it should "return a proper condition for multiple owners" in {
val ownerList = new java.util.ArrayList[String](util.Arrays.asList("owner1", "owner2"))
val ownerFilter: Condition =
FulltextSearchQueryUtils.getContainsFilter(ownerList, USER.EMAIL)
assert(ownerFilter.toString == USER.EMAIL.eq("owner1").or(USER.EMAIL.eq("owner2")).toString)
}
it should "return a proper condition for a single projectId" in {
val projectIdList = new java.util.ArrayList[Integer](util.Arrays.asList(Integer.valueOf(1)))
val projectFilter: Condition =
FulltextSearchQueryUtils.getContainsFilter(projectIdList, WORKFLOW_OF_PROJECT.PID)
assert(projectFilter.toString == WORKFLOW_OF_PROJECT.PID.eq(Integer.valueOf(1)).toString)
}
it should "return a proper condition for multiple projectIds" in {
val projectIdList = new java.util.ArrayList[Integer](
util.Arrays.asList(Integer.valueOf(1), Integer.valueOf(2))
)
val projectFilter: Condition =
FulltextSearchQueryUtils.getContainsFilter(projectIdList, WORKFLOW_OF_PROJECT.PID)
assert(
projectFilter.toString == WORKFLOW_OF_PROJECT.PID
.eq(Integer.valueOf(1))
.or(WORKFLOW_OF_PROJECT.PID.eq(Integer.valueOf(2)))
.toString
)
}
it should "return a proper condition for a single workflowID" in {
val workflowIdList = new java.util.ArrayList[Integer](util.Arrays.asList(Integer.valueOf(1)))
val workflowIdFilter: Condition =
FulltextSearchQueryUtils.getContainsFilter(workflowIdList, WORKFLOW.WID)
assert(workflowIdFilter.toString == WORKFLOW.WID.eq(Integer.valueOf(1)).toString)
}
it should "return a proper condition for multiple workflowIDs" in {
val workflowIdList = new java.util.ArrayList[Integer](
util.Arrays.asList(Integer.valueOf(1), Integer.valueOf(2))
)
val workflowIdFilter: Condition =
FulltextSearchQueryUtils.getContainsFilter(workflowIdList, WORKFLOW.WID)
assert(
workflowIdFilter.toString == WORKFLOW.WID
.eq(Integer.valueOf(1))
.or(WORKFLOW.WID.eq(Integer.valueOf(2)))
.toString
)
}
it should "return a proper condition for creation date type with specific start and end date" in {
val dateFilter: Condition =
FulltextSearchQueryUtils.getDateFilter(
"2023-01-01",
"2023-12-31",
WORKFLOW.CREATION_TIME
)
val dateFormat = new SimpleDateFormat("yyyy-MM-dd")
val startTimestamp = new Timestamp(dateFormat.parse("2023-01-01").getTime)
val endTimestamp =
new Timestamp(
dateFormat.parse("2023-12-31").getTime + TimeUnit.DAYS.toMillis(1) - 1
)
assert(
dateFilter.toString == WORKFLOW.CREATION_TIME.between(startTimestamp, endTimestamp).toString
)
}
it should "return a proper condition for modification date type with specific start and end date" in {
val dateFilter: Condition =
FulltextSearchQueryUtils.getDateFilter(
"2023-01-01",
"2023-12-31",
WORKFLOW.LAST_MODIFIED_TIME
)
val dateFormat = new SimpleDateFormat("yyyy-MM-dd")
val startTimestamp = new Timestamp(dateFormat.parse("2023-01-01").getTime)
val endTimestamp =
new Timestamp(
dateFormat.parse("2023-12-31").getTime + TimeUnit.DAYS.toMillis(1) - 1
)
assert(
dateFilter.toString == WORKFLOW.LAST_MODIFIED_TIME
.between(startTimestamp, endTimestamp)
.toString
)
}
it should "throw a ParseException when endDate is invalid" in {
assertThrows[ParseException] {
FulltextSearchQueryUtils.getDateFilter(
"2023-01-01",
"invalidDate",
WORKFLOW.CREATION_TIME
)
}
}
"getOperatorsFilter" should "return a noCondition when the input operators list is empty" in {
val operatorsFilter: Condition =
FulltextSearchQueryUtils.getOperatorsFilter(
Collections.emptyList[String](),
WORKFLOW.CONTENT
)
assert(operatorsFilter.toString == noCondition().toString)
}
it should "return a proper condition for a single operator" in {
val operatorsList = new java.util.ArrayList[String](util.Arrays.asList("operator1"))
val operatorsFilter: Condition =
FulltextSearchQueryUtils.getOperatorsFilter(operatorsList, WORKFLOW.CONTENT)
val searchKey = "%\"operatorType\":\"operator1\"%"
assert(operatorsFilter.toString == WORKFLOW.CONTENT.likeIgnoreCase(searchKey).toString)
}
it should "return a proper condition for multiple operators" in {
val operatorsList =
new java.util.ArrayList[String](util.Arrays.asList("operator1", "operator2"))
val operatorsFilter: Condition =
FulltextSearchQueryUtils.getOperatorsFilter(operatorsList, WORKFLOW.CONTENT)
val searchKey1 = "%\"operatorType\":\"operator1\"%"
val searchKey2 = "%\"operatorType\":\"operator2\"%"
assert(
operatorsFilter.toString == WORKFLOW.CONTENT
.likeIgnoreCase(searchKey1)
.or(WORKFLOW.CONTENT.likeIgnoreCase(searchKey2))
.toString
)
}
"/search API" should "be able to search for resources in different tables" in {
// create different types of resources, project, workflow, and file
projectResource.createProject(sessionUser1, "test project1")
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
// search
val DashboardClickableFileEntryList =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(getKeywordsArray("test"))
)
assert(DashboardClickableFileEntryList.results.length == 2)
}
it should "return all resources when no keyword provided" in {
projectResource.createProject(sessionUser1, "test project1")
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
val DashboardClickableFileEntryList =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(getKeywordsArray(""))
)
assert(DashboardClickableFileEntryList.results.length == 2)
}
it should "return multiple matching resources from a single resource type" in {
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
projectResource.createProject(sessionUser1, "common project1")
projectResource.createProject(sessionUser1, "common project2")
val DashboardClickableFileEntryList =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(getKeywordsArray("common"))
)
assert(DashboardClickableFileEntryList.results.length == 2)
}
it should "handle multiple keywords correctly" in {
projectResource.createProject(sessionUser1, "test project1")
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
val DashboardClickableFileEntryList =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(getKeywordsArray("test", "project1"))
)
assert(
DashboardClickableFileEntryList.results.length == 1
) // should only return the project
}
it should "filter results by different resourceType" in {
// create different types of resources
// 3 projects, 2 file, and 1 workflow,
projectResource.createProject(sessionUser1, "test project1")
projectResource.createProject(sessionUser1, "test project2")
projectResource.createProject(sessionUser1, "test project3")
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
// search resources with all resourceType
var DashboardClickableFileEntryList =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(getKeywordsArray("test"))
)
assert(DashboardClickableFileEntryList.results.length == 4)
// filter resources by workflow
DashboardClickableFileEntryList = dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(resourceType = "workflow", keywords = getKeywordsArray("test"))
)
assert(DashboardClickableFileEntryList.results.length == 1)
// filter resources by project
DashboardClickableFileEntryList = dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(resourceType = "project", keywords = getKeywordsArray("test"))
)
assert(DashboardClickableFileEntryList.results.length == 3)
}
it should "return resources that match any of all provided keywords" in {
// This test is designed to verify that the searchAllResources function correctly
// returns resources that match all of the provided keywords
// Create different types of resources, a project, a workflow, and a file
projectResource.createProject(sessionUser1, "test project")
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
// Perform search with multiple keywords
val DashboardClickableFileEntryList =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(keywords = getKeywordsArray("test", "project"))
)
// Assert that the search results include resources that match any of the provided keywords
assert(DashboardClickableFileEntryList.results.length == 1)
}
it should "not return resources that belong to a different user" in {
// This test is designed to verify that the searchAllResources function does not return resources that belong to a different user
// Create a project for a different user (sessionUser2)
projectResource.createProject(sessionUser2, "test project2")
// Perform search for resources using sessionUser1
val DashboardClickableFileEntryList =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(keywords = getKeywordsArray("test"))
)
// Assert that the search results do not include the project that belongs to the different user
// Assuming that DashboardClickableFileEntryList is a list of resources where each resource has a `user` property
assert(DashboardClickableFileEntryList.results.isEmpty)
}
it should "paginate results correctly" in {
// This test is designed to verify that the pagination works correctly
// Create 1 workflow, 10 projects
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
for (i <- 1 to 10) {
projectResource.createProject(sessionUser1, s"test project $i")
}
// Request the first page of results (page size is 10)
val firstPage =
dashboardResource.searchAllResourcesCall(sessionUser1, SearchQueryParams(count = 10))
// Assert that the first page has 10 results
assert(firstPage.results.length == 10)
assert(firstPage.more) // Assert that there are more results to be fetched
// Request the second page of results
val secondPage =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(count = 10, offset = 10)
)
// Assert that the second page has 1 results
assert(secondPage.results.length == 1)
// Assert that the results are unique across all pages
val allResults = firstPage.results ++ secondPage.results
assert(allResults.distinct.length == allResults.length)
}
it should "order workflow by name correctly" in {
// Create several resources with different names
workflowResource.persistWorkflow(testWorkflow1, sessionUser1)
workflowResource.persistWorkflow(testWorkflow3, sessionUser1)
workflowResource.persistWorkflow(testWorkflow2, sessionUser1)
// Retrieve resources ordered by name in ascending order
var resources =
dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(resourceType = "workflow", orderBy = "NameAsc")
)
// Check the order of the results
assert(resources.results(0).workflow.get.workflow.getName == "test_workflow1")
assert(resources.results(1).workflow.get.workflow.getName == "test_workflow2")
assert(resources.results(2).workflow.get.workflow.getName == "test_workflow3")
resources = dashboardResource.searchAllResourcesCall(
sessionUser1,
SearchQueryParams(resourceType = "workflow", orderBy = "NameDesc")
)
// Check the order of the results
assert(resources.results(0).workflow.get.workflow.getName == "test_workflow3")
assert(resources.results(1).workflow.get.workflow.getName == "test_workflow2")
assert(resources.results(2).workflow.get.workflow.getName == "test_workflow1")
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/resource/dashboard/user/workflow/WorkflowAccessResourceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.resource.dashboard.user.workflow
import org.apache.texera.auth.SessionUser
import org.apache.texera.dao.MockTexeraDB
import org.apache.texera.dao.jooq.generated.Tables._
import org.apache.texera.dao.jooq.generated.enums.PrivilegeEnum
import org.apache.texera.dao.jooq.generated.tables.daos.{
UserDao,
WorkflowDao,
WorkflowOfUserDao,
WorkflowUserAccessDao
}
import org.apache.texera.dao.jooq.generated.tables.pojos.{
User,
Workflow,
WorkflowOfUser,
WorkflowUserAccess
}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import java.sql.Timestamp
import javax.ws.rs.{BadRequestException, ForbiddenException}
class WorkflowAccessResourceSpec
extends AnyFlatSpec
with BeforeAndAfterAll
with BeforeAndAfterEach
with MockTexeraDB {
private val ownerUid = 1000 + scala.util.Random.nextInt(1000)
private val userWithWriteUid = 2000 + scala.util.Random.nextInt(1000)
private val userWithReadUid = 3000 + scala.util.Random.nextInt(1000)
private val targetUserUid = 4000 + scala.util.Random.nextInt(1000)
private val testWorkflowWid = 5000 + scala.util.Random.nextInt(1000)
private var owner: User = _
private var userWithWrite: User = _
private var userWithRead: User = _
private var targetUser: User = _
private var testWorkflow: Workflow = _
private var userDao: UserDao = _
private var workflowDao: WorkflowDao = _
private var workflowOfUserDao: WorkflowOfUserDao = _
private var workflowUserAccessDao: WorkflowUserAccessDao = _
private var workflowAccessResource: WorkflowAccessResource = _
override protected def beforeAll(): Unit = {
initializeDBAndReplaceDSLContext()
}
override protected def beforeEach(): Unit = {
// Initialize DAOs
userDao = new UserDao(getDSLContext.configuration())
workflowDao = new WorkflowDao(getDSLContext.configuration())
workflowOfUserDao = new WorkflowOfUserDao(getDSLContext.configuration())
workflowUserAccessDao = new WorkflowUserAccessDao(getDSLContext.configuration())
workflowAccessResource = new WorkflowAccessResource()
// Create test users
owner = new User
owner.setUid(ownerUid)
owner.setName("owner")
owner.setEmail("owner@test.com")
owner.setPassword("password")
userWithWrite = new User
userWithWrite.setUid(userWithWriteUid)
userWithWrite.setName("user_with_write")
userWithWrite.setEmail("write@test.com")
userWithWrite.setPassword("password")
userWithRead = new User
userWithRead.setUid(userWithReadUid)
userWithRead.setName("user_with_read")
userWithRead.setEmail("read@test.com")
userWithRead.setPassword("password")
targetUser = new User
targetUser.setUid(targetUserUid)
targetUser.setName("target_user")
targetUser.setEmail("target@test.com")
targetUser.setPassword("password")
// Create test workflow
testWorkflow = new Workflow
testWorkflow.setWid(testWorkflowWid)
testWorkflow.setName("test_workflow")
testWorkflow.setContent("{}")
testWorkflow.setDescription("test description")
testWorkflow.setCreationTime(new Timestamp(System.currentTimeMillis()))
testWorkflow.setLastModifiedTime(new Timestamp(System.currentTimeMillis()))
// Clean up before each test
cleanupTestData()
// Insert test data
userDao.insert(owner)
userDao.insert(userWithWrite)
userDao.insert(userWithRead)
userDao.insert(targetUser)
workflowDao.insert(testWorkflow)
// Set up workflow ownership
val workflowOfUser = new WorkflowOfUser
workflowOfUser.setUid(ownerUid)
workflowOfUser.setWid(testWorkflowWid)
workflowOfUserDao.insert(workflowOfUser)
// Grant write access to userWithWrite
val writeAccess = new WorkflowUserAccess
writeAccess.setUid(userWithWriteUid)
writeAccess.setWid(testWorkflowWid)
writeAccess.setPrivilege(PrivilegeEnum.WRITE)
workflowUserAccessDao.insert(writeAccess)
// Grant read access to userWithRead
val readAccess = new WorkflowUserAccess
readAccess.setUid(userWithReadUid)
readAccess.setWid(testWorkflowWid)
readAccess.setPrivilege(PrivilegeEnum.READ)
workflowUserAccessDao.insert(readAccess)
// Grant write access to targetUser
val targetAccess = new WorkflowUserAccess
targetAccess.setUid(targetUserUid)
targetAccess.setWid(testWorkflowWid)
targetAccess.setPrivilege(PrivilegeEnum.WRITE)
workflowUserAccessDao.insert(targetAccess)
}
override protected def afterEach(): Unit = {
cleanupTestData()
}
private def cleanupTestData(): Unit = {
getDSLContext
.deleteFrom(WORKFLOW_USER_ACCESS)
.where(WORKFLOW_USER_ACCESS.WID.eq(testWorkflowWid))
.execute()
getDSLContext
.deleteFrom(WORKFLOW_OF_USER)
.where(WORKFLOW_OF_USER.WID.eq(testWorkflowWid))
.execute()
getDSLContext
.deleteFrom(WORKFLOW)
.where(WORKFLOW.WID.eq(testWorkflowWid))
.execute()
getDSLContext
.deleteFrom(USER)
.where(
USER.UID.in(ownerUid, userWithWriteUid, userWithReadUid, targetUserUid)
)
.execute()
}
override protected def afterAll(): Unit = {
shutdownDB()
}
"WorkflowAccessResource.revokeAccess" should "successfully revoke access when user has WRITE permission" in {
val sessionUser = new SessionUser(userWithWrite)
// Verify target user has access before revocation
val accessBefore = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(targetUserUid)
)
)
.fetchOne()
assert(accessBefore != null, "Target user should have access before revocation")
// Revoke access
workflowAccessResource.revokeAccess(testWorkflowWid, "target@test.com", sessionUser)
// Verify access has been revoked
val accessAfter = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(targetUserUid)
)
)
.fetchOne()
assert(accessAfter == null, "Target user's access should be revoked")
}
it should "successfully allow user to revoke their own access" in {
val sessionUser = new SessionUser(userWithRead)
// Verify user has access before revocation
val accessBefore = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(userWithReadUid)
)
)
.fetchOne()
assert(accessBefore != null, "User should have access before revocation")
// User revokes their own access
workflowAccessResource.revokeAccess(testWorkflowWid, "read@test.com", sessionUser)
// Verify access has been revoked
val accessAfter = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(userWithReadUid)
)
)
.fetchOne()
assert(accessAfter == null, "User's own access should be revoked")
}
it should "throw ForbiddenException when user without WRITE permission tries to revoke others' access" in {
val sessionUser = new SessionUser(userWithRead)
assertThrows[ForbiddenException] {
workflowAccessResource.revokeAccess(testWorkflowWid, "target@test.com", sessionUser)
}
// Verify target user's access is still intact
val access = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(targetUserUid)
)
)
.fetchOne()
assert(access != null, "Target user's access should remain intact")
}
it should "throw ForbiddenException when trying to revoke owner's access" in {
val sessionUser = new SessionUser(userWithWrite)
val exception = intercept[ForbiddenException] {
workflowAccessResource.revokeAccess(testWorkflowWid, "owner@test.com", sessionUser)
}
assert(
exception.getMessage.contains("owner cannot revoke their own access"),
"Exception message should indicate owner cannot revoke their own access"
)
}
it should "throw ForbiddenException when owner tries to revoke their own access" in {
val sessionUser = new SessionUser(owner)
val exception = intercept[ForbiddenException] {
workflowAccessResource.revokeAccess(testWorkflowWid, "owner@test.com", sessionUser)
}
assert(
exception.getMessage.contains("owner cannot revoke their own access"),
"Exception message should indicate owner cannot revoke their own access"
)
}
it should "throw BadRequestException when email does not exist" in {
val sessionUser = new SessionUser(userWithWrite)
assertThrows[BadRequestException] {
workflowAccessResource.revokeAccess(
testWorkflowWid,
"nonexistent@test.com",
sessionUser
)
}
}
it should "not affect other users' access when revoking one user's access" in {
val sessionUser = new SessionUser(userWithWrite)
// Verify both users have access before revocation
val readAccessBefore = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(userWithReadUid)
)
)
.fetchOne()
assert(readAccessBefore != null, "Read user should have access before revocation")
val targetAccessBefore = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(targetUserUid)
)
)
.fetchOne()
assert(targetAccessBefore != null, "Target user should have access before revocation")
// Revoke only target user's access
workflowAccessResource.revokeAccess(testWorkflowWid, "target@test.com", sessionUser)
// Verify read user's access is still intact
val readAccessAfter = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(userWithReadUid)
)
)
.fetchOne()
assert(readAccessAfter != null, "Read user's access should remain intact")
// Verify target user's access has been revoked
val targetAccessAfter = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(targetUserUid)
)
)
.fetchOne()
assert(targetAccessAfter == null, "Target user's access should be revoked")
}
it should "handle revoking access for a user who already has no access gracefully" in {
val sessionUser = new SessionUser(userWithWrite)
// First revocation
workflowAccessResource.revokeAccess(testWorkflowWid, "target@test.com", sessionUser)
// Verify access has been revoked
val accessAfterFirst = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(targetUserUid)
)
)
.fetchOne()
assert(accessAfterFirst == null, "Target user's access should be revoked")
// Second revocation attempt (should not throw an error, just do nothing)
workflowAccessResource.revokeAccess(testWorkflowWid, "target@test.com", sessionUser)
// Verify access is still revoked
val accessAfterSecond = getDSLContext
.selectFrom(WORKFLOW_USER_ACCESS)
.where(
WORKFLOW_USER_ACCESS.WID
.eq(testWorkflowWid)
.and(
WORKFLOW_USER_ACCESS.UID.eq(targetUserUid)
)
)
.fetchOne()
assert(accessAfterSecond == null, "Target user's access should still be revoked")
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/resource/dashboard/user/workflow/WorkflowExecutionsResourceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.resource.dashboard.user.workflow
import org.apache.texera.amber.core.virtualidentity.{
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity
}
import org.apache.texera.amber.core.workflow.{GlobalPortIdentity, PortIdentity}
import org.apache.texera.amber.util.serde.GlobalPortIdentitySerde.SerdeOps
import org.apache.texera.dao.MockTexeraDB
import org.apache.texera.dao.jooq.generated.Tables._
import org.apache.texera.dao.jooq.generated.tables.daos.{
UserDao,
WorkflowDao,
WorkflowExecutionsDao,
WorkflowVersionDao
}
import org.apache.texera.dao.jooq.generated.tables.pojos.{
User,
Workflow,
WorkflowExecutions,
WorkflowVersion
}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, PrivateMethodTester}
import java.net.URI
import java.sql.Timestamp
import java.util.UUID
import java.util.concurrent.TimeUnit
import scala.collection.mutable.ArrayBuffer
class WorkflowExecutionsResourceSpec
extends AnyFlatSpec
with BeforeAndAfterAll
with BeforeAndAfterEach
with MockTexeraDB
with PrivateMethodTester {
private val testWorkflowWid = 3000 + scala.util.Random.nextInt(1000)
private val testUserId = 1000 + scala.util.Random.nextInt(1000)
private var testWorkflow: Workflow = _
private var testVersion: WorkflowVersion = _
private var testUser: User = _
private var userDao: UserDao = _
private var workflowDao: WorkflowDao = _
private var workflowVersionDao: WorkflowVersionDao = _
private var workflowExecutionsDao: WorkflowExecutionsDao = _
override protected def beforeAll(): Unit = {
initializeDBAndReplaceDSLContext()
}
override protected def beforeEach(): Unit = {
testUser = new User
testUser.setUid(testUserId)
testUser.setName("test_user")
testUser.setEmail("test@example.com")
testUser.setPassword("password")
testUser.setGoogleAvatar("avatar_url")
testWorkflow = new Workflow
testWorkflow.setWid(testWorkflowWid)
testWorkflow.setName("test_workflow_" + UUID.randomUUID().toString.substring(0, 8))
testWorkflow.setContent("{}")
testWorkflow.setDescription("test description")
testWorkflow.setCreationTime(new Timestamp(System.currentTimeMillis()))
testWorkflow.setLastModifiedTime(new Timestamp(System.currentTimeMillis()))
testVersion = new WorkflowVersion
testVersion.setWid(testWorkflowWid)
testVersion.setContent("{}")
testVersion.setCreationTime(new Timestamp(System.currentTimeMillis()))
workflowDao = new WorkflowDao(getDSLContext.configuration())
workflowVersionDao = new WorkflowVersionDao(getDSLContext.configuration())
userDao = new UserDao(getDSLContext.configuration())
workflowExecutionsDao = new WorkflowExecutionsDao(getDSLContext.configuration())
cleanupTestData()
userDao.insert(testUser)
workflowDao.insert(testWorkflow)
workflowVersionDao.insert(testVersion)
}
override protected def afterEach(): Unit = {
cleanupTestData()
}
private def cleanupTestData(): Unit = {
getDSLContext
.deleteFrom(WORKFLOW_EXECUTIONS)
.where(
WORKFLOW_EXECUTIONS.VID.in(
getDSLContext
.select(WORKFLOW_VERSION.VID)
.from(WORKFLOW_VERSION)
.where(WORKFLOW_VERSION.WID.eq(testWorkflowWid))
)
)
.execute()
getDSLContext
.deleteFrom(WORKFLOW_VERSION)
.where(WORKFLOW_VERSION.WID.eq(testWorkflowWid))
.execute()
getDSLContext
.deleteFrom(WORKFLOW)
.where(WORKFLOW.WID.eq(testWorkflowWid))
.execute()
getDSLContext
.deleteFrom(USER)
.where(USER.UID.eq(testUserId))
.execute()
}
override protected def afterAll(): Unit = {
shutdownDB()
}
"WorkflowExecutionsResource.getWorkflowExecutions" should "return executions with EIDs in descending order" in {
val numExecutions = 10
val executionIds = ArrayBuffer.empty[Integer]
for (i <- 1 to numExecutions) {
val execution = new WorkflowExecutions
execution.setVid(testVersion.getVid)
execution.setUid(testUser.getUid)
execution.setStatus(0.toByte)
execution.setResult("")
execution.setStartingTime(
new Timestamp(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(numExecutions - i))
)
execution.setBookmarked(false)
execution.setName(s"Execution ${i}")
execution.setEnvironmentVersion("test-env-1.0")
workflowExecutionsDao.insert(execution)
executionIds.append(execution.getEid)
}
val result = WorkflowExecutionsResource.getWorkflowExecutions(testWorkflowWid, getDSLContext)
assert(result.nonEmpty, "Result should not be empty")
assert(
result.size == numExecutions,
s"Expected $numExecutions executions, but got ${result.size}"
)
for (i <- 0 until result.size - 1) {
assert(
result(i).eId > result(i + 1).eId,
s"Executions are not in descending order: ${result(i).eId} should be > ${result(i + 1).eId}"
)
}
val returnedIds = result.map(_.eId).toSet
assert(
executionIds.toSet.subsetOf(returnedIds),
"All inserted execution IDs should be returned"
)
}
"WorkflowExecutionsResource.insertOperatorPortResultUri" should "insert a result URI row" in {
val execution = new WorkflowExecutions
execution.setVid(testVersion.getVid)
execution.setUid(testUser.getUid)
execution.setStatus(0.toByte)
execution.setResult("")
execution.setStartingTime(new Timestamp(System.currentTimeMillis()))
execution.setBookmarked(false)
execution.setName("Execution with duplicate result URI insert")
execution.setEnvironmentVersion("test-env-1.0")
workflowExecutionsDao.insert(execution)
val executionId = ExecutionIdentity(execution.getEid.longValue())
val globalPortId = GlobalPortIdentity(
PhysicalOpIdentity(OperatorIdentity("operator-1"), "main"),
PortIdentity(),
input = false
)
val uri = URI.create("vfs:///test-result")
WorkflowExecutionsResource.insertOperatorPortResultUri(executionId, globalPortId, uri)
val rows = getDSLContext
.selectFrom(OPERATOR_PORT_EXECUTIONS)
.where(OPERATOR_PORT_EXECUTIONS.WORKFLOW_EXECUTION_ID.eq(execution.getEid))
.and(OPERATOR_PORT_EXECUTIONS.GLOBAL_PORT_ID.eq(globalPortId.serializeAsString))
.fetch()
assert(rows.size() == 1)
assert(rows.get(0).getResultUri == uri.toString)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/resource/dashboard/user/workflow/WorkflowVersionResourceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.resource.dashboard.user.workflow
import org.apache.texera.amber.util.JSONUtils.objectMapper
import org.apache.texera.dao.MockTexeraDB
import org.apache.texera.dao.jooq.generated.Tables
import org.apache.texera.dao.jooq.generated.tables.daos.{WorkflowDao, WorkflowVersionDao}
import org.apache.texera.dao.jooq.generated.tables.pojos.{Workflow, WorkflowVersion}
import org.jooq.impl.DSL
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import java.sql.Timestamp
import java.util.UUID
import java.util.concurrent.TimeUnit
import scala.collection.mutable.ArrayBuffer
class WorkflowVersionResourceSpec
extends AnyFlatSpec
with BeforeAndAfterAll
with BeforeAndAfterEach
with MockTexeraDB {
private val testWorkflowWid = 2000 + scala.util.Random.nextInt(1000)
private var testWorkflow: Workflow = _
private var workflowDao: WorkflowDao = _
private var workflowVersionDao: WorkflowVersionDao = _
private val capturedVersions = ArrayBuffer.empty[Integer]
override protected def beforeAll(): Unit = {
initializeDBAndReplaceDSLContext()
}
override protected def beforeEach(): Unit = {
testWorkflow = new Workflow
testWorkflow.setWid(Integer.valueOf(testWorkflowWid))
testWorkflow.setName("test_workflow_" + UUID.randomUUID().toString.substring(0, 8))
testWorkflow.setContent(createWorkflowContent("initial"))
testWorkflow.setDescription("test description")
workflowDao = new WorkflowDao(getDSLContext.configuration())
workflowVersionDao = new WorkflowVersionDao(getDSLContext.configuration())
cleanupTestData()
workflowDao.insert(testWorkflow)
capturedVersions.clear()
}
override protected def afterEach(): Unit = {
cleanupTestData()
}
private def cleanupTestData(): Unit = {
getDSLContext
.deleteFrom(Tables.WORKFLOW_VERSION)
.where(Tables.WORKFLOW_VERSION.WID.eq(testWorkflowWid))
.execute()
getDSLContext
.deleteFrom(Tables.WORKFLOW)
.where(Tables.WORKFLOW.WID.eq(testWorkflowWid))
.execute()
}
override protected def afterAll(): Unit = {
shutdownDB()
}
private def createWorkflowContent(value: String): String = {
val jsonNode = objectMapper.createObjectNode()
jsonNode.put("value", value)
jsonNode.toString
}
private def createVersionDiff(oldValue: String, newValue: String): String = {
val oldJson = objectMapper.createObjectNode()
oldJson.put("value", oldValue)
val newJson = objectMapper.createObjectNode()
newJson.put("value", newValue)
val patch = com.flipkart.zjsonpatch.JsonDiff.asJson(
oldJson,
newJson
)
patch.toString
}
"WorkflowVersionResource" should "return versions in descending order from fetchSubsequentVersions and apply patches correctly" in {
var currentContent = "initial"
for (i <- 1 to 10) {
val newContent = s"version_$i"
val diffContent = createVersionDiff(currentContent, newContent)
val version = new WorkflowVersion
version.setWid(testWorkflow.getWid)
version.setContent(diffContent)
version.setCreationTime(
new Timestamp(System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(10 - i))
)
workflowVersionDao.insert(version)
currentContent = newContent
}
testWorkflow.setContent(createWorkflowContent(currentContent))
workflowDao.update(testWorkflow)
val midVersionId = 5
val versions = WorkflowVersionResource.fetchSubsequentVersions(
testWorkflow.getWid,
midVersionId,
getDSLContext
)
assert(versions.nonEmpty, "No versions were returned")
for (i <- 0 until versions.length - 1) {
assert(
versions(i).getVid > versions(i + 1).getVid,
s"Versions not in descending order: ${versions(i).getVid} should be > ${versions(i + 1).getVid}"
)
}
val highestVersionId = getDSLContext
.select(DSL.max(Tables.WORKFLOW_VERSION.VID))
.from(Tables.WORKFLOW_VERSION)
.where(Tables.WORKFLOW_VERSION.WID.eq(testWorkflowWid))
.fetchOneInto(classOf[Integer])
assert(versions.head.getVid === highestVersionId, "First version should have the highest VID")
capturedVersions.clear()
versions.foreach(v => capturedVersions.append(v.getVid))
val workflowFromDb = workflowDao.fetchOneByWid(testWorkflow.getWid)
val workflowVersionDirect = WorkflowVersionResource.applyPatch(versions, workflowFromDb)
val directVersionContent =
objectMapper.readTree(workflowVersionDirect.getContent).get("value").asText()
assert(
directVersionContent === s"version_$midVersionId",
s"Workflow content from direct applyPatch should be 'version_$midVersionId' but was '$directVersionContent'"
)
val combinedVersions = WorkflowVersionResource.fetchSubsequentVersions(
testWorkflow.getWid,
midVersionId,
getDSLContext
)
val currentWorkflowForCombined = workflowDao.fetchOneByWid(testWorkflow.getWid)
val workflowVersion =
WorkflowVersionResource.applyPatch(combinedVersions, currentWorkflowForCombined)
assert(capturedVersions.nonEmpty, "No versions were captured")
assert(
capturedVersions.length === versions.length,
"Captured versions length doesn't match fetched versions"
)
for (i <- versions.indices) {
assert(
capturedVersions(i) === versions(i).getVid,
s"Captured version ${capturedVersions(i)} doesn't match fetched version ${versions(i).getVid} at index $i"
)
}
val midVersionContent = objectMapper.readTree(workflowVersion.getContent).get("value").asText()
assert(
midVersionContent === s"version_$midVersionId",
s"Workflow content should be 'version_$midVersionId' but was '$midVersionContent'"
)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/resource/pythonvirtualenvironment/PveResourceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.resource.pythonvirtualenvironment
import org.scalatest.BeforeAndAfterEach
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import java.nio.file.{Files, Path, Paths}
import java.util.concurrent.LinkedBlockingQueue
import scala.jdk.CollectionConverters._
class PveResourceSpec extends AnyFlatSpec with Matchers with BeforeAndAfterEach {
private val testCuid = 256
private var testPveName: String = _
private var testRoot: Path = _
private var queue: LinkedBlockingQueue[String] = _
override protected def beforeEach(): Unit = {
testPveName = s"testenv${System.currentTimeMillis()}"
testRoot = Paths.get("/tmp/texera-pve/venvs").resolve(testCuid.toString)
queue = new LinkedBlockingQueue[String]()
}
override protected def afterEach(): Unit = {
PveManager.deleteEnvironments(testCuid)
}
private def queueText(): String = {
queue.iterator().asScala.toList.mkString("\n")
}
"PveManager" should "create a new PVE and list it" in {
PveManager.createNewPve(testCuid, queue, testPveName, isLocal = true)
val logs = queueText()
logs should not include "[PVE][ERR]"
logs should include(s"[PVE] Created new environment for cuid = $testCuid")
val pvePath = testRoot.resolve(testPveName).resolve("pve")
val pythonPath = pvePath.resolve("bin").resolve("python")
val pipPath = pvePath.resolve("bin").resolve("pip")
Files.exists(pvePath) shouldBe true
Files.exists(pythonPath) shouldBe true
Files.exists(pipPath) shouldBe true
PveManager.getEnvironments(testCuid) should contain(testPveName)
}
"PveManager" should "delete all PVEs for a computing unit" in {
PveManager.createNewPve(testCuid, queue, testPveName, isLocal = true)
Files.exists(testRoot.resolve(testPveName)) shouldBe true
PveManager.deleteEnvironments(testCuid)
Files.exists(testRoot) shouldBe false
PveManager.getEnvironments(testCuid) shouldBe empty
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/service/ExecutionConsoleServiceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import com.google.protobuf.timestamp.Timestamp
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.{
ConsoleMessage,
ConsoleMessageType
}
import org.apache.texera.amber.engine.common.executionruntimestate.ExecutionConsoleStore
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import java.time.Instant
class ExecutionConsoleServiceSpec extends AnyFlatSpec with Matchers {
// Constants for testing
val standardBufferSize: Int = 100
val smallBufferSize: Int = 2
val messageDisplayLength: Int = 100
"processConsoleMessage" should "truncate message title when it exceeds display length" in {
// Create a long message title that exceeds display length
val longTitle = "a" * (messageDisplayLength + 10)
val expectedTruncatedTitle = "a" * (messageDisplayLength - 3) + "..."
// Create a console message with a long title
val consoleMessage = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
longTitle,
"message content"
)
// Call the method under test
val processedMessage =
ConsoleMessageProcessor.processConsoleMessage(consoleMessage, messageDisplayLength)
// Verify the title was truncated
processedMessage.title shouldBe expectedTruncatedTitle
}
it should "not truncate message title when it does not exceed display length" in {
// Create a short message title that doesn't exceed display length
val shortTitle = "Short Title"
// Create a console message with a short title
val consoleMessage = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
shortTitle,
"message content"
)
// Call the method under test
val processedMessage =
ConsoleMessageProcessor.processConsoleMessage(consoleMessage, messageDisplayLength)
// Verify the title was not truncated
processedMessage.title shouldBe shortTitle
}
"addMessageToOperatorConsole" should "add message to buffer when buffer is not full" in {
// Create a test console store
val consoleStore = new ExecutionConsoleStore()
val opId = "op1"
// Create console messages
val message1 = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
"Message 1",
"content 1"
)
val message2 = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
"Message 2",
"content 2"
)
// Add first message
val storeWithMessage1 =
ConsoleMessageProcessor.addMessageToOperatorConsole(
consoleStore,
opId,
message1,
standardBufferSize
)
// Add second message
val storeWithMessage2 = ConsoleMessageProcessor.addMessageToOperatorConsole(
storeWithMessage1,
opId,
message2,
standardBufferSize
)
// Verify both messages are in the buffer
val opInfo = storeWithMessage2.operatorConsole(opId)
opInfo.consoleMessages.size shouldBe 2
opInfo.consoleMessages.head.title shouldBe "Message 1"
opInfo.consoleMessages(1).title shouldBe "Message 2"
}
it should "remove oldest message when buffer is full" in {
// Create a test console store
val consoleStore = new ExecutionConsoleStore()
val opId = "op1"
// Create console messages
val message1 = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
"Message 1",
"content 1"
)
val message2 = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
"Message 2",
"content 2"
)
val message3 = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
"Message 3",
"content 3"
)
// Fill the buffer
val storeWithMessage1 =
ConsoleMessageProcessor.addMessageToOperatorConsole(
consoleStore,
opId,
message1,
smallBufferSize
)
val storeWithMessage2 =
ConsoleMessageProcessor.addMessageToOperatorConsole(
storeWithMessage1,
opId,
message2,
smallBufferSize
)
// Add one more message which should remove the oldest
val storeWithMessage3 =
ConsoleMessageProcessor.addMessageToOperatorConsole(
storeWithMessage2,
opId,
message3,
smallBufferSize
)
// Verify the first message was removed and only the second and third remain
val opInfo = storeWithMessage3.operatorConsole(opId)
opInfo.consoleMessages.size shouldBe 2
opInfo.consoleMessages.head.title shouldBe "Message 2"
opInfo.consoleMessages(1).title shouldBe "Message 3"
}
"the complete message processing flow" should "handle messages correctly" in {
// Create a test console store
val consoleStore = new ExecutionConsoleStore()
val opId = "op1"
// Create a message with a title that needs truncation
val longTitle = "a" * (messageDisplayLength + 10)
val consoleMessage = new ConsoleMessage(
"worker1",
Timestamp(Instant.now),
ConsoleMessageType.PRINT,
"test",
longTitle,
"message content"
)
// Process the message first
val processedMessage =
ConsoleMessageProcessor.processConsoleMessage(consoleMessage, messageDisplayLength)
// Then update the store
val updatedStore = ConsoleMessageProcessor.addMessageToOperatorConsole(
consoleStore,
opId,
processedMessage,
standardBufferSize
)
// Verify correct processing
val opInfo = updatedStore.operatorConsole(opId)
opInfo.consoleMessages.size shouldBe 1
// Check that title was truncated
val expectedTruncatedTitle = "a" * (messageDisplayLength - 3) + "..."
opInfo.consoleMessages.head.title shouldBe expectedTruncatedTitle
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/service/ExecutionReconfigurationServiceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import org.apache.texera.amber.core.executor.OpExecWithClassName
import org.apache.texera.amber.core.virtualidentity.{
ActorVirtualIdentity,
ExecutionIdentity,
OperatorIdentity,
PhysicalOpIdentity,
WorkflowIdentity
}
import org.apache.texera.amber.core.workflow.PhysicalOp
import org.apache.texera.amber.engine.architecture.rpc.controlcommands.WorkflowReconfigureRequest
import org.apache.texera.web.storage.{ExecutionReconfigurationStore, ExecutionStateStore}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import scala.collection.mutable.ArrayBuffer
/**
* Web-service-layer tests for ExecutionReconfigurationService.
*
* The end-to-end engine path (reconfigureWorkflow → Fries algorithm →
* UpdateExecutor on workers) is covered by ReconfigurationSpec.
* This spec focuses on the wiring inside performReconfigurationOnResume:
* empty short-circuit, request construction, and store reset semantics.
*/
class ExecutionReconfigurationServiceSpec extends AnyFlatSpec with Matchers {
private def mkPhysicalOp(name: String): PhysicalOp =
PhysicalOp(
id = PhysicalOpIdentity(OperatorIdentity(name), "main"),
workflowId = WorkflowIdentity(0L),
executionId = ExecutionIdentity(0L),
opExecInitInfo = OpExecWithClassName(s"$name.Class", "")
)
/** Service variant that records dispatched requests and skips the AmberClient
* registration / workflow-dependent diff handler so it can be constructed
* without a live engine.
*/
private class RecordingService(stateStore: ExecutionStateStore)
extends ExecutionReconfigurationService(client = null, stateStore, workflow = null) {
val captured: ArrayBuffer[WorkflowReconfigureRequest] = ArrayBuffer.empty
override protected def dispatch(request: WorkflowReconfigureRequest): Unit =
captured += request
override protected def registerWorkerCompletionCallback(): Unit = ()
override protected def registerCompletionDiffHandler(): Unit = ()
}
"performReconfigurationOnResume" should
"return without dispatching when no reconfigurations are pending" in {
val stateStore = new ExecutionStateStore()
val service = new RecordingService(stateStore)
noException should be thrownBy service.performReconfigurationOnResume()
service.captured shouldBe empty
val state = stateStore.reconfigurationStore.getState
state.unscheduledReconfigurations shouldBe empty
state.currentReconfigId shouldBe None
state.completedReconfigurations shouldBe empty
}
it should "dispatch one request carrying every pending reconfiguration and reset the store" in {
val stateStore = new ExecutionStateStore()
val service = new RecordingService(stateStore)
val op1 = mkPhysicalOp("op-1")
val op2 = mkPhysicalOp("op-2")
stateStore.reconfigurationStore.updateState(_ =>
ExecutionReconfigurationStore(unscheduledReconfigurations = List((op1, None), (op2, None)))
)
service.performReconfigurationOnResume()
service.captured should have size 1
val request = service.captured.head
request.reconfigurationId should not be empty
request.reconfiguration.map(_.targetOpId) should contain theSameElementsInOrderAs Seq(
op1.id,
op2.id
)
request.reconfiguration.map(_.newExecInitInfo) should contain theSameElementsInOrderAs Seq(
op1.opExecInitInfo,
op2.opExecInitInfo
)
val state = stateStore.reconfigurationStore.getState
state.unscheduledReconfigurations shouldBe empty
state.currentReconfigId shouldBe Some(request.reconfigurationId)
state.completedReconfigurations shouldBe empty
}
it should "use a fresh reconfigurationId on each dispatch" in {
val stateStore = new ExecutionStateStore()
val service = new RecordingService(stateStore)
def queueAndDispatch(opName: String): String = {
stateStore.reconfigurationStore.updateState(old =>
old.copy(unscheduledReconfigurations = List((mkPhysicalOp(opName), None)))
)
service.performReconfigurationOnResume()
service.captured.last.reconfigurationId
}
val firstId = queueAndDispatch("op-a")
val secondId = queueAndDispatch("op-b")
firstId should not be secondId
stateStore.reconfigurationStore.getState.currentReconfigId shouldBe Some(secondId)
}
"onWorkerReconfigured" should
"add the worker id to completedReconfigurations so the diff handler can fire" in {
val stateStore = new ExecutionStateStore()
val service = new RecordingService(stateStore)
val w1 = ActorVirtualIdentity("Worker:WF1-E1-op-main-0")
val w2 = ActorVirtualIdentity("Worker:WF1-E1-op-main-1")
service.onWorkerReconfigured(w1)
service.onWorkerReconfigured(w2)
// duplicate completion is idempotent (Set semantics).
service.onWorkerReconfigured(w1)
stateStore.reconfigurationStore.getState.completedReconfigurations should contain theSameElementsAs Set(
w1,
w2
)
}
}
================================================
FILE: amber/src/test/scala/org/apache/texera/web/service/ExecutionResultServiceSpec.scala
================================================
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.web.service
import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
class ExecutionResultServiceSpec extends AnyFlatSpec with Matchers {
"convertTuplesToJson" should "convert tuples with various field types correctly" in {
// Create a schema with different attribute types
val attributes = List(
new Attribute("stringCol", AttributeType.STRING),
new Attribute("intCol", AttributeType.INTEGER),
new Attribute("boolCol", AttributeType.BOOLEAN),
new Attribute("nullCol", AttributeType.ANY),
new Attribute("longStringCol", AttributeType.STRING),
new Attribute("shortBinaryCol", AttributeType.BINARY),
new Attribute("longBinaryCol", AttributeType.BINARY)
)
val schema = new Schema(attributes)
// Create a string longer than maxStringLength (100)
val longString = "a" * 150
// Create binary data
val shortBinaryData = Array[Byte](1, 2, 3, 4, 5)
val longBinaryData = Array.tabulate[Byte](100)(_.toByte)
// Create a tuple with all the test data
val tuple = Tuple
.builder(schema)
.add("stringCol", AttributeType.STRING, "regular string")
.add("intCol", AttributeType.INTEGER, 42)
.add("boolCol", AttributeType.BOOLEAN, true)
.add("nullCol", AttributeType.ANY, null)
.add("longStringCol", AttributeType.STRING, longString)
.add("shortBinaryCol", AttributeType.BINARY, shortBinaryData)
.add("longBinaryCol", AttributeType.BINARY, longBinaryData)
.build()
// Convert to JSON
val result = ExecutionResultService.convertTuplesToJson(List(tuple))
// Verify the result
result should have size 1
val jsonNode = result.head
// Check regular values
jsonNode.get("stringCol").asText() shouldBe "regular string"
jsonNode.get("intCol").asInt() shouldBe 42
jsonNode.get("boolCol").asBoolean() shouldBe true
// Check NULL value
jsonNode.get("nullCol").asText() shouldBe "NULL"
// Check long string truncation
jsonNode.get("longStringCol").asText() should (
have length 103 and // 100 chars + "..."
startWith("a" * 100) and
endWith("...")
)
// Check short binary representation
val shortBinaryString = jsonNode.get("shortBinaryCol").asText()
shortBinaryString should (
startWith("