Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions infra/scripts/test-end-to-end-batch-dataflow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ if [[ ${TEST_EXIT_CODE} != 0 ]]; then
fi

cd ${ORIGINAL_DIR}
exit ${TEST_EXIT_CODE}

echo "
============================================================
Expand All @@ -243,4 +242,6 @@ while read line
do
echo $line
gcloud dataflow jobs cancel $line --region=${GCLOUD_REGION}
done < ingesting_jobs.txt
done < ingesting_jobs.txt

exit ${TEST_EXIT_CODE}
3 changes: 2 additions & 1 deletion infra/scripts/test-end-to-end-batch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ if [[ ${TEST_EXIT_CODE} != 0 ]]; then
fi

cd ${ORIGINAL_DIR}
exit ${TEST_EXIT_CODE}

echo "
============================================================
Expand All @@ -290,3 +289,5 @@ Cleaning up
"

bq rm -r -f ${GOOGLE_CLOUD_PROJECT}:${DATASET_NAME}

exit ${TEST_EXIT_CODE}
23 changes: 23 additions & 0 deletions tests/e2e/bq-batch-retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from feast.feature_set import FeatureSet
from feast.type_map import ValueType
from google.cloud import storage, bigquery
from google.cloud.storage import Blob
from google.protobuf.duration_pb2 import Duration
from pandavro import to_avro

Expand Down Expand Up @@ -155,6 +156,7 @@ def test_batch_get_batch_features_with_file(client):
client.ingest(file_fs1, features_1_df, timeout=480)

# Rename column (datetime -> event_timestamp)
features_1_df['datetime'] + pd.Timedelta(seconds=1) # adds buffer to avoid rounding errors
features_1_df = features_1_df.rename(columns={"datetime": "event_timestamp"})

to_avro(
Expand All @@ -169,6 +171,7 @@ def test_batch_get_batch_features_with_file(client):
)

output = feature_retrieval_job.to_dataframe()
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head())

assert output["entity_id"].to_list() == [
Expand All @@ -194,6 +197,7 @@ def test_batch_get_batch_features_with_gs_path(client, gcs_path):
client.ingest(gcs_fs1, features_1_df, timeout=360)

# Rename column (datetime -> event_timestamp)
features_1_df['datetime'] + pd.Timedelta(seconds=1) # adds buffer to avoid rounding errors
features_1_df = features_1_df.rename(columns={"datetime": "event_timestamp"})

# Output file to local
Expand All @@ -220,6 +224,8 @@ def test_batch_get_batch_features_with_gs_path(client, gcs_path):
)

output = feature_retrieval_job.to_dataframe()
clean_up_remote_files(feature_retrieval_job.get_avro_files())
blob.delete()
print(output.head())

assert output["entity_id"].to_list() == [
Expand Down Expand Up @@ -256,6 +262,7 @@ def test_batch_order_by_creation_time(client):
feature_refs=[f"{PROJECT_NAME}/feature_value3"],
)
output = feature_retrieval_job.to_dataframe()
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head())

assert output["feature_value3"].to_list() == ["CORRECT"] * N_ROWS
Expand Down Expand Up @@ -291,6 +298,7 @@ def test_batch_additional_columns_in_entity_table(client):
entity_rows=entity_df, feature_refs=[f"{PROJECT_NAME}/feature_value4"]
)
output = feature_retrieval_job.to_dataframe().sort_values(by=["entity_id"])
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head(10))

assert np.allclose(
Expand Down Expand Up @@ -336,6 +344,7 @@ def test_batch_point_in_time_correctness_join(client):
entity_rows=entity_df, feature_refs=[f"{PROJECT_NAME}/feature_value5"]
)
output = feature_retrieval_job.to_dataframe()
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head())

assert output["feature_value5"].to_list() == ["CORRECT"] * N_EXAMPLES
Expand Down Expand Up @@ -384,6 +393,7 @@ def test_batch_multiple_featureset_joins(client):
],
)
output = feature_retrieval_job.to_dataframe()
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head())

assert output["entity_id"].to_list() == [
Expand Down Expand Up @@ -417,6 +427,7 @@ def test_batch_no_max_age(client):
)

output = feature_retrieval_job.to_dataframe()
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head())

assert output["entity_id"].to_list() == output["feature_value8"].to_list()
Expand Down Expand Up @@ -499,6 +510,7 @@ def test_update_featureset_apply_featureset_and_ingest_first_subset(
)

output = feature_retrieval_job.to_dataframe().sort_values(by=["entity_id"])
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head())

assert output["update_feature1"].to_list() == subset_df["update_feature1"].to_list()
Expand Down Expand Up @@ -552,6 +564,7 @@ def test_update_featureset_update_featureset_and_ingest_second_subset(
)

output = feature_retrieval_job.to_dataframe().sort_values(by=["entity_id"])
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head())

assert output["update_feature1"].to_list() == subset_df["update_feature1"].to_list()
Expand Down Expand Up @@ -587,6 +600,7 @@ def test_update_featureset_retrieve_valid_fields(client, update_featureset_dataf
],
)
output = feature_retrieval_job.to_dataframe().sort_values(by=["entity_id"])
clean_up_remote_files(feature_retrieval_job.get_avro_files())
print(output.head(10))
assert (
output["update_feature1"].to_list()
Expand Down Expand Up @@ -623,3 +637,12 @@ def get_rows_ingested(

for row in rows:
return row["count"]


def clean_up_remote_files(files):
storage_client = storage.Client()
for file_uri in files:
if file_uri.scheme == "gs":
blob = Blob.from_string(file_uri.geturl(), client=storage_client)
blob.delete()