diff --git a/grpclb/BUILD.bazel b/grpclb/BUILD.bazel index b69fb234733..a18795d0485 100644 --- a/grpclb/BUILD.bazel +++ b/grpclb/BUILD.bazel @@ -9,6 +9,7 @@ java_library( deps = [ ":load_balancer_java_grpc", "//api", + "//context", "//core:internal", "//core:util", "//stub", diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 1a8dec36e38..65293d24511 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -23,6 +23,7 @@ import com.google.common.base.Stopwatch; import io.grpc.Attributes; import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.Context; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; @@ -45,6 +46,7 @@ class GrpclbLoadBalancer extends LoadBalancer { private static final GrpclbConfig DEFAULT_CONFIG = GrpclbConfig.create(Mode.ROUND_ROBIN); private final Helper helper; + private final Context context; private final TimeProvider time; private final Stopwatch stopwatch; private final SubchannelPool subchannelPool; @@ -58,11 +60,13 @@ class GrpclbLoadBalancer extends LoadBalancer { GrpclbLoadBalancer( Helper helper, + Context context, SubchannelPool subchannelPool, TimeProvider time, Stopwatch stopwatch, BackoffPolicy.Provider backoffPolicyProvider) { this.helper = checkNotNull(helper, "helper"); + this.context = checkNotNull(context, "context"); this.time = checkNotNull(time, "time provider"); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); @@ -131,7 +135,7 @@ private void recreateStates() { checkState(grpclbState == null, "Should've been cleared"); grpclbState = new GrpclbState( - config, helper, subchannelPool, time, stopwatch, backoffPolicyProvider); + config, helper, context, subchannelPool, time, stopwatch, backoffPolicyProvider); } @Override diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java index badcfdcec7c..fa9b6963f33 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java @@ -17,6 +17,7 @@ package io.grpc.grpclb; import com.google.common.base.Stopwatch; +import io.grpc.Context; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; @@ -62,6 +63,7 @@ public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { return new GrpclbLoadBalancer( helper, + Context.ROOT, new CachedSubchannelPool(helper), TimeProvider.SYSTEM_TIME_PROVIDER, Stopwatch.createUnstarted(), diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index 8c638b979ed..59fa67dc2dc 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -35,6 +35,7 @@ import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; +import io.grpc.Context; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; @@ -132,6 +133,7 @@ enum Mode { private final String serviceName; private final Helper helper; + private final Context context; private final SynchronizationContext syncContext; @Nullable private final SubchannelPool subchannelPool; @@ -182,12 +184,14 @@ enum Mode { GrpclbState( GrpclbConfig config, Helper helper, + Context context, SubchannelPool subchannelPool, TimeProvider time, Stopwatch stopwatch, BackoffPolicy.Provider backoffPolicyProvider) { this.config = checkNotNull(config, "config"); this.helper = checkNotNull(helper, "helper"); + this.context = checkNotNull(context, "context"); this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); if (config.getMode() == Mode.ROUND_ROBIN) { this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool"); @@ -368,7 +372,12 @@ private void startLbRpc() { checkState(lbStream == null, "previous lbStream has not been cleared yet"); LoadBalancerGrpc.LoadBalancerStub stub = LoadBalancerGrpc.newStub(lbCommChannel); lbStream = new LbStream(stub); - lbStream.start(); + Context prevContext = context.attach(); + try { + lbStream.start(); + } finally { + context.detach(prevContext); + } stopwatch.reset().start(); LoadBalanceRequest initRequest = LoadBalanceRequest.newBuilder() diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index 0c194ae84c9..7dba20be1d0 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -55,6 +55,8 @@ import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; +import io.grpc.Context; +import io.grpc.Context.CancellableContext; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; @@ -229,6 +231,7 @@ public Void answer(InvocationOnMock invocation) { when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2); balancer = new GrpclbLoadBalancer( helper, + Context.ROOT, subchannelPool, fakeClock.getTimeProvider(), fakeClock.getStopwatchSupplier().get(), @@ -2683,6 +2686,39 @@ public void grpclbWorking_lbSendsFallbackMessage() { .inOrder(); } + @Test + public void useIndependentRpcContext() { + // Simulates making RPCs within the context of an inbound RPC. + CancellableContext cancellableContext = Context.current().withCancellation(); + Context prevContext = cancellableContext.attach(); + try { + List backendList = createResolvedBackendAddresses(2); + List grpclbBalancerList = createResolvedBalancerAddresses(2); + deliverResolvedAddresses(backendList, grpclbBalancerList); + + verify(helper).createOobChannel(eq(xattr(grpclbBalancerList)), + eq(lbAuthority(0) + NO_USE_AUTHORITY_SUFFIX)); + verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture()); + StreamObserver lbResponseObserver = lbResponseObserverCaptor.getValue(); + assertEquals(1, lbRequestObservers.size()); + StreamObserver lbRequestObserver = lbRequestObservers.poll(); + verify(lbRequestObserver).onNext( + eq(LoadBalanceRequest.newBuilder() + .setInitialRequest( + InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) + .build())); + lbResponseObserver.onNext(buildInitialResponse()); + + // The inbound RPC finishes and closes its context. The outbound RPC's control plane RPC + // should not be impacted (no retry). + cancellableContext.close(); + assertEquals(0, fakeClock.numPendingTasks(LB_RPC_RETRY_TASK_FILTER)); + verifyNoMoreInteractions(mockLbService); + } finally { + cancellableContext.detach(prevContext); + } + } + private void deliverSubchannelState( final Subchannel subchannel, final ConnectivityStateInfo newState) { ((FakeSubchannel) subchannel).updateState(newState);