Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,15 @@ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: dict[str, An
def _get_web_identity_credential_fetcher(
self,
) -> botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher:
base_session = self.basic_session._session or botocore.session.get_session()
client_creator = base_session.create_client
session_config = self.config
endpoint_url = self.conn.get_service_endpoint_url("sts", sts_connection_assume=True)

def client_creator(service_name, **kwargs):
config = kwargs.pop("config", None)
if session_config:
config = session_config.merge(config) if config else session_config
return self.basic_session.client(service_name, config=config, endpoint_url=endpoint_url, **kwargs)

federation = str(self.extra_config.get("assume_role_with_web_identity_federation"))

web_identity_token_loader = {
Expand Down
61 changes: 54 additions & 7 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,21 +569,19 @@ def import_mock(name, *args):
mock_boto3.assert_has_calls(
[
mock.call.session.Session(),
mock.call.session.Session()._session.__bool__(),
mock.call.session.Session(botocore_session=mock_session.get_session.return_value),
mock.call.session.Session().get_credentials(),
mock.call.session.Session().get_credentials().get_frozen_credentials(),
]
)
mock_fetcher = mock_botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher
mock_fetcher_call_kwargs = mock_fetcher.call_args.kwargs
assert mock_fetcher_call_kwargs["role_arn"] == "arn:aws:iam::123456:role/role_arn"
assert mock_fetcher_call_kwargs["extra_args"] == {}
# client_creator should be a wrapper function, not the raw create_client
assert callable(mock_fetcher_call_kwargs["client_creator"])
mock_botocore.assert_has_calls(
[
mock.call.credentials.AssumeRoleWithWebIdentityCredentialFetcher(
client_creator=mock_boto3.session.Session.return_value._session.create_client,
extra_args={},
role_arn="arn:aws:iam::123456:role/role_arn",
web_identity_token_loader=mock.ANY,
),
mock.call.credentials.DeferredRefreshableCredentials(
method="assume-role-with-web-identity",
refresh_using=mock_fetcher.return_value.fetch_credentials,
Expand Down Expand Up @@ -639,6 +637,55 @@ def test_get_credentials_from_token_file(self, mock_session, mock_credentials_fe
assert mock_creds_fetcher_kwargs["web_identity_token_loader"]() == "TOKEN"
assert mock_open_.call_args.args[0] == "/my-token-path"

@mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher"
)
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.botocore.session.Session")
def test_web_identity_credential_fetcher_uses_botocore_config(
self, mock_session, mock_credentials_fetcher
):
"""Test that assume_role_with_web_identity passes botocore config (e.g. proxy) to STS client."""
proxy_config = {"https": "http://proxy.example.com:8080"}
with mock.patch.object(
AwsBaseHook,
"get_connection",
return_value=Connection(
conn_id="aws_default",
conn_type="aws",
extra=json.dumps(
{
"role_arn": "arn:aws:iam::123456:role/role_arn",
"assume_role_method": "assume_role_with_web_identity",
"assume_role_with_web_identity_token_file": "/my-token-path",
"assume_role_with_web_identity_federation": "file",
"config_kwargs": {"proxies": proxy_config},
}
),
),
):
mock_open_ = mock_open(read_data="TOKEN")
with mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.botocore.utils.FileWebIdentityTokenLoader.__init__.__defaults__",
new=(mock_open_,),
):
AwsBaseHook(aws_conn_id="aws_default", client_type="airflow_test").get_session()

_, mock_creds_fetcher_kwargs = mock_credentials_fetcher.call_args
# Invoke the client_creator wrapper to verify config is merged
client_creator = mock_creds_fetcher_kwargs["client_creator"]
mock_base_client_creator = mock_session.return_value.create_client

# Simulate what botocore does internally: calls client_creator('sts', config=Config(...))
from botocore.config import Config

unsigned_config = Config(signature_version="unsigned")
client_creator("sts", config=unsigned_config)

call_kwargs = mock_base_client_creator.call_args.kwargs
merged_config = call_kwargs["config"]
# The proxy settings from the connection should be present in the merged config
assert merged_config.proxies == proxy_config

@pytest.mark.parametrize(
"sts_endpoint",
[
Expand Down
Loading