Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions farm/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,41 @@ def filename_to_url(filename, cache_dir=None):
return url, etag


def download_from_s3(s3_url: str, cache_dir: str = None):
def download_from_s3(s3_url: str, cache_dir: str = None, access_key: str = None,
secret_access_key: str = None, region_name: str = None):
"""
Download a "folder" from s3 to local. Skip already existing files. Useful for downloading all files of one model
The default and recommended authentication follows boto3's trajectory of checking for ENV variables,
.aws/credentials etc. (see https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html).
However, there's also the option to pass `access_key`, `secret_access_key` and `region_name` directly
as this is needed in some enterprise enviroments with local s3 deployments.

:param s3_url: Url of the "folder" in s3 (e.g. s3://mybucket/my_modelname)
:param cache_dir: Optional local directory where the files shall be stored.
If not supplied, we'll use a subfolder in torch's cache dir (~/.cache/torch/farm)
:param access_key: Optional S3 Access Key
:param secret_access_key: Optional S3 Secret Access Key
:param region_name: Optional Region Name
:return: local path of the folder
"""

logger.info(f"Downloading from {s3_url}")
if cache_dir is None:
cache_dir = FARM_CACHE
s3_resource = boto3.resource('s3')

logger.info(f"Downloading from {s3_url} to {cache_dir}")

if access_key or secret_access_key:
assert secret_access_key and access_key, "You only supplied one of secret_access_key and access_key. We need both."

session = boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
region_name=region_name
)
s3_resource = session.resource('s3')
else:
s3_resource = boto3.resource('s3')

bucket_name, s3_path = split_s3_path(s3_url)
bucket = s3_resource.Bucket(bucket_name)
objects = bucket.objects.filter(Prefix=s3_path)
Expand Down