Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
package dev.openfeature.contrib.providers.flagd.resolver.common;

import dev.openfeature.contrib.providers.flagd.FlagdOptions;
import dev.openfeature.sdk.ImmutableStructure;
import dev.openfeature.sdk.ProviderEvent;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

/**
* A generic GRPC connector that manages connection states, reconnection logic, and event streaming for
* GRPC services.
* A GRPC connector that maintains a managed channel for communication with a flagd server and handles shutdown.
*/
@Slf4j
public class ChannelConnector {
Expand All @@ -29,11 +24,6 @@ public class ChannelConnector {
*/
private final long deadline;

/**
* A consumer that handles connection events such as connection loss or reconnection.
*/
private final Consumer<FlagdProviderEvent> onConnectionEvent;

/**
* Constructs a new {@code ChannelConnector} instance with the specified options and parameters.
*
Expand All @@ -45,17 +35,6 @@ public ChannelConnector(
final FlagdOptions options, final Consumer<FlagdProviderEvent> onConnectionEvent, ManagedChannel channel) {
this.channel = channel;
this.deadline = options.getDeadline();
this.onConnectionEvent = onConnectionEvent;
}

/**
* Initializes the GRPC connection by waiting for the channel to be ready and monitoring its state.
*
* @throws Exception if the channel does not reach the desired state within the deadline
*/
public void initialize() throws Exception {
log.info("Initializing GRPC connection.");
monitorChannelState(ConnectivityState.READY);
}

/**
Expand All @@ -71,27 +50,4 @@ public void shutdown() throws InterruptedException {
channel.awaitTermination(deadline, TimeUnit.MILLISECONDS);
}
}

/**
* Monitors the state of a gRPC channel and triggers the specified callbacks based on state changes.
*
* @param expectedState the initial state to monitor.
*/
private void monitorChannelState(ConnectivityState expectedState) {
channel.notifyWhenStateChanged(expectedState, this::onStateChange);
}

private void onStateChange() {
ConnectivityState currentState = channel.getState(true);
log.debug("Channel state changed to: {}", currentState);
if (currentState == ConnectivityState.TRANSIENT_FAILURE || currentState == ConnectivityState.SHUTDOWN) {
this.onConnectionEvent.accept(new FlagdProviderEvent(
ProviderEvent.PROVIDER_ERROR, Collections.emptyList(), new ImmutableStructure()));
}
if (currentState != ConnectivityState.SHUTDOWN) {
log.debug("continuing to monitor the grpc channel");
// Re-register the state monitor to watch for the next state transition.
monitorChannelState(currentState);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ protected SyncStreamQueueSource(

/** Initialize sync stream connector. */
public void init() throws Exception {
channelConnector.initialize();
Thread listener = new Thread(this::observeSyncStream);
listener.setDaemon(true);
listener.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ protected RpcResolver(
* Initialize RpcResolver resolver.
*/
public void init() throws Exception {
this.connector.initialize();

Thread listener = new Thread(() -> {
try {
observeEventStream();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,16 @@
package dev.openfeature.contrib.providers.flagd.resolver.common;
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini suggested I remove this whole test file. I agree, and I think I should also delete the entire ChannelConnector class as I said in my description, but I think I want to do that separately, as that would require small refactors around shutdown which I don't want to add to this PR.


import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.Lists;
import dev.openfeature.contrib.providers.flagd.FlagdOptions;
import dev.openfeature.flagd.grpc.evaluation.Evaluation;
import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Server;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.ArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.ArgumentCaptor;
import org.mockito.MockitoAnnotations;

class ChannelConnectorTest {
Expand Down Expand Up @@ -84,68 +62,4 @@ private void tearDownGrpcServer() throws InterruptedException {
testServer.awaitTermination();
}
}

@Test
void whenShuttingDownGrpcConnectorConsumerReceivesDisconnectedEvent() throws Exception {
CountDownLatch sync = new CountDownLatch(1);
ArrayList<Boolean> connectionStateChanges = Lists.newArrayList();
Consumer<FlagdProviderEvent> testConsumer = event -> {
connectionStateChanges.add(!event.isDisconnected());
sync.countDown();
};

ChannelConnector instance = new ChannelConnector(FlagdOptions.builder().build(), testConsumer, testChannel);

instance.initialize();
// when shutting grpc connector
instance.shutdown();

// then consumer received DISCONNECTED and CONNECTED event
boolean finished = sync.await(10, TimeUnit.SECONDS);
Assertions.assertTrue(finished);
Assertions.assertEquals(Lists.newArrayList(DISCONNECTED), connectionStateChanges);
}

@ParameterizedTest
@EnumSource(ConnectivityState.class)
void testMonitorChannelState(ConnectivityState state) throws Exception {
ManagedChannel channel = mock(ManagedChannel.class);

// Set up the expected state
ConnectivityState expectedState = ConnectivityState.IDLE;
when(channel.getState(anyBoolean())).thenReturn(state);

// Capture the callback
ArgumentCaptor<Runnable> callbackCaptor = ArgumentCaptor.forClass(Runnable.class);
doNothing().when(channel).notifyWhenStateChanged(any(), callbackCaptor.capture());

Consumer<FlagdProviderEvent> testConsumer = spy(Consumer.class);

ChannelConnector instance = new ChannelConnector(FlagdOptions.builder().build(), testConsumer, channel);

instance.initialize();

// Simulate state change
callbackCaptor.getValue().run();

// Verify the callbacks based on the state
switch (state) {
case READY:
verify(channel, times(2)).notifyWhenStateChanged(any(), any());
verify(testConsumer, never()).accept(any());
break;
case TRANSIENT_FAILURE:
verify(channel, times(2)).notifyWhenStateChanged(any(), any());
verify(testConsumer).accept(any());
break;
case SHUTDOWN:
verify(channel, times(1)).notifyWhenStateChanged(any(), any());
verify(testConsumer).accept(any());
break;
default:
verify(channel, times(2)).notifyWhenStateChanged(any(), any());
verify(testConsumer, never()).accept(any());
break;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -48,7 +47,6 @@ public void setup() throws Exception {
when(blockingStub.getMetadata(any())).thenReturn(GetMetadataResponse.getDefaultInstance());

mockConnector = mock(ChannelConnector.class);
doNothing().when(mockConnector).initialize(); // Mock the initialize method

stub = mock(FlagSyncServiceStub.class);
when(stub.withDeadlineAfter(anyLong(), any())).thenReturn(stub);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ public void init() throws Exception {
blockingStub = mock(ServiceBlockingStub.class);

mockConnector = mock(ChannelConnector.class);
doNothing().when(mockConnector).initialize(); // Mock the initialize method

stub = mock(ServiceStub.class);
when(stub.withDeadlineAfter(anyLong(), any())).thenReturn(stub);
Expand Down