From 158b06d2e23b045aa8aaf69e303920bf4c55f0b7 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Fri, 19 Nov 2021 12:33:21 -0500 Subject: [PATCH 01/17] merge RefreshingManagedChannel & SafeShutdownManagedChannel into ChannelPool. The eventual goal is to allow channel pool to safely add remove channels. The code has been refactored roughly as: - SafeShutdownManagedChannel is now ChannelPool.Entry and ReleasingClientCall - RefreshingManagedChannel has been merged into ChannelPool as a pair of functions scheduleNextRefresh & refresh --- .../com/google/api/gax/grpc/ChannelPool.java | 294 +++++++++++++++--- .../google/api/gax/grpc/ChannelPoolTest.java | 8 +- 2 files changed, 258 insertions(+), 44 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index d6b85275b..010b722e6 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -32,32 +32,57 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.grpc.CallOptions; +import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.ManagedChannel; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Status; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.Nullable; +import org.threeten.bp.Duration; /** * A {@link ManagedChannel} that will send requests round robin via a set of channels. * + *

In addition to spreading requests over a set of child connections, the pool will also actively + * manage the lifecycle of the channels. Currently lifecycle management is limited to pre-emptively + * replacing channels every hour. In the future it will dynamically size the pool based on number of + * outstanding requests. + * *

Package-private for internal use. */ class ChannelPool extends ManagedChannel { + private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName()); + // size greater than 1 to allow multiple channel to refresh at the same time // size not too large so refreshing channels doesn't use too many threads private static final int CHANNEL_REFRESH_EXECUTOR_SIZE = 2; - private final ImmutableList channels; + private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50); + private static final double JITTER_PERCENTAGE = 0.15; + + // A copy on write list of child channels. + private final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; // if set, ChannelPool will manage the life cycle of channelRefreshExecutorService - @Nullable private ScheduledExecutorService channelRefreshExecutorService; + @Nullable private final ScheduledExecutorService channelRefreshExecutorService; + private final ChannelFactory channelFactory; + + private volatile ScheduledFuture nextScheduledRefresh = null; /** * Factory method to create a non-refreshing channel pool @@ -66,12 +91,8 @@ class ChannelPool extends ManagedChannel { * @param channelFactory method to create the channels * @return ChannelPool of non refreshing channels */ - static ChannelPool create(int poolSize, final ChannelFactory channelFactory) throws IOException { - List channels = new ArrayList<>(); - for (int i = 0; i < poolSize; i++) { - channels.add(channelFactory.createSingleChannel()); - } - return new ChannelPool(channels, null); + static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IOException { + return new ChannelPool(channelFactory, poolSize, null); } /** @@ -88,14 +109,10 @@ static ChannelPool create(int poolSize, final ChannelFactory channelFactory) thr @VisibleForTesting static ChannelPool createRefreshing( int poolSize, - final ChannelFactory channelFactory, + ChannelFactory channelFactory, ScheduledExecutorService channelRefreshExecutorService) throws IOException { - List channels = new ArrayList<>(); - for (int i = 0; i < poolSize; i++) { - channels.add(new RefreshingManagedChannel(channelFactory, channelRefreshExecutorService)); - } - return new ChannelPool(channels, channelRefreshExecutorService); + return new ChannelPool(channelFactory, poolSize, channelRefreshExecutorService); } /** @@ -114,15 +131,30 @@ static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFa /** * Initializes the channel pool. Assumes that all channels have the same authority. * - * @param channels a List of channels to pool. + * @param channelFactory method to create the channels + * @param poolSize number of channels in the pool * @param channelRefreshExecutorService periodically refreshes the channels */ private ChannelPool( - List channels, - @Nullable ScheduledExecutorService channelRefreshExecutorService) { - this.channels = ImmutableList.copyOf(channels); - authority = channels.get(0).authority(); + ChannelFactory channelFactory, + int poolSize, + @Nullable ScheduledExecutorService channelRefreshExecutorService) + throws IOException { + this.channelFactory = channelFactory; + + ImmutableList.Builder initialListBuilder = ImmutableList.builder(); + + for (int i = 0; i < poolSize; i++) { + initialListBuilder.add(new Entry(channelFactory.createSingleChannel())); + } + + entries.set(initialListBuilder.build()); + authority = entries.get().get(0).channel.authority(); this.channelRefreshExecutorService = channelRefreshExecutorService; + + if (channelRefreshExecutorService != null) { + nextScheduledRefresh = scheduleNextRefresh(); + } } /** {@inheritDoc} */ @@ -140,17 +172,26 @@ public String authority() { @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { - return getNextChannel().newCall(methodDescriptor, callOptions); + return getChannel(indexTicker.getAndIncrement()).newCall(methodDescriptor, callOptions); + } + + Channel getChannel(int affinity) { + return new AffinityChannel(affinity); } /** {@inheritDoc} */ @Override public ManagedChannel shutdown() { - for (ManagedChannel channelWrapper : channels) { - channelWrapper.shutdown(); + List localEntries = entries.get(); + for (Entry entry : localEntries) { + entry.channel.shutdown(); + } + if (nextScheduledRefresh != null) { + nextScheduledRefresh.cancel(true); } if (channelRefreshExecutorService != null) { - channelRefreshExecutorService.shutdown(); + // shutdownNow will cancel scheduled tasks + channelRefreshExecutorService.shutdownNow(); } return this; } @@ -158,8 +199,9 @@ public ManagedChannel shutdown() { /** {@inheritDoc} */ @Override public boolean isShutdown() { - for (ManagedChannel channel : channels) { - if (!channel.isShutdown()) { + List localEntries = entries.get(); + for (Entry entry : localEntries) { + if (!entry.channel.isShutdown()) { return false; } } @@ -172,8 +214,9 @@ public boolean isShutdown() { /** {@inheritDoc} */ @Override public boolean isTerminated() { - for (ManagedChannel channel : channels) { - if (!channel.isTerminated()) { + List localEntries = entries.get(); + for (Entry entry : localEntries) { + if (!entry.channel.isTerminated()) { return false; } } @@ -186,8 +229,12 @@ public boolean isTerminated() { /** {@inheritDoc} */ @Override public ManagedChannel shutdownNow() { - for (ManagedChannel channel : channels) { - channel.shutdownNow(); + List localEntries = entries.get(); + for (Entry entry : localEntries) { + entry.channel.shutdownNow(); + } + if (nextScheduledRefresh != null) { + nextScheduledRefresh.cancel(true); } if (channelRefreshExecutorService != null) { channelRefreshExecutorService.shutdownNow(); @@ -199,12 +246,13 @@ public ManagedChannel shutdownNow() { @Override public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { long endTimeNanos = System.nanoTime() + unit.toNanos(timeout); - for (ManagedChannel channel : channels) { + List localEntries = entries.get(); + for (Entry entry : localEntries) { long awaitTimeNanos = endTimeNanos - System.nanoTime(); if (awaitTimeNanos <= 0) { break; } - channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS); + entry.channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS); } if (channelRefreshExecutorService != null) { long awaitTimeNanos = endTimeNanos - System.nanoTime(); @@ -213,16 +261,61 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE return isTerminated(); } + /** Scheduling loop. */ + private ScheduledFuture scheduleNextRefresh() { + long delayPeriod = REFRESH_PERIOD.toMillis(); + long jitter = (long) ((Math.random() - 0.5) * JITTER_PERCENTAGE * delayPeriod); + long delay = jitter + delayPeriod; + return channelRefreshExecutorService.schedule( + () -> { + try { + refresh(); + } finally { + scheduleNextRefresh(); + } + }, + delay, + TimeUnit.MILLISECONDS); + } + /** - * Performs a simple round robin on the list of {@link ManagedChannel}s in the {@code channels} - * list. + * Replace all of the channels in the channel pool with fresh ones. This is meant to mitigate the + * hourly GFE disconnects by giving clients the ability to prime the channel on reconnect. * - * @return A {@link ManagedChannel} that can be used for a single RPC call. + *

