66
77
88import functools
9+ import hmac
910import http
1011from typing import Any , Awaitable , Callable , Iterable , Optional , Tuple , Union , cast
1112
@@ -132,24 +133,23 @@ def basic_auth_protocol_factory(
132133
133134 if credentials is not None :
134135 if is_credentials (credentials ):
135-
136- async def check_credentials (username : str , password : str ) -> bool :
137- return (username , password ) == credentials
138-
136+ credentials_list = [cast (Credentials , credentials )]
139137 elif isinstance (credentials , Iterable ):
140138 credentials_list = list (credentials )
141- if all (is_credentials (item ) for item in credentials_list ):
142- credentials_dict = dict (credentials_list )
143-
144- async def check_credentials (username : str , password : str ) -> bool :
145- return credentials_dict .get (username ) == password
146-
147- else :
139+ if not all (is_credentials (item ) for item in credentials_list ):
148140 raise TypeError (f"invalid credentials argument: { credentials } " )
149-
150141 else :
151142 raise TypeError (f"invalid credentials argument: { credentials } " )
152143
144+ credentials_dict = dict (credentials_list )
145+
146+ async def check_credentials (username : str , password : str ) -> bool :
147+ try :
148+ expected_password = credentials_dict [username ]
149+ except KeyError :
150+ return False
151+ return hmac .compare_digest (expected_password , password )
152+
153153 if create_protocol is None :
154154 # Not sure why mypy cannot figure this out.
155155 create_protocol = cast (
@@ -158,5 +158,7 @@ async def check_credentials(username: str, password: str) -> bool:
158158 )
159159
160160 return functools .partial (
161- create_protocol , realm = realm , check_credentials = check_credentials
161+ create_protocol ,
162+ realm = realm ,
163+ check_credentials = check_credentials ,
162164 )
0 commit comments