Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 38f65c2

Browse files
authored
Add type hints to base and authentication (auth0#472)
2 parents ea52e49 + b8a4a2c commit 38f65c2

20 files changed

+384
-186
lines changed

EXAMPLES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- [Connections](#connections)
88
- [Error handling](#error-handling)
99
- [Asynchronous environments](#asynchronous-environments)
10-
10+
1111
## Authentication SDK
1212

1313
### ID token validation

auth0/authentication/async_token_verifier.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
"""Token Verifier module"""
2+
from __future__ import annotations
3+
4+
from typing import TYPE_CHECKING, Any
5+
26
from .. import TokenValidationError
37
from ..rest_async import AsyncRestClient
48
from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier
59

10+
if TYPE_CHECKING:
11+
from aiohttp import ClientSession
12+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
13+
614

715
class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
816
"""Async verifier for RSA signatures, which rely on public key certificates.
@@ -12,11 +20,11 @@ class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
1220
algorithm (str, optional): The expected signing algorithm. Defaults to "RS256".
1321
"""
1422

15-
def __init__(self, jwks_url, algorithm="RS256"):
23+
def __init__(self, jwks_url: str, algorithm: str = "RS256") -> None:
1624
super().__init__(jwks_url, algorithm)
1725
self._fetcher = AsyncJwksFetcher(jwks_url)
1826

19-
def set_session(self, session):
27+
def set_session(self, session: ClientSession) -> None:
2028
"""Set Client Session to improve performance by reusing session.
2129
2230
Args:
@@ -32,7 +40,7 @@ async def _fetch_key(self, key_id=None):
3240
key_id (str): The key's key id."""
3341
return await self._fetcher.get_key(key_id)
3442

35-
async def verify_signature(self, token):
43+
async def verify_signature(self, token) -> dict[str, Any]:
3644
"""Verifies the signature of the given JSON web token.
3745
3846
Args:
@@ -57,11 +65,11 @@ class AsyncJwksFetcher(JwksFetcher):
5765
cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds.
5866
"""
5967

60-
def __init__(self, *args, **kwargs):
68+
def __init__(self, *args: Any, **kwargs: Any) -> None:
6169
super().__init__(*args, **kwargs)
6270
self._async_client = AsyncRestClient(None)
6371

64-
def set_session(self, session):
72+
def set_session(self, session: ClientSession) -> None:
6573
"""Set Client Session to improve performance by reusing session.
6674
6775
Args:
@@ -70,7 +78,7 @@ def set_session(self, session):
7078
"""
7179
self._async_client.set_session(session)
7280

73-
async def _fetch_jwks(self, force=False):
81+
async def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]:
7482
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.
7583
When not, it will perform a network request to the jwks_url to obtain a fresh result
7684
and update the cache value with it.
@@ -90,7 +98,7 @@ async def _fetch_jwks(self, force=False):
9098
self._cache_is_fresh = False
9199
return self._cache_value
92100

93-
async def get_key(self, key_id):
101+
async def get_key(self, key_id: str) -> RSAPublicKey:
94102
"""Obtains the JWK associated with the given key id.
95103
96104
Args:
@@ -126,7 +134,13 @@ class AsyncTokenVerifier(TokenVerifier):
126134
Defaults to 60 seconds.
127135
"""
128136

129-
def __init__(self, signature_verifier, issuer, audience, leeway=0):
137+
def __init__(
138+
self,
139+
signature_verifier: AsyncAsymmetricSignatureVerifier,
140+
issuer: str,
141+
audience: str,
142+
leeway: int = 0,
143+
) -> None:
130144
if not signature_verifier or not isinstance(
131145
signature_verifier, AsyncAsymmetricSignatureVerifier
132146
):
@@ -140,7 +154,7 @@ def __init__(self, signature_verifier, issuer, audience, leeway=0):
140154
self._sv = signature_verifier
141155
self._clock = None # legacy testing requirement
142156

143-
def set_session(self, session):
157+
def set_session(self, session: ClientSession) -> None:
144158
"""Set Client Session to improve performance by reusing session.
145159
146160
Args:
@@ -149,7 +163,13 @@ def set_session(self, session):
149163
"""
150164
self._sv.set_session(session)
151165

152-
async def verify(self, token, nonce=None, max_age=None, organization=None):
166+
async def verify(
167+
self,
168+
token: str,
169+
nonce: str | None = None,
170+
max_age: int | None = None,
171+
organization: str | None = None,
172+
) -> dict[str, Any]:
153173
"""Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec.
154174
155175
Args:

auth0/authentication/base.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
15
from auth0.rest import RestClient, RestClientOptions
6+
from auth0.types import RequestData, TimeoutType
27

38
from .client_authentication import add_client_authentication
49

@@ -21,15 +26,15 @@ class AuthenticationBase:
2126

2227
def __init__(
2328
self,
24-
domain,
25-
client_id,
26-
client_secret=None,
27-
client_assertion_signing_key=None,
28-
client_assertion_signing_alg=None,
29-
telemetry=True,
30-
timeout=5.0,
31-
protocol="https",
32-
):
29+
domain: str,
30+
client_id: str,
31+
client_secret: str | None = None,
32+
client_assertion_signing_key: str | None = None,
33+
client_assertion_signing_alg: str | None = None,
34+
telemetry: bool = True,
35+
timeout: TimeoutType = 5.0,
36+
protocol: str = "https",
37+
) -> None:
3338
self.domain = domain
3439
self.client_id = client_id
3540
self.client_secret = client_secret
@@ -41,7 +46,7 @@ def __init__(
4146
options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0),
4247
)
4348

44-
def _add_client_authentication(self, payload):
49+
def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]:
4550
return add_client_authentication(
4651
payload,
4752
self.domain,
@@ -51,13 +56,28 @@ def _add_client_authentication(self, payload):
5156
self.client_assertion_signing_alg,
5257
)
5358

54-
def post(self, url, data=None, headers=None):
59+
def post(
60+
self,
61+
url: str,
62+
data: RequestData | None = None,
63+
headers: dict[str, str] | None = None,
64+
) -> Any:
5565
return self.client.post(url, data=data, headers=headers)
5666

57-
def authenticated_post(self, url, data=None, headers=None):
67+
def authenticated_post(
68+
self,
69+
url: str,
70+
data: dict[str, Any],
71+
headers: dict[str, str] | None = None,
72+
) -> Any:
5873
return self.client.post(
5974
url, data=self._add_client_authentication(data), headers=headers
6075
)
6176

62-
def get(self, url, params=None, headers=None):
77+
def get(
78+
self,
79+
url: str,
80+
params: dict[str, Any] | None = None,
81+
headers: dict[str, str] | None = None,
82+
) -> Any:
6383
return self.client.get(url, params, headers)

auth0/authentication/client_authentication.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
1+
from __future__ import annotations
2+
13
import datetime
24
import uuid
5+
from typing import Any
36

47
import jwt
58

69

710
def create_client_assertion_jwt(
8-
domain, client_id, client_assertion_signing_key, client_assertion_signing_alg
9-
):
11+
domain: str,
12+
client_id: str,
13+
client_assertion_signing_key: str,
14+
client_assertion_signing_alg: str | None,
15+
) -> str:
1016
"""Creates a JWT for the client_assertion field.
1117
1218
Args:
1319
domain (str): The domain of your Auth0 tenant
1420
client_id (str): Your application's client ID
15-
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT
21+
client_assertion_signing_key (str): Private key used to sign the client assertion JWT
1622
client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')
1723
1824
Returns:
@@ -35,20 +41,20 @@ def create_client_assertion_jwt(
3541

3642

3743
def add_client_authentication(
38-
payload,
39-
domain,
40-
client_id,
41-
client_secret,
42-
client_assertion_signing_key,
43-
client_assertion_signing_alg,
44-
):
44+
payload: dict[str, Any],
45+
domain: str,
46+
client_id: str,
47+
client_secret: str | None,
48+
client_assertion_signing_key: str | None,
49+
client_assertion_signing_alg: str | None,
50+
) -> dict[str, Any]:
4551
"""Adds the client_assertion or client_secret fields to authenticate a payload.
4652
4753
Args:
4854
payload (dict): The POST payload that needs additional fields to be authenticated.
4955
domain (str): The domain of your Auth0 tenant
5056
client_id (str): Your application's client ID
51-
client_secret (str): Your application's client secret
57+
client_secret (str, optional): Your application's client secret
5258
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT
5359
client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')
5460

auth0/authentication/database.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import warnings
1+
from __future__ import annotations
2+
3+
from typing import Any
24

35
from .base import AuthenticationBase
46

@@ -12,17 +14,17 @@ class Database(AuthenticationBase):
1214

1315
def signup(
1416
self,
15-
email,
16-
password,
17-
connection,
18-
username=None,
19-
user_metadata=None,
20-
given_name=None,
21-
family_name=None,
22-
name=None,
23-
nickname=None,
24-
picture=None,
25-
):
17+
email: str,
18+
password: str,
19+
connection: str,
20+
username: str | None = None,
21+
user_metadata: dict[str, Any] | None = None,
22+
given_name: str | None = None,
23+
family_name: str | None = None,
24+
name: str | None = None,
25+
nickname: str | None = None,
26+
picture: str | None = None,
27+
) -> dict[str, Any]:
2628
"""Signup using email and password.
2729
2830
Args:
@@ -50,7 +52,7 @@ def signup(
5052
5153
See: https://auth0.com/docs/api/authentication#signup
5254
"""
53-
body = {
55+
body: dict[str, Any] = {
5456
"client_id": self.client_id,
5557
"email": email,
5658
"password": password,
@@ -71,11 +73,14 @@ def signup(
7173
if picture:
7274
body.update({"picture": picture})
7375

74-
return self.post(
76+
data: dict[str, Any] = self.post(
7577
f"{self.protocol}://{self.domain}/dbconnections/signup", data=body
7678
)
79+
return data
7780

78-
def change_password(self, email, connection, password=None):
81+
def change_password(
82+
self, email: str, connection: str, password: str | None = None
83+
) -> str:
7984
"""Asks to change a password for a given user.
8085
8186
email (str): The user's email address.
@@ -88,7 +93,8 @@ def change_password(self, email, connection, password=None):
8893
"connection": connection,
8994
}
9095

91-
return self.post(
96+
data: str = self.post(
9297
f"{self.protocol}://{self.domain}/dbconnections/change_password",
9398
data=body,
9499
)
100+
return data

auth0/authentication/delegated.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
15
from .base import AuthenticationBase
26

37

@@ -10,13 +14,13 @@ class Delegated(AuthenticationBase):
1014

1115
def get_token(
1216
self,
13-
target,
14-
api_type,
15-
grant_type,
16-
id_token=None,
17-
refresh_token=None,
18-
scope="openid",
19-
):
17+
target: str,
18+
api_type: str,
19+
grant_type: str,
20+
id_token: str | None = None,
21+
refresh_token: str | None = None,
22+
scope: str = "openid",
23+
) -> Any:
2024
"""Obtain a delegation token."""
2125

2226
if id_token and refresh_token:

auth0/authentication/enterprise.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
from .base import AuthenticationBase
24

35

@@ -9,7 +11,7 @@ class Enterprise(AuthenticationBase):
911
domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com)
1012
"""
1113

12-
def saml_metadata(self):
14+
def saml_metadata(self) -> Any:
1315
"""Get SAML2.0 Metadata."""
1416

1517
return self.get(
@@ -18,7 +20,7 @@ def saml_metadata(self):
1820
)
1921
)
2022

21-
def wsfed_metadata(self):
23+
def wsfed_metadata(self) -> Any:
2224
"""Returns the WS-Federation Metadata."""
2325

2426
url = "{}://{}/wsfed/FederationMetadata/2007-06/FederationMetadata.xml"

0 commit comments

Comments
 (0)