Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e1adb3b

Browse files
authored
Merge pull request aio-libs#1836 from pfreixes/master
Choose addr cached based on a round robin strategy
2 parents 6d6f222 + 3da3a01 commit e1adb3b

2 files changed

Lines changed: 161 additions & 53 deletions

File tree

aiohttp/connector.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import warnings
77
from collections import defaultdict
88
from hashlib import md5, sha1, sha256
9+
from itertools import cycle, islice
10+
from time import monotonic
911
from types import MappingProxyType
1012

1113
from . import hdrs, helpers
@@ -483,6 +485,55 @@ def _create_connection(self, req):
483485
_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
484486

485487

488+
class _DNSCacheTable:
489+
490+
def __init__(self, ttl=None):
491+
self._addrs = {}
492+
self._addrs_rr = {}
493+
self._timestamps = {}
494+
self._ttl = ttl
495+
496+
def __contains__(self, host):
497+
return host in self._addrs
498+
499+
@property
500+
def addrs(self):
501+
return self._addrs
502+
503+
def add(self, host, addrs):
504+
self._addrs[host] = addrs
505+
self._addrs_rr[host] = cycle(addrs)
506+
507+
if self._ttl:
508+
self._timestamps[host] = monotonic()
509+
510+
def remove(self, host):
511+
self._addrs.pop(host, None)
512+
self._addrs_rr.pop(host, None)
513+
514+
if self._ttl:
515+
self._timestamps.pop(host, None)
516+
517+
def clear(self):
518+
self._addrs.clear()
519+
self._addrs_rr.clear()
520+
self._timestamps.clear()
521+
522+
def next_addrs(self, host):
523+
# Return an iterator that will get at maximum as many addrs
524+
# there are for the specific host starting from the last
525+
# not itereated addr.
526+
return islice(self._addrs_rr[host], len(self._addrs[host]))
527+
528+
def expired(self, host):
529+
if self._ttl is None:
530+
return False
531+
532+
return (
533+
self._timestamps[host] + self._ttl
534+
) < monotonic()
535+
536+
486537
class TCPConnector(BaseConnector):
487538
"""TCP connector.
488539
@@ -545,9 +596,7 @@ def __init__(self, *, verify_ssl=True, fingerprint=None,
545596
self._resolver = resolver
546597

547598
self._use_dns_cache = use_dns_cache
548-
self._ttl_dns_cache = ttl_dns_cache
549-
self._cached_hosts = {}
550-
self._cached_hosts_timestamp = {}
599+
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
551600
self._ssl_context = ssl_context
552601
self._family = family
553602
self._local_addr = local_addr
@@ -593,26 +642,17 @@ def use_dns_cache(self):
593642
@property
594643
def cached_hosts(self):
595644
"""Read-only dict of cached DNS record."""
596-
return MappingProxyType(self._cached_hosts)
645+
return MappingProxyType(self._cached_hosts.addrs)
597646

598647
def clear_dns_cache(self, host=None, port=None):
599648
"""Remove specified host/port or clear all dns local cache."""
600649
if host is not None and port is not None:
601-
self._cached_hosts.pop((host, port), None)
602-
self._cached_hosts_timestamp.pop((host, port), None)
650+
self._cached_hosts.remove((host, port))
603651
elif host is not None or port is not None:
604652
raise ValueError("either both host and port "
605653
"or none of them are allowed")
606654
else:
607655
self._cached_hosts.clear()
608-
self._cached_hosts_timestamp.clear()
609-
610-
def _dns_entry_expired(self, key):
611-
if self._ttl_dns_cache is None:
612-
return False
613-
return (
614-
self._cached_hosts_timestamp[key] + self._ttl_dns_cache
615-
) < self._loop.time()
616656

617657
@asyncio.coroutine
618658
def _resolve_host(self, host, port):
@@ -623,12 +663,13 @@ def _resolve_host(self, host, port):
623663
if self._use_dns_cache:
624664
key = (host, port)
625665

626-
if key not in self._cached_hosts or (self._dns_entry_expired(key)):
627-
self._cached_hosts[key] = yield from \
666+
if key not in self._cached_hosts or\
667+
self._cached_hosts.expired(key):
668+
addrs = yield from \
628669
self._resolver.resolve(host, port, family=self._family)
629-
self._cached_hosts_timestamp[key] = self._loop.time()
670+
self._cached_hosts.add(key, addrs)
630671

631-
return self._cached_hosts[key]
672+
return self._cached_hosts.next_addrs(key)
632673
else:
633674
res = yield from self._resolver.resolve(
634675
host, port, family=self._family)

tests/test_connector.py

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import ssl
1010
import tempfile
1111
import unittest
12+
from time import sleep
1213
from unittest import mock
1314

1415
import pytest
@@ -17,7 +18,7 @@
1718
import aiohttp
1819
from aiohttp import client, helpers, web
1920
from aiohttp.client import ClientRequest
20-
from aiohttp.connector import Connection
21+
from aiohttp.connector import Connection, _DNSCacheTable
2122
from aiohttp.test_utils import unused_port
2223

2324

@@ -365,40 +366,57 @@ def test_tcp_connector_resolve_host(loop):
365366

366367

367368
@asyncio.coroutine
368-
def test_tcp_connector_dns_cache_not_expired(loop):
369-
conn = aiohttp.TCPConnector(
370-
loop=loop,
371-
use_dns_cache=True,
372-
ttl_dns_cache=10
373-
)
369+
def dns_response():
370+
return ["127.0.0.1"]
374371

375-
res = yield from conn._resolve_host('localhost', 8080)
376-
res2 = yield from conn._resolve_host('localhost', 8080)
377372

378-
assert res is res2
373+
@asyncio.coroutine
374+
def test_tcp_connector_dns_cache_not_expired(loop):
375+
with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver:
376+
conn = aiohttp.TCPConnector(
377+
loop=loop,
378+
use_dns_cache=True,
379+
ttl_dns_cache=10
380+
)
381+
m_resolver().resolve.return_value = dns_response()
382+
yield from conn._resolve_host('localhost', 8080)
383+
yield from conn._resolve_host('localhost', 8080)
384+
m_resolver().resolve.assert_called_once_with(
385+
'localhost',
386+
8080,
387+
family=0
388+
)
379389

380390

381391
@asyncio.coroutine
382392
def test_tcp_connector_dns_cache_forever(loop):
383-
conn = aiohttp.TCPConnector(
384-
loop=loop,
385-
use_dns_cache=True,
386-
ttl_dns_cache=None
387-
)
388-
389-
res = yield from conn._resolve_host('localhost', 8080)
390-
res2 = yield from conn._resolve_host('localhost', 8080)
391-
assert res is res2
393+
with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver:
394+
conn = aiohttp.TCPConnector(
395+
loop=loop,
396+
use_dns_cache=True,
397+
ttl_dns_cache=10
398+
)
399+
m_resolver().resolve.return_value = dns_response()
400+
yield from conn._resolve_host('localhost', 8080)
401+
yield from conn._resolve_host('localhost', 8080)
402+
m_resolver().resolve.assert_called_once_with(
403+
'localhost',
404+
8080,
405+
family=0
406+
)
392407

393408

394409
@asyncio.coroutine
395410
def test_tcp_connector_use_dns_cache_disabled(loop):
396-
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
397-
398-
res = yield from conn._resolve_host('localhost', 8080)
399-
res2 = yield from conn._resolve_host('localhost', 8080)
400-
401-
assert res is not res2
411+
with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver:
412+
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
413+
m_resolver().resolve.return_value = dns_response()
414+
yield from conn._resolve_host('localhost', 8080)
415+
yield from conn._resolve_host('localhost', 8080)
416+
m_resolver().resolve.assert_has_calls([
417+
mock.call('localhost', 8080, family=0),
418+
mock.call('localhost', 8080, family=0)
419+
])
402420

403421

404422
def test_get_pop_empty_conns(loop):
@@ -631,20 +649,15 @@ def test_tcp_connector_fingerprint_invalid(loop):
631649

632650
def test_tcp_connector_clear_dns_cache(loop):
633651
conn = aiohttp.TCPConnector(loop=loop)
634-
info = object()
635-
conn._cached_hosts[('localhost', 123)] = info
636-
conn._cached_hosts_timestamp[('localhost', 123)] = 100
637-
conn._cached_hosts[('localhost', 124)] = info
638-
conn._cached_hosts_timestamp[('localhost', 124)] = 101
652+
hosts = ['a', 'b']
653+
conn._cached_hosts.add(('localhost', 123), hosts)
654+
conn._cached_hosts.add(('localhost', 124), hosts)
639655
conn.clear_dns_cache('localhost', 123)
640-
assert conn.cached_hosts == {('localhost', 124): info}
641-
assert conn._cached_hosts_timestamp == {('localhost', 124): 101}
656+
assert ('localhost', 123) not in conn.cached_hosts
642657
conn.clear_dns_cache('localhost', 123)
643-
assert conn.cached_hosts == {('localhost', 124): info}
644-
assert conn._cached_hosts_timestamp == {('localhost', 124): 101}
658+
assert ('localhost', 123) not in conn.cached_hosts
645659
conn.clear_dns_cache()
646660
assert conn.cached_hosts == {}
647-
assert conn._cached_hosts_timestamp == {}
648661

649662

650663
def test_tcp_connector_clear_dns_cache_bad_args(loop):
@@ -1181,3 +1194,57 @@ def test_resolver_not_called_with_address_is_ip(self):
11811194
self.loop.run_until_complete(connector.connect(req))
11821195

11831196
resolver.resolve.assert_not_called()
1197+
1198+
1199+
class TestDNSCacheTable:
1200+
1201+
@pytest.fixture
1202+
def dns_cache_table(self):
1203+
return _DNSCacheTable()
1204+
1205+
def test_addrs(self, dns_cache_table):
1206+
dns_cache_table.add('localhost', ['127.0.0.1'])
1207+
dns_cache_table.add('foo', ['127.0.0.2'])
1208+
assert dns_cache_table.addrs == {
1209+
'localhost': ['127.0.0.1'],
1210+
'foo': ['127.0.0.2']
1211+
}
1212+
1213+
def test_remove(self, dns_cache_table):
1214+
dns_cache_table.add('localhost', ['127.0.0.1'])
1215+
dns_cache_table.remove('localhost')
1216+
assert dns_cache_table.addrs == {}
1217+
1218+
def test_clear(self, dns_cache_table):
1219+
dns_cache_table.add('localhost', ['127.0.0.1'])
1220+
dns_cache_table.clear()
1221+
assert dns_cache_table.addrs == {}
1222+
1223+
def test_not_expired_ttl_None(self, dns_cache_table):
1224+
dns_cache_table.add('localhost', ['127.0.0.1'])
1225+
assert not dns_cache_table.expired('localhost')
1226+
1227+
def test_not_expired_ttl(self):
1228+
dns_cache_table = _DNSCacheTable(ttl=0.1)
1229+
dns_cache_table.add('localhost', ['127.0.0.1'])
1230+
assert not dns_cache_table.expired('localhost')
1231+
1232+
def test_expired_ttl(self):
1233+
dns_cache_table = _DNSCacheTable(ttl=0.1)
1234+
dns_cache_table.add('localhost', ['127.0.0.1'])
1235+
sleep(0.1)
1236+
assert dns_cache_table.expired('localhost')
1237+
1238+
def test_next_addrs(self, dns_cache_table):
1239+
dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2'])
1240+
1241+
# max elements returned are the full list of addrs
1242+
addrs = list(dns_cache_table.next_addrs('foo'))
1243+
assert addrs == ['127.0.0.1', '127.0.0.2']
1244+
1245+
# different calls to next_addrs return the hosts using
1246+
# a round robin strategy.
1247+
addrs = dns_cache_table.next_addrs('foo')
1248+
assert next(addrs) == '127.0.0.1'
1249+
addrs = dns_cache_table.next_addrs('foo')
1250+
assert next(addrs) == '127.0.0.2'

0 commit comments

Comments
 (0)