diff --git a/kafka/client_async.py b/kafka/client_async.py index 96959d9ae..9e57efd5e 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -399,13 +399,23 @@ def _should_recycle_connection(self, conn): return False def _maybe_connect(self, node_id): - """Idempotent non-blocking connection attempt to the given node id.""" + """Idempotent non-blocking connection attempt to the given node id. + + Returns True if connection object exists and is connected / connecting + """ with self._lock: conn = self._conns.get(node_id) + # Check if existing connection should be recreated because host/port changed + if conn is not None and self._should_recycle_connection(conn): + self._conns.pop(node_id).close() + conn = None + if conn is None: broker = self.cluster.broker_metadata(node_id) - assert broker, 'Broker id %s not in current metadata' % (node_id,) + if broker is None: + log.debug('Broker id %s not in current metadata', node_id) + return False log.debug("Initiating connection to node %s at %s:%s", node_id, broker.host, broker.port) @@ -417,16 +427,11 @@ def _maybe_connect(self, node_id): **self.config) self._conns[node_id] = conn - # Check if existing connection should be recreated because host/port changed - elif self._should_recycle_connection(conn): - self._conns.pop(node_id) - return False - elif conn.connected(): return True conn.connect() - return conn.connected() + return not conn.disconnected() def ready(self, node_id, metadata_priority=True): """Check whether a node is connected and ok to send more requests. @@ -621,7 +626,10 @@ def poll(self, timeout_ms=None, future=None): # Attempt to complete pending connections for node_id in list(self._connecting): - self._maybe_connect(node_id) + # False return means no more connection progress is possible + # Connected nodes will update _connecting via state_change callback + if not self._maybe_connect(node_id): + self._connecting.remove(node_id) # If we got a future that is already done, don't block in _poll if future is not None and future.is_done: @@ -965,7 +973,12 @@ def check_version(self, node_id=None, timeout=None, strict=False): if try_node is None: self._lock.release() raise Errors.NoBrokersAvailable() - self._maybe_connect(try_node) + if not self._maybe_connect(try_node): + if try_node == node_id: + raise Errors.NodeNotReadyError("Connection failed to %s" % node_id) + else: + continue + conn = self._conns[try_node] # We will intentionally cause socket failures diff --git a/test/test_client_async.py b/test/test_client_async.py index ccdd57037..16ee4291d 100644 --- a/test/test_client_async.py +++ b/test/test_client_async.py @@ -71,19 +71,14 @@ def test_can_connect(cli, conn): def test_maybe_connect(cli, conn): - try: - # Node not in metadata, raises AssertionError - cli._maybe_connect(2) - except AssertionError: - pass - else: - assert False, 'Exception not raised' + # Node not in metadata, return False + assert not cli._maybe_connect(2) # New node_id creates a conn object assert 0 not in cli._conns conn.state = ConnectionStates.DISCONNECTED conn.connect.side_effect = lambda: conn._set_conn_state(ConnectionStates.CONNECTING) - assert cli._maybe_connect(0) is False + assert cli._maybe_connect(0) is True assert cli._conns[0] is conn