Skip to content

Commit 845f775

Browse files
authored
Replace API server’s direct Connection access workaround in BaseHook (#54083)
1 parent 6e92943 commit 845f775

File tree

22 files changed

+282
-201
lines changed

22 files changed

+282
-201
lines changed

devel-common/src/tests_common/pytest_plugin.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,35 @@ def mock_supervisor_comms(monkeypatch):
20712071
yield comms
20722072

20732073

2074+
@pytest.fixture
2075+
def sdk_connection_not_found(mock_supervisor_comms):
2076+
"""
2077+
Fixture that mocks supervisor comms to return CONNECTION_NOT_FOUND error.
2078+
2079+
This eliminates the need to manually set up the mock in every test that
2080+
needs a connection not found message through supervisor comms.
2081+
2082+
Example:
2083+
@pytest.mark.db_test
2084+
def test_invalid_location(self, sdk_connection_not_found):
2085+
# Test logic that expects CONNECTION_NOT_FOUND error
2086+
with pytest.raises(AirflowException):
2087+
operator.execute(context)
2088+
"""
2089+
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
2090+
2091+
if not AIRFLOW_V_3_0_PLUS:
2092+
yield None
2093+
return
2094+
2095+
from airflow.sdk.exceptions import ErrorType
2096+
from airflow.sdk.execution_time.comms import ErrorResponse
2097+
2098+
mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
2099+
2100+
yield mock_supervisor_comms
2101+
2102+
20742103
@pytest.fixture
20752104
def mocked_parse(spy_agency):
20762105
"""

providers/amazon/tests/unit/amazon/aws/bundles/test_s3.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
from airflow.exceptions import AirflowException
2828
from airflow.models import Connection
2929
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
30-
from airflow.utils import db
3130

3231
from tests_common.test_utils.config import conf_vars
33-
from tests_common.test_utils.db import clear_db_connections
3432

3533
AWS_CONN_ID_WITH_REGION = "s3_dags_connection"
3634
AWS_CONN_ID_REGION = "eu-central-1"
@@ -78,13 +76,9 @@ def bundle_temp_dir(tmp_path):
7876

7977
@pytest.mark.skipif(not airflow.version.version.strip().startswith("3"), reason="Airflow >=3.0.0 test")
8078
class TestS3DagBundle:
81-
@classmethod
82-
def teardown_class(cls) -> None:
83-
clear_db_connections()
84-
85-
@classmethod
86-
def setup_class(cls) -> None:
87-
db.merge_conn(
79+
@pytest.fixture(autouse=True)
80+
def setup_connections(self, create_connection_without_db):
81+
create_connection_without_db(
8882
Connection(
8983
conn_id=AWS_CONN_ID_DEFAULT,
9084
conn_type="aws",
@@ -93,8 +87,8 @@ def setup_class(cls) -> None:
9387
},
9488
)
9589
)
96-
db.merge_conn(
97-
conn=Connection(
90+
create_connection_without_db(
91+
Connection(
9892
conn_id=AWS_CONN_ID_WITH_REGION,
9993
conn_type="aws",
10094
extra={
@@ -104,7 +98,6 @@ def setup_class(cls) -> None:
10498
)
10599
)
106100

107-
@pytest.mark.db_test
108101
def test_view_url_generates_presigned_url(self):
109102
bundle = S3DagBundle(
110103
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1/dags", bucket_name=S3_BUCKET_NAME
@@ -113,15 +106,13 @@ def test_view_url_generates_presigned_url(self):
113106
url: str = bundle.view_url("test_version")
114107
assert url.startswith("https://my-airflow-dags-bucket.s3.amazonaws.com/project1/dags")
115108

116-
@pytest.mark.db_test
117109
def test_view_url_template_generates_presigned_url(self):
118110
bundle = S3DagBundle(
119111
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1/dags", bucket_name=S3_BUCKET_NAME
120112
)
121113
url: str = bundle.view_url_template()
122114
assert url.startswith("https://my-airflow-dags-bucket.s3.amazonaws.com/project1/dags")
123115

124-
@pytest.mark.db_test
125116
def test_supports_versioning(self):
126117
bundle = S3DagBundle(
127118
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1/dags", bucket_name=S3_BUCKET_NAME
@@ -136,14 +127,12 @@ def test_supports_versioning(self):
136127
with pytest.raises(AirflowException, match="S3 url with version is not supported"):
137128
bundle.view_url("test_version")
138129

139-
@pytest.mark.db_test
140130
def test_correct_bundle_path_used(self):
141131
bundle = S3DagBundle(
142132
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1_dags", bucket_name="airflow_dags"
143133
)
144134
assert str(bundle.base_dir) == str(bundle.s3_dags_dir)
145135

146-
@pytest.mark.db_test
147136
def test_s3_bucket_and_prefix_validated(self, s3_bucket):
148137
hook = S3Hook(aws_conn_id=AWS_CONN_ID_DEFAULT)
149138
assert hook.check_for_bucket(s3_bucket.name) is True
@@ -195,7 +184,6 @@ def _upload_fixtures(self, bucket: str, fixtures_dir: str) -> None:
195184
key = os.path.relpath(path, fixtures_dir)
196185
client.upload_file(Filename=path, Bucket=bucket, Key=key)
197186

198-
@pytest.mark.db_test
199187
def test_refresh(self, s3_bucket, s3_client):
200188
bundle = S3DagBundle(
201189
name="test",
@@ -218,7 +206,6 @@ def test_refresh(self, s3_bucket, s3_client):
218206
assert bundle._log.debug.call_count == 3
219207
assert bundle._log.debug.call_args_list == [download_log_call, download_log_call, download_log_call]
220208

221-
@pytest.mark.db_test
222209
def test_refresh_without_prefix(self, s3_bucket, s3_client):
223210
bundle = S3DagBundle(
224211
name="test",

providers/amazon/tests/unit/amazon/aws/hooks/test_base_aws.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
5151

5252
from tests_common.test_utils.config import conf_vars
53+
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
5354

5455
pytest.importorskip("aiobotocore")
5556

@@ -430,9 +431,8 @@ def test_user_agent_caller_target_function_found(self, mock_class_name, found_cl
430431
assert mock_class_name.call_count == len(found_classes)
431432
assert user_agent_tags["Caller"] == found_classes[-1]
432433

433-
@pytest.mark.db_test
434434
@mock.patch.object(AwsEcsExecutor, "_load_run_kwargs")
435-
def test_user_agent_caller_target_executor_found(self, mock_load_run_kwargs):
435+
def test_user_agent_caller_target_executor_found(self, mock_load_run_kwargs, sdk_connection_not_found):
436436
with conf_vars(
437437
{
438438
("aws_ecs_executor", "cluster"): "foo",
@@ -456,7 +456,16 @@ def test_user_agent_caller_target_function_not_found(self):
456456
@pytest.mark.db_test
457457
@pytest.mark.parametrize("env_var, expected_version", [({"AIRFLOW_CTX_DAG_ID": "banana"}, 5), [{}, None]])
458458
@mock.patch.object(AwsBaseHook, "_get_caller", return_value="Test")
459-
def test_user_agent_dag_run_key_is_hashed_correctly(self, _, env_var, expected_version):
459+
def test_user_agent_dag_run_key_is_hashed_correctly(
460+
self, _, env_var, expected_version, mock_supervisor_comms
461+
):
462+
if AIRFLOW_V_3_0_PLUS:
463+
from airflow.sdk.execution_time.comms import ConnectionResult
464+
465+
mock_supervisor_comms.send.return_value = ConnectionResult(
466+
conn_id="aws_default",
467+
conn_type="aws",
468+
)
460469
with mock.patch.dict(os.environ, env_var, clear=True):
461470
dag_run_key = self.fetch_tags()["DagRunKey"]
462471

providers/amazon/tests/unit/amazon/aws/hooks/test_emr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def test_empty_emr_conn_id(self, mock_boto3_client):
195195

196196
@pytest.mark.db_test
197197
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
198-
def test_missing_emr_conn_id(self, mock_boto3_client):
198+
def test_missing_emr_conn_id(self, mock_boto3_client, sdk_connection_not_found):
199199
"""Test not exists ``emr_conn_id``."""
200200
mock_run_job_flow = mock.MagicMock()
201201
mock_boto3_client.return_value.run_job_flow = mock_run_job_flow

providers/amazon/tests/unit/amazon/aws/sensors/test_eks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_poke_reached_unexpected_terminal_state(self, mock_get_cluster_state, un
106106
mock_get_cluster_state.assert_called_once_with(clusterName=CLUSTER_NAME)
107107

108108
@pytest.mark.db_test
109-
def test_region_argument(self):
109+
def test_region_argument(self, sdk_connection_not_found):
110110
with pytest.warns(AirflowProviderDeprecationWarning) as w:
111111
w.sensor = EksClusterStateSensor(
112112
task_id=TASK_ID,

providers/amazon/tests/unit/amazon/aws/transfers/test_redshift_to_s3.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -600,24 +600,39 @@ def test_table_unloading_using_redshift_data_api(
600600
# test sql arg
601601
assert_equal_ignore_multiple_spaces(mock_rs.execute_statement.call_args.kwargs["Sql"], unload_query)
602602

603-
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
604-
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
605603
@mock.patch("boto3.session.Session")
606604
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
607605
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
608606
def test_get_openlineage_facets_on_complete_default(
609-
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
607+
self, mock_get_facets, mock_run, mock_session, create_connection_without_db
610608
):
609+
create_connection_without_db(
610+
Connection(
611+
conn_id="aws_conn_id",
612+
conn_type="aws",
613+
schema="database",
614+
port=5439,
615+
host="cluster.id.region.redshift.amazonaws.com",
616+
extra={},
617+
)
618+
)
619+
create_connection_without_db(
620+
Connection(
621+
conn_id="redshift_conn_id",
622+
conn_type="redshift",
623+
schema="database",
624+
port=5439,
625+
host="cluster.id.region.redshift.amazonaws.com",
626+
extra={},
627+
)
628+
)
611629
access_key = "aws_access_key_id"
612630
secret_key = "aws_secret_access_key"
613631
mock_session.return_value = Session(access_key, secret_key)
614632
mock_session.return_value.access_key = access_key
615633
mock_session.return_value.secret_key = secret_key
616634
mock_session.return_value.token = None
617635

618-
mock_connection.return_value = mock.MagicMock(
619-
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
620-
)
621636
mock_facets = {
622637
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
623638
"documentation": DocumentationDatasetFacet(description="mock_description"),
@@ -671,24 +686,38 @@ def test_get_openlineage_facets_on_complete_default(
671686
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
672687
}
673688

674-
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
675-
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
676689
@mock.patch("boto3.session.Session")
677690
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
678691
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
679692
def test_get_openlineage_facets_on_complete_with_select_query(
680-
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
693+
self, mock_get_facets, mock_run, mock_session, create_connection_without_db
681694
):
695+
create_connection_without_db(
696+
Connection(
697+
conn_id="redshift_conn_id",
698+
conn_type="redshift",
699+
schema="database",
700+
port=5439,
701+
host="cluster.id.region.redshift.amazonaws.com",
702+
extra={},
703+
)
704+
)
705+
create_connection_without_db(
706+
Connection(
707+
conn_id="aws_conn_id",
708+
conn_type="aws",
709+
schema="database",
710+
port=5439,
711+
host="cluster.id.region.redshift.amazonaws.com",
712+
extra={},
713+
)
714+
)
682715
access_key = "aws_access_key_id"
683716
secret_key = "aws_secret_access_key"
684717
mock_session.return_value = Session(access_key, secret_key)
685718
mock_session.return_value.access_key = access_key
686719
mock_session.return_value.secret_key = secret_key
687720
mock_session.return_value.token = None
688-
689-
mock_connection.return_value = mock.MagicMock(
690-
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
691-
)
692721
mock_facets = {
693722
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
694723
"documentation": DocumentationDatasetFacet(description="mock_description"),
@@ -835,8 +864,6 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
835864
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
836865
}
837866

838-
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
839-
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
840867
@mock.patch("boto3.session.Session")
841868
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
842869
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@@ -846,22 +873,38 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
846873
)
847874
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
848875
def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
849-
self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook
876+
self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, create_connection_without_db
850877
):
851878
"""
852879
Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook return same lineage.
853880
"""
881+
create_connection_without_db(
882+
Connection(
883+
conn_id="redshift_conn_id",
884+
conn_type="redshift",
885+
schema="database",
886+
port=5439,
887+
host="cluster.id.region.redshift.amazonaws.com",
888+
extra={},
889+
)
890+
)
891+
create_connection_without_db(
892+
Connection(
893+
conn_id="aws_conn_id",
894+
conn_type="aws",
895+
schema="database",
896+
port=5439,
897+
host="cluster.id.region.redshift.amazonaws.com",
898+
extra={},
899+
)
900+
)
901+
854902
access_key = "aws_access_key_id"
855903
secret_key = "aws_secret_access_key"
856904
mock_session.return_value = Session(access_key, secret_key)
857905
mock_session.return_value.access_key = access_key
858906
mock_session.return_value.secret_key = secret_key
859907
mock_session.return_value.token = None
860-
861-
mock_connection.return_value = mock.MagicMock(
862-
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
863-
)
864-
mock_hook.return_value = Connection()
865908
mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
866909
mock_rs.describe_statement.return_value = {"Status": "FINISHED"}
867910

0 commit comments

Comments
 (0)