|
9 | 9 | import ssl |
10 | 10 | import tempfile |
11 | 11 | import unittest |
| 12 | +from time import sleep |
12 | 13 | from unittest import mock |
13 | 14 |
|
14 | 15 | import pytest |
|
17 | 18 | import aiohttp |
18 | 19 | from aiohttp import client, helpers, web |
19 | 20 | from aiohttp.client import ClientRequest |
20 | | -from aiohttp.connector import Connection |
| 21 | +from aiohttp.connector import Connection, _DNSCacheTable |
21 | 22 | from aiohttp.test_utils import unused_port |
22 | 23 |
|
23 | 24 |
|
@@ -365,40 +366,57 @@ def test_tcp_connector_resolve_host(loop): |
365 | 366 |
|
366 | 367 |
|
367 | 368 | @asyncio.coroutine |
368 | | -def test_tcp_connector_dns_cache_not_expired(loop): |
369 | | - conn = aiohttp.TCPConnector( |
370 | | - loop=loop, |
371 | | - use_dns_cache=True, |
372 | | - ttl_dns_cache=10 |
373 | | - ) |
| 369 | +def dns_response(): |
| 370 | + return ["127.0.0.1"] |
374 | 371 |
|
375 | | - res = yield from conn._resolve_host('localhost', 8080) |
376 | | - res2 = yield from conn._resolve_host('localhost', 8080) |
377 | 372 |
|
378 | | - assert res is res2 |
| 373 | +@asyncio.coroutine |
| 374 | +def test_tcp_connector_dns_cache_not_expired(loop): |
| 375 | + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: |
| 376 | + conn = aiohttp.TCPConnector( |
| 377 | + loop=loop, |
| 378 | + use_dns_cache=True, |
| 379 | + ttl_dns_cache=10 |
| 380 | + ) |
| 381 | + m_resolver().resolve.return_value = dns_response() |
| 382 | + yield from conn._resolve_host('localhost', 8080) |
| 383 | + yield from conn._resolve_host('localhost', 8080) |
| 384 | + m_resolver().resolve.assert_called_once_with( |
| 385 | + 'localhost', |
| 386 | + 8080, |
| 387 | + family=0 |
| 388 | + ) |
379 | 389 |
|
380 | 390 |
|
381 | 391 | @asyncio.coroutine |
382 | 392 | def test_tcp_connector_dns_cache_forever(loop): |
383 | | - conn = aiohttp.TCPConnector( |
384 | | - loop=loop, |
385 | | - use_dns_cache=True, |
386 | | - ttl_dns_cache=None |
387 | | - ) |
388 | | - |
389 | | - res = yield from conn._resolve_host('localhost', 8080) |
390 | | - res2 = yield from conn._resolve_host('localhost', 8080) |
391 | | - assert res is res2 |
| 393 | + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: |
| 394 | + conn = aiohttp.TCPConnector( |
| 395 | + loop=loop, |
| 396 | + use_dns_cache=True, |
| 397 | + ttl_dns_cache=10 |
| 398 | + ) |
| 399 | + m_resolver().resolve.return_value = dns_response() |
| 400 | + yield from conn._resolve_host('localhost', 8080) |
| 401 | + yield from conn._resolve_host('localhost', 8080) |
| 402 | + m_resolver().resolve.assert_called_once_with( |
| 403 | + 'localhost', |
| 404 | + 8080, |
| 405 | + family=0 |
| 406 | + ) |
392 | 407 |
|
393 | 408 |
|
394 | 409 | @asyncio.coroutine |
395 | 410 | def test_tcp_connector_use_dns_cache_disabled(loop): |
396 | | - conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False) |
397 | | - |
398 | | - res = yield from conn._resolve_host('localhost', 8080) |
399 | | - res2 = yield from conn._resolve_host('localhost', 8080) |
400 | | - |
401 | | - assert res is not res2 |
| 411 | + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: |
| 412 | + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False) |
| 413 | + m_resolver().resolve.return_value = dns_response() |
| 414 | + yield from conn._resolve_host('localhost', 8080) |
| 415 | + yield from conn._resolve_host('localhost', 8080) |
| 416 | + m_resolver().resolve.assert_has_calls([ |
| 417 | + mock.call('localhost', 8080, family=0), |
| 418 | + mock.call('localhost', 8080, family=0) |
| 419 | + ]) |
402 | 420 |
|
403 | 421 |
|
404 | 422 | def test_get_pop_empty_conns(loop): |
@@ -631,20 +649,15 @@ def test_tcp_connector_fingerprint_invalid(loop): |
631 | 649 |
|
632 | 650 | def test_tcp_connector_clear_dns_cache(loop): |
633 | 651 | conn = aiohttp.TCPConnector(loop=loop) |
634 | | - info = object() |
635 | | - conn._cached_hosts[('localhost', 123)] = info |
636 | | - conn._cached_hosts_timestamp[('localhost', 123)] = 100 |
637 | | - conn._cached_hosts[('localhost', 124)] = info |
638 | | - conn._cached_hosts_timestamp[('localhost', 124)] = 101 |
| 652 | + hosts = ['a', 'b'] |
| 653 | + conn._cached_hosts.add(('localhost', 123), hosts) |
| 654 | + conn._cached_hosts.add(('localhost', 124), hosts) |
639 | 655 | conn.clear_dns_cache('localhost', 123) |
640 | | - assert conn.cached_hosts == {('localhost', 124): info} |
641 | | - assert conn._cached_hosts_timestamp == {('localhost', 124): 101} |
| 656 | + assert ('localhost', 123) not in conn.cached_hosts |
642 | 657 | conn.clear_dns_cache('localhost', 123) |
643 | | - assert conn.cached_hosts == {('localhost', 124): info} |
644 | | - assert conn._cached_hosts_timestamp == {('localhost', 124): 101} |
| 658 | + assert ('localhost', 123) not in conn.cached_hosts |
645 | 659 | conn.clear_dns_cache() |
646 | 660 | assert conn.cached_hosts == {} |
647 | | - assert conn._cached_hosts_timestamp == {} |
648 | 661 |
|
649 | 662 |
|
650 | 663 | def test_tcp_connector_clear_dns_cache_bad_args(loop): |
@@ -1181,3 +1194,57 @@ def test_resolver_not_called_with_address_is_ip(self): |
1181 | 1194 | self.loop.run_until_complete(connector.connect(req)) |
1182 | 1195 |
|
1183 | 1196 | resolver.resolve.assert_not_called() |
| 1197 | + |
| 1198 | + |
| 1199 | +class TestDNSCacheTable: |
| 1200 | + |
| 1201 | + @pytest.fixture |
| 1202 | + def dns_cache_table(self): |
| 1203 | + return _DNSCacheTable() |
| 1204 | + |
| 1205 | + def test_addrs(self, dns_cache_table): |
| 1206 | + dns_cache_table.add('localhost', ['127.0.0.1']) |
| 1207 | + dns_cache_table.add('foo', ['127.0.0.2']) |
| 1208 | + assert dns_cache_table.addrs == { |
| 1209 | + 'localhost': ['127.0.0.1'], |
| 1210 | + 'foo': ['127.0.0.2'] |
| 1211 | + } |
| 1212 | + |
| 1213 | + def test_remove(self, dns_cache_table): |
| 1214 | + dns_cache_table.add('localhost', ['127.0.0.1']) |
| 1215 | + dns_cache_table.remove('localhost') |
| 1216 | + assert dns_cache_table.addrs == {} |
| 1217 | + |
| 1218 | + def test_clear(self, dns_cache_table): |
| 1219 | + dns_cache_table.add('localhost', ['127.0.0.1']) |
| 1220 | + dns_cache_table.clear() |
| 1221 | + assert dns_cache_table.addrs == {} |
| 1222 | + |
| 1223 | + def test_not_expired_ttl_None(self, dns_cache_table): |
| 1224 | + dns_cache_table.add('localhost', ['127.0.0.1']) |
| 1225 | + assert not dns_cache_table.expired('localhost') |
| 1226 | + |
| 1227 | + def test_not_expired_ttl(self): |
| 1228 | + dns_cache_table = _DNSCacheTable(ttl=0.1) |
| 1229 | + dns_cache_table.add('localhost', ['127.0.0.1']) |
| 1230 | + assert not dns_cache_table.expired('localhost') |
| 1231 | + |
| 1232 | + def test_expired_ttl(self): |
| 1233 | + dns_cache_table = _DNSCacheTable(ttl=0.1) |
| 1234 | + dns_cache_table.add('localhost', ['127.0.0.1']) |
| 1235 | + sleep(0.1) |
| 1236 | + assert dns_cache_table.expired('localhost') |
| 1237 | + |
| 1238 | + def test_next_addrs(self, dns_cache_table): |
| 1239 | + dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2']) |
| 1240 | + |
| 1241 | + # max elements returned are the full list of addrs |
| 1242 | + addrs = list(dns_cache_table.next_addrs('foo')) |
| 1243 | + assert addrs == ['127.0.0.1', '127.0.0.2'] |
| 1244 | + |
| 1245 | + # different calls to next_addrs return the hosts using |
| 1246 | + # a round robin strategy. |
| 1247 | + addrs = dns_cache_table.next_addrs('foo') |
| 1248 | + assert next(addrs) == '127.0.0.1' |
| 1249 | + addrs = dns_cache_table.next_addrs('foo') |
| 1250 | + assert next(addrs) == '127.0.0.2' |
0 commit comments