Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def _check_inputs(self) -> None:
'The destination Google Cloud Storage path must end with a slash "/" or be empty.'
)

def _destination_uris_for_s3_keys(self, s3_keys: list[str]) -> list[str]:
"""Build ``gs://`` URIs for each transferred S3 object key."""
gcs_bucket, _ = _parse_gcs_url(self.dest_gcs)
return [f"gs://{gcs_bucket}/{self.s3_to_gcs_object(s3_object=k)}" for k in s3_keys]

def _get_files(self, context: Context, gcs_hook: GCSHook) -> list[str]:
# use the super method to list all the files in an S3 bucket/key
s3_objects = super().execute(context)
Expand All @@ -189,7 +194,7 @@ def _get_files(self, context: Context, gcs_hook: GCSHook) -> list[str]:

return s3_objects

def execute(self, context: Context):
def execute(self, context: Context) -> list[str]:
self._check_inputs()
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -206,7 +211,7 @@ def execute(self, context: Context):
else:
self.transfer_files(s3_objects, gcs_hook, s3_hook)

return s3_objects
return self._destination_uris_for_s3_keys(s3_objects)

def exclude_existing_objects(self, s3_objects: list[str], gcs_hook: GCSHook) -> list[str]:
"""Excludes from the list objects that already exist in GCS bucket."""
Expand Down Expand Up @@ -339,14 +344,17 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> list[str]
"""
Handle the trigger callback when transfer jobs complete.

Returns the list of copied file paths when available (deferrable mode with
Returns the list of destination ``gs://`` URIs for copied objects when available (deferrable mode with
files passed via trigger), so subsequent tasks can consume them via XCom.
Returns None when event does not contain files (e.g. legacy triggers).
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info("%s completed with response %s ", self.task_id, event["message"])
return event.get("files")
files = event.get("files")
if files is None:
return None
return self._destination_uris_for_s3_keys(files)

def get_transfer_hook(self):
return CloudDataTransferServiceHook(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import CloudDataTransferServiceHook
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url
from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator
from airflow.utils.timezone import utcnow

Expand Down Expand Up @@ -132,8 +133,8 @@ def test_execute(self, gcs_mock_hook, s3_mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)

# we expect MOCK_FILES to be uploaded
assert sorted(MOCK_FILES) == sorted(uploaded_files)
expected_uris = [f"gs://{GCS_BUCKET}/{GCS_PREFIX}{f}" for f in MOCK_FILES]
assert sorted(expected_uris) == sorted(uploaded_files)

@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
Expand Down Expand Up @@ -264,7 +265,9 @@ def test_execute_apply_gcs_prefix(
impersonation_chain=IMPERSONATION_CHAIN,
)

assert sorted([s3_prefix + s3_object]) == sorted(uploaded_files)
dest_bucket, _ = _parse_gcs_url(gcs_destination)
expected_uri = f"gs://{dest_bucket}/{gcs_object}"
assert uploaded_files == [expected_uri]

@pytest.mark.parametrize(
("s3_prefix", "gcs_destination", "apply_gcs_prefix", "expected_input", "expected_output"),
Expand Down Expand Up @@ -499,17 +502,23 @@ def test_execute_complete_success(self, mock_log):
"airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock
)
def test_execute_complete_success_returns_copied_files(self, mock_log):
"""Deferrable mode returns list of copied files for use in subsequent tasks via XCom."""
"""Deferrable mode returns destination GCS URIs for XCom."""
expected_files = [MOCK_FILE_1, MOCK_FILE_2]
event = {
"status": "success",
"message": "Transfer completed",
"files": expected_files,
}
operator = S3ToGCSOperator(task_id=TASK_ID, bucket=S3_BUCKET)
operator = S3ToGCSOperator(
task_id=TASK_ID,
bucket=S3_BUCKET,
prefix=S3_PREFIX,
dest_gcs=GCS_PATH_PREFIX,
)
result = operator.execute_complete(context={}, event=event)

assert result == expected_files
expected_uris = [f"gs://{GCS_BUCKET}/{GCS_PREFIX}{f}" for f in expected_files]
assert result == expected_uris

@mock.patch(
"airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock
Expand Down
Loading