From 6689d6535d710707b7509cb847d56cc742719b6e Mon Sep 17 00:00:00 2001 From: atakavci Date: Thu, 7 Nov 2024 01:17:52 +0300 Subject: [PATCH 01/21] tba draft --- .gitmodules | 4 + pom.xml | 6 ++ redis-authx | 1 + .../java/redis/clients/jedis/Connection.java | 23 ++++- .../clients/jedis/ConnectionFactory.java | 76 +++++++++++++--- .../redis/clients/jedis/ConnectionPool.java | 71 +++++++++++++-- .../jedis/DefaultJedisClientConfig.java | 79 ++++++++++++----- .../clients/jedis/JedisClientConfig.java | 9 +- .../java/redis/clients/jedis/JedisPooled.java | 4 +- .../authentication/JedisAuthXManager.java | 26 ++++++ .../authentication/TokenCredentials.java | 24 +++++ ...enBasedAuthenticationIntegrationTests.java | 87 +++++++++++++++++++ .../TokenBasedAuthenticationUnitTests.java | 48 ++++++++++ 13 files changed, 414 insertions(+), 44 deletions(-) create mode 100644 .gitmodules create mode 160000 redis-authx create mode 100644 src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java create mode 100644 src/main/java/redis/clients/jedis/authentication/TokenCredentials.java create mode 100644 src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java create mode 100644 src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..e974dd3048 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "redis-authx"] + path = redis-authx + url = https://github.com/redis/tbd-auth-entraid + branch = tba-draft diff --git a/pom.xml b/pom.xml index 4060d175be..0ed233657e 100644 --- a/pom.xml +++ b/pom.xml @@ -93,6 +93,12 @@ test + + redis.clients.authentication + redis-authx-core + 0.1.0 + + junit diff --git a/redis-authx b/redis-authx new file mode 160000 index 0000000000..8f56584858 --- /dev/null +++ b/redis-authx @@ -0,0 +1 @@ +Subproject commit 8f5658485897d2ca56af5238af25fc709cd0eaa9 diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index 2860866c6e..6e2de1377a 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -14,12 +14,16 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; +import java.util.concurrent.atomic.AtomicReference; +import redis.clients.authentication.core.AuthenticatedConnection; +import redis.clients.authentication.core.Token; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.Protocol.Keyword; import redis.clients.jedis.annots.Experimental; import redis.clients.jedis.args.ClientAttributeOption; import redis.clients.jedis.args.Rawable; +import redis.clients.jedis.authentication.TokenCredentials; import redis.clients.jedis.commands.ProtocolCommand; import redis.clients.jedis.exceptions.JedisConnectionException; import redis.clients.jedis.exceptions.JedisDataException; @@ -29,7 +33,7 @@ import redis.clients.jedis.util.RedisInputStream; import redis.clients.jedis.util.RedisOutputStream; -public class Connection implements Closeable { +public class Connection implements Closeable, AuthenticatedConnection { private ConnectionPool memberOf; protected RedisProtocol protocol; @@ -44,6 +48,7 @@ public class Connection implements Closeable { private String strVal; protected String server; protected String version; + protected AtomicReference currentToken = new AtomicReference(null); public Connection() { this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT); @@ -542,6 +547,10 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c // handled in RedisCredentialsProvider.cleanUp() } + public void setToken(Token token) { + currentToken.set(token); + } + private void auth(RedisCredentials credentials) { if (credentials == null || credentials.getPassword() == null) { return; @@ -559,6 +568,13 @@ private void auth(RedisCredentials credentials) { getStatusCodeReply(); } + public void reAuth() { + Token temp = currentToken.getAndSet(null); + if (temp != null) { + auth(new TokenCredentials(temp)); + } + } + protected Map hello(byte[]... args) { sendCommand(Command.HELLO, args); return BuilderFactory.ENCODED_OBJECT_MAP.build(getOne()); @@ -585,4 +601,9 @@ public boolean ping() { } return true; } + + @Override + public void authenticate(Token token) { + this.setToken(token); + } } diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index cc53df56f0..9ccdb6b918 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -6,7 +6,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; + import redis.clients.jedis.annots.Experimental; +import redis.clients.jedis.authentication.TokenCredentials; +import redis.clients.authentication.core.AuthXManager; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.csc.CacheConnection; import redis.clients.jedis.exceptions.JedisException; @@ -20,28 +26,70 @@ public class ConnectionFactory implements PooledObjectFactory { private final JedisSocketFactory jedisSocketFactory; private final JedisClientConfig clientConfig; - private Cache clientSideCache = null; + private final Cache clientSideCache; + private final Supplier objectMaker; public ConnectionFactory(final HostAndPort hostAndPort) { - this.clientConfig = DefaultJedisClientConfig.builder().build(); - this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort); + this(hostAndPort, DefaultJedisClientConfig.builder().build(), null, null); } public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) { - this.clientConfig = clientConfig; - this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig); + this(hostAndPort, clientConfig, null, null); } @Experimental - public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, Cache csCache) { - this.clientConfig = clientConfig; - this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig); - this.clientSideCache = csCache; + public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, + Cache csCache, AuthXManager authXManager) { + this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache, + authXManager); } - public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final JedisClientConfig clientConfig) { - this.clientConfig = clientConfig; + public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, + final JedisClientConfig clientConfig) { + this(jedisSocketFactory, clientConfig, null, null); + } + + private ConnectionFactory(final JedisSocketFactory jedisSocketFactory, + final JedisClientConfig clientConfig, Cache csCache, AuthXManager authXManager) { + this.jedisSocketFactory = jedisSocketFactory; + this.clientSideCache = csCache; + + if (authXManager == null) { + this.clientConfig = clientConfig; + this.objectMaker = connectionSupplier(); + } else { + this.clientConfig = replaceCredentialsProvider(clientConfig, + buildCredentialsProvider(authXManager)); + Supplier supplier = connectionSupplier(); + this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get()); + + try { + authXManager.start(true); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new JedisException("AuthXManager failed to start!", e); + } + } + } + + private JedisClientConfig replaceCredentialsProvider(JedisClientConfig origin, + Supplier newCredentialsProvider) { + return DefaultJedisClientConfig.builder().from(origin) + .credentialsProvider(newCredentialsProvider).build(); + } + + private Supplier buildCredentialsProvider(AuthXManager connManager) { + return new Supplier() { + @Override + public RedisCredentials get() { + return new TokenCredentials(connManager.getCurrentToken()); + } + }; + } + + private Supplier connectionSupplier() { + return clientSideCache == null ? () -> new Connection(jedisSocketFactory, clientConfig) + : () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache); } @Override @@ -64,8 +112,7 @@ public void destroyObject(PooledObject pooledConnection) throws Exce @Override public PooledObject makeObject() throws Exception { try { - Connection jedis = clientSideCache == null ? new Connection(jedisSocketFactory, clientConfig) - : new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache); + Connection jedis = objectMaker.get(); return new DefaultPooledObject<>(jedis); } catch (JedisException je) { logger.debug("Error while makeObject", je); @@ -76,6 +123,8 @@ public PooledObject makeObject() throws Exception { @Override public void passivateObject(PooledObject pooledConnection) throws Exception { // TODO maybe should select db 0? Not sure right now. + Connection jedis = pooledConnection.getObject(); + jedis.reAuth(); } @Override @@ -83,6 +132,7 @@ public boolean validateObject(PooledObject pooledConnection) { final Connection jedis = pooledConnection.getObject(); try { // check HostAndPort ?? + jedis.reAuth(); return jedis.isConnected() && jedis.ping(); } catch (final Exception e) { logger.warn("Error while validating pooled Connection object.", e); diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index 40d4861f98..eba26f7375 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -1,20 +1,42 @@ package redis.clients.jedis; +import java.util.concurrent.atomic.AtomicInteger; + import org.apache.commons.pool2.PooledObjectFactory; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; + +import redis.clients.authentication.core.AuthXManagerFactory; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenListener; import redis.clients.jedis.annots.Experimental; +import redis.clients.jedis.authentication.JedisAuthXManager; import redis.clients.jedis.csc.Cache; +import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.util.Pool; public class ConnectionPool extends Pool { public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) { - this(new ConnectionFactory(hostAndPort, clientConfig)); + this(hostAndPort, clientConfig, createAuthXManager(clientConfig)); + } + + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, + JedisAuthXManager authXManager) { + this(new ConnectionFactory(hostAndPort, clientConfig, null, authXManager)); + attachAuthXManager(authXManager); } @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) { - this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache)); + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, + Cache clientSideCache) { + this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig)); + } + + @Experimental + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, + Cache clientSideCache, JedisAuthXManager authXManager) { + this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager)); + attachAuthXManager(authXManager); } public ConnectionPool(PooledObjectFactory factory) { @@ -23,13 +45,22 @@ public ConnectionPool(PooledObjectFactory factory) { public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, GenericObjectPoolConfig poolConfig) { - this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig); + this(hostAndPort, clientConfig, null, createAuthXManager(clientConfig), poolConfig); } @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache, + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, + Cache clientSideCache, GenericObjectPoolConfig poolConfig) { + this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig), poolConfig); + } + + @Experimental + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, + Cache clientSideCache, JedisAuthXManager authXManager, GenericObjectPoolConfig poolConfig) { - this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig); + this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager), + poolConfig); + attachAuthXManager(authXManager); } public ConnectionPool(PooledObjectFactory factory, @@ -43,4 +74,32 @@ public Connection getResource() { conn.setHandlingPool(this); return conn; } + + private static JedisAuthXManager createAuthXManager(JedisClientConfig clientConfig) { + return (clientConfig.getTokenAuthConfig() != null) + ? (JedisAuthXManager) AuthXManagerFactory.create(JedisAuthXManager.class, + clientConfig.getTokenAuthConfig()) + : null; + } + + private void attachAuthXManager(JedisAuthXManager authXManager) { + if (authXManager != null) { + authXManager.setListener(new TokenListener() { + @Override + public void onTokenRenewed(Token token) { + try { + ConnectionPool.this.evict(); + System.out.println("pool evict: " + ConnectionPool.this.hashCode()); + + } catch (Exception e) { + throw new JedisException("Failed to evict connections from pool", e); + } + } + + @Override + public void onError(Exception reason) { + } + }); + } + } } diff --git a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java index e304e961f1..8b161ca7ff 100644 --- a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java @@ -5,6 +5,8 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; +import redis.clients.authentication.core.TokenAuthConfig; + public final class DefaultJedisClientConfig implements JedisClientConfig { private final RedisProtocol redisProtocol; @@ -28,11 +30,15 @@ public final class DefaultJedisClientConfig implements JedisClientConfig { private final boolean readOnlyForRedisClusterReplicas; - private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMillis, int soTimeoutMillis, - int blockingSocketTimeoutMillis, Supplier credentialsProvider, int database, - String clientName, boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters, + private final TokenAuthConfig tokenAuthConfig; + + private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMillis, + int soTimeoutMillis, int blockingSocketTimeoutMillis, + Supplier credentialsProvider, int database, String clientName, boolean ssl, + SSLSocketFactory sslSocketFactory, SSLParameters sslParameters, HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper, - ClientSetInfoConfig clientSetInfoConfig, boolean readOnlyForRedisClusterReplicas) { + ClientSetInfoConfig clientSetInfoConfig, boolean readOnlyForRedisClusterReplicas, + TokenAuthConfig tokenAuthConfig) { this.redisProtocol = protocol; this.connectionTimeoutMillis = connectionTimeoutMillis; this.socketTimeoutMillis = soTimeoutMillis; @@ -47,6 +53,7 @@ private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMi this.hostAndPortMapper = hostAndPortMapper; this.clientSetInfoConfig = clientSetInfoConfig; this.readOnlyForRedisClusterReplicas = readOnlyForRedisClusterReplicas; + this.tokenAuthConfig = tokenAuthConfig; } @Override @@ -85,6 +92,11 @@ public Supplier getCredentialsProvider() { return credentialsProvider; } + @Override + public TokenAuthConfig getTokenAuthConfig() { + return tokenAuthConfig; + } + @Override public int getDatabase() { return database; @@ -159,6 +171,8 @@ public static class Builder { private boolean readOnlyForRedisClusterReplicas = false; + private TokenAuthConfig tokenAuthConfig = null; + private Builder() { } @@ -168,10 +182,10 @@ public DefaultJedisClientConfig build() { new DefaultRedisCredentials(user, password)); } - return new DefaultJedisClientConfig(redisProtocol, connectionTimeoutMillis, socketTimeoutMillis, - blockingSocketTimeoutMillis, credentialsProvider, database, clientName, ssl, - sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, clientSetInfoConfig, - readOnlyForRedisClusterReplicas); + return new DefaultJedisClientConfig(redisProtocol, connectionTimeoutMillis, + socketTimeoutMillis, blockingSocketTimeoutMillis, credentialsProvider, database, + clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, + clientSetInfoConfig, readOnlyForRedisClusterReplicas, tokenAuthConfig); } /** @@ -272,25 +286,50 @@ public Builder readOnlyForRedisClusterReplicas() { this.readOnlyForRedisClusterReplicas = true; return this; } + + public Builder tokenAuthConfig(TokenAuthConfig tokenAuthConfig) { + this.tokenAuthConfig = tokenAuthConfig; + return this; + } + + public Builder from(JedisClientConfig instance) { + this.redisProtocol = instance.getRedisProtocol(); + this.connectionTimeoutMillis = instance.getConnectionTimeoutMillis(); + this.socketTimeoutMillis = instance.getSocketTimeoutMillis(); + this.blockingSocketTimeoutMillis = instance.getBlockingSocketTimeoutMillis(); + this.credentialsProvider = instance.getCredentialsProvider(); + this.database = instance.getDatabase(); + this.clientName = instance.getClientName(); + this.ssl = instance.isSsl(); + this.sslSocketFactory = instance.getSslSocketFactory(); + this.sslParameters = instance.getSslParameters(); + this.hostnameVerifier = instance.getHostnameVerifier(); + this.hostAndPortMapper = instance.getHostAndPortMapper(); + this.clientSetInfoConfig = instance.getClientSetInfoConfig(); + this.readOnlyForRedisClusterReplicas = instance.isReadOnlyForRedisClusterReplicas(); + this.tokenAuthConfig = instance.getTokenAuthConfig(); + return this; + } } public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int soTimeoutMillis, - int blockingSocketTimeoutMillis, String user, String password, int database, String clientName, - boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters, - HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper) { - return new DefaultJedisClientConfig(null, - connectionTimeoutMillis, soTimeoutMillis, blockingSocketTimeoutMillis, + int blockingSocketTimeoutMillis, String user, String password, int database, + String clientName, boolean ssl, SSLSocketFactory sslSocketFactory, + SSLParameters sslParameters, HostnameVerifier hostnameVerifier, + HostAndPortMapper hostAndPortMapper, TokenAuthConfig tokenAuthConfig) { + return new DefaultJedisClientConfig(null, connectionTimeoutMillis, soTimeoutMillis, + blockingSocketTimeoutMillis, new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(user, password)), database, clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, null, - false); + false, tokenAuthConfig); } public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) { - return new DefaultJedisClientConfig(copy.getRedisProtocol(), - copy.getConnectionTimeoutMillis(), copy.getSocketTimeoutMillis(), - copy.getBlockingSocketTimeoutMillis(), copy.getCredentialsProvider(), - copy.getDatabase(), copy.getClientName(), copy.isSsl(), copy.getSslSocketFactory(), - copy.getSslParameters(), copy.getHostnameVerifier(), copy.getHostAndPortMapper(), - copy.getClientSetInfoConfig(), copy.isReadOnlyForRedisClusterReplicas()); + return new DefaultJedisClientConfig(copy.getRedisProtocol(), copy.getConnectionTimeoutMillis(), + copy.getSocketTimeoutMillis(), copy.getBlockingSocketTimeoutMillis(), + copy.getCredentialsProvider(), copy.getDatabase(), copy.getClientName(), copy.isSsl(), + copy.getSslSocketFactory(), copy.getSslParameters(), copy.getHostnameVerifier(), + copy.getHostAndPortMapper(), copy.getClientSetInfoConfig(), + copy.isReadOnlyForRedisClusterReplicas(), copy.getTokenAuthConfig()); } } diff --git a/src/main/java/redis/clients/jedis/JedisClientConfig.java b/src/main/java/redis/clients/jedis/JedisClientConfig.java index abe1f35237..a8046694bf 100644 --- a/src/main/java/redis/clients/jedis/JedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/JedisClientConfig.java @@ -5,6 +5,8 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; +import redis.clients.authentication.core.TokenAuthConfig; + public interface JedisClientConfig { default RedisProtocol getRedisProtocol() { @@ -45,8 +47,11 @@ default String getPassword() { } default Supplier getCredentialsProvider() { - return new DefaultRedisCredentialsProvider( - new DefaultRedisCredentials(getUser(), getPassword())); + return new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(getUser(), getPassword())); + } + + default TokenAuthConfig getTokenAuthConfig() { + return null; } default int getDatabase() { diff --git a/src/main/java/redis/clients/jedis/JedisPooled.java b/src/main/java/redis/clients/jedis/JedisPooled.java index c3429319e7..c735fda7c1 100644 --- a/src/main/java/redis/clients/jedis/JedisPooled.java +++ b/src/main/java/redis/clients/jedis/JedisPooled.java @@ -295,7 +295,7 @@ public JedisPooled(final GenericObjectPoolConfig poolConfig, final S final int connectionTimeout, final int soTimeout, final int infiniteSoTimeout, final String user, final String password, final int database, final String clientName) { this(new HostAndPort(host, port), DefaultJedisClientConfig.create(connectionTimeout, soTimeout, - infiniteSoTimeout, user, password, database, clientName, false, null, null, null, null), + infiniteSoTimeout, user, password, database, clientName, false, null, null, null, null, null), poolConfig); } @@ -306,7 +306,7 @@ public JedisPooled(final GenericObjectPoolConfig poolConfig, final S final HostnameVerifier hostnameVerifier) { this(new HostAndPort(host, port), DefaultJedisClientConfig.create(connectionTimeout, soTimeout, infiniteSoTimeout, user, password, database, clientName, ssl, sslSocketFactory, sslParameters, - hostnameVerifier, null), poolConfig); + hostnameVerifier, null, null), poolConfig); } public JedisPooled(final URI uri) { diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java new file mode 100644 index 0000000000..dfda02b8ad --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java @@ -0,0 +1,26 @@ +package redis.clients.jedis.authentication; + +import redis.clients.authentication.core.AuthXManager; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenListener; +import redis.clients.authentication.core.TokenManager; + +public class JedisAuthXManager extends AuthXManager { + private TokenListener listener; + + public JedisAuthXManager(TokenManager tokenManager) { + super(tokenManager); + } + + public void setListener(TokenListener listener) { + this.listener = listener; + } + + @Override + public void authenticateConnections(Token token) { + super.authenticateConnections(token); + if (listener != null) { + listener.onTokenRenewed(token); + } + } +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java new file mode 100644 index 0000000000..9c5a54f135 --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java @@ -0,0 +1,24 @@ +package redis.clients.jedis.authentication; + +import redis.clients.authentication.core.Token; +import redis.clients.jedis.RedisCredentials; + +public class TokenCredentials implements RedisCredentials { + private final String user; + private final char[] password; + + public TokenCredentials(Token token) { + user = token.tryGet("oid"); + password = token.getValue().toCharArray(); + } + + @Override + public String getUser() { + return user; + } + + @Override + public char[] getPassword() { + return password; + } +} \ No newline at end of file diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java new file mode 100644 index 0000000000..eaa7750cb7 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -0,0 +1,87 @@ +package redis.clients.jedis.authentication; + +import static org.mockito.Mockito.when; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.jedis.CommandArguments; +/* */ +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisClientConfig; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.Protocol; +import redis.clients.jedis.Protocol.Command; +import redis.clients.jedis.args.Rawable; +import redis.clients.jedis.commands.ProtocolCommand; + +public class TokenBasedAuthenticationIntegrationTests { + + protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); + + @Test + public void testJedisPooledAuth() { + String user = "default"; + String password = endpoint.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken(password, new Date(System.currentTimeMillis() + 100000), + new Date(), Collections.singletonMap("oid", user))); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecutionTimeoutInMs(1000).build(); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .tokenAuthConfig(tokenAuthConfig).build(); + + try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) { + ArgumentCaptor captor = ArgumentCaptor + .forClass(CommandArguments.class); + + try (JedisPooled jedis = new JedisPooled(endpoint.getHostAndPort(), clientConfig)) { + jedis.get("key1"); + } + + // Verify that the static method was called + mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()), + Mockito.atLeast(4)); + + CommandArguments commandArgs = captor.getAllValues().get(0); + List args = StreamSupport.stream(commandArgs.spliterator(), false) + .map(Rawable::getRaw).collect(Collectors.toList()); + + assertThat(args, + contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes())); + + List cmds = captor.getAllValues().stream() + .map(item -> item.getCommand()).collect(Collectors.toList()); + assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET), + cmds); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java new file mode 100644 index 0000000000..290b2a6913 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -0,0 +1,48 @@ +package redis.clients.jedis.authentication; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import java.util.Collections; +import java.util.Date; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.TokenManager; +import redis.clients.authentication.core.TokenManagerConfig; +import redis.clients.jedis.ConnectionPool; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPorts; + +public class TokenBasedAuthenticationUnitTests { + protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); + + @Test + public void testJedisAuthXManager() throws Exception { + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken("password", new Date(System.currentTimeMillis() + 100000), + new Date(), Collections.singletonMap("oid", "default"))); + + TokenManager tokenManager = new TokenManager(idProvider, + new TokenManagerConfig(0.5F, 1000, 1000, null)); + JedisAuthXManager jedisAuthXManager = new JedisAuthXManager(tokenManager); + + AtomicInteger numberOfEvictions = new AtomicInteger(0); + ConnectionPool pool = spy(new ConnectionPool(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().build(), jedisAuthXManager) { + @Override + public void evict() throws Exception { + numberOfEvictions.incrementAndGet(); + super.evict(); + } + }); + + jedisAuthXManager.start(true); + + assertEquals(1, numberOfEvictions.get()); + } +} From a18e63216c9c176710a8511ba3cae5c52f41946a Mon Sep 17 00:00:00 2001 From: atakavci Date: Fri, 8 Nov 2024 00:41:13 +0300 Subject: [PATCH 02/21] - stop authxmanager on pool close - swith to long dates --- .../redis/clients/jedis/ConnectionPool.java | 27 ++++++++++++------- ...enBasedAuthenticationIntegrationTests.java | 4 +-- .../TokenBasedAuthenticationUnitTests.java | 5 ++-- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index eba26f7375..ac9cf63679 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -1,7 +1,5 @@ package redis.clients.jedis; -import java.util.concurrent.atomic.AtomicInteger; - import org.apache.commons.pool2.PooledObjectFactory; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; @@ -16,6 +14,8 @@ public class ConnectionPool extends Pool { + private JedisAuthXManager authXManager; + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) { this(hostAndPort, clientConfig, createAuthXManager(clientConfig)); } @@ -75,22 +75,29 @@ public Connection getResource() { return conn; } - private static JedisAuthXManager createAuthXManager(JedisClientConfig clientConfig) { - return (clientConfig.getTokenAuthConfig() != null) - ? (JedisAuthXManager) AuthXManagerFactory.create(JedisAuthXManager.class, - clientConfig.getTokenAuthConfig()) - : null; + @Override + public void close() { + if (authXManager != null) { + authXManager.stop(); + } + super.close(); + } + + private static JedisAuthXManager createAuthXManager(JedisClientConfig config) { + if (config.getTokenAuthConfig() != null) { + return AuthXManagerFactory.create(JedisAuthXManager.class, config.getTokenAuthConfig()); + } + return null; } private void attachAuthXManager(JedisAuthXManager authXManager) { + this.authXManager = authXManager; if (authXManager != null) { authXManager.setListener(new TokenListener() { @Override public void onTokenRenewed(Token token) { try { - ConnectionPool.this.evict(); - System.out.println("pool evict: " + ConnectionPool.this.hashCode()); - + evict(); } catch (Exception e) { throw new JedisException("Failed to evict connections from pool", e); } diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java index eaa7750cb7..51eb61d617 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -46,8 +46,8 @@ public void testJedisPooledAuth() { IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken(password, new Date(System.currentTimeMillis() + 100000), - new Date(), Collections.singletonMap("oid", user))); + .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000, + System.currentTimeMillis(), Collections.singletonMap("oid", user))); IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); when(idProviderConfig.getProvider()).thenReturn(idProvider); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index 290b2a6913..a59e81a824 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -5,7 +5,6 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import java.util.Collections; -import java.util.Date; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import redis.clients.authentication.core.IdentityProvider; @@ -24,8 +23,8 @@ public void testJedisAuthXManager() throws Exception { IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken("password", new Date(System.currentTimeMillis() + 100000), - new Date(), Collections.singletonMap("oid", "default"))); + .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 100000, + System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, new TokenManagerConfig(0.5F, 1000, 1000, null)); From 392b3b058a0d246b24136efc44651ea28d090222 Mon Sep 17 00:00:00 2001 From: atakavci Date: Tue, 12 Nov 2024 15:40:02 +0300 Subject: [PATCH 03/21] drop use of authxmanager and authenticatedconnection from core --- .../java/redis/clients/jedis/Connection.java | 39 ++- .../clients/jedis/ConnectionFactory.java | 18 +- .../redis/clients/jedis/ConnectionPool.java | 31 +- .../authentication/JedisAuthXManager.java | 95 ++++++- .../JedisAuthenticationException.java | 12 + .../TokenBasedAuthenticationUnitTests.java | 266 +++++++++++++++++- 6 files changed, 392 insertions(+), 69 deletions(-) create mode 100644 src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index 6e2de1377a..655158f1eb 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -16,14 +16,11 @@ import java.util.function.Supplier; import java.util.concurrent.atomic.AtomicReference; -import redis.clients.authentication.core.AuthenticatedConnection; -import redis.clients.authentication.core.Token; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.Protocol.Keyword; import redis.clients.jedis.annots.Experimental; import redis.clients.jedis.args.ClientAttributeOption; import redis.clients.jedis.args.Rawable; -import redis.clients.jedis.authentication.TokenCredentials; import redis.clients.jedis.commands.ProtocolCommand; import redis.clients.jedis.exceptions.JedisConnectionException; import redis.clients.jedis.exceptions.JedisDataException; @@ -33,7 +30,7 @@ import redis.clients.jedis.util.RedisInputStream; import redis.clients.jedis.util.RedisOutputStream; -public class Connection implements Closeable, AuthenticatedConnection { +public class Connection implements Closeable { private ConnectionPool memberOf; protected RedisProtocol protocol; @@ -48,7 +45,8 @@ public class Connection implements Closeable, AuthenticatedConnection { private String strVal; protected String server; protected String version; - protected AtomicReference currentToken = new AtomicReference(null); + protected AtomicReference currentCredentials = new AtomicReference( + null); public Connection() { this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT); @@ -98,8 +96,8 @@ public String toIdentityString() { SocketAddress remoteAddr = socket.getRemoteSocketAddress(); SocketAddress localAddr = socket.getLocalSocketAddress(); if (remoteAddr != null) { - strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id, - localAddr, (broken ? '!' : '-'), remoteAddr); + strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id, localAddr, + (broken ? '!' : '-'), remoteAddr); } else if (localAddr != null) { strVal = String.format("%s{id: 0x%X, L:%s}", className, id, localAddr); } else { @@ -443,8 +441,8 @@ private static boolean validateClientInfo(String info) { for (int i = 0; i < info.length(); i++) { char c = info.charAt(i); if (c < '!' || c > '~') { - throw new JedisValidationException("client info cannot contain spaces, " - + "newlines or special characters."); + throw new JedisValidationException( + "client info cannot contain spaces, " + "newlines or special characters."); } } return true; @@ -474,7 +472,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { String clientName = config.getClientName(); if (clientName != null && validateClientInfo(clientName)) { - fireAndForgetMsg.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName)); + fireAndForgetMsg + .add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName)); } ClientSetInfoConfig setInfoConfig = config.getClientSetInfoConfig(); @@ -530,12 +529,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c if (protocol != null && credentials != null && credentials.getUser() != null) { byte[] rawPass = encodeToBytes(credentials.getPassword()); try { - helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass); + helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), + encode(credentials.getUser()), rawPass); } finally { Arrays.fill(rawPass, (byte) 0); // clear sensitive data } } else { - auth(credentials); + authenticate(credentials); helloResult = protocol == null ? null : hello(encode(protocol.version())); } if (helloResult != null) { @@ -547,11 +547,11 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c // handled in RedisCredentialsProvider.cleanUp() } - public void setToken(Token token) { - currentToken.set(token); + public void setCredentials(RedisCredentials credentials) { + currentCredentials.set(credentials); } - private void auth(RedisCredentials credentials) { + public void authenticate(RedisCredentials credentials) { if (credentials == null || credentials.getPassword() == null) { return; } @@ -569,9 +569,9 @@ private void auth(RedisCredentials credentials) { } public void reAuth() { - Token temp = currentToken.getAndSet(null); + RedisCredentials temp = currentCredentials.getAndSet(null); if (temp != null) { - auth(new TokenCredentials(temp)); + authenticate(temp); } } @@ -601,9 +601,4 @@ public boolean ping() { } return true; } - - @Override - public void authenticate(Token token) { - this.setToken(token); - } } diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index 9ccdb6b918..ce4a10cb7b 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -11,8 +11,7 @@ import java.util.function.Supplier; import redis.clients.jedis.annots.Experimental; -import redis.clients.jedis.authentication.TokenCredentials; -import redis.clients.authentication.core.AuthXManager; +import redis.clients.jedis.authentication.JedisAuthXManager; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.csc.CacheConnection; import redis.clients.jedis.exceptions.JedisException; @@ -39,7 +38,7 @@ public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig @Experimental public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, - Cache csCache, AuthXManager authXManager) { + Cache csCache, JedisAuthXManager authXManager) { this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache, authXManager); } @@ -50,7 +49,7 @@ public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, } private ConnectionFactory(final JedisSocketFactory jedisSocketFactory, - final JedisClientConfig clientConfig, Cache csCache, AuthXManager authXManager) { + final JedisClientConfig clientConfig, Cache csCache, JedisAuthXManager authXManager) { this.jedisSocketFactory = jedisSocketFactory; this.clientSideCache = csCache; @@ -60,7 +59,7 @@ private ConnectionFactory(final JedisSocketFactory jedisSocketFactory, this.objectMaker = connectionSupplier(); } else { this.clientConfig = replaceCredentialsProvider(clientConfig, - buildCredentialsProvider(authXManager)); + authXManager); Supplier supplier = connectionSupplier(); this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get()); @@ -78,15 +77,6 @@ private JedisClientConfig replaceCredentialsProvider(JedisClientConfig origin, .credentialsProvider(newCredentialsProvider).build(); } - private Supplier buildCredentialsProvider(AuthXManager connManager) { - return new Supplier() { - @Override - public RedisCredentials get() { - return new TokenCredentials(connManager.getCurrentToken()); - } - }; - } - private Supplier connectionSupplier() { return clientSideCache == null ? () -> new Connection(jedisSocketFactory, clientConfig) : () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache); diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index ac9cf63679..d7dc0d85f7 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -3,9 +3,6 @@ import org.apache.commons.pool2.PooledObjectFactory; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; -import redis.clients.authentication.core.AuthXManagerFactory; -import redis.clients.authentication.core.Token; -import redis.clients.authentication.core.TokenListener; import redis.clients.jedis.annots.Experimental; import redis.clients.jedis.authentication.JedisAuthXManager; import redis.clients.jedis.csc.Cache; @@ -23,7 +20,7 @@ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) { public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, JedisAuthXManager authXManager) { this(new ConnectionFactory(hostAndPort, clientConfig, null, authXManager)); - attachAuthXManager(authXManager); + attachAuthenticationListener(authXManager); } @Experimental @@ -36,7 +33,7 @@ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache, JedisAuthXManager authXManager) { this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager)); - attachAuthXManager(authXManager); + attachAuthenticationListener(authXManager); } public ConnectionPool(PooledObjectFactory factory) { @@ -60,7 +57,7 @@ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, GenericObjectPoolConfig poolConfig) { this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager), poolConfig); - attachAuthXManager(authXManager); + attachAuthenticationListener(authXManager); } public ConnectionPool(PooledObjectFactory factory, @@ -85,26 +82,20 @@ public void close() { private static JedisAuthXManager createAuthXManager(JedisClientConfig config) { if (config.getTokenAuthConfig() != null) { - return AuthXManagerFactory.create(JedisAuthXManager.class, config.getTokenAuthConfig()); + return new JedisAuthXManager(config.getTokenAuthConfig()); } return null; } - private void attachAuthXManager(JedisAuthXManager authXManager) { + private void attachAuthenticationListener(JedisAuthXManager authXManager) { this.authXManager = authXManager; if (authXManager != null) { - authXManager.setListener(new TokenListener() { - @Override - public void onTokenRenewed(Token token) { - try { - evict(); - } catch (Exception e) { - throw new JedisException("Failed to evict connections from pool", e); - } - } - - @Override - public void onError(Exception reason) { + authXManager.setListener(token -> { + try { + // this is to trigger validations on each connection via ConnectionFactory + evict(); + } catch (Exception e) { + throw new JedisException("Failed to evict connections from pool", e); } }); } diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java index dfda02b8ad..57a3bab894 100644 --- a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java +++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java @@ -1,26 +1,103 @@ package redis.clients.jedis.authentication; -import redis.clients.authentication.core.AuthXManager; +import java.lang.ref.WeakReference; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; import redis.clients.authentication.core.TokenListener; import redis.clients.authentication.core.TokenManager; +import redis.clients.jedis.Connection; +import redis.clients.jedis.RedisCredentials; + +public class JedisAuthXManager implements Supplier { + + private static final Logger log = LoggerFactory.getLogger(JedisAuthXManager.class); -public class JedisAuthXManager extends AuthXManager { - private TokenListener listener; + private TokenManager tokenManager; + private List> connections = Collections + .synchronizedList(new ArrayList<>()); + private Token currentToken; + private AuthenticationListener listener; + + public interface AuthenticationListener { + public void onAuthenticate(Token token); + } public JedisAuthXManager(TokenManager tokenManager) { - super(tokenManager); + this.tokenManager = tokenManager; } - public void setListener(TokenListener listener) { - this.listener = listener; + public JedisAuthXManager(TokenAuthConfig tokenAuthConfig) { + this(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(), + tokenAuthConfig.getTokenManagerConfig())); + } + + public void start(boolean blockForInitialToken) + throws InterruptedException, ExecutionException, TimeoutException { + + tokenManager.start(new TokenListener() { + @Override + public void onTokenRenewed(Token token) { + currentToken = token; + authenticateConnections(token); + } + + @Override + public void onError(Exception reason) { + JedisAuthXManager.this.onError(reason); + } + }, blockForInitialToken); } - @Override public void authenticateConnections(Token token) { - super.authenticateConnections(token); + RedisCredentials credentialsFromToken = new TokenCredentials(token); + for (WeakReference connectionRef : connections) { + Connection connection = connectionRef.get(); + if (connection != null) { + try { + connection.setCredentials(credentialsFromToken); + } catch (Exception e) { + log.error("Failed to authenticate connection!", e); + } + } else { + connections.remove(connectionRef); + } + } if (listener != null) { - listener.onTokenRenewed(token); + listener.onAuthenticate(token); } } + + public void onError(Exception reason) { + throw new JedisAuthenticationException( + "Token request/renewal failed with message:" + reason.getMessage(), reason); + } + + public Connection addConnection(Connection connection) { + connections.add(new WeakReference<>(connection)); + return connection; + } + + public void stop() { + tokenManager.stop(); + } + + public void setListener(AuthenticationListener listener) { + this.listener = listener; + } + + @Override + public RedisCredentials get() { + return new TokenCredentials(this.currentToken); + } + } \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java new file mode 100644 index 0000000000..adc421e790 --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java @@ -0,0 +1,12 @@ +package redis.clients.jedis.authentication; + +public class JedisAuthenticationException extends RuntimeException { + + public JedisAuthenticationException(String message) { + super(message); + } + + public JedisAuthenticationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index a59e81a824..8a0720906e 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -1,14 +1,39 @@ package redis.clients.jedis.authentication; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.awaitility.Awaitility.await; +import static org.hamcrest.CoreMatchers.either; +import static org.hamcrest.CoreMatchers.is; import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; + +import org.hamcrest.Matchers; import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedConstruction; +import org.mockito.Mockito; + import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.core.TokenListener; import redis.clients.authentication.core.TokenManager; import redis.clients.authentication.core.TokenManagerConfig; import redis.clients.jedis.ConnectionPool; @@ -19,7 +44,25 @@ public class TokenBasedAuthenticationUnitTests { protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); @Test - public void testJedisAuthXManager() throws Exception { + public void testJedisAuthXManagerInstance() { + TokenManagerConfig tokenManagerConfig = mock(TokenManagerConfig.class); + IdentityProviderConfig identityProviderConfig = mock(IdentityProviderConfig.class); + IdentityProvider identityProvider = mock(IdentityProvider.class); + + when(identityProviderConfig.getProvider()).thenReturn(identityProvider); + + try (MockedConstruction mockedConstructor = mockConstruction(TokenManager.class, + (mock, context) -> { + assertEquals(identityProvider, context.arguments().get(0)); + assertEquals(tokenManagerConfig, context.arguments().get(1)); + })) { + + new JedisAuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig)); + } + } + + @Test + public void testJedisAuthXManagerTriggersEvict() throws Exception { IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) @@ -31,17 +74,232 @@ public void testJedisAuthXManager() throws Exception { JedisAuthXManager jedisAuthXManager = new JedisAuthXManager(tokenManager); AtomicInteger numberOfEvictions = new AtomicInteger(0); - ConnectionPool pool = spy(new ConnectionPool(endpoint.getHostAndPort(), + ConnectionPool pool = new ConnectionPool(endpoint.getHostAndPort(), endpoint.getClientConfigBuilder().build(), jedisAuthXManager) { @Override public void evict() throws Exception { numberOfEvictions.incrementAndGet(); super.evict(); } - }); + }; jedisAuthXManager.start(true); - assertEquals(1, numberOfEvictions.get()); } + + public static class TokenManagerConfigWrapper extends TokenManagerConfig { + int lower; + float ratio; + + public TokenManagerConfigWrapper() { + super(0, 0, 0, null); + } + + @Override + public int getLowerRefreshBoundMillis() { + return lower; + } + + @Override + public float getExpirationRefreshRatio() { + return ratio; + } + } + + @Test + public void testCalculateRenewalDelay() { + long delay = 0; + long duration = 0; + long issueDate; + long expireDate; + + TokenManagerConfigWrapper config = new TokenManagerConfigWrapper(); + TokenManager manager = new TokenManager(() -> null, config); + + duration = 5000; + config.lower = 2000; + config.ratio = 0.5F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertThat(delay, Matchers + .greaterThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); + + duration = 10000; + config.lower = 8000; + config.ratio = 0.2F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertThat(delay, Matchers + .greaterThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); + + duration = 10000; + config.lower = 10000; + config.ratio = 0.2F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertEquals(0, delay); + + duration = 0; + config.lower = 5000; + config.ratio = 0.2F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertEquals(0, delay); + + duration = 10000; + config.lower = 1000; + config.ratio = 0.00001F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertEquals(0, delay); + + duration = 10000; + config.lower = 1000; + config.ratio = 0.0001F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertThat(delay, either(is(0L)).or(is(1L))); + } + + @Test + public void testAuthXManagerReceivesNewToken() throws InterruptedException, ExecutionException, TimeoutException { + + IdentityProvider identityProvider = () -> new SimpleToken("tokenVal", + System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), + Collections.singletonMap("oid", "user1")); + + TokenManager tokenManager = new TokenManager(identityProvider, + new TokenManagerConfig(0.7F, 200, 2000, null)); + + JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + + final Token[] tokenHolder = new Token[1]; + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + tokenHolder[0] = (Token) args[0]; + return null; + }).when(manager).authenticateConnections(any()); + + manager.start(true); + assertEquals(tokenHolder[0].getValue(), "tokenVal"); + } + + @Test + public void testBlockForInitialToken() { + IdentityProvider identityProvider = () -> { + throw new RuntimeException("Test exception from identity provider!"); + }; + + TokenManager tokenManager = new TokenManager(identityProvider, + new TokenManagerConfig(0.7F, 200, 2000, new TokenManagerConfig.RetryPolicy(5, 100))); + + JedisAuthXManager manager = new JedisAuthXManager(tokenManager); + ExecutionException e = assertThrows(ExecutionException.class, () -> manager.start(true)); + + assertEquals("java.lang.RuntimeException: Test exception from identity provider!", + e.getCause().getCause().getMessage()); + } + + @Test + public void testNoBlockForInitialToken() + throws InterruptedException, ExecutionException, TimeoutException { + int numberOfRetries = 5; + CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); + IdentityProvider identityProvider = () -> { + requesLatch.countDown(); + throw new RuntimeException("Test exception from identity provider!"); + }; + + TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, + 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 100))); + + JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + manager.start(false); + + requesLatch.await(); + verify(manager, Mockito.atLeastOnce()).onError(Mockito.any()); + verify(manager, Mockito.never()).authenticateConnections(Mockito.any()); + } + + @Test + public void testTokenManagerWithFailingTokenRequest() + throws InterruptedException, ExecutionException, TimeoutException { + int numberOfRetries = 5; + CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); + + IdentityProvider identityProvider = mock(IdentityProvider.class); + when(identityProvider.requestToken()).thenAnswer(invocation -> { + requesLatch.countDown(); + if (requesLatch.getCount() > 0) { + throw new RuntimeException("Test exception from identity provider!"); + } + return new SimpleToken("tokenValX", System.currentTimeMillis() + 50 * 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); + }); + + ArgumentCaptor argument = ArgumentCaptor.forClass(Token.class); + + TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, + 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 100))); + + TokenListener listener = mock(TokenListener.class); + tokenManager.start(listener, false); + requesLatch.await(); + verify(identityProvider, times(numberOfRetries)).requestToken(); + verify(listener, never()).onError(any()); + verify(listener).onTokenRenewed(argument.capture()); + assertEquals("tokenValX", argument.getValue().getValue()); + } + + @Test + public void testTokenManagerWithHangingTokenRequest() + throws InterruptedException, ExecutionException, TimeoutException { + int sleepDuration = 200; + int executionTimeout = 100; + int tokenLifetime = 50 * 1000; + int numberOfRetries = 5; + CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); + + IdentityProvider identityProvider = () -> { + requesLatch.countDown(); + if (requesLatch.getCount() > 0) { + try { + Thread.sleep(sleepDuration); + } catch (InterruptedException e) { + } + return null; + } + return new SimpleToken("tokenValX", System.currentTimeMillis() + tokenLifetime, + System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); + }; + + TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, + executionTimeout, new TokenManagerConfig.RetryPolicy(numberOfRetries, 100))); + + JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + manager.start(false); + requesLatch.await(); + verify(manager, never()).onError(any()); + await().atMost(2, TimeUnit.SECONDS).untilAsserted(() -> { + verify(manager, times(1)).authenticateConnections(any()); + }); + } } From f037440f1caa5f091871c5d81ad2ba5fc00c22e5 Mon Sep 17 00:00:00 2001 From: atakavci Date: Sun, 17 Nov 2024 18:38:15 +0300 Subject: [PATCH 04/21] -update submodule ref -change exception message --- redis-authx | 2 +- .../jedis/authentication/JedisAuthXManager.java | 2 +- .../TokenBasedAuthenticationIntegrationTests.java | 2 +- .../TokenBasedAuthenticationUnitTests.java | 10 ++++++---- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/redis-authx b/redis-authx index 8f56584858..7285ce7857 160000 --- a/redis-authx +++ b/redis-authx @@ -1 +1 @@ -Subproject commit 8f5658485897d2ca56af5238af25fc709cd0eaa9 +Subproject commit 7285ce78578652ed8b8132814792b806e7b78a26 diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java index 57a3bab894..a210f7fb42 100644 --- a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java +++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java @@ -79,7 +79,7 @@ public void authenticateConnections(Token token) { public void onError(Exception reason) { throw new JedisAuthenticationException( - "Token request/renewal failed with message:" + reason.getMessage(), reason); + "Token manager failed to acquire new token!", reason); } public Connection addConnection(Connection connection) { diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java index 51eb61d617..4cbf155dd3 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -54,7 +54,7 @@ public void testJedisPooledAuth() { TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) - .lowerRefreshBoundMillis(10000).tokenRequestExecutionTimeoutInMs(1000).build(); + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() .tokenAuthConfig(tokenAuthConfig).build(); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index 8a0720906e..7ac68361aa 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -180,7 +180,8 @@ public void testCalculateRenewalDelay() { } @Test - public void testAuthXManagerReceivesNewToken() throws InterruptedException, ExecutionException, TimeoutException { + public void testAuthXManagerReceivesNewToken() + throws InterruptedException, ExecutionException, TimeoutException { IdentityProvider identityProvider = () -> new SimpleToken("tokenVal", System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), @@ -204,8 +205,9 @@ public void testAuthXManagerReceivesNewToken() throws InterruptedException, Exec @Test public void testBlockForInitialToken() { + String exceptionMessage = "Test exception from identity provider!"; IdentityProvider identityProvider = () -> { - throw new RuntimeException("Test exception from identity provider!"); + throw new RuntimeException(exceptionMessage); }; TokenManager tokenManager = new TokenManager(identityProvider, @@ -214,8 +216,8 @@ public void testBlockForInitialToken() { JedisAuthXManager manager = new JedisAuthXManager(tokenManager); ExecutionException e = assertThrows(ExecutionException.class, () -> manager.start(true)); - assertEquals("java.lang.RuntimeException: Test exception from identity provider!", - e.getCause().getCause().getMessage()); + assertEquals(exceptionMessage, + e.getCause().getCause().getCause().getCause().getMessage()); } @Test From eb635206ca963b68a5eb41cc85faf1cb1867c666 Mon Sep 17 00:00:00 2001 From: atakavci Date: Mon, 18 Nov 2024 04:18:19 +0300 Subject: [PATCH 05/21] - remove submodule - update dependency --- .gitmodules | 4 ---- pom.xml | 14 +++++++------- redis-authx | 1 - 3 files changed, 7 insertions(+), 12 deletions(-) delete mode 100644 .gitmodules delete mode 160000 redis-authx diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e974dd3048..0000000000 --- a/.gitmodules +++ /dev/null @@ -1,4 +0,0 @@ -[submodule "redis-authx"] - path = redis-authx - url = https://github.com/redis/tbd-auth-entraid - branch = tba-draft diff --git a/pom.xml b/pom.xml index 0ed233657e..7bdb3b00e2 100644 --- a/pom.xml +++ b/pom.xml @@ -9,7 +9,7 @@ jar redis.clients jedis - 5.3.0-SNAPSHOT + 5.3.1-SNAPSHOT Jedis Jedis is a blazingly small and sane Redis java client. https://github.com/redis/jedis @@ -75,6 +75,12 @@ 2.11.0 + + redis.clients.authentication + redis-authx-core + 0.1.0-SNAPSHOT + + @@ -93,12 +99,6 @@ test - - redis.clients.authentication - redis-authx-core - 0.1.0 - - junit diff --git a/redis-authx b/redis-authx deleted file mode 160000 index 7285ce7857..0000000000 --- a/redis-authx +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7285ce78578652ed8b8132814792b806e7b78a26 From 523fe42400ae56d8e8b09cd5d5ca093d3f90c98a Mon Sep 17 00:00:00 2001 From: atakavci Date: Tue, 19 Nov 2024 09:59:06 +0300 Subject: [PATCH 06/21] back to current version --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index c596b99a57..a8930f2306 100644 --- a/pom.xml +++ b/pom.xml @@ -9,7 +9,7 @@ jar redis.clients jedis - 5.3.1-SNAPSHOT + 5.3.0-SNAPSHOT Jedis Jedis is a blazingly small and sane Redis java client. https://github.com/redis/jedis From 4e715352ec9846d52ab3fb8207a74b2939346398 Mon Sep 17 00:00:00 2001 From: atakavci Date: Thu, 5 Dec 2024 11:06:52 +0300 Subject: [PATCH 07/21] - move autxhmanager creation to user space - introduce authenticationeventlisteners - clenaup in connectionpool - add entraidtestcontext - add redisintegrationtests - fix failing tokenbasedauthentication unit&integ tests --- pom.xml | 7 + .../clients/jedis/ConnectionFactory.java | 39 +- .../redis/clients/jedis/ConnectionPool.java | 49 +-- .../jedis/DefaultJedisClientConfig.java | 29 +- .../clients/jedis/JedisClientConfig.java | 7 +- .../authentication/AuthXEventListener.java | 21 + ...disAuthXManager.java => AuthXManager.java} | 52 ++- .../JedisAuthenticationException.java | 4 +- .../authentication/EntraIDTestContext.java | 112 +++++ .../RedisEntraIDIntegrationTests.java | 404 ++++++++++++++++++ ...enBasedAuthenticationIntegrationTests.java | 89 ++-- .../TokenBasedAuthenticationUnitTests.java | 66 ++- 12 files changed, 734 insertions(+), 145 deletions(-) create mode 100644 src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java rename src/main/java/redis/clients/jedis/authentication/{JedisAuthXManager.java => AuthXManager.java} (63%) create mode 100644 src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java create mode 100644 src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java diff --git a/pom.xml b/pom.xml index a8930f2306..4ca1343f19 100644 --- a/pom.xml +++ b/pom.xml @@ -156,6 +156,13 @@ test + + redis.clients.authentication + redis-authx-entraid + 0.1.0-SNAPSHOT + test + + io.github.resilience4j diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index ce4a10cb7b..cdbe2ab5c7 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -11,7 +11,8 @@ import java.util.function.Supplier; import redis.clients.jedis.annots.Experimental; -import redis.clients.jedis.authentication.JedisAuthXManager; +import redis.clients.jedis.authentication.AuthXManager; +import redis.clients.jedis.authentication.AuthXEventListener; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.csc.CacheConnection; import redis.clients.jedis.exceptions.JedisException; @@ -28,41 +29,43 @@ public class ConnectionFactory implements PooledObjectFactory { private final Cache clientSideCache; private final Supplier objectMaker; + private final AuthXEventListener authenticationEventListener; + public ConnectionFactory(final HostAndPort hostAndPort) { - this(hostAndPort, DefaultJedisClientConfig.builder().build(), null, null); + this(hostAndPort, DefaultJedisClientConfig.builder().build(), null); } public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) { - this(hostAndPort, clientConfig, null, null); + this(hostAndPort, clientConfig, null); } @Experimental public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, - Cache csCache, JedisAuthXManager authXManager) { - this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache, - authXManager); + Cache csCache) { + this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache); } public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final JedisClientConfig clientConfig) { - this(jedisSocketFactory, clientConfig, null, null); + this(jedisSocketFactory, clientConfig, null); } private ConnectionFactory(final JedisSocketFactory jedisSocketFactory, - final JedisClientConfig clientConfig, Cache csCache, JedisAuthXManager authXManager) { + final JedisClientConfig clientConfig, Cache csCache) { this.jedisSocketFactory = jedisSocketFactory; this.clientSideCache = csCache; + AuthXManager authXManager = clientConfig.getAuthXManager(); if (authXManager == null) { this.clientConfig = clientConfig; this.objectMaker = connectionSupplier(); + this.authenticationEventListener = AuthXEventListener.NOOP_LISTENER; } else { - this.clientConfig = replaceCredentialsProvider(clientConfig, - authXManager); + this.clientConfig = replaceCredentialsProvider(clientConfig, authXManager); Supplier supplier = connectionSupplier(); this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get()); - + this.authenticationEventListener = authXManager.getListener(); try { authXManager.start(true); } catch (InterruptedException | ExecutionException | TimeoutException e) { @@ -114,7 +117,12 @@ public PooledObject makeObject() throws Exception { public void passivateObject(PooledObject pooledConnection) throws Exception { // TODO maybe should select db 0? Not sure right now. Connection jedis = pooledConnection.getObject(); - jedis.reAuth(); + try { + jedis.reAuth(); + } catch (Exception e) { + authenticationEventListener.onConnectionAuthenticationError(e); + throw e; + } } @Override @@ -122,7 +130,12 @@ public boolean validateObject(PooledObject pooledConnection) { final Connection jedis = pooledConnection.getObject(); try { // check HostAndPort ?? - jedis.reAuth(); + try { + jedis.reAuth(); + } catch (Exception e) { + authenticationEventListener.onConnectionAuthenticationError(e); + throw e; + } return jedis.isConnected() && jedis.ping(); } catch (final Exception e) { logger.warn("Error while validating pooled Connection object.", e); diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index d7dc0d85f7..536b3a6484 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -4,36 +4,25 @@ import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import redis.clients.jedis.annots.Experimental; -import redis.clients.jedis.authentication.JedisAuthXManager; +import redis.clients.jedis.authentication.AuthXManager; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.util.Pool; public class ConnectionPool extends Pool { - private JedisAuthXManager authXManager; + private AuthXManager authXManager; public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) { - this(hostAndPort, clientConfig, createAuthXManager(clientConfig)); - } - - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, - JedisAuthXManager authXManager) { - this(new ConnectionFactory(hostAndPort, clientConfig, null, authXManager)); - attachAuthenticationListener(authXManager); + this(new ConnectionFactory(hostAndPort, clientConfig)); + attachAuthenticationListener(clientConfig.getAuthXManager()); } @Experimental public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) { - this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig)); - } - - @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, - Cache clientSideCache, JedisAuthXManager authXManager) { - this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager)); - attachAuthenticationListener(authXManager); + this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache)); + attachAuthenticationListener(clientConfig.getAuthXManager()); } public ConnectionPool(PooledObjectFactory factory) { @@ -42,22 +31,15 @@ public ConnectionPool(PooledObjectFactory factory) { public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, GenericObjectPoolConfig poolConfig) { - this(hostAndPort, clientConfig, null, createAuthXManager(clientConfig), poolConfig); + this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig); + attachAuthenticationListener(clientConfig.getAuthXManager()); } @Experimental public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache, GenericObjectPoolConfig poolConfig) { - this(hostAndPort, clientConfig, clientSideCache, createAuthXManager(clientConfig), poolConfig); - } - - @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, - Cache clientSideCache, JedisAuthXManager authXManager, - GenericObjectPoolConfig poolConfig) { - this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache, authXManager), - poolConfig); - attachAuthenticationListener(authXManager); + this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig); + attachAuthenticationListener(clientConfig.getAuthXManager()); } public ConnectionPool(PooledObjectFactory factory, @@ -80,17 +62,10 @@ public void close() { super.close(); } - private static JedisAuthXManager createAuthXManager(JedisClientConfig config) { - if (config.getTokenAuthConfig() != null) { - return new JedisAuthXManager(config.getTokenAuthConfig()); - } - return null; - } - - private void attachAuthenticationListener(JedisAuthXManager authXManager) { + private void attachAuthenticationListener(AuthXManager authXManager) { this.authXManager = authXManager; if (authXManager != null) { - authXManager.setListener(token -> { + authXManager.addPostAuthenticationHook(token -> { try { // this is to trigger validations on each connection via ConnectionFactory evict(); diff --git a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java index 8b161ca7ff..5f0e050ef4 100644 --- a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java @@ -5,7 +5,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; -import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.jedis.authentication.AuthXManager; public final class DefaultJedisClientConfig implements JedisClientConfig { @@ -30,7 +30,7 @@ public final class DefaultJedisClientConfig implements JedisClientConfig { private final boolean readOnlyForRedisClusterReplicas; - private final TokenAuthConfig tokenAuthConfig; + private final AuthXManager authXManager; private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMillis, int soTimeoutMillis, int blockingSocketTimeoutMillis, @@ -38,7 +38,7 @@ private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMi SSLSocketFactory sslSocketFactory, SSLParameters sslParameters, HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper, ClientSetInfoConfig clientSetInfoConfig, boolean readOnlyForRedisClusterReplicas, - TokenAuthConfig tokenAuthConfig) { + AuthXManager authXManager) { this.redisProtocol = protocol; this.connectionTimeoutMillis = connectionTimeoutMillis; this.socketTimeoutMillis = soTimeoutMillis; @@ -53,7 +53,8 @@ private DefaultJedisClientConfig(RedisProtocol protocol, int connectionTimeoutMi this.hostAndPortMapper = hostAndPortMapper; this.clientSetInfoConfig = clientSetInfoConfig; this.readOnlyForRedisClusterReplicas = readOnlyForRedisClusterReplicas; - this.tokenAuthConfig = tokenAuthConfig; + this.authXManager = authXManager; + } @Override @@ -93,8 +94,8 @@ public Supplier getCredentialsProvider() { } @Override - public TokenAuthConfig getTokenAuthConfig() { - return tokenAuthConfig; + public AuthXManager getAuthXManager() { + return authXManager; } @Override @@ -171,7 +172,7 @@ public static class Builder { private boolean readOnlyForRedisClusterReplicas = false; - private TokenAuthConfig tokenAuthConfig = null; + private AuthXManager authXManager; private Builder() { } @@ -185,7 +186,7 @@ public DefaultJedisClientConfig build() { return new DefaultJedisClientConfig(redisProtocol, connectionTimeoutMillis, socketTimeoutMillis, blockingSocketTimeoutMillis, credentialsProvider, database, clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, - clientSetInfoConfig, readOnlyForRedisClusterReplicas, tokenAuthConfig); + clientSetInfoConfig, readOnlyForRedisClusterReplicas, authXManager); } /** @@ -287,8 +288,8 @@ public Builder readOnlyForRedisClusterReplicas() { return this; } - public Builder tokenAuthConfig(TokenAuthConfig tokenAuthConfig) { - this.tokenAuthConfig = tokenAuthConfig; + public Builder authXManager(AuthXManager authXManager) { + this.authXManager = authXManager; return this; } @@ -307,7 +308,7 @@ public Builder from(JedisClientConfig instance) { this.hostAndPortMapper = instance.getHostAndPortMapper(); this.clientSetInfoConfig = instance.getClientSetInfoConfig(); this.readOnlyForRedisClusterReplicas = instance.isReadOnlyForRedisClusterReplicas(); - this.tokenAuthConfig = instance.getTokenAuthConfig(); + this.authXManager = instance.getAuthXManager(); return this; } } @@ -316,12 +317,12 @@ public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int s int blockingSocketTimeoutMillis, String user, String password, int database, String clientName, boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters, HostnameVerifier hostnameVerifier, - HostAndPortMapper hostAndPortMapper, TokenAuthConfig tokenAuthConfig) { + HostAndPortMapper hostAndPortMapper, AuthXManager authXManager) { return new DefaultJedisClientConfig(null, connectionTimeoutMillis, soTimeoutMillis, blockingSocketTimeoutMillis, new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(user, password)), database, clientName, ssl, sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper, null, - false, tokenAuthConfig); + false, authXManager); } public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) { @@ -330,6 +331,6 @@ public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) { copy.getCredentialsProvider(), copy.getDatabase(), copy.getClientName(), copy.isSsl(), copy.getSslSocketFactory(), copy.getSslParameters(), copy.getHostnameVerifier(), copy.getHostAndPortMapper(), copy.getClientSetInfoConfig(), - copy.isReadOnlyForRedisClusterReplicas(), copy.getTokenAuthConfig()); + copy.isReadOnlyForRedisClusterReplicas(), copy.getAuthXManager()); } } diff --git a/src/main/java/redis/clients/jedis/JedisClientConfig.java b/src/main/java/redis/clients/jedis/JedisClientConfig.java index a8046694bf..82e9eb8e7f 100644 --- a/src/main/java/redis/clients/jedis/JedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/JedisClientConfig.java @@ -5,7 +5,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; -import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.jedis.authentication.AuthXManager; public interface JedisClientConfig { @@ -47,10 +47,11 @@ default String getPassword() { } default Supplier getCredentialsProvider() { - return new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(getUser(), getPassword())); + return new DefaultRedisCredentialsProvider( + new DefaultRedisCredentials(getUser(), getPassword())); } - default TokenAuthConfig getTokenAuthConfig() { + default AuthXManager getAuthXManager() { return null; } diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java new file mode 100644 index 0000000000..4750404157 --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java @@ -0,0 +1,21 @@ +package redis.clients.jedis.authentication; + +public interface AuthXEventListener { + + static AuthXEventListener NOOP_LISTENER = new AuthXEventListener() { + + @Override + public void onIdentityProviderError(Exception reason) { + } + + @Override + public void onConnectionAuthenticationError(Exception reason) { + } + + }; + + public void onIdentityProviderError(Exception reason); + + public void onConnectionAuthenticationError(Exception reason); + +} diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java similarity index 63% rename from src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java rename to src/main/java/redis/clients/jedis/authentication/AuthXManager.java index a210f7fb42..d66bddeb4c 100644 --- a/src/main/java/redis/clients/jedis/authentication/JedisAuthXManager.java +++ b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; +import java.util.function.Consumer; import java.util.function.Supplier; import org.slf4j.Logger; @@ -18,25 +19,22 @@ import redis.clients.jedis.Connection; import redis.clients.jedis.RedisCredentials; -public class JedisAuthXManager implements Supplier { +public final class AuthXManager implements Supplier { - private static final Logger log = LoggerFactory.getLogger(JedisAuthXManager.class); + private static final Logger log = LoggerFactory.getLogger(AuthXManager.class); private TokenManager tokenManager; private List> connections = Collections .synchronizedList(new ArrayList<>()); private Token currentToken; - private AuthenticationListener listener; + private AuthXEventListener listener = AuthXEventListener.NOOP_LISTENER; + private List> postAuthenticateHooks = new ArrayList<>(); - public interface AuthenticationListener { - public void onAuthenticate(Token token); - } - - public JedisAuthXManager(TokenManager tokenManager) { + protected AuthXManager(TokenManager tokenManager) { this.tokenManager = tokenManager; } - public JedisAuthXManager(TokenAuthConfig tokenAuthConfig) { + public AuthXManager(TokenAuthConfig tokenAuthConfig) { this(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(), tokenAuthConfig.getTokenManagerConfig())); } @@ -53,7 +51,8 @@ public void onTokenRenewed(Token token) { @Override public void onError(Exception reason) { - JedisAuthXManager.this.onError(reason); + listener.onIdentityProviderError(reason); + AuthXManager.this.onError(reason); } }, blockForInitialToken); } @@ -63,23 +62,18 @@ public void authenticateConnections(Token token) { for (WeakReference connectionRef : connections) { Connection connection = connectionRef.get(); if (connection != null) { - try { - connection.setCredentials(credentialsFromToken); - } catch (Exception e) { - log.error("Failed to authenticate connection!", e); - } + connection.setCredentials(credentialsFromToken); } else { connections.remove(connectionRef); } } - if (listener != null) { - listener.onAuthenticate(token); - } + postAuthenticateHooks.forEach(hook -> hook.accept(token)); } public void onError(Exception reason) { - throw new JedisAuthenticationException( - "Token manager failed to acquire new token!", reason); + log.error("Token manager failed to acquire new token!", reason); + throw new JedisAuthenticationException("Token manager failed to acquire new token!", + reason); } public Connection addConnection(Connection connection) { @@ -91,8 +85,22 @@ public void stop() { tokenManager.stop(); } - public void setListener(AuthenticationListener listener) { - this.listener = listener; + public void setListener(AuthXEventListener listener) { + if (listener != null) { + this.listener = listener; + } + } + + public void addPostAuthenticationHook(Consumer postAuthenticateHook) { + postAuthenticateHooks.add(postAuthenticateHook); + } + + public void removePostAuthenticationHook(Consumer postAuthenticateHook) { + postAuthenticateHooks.remove(postAuthenticateHook); + } + + public AuthXEventListener getListener() { + return listener; } @Override diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java index adc421e790..c70ab98720 100644 --- a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java +++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java @@ -1,6 +1,8 @@ package redis.clients.jedis.authentication; -public class JedisAuthenticationException extends RuntimeException { +import redis.clients.jedis.exceptions.JedisException; + +public class JedisAuthenticationException extends JedisException { public JedisAuthenticationException(String message) { super(message); diff --git a/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java new file mode 100644 index 0000000000..e0cde9cfef --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java @@ -0,0 +1,112 @@ +package redis.clients.jedis.authentication; + +import java.io.ByteArrayInputStream; +import java.security.KeyFactory; +import java.security.PrivateKey; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashSet; +import java.util.Set; + +public class EntraIDTestContext { + private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID"; + private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY"; + private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"; + private static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY"; + private static final String AZURE_CERT = "AZURE_CERT"; + private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES"; + + private String clientId; + private String authority; + private String clientSecret; + private PrivateKey privateKey; + private X509Certificate cert; + private Set redisScopes; + + public static final EntraIDTestContext DEFAULT = new EntraIDTestContext(); + + private EntraIDTestContext() { + clientId = System.getenv(AZURE_CLIENT_ID); + authority = System.getenv(AZURE_AUTHORITY); + clientSecret = System.getenv(AZURE_CLIENT_SECRET); + } + + public EntraIDTestContext(String clientId, String authority, String clientSecret, + Set redisScopes) { + this.clientId = clientId; + this.authority = authority; + this.clientSecret = clientSecret; + this.redisScopes = redisScopes; + } + + public String getClientId() { + return clientId; + } + + public String getAuthority() { + return authority; + } + + public String getClientSecret() { + return clientSecret; + } + + public PrivateKey getPrivateKey() { + if (privateKey == null) { + this.privateKey = getPrivateKey(System.getenv(AZURE_PRIVATE_KEY)); + } + return privateKey; + } + + public X509Certificate getCert() { + if (cert == null) { + this.cert = getCert(System.getenv(AZURE_CERT)); + } + return cert; + } + + public Set getRedisScopes() { + if (redisScopes == null) { + String redisScopesEnv = System.getenv(AZURE_REDIS_SCOPES); + this.redisScopes = new HashSet<>(Arrays.asList(redisScopesEnv.split(";"))); + } + return redisScopes; + } + + private PrivateKey getPrivateKey(String privateKey) { + try { + // Decode the base64 encoded key into a byte array + byte[] decodedKey = Base64.getDecoder().decode(privateKey); + + // Generate the private key from the decoded byte array using PKCS8EncodedKeySpec + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKey); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // Use the correct algorithm (e.g., "RSA", "EC", "DSA") + PrivateKey key = keyFactory.generatePrivate(keySpec); + return key; + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + private X509Certificate getCert(String cert) { + try { + // Convert the Base64 encoded string into a byte array + byte[] encoded = java.util.Base64.getDecoder().decode(cert); + + // Create a CertificateFactory for X.509 certificates + CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + + // Generate the certificate from the byte array + X509Certificate certificate = (X509Certificate) certificateFactory + .generateCertificate(new ByteArrayInputStream(encoded)); + return certificate; + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java new file mode 100644 index 0000000000..4639e34b02 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -0,0 +1,404 @@ +package redis.clients.jedis.authentication; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.awaitility.Awaitility; +import org.awaitility.Durations; +import org.junit.BeforeClass; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; +import org.mockito.MockedConstruction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.entraid.EntraIDIdentityProvider; +import redis.clients.authentication.entraid.EntraIDIdentityProviderConfig; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; +import redis.clients.authentication.entraid.ManagedIdentityInfo.UserManagedIdentityType; +import redis.clients.authentication.entraid.ServicePrincipalInfo; +import redis.clients.jedis.Connection; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.exceptions.JedisAccessControlException; +import redis.clients.jedis.exceptions.JedisConnectionException; +import redis.clients.jedis.scenario.FaultInjectionClient; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class RedisEntraIDIntegrationTests { + private static final Logger log = LoggerFactory.getLogger(RedisEntraIDIntegrationTests.class); + + private static EntraIDTestContext testCtx; + private static EndpointConfig endpointConfig; + private static HostAndPort hnp; + + private final FaultInjectionClient faultClient = new FaultInjectionClient(); + + @BeforeClass + public static void before() { + try { + testCtx = EntraIDTestContext.DEFAULT; + endpointConfig = HostAndPorts.getRedisEndpoint("standalone-entraid-acl"); + hnp = endpointConfig.getHostAndPort(); + } catch (IllegalArgumentException e) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + + @Test + public void testJedisConfig() { + AtomicInteger counter = new AtomicInteger(0); + try (MockedConstruction mockedConstructor = mockConstruction( + EntraIDIdentityProvider.class, (mock, context) -> { + ServicePrincipalInfo info = (ServicePrincipalInfo) context.arguments().get(0); + + assertEquals(testCtx.getClientId(), info.getClientId()); + assertEquals(testCtx.getAuthority(), info.getAuthority()); + assertEquals(testCtx.getClientSecret(), info.getSecret()); + assertEquals(testCtx.getRedisScopes(), context.arguments().get(1)); + assertNotNull(mock); + doAnswer(invocation -> { + counter.incrementAndGet(); + return new SimpleToken("token1", System.currentTimeMillis() + 5 * 60 * 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "default")); + }).when(mock).requestToken(); + })) { + + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .authority(testCtx.getAuthority()).clientId(testCtx.getClientId()) + .secret(testCtx.getClientSecret()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + JedisPooled jedis = new JedisPooled(new HostAndPort("localhost", 6379), jedisConfig); + assertNotNull(jedis); + assertEquals(1, counter.get()); + + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + // @Test + public void withUserAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()) + .userAssignedManagedIdentity(UserManagedIdentityType.CLIENT_ID, "userManagedAuthxId") + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + // @Test + public void withSystemAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).systemAssignedManagedIdentity() + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with service principals + @Test + public void withSecret_azureServicePrincipalIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with service principals + @Test + public void withCertificate_azureServicePrincipalIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.2.2 + // Test that the Redis client is not blocked/interrupted during token renewal. + @Test + public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException { + // set the stage with consecutive get/set operations with unique keys which takes at least for 2000 ms with a jedispooled instace, + // configure token manager to renew token approximately every 100ms + // wait till all operations are completed and verify that token was renewed at least 20 times after initial token acquisition + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.000001F).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + Consumer hook = mock(Consumer.class); + authXManager.addPostAuthenticationHook(hook); + + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + long startTime = System.currentTimeMillis(); + List> futures = new ArrayList<>(); + ExecutorService executor = Executors.newFixedThreadPool(5); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + Future future = executor.submit(() -> { + for (; System.currentTimeMillis() - startTime < 2000;) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + futures.add(future); + } + for (Future task : futures) { + task.get(); + } + + verify(hook, atLeast(20)).accept(any()); + executor.shutdown(); + } + } + + // T.3.2 + // Verify that all existing connections can be re-authenticated when a new token is received. + @Test + public void allConnectionsReauthTest() throws InterruptedException, ExecutionException { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.000001F).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + authXManager = spy(authXManager); + + List connections = new ArrayList<>(); + + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + long startTime = System.currentTimeMillis(); + List> futures = new ArrayList<>(); + ExecutorService executor = Executors.newFixedThreadPool(5); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + Future future = executor.submit(() -> { + for (; System.currentTimeMillis() - startTime < 2000;) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + futures.add(future); + } + for (Future task : futures) { + task.get(); + } + + connections.forEach(conn -> { + verify(conn, atLeast(1)).reAuth(); + }); + executor.shutdown(); + } + } + + // T.3.2 + // Test system behavior when some connections fail to re-authenticate during bulk authentication. e.g when a network partition occurs for 1 or more of them + @Test + public void partialReauthFailureTest() { + + } + + // T.3.3 + // Verify behavior when attempting to authenticate a single connection with an expired token. + @Test + public void connectionAuthWithExpiredTokenTest() { + IdentityProvider idp = new EntraIDIdentityProviderConfig( + new ServicePrincipalInfo(testCtx.getClientId(), testCtx.getClientSecret(), + testCtx.getAuthority()), + testCtx.getRedisScopes()).getProvider(); + + IdentityProvider mockIdentityProvider = mock(IdentityProvider.class); + AtomicReference token = new AtomicReference<>(); + doAnswer(invocation -> { + if (token.get() == null) { + token.set(idp.requestToken()); + } + return token.get(); + }).when(mockIdentityProvider).requestToken(); + IdentityProviderConfig idpConfig = mock(IdentityProviderConfig.class); + when(idpConfig.getProvider()).thenReturn(mockIdentityProvider); + + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .identityProviderConfig(idpConfig).expirationRefreshRatio(0.000001F).build(); + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + + token + .set(new SimpleToken("token1", System.currentTimeMillis() - 1, System.currentTimeMillis(), + Collections.singletonMap("oid", idp.requestToken().tryGet("oid")))); + + JedisAccessControlException aclException = assertThrows(JedisAccessControlException.class, + () -> { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + + assertEquals("WRONGPASS invalid username-password pair", aclException.getMessage()); + } + } + + // T.3.4 + // Verify handling of reconnection and re-authentication after a network partition. (use cached token) + // @Test + public void networkPartitionEvictionTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.5F).build(); + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + + triggerNetworkFailure(); + + JedisConnectionException aclException = assertThrows(JedisConnectionException.class, () -> { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + + assertEquals("Unexpected end of stream.", aclException.getMessage()); + Awaitility.await().pollDelay(Durations.ONE_HUNDRED_MILLISECONDS).atMost(Durations.TWO_SECONDS) + .until(() -> { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + return true; + }); + } + } + + private void triggerNetworkFailure() { + HashMap params = new HashMap<>(); + params.put("bdb_id", endpointConfig.getBdbId()); + + FaultInjectionClient.TriggerActionResponse actionResponse = null; + String action = "network_failure"; + try { + log.info("Triggering {}", action); + actionResponse = faultClient.triggerAction(action, params); + } catch (IOException e) { + fail("Fault Injection Server error:" + e.getMessage()); + } + log.info("Action id: {}", actionResponse.getActionId()); + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java index 4cbf155dd3..780c82c781 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -9,15 +9,17 @@ import java.util.Arrays; import java.util.Collections; -import java.util.Date; import java.util.List; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import redis.clients.authentication.core.IdentityProvider; import redis.clients.authentication.core.IdentityProviderConfig; @@ -36,52 +38,65 @@ import redis.clients.jedis.commands.ProtocolCommand; public class TokenBasedAuthenticationIntegrationTests { + private static final Logger log = LoggerFactory + .getLogger(TokenBasedAuthenticationIntegrationTests.class); + + private static EndpointConfig endpointConfig; + + @BeforeClass + public static void before() { + try { + endpointConfig = HostAndPorts.getRedisEndpoint("standalone0"); + } catch (IllegalArgumentException e) { + try { + endpointConfig = HostAndPorts.getRedisEndpoint("standalone"); + } catch (IllegalArgumentException ex) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + } - protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); - - @Test - public void testJedisPooledAuth() { - String user = "default"; - String password = endpoint.getPassword(); + @Test + public void testJedisPooledForInitialAuth() { + String user = "default"; + String password = endpointConfig.getPassword(); - IdentityProvider idProvider = mock(IdentityProvider.class); - when(idProvider.requestToken()) - .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000, - System.currentTimeMillis(), Collections.singletonMap("oid", user))); + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000, + System.currentTimeMillis(), Collections.singletonMap("oid", user))); - IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); - when(idProviderConfig.getProvider()).thenReturn(idProvider); + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); - TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() - .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) - .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); - JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() - .tokenAuthConfig(tokenAuthConfig).build(); + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); - try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) { - ArgumentCaptor captor = ArgumentCaptor - .forClass(CommandArguments.class); + try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) { + ArgumentCaptor captor = ArgumentCaptor.forClass(CommandArguments.class); - try (JedisPooled jedis = new JedisPooled(endpoint.getHostAndPort(), clientConfig)) { - jedis.get("key1"); - } + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + jedis.get("key1"); + } - // Verify that the static method was called - mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()), - Mockito.atLeast(4)); + // Verify that the static method was called + mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()), Mockito.atLeast(4)); - CommandArguments commandArgs = captor.getAllValues().get(0); - List args = StreamSupport.stream(commandArgs.spliterator(), false) - .map(Rawable::getRaw).collect(Collectors.toList()); + CommandArguments commandArgs = captor.getAllValues().get(0); + List args = StreamSupport.stream(commandArgs.spliterator(), false) + .map(Rawable::getRaw).collect(Collectors.toList()); - assertThat(args, - contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes())); + assertThat(args, + contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes())); - List cmds = captor.getAllValues().stream() - .map(item -> item.getCommand()).collect(Collectors.toList()); - assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET), - cmds); - } + List cmds = captor.getAllValues().stream().map(item -> item.getCommand()) + .collect(Collectors.toList()); + assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET), cmds); } + } } diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index 7ac68361aa..83c441d492 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import org.awaitility.Durations; import org.hamcrest.Matchers; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -36,12 +37,15 @@ import redis.clients.authentication.core.TokenListener; import redis.clients.authentication.core.TokenManager; import redis.clients.authentication.core.TokenManagerConfig; +import redis.clients.authentication.core.TokenRequestException; import redis.clients.jedis.ConnectionPool; import redis.clients.jedis.EndpointConfig; -import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.HostAndPort; public class TokenBasedAuthenticationUnitTests { - protected static final EndpointConfig endpoint = HostAndPorts.getRedisEndpoint("standalone0"); + + private HostAndPort hnp = new HostAndPort("localhost", 6379); + private EndpointConfig endpoint = new EndpointConfig(hnp, null, null, false); @Test public void testJedisAuthXManagerInstance() { @@ -57,25 +61,51 @@ public void testJedisAuthXManagerInstance() { assertEquals(tokenManagerConfig, context.arguments().get(1)); })) { - new JedisAuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig)); + new AuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig)); } } @Test - public void testJedisAuthXManagerTriggersEvict() throws Exception { + public void withExpirationRefreshRatio_testJedisAuthXManagerTriggersEvict() throws Exception { IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 100000, + .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, - new TokenManagerConfig(0.5F, 1000, 1000, null)); - JedisAuthXManager jedisAuthXManager = new JedisAuthXManager(tokenManager); + new TokenManagerConfig(0.4F, 100, 1000, null)); + AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); + + AtomicInteger numberOfEvictions = new AtomicInteger(0); + ConnectionPool pool = new ConnectionPool(hnp, + endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) { + @Override + public void evict() throws Exception { + numberOfEvictions.incrementAndGet(); + super.evict(); + } + }; + + await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS) + .atMost(Durations.FIVE_HUNDRED_MILLISECONDS) + .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); + } + + public void withLowerRefreshBounds_testJedisAuthXManagerTriggersEvict() throws Exception { + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); + + TokenManager tokenManager = new TokenManager(idProvider, + new TokenManagerConfig(0.9F, 600, 1000, null)); + AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); AtomicInteger numberOfEvictions = new AtomicInteger(0); ConnectionPool pool = new ConnectionPool(endpoint.getHostAndPort(), - endpoint.getClientConfigBuilder().build(), jedisAuthXManager) { + endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) { @Override public void evict() throws Exception { numberOfEvictions.incrementAndGet(); @@ -83,8 +113,9 @@ public void evict() throws Exception { } }; - jedisAuthXManager.start(true); - assertEquals(1, numberOfEvictions.get()); + await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS) + .atMost(Durations.FIVE_HUNDRED_MILLISECONDS) + .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); } public static class TokenManagerConfigWrapper extends TokenManagerConfig { @@ -190,7 +221,7 @@ public void testAuthXManagerReceivesNewToken() TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, 2000, null)); - JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + AuthXManager manager = spy(new AuthXManager(tokenManager)); final Token[] tokenHolder = new Token[1]; doAnswer(invocation -> { @@ -213,11 +244,10 @@ public void testBlockForInitialToken() { TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, 2000, new TokenManagerConfig.RetryPolicy(5, 100))); - JedisAuthXManager manager = new JedisAuthXManager(tokenManager); - ExecutionException e = assertThrows(ExecutionException.class, () -> manager.start(true)); + AuthXManager manager = new AuthXManager(tokenManager); + TokenRequestException e = assertThrows(TokenRequestException.class, () -> manager.start(true)); - assertEquals(exceptionMessage, - e.getCause().getCause().getCause().getCause().getMessage()); + assertEquals(exceptionMessage, e.getCause().getCause().getCause().getMessage()); } @Test @@ -231,9 +261,9 @@ public void testNoBlockForInitialToken() }; TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, - 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 100))); + 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 0))); - JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + AuthXManager manager = spy(new AuthXManager(tokenManager)); manager.start(false); requesLatch.await(); @@ -296,7 +326,7 @@ public void testTokenManagerWithHangingTokenRequest() TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, executionTimeout, new TokenManagerConfig.RetryPolicy(numberOfRetries, 100))); - JedisAuthXManager manager = spy(new JedisAuthXManager(tokenManager)); + AuthXManager manager = spy(new AuthXManager(tokenManager)); manager.start(false); requesLatch.await(); verify(manager, never()).onError(any()); From 08750e5c3255d95183c596692ed7159bd5b7b1bd Mon Sep 17 00:00:00 2001 From: atakavci Date: Thu, 5 Dec 2024 17:26:37 +0300 Subject: [PATCH 08/21] - prevent use of pubsub with TBA+RESP2 combination - fix flaky test --- .../java/redis/clients/jedis/Connection.java | 6 +++ .../redis/clients/jedis/JedisPubSubBase.java | 15 +++++-- .../RedisEntraIDIntegrationTests.java | 2 +- .../TokenBasedAuthenticationUnitTests.java | 39 ++++++++++++------- 4 files changed, 44 insertions(+), 18 deletions(-) diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index 655158f1eb..afd69e91d7 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -47,6 +47,7 @@ public class Connection implements Closeable { protected String version; protected AtomicReference currentCredentials = new AtomicReference( null); + private boolean isTokenBasedAuthenticationEnabled = false; public Connection() { this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT); @@ -453,6 +454,7 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { connect(); protocol = config.getRedisProtocol(); + isTokenBasedAuthenticationEnabled = (config.getAuthXManager() != null); final Supplier credentialsProvider = config.getCredentialsProvider(); if (credentialsProvider instanceof RedisCredentialsProvider) { @@ -601,4 +603,8 @@ public boolean ping() { } return true; } + + public boolean isTokenBasedAuthenticationEnabled() { + return isTokenBasedAuthenticationEnabled; + } } diff --git a/src/main/java/redis/clients/jedis/JedisPubSubBase.java b/src/main/java/redis/clients/jedis/JedisPubSubBase.java index bf9d0a32c5..5c96278fb9 100644 --- a/src/main/java/redis/clients/jedis/JedisPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisPubSubBase.java @@ -53,13 +53,22 @@ public final void unsubscribe(T... channels) { } public final void subscribe(T... channels) { + checkConnectionSuitableForPubSub(); sendAndFlushCommand(Command.SUBSCRIBE, channels); } public final void psubscribe(T... patterns) { + checkConnectionSuitableForPubSub(); sendAndFlushCommand(Command.PSUBSCRIBE, patterns); } + private void checkConnectionSuitableForPubSub() { + if (client.protocol == RedisProtocol.RESP2 && client.isTokenBasedAuthenticationEnabled()) { + throw new JedisException( + "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!"); + } + } + public final void punsubscribe() { sendAndFlushCommand(Command.PUNSUBSCRIBE); } @@ -108,7 +117,7 @@ public final void proceedWithPatterns(Connection client, T... patterns) { protected abstract T encode(byte[] raw); -// private void process(Client client) { + // private void process(Client client) { private void process() { do { @@ -177,7 +186,7 @@ private void process() { } } while (!Thread.currentThread().isInterrupted() && isSubscribed()); -// /* Invalidate instance since this thread is no longer listening */ -// this.client = null; + // /* Invalidate instance since this thread is no longer listening */ + // this.client = null; } } diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index 4639e34b02..e31e6ee595 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -299,7 +299,7 @@ public void connectionAuthWithExpiredTokenTest() { IdentityProvider idp = new EntraIDIdentityProviderConfig( new ServicePrincipalInfo(testCtx.getClientId(), testCtx.getClientSecret(), testCtx.getAuthority()), - testCtx.getRedisScopes()).getProvider(); + testCtx.getRedisScopes(),1000).getProvider(); IdentityProvider mockIdentityProvider = mock(IdentityProvider.class); AtomicReference token = new AtomicReference<>(); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index 83c441d492..469364d118 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -1,7 +1,6 @@ package redis.clients.jedis.authentication; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -12,9 +11,13 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.atLeast; import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.*; import static org.hamcrest.CoreMatchers.either; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + import java.util.Collections; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; @@ -22,12 +25,10 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; -import org.awaitility.Durations; import org.hamcrest.Matchers; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.MockedConstruction; -import org.mockito.Mockito; import redis.clients.authentication.core.IdentityProvider; import redis.clients.authentication.core.IdentityProviderConfig; @@ -87,8 +88,7 @@ public void evict() throws Exception { } }; - await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS) - .atMost(Durations.FIVE_HUNDRED_MILLISECONDS) + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_HUNDRED_MILLISECONDS) .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); } @@ -113,8 +113,7 @@ public void evict() throws Exception { } }; - await().pollInterval(Durations.ONE_HUNDRED_MILLISECONDS) - .atMost(Durations.FIVE_HUNDRED_MILLISECONDS) + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_HUNDRED_MILLISECONDS) .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); } @@ -253,22 +252,34 @@ public void testBlockForInitialToken() { @Test public void testNoBlockForInitialToken() throws InterruptedException, ExecutionException, TimeoutException { - int numberOfRetries = 5; + int numberOfRetries = 1; CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); IdentityProvider identityProvider = () -> { - requesLatch.countDown(); + try { + System.out.println("awaiting"); + requesLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } throw new RuntimeException("Test exception from identity provider!"); }; - TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, - 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 0))); + TokenManager tokenManager = new TokenManager(identityProvider, + new TokenManagerConfig(0.7F, 200, 500, new TokenManagerConfig.RetryPolicy(5, 0))); AuthXManager manager = spy(new AuthXManager(tokenManager)); manager.start(false); - requesLatch.await(); - verify(manager, Mockito.atLeastOnce()).onError(Mockito.any()); - verify(manager, Mockito.never()).authenticateConnections(Mockito.any()); + await().during(FIVE_HUNDRED_MILLISECONDS).until(tokenManager::getCurrentToken, + Matchers.nullValue()); + verify(manager, never()).onError(any()); + verify(manager, never()).authenticateConnections(any()); + requesLatch.countDown(); + + await().during(FIVE_HUNDRED_MILLISECONDS).until(tokenManager::getCurrentToken, + Matchers.nullValue()); + verify(manager, atLeast(1)).onError(any()); + verify(manager, never()).authenticateConnections(any()); } @Test From b1ab1dbb1af22720fa729b2c6f7cea816a864673 Mon Sep 17 00:00:00 2001 From: atakavci Date: Fri, 6 Dec 2024 15:36:33 +0300 Subject: [PATCH 09/21] - support tba with clusters - add cluster+tba tests --- .../java/redis/clients/jedis/Connection.java | 11 +- .../clients/jedis/ConnectionFactory.java | 29 ++-- .../clients/jedis/JedisClusterInfoCache.java | 3 + .../jedis/authentication/AuthXManager.java | 31 ++++- ...AuthenticationClusterIntegrationTests.java | 131 ++++++++++++++++++ .../TokenBasedAuthenticationUnitTests.java | 9 +- 6 files changed, 182 insertions(+), 32 deletions(-) create mode 100644 src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index afd69e91d7..fe57fb80fb 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -21,6 +21,7 @@ import redis.clients.jedis.annots.Experimental; import redis.clients.jedis.args.ClientAttributeOption; import redis.clients.jedis.args.Rawable; +import redis.clients.jedis.authentication.AuthXManager; import redis.clients.jedis.commands.ProtocolCommand; import redis.clients.jedis.exceptions.JedisConnectionException; import redis.clients.jedis.exceptions.JedisDataException; @@ -454,9 +455,15 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { connect(); protocol = config.getRedisProtocol(); - isTokenBasedAuthenticationEnabled = (config.getAuthXManager() != null); - final Supplier credentialsProvider = config.getCredentialsProvider(); + Supplier credentialsProvider = config.getCredentialsProvider(); + + AuthXManager authXManager = config.getAuthXManager(); + if (authXManager != null) { + isTokenBasedAuthenticationEnabled = true; + credentialsProvider = authXManager; + } + if (credentialsProvider instanceof RedisCredentialsProvider) { final RedisCredentialsProvider redisCredentialsProvider = (RedisCredentialsProvider) credentialsProvider; try { diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index cdbe2ab5c7..45e89fc2da 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -6,8 +6,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; import java.util.function.Supplier; import redis.clients.jedis.annots.Experimental; @@ -29,7 +27,7 @@ public class ConnectionFactory implements PooledObjectFactory { private final Cache clientSideCache; private final Supplier objectMaker; - private final AuthXEventListener authenticationEventListener; + private final AuthXEventListener authXEventListener; public ConnectionFactory(final HostAndPort hostAndPort) { this(hostAndPort, DefaultJedisClientConfig.builder().build(), null); @@ -55,31 +53,20 @@ private ConnectionFactory(final JedisSocketFactory jedisSocketFactory, this.jedisSocketFactory = jedisSocketFactory; this.clientSideCache = csCache; - AuthXManager authXManager = clientConfig.getAuthXManager(); + this.clientConfig = clientConfig; + AuthXManager authXManager = clientConfig.getAuthXManager(); if (authXManager == null) { - this.clientConfig = clientConfig; this.objectMaker = connectionSupplier(); - this.authenticationEventListener = AuthXEventListener.NOOP_LISTENER; + this.authXEventListener = AuthXEventListener.NOOP_LISTENER; } else { - this.clientConfig = replaceCredentialsProvider(clientConfig, authXManager); Supplier supplier = connectionSupplier(); this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get()); - this.authenticationEventListener = authXManager.getListener(); - try { - authXManager.start(true); - } catch (InterruptedException | ExecutionException | TimeoutException e) { - throw new JedisException("AuthXManager failed to start!", e); - } + this.authXEventListener = authXManager.getListener(); + authXManager.start(); } } - private JedisClientConfig replaceCredentialsProvider(JedisClientConfig origin, - Supplier newCredentialsProvider) { - return DefaultJedisClientConfig.builder().from(origin) - .credentialsProvider(newCredentialsProvider).build(); - } - private Supplier connectionSupplier() { return clientSideCache == null ? () -> new Connection(jedisSocketFactory, clientConfig) : () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache); @@ -120,7 +107,7 @@ public void passivateObject(PooledObject pooledConnection) throws Ex try { jedis.reAuth(); } catch (Exception e) { - authenticationEventListener.onConnectionAuthenticationError(e); + authXEventListener.onConnectionAuthenticationError(e); throw e; } } @@ -133,7 +120,7 @@ public boolean validateObject(PooledObject pooledConnection) { try { jedis.reAuth(); } catch (Exception e) { - authenticationEventListener.onConnectionAuthenticationError(e); + authXEventListener.onConnectionAuthenticationError(e); throw e; } return jedis.isConnected() && jedis.ping(); diff --git a/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java b/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java index ec63c5206a..9462527c0f 100644 --- a/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java +++ b/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java @@ -103,6 +103,9 @@ public JedisClusterInfoCache(final JedisClientConfig clientConfig, Cache clientS this.clientConfig = clientConfig; this.clientSideCache = clientSideCache; this.startNodes = startNodes; + if (clientConfig.getAuthXManager() != null) { + clientConfig.getAuthXManager().start(); + } if (topologyRefreshPeriod != null) { logger.info("Cluster topology refresh start, period: {}, startNodes: {}", topologyRefreshPeriod, startNodes); topologyRefreshExecutor = Executors.newSingleThreadScheduledExecutor(); diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXManager.java b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java index d66bddeb4c..8a91d9001d 100644 --- a/src/main/java/redis/clients/jedis/authentication/AuthXManager.java +++ b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java @@ -4,8 +4,10 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; @@ -29,6 +31,7 @@ public final class AuthXManager implements Supplier { private Token currentToken; private AuthXEventListener listener = AuthXEventListener.NOOP_LISTENER; private List> postAuthenticateHooks = new ArrayList<>(); + private AtomicReference> uniqueStarterTask = new AtomicReference<>(); protected AuthXManager(TokenManager tokenManager) { this.tokenManager = tokenManager; @@ -39,9 +42,29 @@ public AuthXManager(TokenAuthConfig tokenAuthConfig) { tokenAuthConfig.getTokenManagerConfig())); } - public void start(boolean blockForInitialToken) - throws InterruptedException, ExecutionException, TimeoutException { + public void start() { + Future safeStarter = safeStart(this::tokenManagerStart); + try { + safeStarter.get(); + } catch (InterruptedException | ExecutionException e) { + throw new JedisAuthenticationException("AuthXManager failed to start!", + (e instanceof ExecutionException) ? e.getCause() : e); + } + } + + private Future safeStart(Runnable starter) { + if (uniqueStarterTask.compareAndSet(null, new CompletableFuture())) { + try { + starter.run(); + uniqueStarterTask.get().complete(null); + } catch (Exception e) { + uniqueStarterTask.get().completeExceptionally(e); + } + } + return uniqueStarterTask.get(); + } + private void tokenManagerStart() { tokenManager.start(new TokenListener() { @Override public void onTokenRenewed(Token token) { @@ -54,7 +77,7 @@ public void onError(Exception reason) { listener.onIdentityProviderError(reason); AuthXManager.this.onError(reason); } - }, blockForInitialToken); + }, true); } public void authenticateConnections(Token token) { diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java new file mode 100644 index 0000000000..d711804335 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java @@ -0,0 +1,131 @@ +package redis.clients.jedis.authentication; + +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import org.junit.Test; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; +import redis.clients.jedis.Connection; +import redis.clients.jedis.ConnectionPoolConfig; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisClientConfig; +import redis.clients.jedis.JedisCluster; +import redis.clients.jedis.JedisClusterTestBase; + +public class TokenBasedAuthenticationClusterIntegrationTests extends JedisClusterTestBase { + + @Test + public void testClusterInitWithAuthXManager() { + IdentityProviderConfig idpConfig = new IdentityProviderConfig() { + @Override + public IdentityProvider getProvider() { + return new IdentityProvider() { + @Override + public Token requestToken() { + return new SimpleToken("cluster", System.currentTimeMillis() + 5 * 1000, + System.currentTimeMillis(), + Collections.singletonMap("oid", "default")); + } + }; + } + }; + AuthXManager manager = new AuthXManager(EntraIDTokenAuthConfigBuilder.builder() + .lowerRefreshBoundMillis(1000).identityProviderConfig(idpConfig).build()); + + HostAndPort hp = HostAndPorts.getClusterServers().get(0); + int defaultDirections = 5; + JedisClientConfig config = DefaultJedisClientConfig.builder().authXManager(manager).build(); + + ConnectionPoolConfig DEFAULT_POOL_CONFIG = new ConnectionPoolConfig(); + try (JedisCluster jc = new JedisCluster(hp, config, defaultDirections, + DEFAULT_POOL_CONFIG)) { + + assertEquals("OK", jc.set("foo", "bar")); + assertEquals("bar", jc.get("foo")); + assertEquals(1, jc.del("foo")); + } + } + + @Test + public void testClusterWithReAuth() throws InterruptedException, ExecutionException { + IdentityProviderConfig idpConfig = new IdentityProviderConfig() { + @Override + public IdentityProvider getProvider() { + return new IdentityProvider() { + @Override + public Token requestToken() { + return new SimpleToken("cluster", System.currentTimeMillis() + 5 * 1000, + System.currentTimeMillis(), + Collections.singletonMap("oid", "default")); + } + }; + } + }; + AuthXManager authXManager = new AuthXManager(EntraIDTokenAuthConfigBuilder.builder() + .lowerRefreshBoundMillis(4600).identityProviderConfig(idpConfig).build()); + + authXManager = spy(authXManager); + + List connections = new ArrayList<>(); + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + HostAndPort hp = HostAndPorts.getClusterServers().get(0); + JedisClientConfig config = DefaultJedisClientConfig.builder().authXManager(authXManager) + .build(); + + ExecutorService executorService = Executors.newFixedThreadPool(2); + CountDownLatch latch = new CountDownLatch(1); + try (JedisCluster jc = new JedisCluster(Collections.singleton(hp), config)) { + Runnable task = () -> { + while (latch.getCount() > 0) { + assertEquals("OK", jc.set("foo", "bar")); + } + }; + Future task1 = executorService.submit(task); + Future task2 = executorService.submit(task); + + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) + .until(connections::size, greaterThanOrEqualTo(2)); + + connections.forEach(conn -> { + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) + .untilAsserted(() -> verify(conn, atLeast(2)).reAuth()); + }); + latch.countDown(); + task1.get(); + task2.get(); + } finally { + latch.countDown(); + executorService.shutdown(); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index 469364d118..a3ceeadf1e 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -229,7 +229,7 @@ public void testAuthXManagerReceivesNewToken() return null; }).when(manager).authenticateConnections(any()); - manager.start(true); + manager.start(); assertEquals(tokenHolder[0].getValue(), "tokenVal"); } @@ -244,7 +244,7 @@ public void testBlockForInitialToken() { new TokenManagerConfig(0.7F, 200, 2000, new TokenManagerConfig.RetryPolicy(5, 100))); AuthXManager manager = new AuthXManager(tokenManager); - TokenRequestException e = assertThrows(TokenRequestException.class, () -> manager.start(true)); + TokenRequestException e = assertThrows(TokenRequestException.class, () -> manager.start()); assertEquals(exceptionMessage, e.getCause().getCause().getCause().getMessage()); } @@ -256,7 +256,6 @@ public void testNoBlockForInitialToken() CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); IdentityProvider identityProvider = () -> { try { - System.out.println("awaiting"); requesLatch.await(); } catch (InterruptedException e) { e.printStackTrace(); @@ -268,7 +267,7 @@ public void testNoBlockForInitialToken() new TokenManagerConfig(0.7F, 200, 500, new TokenManagerConfig.RetryPolicy(5, 0))); AuthXManager manager = spy(new AuthXManager(tokenManager)); - manager.start(false); + manager.start(); await().during(FIVE_HUNDRED_MILLISECONDS).until(tokenManager::getCurrentToken, Matchers.nullValue()); @@ -338,7 +337,7 @@ public void testTokenManagerWithHangingTokenRequest() executionTimeout, new TokenManagerConfig.RetryPolicy(numberOfRetries, 100))); AuthXManager manager = spy(new AuthXManager(tokenManager)); - manager.start(false); + manager.start(); requesLatch.await(); verify(manager, never()).onError(any()); await().atMost(2, TimeUnit.SECONDS).untilAsserted(() -> { From 1ec523987dab908e58d9338efe5945d807d82025 Mon Sep 17 00:00:00 2001 From: atakavci Date: Fri, 6 Dec 2024 17:40:36 +0300 Subject: [PATCH 10/21] - remove onerror from authxmanager - fix flaky tests --- .../jedis/authentication/AuthXManager.java | 7 --- .../TokenBasedAuthenticationUnitTests.java | 55 +++++++++---------- 2 files changed, 25 insertions(+), 37 deletions(-) diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXManager.java b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java index 8a91d9001d..a4363e712e 100644 --- a/src/main/java/redis/clients/jedis/authentication/AuthXManager.java +++ b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java @@ -75,7 +75,6 @@ public void onTokenRenewed(Token token) { @Override public void onError(Exception reason) { listener.onIdentityProviderError(reason); - AuthXManager.this.onError(reason); } }, true); } @@ -93,12 +92,6 @@ public void authenticateConnections(Token token) { postAuthenticateHooks.forEach(hook -> hook.accept(token)); } - public void onError(Exception reason) { - log.error("Token manager failed to acquire new token!", reason); - throw new JedisAuthenticationException("Token manager failed to acquire new token!", - reason); - } - public Connection addConnection(Connection connection) { connections.add(new WeakReference<>(connection)); return connection; diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index a3ceeadf1e..a70fec0704 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -11,12 +11,12 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.atLeast; import static org.awaitility.Awaitility.await; import static org.awaitility.Durations.*; import static org.hamcrest.CoreMatchers.either; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import java.util.Collections; import java.util.concurrent.CountDownLatch; @@ -38,7 +38,6 @@ import redis.clients.authentication.core.TokenListener; import redis.clients.authentication.core.TokenManager; import redis.clients.authentication.core.TokenManagerConfig; -import redis.clients.authentication.core.TokenRequestException; import redis.clients.jedis.ConnectionPool; import redis.clients.jedis.EndpointConfig; import redis.clients.jedis.HostAndPort; @@ -154,8 +153,8 @@ public void testCalculateRenewalDelay() { delay = manager.calculateRenewalDelay(expireDate, issueDate); - assertThat(delay, Matchers - .greaterThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); + assertThat(delay, + lessThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); duration = 10000; config.lower = 8000; @@ -165,8 +164,8 @@ public void testCalculateRenewalDelay() { delay = manager.calculateRenewalDelay(expireDate, issueDate); - assertThat(delay, Matchers - .greaterThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); + assertThat(delay, + lessThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); duration = 10000; config.lower = 10000; @@ -234,7 +233,7 @@ public void testAuthXManagerReceivesNewToken() } @Test - public void testBlockForInitialToken() { + public void testBlockForInitialTokenWhenException() { String exceptionMessage = "Test exception from identity provider!"; IdentityProvider identityProvider = () -> { throw new RuntimeException(exceptionMessage); @@ -244,41 +243,33 @@ public void testBlockForInitialToken() { new TokenManagerConfig(0.7F, 200, 2000, new TokenManagerConfig.RetryPolicy(5, 100))); AuthXManager manager = new AuthXManager(tokenManager); - TokenRequestException e = assertThrows(TokenRequestException.class, () -> manager.start()); + JedisAuthenticationException e = assertThrows(JedisAuthenticationException.class, + () -> manager.start()); - assertEquals(exceptionMessage, e.getCause().getCause().getCause().getMessage()); + assertEquals(exceptionMessage, e.getCause().getCause().getMessage()); } @Test - public void testNoBlockForInitialToken() - throws InterruptedException, ExecutionException, TimeoutException { - int numberOfRetries = 1; - CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); + public void testBlockForInitialTokenWhenHangs() { + String exceptionMessage = "AuthXManager failed to start!"; + CountDownLatch latch = new CountDownLatch(1); IdentityProvider identityProvider = () -> { try { - requesLatch.await(); + latch.await(); } catch (InterruptedException e) { - e.printStackTrace(); } - throw new RuntimeException("Test exception from identity provider!"); + return null; }; TokenManager tokenManager = new TokenManager(identityProvider, - new TokenManagerConfig(0.7F, 200, 500, new TokenManagerConfig.RetryPolicy(5, 0))); - - AuthXManager manager = spy(new AuthXManager(tokenManager)); - manager.start(); + new TokenManagerConfig(0.7F, 200, 1000, new TokenManagerConfig.RetryPolicy(2, 100))); - await().during(FIVE_HUNDRED_MILLISECONDS).until(tokenManager::getCurrentToken, - Matchers.nullValue()); - verify(manager, never()).onError(any()); - verify(manager, never()).authenticateConnections(any()); - requesLatch.countDown(); + AuthXManager manager = new AuthXManager(tokenManager); + JedisAuthenticationException e = assertThrows(JedisAuthenticationException.class, + () -> manager.start()); - await().during(FIVE_HUNDRED_MILLISECONDS).until(tokenManager::getCurrentToken, - Matchers.nullValue()); - verify(manager, atLeast(1)).onError(any()); - verify(manager, never()).authenticateConnections(any()); + latch.countDown(); + assertEquals(exceptionMessage, e.getMessage()); } @Test @@ -337,9 +328,13 @@ public void testTokenManagerWithHangingTokenRequest() executionTimeout, new TokenManagerConfig.RetryPolicy(numberOfRetries, 100))); AuthXManager manager = spy(new AuthXManager(tokenManager)); + AuthXEventListener listener = mock(AuthXEventListener.class); + manager.setListener(listener); manager.start(); requesLatch.await(); - verify(manager, never()).onError(any()); + verify(listener, never()).onIdentityProviderError(any()); + verify(listener, never()).onConnectionAuthenticationError(any()); + await().atMost(2, TimeUnit.SECONDS).untilAsserted(() -> { verify(manager, times(1)).authenticateConnections(any()); }); From 7d3a0ae43b4c7e60aadfb2486b9176dc0adf8abb Mon Sep 17 00:00:00 2001 From: atakavci Date: Mon, 9 Dec 2024 11:52:05 +0300 Subject: [PATCH 11/21] - fix flaky test --- .../RedisEntraIDIntegrationTests.java | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index e31e6ee595..0dc0fe26dd 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -195,7 +195,17 @@ public void withCertificate_azureServicePrincipalIntegrationTest() { public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException { // set the stage with consecutive get/set operations with unique keys which takes at least for 2000 ms with a jedispooled instace, // configure token manager to renew token approximately every 100ms - // wait till all operations are completed and verify that token was renewed at least 20 times after initial token acquisition + // wait till all operations are completed and verify that token was renewed at least 10 times after initial token acquisition + // Additional note: Assumptions made on the time taken for token renewal and operations are based on the current implementation and may vary in future + // Assumptions: + // - TTL of token is 2 hour + // - expirationRefreshRatio is 0.000001F + // - renewal delay is 7 ms each time a token is acquired + // - each auth command takes 40 ms in total to complete(considering the cloud test environments) + // - each auth command would need to wait for an ongoing customer operation(GET/SET/DEL) to complete, which would take another 40 ms + // - each renewal happens in 40+40+7 = 87 ms + // - total number of renewals would be 2000 / 87 = 22.9885 ~ 23 + // - to avoid a flaky test results, we will consider approximately half of it as 10 renewals TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) @@ -228,7 +238,7 @@ public void renewalDuringOperationsTest() throws InterruptedException, Execution task.get(); } - verify(hook, atLeast(20)).accept(any()); + verify(hook, atLeast(10)).accept(any()); executor.shutdown(); } } @@ -299,7 +309,7 @@ public void connectionAuthWithExpiredTokenTest() { IdentityProvider idp = new EntraIDIdentityProviderConfig( new ServicePrincipalInfo(testCtx.getClientId(), testCtx.getClientSecret(), testCtx.getAuthority()), - testCtx.getRedisScopes(),1000).getProvider(); + testCtx.getRedisScopes(), 1000).getProvider(); IdentityProvider mockIdentityProvider = mock(IdentityProvider.class); AtomicReference token = new AtomicReference<>(); From 2176505f7277df5a908da5d4c448302bc0225302 Mon Sep 17 00:00:00 2001 From: atakavci Date: Mon, 9 Dec 2024 12:55:37 +0300 Subject: [PATCH 12/21] fix renewalDuringOperationsTest --- .../RedisEntraIDIntegrationTests.java | 64 ++++++++++++------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index 0dc0fe26dd..5d1a6d289b 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -1,5 +1,8 @@ package redis.clients.jedis.authentication; +import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.FIVE_SECONDS; +import static org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; @@ -23,6 +26,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -193,9 +197,9 @@ public void withCertificate_azureServicePrincipalIntegrationTest() { // Test that the Redis client is not blocked/interrupted during token renewal. @Test public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException { - // set the stage with consecutive get/set operations with unique keys which takes at least for 2000 ms with a jedispooled instace, - // configure token manager to renew token approximately every 100ms - // wait till all operations are completed and verify that token was renewed at least 10 times after initial token acquisition + // set the stage with consecutive get/set operations with unique keys which keeps running with a jedispooled instace, + // configure token manager to renew token approximately approximately every 10ms + // wait till token was renewed at least 10 times after initial token acquisition // Additional note: Assumptions made on the time taken for token renewal and operations are based on the current implementation and may vary in future // Assumptions: // - TTL of token is 2 hour @@ -204,8 +208,7 @@ public void renewalDuringOperationsTest() throws InterruptedException, Execution // - each auth command takes 40 ms in total to complete(considering the cloud test environments) // - each auth command would need to wait for an ongoing customer operation(GET/SET/DEL) to complete, which would take another 40 ms // - each renewal happens in 40+40+7 = 87 ms - // - total number of renewals would be 2000 / 87 = 22.9885 ~ 23 - // - to avoid a flaky test results, we will consider approximately half of it as 10 renewals + // - total number of renewals would take 87 * 10 = 870 ms TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) @@ -218,29 +221,42 @@ public void renewalDuringOperationsTest() throws InterruptedException, Execution DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() .authXManager(authXManager).build(); - long startTime = System.currentTimeMillis(); - List> futures = new ArrayList<>(); - ExecutorService executor = Executors.newFixedThreadPool(5); - - try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { - for (int i = 0; i < 5; i++) { - Future future = executor.submit(() -> { - for (; System.currentTimeMillis() - startTime < 2000;) { - String key = UUID.randomUUID().toString(); - jedis.set(key, "value"); - assertEquals("value", jedis.get(key)); - jedis.del(key); + ExecutorService jedisExecutors = Executors.newFixedThreadPool(5); + AtomicBoolean completed = new AtomicBoolean(false); + + ExecutorService runner = Executors.newSingleThreadExecutor(); + runner.submit(() -> { + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + List> futures = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + Future future = jedisExecutors.submit(() -> { + while (!completed.get()) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + futures.add(future); + } + for (Future task : futures) { + try { + task.get(); + } catch (InterruptedException | ExecutionException e) { + e.printStackTrace(); } - }); - futures.add(future); - } - for (Future task : futures) { - task.get(); + } } + }); + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_SECONDS).untilAsserted(() -> { verify(hook, atLeast(10)).accept(any()); - executor.shutdown(); - } + }); + + completed.set(true); + runner.shutdown(); + jedisExecutors.shutdown(); } // T.3.2 From 9ea510d6612c3d1e34b06a7e817adc4daa6d7368 Mon Sep 17 00:00:00 2001 From: atakavci Date: Tue, 10 Dec 2024 12:13:24 +0300 Subject: [PATCH 13/21] -reviews from @sazzad16 --- .../java/redis/clients/jedis/Connection.java | 9 +++---- .../clients/jedis/ConnectionFactory.java | 26 ++++++++++--------- .../redis/clients/jedis/ConnectionPool.java | 9 ++++--- .../authentication/TokenCredentials.java | 2 +- .../RedisEntraIDIntegrationTests.java | 2 +- ...AuthenticationClusterIntegrationTests.java | 2 +- .../TokenBasedAuthenticationUnitTests.java | 9 +------ 7 files changed, 27 insertions(+), 32 deletions(-) diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index fe57fb80fb..a59d359e34 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -560,7 +560,7 @@ public void setCredentials(RedisCredentials credentials) { currentCredentials.set(credentials); } - public void authenticate(RedisCredentials credentials) { + private void authenticate(RedisCredentials credentials) { if (credentials == null || credentials.getPassword() == null) { return; } @@ -577,11 +577,8 @@ public void authenticate(RedisCredentials credentials) { getStatusCodeReply(); } - public void reAuth() { - RedisCredentials temp = currentCredentials.getAndSet(null); - if (temp != null) { - authenticate(temp); - } + public void reAuthenticate() { + authenticate(currentCredentials.getAndSet(null)); } protected Map hello(byte[]... args) { diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index 45e89fc2da..6ce7c3663e 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -104,12 +104,7 @@ public PooledObject makeObject() throws Exception { public void passivateObject(PooledObject pooledConnection) throws Exception { // TODO maybe should select db 0? Not sure right now. Connection jedis = pooledConnection.getObject(); - try { - jedis.reAuth(); - } catch (Exception e) { - authXEventListener.onConnectionAuthenticationError(e); - throw e; - } + reAuthenticate(jedis); } @Override @@ -117,16 +112,23 @@ public boolean validateObject(PooledObject pooledConnection) { final Connection jedis = pooledConnection.getObject(); try { // check HostAndPort ?? - try { - jedis.reAuth(); - } catch (Exception e) { - authXEventListener.onConnectionAuthenticationError(e); - throw e; + if (!jedis.isConnected()) { + return false; } - return jedis.isConnected() && jedis.ping(); + reAuthenticate(jedis); + return jedis.ping(); } catch (final Exception e) { logger.warn("Error while validating pooled Connection object.", e); return false; } } + + private void reAuthenticate(Connection jedis) throws Exception { + try { + jedis.reAuthenticate(); + } catch (Exception e) { + authXEventListener.onConnectionAuthenticationError(e); + throw e; + } + } } diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index 536b3a6484..2ae1401081 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -56,10 +56,13 @@ public Connection getResource() { @Override public void close() { - if (authXManager != null) { - authXManager.stop(); + try { + if (authXManager != null) { + authXManager.stop(); + } + } finally { + super.close(); } - super.close(); } private void attachAuthenticationListener(AuthXManager authXManager) { diff --git a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java index 9c5a54f135..471c34bc40 100644 --- a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java +++ b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java @@ -3,7 +3,7 @@ import redis.clients.authentication.core.Token; import redis.clients.jedis.RedisCredentials; -public class TokenCredentials implements RedisCredentials { +class TokenCredentials implements RedisCredentials { private final String user; private final char[] password; diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index 5d1a6d289b..b6010ca28f 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -305,7 +305,7 @@ public void allConnectionsReauthTest() throws InterruptedException, ExecutionExc } connections.forEach(conn -> { - verify(conn, atLeast(1)).reAuth(); + verify(conn, atLeast(1)).reAuthenticate(); }); executor.shutdown(); } diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java index d711804335..2b6e4e3256 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java @@ -118,7 +118,7 @@ public Token requestToken() { connections.forEach(conn -> { await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) - .untilAsserted(() -> verify(conn, atLeast(2)).reAuth()); + .untilAsserted(() -> verify(conn, atLeast(2)).reAuthenticate()); }); latch.countDown(); task1.get(); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index a70fec0704..699dc47f31 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -3,14 +3,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockConstruction; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.awaitility.Awaitility.await; import static org.awaitility.Durations.*; import static org.hamcrest.CoreMatchers.either; From 2175c152b4397a628070c2c44a685de2e71ee169 Mon Sep 17 00:00:00 2001 From: atakavci Date: Wed, 11 Dec 2024 17:20:45 +0300 Subject: [PATCH 14/21] - fix config for managedIdentity - set audiences with scopes - managed identity tests --- .../authentication/EntraIDTestContext.java | 13 ++- .../RedisEntraIDIntegrationTests.java | 40 --------- ...ntraIDManagedIdentityIntegrationTests.java | 81 +++++++++++++++++++ 3 files changed, 93 insertions(+), 41 deletions(-) create mode 100644 src/test/java/redis/clients/jedis/authentication/RedisEntraIDManagedIdentityIntegrationTests.java diff --git a/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java index e0cde9cfef..b58ee2fd21 100644 --- a/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java +++ b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java @@ -18,6 +18,7 @@ public class EntraIDTestContext { private static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY"; private static final String AZURE_CERT = "AZURE_CERT"; private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES"; + private static final String AZURE_USER_ASSIGNED_MANAGED_ID = "AZURE_USER_ASSIGNED_MANAGED_ID"; private String clientId; private String authority; @@ -25,6 +26,7 @@ public class EntraIDTestContext { private PrivateKey privateKey; private X509Certificate cert; private Set redisScopes; + private String userAssignedManagedIdentity; public static final EntraIDTestContext DEFAULT = new EntraIDTestContext(); @@ -32,14 +34,19 @@ private EntraIDTestContext() { clientId = System.getenv(AZURE_CLIENT_ID); authority = System.getenv(AZURE_AUTHORITY); clientSecret = System.getenv(AZURE_CLIENT_SECRET); + userAssignedManagedIdentity = System.getenv(AZURE_USER_ASSIGNED_MANAGED_ID); } public EntraIDTestContext(String clientId, String authority, String clientSecret, - Set redisScopes) { + PrivateKey privateKey, X509Certificate cert, Set redisScopes, + String userAssignedManagedIdentity) { this.clientId = clientId; this.authority = authority; this.clientSecret = clientSecret; + this.privateKey = privateKey; + this.cert = cert; this.redisScopes = redisScopes; + this.userAssignedManagedIdentity = userAssignedManagedIdentity; } public String getClientId() { @@ -76,6 +83,10 @@ public Set getRedisScopes() { return redisScopes; } + public String getUserAssignedManagedIdentity() { + return userAssignedManagedIdentity; + } + private PrivateKey getPrivateKey(String privateKey) { try { // Decode the base64 encoded key into a byte array diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index b6010ca28f..d57e0da3d2 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -49,7 +49,6 @@ import redis.clients.authentication.entraid.EntraIDIdentityProvider; import redis.clients.authentication.entraid.EntraIDIdentityProviderConfig; import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; -import redis.clients.authentication.entraid.ManagedIdentityInfo.UserManagedIdentityType; import redis.clients.authentication.entraid.ServicePrincipalInfo; import redis.clients.jedis.Connection; import redis.clients.jedis.DefaultJedisClientConfig; @@ -116,45 +115,6 @@ public void testJedisConfig() { } } - // T.1.1 - // Verify authentication using Azure AD with managed identities - // @Test - public void withUserAssignedId_azureManagedIdentityIntegrationTest() { - TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() - .clientId(testCtx.getClientId()) - .userAssignedManagedIdentity(UserManagedIdentityType.CLIENT_ID, "userManagedAuthxId") - .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); - - DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() - .authXManager(new AuthXManager(tokenAuthConfig)).build(); - - try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { - String key = UUID.randomUUID().toString(); - jedis.set(key, "value"); - assertEquals("value", jedis.get(key)); - jedis.del(key); - } - } - - // T.1.1 - // Verify authentication using Azure AD with managed identities - // @Test - public void withSystemAssignedId_azureManagedIdentityIntegrationTest() { - TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() - .clientId(testCtx.getClientId()).systemAssignedManagedIdentity() - .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); - - DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() - .authXManager(new AuthXManager(tokenAuthConfig)).build(); - - try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { - String key = UUID.randomUUID().toString(); - jedis.set(key, "value"); - assertEquals("value", jedis.get(key)); - jedis.del(key); - } - } - // T.1.1 // Verify authentication using Azure AD with service principals @Test diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDManagedIdentityIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDManagedIdentityIntegrationTests.java new file mode 100644 index 0000000000..7e305ab766 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDManagedIdentityIntegrationTests.java @@ -0,0 +1,81 @@ +package redis.clients.jedis.authentication; + +import static org.junit.Assert.assertEquals; + +import java.util.Collections; +import java.util.Set; +import java.util.UUID; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; +import redis.clients.authentication.entraid.ManagedIdentityInfo.UserManagedIdentityType; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisPooled; + +public class RedisEntraIDManagedIdentityIntegrationTests { + private static final Logger log = LoggerFactory.getLogger(RedisEntraIDIntegrationTests.class); + + private static EntraIDTestContext testCtx; + private static EndpointConfig endpointConfig; + private static HostAndPort hnp; + private static Set managedIdentityAudience = Collections + .singleton("https://redis.azure.com"); + + @BeforeClass + public static void before() { + try { + testCtx = EntraIDTestContext.DEFAULT; + endpointConfig = HostAndPorts.getRedisEndpoint("standalone-entraid-acl"); + hnp = endpointConfig.getHostAndPort(); + } catch (IllegalArgumentException e) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + @Test + public void withUserAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .userAssignedManagedIdentity(UserManagedIdentityType.OBJECT_ID, + testCtx.getUserAssignedManagedIdentity()) + .scopes(managedIdentityAudience).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + @Test + public void withSystemAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .systemAssignedManagedIdentity().scopes(managedIdentityAudience).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } +} From 86cf6f6655fcfd09f239be80505432c2c0def0d1 Mon Sep 17 00:00:00 2001 From: atakavci Date: Thu, 12 Dec 2024 14:48:26 +0300 Subject: [PATCH 15/21] review from @ggivo - use getuser instead oid from Token --- .../clients/jedis/authentication/AuthXManager.java | 1 + .../jedis/authentication/TokenCredentials.java | 2 +- .../authentication/RedisEntraIDIntegrationTests.java | 10 ++++------ ...enBasedAuthenticationClusterIntegrationTests.java | 12 ++++++------ .../TokenBasedAuthenticationIntegrationTests.java | 5 ++--- .../TokenBasedAuthenticationUnitTests.java | 10 +++++----- 6 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXManager.java b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java index a4363e712e..eba5d8428f 100644 --- a/src/main/java/redis/clients/jedis/authentication/AuthXManager.java +++ b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java @@ -47,6 +47,7 @@ public void start() { try { safeStarter.get(); } catch (InterruptedException | ExecutionException e) { + log.error("AuthXManager failed to start!", e); throw new JedisAuthenticationException("AuthXManager failed to start!", (e instanceof ExecutionException) ? e.getCause() : e); } diff --git a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java index 471c34bc40..143ee60b9d 100644 --- a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java +++ b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java @@ -8,7 +8,7 @@ class TokenCredentials implements RedisCredentials { private final char[] password; public TokenCredentials(Token token) { - user = token.tryGet("oid"); + user = token.getUser(); password = token.getValue().toCharArray(); } diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index d57e0da3d2..fa920e1daa 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.UUID; @@ -96,8 +95,8 @@ public void testJedisConfig() { assertNotNull(mock); doAnswer(invocation -> { counter.incrementAndGet(); - return new SimpleToken("token1", System.currentTimeMillis() + 5 * 60 * 1000, - System.currentTimeMillis(), Collections.singletonMap("oid", "default")); + return new SimpleToken("default", "token1", System.currentTimeMillis() + 5 * 60 * 1000, + System.currentTimeMillis(), null); }).when(mock).requestToken(); })) { @@ -312,9 +311,8 @@ public void connectionAuthWithExpiredTokenTest() { jedis.del(key); } - token - .set(new SimpleToken("token1", System.currentTimeMillis() - 1, System.currentTimeMillis(), - Collections.singletonMap("oid", idp.requestToken().tryGet("oid")))); + token.set(new SimpleToken(idp.requestToken().getUser(), "token1", + System.currentTimeMillis() - 1, System.currentTimeMillis(), null)); JedisAccessControlException aclException = assertThrows(JedisAccessControlException.class, () -> { diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java index 2b6e4e3256..cd7e8eb6f4 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java @@ -45,9 +45,9 @@ public IdentityProvider getProvider() { return new IdentityProvider() { @Override public Token requestToken() { - return new SimpleToken("cluster", System.currentTimeMillis() + 5 * 1000, - System.currentTimeMillis(), - Collections.singletonMap("oid", "default")); + return new SimpleToken("default", "cluster", + System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), + null); } }; } @@ -77,9 +77,9 @@ public IdentityProvider getProvider() { return new IdentityProvider() { @Override public Token requestToken() { - return new SimpleToken("cluster", System.currentTimeMillis() + 5 * 1000, - System.currentTimeMillis(), - Collections.singletonMap("oid", "default")); + return new SimpleToken("default", "cluster", + System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), + null); } }; } diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java index 780c82c781..b9fcfa8218 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -8,7 +8,6 @@ import static org.hamcrest.Matchers.contains; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -64,8 +63,8 @@ public void testJedisPooledForInitialAuth() { IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken(password, System.currentTimeMillis() + 100000, - System.currentTimeMillis(), Collections.singletonMap("oid", user))); + .thenReturn(new SimpleToken(user, password, System.currentTimeMillis() + 100000, + System.currentTimeMillis(), null)); IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); when(idProviderConfig.getProvider()).thenReturn(idProvider); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index 699dc47f31..eb345fffc6 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -63,7 +63,7 @@ public void withExpirationRefreshRatio_testJedisAuthXManagerTriggersEvict() thro IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000, + .thenReturn(new SimpleToken("default","password", System.currentTimeMillis() + 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, @@ -88,7 +88,7 @@ public void withLowerRefreshBounds_testJedisAuthXManagerTriggersEvict() throws E IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken("password", System.currentTimeMillis() + 1000, + .thenReturn(new SimpleToken("default","password", System.currentTimeMillis() + 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, @@ -205,7 +205,7 @@ public void testCalculateRenewalDelay() { public void testAuthXManagerReceivesNewToken() throws InterruptedException, ExecutionException, TimeoutException { - IdentityProvider identityProvider = () -> new SimpleToken("tokenVal", + IdentityProvider identityProvider = () -> new SimpleToken("user1","tokenVal", System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); @@ -277,7 +277,7 @@ public void testTokenManagerWithFailingTokenRequest() if (requesLatch.getCount() > 0) { throw new RuntimeException("Test exception from identity provider!"); } - return new SimpleToken("tokenValX", System.currentTimeMillis() + 50 * 1000, + return new SimpleToken("user1","tokenValX", System.currentTimeMillis() + 50 * 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); }); @@ -313,7 +313,7 @@ public void testTokenManagerWithHangingTokenRequest() } return null; } - return new SimpleToken("tokenValX", System.currentTimeMillis() + tokenLifetime, + return new SimpleToken("user1","tokenValX", System.currentTimeMillis() + tokenLifetime, System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); }; From 9717c9a5b14ce7d9145fa7755ef73e1ccd020afe Mon Sep 17 00:00:00 2001 From: atakavci Date: Thu, 12 Dec 2024 15:55:18 +0300 Subject: [PATCH 16/21] handle and propogate from unsuccessful AUTH response --- src/main/java/redis/clients/jedis/Connection.java | 10 +++++----- .../java/redis/clients/jedis/ConnectionFactory.java | 11 ++++++++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index a59d359e34..c14dfa08f6 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -560,9 +560,9 @@ public void setCredentials(RedisCredentials credentials) { currentCredentials.set(credentials); } - private void authenticate(RedisCredentials credentials) { + private String authenticate(RedisCredentials credentials) { if (credentials == null || credentials.getPassword() == null) { - return; + return null; } byte[] rawPass = encodeToBytes(credentials.getPassword()); try { @@ -574,11 +574,11 @@ private void authenticate(RedisCredentials credentials) { } finally { Arrays.fill(rawPass, (byte) 0); // clear sensitive data } - getStatusCodeReply(); + return getStatusCodeReply(); } - public void reAuthenticate() { - authenticate(currentCredentials.getAndSet(null)); + public String reAuthenticate() { + return authenticate(currentCredentials.getAndSet(null)); } protected Map hello(byte[]... args) { diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index 6ce7c3663e..7440417152 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -10,6 +10,7 @@ import redis.clients.jedis.annots.Experimental; import redis.clients.jedis.authentication.AuthXManager; +import redis.clients.jedis.authentication.JedisAuthenticationException; import redis.clients.jedis.authentication.AuthXEventListener; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.csc.CacheConnection; @@ -125,8 +126,16 @@ public boolean validateObject(PooledObject pooledConnection) { private void reAuthenticate(Connection jedis) throws Exception { try { - jedis.reAuthenticate(); + String result = jedis.reAuthenticate(); + if (result != null && !result.equals("OK")) { + String msg = "Re-authentication failed with server response: " + result; + Exception failedAuth = new JedisAuthenticationException(msg); + logger.error(failedAuth.getMessage(), failedAuth); + authXEventListener.onConnectionAuthenticationError(failedAuth); + return; + } } catch (Exception e) { + logger.error("Error while re-authenticating connection", e); authXEventListener.onConnectionAuthenticationError(e); throw e; } From 9185f44813f1e5f43c87597e947de27ffb0841f6 Mon Sep 17 00:00:00 2001 From: atakavci Date: Sun, 15 Dec 2024 02:36:53 +0300 Subject: [PATCH 17/21] adding reauth support for both pubsub and shardedpubsub --- .../java/redis/clients/jedis/Connection.java | 17 +- .../redis/clients/jedis/JedisPubSubBase.java | 53 +++-- .../clients/jedis/JedisSafeAuthenticator.java | 104 +++++++++ .../clients/jedis/JedisShardedPubSubBase.java | 31 ++- .../RedisEntraIDIntegrationTests.java | 7 - ...enBasedAuthenticationIntegrationTests.java | 209 +++++++++++++++--- 6 files changed, 346 insertions(+), 75 deletions(-) create mode 100644 src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index c14dfa08f6..de473d0b8e 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -46,9 +46,8 @@ public class Connection implements Closeable { private String strVal; protected String server; protected String version; - protected AtomicReference currentCredentials = new AtomicReference( - null); - private boolean isTokenBasedAuthenticationEnabled = false; + private AtomicReference currentCredentials = new AtomicReference<>(null); + private AuthXManager authXManager; public Connection() { this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT); @@ -68,6 +67,7 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC public Connection(final JedisSocketFactory socketFactory) { this.socketFactory = socketFactory; + this.authXManager = null; } public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) { @@ -458,9 +458,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { Supplier credentialsProvider = config.getCredentialsProvider(); - AuthXManager authXManager = config.getAuthXManager(); + authXManager = config.getAuthXManager(); if (authXManager != null) { - isTokenBasedAuthenticationEnabled = true; credentialsProvider = authXManager; } @@ -608,7 +607,11 @@ public boolean ping() { return true; } - public boolean isTokenBasedAuthenticationEnabled() { - return isTokenBasedAuthenticationEnabled; + protected boolean isTokenBasedAuthenticationEnabled() { + return authXManager != null; + } + + protected AuthXManager getAuthXManager() { + return authXManager; } } diff --git a/src/main/java/redis/clients/jedis/JedisPubSubBase.java b/src/main/java/redis/clients/jedis/JedisPubSubBase.java index 5c96278fb9..4f72f546d7 100644 --- a/src/main/java/redis/clients/jedis/JedisPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisPubSubBase.java @@ -4,6 +4,7 @@ import java.util.Arrays; import java.util.List; +import java.util.function.Consumer; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.exceptions.JedisException; @@ -12,7 +13,8 @@ public abstract class JedisPubSubBase { private int subscribedChannels = 0; - private volatile Connection client; + private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator(); + private final Consumer pingResultHandler = this::processPingReply; public void onMessage(T channel, T message) { } @@ -36,12 +38,7 @@ public void onPong(T pattern) { } private void sendAndFlushCommand(Command command, T... args) { - if (client == null) { - throw new JedisException(getClass() + " is not connected to a Connection."); - } - CommandArguments cargs = new CommandArguments(command).addObjects(args); - client.sendCommand(cargs); - client.flush(); + authenticator.sendAndFlushCommand(command, args); } public final void unsubscribe() { @@ -63,7 +60,8 @@ public final void psubscribe(T... patterns) { } private void checkConnectionSuitableForPubSub() { - if (client.protocol == RedisProtocol.RESP2 && client.isTokenBasedAuthenticationEnabled()) { + if (authenticator.client.protocol != RedisProtocol.RESP3 + && authenticator.client.isTokenBasedAuthenticationEnabled()) { throw new JedisException( "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!"); } @@ -78,7 +76,13 @@ public final void punsubscribe(T... patterns) { } public final void ping() { - sendAndFlushCommand(Command.PING); + authenticator.commandSync.lock(); + try { + sendAndFlushCommand(Command.PING); + authenticator.resultHandler.add(pingResultHandler); + } finally { + authenticator.commandSync.unlock(); + } } public final void ping(T argument) { @@ -94,24 +98,24 @@ public final int getSubscribedChannels() { } public final void proceed(Connection client, T... channels) { - this.client = client; - this.client.setTimeoutInfinite(); + authenticator.registerForAuthentication(client); + authenticator.client.setTimeoutInfinite(); try { subscribe(channels); process(); } finally { - this.client.rollbackTimeout(); + authenticator.client.rollbackTimeout(); } } public final void proceedWithPatterns(Connection client, T... patterns) { - this.client = client; - this.client.setTimeoutInfinite(); + authenticator.registerForAuthentication(client); + authenticator.client.setTimeoutInfinite(); try { psubscribe(patterns); process(); } finally { - this.client.rollbackTimeout(); + authenticator.client.rollbackTimeout(); } } @@ -121,7 +125,7 @@ public final void proceedWithPatterns(Connection client, T... patterns) { private void process() { do { - Object reply = client.getUnflushedObject(); + Object reply = authenticator.client.getUnflushedObject(); if (reply instanceof List) { List listReply = (List) reply; @@ -175,12 +179,8 @@ private void process() { throw new JedisException("Unknown message type: " + firstObj); } } else if (reply instanceof byte[]) { - byte[] resp = (byte[]) reply; - if ("PONG".equals(SafeEncoder.encode(resp))) { - onPong(null); - } else { - onPong(encode(resp)); - } + Consumer resultHandler = authenticator.resultHandler.remove(); + resultHandler.accept(reply); } else { throw new JedisException("Unknown message type: " + reply); } @@ -189,4 +189,13 @@ private void process() { // /* Invalidate instance since this thread is no longer listening */ // this.client = null; } + + private void processPingReply(Object reply) { + byte[] resp = (byte[]) reply; + if ("PONG".equals(SafeEncoder.encode(resp))) { + onPong(null); + } else { + onPong(encode(resp)); + } + } } diff --git a/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java b/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java new file mode 100644 index 0000000000..16b72f1684 --- /dev/null +++ b/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java @@ -0,0 +1,104 @@ +package redis.clients.jedis; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.jedis.Protocol.Command; +import redis.clients.jedis.authentication.JedisAuthenticationException; +import redis.clients.jedis.exceptions.JedisException; +import redis.clients.jedis.util.SafeEncoder; + +public class JedisSafeAuthenticator { + + private static final Token PLACEHOLDER_TOKEN = new SimpleToken(null, null, 0, 0, null); + private static final Logger logger = LoggerFactory.getLogger(JedisSafeAuthenticator.class); + + protected volatile Connection client; + protected final Consumer authResultHandler = this::processAuthReply; + protected final Consumer authenticationHandler = this::safeReAuthenticate; + + protected final AtomicReference pendingTokenRef = new AtomicReference(null); + protected final ReentrantLock commandSync = new ReentrantLock(); + protected final Queue> resultHandler = new ConcurrentLinkedQueue>(); + + protected void sendAndFlushCommand(Command command, Object... args) { + if (client == null) { + throw new JedisException(getClass() + " is not connected to a Connection."); + } + CommandArguments cargs = new CommandArguments(command).addObjects(args); + + Token newToken = pendingTokenRef.getAndSet(PLACEHOLDER_TOKEN); + + // lets send the command without locking !!IF!! we know that pendingTokenRef is null replaced with PLACEHOLDER_TOKEN and no re-auth will go into action + // !!ELSE!! we are locking since we already know a re-auth is still in progress in another thread and we need to wait for it to complete, we do nothing but wait on it! + if (newToken != null) { + commandSync.lock(); + } + try { + client.sendCommand(cargs); + client.flush(); + } finally { + Token newerToken = pendingTokenRef.getAndSet(null); + // lets check if a newer token received since the beginning of this sendAndFlushCommand call + if (newerToken != null && newerToken != PLACEHOLDER_TOKEN) { + safeReAuthenticate(newerToken); + } + if (newToken != null) { + commandSync.unlock(); + } + } + } + + protected void registerForAuthentication(Connection newClient) { + Connection oldClient = this.client; + if (oldClient == newClient) return; + if (oldClient != null && oldClient.getAuthXManager() != null) { + oldClient.getAuthXManager().removePostAuthenticationHook(authenticationHandler); + } + if (newClient != null && newClient.getAuthXManager() != null) { + newClient.getAuthXManager().addPostAuthenticationHook(authenticationHandler); + } + this.client = newClient; + } + + private void safeReAuthenticate(Token token) { + try { + byte[] rawPass = client.encodeToBytes(token.getValue().toCharArray()); + byte[] rawUser = client.encodeToBytes(token.getUser().toCharArray()); + + Token newToken = pendingTokenRef.getAndSet(token); + if (newToken == null) { + commandSync.lock(); + try { + sendAndFlushCommand(Command.AUTH, rawUser, rawPass); + resultHandler.add(this.authResultHandler); + } finally { + pendingTokenRef.set(null); + commandSync.unlock(); + } + } + } catch (Exception e) { + logger.error("Error while re-authenticating connection", e); + client.getAuthXManager().getListener().onConnectionAuthenticationError(e); + } + } + + protected void processAuthReply(Object reply) { + byte[] resp = (byte[]) reply; + String response = SafeEncoder.encode(resp); + if (!"OK".equals(response)) { + String msg = "Re-authentication failed with server response: " + response; + Exception failedAuth = new JedisAuthenticationException(msg); + logger.error(failedAuth.getMessage(), failedAuth); + client.getAuthXManager().getListener().onConnectionAuthenticationError(failedAuth); + } + } +} diff --git a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java index 2b2ce944fe..a52e9fbadf 100644 --- a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java @@ -4,6 +4,7 @@ import java.util.Arrays; import java.util.List; +import java.util.function.Consumer; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.exceptions.JedisException; @@ -11,7 +12,7 @@ public abstract class JedisShardedPubSubBase { private int subscribedChannels = 0; - private volatile Connection client; + private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator(); public void onSMessage(T channel, T message) { } @@ -23,12 +24,7 @@ public void onSUnsubscribe(T channel, int subscribedChannels) { } private void sendAndFlushCommand(Command command, T... args) { - if (client == null) { - throw new JedisException(getClass() + " is not connected to a Connection."); - } - CommandArguments cargs = new CommandArguments(command).addObjects(args); - client.sendCommand(cargs); - client.flush(); + authenticator.sendAndFlushCommand(command, args); } public final void sunsubscribe() { @@ -40,9 +36,18 @@ public final void sunsubscribe(T... channels) { } public final void ssubscribe(T... channels) { + checkConnectionSuitableForPubSub(); sendAndFlushCommand(Command.SSUBSCRIBE, channels); } + private void checkConnectionSuitableForPubSub() { + if (authenticator.client.protocol != RedisProtocol.RESP3 + && authenticator.client.isTokenBasedAuthenticationEnabled()) { + throw new JedisException( + "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!"); + } + } + public final boolean isSubscribed() { return subscribedChannels > 0; } @@ -52,23 +57,22 @@ public final int getSubscribedChannels() { } public final void proceed(Connection client, T... channels) { - this.client = client; - this.client.setTimeoutInfinite(); + authenticator.registerForAuthentication(client); + authenticator.client.setTimeoutInfinite(); try { ssubscribe(channels); process(); } finally { - this.client.rollbackTimeout(); + authenticator.client.rollbackTimeout(); } } protected abstract T encode(byte[] raw); -// private void process(Client client) { private void process() { do { - Object reply = client.getUnflushedObject(); + Object reply = authenticator.client.getUnflushedObject(); if (reply instanceof List) { List listReply = (List) reply; @@ -96,6 +100,9 @@ private void process() { } else { throw new JedisException("Unknown message type: " + firstObj); } + } else if (reply instanceof byte[]) { + Consumer resultHandler = authenticator.resultHandler.remove(); + resultHandler.accept(reply); } else { throw new JedisException("Unknown message type: " + reply); } diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index fa920e1daa..55551331ed 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -270,13 +270,6 @@ public void allConnectionsReauthTest() throws InterruptedException, ExecutionExc } } - // T.3.2 - // Test system behavior when some connections fail to re-authenticate during bulk authentication. e.g when a network partition occurs for 1 or more of them - @Test - public void partialReauthFailureTest() { - - } - // T.3.3 // Verify behavior when attempting to authenticate a single connection with an expired token. @Test diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java index b9fcfa8218..336fa416dd 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -2,21 +2,27 @@ import static org.mockito.Mockito.when; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS; +import static org.awaitility.Durations.ONE_SECOND; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; - +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,16 +31,17 @@ import redis.clients.authentication.core.SimpleToken; import redis.clients.authentication.core.TokenAuthConfig; import redis.clients.jedis.CommandArguments; +import redis.clients.jedis.Connection; /* */ import redis.clients.jedis.DefaultJedisClientConfig; import redis.clients.jedis.EndpointConfig; import redis.clients.jedis.HostAndPorts; import redis.clients.jedis.JedisClientConfig; import redis.clients.jedis.JedisPooled; -import redis.clients.jedis.Protocol; +import redis.clients.jedis.JedisPubSub; +import redis.clients.jedis.RedisProtocol; import redis.clients.jedis.Protocol.Command; -import redis.clients.jedis.args.Rawable; -import redis.clients.jedis.commands.ProtocolCommand; +import redis.clients.jedis.exceptions.JedisException; public class TokenBasedAuthenticationIntegrationTests { private static final Logger log = LoggerFactory @@ -62,9 +69,8 @@ public void testJedisPooledForInitialAuth() { String password = endpointConfig.getPassword(); IdentityProvider idProvider = mock(IdentityProvider.class); - when(idProvider.requestToken()) - .thenReturn(new SimpleToken(user, password, System.currentTimeMillis() + 100000, - System.currentTimeMillis(), null)); + when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password, + System.currentTimeMillis() + 100000, System.currentTimeMillis(), null)); IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); when(idProviderConfig.getProvider()).thenReturn(idProvider); @@ -76,26 +82,175 @@ public void testJedisPooledForInitialAuth() { JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() .authXManager(new AuthXManager(tokenAuthConfig)).build(); - try (MockedStatic mockedStatic = Mockito.mockStatic(Protocol.class)) { - ArgumentCaptor captor = ArgumentCaptor.forClass(CommandArguments.class); + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + jedis.get("key1"); + } + } + + @Test + public void testJedisPooledReauth() { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenAnswer(invocation -> new SimpleToken(user, password, + System.currentTimeMillis() + 5000, System.currentTimeMillis(), null)); - try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { - jedis.get("key1"); + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(4800).tokenRequestExecTimeoutInMs(1000).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + authXManager = spy(authXManager); + List connections = new ArrayList<>(); + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder().authXManager(authXManager) + .build(); + + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + AtomicBoolean stop = new AtomicBoolean(false); + ExecutorService executor = Executors.newSingleThreadExecutor(); + executor.submit(() -> { + while (!stop.get()) { + jedis.get("key1"); + } + }); + + for (Connection connection : connections) { + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND).untilAsserted(() -> { + verify(connection, atLeast(3)).reAuthenticate(); + }); } + stop.set(true); + executor.shutdown(); + } + } + + @Test + public void testPubSubForInitialAuth() throws InterruptedException { + String user = "default"; + String password = endpointConfig.getPassword(); - // Verify that the static method was called - mockedStatic.verify(() -> Protocol.sendCommand(any(), captor.capture()), Mockito.atLeast(4)); + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password, + System.currentTimeMillis() + 100000, System.currentTimeMillis(), null)); - CommandArguments commandArgs = captor.getAllValues().get(0); - List args = StreamSupport.stream(commandArgs.spliterator(), false) - .map(Rawable::getRaw).collect(Collectors.toList()); + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).protocol(RedisProtocol.RESP3).build(); + + JedisPubSub pubSub = new JedisPubSub() { + public void onSubscribe(String channel, int subscribedChannels) { + this.unsubscribe(); + } + }; + + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + jedis.subscribe(pubSub, "channel1"); + } + } - assertThat(args, - contains(Protocol.Command.AUTH.getRaw(), user.getBytes(), password.getBytes())); + @Test + public void testJedisPubSubReauth() { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenAnswer(invocation -> new SimpleToken(user, password, + System.currentTimeMillis() + 5000, System.currentTimeMillis(), null)); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(4800).tokenRequestExecTimeoutInMs(1000).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + authXManager = spy(authXManager); + List connections = new ArrayList<>(); + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder().authXManager(authXManager) + .protocol(RedisProtocol.RESP3).build(); + + JedisPubSub pubSub = new JedisPubSub() { + }; + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + ExecutorService executor = Executors.newSingleThreadExecutor(); + executor.submit(() -> { + jedis.subscribe(pubSub, "channel1"); + }); + + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) + .until(pubSub::getSubscribedChannels, greaterThan(0)); + + assertEquals(1, connections.size()); + for (Connection connection : connections) { + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND).untilAsserted(() -> { + ArgumentCaptor captor = ArgumentCaptor.forClass(CommandArguments.class); + + verify(connection, atLeast(3)).sendCommand(captor.capture()); + assertThat(captor.getAllValues().stream() + .filter((item) -> item.getCommand() == Command.AUTH).count(), + greaterThan(3L)); + + }); + } + pubSub.unsubscribe(); + executor.shutdown(); + } + } + + @Test + public void testJedisPubSubWithResp2() { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password, + System.currentTimeMillis() + 100000, System.currentTimeMillis(), null)); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); - List cmds = captor.getAllValues().stream().map(item -> item.getCommand()) - .collect(Collectors.toList()); - assertEquals(Arrays.asList(Command.AUTH, Command.CLIENT, Command.CLIENT, Command.GET), cmds); + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + JedisPubSub pubSub = new JedisPubSub() { + }; + JedisException e = assertThrows(JedisException.class, + () -> jedis.subscribe(pubSub, "channel1")); + assertEquals( + "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!", + e.getMessage()); } } } From 5f8159d0f392108f44a8937563385f704f406733 Mon Sep 17 00:00:00 2001 From: atakavci Date: Sun, 15 Dec 2024 04:24:21 +0300 Subject: [PATCH 18/21] fix ping issue with pubsub --- .../java/redis/clients/jedis/JedisPubSubBase.java | 13 +++++++++++-- .../redis/clients/jedis/JedisShardedPubSubBase.java | 6 +++++- .../TokenBasedAuthenticationUnitTests.java | 13 +++++++------ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/main/java/redis/clients/jedis/JedisPubSubBase.java b/src/main/java/redis/clients/jedis/JedisPubSubBase.java index 4f72f546d7..91fee36c58 100644 --- a/src/main/java/redis/clients/jedis/JedisPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisPubSubBase.java @@ -86,7 +86,13 @@ public final void ping() { } public final void ping(T argument) { - sendAndFlushCommand(Command.PING, argument); + authenticator.commandSync.lock(); + try { + sendAndFlushCommand(Command.PING, argument); + authenticator.resultHandler.add(pingResultHandler); + } finally { + authenticator.commandSync.unlock(); + } } public final boolean isSubscribed() { @@ -179,7 +185,10 @@ private void process() { throw new JedisException("Unknown message type: " + firstObj); } } else if (reply instanceof byte[]) { - Consumer resultHandler = authenticator.resultHandler.remove(); + Consumer resultHandler = authenticator.resultHandler.poll(); + if (resultHandler == null) { + throw new JedisException("Unexpected message : " + SafeEncoder.encode((byte[]) reply)); + } resultHandler.accept(reply); } else { throw new JedisException("Unknown message type: " + reply); diff --git a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java index a52e9fbadf..9020693929 100644 --- a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java @@ -8,6 +8,7 @@ import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.exceptions.JedisException; +import redis.clients.jedis.util.SafeEncoder; public abstract class JedisShardedPubSubBase { @@ -101,7 +102,10 @@ private void process() { throw new JedisException("Unknown message type: " + firstObj); } } else if (reply instanceof byte[]) { - Consumer resultHandler = authenticator.resultHandler.remove(); + Consumer resultHandler = authenticator.resultHandler.poll(); + if (resultHandler == null) { + throw new JedisException("Unexpected message : " + SafeEncoder.encode((byte[]) reply)); + } resultHandler.accept(reply); } else { throw new JedisException("Unknown message type: " + reply); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index eb345fffc6..ce5f3b9245 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -63,7 +63,7 @@ public void withExpirationRefreshRatio_testJedisAuthXManagerTriggersEvict() thro IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken("default","password", System.currentTimeMillis() + 1000, + .thenReturn(new SimpleToken("default", "password", System.currentTimeMillis() + 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, @@ -88,7 +88,7 @@ public void withLowerRefreshBounds_testJedisAuthXManagerTriggersEvict() throws E IdentityProvider idProvider = mock(IdentityProvider.class); when(idProvider.requestToken()) - .thenReturn(new SimpleToken("default","password", System.currentTimeMillis() + 1000, + .thenReturn(new SimpleToken("default", "password", System.currentTimeMillis() + 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, @@ -205,7 +205,7 @@ public void testCalculateRenewalDelay() { public void testAuthXManagerReceivesNewToken() throws InterruptedException, ExecutionException, TimeoutException { - IdentityProvider identityProvider = () -> new SimpleToken("user1","tokenVal", + IdentityProvider identityProvider = () -> new SimpleToken("user1", "tokenVal", System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); @@ -277,7 +277,7 @@ public void testTokenManagerWithFailingTokenRequest() if (requesLatch.getCount() > 0) { throw new RuntimeException("Test exception from identity provider!"); } - return new SimpleToken("user1","tokenValX", System.currentTimeMillis() + 50 * 1000, + return new SimpleToken("user1", "tokenValX", System.currentTimeMillis() + 50 * 1000, System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); }); @@ -289,9 +289,10 @@ public void testTokenManagerWithFailingTokenRequest() TokenListener listener = mock(TokenListener.class); tokenManager.start(listener, false); requesLatch.await(); + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_HUNDRED_MILLISECONDS) + .untilAsserted(() -> verify(listener).onTokenRenewed(argument.capture())); verify(identityProvider, times(numberOfRetries)).requestToken(); verify(listener, never()).onError(any()); - verify(listener).onTokenRenewed(argument.capture()); assertEquals("tokenValX", argument.getValue().getValue()); } @@ -313,7 +314,7 @@ public void testTokenManagerWithHangingTokenRequest() } return null; } - return new SimpleToken("user1","tokenValX", System.currentTimeMillis() + tokenLifetime, + return new SimpleToken("user1", "tokenValX", System.currentTimeMillis() + tokenLifetime, System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); }; From edf631aec3115f7998b1af04e4452e9f41ca2edc Mon Sep 17 00:00:00 2001 From: atakavci Date: Fri, 20 Dec 2024 03:06:42 +0300 Subject: [PATCH 19/21] - review from @sazzad16 : make JedisSafeAuthenticator protected - fix failing unit tests --- .../redis/clients/jedis/JedisSafeAuthenticator.java | 2 +- .../TokenBasedAuthenticationUnitTests.java | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java b/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java index 16b72f1684..9c7f95dba1 100644 --- a/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java +++ b/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java @@ -16,7 +16,7 @@ import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.util.SafeEncoder; -public class JedisSafeAuthenticator { +class JedisSafeAuthenticator { private static final Token PLACEHOLDER_TOKEN = new SimpleToken(null, null, 0, 0, null); private static final Logger logger = LoggerFactory.getLogger(JedisSafeAuthenticator.class); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index ce5f3b9245..802dda2b86 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -31,6 +31,7 @@ import redis.clients.authentication.core.TokenListener; import redis.clients.authentication.core.TokenManager; import redis.clients.authentication.core.TokenManagerConfig; +import redis.clients.authentication.core.TokenManagerConfig.RetryPolicy; import redis.clients.jedis.ConnectionPool; import redis.clients.jedis.EndpointConfig; import redis.clients.jedis.HostAndPort; @@ -67,7 +68,7 @@ public void withExpirationRefreshRatio_testJedisAuthXManagerTriggersEvict() thro System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, - new TokenManagerConfig(0.4F, 100, 1000, null)); + new TokenManagerConfig(0.4F, 100, 1000, new RetryPolicy(1, 1))); AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); AtomicInteger numberOfEvictions = new AtomicInteger(0); @@ -92,7 +93,7 @@ public void withLowerRefreshBounds_testJedisAuthXManagerTriggersEvict() throws E System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); TokenManager tokenManager = new TokenManager(idProvider, - new TokenManagerConfig(0.9F, 600, 1000, null)); + new TokenManagerConfig(0.9F, 600, 1000, new RetryPolicy(1, 1))); AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); AtomicInteger numberOfEvictions = new AtomicInteger(0); @@ -126,6 +127,11 @@ public int getLowerRefreshBoundMillis() { public float getExpirationRefreshRatio() { return ratio; } + + @Override + public RetryPolicy getRetryPolicy() { + return new RetryPolicy(1, 1); + } } @Test @@ -210,7 +216,7 @@ public void testAuthXManagerReceivesNewToken() Collections.singletonMap("oid", "user1")); TokenManager tokenManager = new TokenManager(identityProvider, - new TokenManagerConfig(0.7F, 200, 2000, null)); + new TokenManagerConfig(0.7F, 200, 2000, new RetryPolicy(1, 1))); AuthXManager manager = spy(new AuthXManager(tokenManager)); From 88a20c24fb4357f616867a83dbdfdb62e44378c0 Mon Sep 17 00:00:00 2001 From: atakavci Date: Fri, 20 Dec 2024 17:48:52 +0300 Subject: [PATCH 20/21] update authx version --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 97467d6d89..add6b378ea 100644 --- a/pom.xml +++ b/pom.xml @@ -78,7 +78,7 @@ redis.clients.authentication redis-authx-core - 0.1.0-SNAPSHOT + 0.1.1-beta1 @@ -159,7 +159,7 @@ redis.clients.authentication redis-authx-entraid - 0.1.0-SNAPSHOT + 0.1.1-beta1 test From 93f53a2fdafd38924310dd9f15a141b88c4c00f4 Mon Sep 17 00:00:00 2001 From: atakavci Date: Fri, 20 Dec 2024 18:22:45 +0300 Subject: [PATCH 21/21] - remove workaround for standalone endpoint --- .../TokenBasedAuthenticationIntegrationTests.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java index 336fa416dd..9060f80719 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -54,12 +54,8 @@ public static void before() { try { endpointConfig = HostAndPorts.getRedisEndpoint("standalone0"); } catch (IllegalArgumentException e) { - try { - endpointConfig = HostAndPorts.getRedisEndpoint("standalone"); - } catch (IllegalArgumentException ex) { - log.warn("Skipping test because no Redis endpoint is configured"); - org.junit.Assume.assumeTrue(false); - } + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); } }