diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c5c1339b8..775324b5d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,7 +60,7 @@ jobs: report_paths: "**/build/test-results/test/TEST-*.xml" unit_test_jdk8: - name: Unit test with docker service [JDK8] + name: Unit test with CLI runs-on: ubuntu-latest-16-cores timeout-minutes: 30 steps: @@ -82,9 +82,9 @@ jobs: - name: Set up Gradle uses: gradle/actions/setup-gradle@ac396bf1a80af16236baf54bd7330ae21dc6ece5 # v6 - - name: Start containerized server and dependencies + - name: Start CLI server env: - TEMPORAL_CLI_VERSION: 1.6.1-server-1.31.0-151.0 + TEMPORAL_CLI_VERSION: 1.7.0 run: | wget -O temporal_cli.tar.gz https://github.com/temporalio/cli/releases/download/v${TEMPORAL_CLI_VERSION}/temporal_cli_${TEMPORAL_CLI_VERSION}_linux_amd64.tar.gz tar -xzf temporal_cli.tar.gz diff --git a/README.md b/README.md index 5e24c5116..52be87a5b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +![Temporal Java SDK](https://raw.githubusercontent.com/temporalio/assets/main/files/w/java.png) + # Temporal Java SDK [![Build status](https://github.com/temporalio/sdk-java/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/temporalio/sdk-java/actions/workflows/ci.yml) [![Coverage Status](https://coveralls.io/repos/github/temporalio/sdk-java/badge.svg?branch=master)](https://coveralls.io/github/temporalio/sdk-java?branch=master) [Temporal](https://github.com/temporalio/temporal) is a Workflow-as-Code platform for building and operating @@ -95,4 +97,4 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file +limitations under the License. diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java index b01503956..1dceb67fb 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java @@ -37,6 +37,7 @@ public ActivityPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, double activitiesPerSecond, @Nonnull TrackingSlotSupplier slotSupplier, @@ -53,6 +54,7 @@ public ActivityPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java index 520ce7a37..d2fddde3f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java @@ -105,6 +105,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), taskQueueActivitiesPerSecond, this.slotSupplier, @@ -113,7 +114,7 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { @@ -125,6 +126,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), taskQueueActivitiesPerSecond, this.slotSupplier, @@ -133,7 +135,8 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java index 60ebcbf65..b23d16184 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java @@ -43,6 +43,7 @@ public AsyncActivityPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, double activitiesPerSecond, @Nonnull TrackingSlotSupplier slotSupplier, @@ -59,6 +60,7 @@ public AsyncActivityPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java index efc4dc807..1ba3b84d1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java @@ -41,6 +41,7 @@ public AsyncNexusPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull Scope metricsScope, @Nonnull Supplier serverCapabilities, @@ -57,6 +58,8 @@ public AsyncNexusPollTask( .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); + if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequest.setDeploymentOptions( WorkerVersioningProtoUtils.deploymentOptionsToProto( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java index 510634379..7859484bb 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java @@ -29,7 +29,6 @@ final class AsyncPoller extends BasePoller { private final List> asyncTaskPollers; private final PollerOptions pollerOptions; private final PollerBehaviorAutoscaling pollerBehavior; - private final boolean serverSupportsAutoscaling; private final Scope workerMetricsScope; private Throttler pollRateThrottler; private final Thread.UncaughtExceptionHandler uncaughtExceptionHandler = @@ -43,7 +42,7 @@ final class AsyncPoller extends BasePoller { PollTaskAsync asyncTaskPoller, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - boolean serverSupportsAutoscaling, + NamespaceCapabilities namespaceCapabilities, Scope workerMetricsScope) { this( slotSupplier, @@ -51,7 +50,7 @@ final class AsyncPoller extends BasePoller { Collections.singletonList(asyncTaskPoller), taskExecutor, pollerOptions, - serverSupportsAutoscaling, + namespaceCapabilities, workerMetricsScope); } @@ -61,9 +60,9 @@ final class AsyncPoller extends BasePoller { List> asyncTaskPollers, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - boolean serverSupportsAutoscaling, + NamespaceCapabilities namespaceCapabilities, Scope workerMetricsScope) { - super(taskExecutor); + super(taskExecutor, namespaceCapabilities); Objects.requireNonNull(slotSupplier, "slotSupplier cannot be null"); Objects.requireNonNull(slotReservationData, "slotReservation data should not be null"); Objects.requireNonNull(asyncTaskPollers, "asyncTaskPollers should not be null"); @@ -82,7 +81,6 @@ final class AsyncPoller extends BasePoller { + " is not supported for AsyncPoller. Only PollerBehaviorAutoscaling is supported."); } this.pollerBehavior = (PollerBehaviorAutoscaling) pollerOptions.getPollerBehavior(); - this.serverSupportsAutoscaling = serverSupportsAutoscaling; this.pollerOptions = pollerOptions; this.workerMetricsScope = workerMetricsScope; } @@ -114,7 +112,7 @@ public boolean start() { pollerBehavior.getMinConcurrentTaskPollers(), pollerBehavior.getMaxConcurrentTaskPollers(), pollerBehavior.getInitialConcurrentTaskPollers(), - serverSupportsAutoscaling, + namespaceCapabilities.isPollerAutoscaling(), (newTarget) -> { log.debug( "Updating maximum number of pollers for {} to: {}", @@ -136,12 +134,14 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean return super.shutdown(shutdownManager, interruptTasks) .thenApply( (f) -> { - for (PollTaskAsync asyncTaskPoller : asyncTaskPollers) { - try { - log.debug("Shutting down async poller: {}", asyncTaskPoller.getLabel()); - asyncTaskPoller.cancel(new RuntimeException("Shutting down poller")); - } catch (Throwable e) { - log.error("Error while cancelling poll task", e); + if (!namespaceCapabilities.isGracefulPollShutdown()) { + for (PollTaskAsync asyncTaskPoller : asyncTaskPollers) { + try { + log.debug("Shutting down async poller: {}", asyncTaskPoller.getLabel()); + asyncTaskPoller.cancel(new RuntimeException("Shutting down poller")); + } catch (Throwable e) { + log.error("Error while cancelling poll task", e); + } } } return null; diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java index c30dbc9e1..3bfa796a3 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java @@ -52,6 +52,7 @@ public AsyncWorkflowPollTask( @Nonnull String taskQueue, @Nullable String stickyTaskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @@ -67,6 +68,8 @@ public AsyncWorkflowPollTask( .setNamespace(Objects.requireNonNull(namespace)) .setIdentity(Objects.requireNonNull(identity)); + pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); + if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( WorkerVersioningProtoUtils.deploymentOptionsToProto( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java index 9b8141fc0..febd6241a 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java @@ -27,9 +27,14 @@ abstract class BasePoller implements SuspendableWorker { protected ExecutorService pollExecutor; - protected BasePoller(ShutdownableTaskExecutor taskExecutor) { + protected final NamespaceCapabilities namespaceCapabilities; + + protected BasePoller( + ShutdownableTaskExecutor taskExecutor, NamespaceCapabilities namespaceCapabilities) { Objects.requireNonNull(taskExecutor, "taskExecutor should not be null"); this.taskExecutor = taskExecutor; + this.namespaceCapabilities = + Objects.requireNonNull(namespaceCapabilities, "namespaceCapabilities should not be null"); } @Override @@ -55,15 +60,24 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean return CompletableFuture.completedFuture(null); } - return shutdownManager - // it's ok to forcefully shutdown pollers, because they are stuck in a long poll call - // so we don't risk loosing any progress doing that. - .shutdownExecutorNow(pollExecutor, this + "#pollExecutor", Duration.ofSeconds(1)) - .exceptionally( - e -> { - log.error("Unexpected exception during shutdown", e); - return null; - }); + CompletableFuture pollExecutorShutdown; + if (namespaceCapabilities.isGracefulPollShutdown()) { + // When graceful poll shutdown is enabled, the server will complete outstanding polls with + // empty responses after ShutdownWorker is called. We simply wait for polls to return. + pollExecutorShutdown = + shutdownManager.shutdownExecutor( + pollExecutor, this + "#pollExecutor", Duration.ofSeconds(80)); + } else { + // Old behaviour forcibly stops outstanding polls. + pollExecutorShutdown = + shutdownManager.shutdownExecutorNow( + pollExecutor, this + "#pollExecutor", Duration.ofSeconds(1)); + } + return pollExecutorShutdown.exceptionally( + e -> { + log.error("Unexpected exception during shutdown", e); + return null; + }); } @Override diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java index 8dcaa6f33..7fe0335b1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java @@ -52,8 +52,9 @@ public MultiThreadedPoller( PollTask pollTask, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - Scope workerMetricsScope) { - super(taskExecutor); + Scope workerMetricsScope, + NamespaceCapabilities namespaceCapabilities) { + super(taskExecutor, namespaceCapabilities); Objects.requireNonNull(identity, "identity cannot be null"); Objects.requireNonNull(pollTask, "poll service should not be null"); Objects.requireNonNull(pollerOptions, "pollerOptions should not be null"); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java index 4fa9d09a5..a3410fa25 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java @@ -1,5 +1,6 @@ package io.temporal.internal.worker; +import io.temporal.api.namespace.v1.NamespaceInfo.Capabilities; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -9,14 +10,28 @@ */ public final class NamespaceCapabilities { private final AtomicBoolean pollerAutoscaling = new AtomicBoolean(false); + private final AtomicBoolean gracefulPollShutdown = new AtomicBoolean(false); private final AtomicBoolean workerHeartbeats = new AtomicBoolean(false); + public void setFromCapabilities(Capabilities capabilities) { + if (capabilities.getPollerAutoscaling()) { + pollerAutoscaling.set(true); + } + if (capabilities.getWorkerPollCompleteOnShutdown()) { + gracefulPollShutdown.set(true); + } + } + public boolean isPollerAutoscaling() { return pollerAutoscaling.get(); } - public void setPollerAutoscaling(boolean value) { - pollerAutoscaling.set(value); + public boolean isGracefulPollShutdown() { + return gracefulPollShutdown.get(); + } + + public void setGracefulPollShutdown(boolean value) { + gracefulPollShutdown.set(value); } public boolean isWorkerHeartbeats() { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java index 4116825b9..0ccab5944 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java @@ -34,6 +34,7 @@ public NexusPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @@ -49,6 +50,7 @@ public NexusPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequest.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java index d826e5543..ac364a747 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java @@ -111,6 +111,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), workerMetricsScope, service.getServerCapabilities(), @@ -118,7 +119,7 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { poller = @@ -129,6 +130,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), this.slotSupplier, workerMetricsScope, @@ -136,7 +138,8 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java index f8baba01d..559370772 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java @@ -40,6 +40,7 @@ public static final class Builder { private Duration drainStickyTaskQueueTimeout; private boolean usingVirtualThreads; private WorkerDeploymentOptions deploymentOptions; + private String workerInstanceKey; private Builder() {} @@ -64,6 +65,7 @@ private Builder(SingleWorkerOptions options) { this.drainStickyTaskQueueTimeout = options.getDrainStickyTaskQueueTimeout(); this.usingVirtualThreads = options.isUsingVirtualThreads(); this.deploymentOptions = options.getDeploymentOptions(); + this.workerInstanceKey = options.getWorkerInstanceKey(); } public Builder setIdentity(String identity) { @@ -155,6 +157,11 @@ public Builder setDeploymentOptions(WorkerDeploymentOptions deploymentOptions) { return this; } + public Builder setWorkerInstanceKey(String workerInstanceKey) { + this.workerInstanceKey = workerInstanceKey; + return this; + } + public SingleWorkerOptions build() { PollerOptions pollerOptions = this.pollerOptions; if (pollerOptions == null) { @@ -193,7 +200,8 @@ public SingleWorkerOptions build() { this.defaultHeartbeatThrottleInterval, drainStickyTaskQueueTimeout, usingVirtualThreads, - this.deploymentOptions); + this.deploymentOptions, + this.workerInstanceKey); } } @@ -214,6 +222,7 @@ public SingleWorkerOptions build() { private final Duration drainStickyTaskQueueTimeout; private final boolean usingVirtualThreads; private final WorkerDeploymentOptions deploymentOptions; + private final String workerInstanceKey; private SingleWorkerOptions( String identity, @@ -232,7 +241,8 @@ private SingleWorkerOptions( Duration defaultHeartbeatThrottleInterval, Duration drainStickyTaskQueueTimeout, boolean usingVirtualThreads, - WorkerDeploymentOptions deploymentOptions) { + WorkerDeploymentOptions deploymentOptions, + String workerInstanceKey) { this.identity = identity; this.binaryChecksum = binaryChecksum; this.buildId = buildId; @@ -250,6 +260,7 @@ private SingleWorkerOptions( this.drainStickyTaskQueueTimeout = drainStickyTaskQueueTimeout; this.usingVirtualThreads = usingVirtualThreads; this.deploymentOptions = deploymentOptions; + this.workerInstanceKey = workerInstanceKey; } public String getIdentity() { @@ -331,6 +342,10 @@ public WorkerDeploymentOptions getDeploymentOptions() { return deploymentOptions; } + public String getWorkerInstanceKey() { + return workerInstanceKey; + } + public WorkerVersioningOptions getWorkerVersioningOptions() { return new WorkerVersioningOptions( this.getBuildId(), this.isUsingBuildIdForVersioning(), this.getDeploymentOptions()); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java index 51ab7a700..18cf7fd4a 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java @@ -3,9 +3,7 @@ import static io.temporal.internal.common.InternalUtils.createStickyTaskQueue; import io.temporal.api.common.v1.Payloads; -import io.temporal.api.enums.v1.TaskQueueType; import io.temporal.api.taskqueue.v1.TaskQueue; -import io.temporal.api.worker.v1.WorkerHeartbeat; import io.temporal.client.WorkflowClient; import io.temporal.common.converter.DataConverter; import io.temporal.common.converter.EncodedValues; @@ -24,11 +22,9 @@ import io.temporal.workflow.Functions.Func1; import java.lang.reflect.Type; import java.time.Duration; -import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.*; -import java.util.function.Supplier; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.slf4j.Logger; @@ -64,8 +60,6 @@ public SyncWorkflowWorker( @Nonnull WorkflowClient client, @Nonnull String namespace, @Nonnull String taskQueue, - @Nonnull String workerInstanceKey, - @Nonnull Supplier> activeTaskQueueTypesSupplier, @Nonnull SingleWorkerOptions singleWorkerOptions, @Nonnull SingleWorkerOptions localActivityOptions, @Nonnull WorkflowRunLockManager runLocks, @@ -123,8 +117,6 @@ public SyncWorkflowWorker( client.getWorkflowServiceStubs(), namespace, taskQueue, - workerInstanceKey, - activeTaskQueueTypesSupplier, stickyTaskQueueName, singleWorkerOptions, runLocks, @@ -250,10 +242,6 @@ public TrackingSlotSupplier getLocalActivitySlotSupplier( return laWorker.getSlotSupplier(); } - public void setHeartbeatSupplier(Supplier supplier) { - workflowWorker.setHeartbeatSupplier(supplier); - } - public boolean hasStickyQueue() { return workflowWorker.hasStickyQueue(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java index cdb5e5163..18607b5d1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java @@ -47,6 +47,7 @@ public WorkflowPollTask( @Nonnull String taskQueue, @Nullable String stickyTaskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull StickyQueueBalancer stickyQueueBalancer, @@ -73,6 +74,7 @@ public WorkflowPollTask( PollWorkflowTaskQueueRequest.newBuilder() .setNamespace(Objects.requireNonNull(namespace)) .setIdentity(Objects.requireNonNull(identity)); + pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java index f98316d5d..a128c7b75 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java @@ -13,11 +13,8 @@ import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.enums.v1.QueryResultType; import io.temporal.api.enums.v1.TaskQueueKind; -import io.temporal.api.enums.v1.TaskQueueType; -import io.temporal.api.enums.v1.WorkerStatus; import io.temporal.api.enums.v1.WorkflowTaskFailedCause; import io.temporal.api.failure.v1.Failure; -import io.temporal.api.worker.v1.WorkerHeartbeat; import io.temporal.api.workflowservice.v1.*; import io.temporal.failure.ApplicationFailure; import io.temporal.internal.logging.LoggerTag; @@ -33,7 +30,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.slf4j.Logger; @@ -41,7 +37,6 @@ import org.slf4j.MDC; final class WorkflowWorker implements SuspendableWorker { - private static final String GRACEFUL_SHUTDOWN_MESSAGE = "graceful shutdown"; private static final Logger log = LoggerFactory.getLogger(WorkflowWorker.class); private final WorkflowRunLockManager runLocks; @@ -58,9 +53,6 @@ final class WorkflowWorker implements SuspendableWorker { private final GrpcRetryer grpcRetryer; private final EagerActivityDispatcher eagerActivityDispatcher; private final TrackingSlotSupplier slotSupplier; - private volatile Supplier heartbeatSupplier; - private final String workerInstanceKey; - private final Supplier> activeTaskQueueTypesSupplier; private final TaskCounter taskCounter = new TaskCounter(); private final PollerTracker pollerTracker = new PollerTracker(); @@ -79,8 +71,6 @@ public WorkflowWorker( @Nonnull WorkflowServiceStubs service, @Nonnull String namespace, @Nonnull String taskQueue, - @Nonnull String workerInstanceKey, - @Nonnull Supplier> activeTaskQueueTypesSupplier, @Nullable String stickyTaskQueueName, @Nonnull SingleWorkerOptions options, @Nonnull WorkflowRunLockManager runLocks, @@ -92,8 +82,6 @@ public WorkflowWorker( this.service = Objects.requireNonNull(service); this.namespace = Objects.requireNonNull(namespace); this.taskQueue = Objects.requireNonNull(taskQueue); - this.workerInstanceKey = Objects.requireNonNull(workerInstanceKey); - this.activeTaskQueueTypesSupplier = Objects.requireNonNull(activeTaskQueueTypesSupplier); this.options = Objects.requireNonNull(options); this.stickyTaskQueueName = stickyTaskQueueName; this.pollerOptions = getPollerOptions(options); @@ -133,6 +121,7 @@ public boolean start() { taskQueue, null, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -146,6 +135,7 @@ public boolean start() { taskQueue, stickyTaskQueueName, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -162,6 +152,7 @@ public boolean start() { taskQueue, null, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -175,7 +166,7 @@ public boolean start() { pollers, this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { PollerBehaviorSimpleMaximum pollerBehavior = @@ -193,6 +184,7 @@ public boolean start() { taskQueue, stickyTaskQueueName, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, stickyQueueBalancer, @@ -202,7 +194,8 @@ public boolean start() { stickyPollerTracker), pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); @@ -232,46 +225,23 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean stickyQueueBalancer, options.getDrainStickyTaskQueueTimeout()) : CompletableFuture.completedFuture(null)) .thenCompose(ignore -> poller.shutdown(shutdownManager, interruptTasks)); - return CompletableFuture.allOf( - pollerShutdown.thenCompose( - ignore -> { - ShutdownWorkerRequest.Builder shutdownReq = - ShutdownWorkerRequest.newBuilder() - .setIdentity(options.getIdentity()) - .setNamespace(namespace) - .setTaskQueue(taskQueue) - .setWorkerInstanceKey(workerInstanceKey) - .setReason(GRACEFUL_SHUTDOWN_MESSAGE) - .addAllTaskQueueTypes(activeTaskQueueTypesSupplier.get()); - if (stickyTaskQueueName != null) { - shutdownReq.setStickyTaskQueue(stickyTaskQueueName); - } - if (heartbeatSupplier != null) { - shutdownReq.setWorkerHeartbeat( - heartbeatSupplier.get().toBuilder() - .setStatus(WorkerStatus.WORKER_STATUS_SHUTTING_DOWN) - .build()); - } - return shutdownManager.waitOnWorkerShutdownRequest( - service.futureStub().shutdownWorker(shutdownReq.build())); - }), - pollerShutdown - .thenCompose( - ignore -> - !interruptTasks - ? shutdownManager.waitForSupplierPermitsReleasedUnlimited( - slotSupplier, supplierName) - : CompletableFuture.completedFuture(null)) - .thenCompose( - ignore -> - pollTaskExecutor != null - ? pollTaskExecutor.shutdown(shutdownManager, interruptTasks) - : CompletableFuture.completedFuture(null)) - .exceptionally( - e -> { - log.error("Unexpected exception during shutdown", e); - return null; - })); + return pollerShutdown + .thenCompose( + ignore -> + !interruptTasks + ? shutdownManager.waitForSupplierPermitsReleasedUnlimited( + slotSupplier, supplierName) + : CompletableFuture.completedFuture(null)) + .thenCompose( + ignore -> + pollTaskExecutor != null + ? pollTaskExecutor.shutdown(shutdownManager, interruptTasks) + : CompletableFuture.completedFuture(null)) + .exceptionally( + e -> { + log.error("Unexpected exception during shutdown", e); + return null; + }); } @Override @@ -363,10 +333,6 @@ public WorkflowTaskDispatchHandle reserveWorkflowExecutor() { .orElse(null); } - public void setHeartbeatSupplier(Supplier supplier) { - this.heartbeatSupplier = supplier; - } - public TrackingSlotSupplier getSlotSupplier() { return slotSupplier; } diff --git a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java index ce599c6ad..19a52c65e 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java @@ -13,6 +13,7 @@ import io.temporal.api.worker.v1.WorkerHostInfo; import io.temporal.api.worker.v1.WorkerPollerInfo; import io.temporal.api.worker.v1.WorkerSlotsInfo; +import io.temporal.api.workflowservice.v1.ShutdownWorkerRequest; import io.temporal.client.WorkflowClient; import io.temporal.client.WorkflowClientOptions; import io.temporal.common.Experimental; @@ -27,6 +28,7 @@ import io.temporal.internal.worker.TaskCounter; import io.temporal.serviceclient.MetricsTag; import io.temporal.serviceclient.Version; +import io.temporal.serviceclient.WorkflowServiceStubs; import io.temporal.worker.tuning.*; import io.temporal.workflow.Functions; import io.temporal.workflow.Functions.Func; @@ -59,17 +61,22 @@ public final class Worker { private static final Logger log = LoggerFactory.getLogger(Worker.class); private final WorkerOptions options; private final String taskQueue; + private final String workerInstanceKey = UUID.randomUUID().toString(); private final List plugins; + private final WorkflowServiceStubs service; + private final String namespace; + private final String identity; + private final String stickyTaskQueueName; final SyncWorkflowWorker workflowWorker; final SyncActivityWorker activityWorker; final SyncNexusWorker nexusWorker; private final AtomicBoolean started = new AtomicBoolean(); private volatile boolean shuttingDown = false; - private final String workerInstanceKey = UUID.randomUUID().toString(); private volatile Instant startTime; private final WorkflowClientOptions clientOptions; private final @Nonnull WorkflowExecutorCache cache; private final Map previousHeartbeatSnapshots = new ConcurrentHashMap<>(); + private volatile Supplier heartbeatSupplier; private static final class TaskSnapshot { final int processed; @@ -106,22 +113,30 @@ private static final class TaskSnapshot { @Nonnull NamespaceCapabilities namespaceCapabilities) { Objects.requireNonNull(client, "client should not be null"); + Objects.requireNonNull(namespaceCapabilities, "namespaceCapabilities should not be null"); this.plugins = Objects.requireNonNull(plugins, "plugins should not be null"); Preconditions.checkArgument( !Strings.isNullOrEmpty(taskQueue), "taskQueue should not be an empty string"); this.taskQueue = taskQueue; + this.service = client.getWorkflowServiceStubs(); this.options = WorkerOptions.newBuilder(options).validateAndBuildWithDefaults(); this.clientOptions = client.getOptions(); this.cache = cache; factoryOptions = WorkerFactoryOptions.newBuilder(factoryOptions).validateAndBuildWithDefaults(); WorkflowClientOptions clientOptions = client.getOptions(); String namespace = clientOptions.getNamespace(); + this.namespace = namespace; Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); Scope taggedScope = metricsScope.tagged(tags); SingleWorkerOptions activityOptions = toActivityOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); if (this.options.isLocalActivityWorkerOnly()) { activityWorker = null; } else { @@ -149,7 +164,12 @@ private static final class TaskSnapshot { SingleWorkerOptions nexusOptions = toNexusOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); SlotSupplier nexusSlotSupplier = this.options.getWorkerTuner() == null ? new FixedSizeSlotSupplier<>(this.options.getMaxConcurrentNexusExecutionSize()) @@ -167,10 +187,16 @@ private static final class TaskSnapshot { clientOptions, taskQueue, contextPropagators, - taggedScope); + taggedScope, + workerInstanceKey); SingleWorkerOptions localActivityOptions = toLocalActivityOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); SlotSupplier workflowSlotSupplier = this.options.getWorkerTuner() == null @@ -183,18 +209,20 @@ private static final class TaskSnapshot { : this.options.getWorkerTuner().getLocalActivitySlotSupplier(); attachMetricsToResourceController(taggedScope, localActivitySlotSupplier); + this.identity = singleWorkerOptions.getIdentity(); + this.stickyTaskQueueName = + useStickyTaskQueue ? getStickyTaskQueueName(client.getOptions().getIdentity()) : null; + workflowWorker = new SyncWorkflowWorker( client, namespace, taskQueue, - workerInstanceKey, - this::getActiveTaskQueueTypes, singleWorkerOptions, localActivityOptions, runLocks, cache, - useStickyTaskQueue ? getStickyTaskQueueName(client.getOptions().getIdentity()) : null, + stickyTaskQueueName, workflowThreadExecutor, eagerActivityDispatcher, workflowSlotSupplier, @@ -454,19 +482,48 @@ void start() { } CompletableFuture shutdown(ShutdownManager shutdownManager, boolean interruptUserTasks) { - shuttingDown = true; - CompletableFuture workflowWorkerShutdownFuture = - workflowWorker.shutdown(shutdownManager, interruptUserTasks); - CompletableFuture nexusWorkerShutdownFuture = - nexusWorker.shutdown(shutdownManager, interruptUserTasks); - if (activityWorker != null) { - return CompletableFuture.allOf( - activityWorker.shutdown(shutdownManager, interruptUserTasks), - workflowWorkerShutdownFuture, - nexusWorkerShutdownFuture); - } else { - return CompletableFuture.allOf(workflowWorkerShutdownFuture, nexusWorkerShutdownFuture); + ShutdownWorkerRequest.Builder requestBuilder = + ShutdownWorkerRequest.newBuilder() + .setNamespace(namespace) + .setIdentity(identity) + .setWorkerInstanceKey(workerInstanceKey) + .setTaskQueue(taskQueue) + .setReason("graceful shutdown") + .addAllTaskQueueTypes(getActiveTaskQueueTypes()); + if (stickyTaskQueueName != null) { + requestBuilder.setStickyTaskQueue(stickyTaskQueueName); + } + if (heartbeatSupplier != null) { + requestBuilder.setWorkerHeartbeat( + heartbeatSupplier.get().toBuilder() + .setStatus(WorkerStatus.WORKER_STATUS_SHUTTING_DOWN) + .build()); } + CompletableFuture shutdownWorkerRpc = + shutdownManager.waitOnWorkerShutdownRequest( + service.futureStub().shutdownWorker(requestBuilder.build())); + + // When interrupting tasks (shutdownNow), fire the RPC but don't block on it — proceed to + // shut down pollers immediately. For graceful shutdown, wait for the RPC so the server can + // complete outstanding polls with empty responses before we start waiting on them. + CompletableFuture preShutdown = + interruptUserTasks ? CompletableFuture.completedFuture(null) : shutdownWorkerRpc; + + return preShutdown.thenCompose( + ignore -> { + CompletableFuture workflowWorkerShutdownFuture = + workflowWorker.shutdown(shutdownManager, interruptUserTasks); + CompletableFuture nexusWorkerShutdownFuture = + nexusWorker.shutdown(shutdownManager, interruptUserTasks); + if (activityWorker != null) { + return CompletableFuture.allOf( + activityWorker.shutdown(shutdownManager, interruptUserTasks), + workflowWorkerShutdownFuture, + nexusWorkerShutdownFuture); + } else { + return CompletableFuture.allOf(workflowWorkerShutdownFuture, nexusWorkerShutdownFuture); + } + }); } boolean isTerminated() { @@ -491,6 +548,10 @@ String getWorkerInstanceKey() { return workerInstanceKey; } + void setHeartbeatSupplier(Supplier supplier) { + this.heartbeatSupplier = supplier; + } + List getActiveTaskQueueTypes() { List types = new ArrayList<>(); if (workflowWorker.isAnyTypeSupported()) { @@ -826,8 +887,10 @@ private static SingleWorkerOptions toActivityOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setUsingVirtualThreads(options.isUsingVirtualThreadsOnActivityWorker()) .setPollerOptions( PollerOptions.newBuilder() @@ -848,8 +911,10 @@ private static SingleWorkerOptions toNexusOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -870,7 +935,8 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( WorkflowClientOptions clientOptions, String taskQueue, List contextPropagators, - Scope metricsScope) { + Scope metricsScope, + String workerInstanceKey) { Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); @@ -899,7 +965,8 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( } } - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -921,8 +988,10 @@ private static SingleWorkerOptions toLocalActivityOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -939,7 +1008,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( WorkerFactoryOptions factoryOptions, WorkerOptions options, WorkflowClientOptions clientOptions, - List contextPropagators) { + List contextPropagators, + String workerInstanceKey) { String buildId = null; if (options.getBuildId() != null) { buildId = options.getBuildId(); @@ -962,7 +1032,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( .setWorkerInterceptors(factoryOptions.getWorkerInterceptors()) .setMaxHeartbeatThrottleInterval(options.getMaxHeartbeatThrottleInterval()) .setDefaultHeartbeatThrottleInterval(options.getDefaultHeartbeatThrottleInterval()) - .setDeploymentOptions(options.getDeploymentOptions()); + .setDeploymentOptions(options.getDeploymentOptions()) + .setWorkerInstanceKey(workerInstanceKey); } /** diff --git a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java index 741990624..c9bb8eb21 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java @@ -268,17 +268,8 @@ public synchronized void start() { DescribeNamespaceRequest.newBuilder() .setNamespace(workflowClient.getOptions().getNamespace()) .build()); - if (describeNamespaceResponse.getNamespaceInfo().getCapabilities().getWorkerHeartbeats()) { - namespaceCapabilities.setWorkerHeartbeats(true); - } else { - log.debug( - "Server does not support worker heartbeats for namespace {}", - workflowClient.getOptions().getNamespace()); - } - - if (describeNamespaceResponse.getNamespaceInfo().getCapabilities().getPollerAutoscaling()) { - namespaceCapabilities.setPollerAutoscaling(true); - } + namespaceCapabilities.setFromCapabilities( + describeNamespaceResponse.getNamespaceInfo().getCapabilities()); // Build plugin execution chain (reverse order for proper nesting) Consumer startChain = WorkerFactory::doStart; @@ -321,7 +312,7 @@ private void doStart() { Supplier heartbeatSupplier = worker.buildHeartbeatCallback(workerGroupingKey); hbManager.registerWorker(namespace, worker.getWorkerInstanceKey(), heartbeatSupplier); - worker.workflowWorker.setHeartbeatSupplier(heartbeatSupplier); + worker.setHeartbeatSupplier(heartbeatSupplier); } } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java index 2ade97762..5faa34ca7 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java @@ -133,7 +133,7 @@ private AsyncPoller newPoller( pollTask, taskExecutor, options, - false, + new NamespaceCapabilities(), new NoopScope()); } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java new file mode 100644 index 000000000..ef0e93495 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java @@ -0,0 +1,150 @@ +package io.temporal.internal.worker; + +import static org.junit.Assert.*; + +import com.uber.m3.tally.NoopScope; +import io.temporal.api.namespace.v1.NamespaceInfo.Capabilities; +import io.temporal.worker.tuning.PollerBehaviorSimpleMaximum; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nonnull; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Tests that an in-flight poll survives shutdown when graceful poll shutdown is enabled, and is + * killed promptly when it is not. + */ +@RunWith(Parameterized.class) +public class GracefulPollShutdownTest { + + @Parameterized.Parameter public boolean graceful; + + @Parameterized.Parameters(name = "graceful={0}") + public static Object[] data() { + return new Object[] {true, false}; + } + + @Test(timeout = 10_000) + public void inflightPollSurvivesShutdownOnlyWhenGraceful() throws Exception { + NamespaceCapabilities capabilities = new NamespaceCapabilities(); + capabilities.setFromCapabilities( + Capabilities.newBuilder().setWorkerPollCompleteOnShutdown(graceful).build()); + + AtomicReference processedTask = new AtomicReference<>(); + CountDownLatch taskProcessedLatch = new CountDownLatch(1); + ShutdownableTaskExecutor taskExecutor = + new ShutdownableTaskExecutor() { + @Override + public void process(@Nonnull String task) { + processedTask.set(task); + taskProcessedLatch.countDown(); + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public CompletableFuture shutdown( + ShutdownManager shutdownManager, boolean interruptTasks) { + return CompletableFuture.completedFuture(null); + } + + @Override + public void awaitTermination(long timeout, TimeUnit unit) {} + }; + + // -- poll task: first call returns immediately, second blocks until released -- + CountDownLatch secondPollStarted = new CountDownLatch(1); + CountDownLatch releasePoll = new CountDownLatch(1); + + MultiThreadedPoller.PollTask pollTask = + new MultiThreadedPoller.PollTask() { + private int callCount = 0; + + @Override + public synchronized String poll() { + callCount++; + if (callCount == 1) { + return "task-1"; + } else if (callCount == 2) { + secondPollStarted.countDown(); + try { + releasePoll.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + return "task-2"; + } + // Subsequent calls just block until interrupted (simulates long poll) + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return null; + } + }; + + // -- create poller with 1 thread so polls are sequential -- + MultiThreadedPoller poller = + new MultiThreadedPoller<>( + "test-identity", + pollTask, + taskExecutor, + PollerOptions.newBuilder() + .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) + .setPollThreadNamePrefix("test-poller") + .build(), + new NoopScope(), + capabilities); + + poller.start(); + + // Wait for the first task to be processed (proves poller is running) + assertTrue("first task should be processed", taskProcessedLatch.await(5, TimeUnit.SECONDS)); + assertEquals("task-1", processedTask.get()); + + // Wait for the second poll to be in-flight + assertTrue("second poll should start", secondPollStarted.await(5, TimeUnit.SECONDS)); + + // Trigger shutdown (don't interrupt tasks) + ShutdownManager shutdownManager = new ShutdownManager(); + CompletableFuture shutdownFuture = poller.shutdown(shutdownManager, false); + + if (graceful) { + // In graceful mode the poller waits for the in-flight poll to complete. + // The shutdown should NOT have completed yet since the poll is still blocked. + assertFalse("shutdown should not complete while poll is in-flight", shutdownFuture.isDone()); + + // Simulate the server returning the poll response (as it would after ShutdownWorker RPC) + releasePoll.countDown(); + + // Wait for shutdown to complete - the poll should return "task-2" and be processed + shutdownFuture.get(5, TimeUnit.SECONDS); + + assertEquals("task-2", processedTask.get()); + } else { + // In legacy mode the poller forcefully interrupts in-flight polls. + // Shutdown should complete quickly without releasing the blocked poll. + shutdownFuture.get(5, TimeUnit.SECONDS); + + // The second task should NOT have been processed since the poll was killed. + assertNotEquals( + "task-2 should not be processed in legacy mode", "task-2", processedTask.get()); + } + + shutdownManager.close(); + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java index e4223c0b5..c6f11a61a 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java @@ -80,6 +80,7 @@ public void supplierIsCalledAppropriately() { TASK_QUEUE, "stickytaskqueue", "", + "test-instance-key", new WorkerVersioningOptions("", false, null), trackingSS, stickyQueueBalancer, @@ -172,6 +173,7 @@ public void asyncPollerSupplierIsCalledAppropriately() throws Exception { TASK_QUEUE, null, "", + "test-instance-key", new WorkerVersioningOptions("", false, null), trackingSS, metricsScope, diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java index 59538ac8b..ab806c960 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java @@ -68,6 +68,7 @@ public void stickyQueueBacklogResetTest() { "taskqueue", "stickytaskqueue", "", + "test-instance-key", new WorkerVersioningOptions("", false, null), slotSupplier, stickyQueueBalancer, @@ -97,6 +98,7 @@ public void stickyQueueBacklogResetTest() { .setKind(TaskQueueKind.TASK_QUEUE_KIND_STICKY) .build()) .setNamespace("default") + .setWorkerInstanceKey("test-instance-key") .build()))) .thenReturn(pollResponse); if (throwOnPoll) { diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java index d4e6e947c..d4f1824c2 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java @@ -14,7 +14,6 @@ import com.uber.m3.util.ImmutableMap; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.common.v1.WorkflowType; -import io.temporal.api.enums.v1.TaskQueueType; import io.temporal.api.workflowservice.v1.*; import io.temporal.common.reporter.TestStatsReporter; import io.temporal.internal.common.InternalUtils; @@ -30,12 +29,8 @@ import io.temporal.worker.tuning.SlotSupplier; import io.temporal.worker.tuning.WorkflowSlotInfo; import java.time.Duration; -import java.util.Arrays; -import java.util.List; import java.util.UUID; import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; import org.junit.Test; import org.mockito.stubbing.Answer; import org.slf4j.Logger; @@ -74,12 +69,11 @@ public void concurrentPollRequestLockTest() throws Exception { client, "default", "task_queue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky_task_queue", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(3)) @@ -246,12 +240,11 @@ public void respondWorkflowTaskFailureMetricTest() throws Exception { client, "default", "task_queue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky_task_queue", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -391,12 +384,11 @@ public boolean isAnyTypeSupported() { client, "default", "taskQueue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -444,80 +436,6 @@ public boolean isAnyTypeSupported() { worker.shutdown(new ShutdownManager(), true).get(); } - @Test - public void activeTaskQueueTypesEvaluatedAtShutdownTime() throws Exception { - WorkflowServiceStubs client = mock(WorkflowServiceStubs.class); - when(client.getServerCapabilities()) - .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); - - WorkflowRunLockManager runLockManager = new WorkflowRunLockManager(); - Scope metricsScope = new NoopScope(); - WorkflowExecutorCache cache = new WorkflowExecutorCache(10, runLockManager, metricsScope); - SlotSupplier slotSupplier = new FixedSizeSlotSupplier<>(10); - - WorkflowTaskHandler taskHandler = mock(WorkflowTaskHandler.class); - when(taskHandler.isAnyTypeSupported()).thenReturn(true); - - // Supplier that starts with WORKFLOW only, then adds NEXUS later - AtomicReference> typesRef = - new AtomicReference<>(Arrays.asList(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); - Supplier> supplier = typesRef::get; - - EagerActivityDispatcher eagerActivityDispatcher = mock(EagerActivityDispatcher.class); - WorkflowWorker worker = - new WorkflowWorker( - client, - "default", - "task_queue", - "test-worker-instance-key", - supplier, - null, - SingleWorkerOptions.newBuilder() - .setIdentity("test_identity") - .setBuildId(UUID.randomUUID().toString()) - .setPollerOptions( - PollerOptions.newBuilder() - .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) - .build()) - .setMetricsScope(metricsScope) - .build(), - runLockManager, - cache, - taskHandler, - eagerActivityDispatcher, - slotSupplier, - new NamespaceCapabilities()); - - // Simulate registering Nexus after construction - typesRef.set( - Arrays.asList( - TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW, - TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY, - TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); - - WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = - mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); - when(client.futureStub()).thenReturn(futureStub); - when(futureStub.shutdownWorker(any(ShutdownWorkerRequest.class))) - .thenReturn(Futures.immediateFuture(ShutdownWorkerResponse.newBuilder().build())); - - worker.shutdown(new ShutdownManager(), true).get(5, TimeUnit.SECONDS); - - org.mockito.ArgumentCaptor captor = - org.mockito.ArgumentCaptor.forClass(ShutdownWorkerRequest.class); - verify(futureStub).shutdownWorker(captor.capture()); - List shutdownTypes = captor.getValue().getTaskQueueTypesList(); - assertTrue( - "ShutdownWorkerRequest should include NEXUS type added after construction", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); - assertTrue( - "ShutdownWorkerRequest should include WORKFLOW type", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); - assertTrue( - "ShutdownWorkerRequest should include ACTIVITY type", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY)); - } - private ReplayWorkflowFactory setUpMockWorkflowFactory() throws Throwable { ReplayWorkflow mockWorkflow = mock(ReplayWorkflow.class); ReplayWorkflowFactory mockFactory = mock(ReplayWorkflowFactory.class); diff --git a/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java new file mode 100644 index 000000000..d48e39725 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java @@ -0,0 +1,141 @@ +package io.temporal.worker; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.google.common.util.concurrent.Futures; +import com.uber.m3.tally.NoopScope; +import com.uber.m3.tally.Scope; +import io.nexusrpc.handler.OperationHandler; +import io.nexusrpc.handler.OperationImpl; +import io.nexusrpc.handler.ServiceImpl; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.api.enums.v1.TaskQueueType; +import io.temporal.api.workflowservice.v1.GetSystemInfoResponse; +import io.temporal.api.workflowservice.v1.ShutdownWorkerRequest; +import io.temporal.api.workflowservice.v1.ShutdownWorkerResponse; +import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.internal.sync.WorkflowThreadExecutor; +import io.temporal.internal.worker.NamespaceCapabilities; +import io.temporal.internal.worker.ShutdownManager; +import io.temporal.internal.worker.WorkflowExecutorCache; +import io.temporal.internal.worker.WorkflowRunLockManager; +import io.temporal.serviceclient.WorkflowServiceStubs; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import io.temporal.workflow.shared.TestNexusServices; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +public class WorkerShutdownTest { + + @WorkflowInterface + public interface TestWorkflow { + @WorkflowMethod + void run(); + } + + public static class TestWorkflowImpl implements TestWorkflow { + @Override + public void run() {} + } + + @ActivityInterface + public interface TestActivity { + @ActivityMethod + void doThing(); + } + + public static class TestActivityImpl implements TestActivity { + @Override + public void doThing() {} + } + + @ServiceImpl(service = TestNexusServices.TestNexusService1.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return OperationHandler.sync((ctx, details, now) -> "Hello " + now); + } + } + + /** + * Verifies that the active task queue types in the ShutdownWorkerRequest are evaluated at + * shutdown time, not at Worker construction time. Types registered after construction must be + * reflected in the request. + */ + @Test + public void activeTaskQueueTypesEvaluatedAtShutdownTime() throws Exception { + WorkflowServiceStubs service = mock(WorkflowServiceStubs.class); + when(service.getServerCapabilities()) + .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); + + WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = + mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); + when(service.futureStub()).thenReturn(futureStub); + when(futureStub.shutdownWorker(any(ShutdownWorkerRequest.class))) + .thenReturn(Futures.immediateFuture(ShutdownWorkerResponse.newBuilder().build())); + + WorkflowServiceGrpc.WorkflowServiceBlockingStub blockingStub = + mock(WorkflowServiceGrpc.WorkflowServiceBlockingStub.class); + when(service.blockingStub()).thenReturn(blockingStub); + when(blockingStub.withOption(any(), any())).thenReturn(blockingStub); + + WorkflowClient client = mock(WorkflowClient.class); + when(client.getWorkflowServiceStubs()).thenReturn(service); + when(client.getOptions()) + .thenReturn( + WorkflowClientOptions.newBuilder() + .setNamespace("test-ns") + .validateAndBuildWithDefaults()); + + Scope metricsScope = new NoopScope(); + WorkflowRunLockManager runLocks = new WorkflowRunLockManager(); + WorkflowExecutorCache cache = new WorkflowExecutorCache(10, runLocks, metricsScope); + WorkflowThreadExecutor wfThreadExecutor = mock(WorkflowThreadExecutor.class); + + Worker worker = + new Worker( + client, + "test-task-queue", + WorkerFactoryOptions.newBuilder().build(), + WorkerOptions.newBuilder().build(), + metricsScope, + runLocks, + cache, + false, + wfThreadExecutor, + Collections.emptyList(), + Collections.emptyList(), + new NamespaceCapabilities()); + + // Register types AFTER worker construction. The request built by shutdown should reflect + // these registrations, proving that getActiveTaskQueueTypes() is evaluated lazily. + worker.registerWorkflowImplementationTypes(TestWorkflowImpl.class); + worker.registerActivitiesImplementations(new TestActivityImpl()); + worker.registerNexusServiceImplementation(new TestNexusServiceImpl()); + + worker.shutdown(new ShutdownManager(), true).get(5, TimeUnit.SECONDS); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(ShutdownWorkerRequest.class); + verify(futureStub).shutdownWorker(captor.capture()); + List shutdownTypes = captor.getValue().getTaskQueueTypesList(); + assertTrue( + "ShutdownWorkerRequest should include WORKFLOW type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); + assertTrue( + "ShutdownWorkerRequest should include ACTIVITY type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY)); + assertTrue( + "ShutdownWorkerRequest should include NEXUS type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); + } +}