diff --git a/python_http_client/exceptions.py b/python_http_client/exceptions.py index b6343f2..2a8c179 100644 --- a/python_http_client/exceptions.py +++ b/python_http_client/exceptions.py @@ -4,11 +4,23 @@ class HTTPError(Exception): """ Base of all other errors""" - def __init__(self, error): - self.status_code = error.code - self.reason = error.reason - self.body = error.read() - self.headers = error.hdrs + def __init__(self, *args): + if len(args) == 4: + self.status_code = args[0] + self.reason = args[1] + self.body = args[2] + self.headers = args[3] + else: + self.status_code = args[0].code + self.reason = args[0].reason + self.body = args[0].read() + self.headers = args[0].hdrs + + def __reduce__(self): + return ( + HTTPError, + (self.status_code, self.reason, self.body, self.headers) + ) @property def to_dict(self): diff --git a/tests/test_unit.py b/tests/test_unit.py index 36b115f..be85543 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -7,6 +7,7 @@ BadRequestsError, NotFoundError, ServiceUnavailableError, + UnauthorizedError, UnsupportedMediaTypeError, ) @@ -208,6 +209,41 @@ def test_client_pickle_unpickle(self): "original client and unpickled client must have the same state" ) + @mock.patch('python_http_client.client.urllib') + def test_pickle_error(self, mock_lib): + mock_opener = MockOpener() + mock_lib.build_opener.return_value = mock_opener + + client = self.client.__getattr__('hello') + + mock_opener.response_code = 401 + try: + client.get() + except UnauthorizedError as e: + pickled_error = pickle.dumps(e) + unpickled_error = pickle.loads(pickled_error) + + self.assertEqual( + e.status_code, + unpickled_error.status_code, + "unpickled error must have the same status code", + ) + self.assertEqual( + e.reason, + unpickled_error.reason, + "unpickled error must have the same reason", + ) + self.assertEqual( + e.body, + unpickled_error.body, + "unpickled error must have the same body", + ) + self.assertEqual( + e.headers, + unpickled_error.headers, + "unpickled error must have the same headers", + ) + if __name__ == '__main__': unittest.main()