diff --git a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionImpl.java b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionImpl.java index a2ef4f619821..f79142171f16 100644 --- a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionImpl.java +++ b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionImpl.java @@ -324,7 +324,7 @@ VRpc newCall(VRpcDescriptor descriptor) { long rpcId = nextRpcId; nextRpcId = Math.incrementExact(nextRpcId); - return new VRpcImpl<>(this, descriptor, rpcId, stream.getPeerInfo(), debugTagTracer); + return new VRpcImpl<>(this, descriptor, rpcId, stream.getPeerInfo()); } } diff --git a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionPoolImpl.java b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionPoolImpl.java index 35884cb74319..d06065946aa0 100644 --- a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionPoolImpl.java +++ b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/SessionPoolImpl.java @@ -634,6 +634,9 @@ public void start(ReqT req, VRpcCallContext ctx, VRpcListener listener) { this.deadlineMonitor = monitorDeadline(executorService, ctx.getOperationInfo().getDeadline()); synchronized (SessionPoolImpl.this) { + if (isCancelled) { + return; + } if (SessionPoolImpl.this.poolState != PoolState.STARTED) { listener.onClose( VRpcResult.createUncommitedError( diff --git a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/VRpcImpl.java b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/VRpcImpl.java index a2d841bd1a68..82760001c06c 100644 --- a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/VRpcImpl.java +++ b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/internal/session/VRpcImpl.java @@ -20,14 +20,13 @@ import com.google.bigtable.v2.VirtualRpcRequest; import com.google.bigtable.v2.VirtualRpcRequest.Metadata; import com.google.bigtable.v2.VirtualRpcResponse; -import com.google.cloud.bigtable.data.v2.internal.csm.tracers.DebugTagTracer; import com.google.cloud.bigtable.data.v2.internal.middleware.VRpc; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.Message; import com.google.protobuf.MessageLite; import com.google.protobuf.util.Durations; import io.grpc.Status; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -61,7 +60,9 @@ interface VRpcSessionApi { private enum State { NEW, + STARTING, STARTED, + CANCELLED, CLOSED } @@ -71,91 +72,118 @@ private enum State { private VRpcListener listener; private PeerInfo peerInfo; - private AtomicReference state; + private final Object lock = new Object(); - private final DebugTagTracer debugTagTracer; + @GuardedBy("lock") + private State state = State.NEW; + + @GuardedBy("lock") + private Status cancelStatus = null; public VRpcImpl( VRpcSessionApi session, VRpcDescriptor desc, long rpcId, - PeerInfo peerInfo, - DebugTagTracer debugTagTracer) { + PeerInfo peerInfo) { this.session = session; this.desc = desc; this.rpcId = rpcId; - this.state = new AtomicReference<>(State.NEW); this.peerInfo = peerInfo; - this.debugTagTracer = debugTagTracer; } @Override public void start(ReqT req, VRpcCallContext ctx, VRpcListener listener) { this.listener = listener; - Status status; - boolean retryable = true; - - if (!state.compareAndSet(State.NEW, State.STARTED)) { - status = Status.INTERNAL.withDescription("VRpc already started in state: " + state.get()); - retryable = false; - } else if (ctx.getOperationInfo().getDeadline().timeRemaining(TimeUnit.MICROSECONDS) - < TimeUnit.MILLISECONDS.toMicros(1)) { - // Don't send RPCs that don't have any hope of succeeding - status = - Status.DEADLINE_EXCEEDED.withDescription("Remaining deadline is too short to send RPC"); - retryable = false; - } else { - Metadata vRpcMetadata = - Metadata.newBuilder() - .setAttemptNumber(ctx.getOperationInfo().getAttemptNumber()) - .setTraceparent(ctx.getTraceParent()) - .build(); - ctx.getTracer().onRequestSent(peerInfo); - status = - session.startRpc( - this, - VirtualRpcRequest.newBuilder() - .setRpcId(rpcId) - .setMetadata(vRpcMetadata) - .setDeadline( - Durations.fromNanos( - ctx.getOperationInfo().getDeadline().timeRemaining(TimeUnit.NANOSECONDS))) - .setPayload(desc.encode(req)) - .build()); - // if status is not OK, the session might not be ready and the vRPC can be retried on a - // different session - } + Status status = null; + VirtualRpcRequest request = null; - if (!status.isOk()) { - debugTagTracer.checkPrecondition( - state.compareAndSet(State.STARTED, State.CLOSED), - "vrpc_incorrect_start_state", - "VRpc has incorrect state. Expected to be started but was %s", - state); - // TODO: loop through the session executor - if (retryable) { - listener.onClose(VRpcResult.createUncommitedError(status)); + synchronized (lock) { + if (state == State.CANCELLED) { + state = State.CLOSED; + status = cancelStatus; + } else if (state != State.NEW) { + status = Status.INTERNAL.withDescription("VRpc already started in state: " + state); + } else if (ctx.getOperationInfo().getDeadline().timeRemaining(TimeUnit.MICROSECONDS) + < TimeUnit.MILLISECONDS.toMicros(1)) { + // Don't send RPCs that don't have any hope of succeeding + state = State.CLOSED; + status = + Status.DEADLINE_EXCEEDED.withDescription("Remaining deadline is too short to send RPC"); } else { - listener.onClose(VRpcResult.createRejectedError(status)); + state = State.STARTING; + Metadata vRpcMetadata = + Metadata.newBuilder() + .setAttemptNumber(ctx.getOperationInfo().getAttemptNumber()) + .setTraceparent(ctx.getTraceParent()) + .build(); + ctx.getTracer().onRequestSent(peerInfo); + request = + VirtualRpcRequest.newBuilder() + .setRpcId(rpcId) + .setMetadata(vRpcMetadata) + .setDeadline( + Durations.fromNanos( + ctx.getOperationInfo().getDeadline().timeRemaining(TimeUnit.NANOSECONDS))) + .setPayload(desc.encode(req)) + .build(); } } + + if (status != null) { + listener.onClose(VRpcResult.createRejectedError(status)); + return; + } + + Status startRpcStatus = session.startRpc(this, request); + + if (startRpcStatus.isOk()) { + boolean shouldCancelSession = false; + Status localCancel = null; + synchronized (lock) { + if (state == State.STARTING) { + state = State.STARTED; + } else if (state == State.CANCELLED) { + shouldCancelSession = true; + localCancel = cancelStatus; + } + } + if (shouldCancelSession) { + session.cancelRpc(rpcId, localCancel.getDescription(), localCancel.getCause()); + } + } else { + VRpcResult result; + synchronized (lock) { + if (state == State.CANCELLED) { + result = VRpcResult.createRejectedError(cancelStatus); + } else { + result = VRpcResult.createUncommitedError(startRpcStatus); + } + state = State.CLOSED; + } + listener.onClose(result); + } } void handleSessionClose(VRpcResult result) { - if (!state.compareAndSet(State.STARTED, State.CLOSED)) { - logger.warning("tried to close a vRPC after it was already closed state: " + state.get()); - return; + synchronized (lock) { + if (state == State.CLOSED) { + return; + } + state = State.CLOSED; } listener.onClose(result); } void handleResponse(VirtualRpcResponse response) { - if (!state.compareAndSet(State.STARTED, State.CLOSED)) { - // This can happen if the call was cancelled just before the response arrived. - // Silently ignore it. - return; + synchronized (lock) { + if (state != State.STARTED && state != State.STARTING) { + // This can happen if the call was cancelled just before the response arrived. + // Silently ignore it. + return; + } + state = State.CLOSED; } // TODO: handle streaming @@ -186,8 +214,11 @@ void handleResponse(VirtualRpcResponse response) { } void handleError(VRpcResult result) { - if (state.getAndSet(State.CLOSED) == State.CLOSED) { - return; + synchronized (lock) { + if (state == State.CLOSED) { + return; + } + state = State.CLOSED; } listener.onClose(result); @@ -195,7 +226,27 @@ void handleError(VRpcResult result) { @Override public void cancel(@Nullable String message, @Nullable Throwable cause) { - session.cancelRpc(rpcId, message, cause); + boolean isRealCall = false; + synchronized (lock) { + if (state == State.NEW || state == State.STARTING) { + state = State.CANCELLED; + Status status = Status.CANCELLED; + if (message != null) { + status = status.withDescription(message); + } + if (cause != null) { + status = status.withCause(cause); + } + cancelStatus = status; + return; + } else if (state == State.STARTED) { + isRealCall = true; + } + // ignore closed vRPCs + } + if (isRealCall) { + session.cancelRpc(rpcId, message, cause); + } } @Override