diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 6060b2278d6c..ca2412d818e4 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -372,16 +372,18 @@ public GapicSpannerRpc(final SpannerOptions options) { options, headerProviderWithUserAgent, isEnableDirectAccess); GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator = createGrpcGcpEndpointChannelConfigurator(defaultChannelProviderBuilder, options); - maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); - - if (options.getChannelProvider() == null - && isEnableDirectAccess - && options.isEnableGcpFallback()) { + boolean useGcpFallback = + options.getChannelProvider() == null + && isEnableDirectAccess + && options.isEnableGcpFallback(); + if (useGcpFallback) { setupGcpFallback( defaultChannelProviderBuilder, options, headerProviderWithUserAgent, credentialsProvider); + } else { + maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); } boolean enableLocationApi = options.isEnableLocationApi(); @@ -656,8 +658,11 @@ private void setupGcpFallback( final HeaderProvider headerProviderWithUserAgent, final CredentialsProvider credentialsProvider) { InstantiatingGrpcChannelProvider.Builder cloudPathProviderBuilder = - createChannelProviderBuilder( + createBaseChannelProviderBuilder( options, headerProviderWithUserAgent, /* isEnableDirectAccess= */ false); + if (options.isGrpcGcpExtensionEnabled()) { + cloudPathProviderBuilder.setPoolSize(1); + } InstantiatingGrpcChannelProvider cloudPathProvider = cloudPathProviderBuilder.build(); ManagedChannelBuilder cloudPathBuilder; @@ -689,6 +694,9 @@ public ClientCall interceptCall( final ApiFunction existingConfigurator = defaultChannelProviderBuilder.getChannelConfigurator(); + if (options.isGrpcGcpExtensionEnabled()) { + defaultChannelProviderBuilder.setPoolSize(1); + } defaultChannelProviderBuilder.setChannelConfigurator( directPathBuilder -> { ManagedChannelBuilder builder = directPathBuilder; @@ -696,22 +704,24 @@ public ClientCall interceptCall( builder = existingConfigurator.apply(builder); } - String jsonApiConfig = parseGrpcGcpApiConfig(); - GcpManagedChannelOptions gcpOptions = grpcGcpOptionsWithMetricsAndDcp(options); - if (gcpOptions == null) { - gcpOptions = GcpManagedChannelOptions.newBuilder().build(); + ManagedChannelBuilder primaryBuilder = builder; + ManagedChannelBuilder fallbackBuilder = cloudPathBuilder; + if (options.isGrpcGcpExtensionEnabled()) { + String jsonApiConfig = parseGrpcGcpApiConfig(); + GcpManagedChannelOptions gcpOptions = grpcGcpOptionsWithMetricsAndDcp(options); + if (gcpOptions == null) { + gcpOptions = GcpManagedChannelOptions.newBuilder().build(); + } + primaryBuilder = + GcpManagedChannelBuilder.forDelegateBuilder(builder) + .withApiConfigJsonString(jsonApiConfig) + .withOptions(gcpOptions); + fallbackBuilder = + GcpManagedChannelBuilder.forDelegateBuilder(cloudPathBuilder) + .withApiConfigJsonString(jsonApiConfig) + .withOptions(gcpOptions); } - GcpManagedChannelBuilder primaryGcpBuilder = - GcpManagedChannelBuilder.forDelegateBuilder(builder) - .withApiConfigJsonString(jsonApiConfig) - .withOptions(gcpOptions); - - GcpManagedChannelBuilder fallbackGcpBuilder = - GcpManagedChannelBuilder.forDelegateBuilder(cloudPathBuilder) - .withApiConfigJsonString(jsonApiConfig) - .withOptions(gcpOptions); - GcpFallbackOpenTelemetry fallbackTelemetry = GcpFallbackOpenTelemetry.newBuilder() .withSdk(getFallbackOpenTelemetry(options)) @@ -720,9 +730,7 @@ public ClientCall interceptCall( .build(); return new FallbackChannelBuilder( - primaryGcpBuilder, - fallbackGcpBuilder, - createFallbackChannelOptions(fallbackTelemetry, 1)); + primaryBuilder, fallbackBuilder, createFallbackChannelOptions(fallbackTelemetry, 1)); }); } @@ -2595,15 +2603,15 @@ private static class FallbackChannelBuilder extends ForwardingChannelBuilder2 { private final GcpFallbackChannelOptions options; - private final GcpManagedChannelBuilder primaryGcpBuilder; - private final GcpManagedChannelBuilder fallbackGcpBuilder; + private final ManagedChannelBuilder primaryBuilder; + private final ManagedChannelBuilder fallbackBuilder; private FallbackChannelBuilder( - GcpManagedChannelBuilder primary, - GcpManagedChannelBuilder fallback, + ManagedChannelBuilder primary, + ManagedChannelBuilder fallback, GcpFallbackChannelOptions options) { - this.primaryGcpBuilder = primary; - this.fallbackGcpBuilder = fallback; + this.primaryBuilder = primary; + this.fallbackBuilder = fallback; this.options = options; } @@ -2613,7 +2621,7 @@ private FallbackChannelBuilder( */ @Override protected ManagedChannelBuilder delegate() { - return primaryGcpBuilder; + return primaryBuilder; } /** @@ -2622,7 +2630,7 @@ protected ManagedChannelBuilder delegate() { */ @Override public ManagedChannel build() { - return new GcpFallbackChannel(options, primaryGcpBuilder, fallbackGcpBuilder); + return new GcpFallbackChannel(options, primaryBuilder, fallbackBuilder); } } } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java index e6576d370bb0..cdb04816a012 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java @@ -103,12 +103,17 @@ import io.opentelemetry.sdk.trace.SdkTracerProvider; import io.opentelemetry.sdk.trace.samplers.Sampler; import java.io.IOException; +import java.lang.reflect.Array; +import java.lang.reflect.Modifier; import java.net.InetSocketAddress; import java.time.Duration; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -172,6 +177,25 @@ public class GapicSpannerRpcTest { new java.util.Date( System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS)))); + private static final String GRPC_GCP_CHANNEL_REF_CLASS_NAME = + "com.google.cloud.grpc.GcpManagedChannel$ChannelRef"; + + private static final class GrpcGcpObjectCounts { + int gcpManagedChannels; + int channelRefs; + + GrpcGcpObjectCounts minus(GrpcGcpObjectCounts other) { + GrpcGcpObjectCounts difference = new GrpcGcpObjectCounts(); + difference.gcpManagedChannels = gcpManagedChannels - other.gcpManagedChannels; + difference.channelRefs = channelRefs - other.channelRefs; + return difference; + } + + String debugString() { + return "GcpManagedChannel=" + gcpManagedChannels + ", ChannelRef=" + channelRefs; + } + } + private static MockSpannerServiceImpl mockSpanner; private static Server server; private static InetSocketAddress address; @@ -1409,6 +1433,134 @@ private SpannerOptions createSpannerOptions() { .build(); } + @Test + public void testDirectPathFallbackCreatesOneGrpcGcpLayerPerPath() { + SpannerOptions.useEnvironment(new SpannerOptions.SpannerEnvironment() {}); + GapicSpannerRpc rpc = null; + try { + SpannerOptions options = createDirectPathFallbackObjectCountOptions().build(); + assumeTrue( + "GCP fallback must be enabled for this DirectPath fallback test", + options.isEnableGcpFallback()); + GrpcGcpObjectCounts before = countGrpcGcpObjectsFromChannelz(); + rpc = new GapicSpannerRpc(options); + GrpcGcpObjectCounts counts = countGrpcGcpObjectsFromChannelz().minus(before); + assertEquals(counts.debugString(), 6, counts.gcpManagedChannels); + assertEquals(counts.debugString(), 48, counts.channelRefs); + } finally { + if (rpc != null) { + rpc.shutdown(); + } + SpannerOptions.useDefaultEnvironment(); + } + } + + @Test + public void testDirectPathFallbackWithGaxChannelPoolDoesNotCreateGrpcGcpChannelRefs() { + SpannerOptions.useEnvironment(new SpannerOptions.SpannerEnvironment() {}); + GapicSpannerRpc rpc = null; + try { + SpannerOptions options = + createDirectPathFallbackObjectCountOptions().disableGrpcGcpExtension().build(); + assumeTrue( + "GCP fallback must be enabled for this DirectPath fallback test", + options.isEnableGcpFallback()); + GrpcGcpObjectCounts before = countGrpcGcpObjectsFromChannelz(); + rpc = new GapicSpannerRpc(options); + GrpcGcpObjectCounts counts = countGrpcGcpObjectsFromChannelz().minus(before); + assertEquals(counts.debugString(), 0, counts.gcpManagedChannels); + assertEquals(counts.debugString(), 0, counts.channelRefs); + } finally { + if (rpc != null) { + rpc.shutdown(); + } + SpannerOptions.useDefaultEnvironment(); + } + } + + private SpannerOptions.Builder createDirectPathFallbackObjectCountOptions() { + return SpannerOptions.newBuilder() + .setProjectId("test-project") + .setEnableDirectAccess(true) + .setHost("http://localhost:1") + .setCredentials(NoCredentials.getInstance()); + } + + private static GrpcGcpObjectCounts countGrpcGcpObjectsFromChannelz() { + GrpcGcpObjectCounts counts = new GrpcGcpObjectCounts(); + Object channelz = io.grpc.InternalChannelz.instance(); + Set visited = Collections.newSetFromMap(new IdentityHashMap<>()); + countGrpcGcpObjectsFromChannelzField(channelz, "rootChannels", visited, counts); + countGrpcGcpObjectsFromChannelzField(channelz, "subchannels", visited, counts); + return counts; + } + + private static void countGrpcGcpObjectsFromChannelzField( + Object channelz, String fieldName, Set visited, GrpcGcpObjectCounts counts) { + try { + java.lang.reflect.Field field = channelz.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + countGrpcGcpObjects(field.get(channelz), visited, counts); + } catch (RuntimeException | ReflectiveOperationException ignored) { + // Ignore fields that are not reflectively accessible in this runtime. + } + } + + private static void countGrpcGcpObjects( + Object object, Set visited, GrpcGcpObjectCounts counts) { + if (object == null || !visited.add(object)) { + return; + } + if (object instanceof GcpManagedChannel) { + counts.gcpManagedChannels++; + } + Class clazz = object.getClass(); + if (clazz.getName().equals(GRPC_GCP_CHANNEL_REF_CLASS_NAME)) { + counts.channelRefs++; + } + if (object instanceof Collection) { + for (Object value : (Collection) object) { + countGrpcGcpObjects(value, visited, counts); + } + return; + } + if (object instanceof Map) { + for (Map.Entry entry : ((Map) object).entrySet()) { + countGrpcGcpObjects(entry.getKey(), visited, counts); + countGrpcGcpObjects(entry.getValue(), visited, counts); + } + return; + } + if (clazz.isArray()) { + int length = Array.getLength(object); + for (int i = 0; i < length; i++) { + countGrpcGcpObjects(Array.get(object, i), visited, counts); + } + return; + } + if (!shouldInspectFields(clazz)) { + return; + } + for (Class current = clazz; current != null; current = current.getSuperclass()) { + for (java.lang.reflect.Field field : current.getDeclaredFields()) { + if (Modifier.isStatic(field.getModifiers())) { + continue; + } + try { + field.setAccessible(true); + countGrpcGcpObjects(field.get(object), visited, counts); + } catch (RuntimeException | IllegalAccessException ignored) { + // Ignore fields that are not reflectively accessible in this runtime. + } + } + } + } + + private static boolean shouldInspectFields(Class clazz) { + String name = clazz.getName(); + return name.startsWith("com.google.") || name.startsWith("io.grpc."); + } + static class TestableGapicSpannerRpc extends GapicSpannerRpc { public TestableGapicSpannerRpc(SpannerOptions options) { super(options);