This is done on a best effort basis. If the replacement channel fails to construct, the old + * channel will continue to be used. */ - private ManagedChannel getNextChannel() { - return getChannel(indexTicker.getAndIncrement()); + private void refresh() { + List localEntries = entries.get(); + ArrayList newEntries = new ArrayList<>(localEntries); + ArrayList removedEntries = new ArrayList<>(); + + for (int i = 0; i < localEntries.size(); i++) { + try { + Entry removed = newEntries.set(i, new Entry(channelFactory.createSingleChannel())); + removedEntries.add(removed); + } catch (IOException e) { + LOG.log(Level.WARNING, "Failed to refresh channel, leaving old channel", e); + } + } + entries.set(ImmutableList.copyOf(newEntries)); + + removedEntries.forEach(Entry::requestShutdown); } + /** + * Get and retain a Channel Entry. The returned will have its rpc count incremented, preventing it + * from getting recycled. + */ + Entry getRetainedEntry(int affinity) { + for (int i = 0; i < 5; i++) { + Entry entry = getEntry(affinity); + if (entry.retain()) { + return entry; + } + } + throw new IllegalStateException("Failed to retain a channel"); + } /** * Returns one of the channels managed by this pool. The pool continues to "own" the channel, and * the caller should not shut it down. @@ -231,13 +324,136 @@ private ManagedChannel getNextChannel() { * reverse is not true: Two calls with different affinities might return the same channel. * However, the implementation should attempt to spread load evenly. */ - ManagedChannel getChannel(int affinity) { - int index = affinity % channels.size(); + private Entry getEntry(int affinity) { + List localEntries = entries.get(); + + int index = affinity % localEntries.size(); index = Math.abs(index); // If index is the most negative int, abs(index) is still negative. if (index < 0) { index = 0; } - return channels.get(index); + + return localEntries.get(index); + } + + /** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */ + private static class Entry { + private final ManagedChannel channel; + private final AtomicInteger outstandingRpcs = new AtomicInteger(0); + + // Flag that the channel should be closed once all of the outstanding RPC complete. + private final AtomicBoolean shutdownRequested = new AtomicBoolean(); + // Flag that the channel has been closed. + private final AtomicBoolean shutdownInitiated = new AtomicBoolean(); + + private Entry(ManagedChannel channel) { + this.channel = channel; + } + + /** + * Try to increment the outstanding RPC count. The method will return false if the channel is + * closing and the caller should pick a different channel. If the method returned true, the + * channel has been successfully retained and it is the responsibility of the caller to release + * it. + */ + private boolean retain() { + // register desire to start RPC + outstandingRpcs.incrementAndGet(); + + // abort if the channel is closing + if (shutdownRequested.get()) { + release(); + return false; + } + return true; + } + + /** + * Notify the channel that the number of outstanding RPCs has decreased. If shutdown has been + * previously requested, this method will shutdown the channel if its the last outstanding RPC. + */ + private void release() { + int newCount = outstandingRpcs.decrementAndGet(); + if (newCount < 0) { + LOG.log(Level.SEVERE, "Reference count is negative!: " + newCount); + } + + if (newCount == 0 && shutdownRequested.get()) { + shutdown(); + } + } + + /** + * Request a shutdown. The actual shutdown will be delayed until there are no more outstanding + * RPCs. + */ + private void requestShutdown() { + shutdownRequested.set(true); + if (outstandingRpcs.get() == 0) { + shutdown(); + } + } + + /** Ensure that shutdown is only called once. */ + private void shutdown() { + if (shutdownInitiated.compareAndSet(false, true)) { + channel.shutdown(); + } + } + } + + /** Thin wrapper to ensure that new calls are properly reference counted. */ + private class AffinityChannel extends Channel { + private final int affinity; + + public AffinityChannel(int affinity) { + this.affinity = affinity; + } + + @Override + public String authority() { + return authority; + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + + Entry entry = getRetainedEntry(affinity); + + return new ReleasingClientCall<>(entry.channel.newCall(methodDescriptor, callOptions), entry); + } + } + + /** ClientCall wrapper that makes sure to decrement the outstanding RPC count on completion. */ + static class ReleasingClientCall extends SimpleForwardingClientCall { + final Entry entry; + + public ReleasingClientCall(ClientCall delegate, Entry entry) { + super(delegate); + this.entry = entry; + } + + @Override + public void start(Listener responseListener, Metadata headers) { + try { + super.start( + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + try { + super.onClose(status, trailers); + } finally { + entry.release(); + } + } + }, + headers); + } catch (Exception e) { + // In case start failed, make sure to release + entry.release(); + } + } } } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index d6ae65202..875a0e97a 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -91,15 +91,13 @@ private void verifyTargetChannel( @SuppressWarnings("unchecked") ClientCall expectedClientCall = Mockito.mock(ClientCall.class); - for (ManagedChannel channel : channels) { - Mockito.reset(channel); - } + channels.forEach(Mockito::reset); Mockito.doReturn(expectedClientCall).when(targetChannel).newCall(methodDescriptor, callOptions); ClientCall actualCall = pool.newCall(methodDescriptor, callOptions); - - Truth.assertThat(actualCall).isSameInstanceAs(expectedClientCall); Mockito.verify(targetChannel, Mockito.times(1)).newCall(methodDescriptor, callOptions); + actualCall.start(null, null); + Mockito.verify(expectedClientCall, Mockito.times(1)).start(Mockito.any(), Mockito.any()); for (ManagedChannel otherChannel : channels) { if (otherChannel != targetChannel) { From 7e02e4ab068573ef049b29f8d2833ef1adf30980 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Fri, 19 Nov 2021 12:48:09 -0500 Subject: [PATCH 02/17] migrate SafeShutdownManagedChannel tests --- .../com/google/api/gax/grpc/ChannelPool.java | 4 +- .../google/api/gax/grpc/ChannelPoolTest.java | 151 ++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 010b722e6..2c73b1c54 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -29,6 +29,7 @@ */ package com.google.api.gax.grpc; +import com.google.api.core.InternalApi; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.grpc.CallOptions; @@ -285,7 +286,8 @@ private ScheduledFuture scheduleNextRefresh() { *

This is done on a best effort basis. If the replacement channel fails to construct, the old * channel will continue to be used. */ - private void refresh() { + @InternalApi("Visible for testing") + void refresh() { List localEntries = entries.get(); ArrayList newEntries = new ArrayList<>(localEntries); ArrayList removedEntries = new ArrayList<>(); diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 875a0e97a..5ce13f5aa 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -30,7 +30,9 @@ package com.google.api.gax.grpc; import com.google.api.gax.grpc.testing.FakeChannelFactory; +import com.google.api.gax.grpc.testing.FakeMethodDescriptor; import com.google.api.gax.grpc.testing.FakeServiceGrpc; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.truth.Truth; import com.google.type.Color; @@ -38,7 +40,9 @@ import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.ManagedChannel; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Status; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -228,4 +232,151 @@ public void channelPrimerIsCalledPeriodically() throws IOException { Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); scheduledExecutorService.shutdown(); } + + // ---- + // call should be allowed to complete and the channel should not shutdown + @Test + public void callShouldCompleteAfterCreation() throws IOException { + final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class); + FakeChannelFactory channelFactory = + new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); + ChannelPool pool = ChannelPool.create(1, channelFactory); + + // create a mock call when new call comes to the underlying channel + MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); + MockClientCall spyClientCall = Mockito.spy(mockClientCall); + Mockito.when( + underlyingChannel.newCall( + Mockito.>any(), Mockito.any(CallOptions.class))) + .thenReturn(spyClientCall); + + Answer verifyChannelNotShutdown = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); + } + }; + + // verify that underlying channel is not shutdown when clientCall is still sending message + Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); + + // create a new call on entry + @SuppressWarnings("unchecked") + ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); + ClientCall call = + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + pool.refresh(); + // shutdown is not called because there is still an outstanding call, even if it hasn't started + Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); + + // start clientCall + call.start(listener, new Metadata()); + // send message and end the call + call.sendMessage("message"); + // shutdown is called because the outstanding call has completed + Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); + + // Replacement channel shouldn't be touched + Mockito.verify(replacementChannel, Mockito.never()).shutdown(); + Mockito.verify(replacementChannel, Mockito.never()).newCall(Mockito.any(), Mockito.any()); + } + + // call should be allowed to complete and the channel should not shutdown + @Test + public void callShouldCompleteAfterStarted() throws IOException { + final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class); + + FakeChannelFactory channelFactory = + new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); + ChannelPool pool = ChannelPool.create(1, channelFactory); + + // create a mock call when new call comes to the underlying channel + MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); + MockClientCall spyClientCall = Mockito.spy(mockClientCall); + Mockito.when( + underlyingChannel.newCall( + Mockito.>any(), Mockito.any(CallOptions.class))) + .thenReturn(spyClientCall); + + Answer verifyChannelNotShutdown = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); + } + }; + + // verify that underlying channel is not shutdown when clientCall is still sending message + Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); + + // create a new call on safeShutdownManagedChannel + @SuppressWarnings("unchecked") + ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); + ClientCall call = + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + // start clientCall + call.start(listener, new Metadata()); + pool.refresh(); + + // shutdown is not called because there is still an outstanding call + Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); + // send message and end the call + call.sendMessage("message"); + // shutdown is called because the outstanding call has completed + Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); + } + + // Channel should shutdown after a refresh all the calls have completed + @Test + public void channelShouldShutdown() throws IOException { + final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class); + + FakeChannelFactory channelFactory = + new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); + ChannelPool pool = ChannelPool.create(1, channelFactory); + + // create a mock call when new call comes to the underlying channel + MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); + MockClientCall spyClientCall = Mockito.spy(mockClientCall); + Mockito.when( + underlyingChannel.newCall( + Mockito.>any(), Mockito.any(CallOptions.class))) + .thenReturn(spyClientCall); + + Answer verifyChannelNotShutdown = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); + } + }; + + // verify that underlying channel is not shutdown when clientCall is still sending message + Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); + + // create a new call on safeShutdownManagedChannel + @SuppressWarnings("unchecked") + ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); + ClientCall call = + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + // start clientCall + call.start(listener, new Metadata()); + // send message and end the call + call.sendMessage("message"); + // shutdown is not called because it has not been shutdown yet + Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); + pool.refresh(); + // shutdown is called because the outstanding call has completed + Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); + } } From 50735bdcf39906cb2e88b70bb0d031b5599c3ab1 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Fri, 19 Nov 2021 13:04:53 -0500 Subject: [PATCH 03/17] migrate old tests and remove RefreshingManagedChannel and SafeShutdownManagedChannel --- .../gax/grpc/RefreshingManagedChannel.java | 215 ------------------ .../gax/grpc/SafeShutdownManagedChannel.java | 170 -------------- .../google/api/gax/grpc/ChannelPoolTest.java | 47 +++- .../grpc/RefreshingManagedChannelTest.java | 193 ---------------- .../grpc/SafeShutdownManagedChannelTest.java | 176 -------------- 5 files changed, 44 insertions(+), 757 deletions(-) delete mode 100644 gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java delete mode 100644 gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java delete mode 100644 gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java delete mode 100644 gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java deleted file mode 100644 index ec8680d4c..000000000 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google LLC nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ -package com.google.api.gax.grpc; - -import io.grpc.CallOptions; -import io.grpc.ClientCall; -import io.grpc.ManagedChannel; -import io.grpc.MethodDescriptor; -import java.io.IOException; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.logging.Level; -import java.util.logging.Logger; -import org.threeten.bp.Duration; - -/** - * A {@link ManagedChannel} that will refresh the underlying channel by swapping the underlying - * channel with a new one periodically - * - *

Package-private for internal use. - * - *

A note on the synchronization logic. refreshChannel is called periodically which updates - * delegate and nextScheduledRefresh. lock is needed to provide atomic access and update of delegate - * and nextScheduledRefresh. One example is newCall needs to be atomic to avoid context switching to - * refreshChannel that shuts down delegate before newCall is completed. - */ -class RefreshingManagedChannel extends ManagedChannel { - private static final Logger LOG = Logger.getLogger(RefreshingManagedChannel.class.getName()); - // refresh every 50 minutes with 15% jitter for a range of 42.5min to 57.5min - private static final Duration refreshPeriod = Duration.ofMinutes(50); - private static final double jitterPercentage = 0.15; - private volatile SafeShutdownManagedChannel delegate; - private volatile ScheduledFuture nextScheduledRefresh; - // Read: method calls on delegate and nextScheduledRefresh - // Write: updating references of delegate and nextScheduledRefresh - private final ReadWriteLock lock; - private final ChannelFactory channelFactory; - private final ScheduledExecutorService scheduledExecutorService; - - RefreshingManagedChannel( - ChannelFactory channelFactory, ScheduledExecutorService scheduledExecutorService) - throws IOException { - this.delegate = new SafeShutdownManagedChannel(channelFactory.createSingleChannel()); - this.channelFactory = channelFactory; - this.scheduledExecutorService = scheduledExecutorService; - this.lock = new ReentrantReadWriteLock(); - this.nextScheduledRefresh = scheduleNextRefresh(); - } - - /** - * Refresh the existing channel by swapping the current channel with a new channel and schedule - * the next refresh - * - *

refreshChannel can only be called by scheduledExecutorService and not any other methods in - * this class. This is important so no threads will try to acquire the write lock while holding - * the read lock. - */ - private void refreshChannel() { - SafeShutdownManagedChannel newChannel; - try { - newChannel = new SafeShutdownManagedChannel(channelFactory.createSingleChannel()); - } catch (IOException ioException) { - LOG.log( - Level.WARNING, - "Failed to create a new channel when refreshing channel. This has no effect on the " - + "existing channels. The existing channel will continue to be used", - ioException); - return; - } - - SafeShutdownManagedChannel oldChannel = delegate; - lock.writeLock().lock(); - try { - // This thread can be interrupted by invoking cancel on nextScheduledRefresh - // Interrupt happens when this thread is blocked on acquiring the write lock because shutdown - // was called and that thread holds the read lock. - // When shutdown completes and releases the read lock and this thread acquires the write lock. - // This thread should not continue because the channel has shutdown. This check ensures that - // this thread terminates without swapping the channel and do not schedule the next refresh. - if (Thread.currentThread().isInterrupted()) { - newChannel.shutdownNow(); - return; - } - delegate = newChannel; - nextScheduledRefresh = scheduleNextRefresh(); - } finally { - lock.writeLock().unlock(); - } - oldChannel.shutdownSafely(); - } - - /** Schedule the next instance of refreshing this channel */ - private ScheduledFuture scheduleNextRefresh() { - long delayPeriod = refreshPeriod.toMillis(); - long jitter = (long) ((Math.random() - 0.5) * jitterPercentage * delayPeriod); - long delay = jitter + delayPeriod; - return scheduledExecutorService.schedule( - new Runnable() { - @Override - public void run() { - refreshChannel(); - } - }, - delay, - TimeUnit.MILLISECONDS); - } - - /** {@inheritDoc} */ - @Override - public ClientCall newCall( - MethodDescriptor methodDescriptor, CallOptions callOptions) { - lock.readLock().lock(); - try { - return delegate.newCall(methodDescriptor, callOptions); - } finally { - lock.readLock().unlock(); - } - } - - /** {@inheritDoc} */ - @Override - public String authority() { - // no lock here because authority is constant across all channels - return delegate.authority(); - } - - /** {@inheritDoc} */ - @Override - public ManagedChannel shutdown() { - lock.readLock().lock(); - try { - nextScheduledRefresh.cancel(true); - delegate.shutdown(); - return this; - } finally { - lock.readLock().unlock(); - } - } - - /** {@inheritDoc} */ - @Override - public ManagedChannel shutdownNow() { - lock.readLock().lock(); - try { - nextScheduledRefresh.cancel(true); - delegate.shutdownNow(); - return this; - } finally { - lock.readLock().unlock(); - } - } - - /** {@inheritDoc} */ - @Override - public boolean isShutdown() { - lock.readLock().lock(); - try { - return delegate.isShutdown(); - } finally { - lock.readLock().unlock(); - } - } - - /** {@inheritDoc} */ - @Override - public boolean isTerminated() { - lock.readLock().lock(); - try { - return delegate.isTerminated(); - } finally { - lock.readLock().unlock(); - } - } - - /** {@inheritDoc} */ - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - lock.readLock().lock(); - try { - return delegate.awaitTermination(timeout, unit); - } finally { - lock.readLock().unlock(); - } - } -} diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java deleted file mode 100644 index 0ae71d2e3..000000000 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google LLC nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ -package com.google.api.gax.grpc; - -import com.google.common.base.Preconditions; -import io.grpc.CallOptions; -import io.grpc.ClientCall; -import io.grpc.ClientCall.Listener; -import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; -import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * A {@link ManagedChannel} that will complete all calls started on the underlying channel before - * shutting down. - * - *

This class is not thread-safe. Caller must synchronize in order to ensure no new calls if safe - * shutdown has started. - * - *

Package-private for internal use. - */ -class SafeShutdownManagedChannel extends ManagedChannel { - private final ManagedChannel delegate; - private final AtomicInteger outstandingCalls = new AtomicInteger(0); - private volatile boolean isShutdownSafely = false; - - SafeShutdownManagedChannel(ManagedChannel managedChannel) { - this.delegate = managedChannel; - } - - /** - * Safely shutdown channel by checking that there are no more outstanding calls. If there are - * outstanding calls, the last call will invoke this method again when it complete - * - *

Caller should take care to synchronize with newCall so no new calls are started after - * shutdownSafely is called - */ - void shutdownSafely() { - isShutdownSafely = true; - if (outstandingCalls.get() == 0) { - delegate.shutdown(); - } - } - - /** {@inheritDoc} */ - @Override - public ManagedChannel shutdown() { - delegate.shutdown(); - return this; - } - - /** {@inheritDoc} */ - @Override - public boolean isShutdown() { - return delegate.isShutdown(); - } - - /** {@inheritDoc} */ - @Override - public ManagedChannel shutdownNow() { - delegate.shutdownNow(); - return this; - } - - /** {@inheritDoc} */ - @Override - public boolean isTerminated() { - return delegate.isTerminated(); - } - - /** {@inheritDoc} */ - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return delegate.awaitTermination(timeout, unit); - } - - /** - * Decrement outstanding call counter and shutdown if there are no more outstanding calls and - * {@link SafeShutdownManagedChannel#shutdownSafely()} has been invoked - */ - private void onClientCallClose() { - if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) { - shutdownSafely(); - } - } - - /** Listener that's responsible for decrementing outstandingCalls when the call closes */ - private class DecrementOutstandingCalls extends SimpleForwardingClientCallListener { - DecrementOutstandingCalls(Listener delegate) { - super(delegate); - } - - @Override - public void onClose(Status status, Metadata trailers) { - // decrement in finally block in case onClose throws an exception - try { - super.onClose(status, trailers); - } finally { - onClientCallClose(); - } - } - } - - /** To wrap around delegate to hook in {@link DecrementOutstandingCalls} */ - private class ClientCallProxy extends SimpleForwardingClientCall { - ClientCallProxy(ClientCall delegate) { - super(delegate); - } - - @Override - public void start(Listener responseListener, Metadata headers) { - super.start(new DecrementOutstandingCalls<>(responseListener), headers); - } - } - - /** - * Caller must take care to synchronize newCall and shutdownSafely in order to avoid race - * conditions of starting new calls after shutdownSafely is called - * - * @see io.grpc.ManagedChannel#newCall(MethodDescriptor, CallOptions) - */ - @Override - public ClientCall newCall( - MethodDescriptor methodDescriptor, CallOptions callOptions) { - Preconditions.checkState(!isShutdownSafely); - // increment after client call in case newCall throws an exception - ClientCall clientCall = - new ClientCallProxy<>(delegate.newCall(methodDescriptor, callOptions)); - outstandingCalls.incrementAndGet(); - return clientCall; - } - - /** {@inheritDoc} */ - @Override - public String authority() { - return delegate.authority(); - } -} diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 5ce13f5aa..53a1447b2 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -184,9 +184,9 @@ public void channelPrimerShouldCallPoolConstruction() throws IOException { @Test public void channelPrimerIsCalledPeriodically() throws IOException { ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class); - ManagedChannel channel1 = Mockito.mock(RefreshingManagedChannel.class); - ManagedChannel channel2 = Mockito.mock(RefreshingManagedChannel.class); - ManagedChannel channel3 = Mockito.mock(RefreshingManagedChannel.class); + ManagedChannel channel1 = Mockito.mock(ManagedChannel.class); + ManagedChannel channel2 = Mockito.mock(ManagedChannel.class); + ManagedChannel channel3 = Mockito.mock(ManagedChannel.class); List channelRefreshers = new ArrayList<>(); @@ -379,4 +379,45 @@ public Object answer(InvocationOnMock invocation) throws Throwable { // shutdown is called because the outstanding call has completed Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); } + + @Test + public void channelRefreshShouldSwapChannels() throws IOException { + ManagedChannel underlyingChannel1 = Mockito.mock(ManagedChannel.class); + ManagedChannel underlyingChannel2 = Mockito.mock(ManagedChannel.class); + + // mock executor service to capture the runnable scheduled so we can invoke it when we want to + ScheduledExecutorService scheduledExecutorService = + Mockito.mock(ScheduledExecutorService.class); + final List channelRefreshers = new ArrayList<>(); + Answer extractChannelRefresher = + new Answer() { + public Object answer(InvocationOnMock invocation) { + channelRefreshers.add(invocation.getArgument(0)); + return null; + } + }; + + Mockito.doAnswer(extractChannelRefresher) + .when(scheduledExecutorService) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + + FakeChannelFactory channelFactory = + new FakeChannelFactory(ImmutableList.of(underlyingChannel1, underlyingChannel2)); + ChannelPool pool = ChannelPool.createRefreshing(1, channelFactory, scheduledExecutorService); + Mockito.reset(underlyingChannel1); + + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + Mockito.verify(underlyingChannel1, Mockito.only()) + .newCall(Mockito.>any(), Mockito.any(CallOptions.class)); + + // swap channel + pool.refresh(); + + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + Mockito.verify(underlyingChannel2, Mockito.only()) + .newCall(Mockito.>any(), Mockito.any(CallOptions.class)); + } } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java deleted file mode 100644 index b56070bf3..000000000 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google LLC nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ -package com.google.api.gax.grpc; - -import com.google.api.gax.grpc.testing.FakeChannelFactory; -import com.google.api.gax.grpc.testing.FakeMethodDescriptor; -import io.grpc.CallOptions; -import io.grpc.ClientCall; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mockito; -import org.mockito.stubbing.Answer; - -@RunWith(JUnit4.class) -public class RefreshingManagedChannelTest { - @Test - public void channelRefreshShouldSwapChannels() throws IOException { - ManagedChannel underlyingChannel1 = Mockito.mock(ManagedChannel.class); - ManagedChannel underlyingChannel2 = Mockito.mock(ManagedChannel.class); - - // mock executor service to capture the runnable scheduled so we can invoke it when we want to - ScheduledExecutorService scheduledExecutorService = - Mockito.mock(ScheduledExecutorService.class); - List channelRefreshers = new ArrayList<>(); - Answer extractChannelRefresher = - invocation -> { - channelRefreshers.add((Runnable) invocation.getArgument(0)); - return null; - }; - - Mockito.doAnswer(extractChannelRefresher) - .when(scheduledExecutorService) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); - - FakeChannelFactory channelFactory = - new FakeChannelFactory(Arrays.asList(underlyingChannel1, underlyingChannel2)); - - ManagedChannel refreshingManagedChannel = - new RefreshingManagedChannel(channelFactory, scheduledExecutorService); - - refreshingManagedChannel.newCall( - FakeMethodDescriptor.create(), CallOptions.DEFAULT); - - Mockito.verify(underlyingChannel1, Mockito.only()) - .newCall(Mockito.>any(), Mockito.any(CallOptions.class)); - - // swap channel - channelRefreshers.get(0).run(); - - refreshingManagedChannel.newCall( - FakeMethodDescriptor.create(), CallOptions.DEFAULT); - - Mockito.verify(underlyingChannel2, Mockito.only()) - .newCall(Mockito.>any(), Mockito.any(CallOptions.class)); - } - - @Test - public void randomizeTest() throws IOException, InterruptedException, ExecutionException { - int channelCount = 10; - ManagedChannel[] underlyingChannels = new ManagedChannel[channelCount]; - Random r = new Random(); - for (int i = 0; i < channelCount; i++) { - ManagedChannel mockManagedChannel = Mockito.mock(ManagedChannel.class); - underlyingChannels[i] = mockManagedChannel; - - Answer waitAndSendMessage = - invocation -> { - // add a little time to sleep so calls don't always complete right away - TimeUnit.MICROSECONDS.sleep(r.nextInt(1000)); - // when sending message on the call, the channel cannot be shutdown - Mockito.verify(mockManagedChannel, Mockito.never()).shutdown(); - return invocation.callRealMethod(); - }; - - Answer createNewCall = - invocation -> { - // create a new client call for every new call to the underlying channel - MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); - MockClientCall spyClientCall = Mockito.spy(mockClientCall); - - // spy into clientCall to verify that the channel is not shutdown - Mockito.doAnswer(waitAndSendMessage) - .when(spyClientCall) - .sendMessage(Mockito.anyString()); - - return spyClientCall; - }; - - // return a new mocked client call when requesting new call on the channel - Mockito.doAnswer(createNewCall) - .when(underlyingChannels[i]) - .newCall( - Mockito.>any(), Mockito.any(CallOptions.class)); - } - - // mock executor service to capture the runnable scheduled so we can invoke it when we want to - List channelRefreshers = new ArrayList<>(); - ScheduledExecutorService scheduledExecutorService = - Mockito.mock(ScheduledExecutorService.class); - Answer extractChannelRefresher = - invocation -> { - channelRefreshers.add((Runnable) invocation.getArgument(0)); - return null; - }; - Mockito.doAnswer(extractChannelRefresher) - .when(scheduledExecutorService) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); - - FakeChannelFactory channelFactory = new FakeChannelFactory(Arrays.asList(underlyingChannels)); - ManagedChannel refreshingManagedChannel = - new RefreshingManagedChannel(channelFactory, scheduledExecutorService); - - // send a bunch of request to RefreshingManagedChannel, executor needs more than 1 thread to - // test out concurrency - ExecutorService executor = Executors.newFixedThreadPool(10); - - // channelCount - 1 because the last channel cannot be refreshed because the FakeChannelFactory - // has no more channel to create - for (int i = 0; i < channelCount - 1; i++) { - List> futures = new ArrayList<>(); - int requestCount = 100; - int whenToRefresh = r.nextInt(requestCount); - for (int j = 0; j < requestCount; j++) { - Runnable createNewCall = - () -> { - // create a new call and send message on refreshingManagedChannel - ClientCall call = - refreshingManagedChannel.newCall( - FakeMethodDescriptor.create(), CallOptions.DEFAULT); - @SuppressWarnings("unchecked") - ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); - call.start(listener, new Metadata()); - call.sendMessage("message"); - }; - futures.add(executor.submit(createNewCall)); - // at the randomly chosen point, refresh the channel - if (j == whenToRefresh) { - futures.add(executor.submit(channelRefreshers.get(i))); - } - } - for (Future future : futures) { - future.get(); - } - Mockito.verify(underlyingChannels[i], Mockito.atLeastOnce()).shutdown(); - Mockito.verify(underlyingChannels[i + 1], Mockito.never()).shutdown(); - } - } -} diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java deleted file mode 100644 index 91b16dcbf..000000000 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google LLC nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ -package com.google.api.gax.grpc; - -import com.google.api.gax.grpc.testing.FakeMethodDescriptor; -import io.grpc.CallOptions; -import io.grpc.ClientCall; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mockito; -import org.mockito.stubbing.Answer; - -@RunWith(JUnit4.class) -public class SafeShutdownManagedChannelTest { - // call should be allowed to complete and the channel should not shutdown - @Test - public void callShouldCompleteAfterCreation() { - final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); - - SafeShutdownManagedChannel safeShutdownManagedChannel = - new SafeShutdownManagedChannel(underlyingChannel); - - // create a mock call when new call comes to the underlying channel - MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); - MockClientCall spyClientCall = Mockito.spy(mockClientCall); - Mockito.when( - underlyingChannel.newCall( - Mockito.>any(), Mockito.any(CallOptions.class))) - .thenReturn(spyClientCall); - - Answer verifyChannelNotShutdown = - invocation -> { - Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); - return invocation.callRealMethod(); - }; - - // verify that underlying channel is not shutdown when clientCall is still sending message - Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); - - // create a new call on safeShutdownManagedChannel - @SuppressWarnings("unchecked") - ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); - ClientCall call = - safeShutdownManagedChannel.newCall( - FakeMethodDescriptor.create(), CallOptions.DEFAULT); - - safeShutdownManagedChannel.shutdownSafely(); - // shutdown is not called because there is still an outstanding call, even if it hasn't started - Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); - - // start clientCall - call.start(listener, new Metadata()); - // send message and end the call - call.sendMessage("message"); - // shutdown is called because the outstanding call has completed - Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); - } - - // call should be allowed to complete and the channel should not shutdown - @Test - public void callShouldCompleteAfterStarted() { - final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); - - SafeShutdownManagedChannel safeShutdownManagedChannel = - new SafeShutdownManagedChannel(underlyingChannel); - - // create a mock call when new call comes to the underlying channel - MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); - MockClientCall spyClientCall = Mockito.spy(mockClientCall); - Mockito.when( - underlyingChannel.newCall( - Mockito.>any(), Mockito.any(CallOptions.class))) - .thenReturn(spyClientCall); - - Answer verifyChannelNotShutdown = - invocation -> { - Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); - return invocation.callRealMethod(); - }; - - // verify that underlying channel is not shutdown when clientCall is still sending message - Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); - - // create a new call on safeShutdownManagedChannel - @SuppressWarnings("unchecked") - ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); - ClientCall call = - safeShutdownManagedChannel.newCall( - FakeMethodDescriptor.create(), CallOptions.DEFAULT); - - // start clientCall - call.start(listener, new Metadata()); - safeShutdownManagedChannel.shutdownSafely(); - // shutdown is not called because there is still an outstanding call - Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); - // send message and end the call - call.sendMessage("message"); - // shutdown is called because the outstanding call has completed - Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); - } - - // Channel should shutdown after a refresh all the calls have completed - @Test - public void channelShouldShutdown() { - final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); - - SafeShutdownManagedChannel safeShutdownManagedChannel = - new SafeShutdownManagedChannel(underlyingChannel); - - // create a mock call when new call comes to the underlying channel - MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); - MockClientCall spyClientCall = Mockito.spy(mockClientCall); - Mockito.when( - underlyingChannel.newCall( - Mockito.>any(), Mockito.any(CallOptions.class))) - .thenReturn(spyClientCall); - - Answer verifyChannelNotShutdown = - invocation -> { - Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); - return invocation.callRealMethod(); - }; - - // verify that underlying channel is not shutdown when clientCall is still sending message - Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); - - // create a new call on safeShutdownManagedChannel - @SuppressWarnings("unchecked") - ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); - ClientCall call = - safeShutdownManagedChannel.newCall( - FakeMethodDescriptor.create(), CallOptions.DEFAULT); - - // start clientCall - call.start(listener, new Metadata()); - // send message and end the call - call.sendMessage("message"); - // shutdown is not called because it has not been shutdown yet - Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); - safeShutdownManagedChannel.shutdownSafely(); - // shutdown is called because the outstanding call has completed - Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); - } -} From 3a501e0de709fc1055ac17768729fd89cfcc2287 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Mon, 22 Nov 2021 17:01:37 -0500 Subject: [PATCH 04/17] fix test --- .../java/com/google/api/gax/grpc/GrpcClientCallsTest.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java index ef0d58ce9..440d57209 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java @@ -30,6 +30,7 @@ package com.google.api.gax.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.verify; import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeServiceGrpc; @@ -83,8 +84,8 @@ public void testAffinity() throws IOException { ClientCall gotCallC = GrpcClientCalls.newCall(descriptor, context.withChannelAffinity(1)); - assertThat(gotCallA).isSameInstanceAs(gotCallB); - assertThat(gotCallA).isNotSameInstanceAs(gotCallC); + verify(channel0, Mockito.times(2)).newCall(Mockito.eq(descriptor), Mockito.any()); + verify(channel1, Mockito.times(1)).newCall(Mockito.eq(descriptor), Mockito.any()); } @Test From 6cb0964c470fdac1f11feb5bbcc7b60f7144949b Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 9 Dec 2021 16:19:08 -0500 Subject: [PATCH 05/17] address feedback --- .../com/google/api/gax/grpc/ChannelPool.java | 47 ++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 2c73b1c54..fabb91503 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -75,8 +75,7 @@ class ChannelPool extends ManagedChannel { private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50); private static final double JITTER_PERCENTAGE = 0.15; - // A copy on write list of child channels. - private final AtomicReference> entries = new AtomicReference<>(); + private final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; // if set, ChannelPool will manage the life cycle of channelRefreshExecutorService @@ -269,11 +268,8 @@ private ScheduledFuture scheduleNextRefresh() { long delay = jitter + delayPeriod; return channelRefreshExecutorService.schedule( () -> { - try { - refresh(); - } finally { - scheduleNextRefresh(); - } + scheduleNextRefresh(); + refresh(); }, delay, TimeUnit.MILLISECONDS); @@ -300,31 +296,48 @@ void refresh() { LOG.log(Level.WARNING, "Failed to refresh channel, leaving old channel", e); } } - entries.set(ImmutableList.copyOf(newEntries)); + + ImmutableList replacedEntries = entries.getAndSet(ImmutableList.copyOf(newEntries)); + + // In the unlikely case that the list was modified while the new channels were being created, + // shutdown the unexpected channels. + for (Entry e : replacedEntries) { + if (!newEntries.contains(e) && !removedEntries.contains(e)) { + removedEntries.add(e); + } + } removedEntries.forEach(Entry::requestShutdown); } /** - * Get and retain a Channel Entry. The returned will have its rpc count incremented, preventing it - * from getting recycled. + * Get and retain a Channel Entry. The returned Entry will have its rpc count incremented, + * preventing it from getting recycled. */ Entry getRetainedEntry(int affinity) { + // The maximum number of concurrent calls to this method for any given time span is at most 2, + // so the loop can actually be 2 times. But going for 5 times for a safety margin for potential + // code evolving for (int i = 0; i < 5; i++) { Entry entry = getEntry(affinity); if (entry.retain()) { return entry; } } - throw new IllegalStateException("Failed to retain a channel"); + // It is unlikely to reach here unless the pool code evolves to increase the maximum possible + // concurrent calls to this method. If it does, this is a bug in the channel pool implementation + // the number of retries above should be greater than the number of contending maintenance + // tasks. + throw new IllegalStateException("Bug: failed to retain a channel"); } /** * Returns one of the channels managed by this pool. The pool continues to "own" the channel, and * the caller should not shut it down. * - * @param affinity Two calls to this method with the same affinity returns the same channel. The - * reverse is not true: Two calls with different affinities might return the same channel. - * However, the implementation should attempt to spread load evenly. + * @param affinity Two calls to this method with the same affinity returns the same channel most + * of the time, if the channel pool was refreshed since the last call, a new channel will be + * returned. The reverse is not true: Two calls with different affinities might return the + * same channel. However, the implementation should attempt to spread load evenly. */ private Entry getEntry(int affinity) { List localEntries = entries.get(); @@ -378,10 +391,12 @@ private boolean retain() { private void release() { int newCount = outstandingRpcs.decrementAndGet(); if (newCount < 0) { - LOG.log(Level.SEVERE, "Reference count is negative!: " + newCount); + throw new IllegalStateException("Bug: reference count is negative!: " + newCount); } - if (newCount == 0 && shutdownRequested.get()) { + // Must check outstandingRpcs after shutdownRequested (in reverse order of retain()) to ensure + // mutual exclusion. + if (shutdownRequested.get() && outstandingRpcs.get() == 0) { shutdown(); } } From 8f6e39edc1f169e62e0ffa041c06b0860c39178b Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 23 Dec 2021 16:37:09 -0500 Subject: [PATCH 06/17] fix race condition on refresh() --- .../com/google/api/gax/grpc/ChannelPool.java | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 9f705eaaf..4ebc3782c 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -278,14 +278,11 @@ private ScheduledFuture scheduleNextRefresh() { */ @InternalApi("Visible for testing") void refresh() { - List localEntries = entries.get(); - ArrayList newEntries = new ArrayList<>(localEntries); - ArrayList removedEntries = new ArrayList<>(); + ArrayList newEntries = new ArrayList<>(entries.get()); - for (int i = 0; i < localEntries.size(); i++) { + for (int i = 0; i < newEntries.size(); i++) { try { - Entry removed = newEntries.set(i, new Entry(channelFactory.createSingleChannel())); - removedEntries.add(removed); + newEntries.set(i, new Entry(channelFactory.createSingleChannel())); } catch (IOException e) { LOG.log(Level.WARNING, "Failed to refresh channel, leaving old channel", e); } @@ -293,15 +290,13 @@ void refresh() { ImmutableList replacedEntries = entries.getAndSet(ImmutableList.copyOf(newEntries)); - // In the unlikely case that the list was modified while the new channels were being created, - // shutdown the unexpected channels. + // Shutdown the channels that were cycled out. This will either be the channels we just + // refreshed or in case of a race, the channels that the other thread set. for (Entry e : replacedEntries) { - if (!newEntries.contains(e) && !removedEntries.contains(e)) { - removedEntries.add(e); + if (!newEntries.contains(e)) { + e.requestShutdown(); } } - - removedEntries.forEach(Entry::requestShutdown); } /** From c93daf4ac892d0db324abb506ce42b9a90d2823f Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 23 Dec 2021 16:37:44 -0500 Subject: [PATCH 07/17] fix warnings in test --- .../google/api/gax/grpc/ChannelPoolTest.java | 79 +++++++------------ 1 file changed, 27 insertions(+), 52 deletions(-) diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 53a1447b2..83de7e77e 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -57,7 +57,6 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mockito; -import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @RunWith(JUnit4.class) @@ -129,13 +128,9 @@ public void ensureEvenDistribution() throws InterruptedException, IOException { channels[i] = Mockito.mock(ManagedChannel.class); Mockito.when(channels[i].newCall(methodDescriptor, callOptions)) .thenAnswer( - new Answer>() { - @Override - public ClientCall answer(InvocationOnMock invocationOnMock) - throws Throwable { - counts[index].incrementAndGet(); - return clientCall; - } + (ignored) -> { + counts[index].incrementAndGet(); + return clientCall; }); } @@ -148,12 +143,9 @@ public ClientCall answer(InvocationOnMock invocationOnMock) ExecutorService executor = Executors.newFixedThreadPool(numThreads); for (int i = 0; i < numThreads; i++) { executor.submit( - new Runnable() { - @Override - public void run() { - for (int j = 0; j < numPerThread; j++) { - pool.newCall(methodDescriptor, callOptions); - } + () -> { + for (int j = 0; j < numPerThread; j++) { + pool.newCall(methodDescriptor, callOptions); } }); } @@ -195,7 +187,7 @@ public void channelPrimerIsCalledPeriodically() throws IOException { Answer extractChannelRefresher = invocation -> { - channelRefreshers.add((Runnable) invocation.getArgument(0)); + channelRefreshers.add(invocation.getArgument(0)); return Mockito.mock(ScheduledFuture.class); }; @@ -234,7 +226,7 @@ public void channelPrimerIsCalledPeriodically() throws IOException { } // ---- - // call should be allowed to complete and the channel should not shutdown + // call should be allowed to complete and the channel should not be shutdown @Test public void callShouldCompleteAfterCreation() throws IOException { final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); @@ -251,13 +243,10 @@ public void callShouldCompleteAfterCreation() throws IOException { Mockito.>any(), Mockito.any(CallOptions.class))) .thenReturn(spyClientCall); - Answer verifyChannelNotShutdown = - new Answer() { - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); - return invocation.callRealMethod(); - } + Answer verifyChannelNotShutdown = + invocation -> { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); }; // verify that underlying channel is not shutdown when clientCall is still sending message @@ -267,7 +256,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { @SuppressWarnings("unchecked") ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); ClientCall call = - pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); pool.refresh(); // shutdown is not called because there is still an outstanding call, even if it hasn't started @@ -285,7 +274,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { Mockito.verify(replacementChannel, Mockito.never()).newCall(Mockito.any(), Mockito.any()); } - // call should be allowed to complete and the channel should not shutdown + // call should be allowed to complete and the channel should not be shutdown @Test public void callShouldCompleteAfterStarted() throws IOException { final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); @@ -303,13 +292,10 @@ public void callShouldCompleteAfterStarted() throws IOException { Mockito.>any(), Mockito.any(CallOptions.class))) .thenReturn(spyClientCall); - Answer verifyChannelNotShutdown = - new Answer() { - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); - return invocation.callRealMethod(); - } + Answer verifyChannelNotShutdown = + invocation -> { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); }; // verify that underlying channel is not shutdown when clientCall is still sending message @@ -319,7 +305,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { @SuppressWarnings("unchecked") ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); ClientCall call = - pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); // start clientCall call.start(listener, new Metadata()); @@ -333,7 +319,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); } - // Channel should shutdown after a refresh all the calls have completed + // Channel should be shutdown after a refresh all the calls have completed @Test public void channelShouldShutdown() throws IOException { final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); @@ -351,13 +337,10 @@ public void channelShouldShutdown() throws IOException { Mockito.>any(), Mockito.any(CallOptions.class))) .thenReturn(spyClientCall); - Answer verifyChannelNotShutdown = - new Answer() { - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); - return invocation.callRealMethod(); - } + Answer verifyChannelNotShutdown = + invocation -> { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); }; // verify that underlying channel is not shutdown when clientCall is still sending message @@ -367,7 +350,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { @SuppressWarnings("unchecked") ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); ClientCall call = - pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); + pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); // start clientCall call.start(listener, new Metadata()); @@ -385,19 +368,11 @@ public void channelRefreshShouldSwapChannels() throws IOException { ManagedChannel underlyingChannel1 = Mockito.mock(ManagedChannel.class); ManagedChannel underlyingChannel2 = Mockito.mock(ManagedChannel.class); - // mock executor service to capture the runnable scheduled so we can invoke it when we want to + // mock executor service to capture the runnable scheduled, so we can invoke it when we want to ScheduledExecutorService scheduledExecutorService = Mockito.mock(ScheduledExecutorService.class); - final List channelRefreshers = new ArrayList<>(); - Answer extractChannelRefresher = - new Answer() { - public Object answer(InvocationOnMock invocation) { - channelRefreshers.add(invocation.getArgument(0)); - return null; - } - }; - Mockito.doAnswer(extractChannelRefresher) + Mockito.doReturn(null) .when(scheduledExecutorService) .schedule( Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); From 0062b0f595c9eb1334ee9b33823d7e09869c4ae9 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 23 Dec 2021 16:46:45 -0500 Subject: [PATCH 08/17] Update gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java Co-authored-by: Chanseok Oh --- .../src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 83de7e77e..f0c6abbdf 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -229,7 +229,7 @@ public void channelPrimerIsCalledPeriodically() throws IOException { // call should be allowed to complete and the channel should not be shutdown @Test public void callShouldCompleteAfterCreation() throws IOException { - final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class); FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); From cc0b50795bc5d456705fc987fa797a5614060a63 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 23 Dec 2021 16:47:36 -0500 Subject: [PATCH 09/17] Update gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java Co-authored-by: Chanseok Oh --- .../src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index f0c6abbdf..f9a43acce 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -322,7 +322,7 @@ public void callShouldCompleteAfterStarted() throws IOException { // Channel should be shutdown after a refresh all the calls have completed @Test public void channelShouldShutdown() throws IOException { - final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class); FakeChannelFactory channelFactory = From 7bf8b6df381e949d3ee3f3a20c2b16072df2f7df Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 6 Jan 2022 12:12:33 -0500 Subject: [PATCH 10/17] handle race condition --- .../com/google/api/gax/grpc/ChannelPool.java | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 4ebc3782c..7d7c837a7 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -75,6 +75,7 @@ class ChannelPool extends ManagedChannel { private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50); private static final double JITTER_PERCENTAGE = 0.15; + private final Object entryWriteLock = new Object(); private final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; @@ -278,23 +279,31 @@ private ScheduledFuture scheduleNextRefresh() { */ @InternalApi("Visible for testing") void refresh() { - ArrayList newEntries = new ArrayList<>(entries.get()); - - for (int i = 0; i < newEntries.size(); i++) { - try { - newEntries.set(i, new Entry(channelFactory.createSingleChannel())); - } catch (IOException e) { - LOG.log(Level.WARNING, "Failed to refresh channel, leaving old channel", e); + // Note: synchronization is necessary in case refresh is called concurrently: + // - thread1 fails to replace a single entry + // - thread2 succeeds replacing an entry + // - thread1 loses the race to replace the list + // - then thread2 will shut down channel that thread1 will put back into circulation (after it + // replaces the list) + synchronized (entryWriteLock) { + ArrayList newEntries = new ArrayList<>(entries.get()); + + for (int i = 0; i < newEntries.size(); i++) { + try { + newEntries.set(i, new Entry(channelFactory.createSingleChannel())); + } catch (IOException e) { + LOG.log(Level.WARNING, "Failed to refresh channel, leaving old channel", e); + } } - } - ImmutableList replacedEntries = entries.getAndSet(ImmutableList.copyOf(newEntries)); + ImmutableList replacedEntries = entries.getAndSet(ImmutableList.copyOf(newEntries)); - // Shutdown the channels that were cycled out. This will either be the channels we just - // refreshed or in case of a race, the channels that the other thread set. - for (Entry e : replacedEntries) { - if (!newEntries.contains(e)) { - e.requestShutdown(); + // Shutdown the channels that were cycled out. This will either be the channels we just + // refreshed or in case of a race, the channels that the other thread set. + for (Entry e : replacedEntries) { + if (!newEntries.contains(e)) { + e.requestShutdown(); + } } } } From d2d7830284be3a523526c98a7ac8af459cd21247 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 6 Jan 2022 14:05:56 -0500 Subject: [PATCH 11/17] Update gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java Co-authored-by: Chanseok Oh --- .../src/main/java/com/google/api/gax/grpc/ChannelPool.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 7d7c837a7..efd46f04a 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -298,8 +298,7 @@ void refresh() { ImmutableList replacedEntries = entries.getAndSet(ImmutableList.copyOf(newEntries)); - // Shutdown the channels that were cycled out. This will either be the channels we just - // refreshed or in case of a race, the channels that the other thread set. + // Shutdown the channels that were cycled out. for (Entry e : replacedEntries) { if (!newEntries.contains(e)) { e.requestShutdown(); From 4ebbebe038cc7e45a33b130266465ff28ec55faf Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Mon, 22 Nov 2021 14:40:55 -0500 Subject: [PATCH 12/17] introduce dynamic channel pool --- .../com/google/api/gax/grpc/ChannelPool.java | 252 +++++++++++++----- .../api/gax/grpc/ChannelPoolSettings.java | 180 +++++++++++++ .../InstantiatingGrpcChannelProvider.java | 96 ++++--- .../google/api/gax/grpc/ChannelPoolTest.java | 211 +++++++++++++-- .../InstantiatingGrpcChannelProviderTest.java | 9 - .../InstantiatingHttpJsonChannelProvider.java | 6 +- .../rpc/FixedTransportChannelProvider.java | 6 +- .../api/gax/rpc/TransportChannelProvider.java | 16 +- 8 files changed, 625 insertions(+), 151 deletions(-) create mode 100644 gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index efd46f04a..f11514989 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -31,6 +31,7 @@ import com.google.api.core.InternalApi; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import io.grpc.CallOptions; import io.grpc.Channel; @@ -46,14 +47,12 @@ import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.Nullable; import org.threeten.bp.Duration; /** @@ -68,22 +67,16 @@ */ class ChannelPool extends ManagedChannel { private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName()); - - // size greater than 1 to allow multiple channel to refresh at the same time - // size not too large so refreshing channels doesn't use too many threads - private static final int CHANNEL_REFRESH_EXECUTOR_SIZE = 2; private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50); - private static final double JITTER_PERCENTAGE = 0.15; + + private final ChannelPoolSettings settings; + private final ChannelFactory channelFactory; + private final ScheduledExecutorService executor; private final Object entryWriteLock = new Object(); private final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; - // if set, ChannelPool will manage the life cycle of channelRefreshExecutorService - @Nullable private final ScheduledExecutorService channelRefreshExecutorService; - private final ChannelFactory channelFactory; - - private volatile ScheduledFuture nextScheduledRefresh = null; /** * Factory method to create a non-refreshing channel pool @@ -92,8 +85,9 @@ class ChannelPool extends ManagedChannel { * @param channelFactory method to create the channels * @return ChannelPool of non-refreshing channels */ + @VisibleForTesting static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IOException { - return new ChannelPool(channelFactory, poolSize, null); + return new ChannelPool(ChannelPoolSettings.staticallySized(poolSize), channelFactory, null); } /** @@ -103,58 +97,66 @@ static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IO * * @param poolSize number of channels in the pool * @param channelFactory method to create the channels - * @param channelRefreshExecutorService periodically refreshes the channels; its life cycle will - * be managed by ChannelPool + * @param executor used to schedule maintenance tasks like refresh channels and resizing the pool. * @return ChannelPool of refreshing channels */ @VisibleForTesting static ChannelPool createRefreshing( - int poolSize, - ChannelFactory channelFactory, - ScheduledExecutorService channelRefreshExecutorService) + int poolSize, ChannelFactory channelFactory, ScheduledExecutorService executor) throws IOException { - return new ChannelPool(channelFactory, poolSize, channelRefreshExecutorService); + return new ChannelPool( + ChannelPoolSettings.staticallySized(poolSize) + .toBuilder() + .setPreemptiveReconnectEnabled(true) + .build(), + channelFactory, + executor); } - /** - * Factory method to create a refreshing channel pool - * - * @param poolSize number of channels in the pool - * @param channelFactory method to create the channels - * @return ChannelPool of refreshing channels - */ - static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFactory) + static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFactory) throws IOException { - return createRefreshing( - poolSize, channelFactory, Executors.newScheduledThreadPool(CHANNEL_REFRESH_EXECUTOR_SIZE)); + return new ChannelPool(settings, channelFactory, Executors.newSingleThreadScheduledExecutor()); } /** * Initializes the channel pool. Assumes that all channels have the same authority. * + * @param settings options for controling the ChannelPool sizing behavior * @param channelFactory method to create the channels - * @param poolSize number of channels in the pool - * @param channelRefreshExecutorService periodically refreshes the channels + * @param executor periodically refreshes the channels. Must be single threaded */ - private ChannelPool( + @InternalApi("VisibleForTesting") + ChannelPool( + ChannelPoolSettings settings, ChannelFactory channelFactory, - int poolSize, - @Nullable ScheduledExecutorService channelRefreshExecutorService) + ScheduledExecutorService executor) throws IOException { + this.settings = settings; this.channelFactory = channelFactory; ImmutableList.Builder initialListBuilder = ImmutableList.builder(); - for (int i = 0; i < poolSize; i++) { + for (int i = 0; i < settings.getInitialChannelCount(); i++) { initialListBuilder.add(new Entry(channelFactory.createSingleChannel())); } entries.set(initialListBuilder.build()); authority = entries.get().get(0).channel.authority(); - this.channelRefreshExecutorService = channelRefreshExecutorService; - - if (channelRefreshExecutorService != null) { - nextScheduledRefresh = scheduleNextRefresh(); + this.executor = executor; + + if (!settings.isStaticSize()) { + executor.scheduleAtFixedRate( + this::resizeSafely, + ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(), + ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(), + TimeUnit.SECONDS); + } + if (settings.isPreemptiveReconnectEnabled()) { + executor.scheduleAtFixedRate( + this::refreshSafely, + REFRESH_PERIOD.getSeconds(), + REFRESH_PERIOD.getSeconds(), + TimeUnit.SECONDS); } } @@ -187,12 +189,9 @@ public ManagedChannel shutdown() { for (Entry entry : localEntries) { entry.channel.shutdown(); } - if (nextScheduledRefresh != null) { - nextScheduledRefresh.cancel(true); - } - if (channelRefreshExecutorService != null) { + if (executor != null) { // shutdownNow will cancel scheduled tasks - channelRefreshExecutorService.shutdownNow(); + executor.shutdownNow(); } return this; } @@ -206,7 +205,7 @@ public boolean isShutdown() { return false; } } - return channelRefreshExecutorService == null || channelRefreshExecutorService.isShutdown(); + return executor == null || executor.isShutdown(); } /** {@inheritDoc} */ @@ -218,7 +217,8 @@ public boolean isTerminated() { return false; } } - return channelRefreshExecutorService == null || channelRefreshExecutorService.isTerminated(); + + return executor == null || executor.isTerminated(); } /** {@inheritDoc} */ @@ -228,11 +228,8 @@ public ManagedChannel shutdownNow() { for (Entry entry : localEntries) { entry.channel.shutdownNow(); } - if (nextScheduledRefresh != null) { - nextScheduledRefresh.cancel(true); - } - if (channelRefreshExecutorService != null) { - channelRefreshExecutorService.shutdownNow(); + if (executor != null) { + executor.shutdownNow(); } return this; } @@ -249,25 +246,129 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE } entry.channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS); } - if (channelRefreshExecutorService != null) { + if (executor != null) { long awaitTimeNanos = endTimeNanos - System.nanoTime(); - channelRefreshExecutorService.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS); + executor.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS); } return isTerminated(); } - /** Scheduling loop. */ - private ScheduledFuture scheduleNextRefresh() { - long delayPeriod = REFRESH_PERIOD.toMillis(); - long jitter = (long) ((Math.random() - 0.5) * JITTER_PERCENTAGE * delayPeriod); - long delay = jitter + delayPeriod; - return channelRefreshExecutorService.schedule( - () -> { - scheduleNextRefresh(); - refresh(); - }, - delay, - TimeUnit.MILLISECONDS); + void resizeSafely() { + try { + synchronized (entryWriteLock) { + resize(); + } + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to resize channel pool", e); + } + } + + /** + * Resize the number of channels based on the number of outstanding RPCs. + * + *

This method is expected to be called on a fixed interval. On every invocation it will: + * + *

    + *
  • Get the maximum number of outstanding RPCs since last invocation + *
  • Determine a valid range of number of channels to handle that many outstanding RPCs + *
  • If the current number of channel falls outside of that range, add or remove at most + * {@link ChannelPoolSettings#MAX_RESIZE_DELTA} to get closer to middle of that range. + *
+ * + *

Not threadsafe, must be called under the entryWriteLock monitor + */ + void resize() { + List localEntries = entries.get(); + // Estimate the peak of RPCs in the last interval by summing the peak of RPCs per channel + int actualOutstandingRpcs = + localEntries.stream().mapToInt(Entry::getAndResetMaxOutstanding).sum(); + + // Number of channels if each channel operated at max capacity + int minChannels = + (int) Math.ceil(actualOutstandingRpcs / (double) settings.getMaxRpcsPerChannel()); + // Limit the threshold to absolute range + if (minChannels < settings.getMinChannelCount()) { + minChannels = settings.getMinChannelCount(); + } + + // Number of channels if each channel operated at minimum capacity + int maxChannels = + (int) Math.ceil(actualOutstandingRpcs / (double) settings.getMinRpcsPerChannel()); + // Limit the threshold to absolute range + if (maxChannels > settings.getMaxChannelCount()) { + maxChannels = settings.getMaxChannelCount(); + } + if (maxChannels < minChannels) { + maxChannels = minChannels; + } + + // If the pool were to be resized, try to aim for the middle of the bound, but limit rate of + // change. + int tentativeTarget = (maxChannels + minChannels) / 2; + int currentSize = localEntries.size(); + int delta = tentativeTarget - currentSize; + int dampenedTarget = tentativeTarget; + if (Math.abs(delta) > ChannelPoolSettings.MAX_RESIZE_DELTA) { + dampenedTarget = + currentSize + (int) Math.copySign(ChannelPoolSettings.MAX_RESIZE_DELTA, delta); + } + + // Only resize the pool when thresholds are crossed + if (localEntries.size() < minChannels) { + LOG.fine( + String.format( + "Detected throughput peak of %d, expanding channel pool size: %d -> %d.", + actualOutstandingRpcs, currentSize, dampenedTarget)); + + expand(tentativeTarget); + } else if (localEntries.size() > maxChannels) { + LOG.fine( + String.format( + "Detected throughput drop to %d, shrinking channel pool size: %d -> %d.", + actualOutstandingRpcs, currentSize, dampenedTarget)); + + shrink(tentativeTarget); + } + } + + /** Not threadsafe, must be called under the entryWriteLock monitor */ + private void shrink(int desiredSize) { + List localEntries = entries.get(); + Preconditions.checkState( + localEntries.size() >= desiredSize, "desired size is already smaller than the current"); + + // Set the new list + entries.set(ImmutableList.copyOf(localEntries.subList(0, desiredSize))); + // clean up removed entries + List removed = localEntries.subList(desiredSize, localEntries.size()); + removed.forEach(Entry::requestShutdown); + } + + /** Not threadsafe, must be called under the entryWriteLock monitor */ + private void expand(int desiredSize) { + List localEntries = entries.get(); + Preconditions.checkState( + localEntries.size() <= desiredSize, "desired size is already bigger than the current"); + + ImmutableList.Builder newEntries = ImmutableList.builder().addAll(localEntries); + + for (int i = 0; i < desiredSize - localEntries.size(); i++) { + try { + newEntries.add(new Entry(channelFactory.createSingleChannel())); + } catch (IOException e) { + LOG.log(Level.WARNING, "Failed to add channel", e); + } + } + + entries.set(newEntries.build()); + } + + private void refreshSafely() { + try { + refresh(); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to pre-emptively refresh channnels", e); + } } /** @@ -340,7 +441,13 @@ Entry getRetainedEntry(int affinity) { private Entry getEntry(int affinity) { List localEntries = entries.get(); - int index = Math.abs(affinity % localEntries.size()); + int index = affinity % localEntries.size(); + index = Math.abs(index); + // If index is the most negative int, abs(index) is still negative. + if (index < 0) { + index = 0; + } + return localEntries.get(index); } @@ -348,6 +455,7 @@ private Entry getEntry(int affinity) { private static class Entry { private final ManagedChannel channel; private final AtomicInteger outstandingRpcs = new AtomicInteger(0); + private final AtomicInteger maxOutstanding = new AtomicInteger(); // Flag that the channel should be closed once all of the outstanding RPC complete. private final AtomicBoolean shutdownRequested = new AtomicBoolean(); @@ -358,6 +466,10 @@ private Entry(ManagedChannel channel) { this.channel = channel; } + int getAndResetMaxOutstanding() { + return maxOutstanding.getAndSet(outstandingRpcs.get()); + } + /** * Try to increment the outstanding RPC count. The method will return false if the channel is * closing and the caller should pick a different channel. If the method returned true, the @@ -366,7 +478,13 @@ private Entry(ManagedChannel channel) { */ private boolean retain() { // register desire to start RPC - outstandingRpcs.incrementAndGet(); + int currentOutstanding = outstandingRpcs.incrementAndGet(); + + // Rough book keeping + int prevMax = maxOutstanding.get(); + if (currentOutstanding > prevMax) { + maxOutstanding.incrementAndGet(); + } // abort if the channel is closing if (shutdownRequested.get()) { diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java new file mode 100644 index 000000000..383be15ef --- /dev/null +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java @@ -0,0 +1,180 @@ +/* + * Copyright 2021 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.api.gax.grpc; + +import com.google.api.core.BetaApi; +import com.google.auto.value.AutoValue; +import com.google.common.base.Preconditions; +import java.time.Duration; + +/** + * Settings to control {@link ChannelPool} behavior. + * + *

To facilitate low latency/high throughout applications, gax provides a {@link ChannelPool}. + * The pool is meant to facilitate high throughput/low latency clients. By splitting load across + * multiple gRPC channels the client can spread load across multiple frontends and overcome gRPC's + * limit of 100 concurrent RPCs per channel. However oversizing the {@link ChannelPool} can lead to + * underutilized channels which will lead to high tail latency due to GFEs disconnecting idle + * channels. + * + *

The {@link ChannelPool} is designed to adapt to varying traffic patterns by tracking + * outstanding RPCs and resizing the pool size. This class configures the behavior. In general + * clients should aim to have less than 50 concurrent RPCs per channel and at least 1 outstanding + * per channel per minute. + * + *

The settings in this class will be applied every minute. + */ +@BetaApi("surface for channel pool sizing is not yet stable") +@AutoValue +public abstract class ChannelPoolSettings { + /** How often to check and possibly resize the {@link ChannelPool}. */ + static final Duration RESIZE_INTERVAL = Duration.ofMinutes(1); + /** The maximum number of channels that can be added or removed at a time. */ + static final int MAX_RESIZE_DELTA = 2; + + /** + * Threshold to start scaling down the channel pool. + * + *

When the average of the maximum number of outstanding RPCs in a single minute drop below + * this threshold, channels will be removed from the pool. + */ + public abstract int getMinRpcsPerChannel(); + + /** + * Threshold to start scaling up the channel pool. + * + *

When the average of the maximum number of outstanding RPCs in a single minute surpass this + * threshold, channels will be added to the pool. + */ + public abstract int getMaxRpcsPerChannel(); + + /** + * The absolute minimum size of the channel pool. + * + *

Regardless of the current throughput, the number of channels will not drop below this limit + */ + public abstract int getMinChannelCount(); + + /** + * The absolute maximum size of the channel pool. + * + *

Regardless of the current throughput, the number of channels will not exceed this limit + */ + public abstract int getMaxChannelCount(); + + /** + * The initial size of the channel pool. + * + *

During client construction the client open this many connections. This will be scaled up or + * down in the next period. + */ + public abstract int getInitialChannelCount(); + + /** + * If all of the channels should be replaced on an hourly basis. + * + *

The GFE will forcibly disconnect active channels after an hour. To minimize the cost of + * reconnects, this will create a new channel asynchronuously, prime it and then swap it with an + * old channel. + */ + public abstract boolean isPreemptiveReconnectEnabled(); + + /** Helper to check if the {@link ChannelPool} implementation can skip dynamic size logic */ + boolean isStaticSize() { + // When range is restricted to a single size + if (getMinChannelCount() == getMaxChannelCount()) { + return true; + } + // When the scaling threshold are not set + if (getMinRpcsPerChannel() == 0 && getMaxRpcsPerChannel() == Integer.MAX_VALUE) { + return true; + } + + return false; + } + + public abstract Builder toBuilder(); + + public static ChannelPoolSettings staticallySized(int size) { + return builder() + .setInitialChannelCount(size) + .setMinRpcsPerChannel(0) + .setMaxRpcsPerChannel(Integer.MAX_VALUE) + .setMinChannelCount(size) + .setMaxChannelCount(size) + .build(); + } + + public static Builder builder() { + return new AutoValue_ChannelPoolSettings.Builder() + .setInitialChannelCount(1) + .setMinChannelCount(1) + .setMaxChannelCount(200) + .setMinRpcsPerChannel(0) + .setMaxRpcsPerChannel(Integer.MAX_VALUE) + .setPreemptiveReconnectEnabled(false); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMinRpcsPerChannel(int count); + + public abstract Builder setMaxRpcsPerChannel(int count); + + public abstract Builder setMinChannelCount(int count); + + public abstract Builder setMaxChannelCount(int count); + + public abstract Builder setInitialChannelCount(int count); + + public abstract Builder setPreemptiveReconnectEnabled(boolean enabled); + + abstract ChannelPoolSettings autoBuild(); + + public ChannelPoolSettings build() { + ChannelPoolSettings s = autoBuild(); + + Preconditions.checkState( + s.getMinRpcsPerChannel() <= s.getMaxRpcsPerChannel(), "rpcsPerChannel range is invalid"); + Preconditions.checkState( + s.getMinChannelCount() > 0, "Minimum channel count must be at least 1"); + Preconditions.checkState( + s.getMinChannelCount() <= s.getMaxRpcsPerChannel(), "absolute channel range is invalid"); + Preconditions.checkState( + s.getMinChannelCount() <= s.getInitialChannelCount() + && s.getInitialChannelCount() <= s.getMaxChannelCount(), + "initial channel count must be with the absolute channel count range"); + Preconditions.checkState( + s.getInitialChannelCount() > 0, "Initial channel count must be greater than 0"); + return s; + } + } +} diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 1aeb6b179..ce2876785 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -42,7 +42,6 @@ import com.google.auth.Credentials; import com.google.auth.oauth2.ComputeEngineCredentials; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -102,7 +101,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @Nullable private final Duration keepAliveTime; @Nullable private final Duration keepAliveTimeout; @Nullable private final Boolean keepAliveWithoutCalls; - @Nullable private final Integer poolSize; + private final ChannelPoolSettings channelPoolSettings; @Nullable private final Credentials credentials; @Nullable private final ChannelPrimer channelPrimer; @Nullable private final Boolean attemptDirectPath; @@ -126,7 +125,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) { this.keepAliveTime = builder.keepAliveTime; this.keepAliveTimeout = builder.keepAliveTimeout; this.keepAliveWithoutCalls = builder.keepAliveWithoutCalls; - this.poolSize = builder.poolSize; + this.channelPoolSettings = builder.channelPoolSettings; this.channelConfigurator = builder.channelConfigurator; this.credentials = builder.credentials; this.channelPrimer = builder.channelPrimer; @@ -195,16 +194,17 @@ public TransportChannelProvider withEndpoint(String endpoint) { return toBuilder().setEndpoint(endpoint).build(); } + /** @deprecated Please modify pool settings via {@link #toBuilder()} */ + @Deprecated @Override - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") public boolean acceptsPoolSize() { - return poolSize == null; + return true; } + /** @deprecated Please modify pool settings via {@link #toBuilder()} */ + @Deprecated @Override - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") public TransportChannelProvider withPoolSize(int size) { - Preconditions.checkState(acceptsPoolSize(), "pool size already set to %s", poolSize); return toBuilder().setPoolSize(size).build(); } @@ -230,26 +230,9 @@ public TransportChannel getTransportChannel() throws IOException { } private TransportChannel createChannel() throws IOException { - - int realPoolSize = MoreObjects.firstNonNull(poolSize, 1); - ChannelFactory channelFactory = - new ChannelFactory() { - @Override - public ManagedChannel createSingleChannel() throws IOException { - try { - return InstantiatingGrpcChannelProvider.this.createSingleChannel(); - } catch (GeneralSecurityException e) { - throw new IOException(e); - } - } - }; - ManagedChannel outerChannel; - if (channelPrimer != null) { - outerChannel = ChannelPool.createRefreshing(realPoolSize, channelFactory); - } else { - outerChannel = ChannelPool.create(realPoolSize, channelFactory); - } - return GrpcTransportChannel.create(outerChannel); + return GrpcTransportChannel.create( + ChannelPool.create( + channelPoolSettings, InstantiatingGrpcChannelProvider.this::createSingleChannel)); } // TODO(mohanli): Use attemptDirectPath as the only indicator once setAttemptDirectPath is adapted @@ -316,7 +299,7 @@ ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSec return null; } - private ManagedChannel createSingleChannel() throws IOException, GeneralSecurityException { + private ManagedChannel createSingleChannel() throws IOException { GrpcHeaderInterceptor headerInterceptor = new GrpcHeaderInterceptor(headerProvider.getHeaders()); GrpcMetadataHandlerInterceptor metadataHandlerInterceptor = @@ -350,7 +333,12 @@ && isOnComputeEngine()) { builder.keepAliveTime(DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS, TimeUnit.SECONDS); builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS); } else { - ChannelCredentials channelCredentials = createMtlsChannelCredentials(); + ChannelCredentials channelCredentials; + try { + channelCredentials = createMtlsChannelCredentials(); + } catch (GeneralSecurityException e) { + throw new IOException(e); + } if (channelCredentials != null) { builder = Grpc.newChannelBuilder(endpoint, channelCredentials); } else { @@ -439,7 +427,7 @@ public static Builder newBuilder() { } public static final class Builder { - private int processorCount; + @Deprecated private int processorCount; private Executor executor; private HeaderProvider headerProvider; private String endpoint; @@ -451,10 +439,10 @@ public static final class Builder { @Nullable private Duration keepAliveTime; @Nullable private Duration keepAliveTimeout; @Nullable private Boolean keepAliveWithoutCalls; - @Nullable private Integer poolSize; @Nullable private ApiFunction channelConfigurator; @Nullable private Credentials credentials; @Nullable private ChannelPrimer channelPrimer; + private ChannelPoolSettings channelPoolSettings; @Nullable private Boolean attemptDirectPath; @Nullable private Boolean allowNonDefaultServiceAccount; @Nullable private ImmutableMap directPathServiceConfig; @@ -462,6 +450,7 @@ public static final class Builder { private Builder() { processorCount = Runtime.getRuntime().availableProcessors(); envProvider = DirectPathEnvironmentProvider.getInstance(); + channelPoolSettings = ChannelPoolSettings.staticallySized(1); } private Builder(InstantiatingGrpcChannelProvider provider) { @@ -476,17 +465,23 @@ private Builder(InstantiatingGrpcChannelProvider provider) { this.keepAliveTime = provider.keepAliveTime; this.keepAliveTimeout = provider.keepAliveTimeout; this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls; - this.poolSize = provider.poolSize; this.channelConfigurator = provider.channelConfigurator; this.credentials = provider.credentials; this.channelPrimer = provider.channelPrimer; + this.channelPoolSettings = provider.channelPoolSettings; this.attemptDirectPath = provider.attemptDirectPath; this.allowNonDefaultServiceAccount = provider.allowNonDefaultServiceAccount; this.directPathServiceConfig = provider.directPathServiceConfig; this.mtlsProvider = provider.mtlsProvider; } - /** Sets the number of available CPUs, used internally for testing. */ + /** + * Sets the number of available CPUs, used internally for testing. + * + * @deprecated CPU based channel scaling is deprecated, please use RPC based scaling instead via + * {@link Builder#setChannelPoolSettings(ChannelPoolSettings)} + */ + @Deprecated Builder setProcessorCount(int processorCount) { this.processorCount = processorCount; return this; @@ -611,34 +606,27 @@ public Boolean getKeepAliveWithoutCalls() { return keepAliveWithoutCalls; } - /** - * Number of underlying grpc channels to open. Calls will be load balanced round robin across - * them. - */ + /** @deprecated Please use {@link #setChannelPoolSettings(ChannelPoolSettings)} */ + @Deprecated public int getPoolSize() { - if (poolSize == null) { - return 1; - } - return poolSize; + return channelPoolSettings.getInitialChannelCount(); } - /** - * Number of underlying grpc channels to open. Calls will be load balanced round robin across - * them - */ + /** @deprecated Please use {@link #setChannelPoolSettings(ChannelPoolSettings)} */ + @Deprecated public Builder setPoolSize(int poolSize) { - Preconditions.checkArgument(poolSize > 0, "Pool size must be positive"); - Preconditions.checkArgument( - poolSize <= MAX_POOL_SIZE, "Pool size must be less than %s", MAX_POOL_SIZE); - this.poolSize = poolSize; + channelPoolSettings = ChannelPoolSettings.staticallySized(poolSize); return this; } - /** Sets the number of channels relative to the available CPUs. */ + /** @deprecated Please use {@link #setChannelPoolSettings(ChannelPoolSettings)} */ + @Deprecated public Builder setChannelsPerCpu(double multiplier) { return setChannelsPerCpu(multiplier, 100); } + /** @deprecated Please use {@link #setChannelPoolSettings(ChannelPoolSettings)} */ + @Deprecated public Builder setChannelsPerCpu(double multiplier, int maxChannels) { Preconditions.checkArgument(multiplier > 0, "multiplier must be positive"); Preconditions.checkArgument(maxChannels > 0, "maxChannels must be positive"); @@ -647,7 +635,13 @@ public Builder setChannelsPerCpu(double multiplier, int maxChannels) { if (channelCount > maxChannels) { channelCount = maxChannels; } - return setPoolSize(channelCount); + return setChannelPoolSettings(ChannelPoolSettings.staticallySized(channelCount)); + } + + @BetaApi("Channel pool sizing api is not yet stable") + public Builder setChannelPoolSettings(ChannelPoolSettings settings) { + this.channelPoolSettings = settings; + return this; } public Builder setCredentials(Credentials credentials) { diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index f9a43acce..90ba9b59b 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -29,12 +29,13 @@ */ package com.google.api.gax.grpc; +import static com.google.common.truth.Truth.assertThat; + import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeMethodDescriptor; import com.google.api.gax.grpc.testing.FakeServiceGrpc; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import com.google.common.truth.Truth; import com.google.type.Color; import com.google.type.Money; import io.grpc.CallOptions; @@ -43,6 +44,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import io.grpc.stub.ClientCalls; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -56,6 +58,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.mockito.stubbing.Answer; @@ -69,7 +72,7 @@ public void testAuthority() throws IOException { Mockito.when(sub1.authority()).thenReturn("myAuth"); ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2))); - Truth.assertThat(pool.authority()).isEqualTo("myAuth"); + assertThat(pool.authority()).isEqualTo("myAuth"); } @Test @@ -151,11 +154,11 @@ public void ensureEvenDistribution() throws InterruptedException, IOException { } executor.shutdown(); boolean shutdown = executor.awaitTermination(1, TimeUnit.MINUTES); - Truth.assertThat(shutdown).isTrue(); + assertThat(shutdown).isTrue(); int expectedCount = (numThreads * numPerThread) / numChannels; for (AtomicInteger count : counts) { - Truth.assertThat(count.get()).isAnyOf(expectedCount, expectedCount + 1); + assertThat(count.get()).isAnyOf(expectedCount, expectedCount + 1); } } @@ -193,8 +196,8 @@ public void channelPrimerIsCalledPeriodically() throws IOException { Mockito.doAnswer(extractChannelRefresher) .when(scheduledExecutorService) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + .scheduleAtFixedRate( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.anyLong(), Mockito.any()); FakeChannelFactory channelFactory = new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer); @@ -203,26 +206,16 @@ public void channelPrimerIsCalledPeriodically() throws IOException { // 1 call during the creation Mockito.verify(mockChannelPrimer, Mockito.times(1)) .primeChannel(Mockito.any(ManagedChannel.class)); - Mockito.verify(scheduledExecutorService, Mockito.times(1)) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); channelRefreshers.get(0).run(); // 1 more call during channel refresh Mockito.verify(mockChannelPrimer, Mockito.times(2)) .primeChannel(Mockito.any(ManagedChannel.class)); - Mockito.verify(scheduledExecutorService, Mockito.times(2)) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); channelRefreshers.get(0).run(); // 1 more call during channel refresh Mockito.verify(mockChannelPrimer, Mockito.times(3)) .primeChannel(Mockito.any(ManagedChannel.class)); - Mockito.verify(scheduledExecutorService, Mockito.times(3)) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); - scheduledExecutorService.shutdown(); } // ---- @@ -395,4 +388,190 @@ public void channelRefreshShouldSwapChannels() throws IOException { Mockito.verify(underlyingChannel2, Mockito.only()) .newCall(Mockito.>any(), Mockito.any(CallOptions.class)); } + + @Test + public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() throws Exception { + ScheduledExecutorService executor = Mockito.mock(ScheduledExecutorService.class); + + List channels = new ArrayList<>(); + List> startedCalls = new ArrayList<>(); + + ChannelFactory channelFactory = + () -> { + ManagedChannel channel = Mockito.mock(ManagedChannel.class); + Mockito.when(channel.newCall(Mockito.any(), Mockito.any())) + .thenAnswer( + (Answer>) + invocation -> { + @SuppressWarnings("unchecked") + ClientCall clientCall = Mockito.mock(ClientCall.class); + startedCalls.add(clientCall); + return clientCall; + }); + + channels.add(channel); + return channel; + }; + + ChannelPool pool = + new ChannelPool( + ChannelPoolSettings.builder() + .setInitialChannelCount(2) + .setMinRpcsPerChannel(1) + .setMaxRpcsPerChannel(2) + .build(), + channelFactory, + executor); + assertThat(pool.entries.get()).hasSize(2); + + // Start the minimum number of + for (int i = 0; i < 2; i++) { + ClientCalls.futureUnaryCall( + pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT), + Color.getDefaultInstance()); + } + pool.resize(); + assertThat(pool.entries.get()).hasSize(2); + + // Add enough RPCs to be just at the brink of expansion + for (int i = startedCalls.size(); i < 4; i++) { + ClientCalls.futureUnaryCall( + pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT), + Color.getDefaultInstance()); + } + pool.resize(); + assertThat(pool.entries.get()).hasSize(2); + + // Add another RPC to push expansion + pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT); + pool.resize(); + assertThat(pool.entries.get()).hasSize(4); // += ChannelPool::MAX_RESIZE_DELTA + assertThat(startedCalls).hasSize(5); + + // Complete RPCs to the brink of shrinking + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + Mockito.verify(startedCalls.remove(0)).start(captor.capture(), Mockito.any()); + captor.getValue().onClose(Status.ABORTED, new Metadata()); + // Resize twice: the first round maintains the peak from the last cycle + pool.resize(); + pool.resize(); + assertThat(pool.entries.get()).hasSize(4); + assertThat(startedCalls).hasSize(4); + + // Complete another RPC to trigger shrinking + Mockito.verify(startedCalls.remove(0)).start(captor.capture(), Mockito.any()); + captor.getValue().onClose(Status.ABORTED, new Metadata()); + // Resize twice: the first round maintains the peak from the last cycle + pool.resize(); + pool.resize(); + assertThat(startedCalls).hasSize(3); + // range of channels is [2-3] rounded down average is 2 + assertThat(pool.entries.get()).hasSize(2); + } + + @Test + public void removedIdleChannelsAreShutdown() throws Exception { + ScheduledExecutorService executor = Mockito.mock(ScheduledExecutorService.class); + + List channels = new ArrayList<>(); + List> startedCalls = new ArrayList<>(); + + ChannelFactory channelFactory = + () -> { + ManagedChannel channel = Mockito.mock(ManagedChannel.class); + Mockito.when(channel.newCall(Mockito.any(), Mockito.any())) + .thenAnswer( + (Answer>) + invocation -> { + @SuppressWarnings("unchecked") + ClientCall clientCall = Mockito.mock(ClientCall.class); + startedCalls.add(clientCall); + return clientCall; + }); + + channels.add(channel); + return channel; + }; + + ChannelPool pool = + new ChannelPool( + ChannelPoolSettings.builder() + .setInitialChannelCount(2) + .setMinRpcsPerChannel(1) + .setMaxRpcsPerChannel(2) + .build(), + channelFactory, + executor); + assertThat(pool.entries.get()).hasSize(2); + + // With no outstanding RPCs, the pool should shrink + pool.resize(); + assertThat(pool.entries.get()).hasSize(1); + Mockito.verify(channels.get(1), Mockito.times(1)).shutdown(); + } + + @Test + public void removedActiveChannelsAreShutdown() throws Exception { + ScheduledExecutorService executor = Mockito.mock(ScheduledExecutorService.class); + + List channels = new ArrayList<>(); + List> startedCalls = new ArrayList<>(); + + ChannelFactory channelFactory = + () -> { + ManagedChannel channel = Mockito.mock(ManagedChannel.class); + Mockito.when(channel.newCall(Mockito.any(), Mockito.any())) + .thenAnswer( + (Answer>) + invocation -> { + @SuppressWarnings("unchecked") + ClientCall clientCall = Mockito.mock(ClientCall.class); + startedCalls.add(clientCall); + return clientCall; + }); + + channels.add(channel); + return channel; + }; + + ChannelPool pool = + new ChannelPool( + ChannelPoolSettings.builder() + .setInitialChannelCount(2) + .setMinRpcsPerChannel(1) + .setMaxRpcsPerChannel(2) + .build(), + channelFactory, + executor); + assertThat(pool.entries.get()).hasSize(2); + + // Start 2 RPCs + for (int i = 0; i < 2; i++) { + ClientCalls.futureUnaryCall( + pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT), + Color.getDefaultInstance()); + } + // Complete the first one + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + Mockito.verify(startedCalls.get(0)).start(captor.capture(), Mockito.any()); + captor.getValue().onClose(Status.ABORTED, new Metadata()); + + // With a single RPC, the pool should shrink + pool.resize(); + pool.resize(); + assertThat(pool.entries.get()).hasSize(1); + + // While the RPC is outstanding, the channel should still be open + Mockito.verify(channels.get(1), Mockito.never()).shutdown(); + + // Complete the RPC + Mockito.verify(startedCalls.get(1)).start(captor.capture(), Mockito.any()); + captor.getValue().onClose(Status.ABORTED, new Metadata()); + // Now the channel should be closed + Mockito.verify(channels.get(1), Mockito.times(1)).shutdown(); + } } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index a3f88d554..a2f4ebd76 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -32,7 +32,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; import com.google.api.core.ApiFunction; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; @@ -149,15 +148,7 @@ public void testWithPoolSize() throws IOException { provider.getTransportChannel().shutdownNow(); provider = provider.withPoolSize(2); - assertThat(provider.acceptsPoolSize()).isFalse(); provider.getTransportChannel().shutdownNow(); - - try { - provider.withPoolSize(3); - fail("acceptsPoolSize() returned false; we shouldn't be able to set it again"); - } catch (IllegalStateException e) { - - } } @Test diff --git a/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java b/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java index 2e4ff935b..c349d543b 100644 --- a/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java +++ b/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java @@ -134,14 +134,16 @@ public TransportChannelProvider withEndpoint(String endpoint) { return toBuilder().setEndpoint(endpoint).build(); } + /** @deprecated REST transport channel doesn't support channel pooling */ + @Deprecated @Override - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") public boolean acceptsPoolSize() { return false; } + /** @deprecated REST transport channel doesn't support channel pooling */ + @Deprecated @Override - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") public TransportChannelProvider withPoolSize(int size) { throw new UnsupportedOperationException( "InstantiatingHttpJsonChannelProvider doesn't allow pool size customization"); diff --git a/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java b/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java index c2f08fc62..18e25f3f8 100644 --- a/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java +++ b/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java @@ -92,14 +92,16 @@ public TransportChannelProvider withEndpoint(String endpoint) { "FixedTransportChannelProvider doesn't need an endpoint"); } + /** @deprecated FixedTransportChannelProvider doesn't support ChannelPool configuration */ + @Deprecated @Override - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") public boolean acceptsPoolSize() { return false; } + /** @deprecated FixedTransportChannelProvider doesn't support ChannelPool configuration */ + @Deprecated @Override - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") public TransportChannelProvider withPoolSize(int size) { throw new UnsupportedOperationException( "FixedTransportChannelProvider doesn't allow pool size customization"); diff --git a/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java b/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java index 160adf02f..d700d380f 100644 --- a/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java +++ b/gax/src/main/java/com/google/api/gax/rpc/TransportChannelProvider.java @@ -99,12 +99,20 @@ public interface TransportChannelProvider { */ TransportChannelProvider withEndpoint(String endpoint); - /** Reports whether this provider allows pool size customization. */ - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") + /** + * Reports whether this provider allows pool size customization. + * + * @deprecated Pool settings should be configured on the builder of the specific implementation. + */ + @Deprecated boolean acceptsPoolSize(); - /** Number of underlying transport channels to open. Calls will be load balanced across them. */ - @BetaApi("The surface for customizing pool size is not stable yet and may change in the future.") + /** + * Number of underlying transport channels to open. Calls will be load balanced across them. + * + * @deprecated Pool settings should be configured on the builder of the specific implementation. + */ + @Deprecated TransportChannelProvider withPoolSize(int size); /** True if credentials are needed before channel creation. */ From 5f65a37da83c2c80047c3e7c5fbd773528043fd6 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Fri, 7 Jan 2022 12:33:09 -0500 Subject: [PATCH 13/17] fix test after broken merge --- .../src/main/java/com/google/api/gax/grpc/ChannelPool.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index f11514989..94ba5fdcb 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -74,7 +74,8 @@ class ChannelPool extends ManagedChannel { private final ScheduledExecutorService executor; private final Object entryWriteLock = new Object(); - private final AtomicReference> entries = new AtomicReference<>(); + @VisibleForTesting + final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; From 10c0cf042052f5ec508bc10be254a795cca00480 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Fri, 7 Jan 2022 12:33:50 -0500 Subject: [PATCH 14/17] format --- .../src/main/java/com/google/api/gax/grpc/ChannelPool.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 94ba5fdcb..3cde997d5 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -74,8 +74,7 @@ class ChannelPool extends ManagedChannel { private final ScheduledExecutorService executor; private final Object entryWriteLock = new Object(); - @VisibleForTesting - final AtomicReference> entries = new AtomicReference<>(); + @VisibleForTesting final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; From 3381e9ec568fef71cfe4bab30450dccae4b93561 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Fri, 11 Feb 2022 15:44:39 -0500 Subject: [PATCH 15/17] address feedback --- .../com/google/api/gax/grpc/ChannelPool.java | 31 +++++++++---------- .../api/gax/grpc/ChannelPoolSettings.java | 17 +++++----- .../InstantiatingGrpcChannelProvider.java | 2 -- .../google/api/gax/grpc/ChannelPoolTest.java | 13 ++++---- 4 files changed, 30 insertions(+), 33 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 3cde997d5..2827fa262 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -107,7 +107,7 @@ static ChannelPool createRefreshing( return new ChannelPool( ChannelPoolSettings.staticallySized(poolSize) .toBuilder() - .setPreemptiveReconnectEnabled(true) + .setPreemptiveRefreshEnabled(true) .build(), channelFactory, executor); @@ -123,9 +123,9 @@ static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFa * * @param settings options for controling the ChannelPool sizing behavior * @param channelFactory method to create the channels - * @param executor periodically refreshes the channels. Must be single threaded + * @param executor periodically refreshes the channels */ - @InternalApi("VisibleForTesting") + @VisibleForTesting ChannelPool( ChannelPoolSettings settings, ChannelFactory channelFactory, @@ -151,7 +151,7 @@ static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFa ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(), TimeUnit.SECONDS); } - if (settings.isPreemptiveReconnectEnabled()) { + if (settings.isPreemptiveRefreshEnabled()) { executor.scheduleAtFixedRate( this::refreshSafely, REFRESH_PERIOD.getSeconds(), @@ -253,7 +253,7 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE return isTerminated(); } - void resizeSafely() { + private void resizeSafely() { try { synchronized (entryWriteLock) { resize(); @@ -277,6 +277,7 @@ void resizeSafely() { * *

Not threadsafe, must be called under the entryWriteLock monitor */ + @VisibleForTesting void resize() { List localEntries = entries.get(); // Estimate the peak of RPCs in the last interval by summing the peak of RPCs per channel @@ -292,6 +293,7 @@ void resize() { } // Number of channels if each channel operated at minimum capacity + // Note: getMinRpcsPerChannel() can return 0, but division by 0 shouldn't cause a problem. int maxChannels = (int) Math.ceil(actualOutstandingRpcs / (double) settings.getMinRpcsPerChannel()); // Limit the threshold to absolute range @@ -320,25 +322,25 @@ void resize() { "Detected throughput peak of %d, expanding channel pool size: %d -> %d.", actualOutstandingRpcs, currentSize, dampenedTarget)); - expand(tentativeTarget); + expand(dampenedTarget); } else if (localEntries.size() > maxChannels) { LOG.fine( String.format( "Detected throughput drop to %d, shrinking channel pool size: %d -> %d.", actualOutstandingRpcs, currentSize, dampenedTarget)); - shrink(tentativeTarget); + shrink(dampenedTarget); } } /** Not threadsafe, must be called under the entryWriteLock monitor */ private void shrink(int desiredSize) { - List localEntries = entries.get(); + ImmutableList localEntries = entries.get(); Preconditions.checkState( - localEntries.size() >= desiredSize, "desired size is already smaller than the current"); + localEntries.size() >= desiredSize, "current size is already smaller than the desired"); // Set the new list - entries.set(ImmutableList.copyOf(localEntries.subList(0, desiredSize))); + entries.set(localEntries.subList(0, desiredSize)); // clean up removed entries List removed = localEntries.subList(desiredSize, localEntries.size()); removed.forEach(Entry::requestShutdown); @@ -348,7 +350,7 @@ private void shrink(int desiredSize) { private void expand(int desiredSize) { List localEntries = entries.get(); Preconditions.checkState( - localEntries.size() <= desiredSize, "desired size is already bigger than the current"); + localEntries.size() <= desiredSize, "current size is already bigger than the desired"); ImmutableList.Builder newEntries = ImmutableList.builder().addAll(localEntries); @@ -441,12 +443,7 @@ Entry getRetainedEntry(int affinity) { private Entry getEntry(int affinity) { List localEntries = entries.get(); - int index = affinity % localEntries.size(); - index = Math.abs(index); - // If index is the most negative int, abs(index) is still negative. - if (index < 0) { - index = 0; - } + int index = Math.abs(affinity % localEntries.size()); return localEntries.get(index); } diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java index 383be15ef..19e62782a 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java @@ -72,7 +72,8 @@ public abstract class ChannelPoolSettings { * Threshold to start scaling up the channel pool. * *

When the average of the maximum number of outstanding RPCs in a single minute surpass this - * threshold, channels will be added to the pool. + * threshold, channels will be added to the pool. For google services, gRPC channels will start + * locally queuing RPC when there are 100 concurrent RPCs. */ public abstract int getMaxRpcsPerChannel(); @@ -105,7 +106,7 @@ public abstract class ChannelPoolSettings { * reconnects, this will create a new channel asynchronuously, prime it and then swap it with an * old channel. */ - public abstract boolean isPreemptiveReconnectEnabled(); + public abstract boolean isPreemptiveRefreshEnabled(); /** Helper to check if the {@link ChannelPool} implementation can skip dynamic size logic */ boolean isStaticSize() { @@ -140,7 +141,7 @@ public static Builder builder() { .setMaxChannelCount(200) .setMinRpcsPerChannel(0) .setMaxRpcsPerChannel(Integer.MAX_VALUE) - .setPreemptiveReconnectEnabled(false); + .setPreemptiveRefreshEnabled(false); } @AutoValue.Builder @@ -155,7 +156,7 @@ public abstract static class Builder { public abstract Builder setInitialChannelCount(int count); - public abstract Builder setPreemptiveReconnectEnabled(boolean enabled); + public abstract Builder setPreemptiveRefreshEnabled(boolean enabled); abstract ChannelPoolSettings autoBuild(); @@ -169,9 +170,11 @@ public ChannelPoolSettings build() { Preconditions.checkState( s.getMinChannelCount() <= s.getMaxRpcsPerChannel(), "absolute channel range is invalid"); Preconditions.checkState( - s.getMinChannelCount() <= s.getInitialChannelCount() - && s.getInitialChannelCount() <= s.getMaxChannelCount(), - "initial channel count must be with the absolute channel count range"); + s.getMinChannelCount() <= s.getInitialChannelCount(), + "initial channel count be at least minChannelCount"); + Preconditions.checkState( + s.getInitialChannelCount() <= s.getMaxChannelCount(), + "initial channel count must be less than maxChannelCount"); Preconditions.checkState( s.getInitialChannelCount() > 0, "Initial channel count must be greater than 0"); return s; diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 5ba1bec22..c309019a0 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -87,8 +87,6 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP private static final String DIRECT_PATH_ENV_ENABLE_XDS = "GOOGLE_CLOUD_ENABLE_DIRECT_PATH_XDS"; static final long DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS = 3600; static final long DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS = 20; - // reduce the thundering herd problem of too many channels trying to (re)connect at the same time - static final int MAX_POOL_SIZE = 1000; static final String GCE_PRODUCTION_NAME_PRIOR_2016 = "Google"; static final String GCE_PRODUCTION_NAME_AFTER_2016 = "Google Compute Engine"; diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 90ba9b59b..747fc4b0d 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -401,13 +401,12 @@ public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() thro ManagedChannel channel = Mockito.mock(ManagedChannel.class); Mockito.when(channel.newCall(Mockito.any(), Mockito.any())) .thenAnswer( - (Answer>) - invocation -> { - @SuppressWarnings("unchecked") - ClientCall clientCall = Mockito.mock(ClientCall.class); - startedCalls.add(clientCall); - return clientCall; - }); + invocation -> { + @SuppressWarnings("unchecked") + ClientCall clientCall = Mockito.mock(ClientCall.class); + startedCalls.add(clientCall); + return clientCall; + }); channels.add(channel); return channel; From 71d177b979f0c22479a97abc548cbe8bddc72240 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Mon, 14 Feb 2022 13:49:16 -0500 Subject: [PATCH 16/17] remove unused import --- .../com/google/api/gax/rpc/FixedTransportChannelProvider.java | 1 - 1 file changed, 1 deletion(-) diff --git a/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java b/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java index 7a9d0fd13..0bf6205dd 100644 --- a/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java +++ b/gax/src/main/java/com/google/api/gax/rpc/FixedTransportChannelProvider.java @@ -29,7 +29,6 @@ */ package com.google.api.gax.rpc; -import com.google.api.core.BetaApi; import com.google.api.core.InternalExtensionOnly; import com.google.auth.Credentials; import com.google.common.base.Preconditions; From d82c6b836a5e562a440d70570504f73b53c61895 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Tue, 15 Feb 2022 11:28:23 -0500 Subject: [PATCH 17/17] inline old factory methods --- .../com/google/api/gax/grpc/ChannelPool.java | 35 ---------- .../google/api/gax/grpc/ChannelPoolTest.java | 68 ++++++++++++------- .../api/gax/grpc/GrpcClientCallsTest.java | 5 +- 3 files changed, 49 insertions(+), 59 deletions(-) diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 2827fa262..5215b4d9b 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -78,41 +78,6 @@ class ChannelPool extends ManagedChannel { private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; - /** - * Factory method to create a non-refreshing channel pool - * - * @param poolSize number of channels in the pool - * @param channelFactory method to create the channels - * @return ChannelPool of non-refreshing channels - */ - @VisibleForTesting - static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IOException { - return new ChannelPool(ChannelPoolSettings.staticallySized(poolSize), channelFactory, null); - } - - /** - * Factory method to create a refreshing channel pool - * - *

Package-private for testing purposes only - * - * @param poolSize number of channels in the pool - * @param channelFactory method to create the channels - * @param executor used to schedule maintenance tasks like refresh channels and resizing the pool. - * @return ChannelPool of refreshing channels - */ - @VisibleForTesting - static ChannelPool createRefreshing( - int poolSize, ChannelFactory channelFactory, ScheduledExecutorService executor) - throws IOException { - return new ChannelPool( - ChannelPoolSettings.staticallySized(poolSize) - .toBuilder() - .setPreemptiveRefreshEnabled(true) - .build(), - channelFactory, - executor); - } - static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFactory) throws IOException { return new ChannelPool(settings, channelFactory, Executors.newSingleThreadScheduledExecutor()); diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 747fc4b0d..1bf472653 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -71,7 +71,10 @@ public void testAuthority() throws IOException { Mockito.when(sub1.authority()).thenReturn("myAuth"); - ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2))); + ChannelPool pool = + ChannelPool.create( + ChannelPoolSettings.staticallySized(2), + new FakeChannelFactory(Arrays.asList(sub1, sub2))); assertThat(pool.authority()).isEqualTo("myAuth"); } @@ -83,7 +86,9 @@ public void testRoundRobin() throws IOException { Mockito.when(sub1.authority()).thenReturn("myAuth"); ArrayList channels = Lists.newArrayList(sub1, sub2); - ChannelPool pool = ChannelPool.create(channels.size(), new FakeChannelFactory(channels)); + ChannelPool pool = + ChannelPool.create( + ChannelPoolSettings.staticallySized(channels.size()), new FakeChannelFactory(channels)); verifyTargetChannel(pool, channels, sub1); verifyTargetChannel(pool, channels, sub2); @@ -138,7 +143,9 @@ public void ensureEvenDistribution() throws InterruptedException, IOException { } final ChannelPool pool = - ChannelPool.create(numChannels, new FakeChannelFactory(Arrays.asList(channels))); + ChannelPool.create( + ChannelPoolSettings.staticallySized(numChannels), + new FakeChannelFactory(Arrays.asList(channels))); int numThreads = 20; final int numPerThread = 1000; @@ -170,7 +177,11 @@ public void channelPrimerShouldCallPoolConstruction() throws IOException { ManagedChannel channel2 = Mockito.mock(ManagedChannel.class); ChannelPool.create( - 2, new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer)); + ChannelPoolSettings.staticallySized(2) + .toBuilder() + .setPreemptiveRefreshEnabled(true) + .build(), + new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer)); Mockito.verify(mockChannelPrimer, Mockito.times(2)) .primeChannel(Mockito.any(ManagedChannel.class)); } @@ -202,7 +213,13 @@ public void channelPrimerIsCalledPeriodically() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer); - ChannelPool.createRefreshing(1, channelFactory, scheduledExecutorService); + new ChannelPool( + ChannelPoolSettings.staticallySized(1) + .toBuilder() + .setPreemptiveRefreshEnabled(true) + .build(), + channelFactory, + scheduledExecutorService); // 1 call during the creation Mockito.verify(mockChannelPrimer, Mockito.times(1)) .primeChannel(Mockito.any(ManagedChannel.class)); @@ -226,7 +243,7 @@ public void callShouldCompleteAfterCreation() throws IOException { ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class); FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); - ChannelPool pool = ChannelPool.create(1, channelFactory); + ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); // create a mock call when new call comes to the underlying channel MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); @@ -275,7 +292,7 @@ public void callShouldCompleteAfterStarted() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); - ChannelPool pool = ChannelPool.create(1, channelFactory); + ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); // create a mock call when new call comes to the underlying channel MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); @@ -320,7 +337,7 @@ public void channelShouldShutdown() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel)); - ChannelPool pool = ChannelPool.create(1, channelFactory); + ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory); // create a mock call when new call comes to the underlying channel MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); @@ -372,7 +389,14 @@ public void channelRefreshShouldSwapChannels() throws IOException { FakeChannelFactory channelFactory = new FakeChannelFactory(ImmutableList.of(underlyingChannel1, underlyingChannel2)); - ChannelPool pool = ChannelPool.createRefreshing(1, channelFactory, scheduledExecutorService); + ChannelPool pool = + new ChannelPool( + ChannelPoolSettings.staticallySized(1) + .toBuilder() + .setPreemptiveRefreshEnabled(true) + .build(), + channelFactory, + scheduledExecutorService); Mockito.reset(underlyingChannel1); pool.newCall(FakeMethodDescriptor.create(), CallOptions.DEFAULT); @@ -482,13 +506,12 @@ public void removedIdleChannelsAreShutdown() throws Exception { ManagedChannel channel = Mockito.mock(ManagedChannel.class); Mockito.when(channel.newCall(Mockito.any(), Mockito.any())) .thenAnswer( - (Answer>) - invocation -> { - @SuppressWarnings("unchecked") - ClientCall clientCall = Mockito.mock(ClientCall.class); - startedCalls.add(clientCall); - return clientCall; - }); + invocation -> { + @SuppressWarnings("unchecked") + ClientCall clientCall = Mockito.mock(ClientCall.class); + startedCalls.add(clientCall); + return clientCall; + }); channels.add(channel); return channel; @@ -523,13 +546,12 @@ public void removedActiveChannelsAreShutdown() throws Exception { ManagedChannel channel = Mockito.mock(ManagedChannel.class); Mockito.when(channel.newCall(Mockito.any(), Mockito.any())) .thenAnswer( - (Answer>) - invocation -> { - @SuppressWarnings("unchecked") - ClientCall clientCall = Mockito.mock(ClientCall.class); - startedCalls.add(clientCall); - return clientCall; - }); + invocation -> { + @SuppressWarnings("unchecked") + ClientCall clientCall = Mockito.mock(ClientCall.class); + startedCalls.add(clientCall); + return clientCall; + }); channels.add(channel); return channel; diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java index 440d57209..fcdea5afe 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java @@ -74,7 +74,10 @@ public void testAffinity() throws IOException { .thenReturn(clientCall0); Mockito.when(channel1.newCall(Mockito.eq(descriptor), Mockito.any())) .thenReturn(clientCall1); - Channel pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(channel0, channel1))); + Channel pool = + ChannelPool.create( + ChannelPoolSettings.staticallySized(2), + new FakeChannelFactory(Arrays.asList(channel0, channel1))); GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool); ClientCall gotCallA =