diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py index 9ac4627353204..89048ef7f9422 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py @@ -180,6 +180,11 @@ def _check_inputs(self) -> None: 'The destination Google Cloud Storage path must end with a slash "/" or be empty.' ) + def _destination_uris_for_s3_keys(self, s3_keys: list[str]) -> list[str]: + """Build ``gs://`` URIs for each transferred S3 object key.""" + gcs_bucket, _ = _parse_gcs_url(self.dest_gcs) + return [f"gs://{gcs_bucket}/{self.s3_to_gcs_object(s3_object=k)}" for k in s3_keys] + def _get_files(self, context: Context, gcs_hook: GCSHook) -> list[str]: # use the super method to list all the files in an S3 bucket/key s3_objects = super().execute(context) @@ -189,7 +194,7 @@ def _get_files(self, context: Context, gcs_hook: GCSHook) -> list[str]: return s3_objects - def execute(self, context: Context): + def execute(self, context: Context) -> list[str]: self._check_inputs() gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, @@ -206,7 +211,7 @@ def execute(self, context: Context): else: self.transfer_files(s3_objects, gcs_hook, s3_hook) - return s3_objects + return self._destination_uris_for_s3_keys(s3_objects) def exclude_existing_objects(self, s3_objects: list[str], gcs_hook: GCSHook) -> list[str]: """Excludes from the list objects that already exist in GCS bucket.""" @@ -339,14 +344,17 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> list[str] """ Handle the trigger callback when transfer jobs complete. - Returns the list of copied file paths when available (deferrable mode with + Returns the list of destination ``gs://`` URIs for copied objects when available (deferrable mode with files passed via trigger), so subsequent tasks can consume them via XCom. Returns None when event does not contain files (e.g. legacy triggers). """ if event["status"] == "error": raise AirflowException(event["message"]) self.log.info("%s completed with response %s ", self.task_id, event["message"]) - return event.get("files") + files = event.get("files") + if files is None: + return None + return self._destination_uris_for_s3_keys(files) def get_transfer_hook(self): return CloudDataTransferServiceHook( diff --git a/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py b/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py index 1e17364b6954e..472c356769b9b 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py @@ -25,6 +25,7 @@ from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import CloudDataTransferServiceHook +from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator from airflow.utils.timezone import utcnow @@ -132,8 +133,8 @@ def test_execute(self, gcs_mock_hook, s3_mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) - # we expect MOCK_FILES to be uploaded - assert sorted(MOCK_FILES) == sorted(uploaded_files) + expected_uris = [f"gs://{GCS_BUCKET}/{GCS_PREFIX}{f}" for f in MOCK_FILES] + assert sorted(expected_uris) == sorted(uploaded_files) @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") @@ -264,7 +265,9 @@ def test_execute_apply_gcs_prefix( impersonation_chain=IMPERSONATION_CHAIN, ) - assert sorted([s3_prefix + s3_object]) == sorted(uploaded_files) + dest_bucket, _ = _parse_gcs_url(gcs_destination) + expected_uri = f"gs://{dest_bucket}/{gcs_object}" + assert uploaded_files == [expected_uri] @pytest.mark.parametrize( ("s3_prefix", "gcs_destination", "apply_gcs_prefix", "expected_input", "expected_output"), @@ -499,17 +502,23 @@ def test_execute_complete_success(self, mock_log): "airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock ) def test_execute_complete_success_returns_copied_files(self, mock_log): - """Deferrable mode returns list of copied files for use in subsequent tasks via XCom.""" + """Deferrable mode returns destination GCS URIs for XCom.""" expected_files = [MOCK_FILE_1, MOCK_FILE_2] event = { "status": "success", "message": "Transfer completed", "files": expected_files, } - operator = S3ToGCSOperator(task_id=TASK_ID, bucket=S3_BUCKET) + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=S3_PREFIX, + dest_gcs=GCS_PATH_PREFIX, + ) result = operator.execute_complete(context={}, event=event) - assert result == expected_files + expected_uris = [f"gs://{GCS_BUCKET}/{GCS_PREFIX}{f}" for f in expected_files] + assert result == expected_uris @mock.patch( "airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock