|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +from unittest.mock import ANY |
15 | 16 | from unittest.mock import AsyncMock |
16 | 17 | from unittest.mock import Mock |
17 | 18 | from unittest.mock import patch |
|
30 | 31 | from google.adk.auth.auth_schemes import ExtendedOAuth2 |
31 | 32 | from google.adk.auth.auth_tool import AuthConfig |
32 | 33 | from google.adk.auth.credential_manager import CredentialManager |
| 34 | +from google.adk.auth.credential_manager import ServiceAccountCredentialExchanger |
33 | 35 | from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata |
34 | 36 | import pytest |
35 | 37 |
|
@@ -422,36 +424,32 @@ async def test_validate_credential_oauth2_missing_scheme_info( |
422 | 424 | await manager._validate_credential() |
423 | 425 |
|
424 | 426 | @pytest.mark.asyncio |
425 | | - async def test_exchange_credentials_service_account(self): |
| 427 | + async def test_exchange_credentials_service_account( |
| 428 | + self, service_account_credential, oauth2_auth_scheme |
| 429 | + ): |
426 | 430 | """Test _exchange_credential with service account credential.""" |
427 | | - mock_service_account = Mock(spec=ServiceAccount) |
428 | | - mock_credential = Mock(spec=AuthCredential) |
429 | | - mock_credential.auth_type = AuthCredentialTypes.SERVICE_ACCOUNT |
430 | | - |
431 | 431 | auth_config = Mock(spec=AuthConfig) |
432 | | - auth_config.auth_scheme = Mock() |
| 432 | + auth_config.auth_scheme = oauth2_auth_scheme |
433 | 433 |
|
434 | | - # Mock exchanger |
435 | | - mock_exchanger = Mock() |
436 | | - mock_exchanger.exchange = AsyncMock(return_value=mock_credential) |
| 434 | + exchanged_credential = Mock(spec=AuthCredential) |
437 | 435 |
|
438 | 436 | manager = CredentialManager(auth_config) |
439 | 437 |
|
440 | | - # Mock the exchanger registry to return our mock exchanger |
441 | 438 | with patch.object( |
442 | | - manager._exchanger_registry, |
443 | | - "get_exchanger", |
444 | | - return_value=mock_exchanger, |
445 | | - ): |
| 439 | + ServiceAccountCredentialExchanger, |
| 440 | + "exchange_credential", |
| 441 | + return_value=exchanged_credential, |
| 442 | + autospec=True, |
| 443 | + ) as mock_exchange_credential: |
446 | 444 | result, was_exchanged = await manager._exchange_credential( |
447 | | - mock_credential |
| 445 | + service_account_credential |
448 | 446 | ) |
449 | 447 |
|
450 | | - mock_exchanger.exchange.assert_called_once_with( |
451 | | - mock_credential, auth_config.auth_scheme |
452 | | - ) |
453 | | - assert result == mock_credential |
454 | | - assert was_exchanged is True |
| 448 | + mock_exchange_credential.assert_called_once_with( |
| 449 | + ANY, oauth2_auth_scheme, service_account_credential |
| 450 | + ) |
| 451 | + assert result == exchanged_credential |
| 452 | + assert was_exchanged is True |
455 | 453 |
|
456 | 454 | @pytest.mark.asyncio |
457 | 455 | async def test_exchange_credential_no_exchanger(self): |
|
0 commit comments