Skip to content

Commit 6aeb6d6

Browse files
jmelinavAndres March
authored andcommitted
Add S3 staging support to python client SDK (#706)
1 parent 83d9cfa commit 6aeb6d6

10 files changed

Lines changed: 396 additions & 112 deletions

File tree

sdk/python/feast/job.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import tempfile
21
import time
32
from datetime import datetime, timedelta
43
from typing import List
54
from urllib.parse import urlparse
65

76
import fastavro
87
import pandas as pd
9-
from google.cloud import storage
108
from google.protobuf.json_format import MessageToJson
119

1210
from feast.core.CoreService_pb2 import ListIngestionJobsRequest
@@ -23,6 +21,7 @@
2321
from feast.serving.ServingService_pb2 import Job as JobProto
2422
from feast.serving.ServingService_pb2_grpc import ServingServiceStub
2523
from feast.source import Source
24+
from feast.staging.staging_strategy import StagingStrategy
2625

2726
# Maximum no of seconds to wait until the retrieval jobs status is DONE in Feast
2827
# Currently set to the maximum query execution time limit in BigQuery
@@ -47,8 +46,7 @@ def __init__(
4746
"""
4847
self.job_proto = job_proto
4948
self.serving_stub = serving_stub
50-
# TODO: abstract away GCP depedency
51-
self.gcs_client = storage.Client(project=None)
49+
self.staging_strategy = StagingStrategy()
5250

5351
@property
5452
def id(self):
@@ -126,16 +124,7 @@ def result(self, timeout_sec: int = DEFAULT_TIMEOUT_SEC):
126124
"""
127125
uris = self.get_avro_files(timeout_sec)
128126
for file_uri in uris:
129-
if file_uri.scheme == "gs":
130-
file_obj = tempfile.TemporaryFile()
131-
self.gcs_client.download_blob_to_file(file_uri.geturl(), file_obj)
132-
elif file_uri.scheme == "file":
133-
file_obj = open(file_uri.path, "rb")
134-
else:
135-
raise Exception(
136-
f"Could not identify file URI {file_uri}. Only gs:// and file:// supported"
137-
)
138-
127+
file_obj = self.staging_strategy.execute_file_download(file_uri)
139128
file_obj.seek(0)
140129
avro_reader = fastavro.reader(file_obj)
141130

sdk/python/feast/loaders/file.py

Lines changed: 20 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
# limitations under the License.
1414

1515
import os
16-
import re
1716
import shutil
1817
import tempfile
1918
import uuid
2019
from datetime import datetime
2120
from typing import List, Optional, Tuple, Union
22-
from urllib.parse import ParseResult, urlparse
21+
from urllib.parse import urlparse
2322

2423
import pandas as pd
25-
from google.cloud import storage
2624
from pandavro import to_avro
2725

26+
from feast.staging.staging_strategy import StagingStrategy
27+
2828

2929
def export_source_to_staging_location(
3030
source: Union[pd.DataFrame, str], staging_location_uri: str
@@ -58,6 +58,7 @@ def export_source_to_staging_location(
5858
remote staging location.
5959
"""
6060

61+
staging_strategy = StagingStrategy()
6162
uri = urlparse(staging_location_uri)
6263

6364
# Prepare Avro file to be exported to staging location
@@ -66,47 +67,34 @@ def export_source_to_staging_location(
6667
uri_path = None # type: Optional[str]
6768
if uri.scheme == "file":
6869
uri_path = uri.path
69-
7070
# Remote gs staging location provided by serving
7171
dir_path, file_name, source_path = export_dataframe_to_local(
7272
df=source, dir_path=uri_path
7373
)
74-
elif urlparse(source).scheme in ["", "file"]:
75-
# Local file provided as a source
76-
dir_path = None
77-
file_name = os.path.basename(source)
78-
source_path = os.path.abspath(
79-
os.path.join(urlparse(source).netloc, urlparse(source).path)
80-
)
81-
elif urlparse(source).scheme == "gs":
82-
# Google Cloud Storage path provided
83-
input_source_uri = urlparse(source)
84-
if "*" in source:
85-
# Wildcard path
86-
return _get_files(bucket=input_source_uri.hostname, uri=input_source_uri)
74+
elif isinstance(source, str):
75+
if urlparse(source).scheme in ["", "file"]:
76+
# Local file provided as a source
77+
dir_path = None
78+
file_name = os.path.basename(source)
79+
source_path = os.path.abspath(
80+
os.path.join(urlparse(source).netloc, urlparse(source).path)
81+
)
8782
else:
88-
return [source]
83+
# gs, s3 file provided as a source.
84+
return staging_strategy.execute_get_source_files(source)
8985
else:
9086
raise Exception(
9187
f"Only string and DataFrame types are allowed as a "
9288
f"source, {type(source)} was provided."
9389
)
9490

9591
# Push data to required staging location
96-
if uri.scheme == "gs":
97-
# Staging location is a Google Cloud Storage path
98-
upload_file_to_gcs(
99-
source_path, uri.hostname, str(uri.path).strip("/") + "/" + file_name
100-
)
101-
elif uri.scheme == "file":
102-
# Staging location is a file path
103-
# Used for end-to-end test
104-
pass
105-
else:
106-
raise Exception(
107-
f"Staging location {staging_location_uri} does not have a "
108-
f"valid URI. Only gs:// and file:// uri scheme are supported."
109-
)
92+
staging_strategy.execute_file_upload(
93+
uri.scheme,
94+
source_path,
95+
uri.hostname,
96+
str(uri.path).strip("/") + "/" + file_name,
97+
)
11098

11199
# Clean up, remove local staging file
112100
if dir_path and isinstance(source, pd.DataFrame) and len(str(dir_path)) > 4:
@@ -160,70 +148,6 @@ def export_dataframe_to_local(
160148
return dir_path, file_name, dest_path
161149

162150

163-
def upload_file_to_gcs(local_path: str, bucket: str, remote_path: str) -> None:
164-
"""
165-
Upload a file from the local file system to Google Cloud Storage (GCS).
166-
167-
Args:
168-
local_path (str):
169-
Local filesystem path of file to upload.
170-
171-
bucket (str):
172-
GCS bucket destination to upload to.
173-
174-
remote_path (str):
175-
Path within GCS bucket to upload file to, includes file name.
176-
177-
Returns:
178-
None:
179-
None
180-
"""
181-
182-
storage_client = storage.Client(project=None)
183-
bucket = storage_client.get_bucket(bucket)
184-
blob = bucket.blob(remote_path)
185-
blob.upload_from_filename(local_path)
186-
187-
188-
def _get_files(bucket: str, uri: ParseResult) -> List[str]:
189-
"""
190-
List all available files within a Google storage bucket that matches a wild
191-
card path.
192-
193-
Args:
194-
bucket (str):
195-
Google Storage bucket to reference.
196-
197-
uri (urllib.parse.ParseResult):
198-
Wild card uri path containing the "*" character.
199-
Example:
200-
* gs://feast/staging_location/*
201-
* gs://feast/staging_location/file_*.avro
202-
203-
Returns:
204-
List[str]:
205-
List of all available files matching the wildcard path.
206-
"""
207-
208-
storage_client = storage.Client(project=None)
209-
bucket = storage_client.get_bucket(bucket)
210-
path = uri.path
211-
212-
if "*" in path:
213-
regex = re.compile(path.replace("*", ".*?").strip("/"))
214-
blob_list = bucket.list_blobs(
215-
prefix=path.strip("/").split("*")[0], delimiter="/"
216-
)
217-
# File path should not be in path (file path must be longer than path)
218-
return [
219-
f"{uri.scheme}://{uri.hostname}/{file}"
220-
for file in [x.name for x in blob_list]
221-
if re.match(regex, file) and file not in path
222-
]
223-
else:
224-
raise Exception(f"{path} is not a wildcard path")
225-
226-
227151
def _get_file_name() -> str:
228152
"""
229153
Create a random file name.

sdk/python/feast/staging/__init__.py

Whitespace-only changes.
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import re
2+
from abc import ABC, ABCMeta, abstractmethod
3+
from enum import Enum
4+
from tempfile import TemporaryFile
5+
from typing import List
6+
from urllib.parse import ParseResult, urlparse
7+
8+
import boto3
9+
from google.cloud import storage
10+
11+
12+
class PROTOCOL(Enum):
13+
GS = "gs"
14+
S3 = "s3"
15+
LOCAL_FILE = "file"
16+
17+
18+
class StagingStrategy:
19+
def __init__(self):
20+
self._protocol_dict = dict()
21+
22+
def execute_file_download(self, file_uri: ParseResult) -> TemporaryFile:
23+
protocol = self._get_staging_protocol(file_uri.scheme)
24+
return protocol.download_file(file_uri)
25+
26+
def execute_get_source_files(self, source: str) -> List[str]:
27+
uri = urlparse(source)
28+
if "*" in uri.path:
29+
protocol = self._get_staging_protocol(uri.scheme)
30+
return protocol.list_files(bucket=uri.hostname, uri=uri)
31+
elif PROTOCOL(uri.scheme) in [PROTOCOL.S3, PROTOCOL.GS]:
32+
return [source]
33+
else:
34+
raise Exception(
35+
f"Could not identify file protocol {uri.scheme}. Only gs:// and file:// and s3:// supported"
36+
)
37+
38+
def execute_file_upload(
39+
self, scheme: str, local_path: str, bucket: str, remote_path: str
40+
):
41+
protocol = self._get_staging_protocol(scheme)
42+
return protocol.upload_file(local_path, bucket, remote_path)
43+
44+
def _get_staging_protocol(self, protocol):
45+
if protocol in self._protocol_dict:
46+
return self._protocol_dict[protocol]
47+
else:
48+
if PROTOCOL(protocol) == PROTOCOL.GS:
49+
self._protocol_dict[protocol] = GCSProtocol()
50+
elif PROTOCOL(protocol) == PROTOCOL.S3:
51+
self._protocol_dict[protocol] = S3Protocol()
52+
elif PROTOCOL(protocol) == PROTOCOL.LOCAL_FILE:
53+
self._protocol_dict[protocol] = LocalFSProtocol()
54+
else:
55+
raise Exception(
56+
f"Could not identify file protocol {protocol}. Only gs:// and file:// and s3:// supported"
57+
)
58+
return self._protocol_dict[protocol]
59+
60+
61+
class AbstractStagingProtocol(ABC):
62+
63+
__metaclass__ = ABCMeta
64+
65+
@abstractmethod
66+
def __init__(self):
67+
pass
68+
69+
@abstractmethod
70+
def download_file(self, uri: ParseResult) -> TemporaryFile:
71+
pass
72+
73+
@abstractmethod
74+
def list_files(self, bucket: str, uri: ParseResult) -> List[str]:
75+
pass
76+
77+
@abstractmethod
78+
def upload_file(self, local_path: str, bucket: str, remote_path: str):
79+
pass
80+
81+
82+
class GCSProtocol(AbstractStagingProtocol):
83+
def __init__(self):
84+
self.gcs_client = storage.Client(project=None)
85+
86+
def download_file(self, uri: ParseResult) -> TemporaryFile:
87+
url = uri.geturl()
88+
file_obj = TemporaryFile()
89+
self.gcs_client.download_blob_to_file(url, file_obj)
90+
return file_obj
91+
92+
def list_files(self, bucket: str, uri: ParseResult) -> List[str]:
93+
bucket = self.gcs_client.get_bucket(bucket)
94+
path = uri.path
95+
96+
if "*" in path:
97+
regex = re.compile(path.replace("*", ".*?").strip("/"))
98+
blob_list = bucket.list_blobs(
99+
prefix=path.strip("/").split("*")[0], delimiter="/"
100+
)
101+
# File path should not be in path (file path must be longer than path)
102+
return [
103+
f"{uri.scheme}://{uri.hostname}/{file}"
104+
for file in [x.name for x in blob_list]
105+
if re.match(regex, file) and file not in path
106+
]
107+
else:
108+
raise Exception(f"{path} is not a wildcard path")
109+
110+
def upload_file(self, local_path: str, bucket: str, remote_path: str):
111+
bucket = self.gcs_client.get_bucket(bucket)
112+
blob = bucket.blob(remote_path)
113+
blob.upload_from_filename(local_path)
114+
115+
116+
class S3Protocol(AbstractStagingProtocol):
117+
def __init__(self):
118+
self.s3_client = boto3.client("s3")
119+
120+
def download_file(self, uri: ParseResult) -> TemporaryFile:
121+
url = uri.path[1:] # removing leading / from the path
122+
bucket = uri.hostname
123+
file_obj = TemporaryFile()
124+
self.s3_client.download_fileobj(bucket, url, file_obj)
125+
return file_obj
126+
127+
def list_files(self, bucket: str, uri: ParseResult) -> List[str]:
128+
path = uri.path
129+
130+
if "*" in path:
131+
regex = re.compile(path.replace("*", ".*?").strip("/"))
132+
blob_list = self.s3_client.list_objects(
133+
Bucket=bucket, Prefix=path.strip("/").split("*")[0], Delimiter="/"
134+
)
135+
# File path should not be in path (file path must be longer than path)
136+
return [
137+
f"{uri.scheme}://{uri.hostname}/{file}"
138+
for file in [x["Key"] for x in blob_list["Contents"]]
139+
if re.match(regex, file) and file not in path
140+
]
141+
else:
142+
raise Exception(f"{path} is not a wildcard path")
143+
144+
def upload_file(self, local_path: str, bucket: str, remote_path: str):
145+
with open(local_path, "rb") as file:
146+
self.s3_client.upload_fileobj(file, bucket, remote_path)
147+
148+
149+
class LocalFSProtocol(AbstractStagingProtocol):
150+
def __init__(self):
151+
pass
152+
153+
def download_file(self, file_uri: ParseResult) -> TemporaryFile:
154+
url = file_uri.path
155+
file_obj = open(url, "rb")
156+
return file_obj
157+
158+
def list_files(self, bucket: str, uri: ParseResult) -> List[str]:
159+
raise NotImplementedError("list file not implemented for Local file")
160+
161+
def upload_file(self, local_path: str, bucket: str, remote_path: str):
162+
pass # For test cases

sdk/python/requirements-ci.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ pytest-timeout
1010
pytest-ordering==0.6.*
1111
pandas==0.*
1212
mock==2.0.0
13-
pandavro==1.5.*
13+
pandavro==1.5.*
14+
moto

sdk/python/requirements-dev.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,6 @@ mypy
3535
mypy-protobuf
3636
pre-commit
3737
flake8
38-
black
38+
black
39+
boto3
40+
moto

0 commit comments

Comments
 (0)