diff --git a/core/src/main/java/feast/core/config/TrainingConfig.java b/core/src/main/java/feast/core/config/TrainingConfig.java index 6fdfe31b90c..9bcff052f1b 100644 --- a/core/src/main/java/feast/core/config/TrainingConfig.java +++ b/core/src/main/java/feast/core/config/TrainingConfig.java @@ -1,7 +1,5 @@ package feast.core.config; -import com.google.cloud.bigquery.BigQuery; -import com.google.cloud.bigquery.BigQueryOptions; import com.google.common.base.Charsets; import com.google.common.io.CharStreams; import com.hubspot.jinjava.Jinjava; @@ -9,6 +7,7 @@ import feast.core.dao.FeatureInfoRepository; import feast.core.training.BigQueryDatasetTemplater; import feast.core.training.BigQueryTraningDatasetCreator; +import feast.core.util.RandomUuidProvider; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -18,9 +17,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; -/** - * Configuration related to training API - */ +/** Configuration related to training API */ @Configuration public class TrainingConfig { @@ -37,10 +34,9 @@ public BigQueryDatasetTemplater getBigQueryTrainingDatasetTemplater( @Bean public BigQueryTraningDatasetCreator getBigQueryTrainingDatasetCreator( BigQueryDatasetTemplater templater, - StorageSpecs storageSpecs, @Value("${feast.core.projectId}") String projectId, @Value("${feast.core.datasetPrefix}") String datasetPrefix) { - BigQuery bigquery = BigQueryOptions.newBuilder().setProjectId(projectId).build().getService(); - return new BigQueryTraningDatasetCreator(templater, projectId, datasetPrefix); + return new BigQueryTraningDatasetCreator( + templater, projectId, datasetPrefix, new RandomUuidProvider()); } } diff --git a/core/src/main/java/feast/core/grpc/DatasetServiceImpl.java b/core/src/main/java/feast/core/grpc/DatasetServiceImpl.java index 20081cc809d..a4211de9726 100644 --- a/core/src/main/java/feast/core/grpc/DatasetServiceImpl.java +++ b/core/src/main/java/feast/core/grpc/DatasetServiceImpl.java @@ -66,7 +66,8 @@ public void createDataset( request.getStartDate(), request.getEndDate(), request.getLimit(), - request.getNamePrefix()); + request.getNamePrefix(), + request.getFiltersMap()); CreateDatasetResponse response = CreateDatasetResponse.newBuilder().setDatasetInfo(datasetInfo).build(); diff --git a/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java b/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java index d5f0019f26d..7edf453b5a1 100644 --- a/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java +++ b/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java @@ -23,18 +23,18 @@ import feast.core.model.FeatureInfo; import feast.core.storage.BigQueryStorageManager; import feast.specs.StorageSpecProto.StorageSpec; +import feast.types.ValueProto.ValueType.Enum; import java.time.Instant; import java.time.ZoneId; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; import java.util.stream.Collectors; -import lombok.Getter; - public class BigQueryDatasetTemplater { @@ -45,7 +45,9 @@ public class BigQueryDatasetTemplater { private final DateTimeFormatter formatter; public BigQueryDatasetTemplater( - Jinjava jinjava, String templateString, StorageSpec storageSpec, + Jinjava jinjava, + String templateString, + StorageSpec storageSpec, FeatureInfoRepository featureInfoRepository) { this.storageSpec = storageSpec; this.featureInfoRepository = featureInfoRepository; @@ -65,35 +67,84 @@ protected StorageSpec getStorageSpec() { * @param startDate start date * @param endDate end date * @param limit limit + * @param filters additional WHERE clause * @return SQL query for creating training table. */ - String createQuery(FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) { + String createQuery( + FeatureSet featureSet, + Timestamp startDate, + Timestamp endDate, + long limit, + Map filters) { List featureIds = featureSet.getFeatureIdsList(); - List featureInfos = featureInfoRepository.findAllById(featureIds); - String tableId = featureInfos.size() > 0 ? getBqTableId(featureInfos.get(0)) : ""; - Features features = new Features(featureInfos, tableId); + List featureInfos = getFeatureInfosOrThrow(featureIds); + + // split filter based on ValueType of the feature + Map tmpFilter = new HashMap<>(filters); + Map numberFilters = new HashMap<>(); + Map stringFilters = new HashMap<>(); + if (filters.containsKey("job_id")) { + stringFilters.put("job_id", tmpFilter.get("job_id")); + tmpFilter.remove("job_id"); + } + + List featureFilterInfos = getFeatureInfosOrThrow(new ArrayList<>(tmpFilter.keySet())); + Map featureInfoMap = new HashMap<>(); + for (FeatureInfo featureInfo: featureFilterInfos) { + featureInfoMap.put(featureInfo.getId(), featureInfo); + } + + + for (Map.Entry filter : tmpFilter.entrySet()) { + FeatureInfo featureInfo = featureInfoMap.get(filter.getKey()); + if (isMappableToString(featureInfo.getValueType())) { + stringFilters.put(featureInfo.getName(), filter.getValue()); + } else { + numberFilters.put(featureInfo.getName(), filter.getValue()); + } + } + List featureNames = getFeatureNames(featureInfos); + String tableId = getBqTableId(featureInfos.get(0)); + String startDateStr = formatDateString(startDate); + String endDateStr = formatDateString(endDate); + String limitStr = (limit != 0) ? String.valueOf(limit) : null; + return renderTemplate(tableId, featureNames, startDateStr, endDateStr, limitStr, + numberFilters, stringFilters); + } + + private boolean isMappableToString(Enum valueType) { + return valueType.equals(Enum.STRING); + } + + private List getFeatureNames(List featureInfos) { + return featureInfos.stream().map(FeatureInfo::getName).collect(Collectors.toList()); + } + + private List getFeatureInfosOrThrow(List featureIds) { + List featureInfos = featureInfoRepository.findAllById(featureIds); if (featureInfos.size() < featureIds.size()) { Set foundFeatureIds = featureInfos.stream().map(FeatureInfo::getId).collect(Collectors.toSet()); featureIds.removeAll(foundFeatureIds); throw new NoSuchElementException("features not found: " + featureIds); } - - String startDateStr = formatDateString(startDate); - String endDateStr = formatDateString(endDate); - String limitStr = (limit != 0) ? String.valueOf(limit) : null; - return renderTemplate(features, startDateStr, endDateStr, limitStr); + return featureInfos; } private String renderTemplate( - Features features, String startDateStr, String endDateStr, String limitStr) { + String tableId, List features, String startDateStr, String endDateStr, String limitStr, + Map numberFilters, + Map stringFilters) { Map context = new HashMap<>(); - context.put("feature_set", features); + context.put("table_id", tableId); + context.put("features", features); context.put("start_date", startDateStr); context.put("end_date", endDateStr); context.put("limit", limitStr); + context.put("number_filters", numberFilters); + context.put("string_filters", stringFilters); return jinjava.render(template, context); } @@ -117,16 +168,4 @@ private String formatDateString(Timestamp timestamp) { Instant instant = Instant.ofEpochSecond(timestamp.getSeconds()).truncatedTo(ChronoUnit.DAYS); return formatter.format(instant); } - - @Getter - static final class Features { - - final List columns; - final String tableId; - - Features(List featureInfos, String tableId) { - columns = featureInfos.stream().map(FeatureInfo::getName).collect(Collectors.toList()); - this.tableId = tableId; - } - } } diff --git a/core/src/main/java/feast/core/training/BigQueryTraningDatasetCreator.java b/core/src/main/java/feast/core/training/BigQueryTraningDatasetCreator.java index 5414364d40b..4584aa7987b 100644 --- a/core/src/main/java/feast/core/training/BigQueryTraningDatasetCreator.java +++ b/core/src/main/java/feast/core/training/BigQueryTraningDatasetCreator.java @@ -30,15 +30,10 @@ import feast.core.DatasetServiceProto.DatasetInfo; import feast.core.DatasetServiceProto.FeatureSet; import feast.core.exception.TrainingDatasetCreationException; -import java.math.BigInteger; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; +import feast.core.util.UuidProvider; import java.time.Instant; import java.time.ZoneId; import java.time.format.DateTimeFormatter; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Map; import lombok.extern.slf4j.Slf4j; @@ -49,13 +44,19 @@ public class BigQueryTraningDatasetCreator { private final DateTimeFormatter formatter; private final String projectId; private final String datasetPrefix; + private final UuidProvider uuidProvider; private transient BigQuery bigQuery; public BigQueryTraningDatasetCreator( BigQueryDatasetTemplater templater, String projectId, - String datasetPrefix) { - this(templater, projectId, datasetPrefix, + String datasetPrefix, + UuidProvider uuidProvider) { + this( + templater, + projectId, + datasetPrefix, + uuidProvider, BigQueryOptions.newBuilder().setProjectId(projectId).build().getService()); } @@ -63,12 +64,14 @@ public BigQueryTraningDatasetCreator( BigQueryDatasetTemplater templater, String projectId, String datasetPrefix, + UuidProvider uuidProvider, BigQuery bigQuery) { this.templater = templater; this.formatter = DateTimeFormatter.ofPattern("yyyyMMdd").withZone(ZoneId.of("UTC")); this.projectId = projectId; this.datasetPrefix = datasetPrefix; this.bigQuery = bigQuery; + this.uuidProvider = uuidProvider; } /** @@ -80,6 +83,7 @@ public BigQueryTraningDatasetCreator( * @param endDate end date of the training dataset (inclusive) * @param limit maximum number of row should be created. * @param namePrefix prefix for dataset name + * @param filters additional where clause * @return dataset info associated with the created training dataset */ public DatasetInfo createDataset( @@ -87,11 +91,11 @@ public DatasetInfo createDataset( Timestamp startDate, Timestamp endDate, long limit, - String namePrefix) { + String namePrefix, + Map filters) { try { - String query = templater.createQuery(featureSet, startDate, endDate, limit); - String tableName = - createBqTableName(datasetPrefix, featureSet, startDate, endDate, namePrefix); + String query = templater.createQuery(featureSet, startDate, endDate, limit, filters); + String tableName = createBqTableName(datasetPrefix, featureSet, namePrefix); String tableDescription = createBqTableDescription(featureSet, startDate, endDate, query); Map options = templater.getStorageSpec().getOptionsMap(); @@ -124,47 +128,22 @@ public DatasetInfo createDataset( throw new TrainingDatasetCreationException("Failed creating training dataset", e); } catch (InterruptedException e) { log.error("Training dataset creation was interrupted", e); - throw new TrainingDatasetCreationException("Training dataset creation was interrupted", - e); + throw new TrainingDatasetCreationException("Training dataset creation was interrupted", e); } } - private String createBqTableName( - String datasetPrefix, - FeatureSet featureSet, - Timestamp startDate, - Timestamp endDate, - String namePrefix) { - - List features = new ArrayList(featureSet.getFeatureIdsList()); - Collections.sort(features); + private String createBqTableName(String datasetPrefix, FeatureSet featureSet, String namePrefix) { - String datasetId = String.format("%s_%s_%s", features, startDate, endDate); - StringBuilder hashText; - - // create hash from datasetId - try { - MessageDigest md = MessageDigest.getInstance("SHA-1"); - byte[] messageDigest = md.digest(datasetId.getBytes()); - BigInteger no = new BigInteger(1, messageDigest); - hashText = new StringBuilder(no.toString(16)); - while (hashText.length() < 32) { - hashText.insert(0, "0"); - } - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException(e); - } + String suffix = uuidProvider.getUuid(); if (!Strings.isNullOrEmpty(namePrefix)) { // only alphanumeric and underscore are allowed namePrefix = namePrefix.replaceAll("[^a-zA-Z0-9_]", "_"); return String.format( - "%s_%s_%s_%s", datasetPrefix, featureSet.getEntityName(), namePrefix, - hashText.toString()); + "%s_%s_%s_%s", datasetPrefix, featureSet.getEntityName(), namePrefix, suffix); } - return String.format( - "%s_%s_%s", datasetPrefix, featureSet.getEntityName(), hashText.toString()); + return String.format("%s_%s_%s", datasetPrefix, featureSet.getEntityName(), suffix); } private String createBqTableDescription( diff --git a/core/src/main/java/feast/core/util/RandomUuidProvider.java b/core/src/main/java/feast/core/util/RandomUuidProvider.java new file mode 100644 index 00000000000..67e155d5775 --- /dev/null +++ b/core/src/main/java/feast/core/util/RandomUuidProvider.java @@ -0,0 +1,10 @@ +package feast.core.util; + +import java.util.UUID; + +public class RandomUuidProvider implements UuidProvider { + @Override + public String getUuid() { + return UUID.randomUUID().toString().replace("-",""); + } +} diff --git a/core/src/main/java/feast/core/util/UuidProvider.java b/core/src/main/java/feast/core/util/UuidProvider.java new file mode 100644 index 00000000000..c537560d1c6 --- /dev/null +++ b/core/src/main/java/feast/core/util/UuidProvider.java @@ -0,0 +1,5 @@ +package feast.core.util; + +public interface UuidProvider { + String getUuid(); +} diff --git a/core/src/main/resources/templates/bq_training.tmpl b/core/src/main/resources/templates/bq_training.tmpl index df7c301ae2b..a92666e91bd 100644 --- a/core/src/main/resources/templates/bq_training.tmpl +++ b/core/src/main/resources/templates/bq_training.tmpl @@ -1,10 +1,12 @@ SELECT id, - event_timestamp{%- if feature_set.columns | length > 0 %},{%- endif %} - {{ feature_set.columns | join(',') }} + event_timestamp{%- if features | length > 0 %},{%- endif %} + {{ features | join(',') }} FROM - `{{ feature_set.tableId }}` + `{{ table_id }}` WHERE event_timestamp >= TIMESTAMP("{{ start_date }}") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("{{ end_date }}", INTERVAL 1 DAY)) +{%- for key, val in number_filters.items() %} AND {{ key }} = {{ val }} {%- endfor %} +{%- for key, val in string_filters.items() %} AND {{ key }} = "{{ val }}" {%- endfor %} {% if limit is not none -%} LIMIT {{ limit }} {%- endif %} \ No newline at end of file diff --git a/core/src/test/java/feast/core/grpc/DatasetServiceImplTest.java b/core/src/test/java/feast/core/grpc/DatasetServiceImplTest.java index 91baf76f219..88bea40d9e8 100644 --- a/core/src/test/java/feast/core/grpc/DatasetServiceImplTest.java +++ b/core/src/test/java/feast/core/grpc/DatasetServiceImplTest.java @@ -4,6 +4,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -21,6 +22,9 @@ import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.testing.GrpcCleanupRule; import java.text.ParseException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -78,7 +82,8 @@ public void shouldCallcreateDatasetWithCorrectRequest() { any(Timestamp.class), any(Timestamp.class), anyLong(), - anyString())) + anyString(), + anyMap())) .thenReturn(datasetInfo); long limit = 9999; @@ -95,7 +100,44 @@ public void shouldCallcreateDatasetWithCorrectRequest() { client.createDataset(request); verify(trainingDatasetCreator) - .createDataset(validFeatureSet, validStartDate, validEndDate, limit, namePrefix); + .createDataset(validFeatureSet, validStartDate, validEndDate, limit, namePrefix, Collections + .emptyMap()); + } + + @SuppressWarnings("ResultOfMethodCallIgnored") + @Test + public void shouldCallcreateDatasetWithCorrectRequestWithFilters() { + DatasetInfo datasetInfo = + DatasetInfo.newBuilder().setName("mydataset").setTableUrl("project.dataset.table").build(); + when(trainingDatasetCreator.createDataset( + any(FeatureSet.class), + any(Timestamp.class), + any(Timestamp.class), + anyLong(), + anyString(), + anyMap())) + .thenReturn(datasetInfo); + + long limit = 9999; + String namePrefix = "mydataset"; + Map filters = new HashMap<>(); + filters.put("key1", "value1"); + filters.put("key2", "value2"); + CreateDatasetRequest request = + CreateDatasetRequest.newBuilder() + .setFeatureSet(validFeatureSet) + .setStartDate(validStartDate) + .setEndDate(validEndDate) + .setLimit(limit) + .setNamePrefix(namePrefix) + .putAllFilters(filters) + .build(); + + client.createDataset(request); + + + verify(trainingDatasetCreator) + .createDataset(validFeatureSet, validStartDate, validEndDate, limit, namePrefix, filters); } @Test @@ -107,7 +149,8 @@ public void shouldPropagateCreatedDatasetInfo() { any(Timestamp.class), any(Timestamp.class), anyLong(), - anyString())) + anyString(), + anyMap())) .thenReturn(datasetInfo); long limit = 9999; diff --git a/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java b/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java index 5a072eaf90c..90d73834f15 100644 --- a/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java +++ b/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java @@ -33,16 +33,18 @@ import feast.core.model.EntityInfo; import feast.core.model.FeatureInfo; import feast.core.storage.BigQueryStorageManager; -import feast.core.training.BigQueryDatasetTemplater.Features; import feast.specs.EntitySpecProto.EntitySpec; import feast.specs.FeatureSpecProto.FeatureSpec; import feast.specs.StorageSpecProto.StorageSpec; +import feast.types.ValueProto.ValueType; +import feast.types.ValueProto.ValueType.Enum; import java.io.InputStream; import java.io.InputStreamReader; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; @@ -60,26 +62,26 @@ public class BigQueryDatasetTemplaterTest { private BigQueryDatasetTemplater templater; private BasicFormatterImpl formatter = new BasicFormatterImpl(); - @Mock - private FeatureInfoRepository featureInfoRespository; + @Mock private FeatureInfoRepository featureInfoRespository; private String sqlTemplate; @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); - StorageSpec storageSpec = StorageSpec.newBuilder() - .setId("BIGQUERY1") - .setType(BigQueryStorageManager.TYPE) - .putOptions("project", "project") - .putOptions("dataset", "dataset") - .build(); + StorageSpec storageSpec = + StorageSpec.newBuilder() + .setId("BIGQUERY1") + .setType(BigQueryStorageManager.TYPE) + .putOptions("project", "project") + .putOptions("dataset", "dataset") + .build(); Jinjava jinjava = new Jinjava(); Resource resource = new ClassPathResource("templates/bq_training.tmpl"); InputStream resourceInputStream = resource.getInputStream(); sqlTemplate = CharStreams.toString(new InputStreamReader(resourceInputStream, Charsets.UTF_8)); - templater = new BigQueryDatasetTemplater(jinjava, sqlTemplate, storageSpec, - featureInfoRespository); + templater = + new BigQueryDatasetTemplater(jinjava, sqlTemplate, storageSpec, featureInfoRespository); } @Test(expected = NoSuchElementException.class) @@ -89,21 +91,23 @@ public void shouldThrowNoSuchElementExceptionIfFeatureNotFound() { .setEntityName("myentity") .addAllFeatureIds(Arrays.asList("myentity.feature1", "myentity.feature2")) .build(); - templater.createQuery(fs, Timestamps.fromSeconds(0), Timestamps.fromSeconds(1), 0); + templater.createQuery( + fs, Timestamps.fromSeconds(0), Timestamps.fromSeconds(1), 0, Collections.emptyMap()); } @Test public void shouldPassCorrectArgumentToTemplateEngine() { - StorageSpec storageSpec = StorageSpec.newBuilder() - .setId("BIGQUERY1") - .setType(BigQueryStorageManager.TYPE) - .putOptions("project", "project") - .putOptions("dataset", "dataset") - .build(); + StorageSpec storageSpec = + StorageSpec.newBuilder() + .setId("BIGQUERY1") + .setType(BigQueryStorageManager.TYPE) + .putOptions("project", "project") + .putOptions("dataset", "dataset") + .build(); Jinjava jinjava = mock(Jinjava.class); - templater = new BigQueryDatasetTemplater(jinjava, sqlTemplate, storageSpec, - featureInfoRespository); + templater = + new BigQueryDatasetTemplater(jinjava, sqlTemplate, storageSpec, featureInfoRespository); Timestamp startDate = Timestamps.fromSeconds(Instant.parse("2018-01-01T00:00:00.00Z").getEpochSecond()); @@ -114,7 +118,7 @@ public void shouldPassCorrectArgumentToTemplateEngine() { String featureName = "feature1"; when(featureInfoRespository.findAllById(any(List.class))) - .thenReturn(Collections.singletonList(createFeatureInfo(featureId, featureName))); + .thenReturn(Collections.singletonList(createFeatureInfo(featureId, featureName, Enum.INT64))); FeatureSet fs = FeatureSet.newBuilder() @@ -122,7 +126,7 @@ public void shouldPassCorrectArgumentToTemplateEngine() { .addAllFeatureIds(Arrays.asList(featureId)) .build(); - templater.createQuery(fs, startDate, endDate, limit); + templater.createQuery(fs, startDate, endDate, limit, Collections.emptyMap()); ArgumentCaptor templateArg = ArgumentCaptor.forClass(String.class); ArgumentCaptor> contextArg = ArgumentCaptor.forClass(Map.class); @@ -136,9 +140,8 @@ public void shouldPassCorrectArgumentToTemplateEngine() { assertThat(actualContext.get("end_date"), equalTo("2019-01-01")); assertThat(actualContext.get("limit"), equalTo(String.valueOf(limit))); - Features features = (Features) actualContext.get("feature_set"); - assertThat(features.getColumns().size(), equalTo(1)); - assertThat(features.getColumns().get(0), equalTo(featureName)); + List features = (List) actualContext.get("features"); + assertThat(features.get(0), equalTo(featureName)); } @Test @@ -148,12 +151,12 @@ public void shouldRenderCorrectQuery1() throws Exception { String featureId2 = "myentity.feature2"; String featureName2 = "feature2"; - FeatureInfo featureInfo1 = createFeatureInfo(featureId1, featureName1); - FeatureInfo featureInfo2 = createFeatureInfo(featureId2, featureName2); + FeatureInfo featureInfo1 = createFeatureInfo(featureId1, featureName1, Enum.INT64); + FeatureInfo featureInfo2 = createFeatureInfo(featureId2, featureName2, Enum.INT64); String featureId3 = "myentity.feature3"; String featureName3 = "feature3"; - FeatureInfo featureInfo3 = createFeatureInfo(featureId3, featureName3); + FeatureInfo featureInfo3 = createFeatureInfo(featureId3, featureName3, Enum.INT64); when(featureInfoRespository.findAllById(any(List.class))) .thenReturn(Arrays.asList(featureInfo1, featureInfo2, featureInfo3)); @@ -169,7 +172,7 @@ public void shouldRenderCorrectQuery1() throws Exception { Timestamps.fromSeconds(Instant.parse("2018-01-30T12:11:11.00Z").getEpochSecond()); int limit = 100; - String query = templater.createQuery(fs, startDate, endDate, limit); + String query = templater.createQuery(fs, startDate, endDate, limit, Collections.emptyMap()); checkExpectedQuery(query, "expQuery1.sql"); } @@ -182,7 +185,7 @@ public void shouldRenderCorrectQuery2() throws Exception { String featureId = "myentity.feature1"; String featureName = "feature1"; - featureInfos.add(createFeatureInfo(featureId, featureName)); + featureInfos.add(createFeatureInfo(featureId, featureName, Enum.INT64)); featureIds.add(featureId); when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos); @@ -194,11 +197,149 @@ public void shouldRenderCorrectQuery2() throws Exception { FeatureSet featureSet = FeatureSet.newBuilder().setEntityName("myentity").addAllFeatureIds(featureIds).build(); - String query = templater.createQuery(featureSet, startDate, endDate, 1000); + String query = + templater.createQuery(featureSet, startDate, endDate, 1000, Collections.emptyMap()); checkExpectedQuery(query, "expQuery2.sql"); } + @Test + public void shouldRenderCorrectQueryWithNumberFilter() throws Exception { + List featureInfos = new ArrayList<>(); + List featureIds = new ArrayList<>(); + + String featureId = "myentity.feature1"; + String featureId2 = "myentity.feature2"; + String featureName = "feature1"; + String featureName2 = "feature2"; + + featureInfos.add(createFeatureInfo(featureId, featureName, Enum.INT64)); + featureInfos.add(createFeatureInfo(featureId2, featureName2, Enum.INT64)); + featureIds.add(featureId); + featureIds.add(featureId2); + + when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos); + + Timestamp startDate = + Timestamps.fromSeconds(Instant.parse("2018-01-02T00:00:00.00Z").getEpochSecond()); + Timestamp endDate = + Timestamps.fromSeconds(Instant.parse("2018-01-30T12:11:11.00Z").getEpochSecond()); + FeatureSet featureSet = + FeatureSet.newBuilder().setEntityName("myentity").addAllFeatureIds(featureIds).build(); + + Map filter = new HashMap<>(); + filter.put("myentity.feature1", "10"); + + String query = + templater.createQuery(featureSet, startDate, endDate, 1000, filter); + + checkExpectedQuery(query, "expQueryWithNumberFilter.sql"); + } + + @Test + public void shouldRenderCorrectQueryWithStringFilter() throws Exception { + List featureInfos = new ArrayList<>(); + List featureIds = new ArrayList<>(); + + String featureId = "myentity.feature1"; + String featureId2 = "myentity.feature2"; + String featureName = "feature1"; + String featureName2 = "feature2"; + + featureInfos.add(createFeatureInfo(featureId, featureName, Enum.STRING)); + featureInfos.add(createFeatureInfo(featureId2, featureName2, Enum.STRING)); + featureIds.add(featureId); + featureIds.add(featureId2); + + when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos); + + Timestamp startDate = + Timestamps.fromSeconds(Instant.parse("2018-01-02T00:00:00.00Z").getEpochSecond()); + Timestamp endDate = + Timestamps.fromSeconds(Instant.parse("2018-01-30T12:11:11.00Z").getEpochSecond()); + FeatureSet featureSet = + FeatureSet.newBuilder().setEntityName("myentity").addAllFeatureIds(featureIds).build(); + + Map filter = new HashMap<>(); + filter.put("myentity.feature1", "10"); + + String query = + templater.createQuery(featureSet, startDate, endDate, 1000, filter); + + checkExpectedQuery(query, "expQueryWithStringFilter.sql"); + } + + + @Test + public void shouldRenderCorrectQueryWithStringAndNumberFilter() throws Exception { + List featureInfos = new ArrayList<>(); + List featureIds = new ArrayList<>(); + + String featureId = "myentity.feature1"; + String featureId2 = "myentity.feature2"; + String featureName = "feature1"; + String featureName2 = "feature2"; + + featureInfos.add(createFeatureInfo(featureId, featureName, Enum.INT64)); + featureInfos.add(createFeatureInfo(featureId2, featureName2, Enum.STRING)); + featureIds.add(featureId); + featureIds.add(featureId2); + + when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos); + + Timestamp startDate = + Timestamps.fromSeconds(Instant.parse("2018-01-02T00:00:00.00Z").getEpochSecond()); + Timestamp endDate = + Timestamps.fromSeconds(Instant.parse("2018-01-30T12:11:11.00Z").getEpochSecond()); + FeatureSet featureSet = + FeatureSet.newBuilder().setEntityName("myentity").addAllFeatureIds(featureIds).build(); + + Map filter = new HashMap<>(); + filter.put("myentity.feature1", "10"); + filter.put("myentity.feature2", "HELLO"); + + String query = + templater.createQuery(featureSet, startDate, endDate, 1000, filter); + + checkExpectedQuery(query, "expQueryWithNumberAndStringFilter.sql"); + } + + + @Test + public void shouldRenderCorrectQueryWithJobIdFilter() throws Exception { + List featureInfos = new ArrayList<>(); + List featureIds = new ArrayList<>(); + + String featureId = "myentity.feature1"; + String featureId2 = "myentity.feature2"; + String featureName = "feature1"; + String featureName2 = "feature2"; + + featureInfos.add(createFeatureInfo(featureId, featureName, Enum.INT64)); + featureInfos.add(createFeatureInfo(featureId2, featureName2, Enum.STRING)); + featureIds.add(featureId); + featureIds.add(featureId2); + + when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos); + + Timestamp startDate = + Timestamps.fromSeconds(Instant.parse("2018-01-02T00:00:00.00Z").getEpochSecond()); + Timestamp endDate = + Timestamps.fromSeconds(Instant.parse("2018-01-30T12:11:11.00Z").getEpochSecond()); + FeatureSet featureSet = + FeatureSet.newBuilder().setEntityName("myentity").addAllFeatureIds(featureIds).build(); + + Map filter = new HashMap<>(); + filter.put("myentity.feature1", "10"); + filter.put("myentity.feature2", "HELLO"); + filter.put("job_id", "1234567890"); + + String query = + templater.createQuery(featureSet, startDate, endDate, 1000, filter); + + checkExpectedQuery(query, "expQueryWithJobIdFilter.sql"); + } + private void checkExpectedQuery(String query, String pathToExpQuery) throws Exception { String tmpl = CharStreams.toString( @@ -212,11 +353,12 @@ private void checkExpectedQuery(String query, String pathToExpQuery) throws Exce assertThat(query, equalTo(expQuery)); } - private FeatureInfo createFeatureInfo(String featureId, String featureName) { + private FeatureInfo createFeatureInfo(String featureId, String featureName, ValueType.Enum valueType) { FeatureSpec fs = FeatureSpec.newBuilder() .setId(featureId) .setName(featureName) + .setValueType(valueType) .build(); EntitySpec entitySpec = EntitySpec.newBuilder().setName(featureId.split("\\.")[0]).build(); diff --git a/core/src/test/java/feast/core/training/BigQueryTraningDatasetCreatorTest.java b/core/src/test/java/feast/core/training/BigQueryTraningDatasetCreatorTest.java index fff75eefae2..5cbaa1017d2 100644 --- a/core/src/test/java/feast/core/training/BigQueryTraningDatasetCreatorTest.java +++ b/core/src/test/java/feast/core/training/BigQueryTraningDatasetCreatorTest.java @@ -22,9 +22,11 @@ import feast.core.DatasetServiceProto.DatasetInfo; import feast.core.DatasetServiceProto.FeatureSet; import feast.core.storage.BigQueryStorageManager; +import feast.core.util.UuidProvider; import feast.specs.StorageSpecProto.StorageSpec; import java.time.Instant; import java.util.Arrays; +import java.util.Collections; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -34,6 +36,7 @@ import static org.junit.Assert.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -55,6 +58,8 @@ public class BigQueryTraningDatasetCreatorTest { private BigQueryDatasetTemplater templater; @Mock private BigQuery bq; + @Mock + private UuidProvider uuidProvider; @Before public void setUp() { @@ -65,10 +70,11 @@ public void setUp() { .putOptions("project", "project") .putOptions("dataset", "dataset") .build()); - creator = new BigQueryTraningDatasetCreator(templater, projectId, datasetPrefix, bq); + creator = new BigQueryTraningDatasetCreator(templater, projectId, datasetPrefix, uuidProvider, bq); + when(uuidProvider.getUuid()).thenReturn("b0009f0f7df634ddc130571319e0deb9742eb1da"); when(templater.createQuery( - any(FeatureSet.class), any(Timestamp.class), any(Timestamp.class), anyLong())) + any(FeatureSet.class), any(Timestamp.class), any(Timestamp.class), anyLong(), anyMap())) .thenReturn("SELECT * FROM `project.dataset.table`"); } @@ -89,7 +95,8 @@ public void shouldCreateCorrectDatasetIfPrefixNotSpecified() { long limit = 999; String namePrefix = ""; - DatasetInfo dsInfo = creator.createDataset(featureSet, startDate, endDate, limit, namePrefix); + DatasetInfo dsInfo = creator.createDataset(featureSet, startDate, endDate, limit, namePrefix, Collections + .emptyMap()); assertThat( dsInfo.getName(), equalTo("feast_myentity_b0009f0f7df634ddc130571319e0deb9742eb1da")); assertThat( @@ -117,7 +124,7 @@ public void shouldCreateCorrectDatasetIfPrefixIsSpecified() { long limit = 999; String namePrefix = "mydataset"; - DatasetInfo dsInfo = creator.createDataset(featureSet, startDate, endDate, limit, namePrefix); + DatasetInfo dsInfo = creator.createDataset(featureSet, startDate, endDate, limit, namePrefix, Collections.emptyMap()); assertThat( dsInfo.getTableUrl(), equalTo( @@ -146,8 +153,8 @@ public void shouldPassArgumentToTemplater() { long limit = 999; String namePrefix = ""; - creator.createDataset(featureSet, startDate, endDate, limit, namePrefix); + creator.createDataset(featureSet, startDate, endDate, limit, namePrefix, Collections.emptyMap()); - verify(templater).createQuery(featureSet, startDate, endDate, limit); + verify(templater).createQuery(featureSet, startDate, endDate, limit, Collections.emptyMap()); } } diff --git a/core/src/test/resources/sql/expQueryWithJobIdFilter.sql b/core/src/test/resources/sql/expQueryWithJobIdFilter.sql new file mode 100644 index 00000000000..96c076b6a6d --- /dev/null +++ b/core/src/test/resources/sql/expQueryWithJobIdFilter.sql @@ -0,0 +1,8 @@ +SELECT + id, + event_timestamp, + feature1,feature2 +FROM + `project.dataset.myentity` +WHERE event_timestamp >= TIMESTAMP("2018-01-02") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) AND feature1 = 10 AND feature2 = "HELLO" AND job_id = "1234567890" +LIMIT 1000 \ No newline at end of file diff --git a/core/src/test/resources/sql/expQueryWithNumberAndStringFilter.sql b/core/src/test/resources/sql/expQueryWithNumberAndStringFilter.sql new file mode 100644 index 00000000000..8769fe44af7 --- /dev/null +++ b/core/src/test/resources/sql/expQueryWithNumberAndStringFilter.sql @@ -0,0 +1,8 @@ +SELECT + id, + event_timestamp, + feature1,feature2 +FROM + `project.dataset.myentity` +WHERE event_timestamp >= TIMESTAMP("2018-01-02") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) AND feature1 = 10 AND feature2 = "HELLO" +LIMIT 1000 \ No newline at end of file diff --git a/core/src/test/resources/sql/expQueryWithNumberFilter.sql b/core/src/test/resources/sql/expQueryWithNumberFilter.sql new file mode 100644 index 00000000000..6b199b7c4b8 --- /dev/null +++ b/core/src/test/resources/sql/expQueryWithNumberFilter.sql @@ -0,0 +1,8 @@ +SELECT + id, + event_timestamp, + feature1,feature2 +FROM + `project.dataset.myentity` +WHERE event_timestamp >= TIMESTAMP("2018-01-02") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) AND feature1 = 10 +LIMIT 1000 \ No newline at end of file diff --git a/core/src/test/resources/sql/expQueryWithStringFilter.sql b/core/src/test/resources/sql/expQueryWithStringFilter.sql new file mode 100644 index 00000000000..8c0a3805041 --- /dev/null +++ b/core/src/test/resources/sql/expQueryWithStringFilter.sql @@ -0,0 +1,8 @@ +SELECT + id, + event_timestamp, + feature1,feature2 +FROM + `project.dataset.myentity` +WHERE event_timestamp >= TIMESTAMP("2018-01-02") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) AND feature1 = "10" +LIMIT 1000 \ No newline at end of file diff --git a/protos/feast/core/DatasetService.proto b/protos/feast/core/DatasetService.proto index fa17444566f..c094b41b374 100644 --- a/protos/feast/core/DatasetService.proto +++ b/protos/feast/core/DatasetService.proto @@ -42,6 +42,8 @@ message DatasetServiceTypes { int64 limit = 4; // (optional) prefix for dataset name string namePrefix = 5; + // (optional) additional WHERE clause, all filter entry will be combined with logic AND + map filters = 6; } message CreateDatasetResponse { diff --git a/sdk/python/feast/core/DatasetService_pb2.py b/sdk/python/feast/core/DatasetService_pb2.py index 25171bad57d..0c94041a4b1 100644 --- a/sdk/python/feast/core/DatasetService_pb2.py +++ b/sdk/python/feast/core/DatasetService_pb2.py @@ -20,13 +20,50 @@ package='feast.core', syntax='proto3', serialized_options=_b('\n\nfeast.coreB\023DatasetServiceProtoZ5github.com/gojek/feast/protos/generated/go/feast/core'), - serialized_pb=_b('\n\x1f\x66\x65\x61st/core/DatasetService.proto\x12\nfeast.core\x1a\x1fgoogle/protobuf/timestamp.proto\"\xa0\x02\n\x13\x44\x61tasetServiceTypes\x1a\xc1\x01\n\x14\x43reateDatasetRequest\x12*\n\nfeatureSet\x18\x01 \x01(\x0b\x32\x16.feast.core.FeatureSet\x12-\n\tstartDate\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12+\n\x07\x65ndDate\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\r\n\x05limit\x18\x04 \x01(\x03\x12\x12\n\nnamePrefix\x18\x05 \x01(\t\x1a\x45\n\x15\x43reateDatasetResponse\x12,\n\x0b\x64\x61tasetInfo\x18\x01 \x01(\x0b\x32\x17.feast.core.DatasetInfo\"4\n\nFeatureSet\x12\x12\n\nentityName\x18\x01 \x01(\t\x12\x12\n\nfeatureIds\x18\x02 \x03(\t\"-\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08tableUrl\x18\x02 \x01(\t2\x90\x01\n\x0e\x44\x61tasetService\x12~\n\rCreateDataset\x12\x34.feast.core.DatasetServiceTypes.CreateDatasetRequest\x1a\x35.feast.core.DatasetServiceTypes.CreateDatasetResponse\"\x00\x42X\n\nfeast.coreB\x13\x44\x61tasetServiceProtoZ5github.com/gojek/feast/protos/generated/go/feast/coreb\x06proto3') + serialized_pb=_b('\n\x1f\x66\x65\x61st/core/DatasetService.proto\x12\nfeast.core\x1a\x1fgoogle/protobuf/timestamp.proto\"\xa4\x03\n\x13\x44\x61tasetServiceTypes\x1a\xc5\x02\n\x14\x43reateDatasetRequest\x12*\n\nfeatureSet\x18\x01 \x01(\x0b\x32\x16.feast.core.FeatureSet\x12-\n\tstartDate\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12+\n\x07\x65ndDate\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\r\n\x05limit\x18\x04 \x01(\x03\x12\x12\n\nnamePrefix\x18\x05 \x01(\t\x12R\n\x07\x66ilters\x18\x06 \x03(\x0b\x32\x41.feast.core.DatasetServiceTypes.CreateDatasetRequest.FiltersEntry\x1a.\n\x0c\x46iltersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x45\n\x15\x43reateDatasetResponse\x12,\n\x0b\x64\x61tasetInfo\x18\x01 \x01(\x0b\x32\x17.feast.core.DatasetInfo\"4\n\nFeatureSet\x12\x12\n\nentityName\x18\x01 \x01(\t\x12\x12\n\nfeatureIds\x18\x02 \x03(\t\"-\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08tableUrl\x18\x02 \x01(\t2\x90\x01\n\x0e\x44\x61tasetService\x12~\n\rCreateDataset\x12\x34.feast.core.DatasetServiceTypes.CreateDatasetRequest\x1a\x35.feast.core.DatasetServiceTypes.CreateDatasetResponse\"\x00\x42X\n\nfeast.coreB\x13\x44\x61tasetServiceProtoZ5github.com/gojek/feast/protos/generated/go/feast/coreb\x06proto3') , dependencies=[google_dot_protobuf_dot_timestamp__pb2.DESCRIPTOR,]) +_DATASETSERVICETYPES_CREATEDATASETREQUEST_FILTERSENTRY = _descriptor.Descriptor( + name='FiltersEntry', + full_name='feast.core.DatasetServiceTypes.CreateDatasetRequest.FiltersEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='feast.core.DatasetServiceTypes.CreateDatasetRequest.FiltersEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='feast.core.DatasetServiceTypes.CreateDatasetRequest.FiltersEntry.value', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=_b('8\001'), + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=384, + serialized_end=430, +) + _DATASETSERVICETYPES_CREATEDATASETREQUEST = _descriptor.Descriptor( name='CreateDatasetRequest', full_name='feast.core.DatasetServiceTypes.CreateDatasetRequest', @@ -69,10 +106,17 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='filters', full_name='feast.core.DatasetServiceTypes.CreateDatasetRequest.filters', index=5, + number=6, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], - nested_types=[], + nested_types=[_DATASETSERVICETYPES_CREATEDATASETREQUEST_FILTERSENTRY, ], enum_types=[ ], serialized_options=None, @@ -82,7 +126,7 @@ oneofs=[ ], serialized_start=105, - serialized_end=298, + serialized_end=430, ) _DATASETSERVICETYPES_CREATEDATASETRESPONSE = _descriptor.Descriptor( @@ -111,8 +155,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=300, - serialized_end=369, + serialized_start=432, + serialized_end=501, ) _DATASETSERVICETYPES = _descriptor.Descriptor( @@ -135,7 +179,7 @@ oneofs=[ ], serialized_start=81, - serialized_end=369, + serialized_end=501, ) @@ -172,8 +216,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=371, - serialized_end=423, + serialized_start=503, + serialized_end=555, ) @@ -210,13 +254,15 @@ extension_ranges=[], oneofs=[ ], - serialized_start=425, - serialized_end=470, + serialized_start=557, + serialized_end=602, ) +_DATASETSERVICETYPES_CREATEDATASETREQUEST_FILTERSENTRY.containing_type = _DATASETSERVICETYPES_CREATEDATASETREQUEST _DATASETSERVICETYPES_CREATEDATASETREQUEST.fields_by_name['featureSet'].message_type = _FEATURESET _DATASETSERVICETYPES_CREATEDATASETREQUEST.fields_by_name['startDate'].message_type = google_dot_protobuf_dot_timestamp__pb2._TIMESTAMP _DATASETSERVICETYPES_CREATEDATASETREQUEST.fields_by_name['endDate'].message_type = google_dot_protobuf_dot_timestamp__pb2._TIMESTAMP +_DATASETSERVICETYPES_CREATEDATASETREQUEST.fields_by_name['filters'].message_type = _DATASETSERVICETYPES_CREATEDATASETREQUEST_FILTERSENTRY _DATASETSERVICETYPES_CREATEDATASETREQUEST.containing_type = _DATASETSERVICETYPES _DATASETSERVICETYPES_CREATEDATASETRESPONSE.fields_by_name['datasetInfo'].message_type = _DATASETINFO _DATASETSERVICETYPES_CREATEDATASETRESPONSE.containing_type = _DATASETSERVICETYPES @@ -228,6 +274,13 @@ DatasetServiceTypes = _reflection.GeneratedProtocolMessageType('DatasetServiceTypes', (_message.Message,), dict( CreateDatasetRequest = _reflection.GeneratedProtocolMessageType('CreateDatasetRequest', (_message.Message,), dict( + + FiltersEntry = _reflection.GeneratedProtocolMessageType('FiltersEntry', (_message.Message,), dict( + DESCRIPTOR = _DATASETSERVICETYPES_CREATEDATASETREQUEST_FILTERSENTRY, + __module__ = 'feast.core.DatasetService_pb2' + # @@protoc_insertion_point(class_scope:feast.core.DatasetServiceTypes.CreateDatasetRequest.FiltersEntry) + )) + , DESCRIPTOR = _DATASETSERVICETYPES_CREATEDATASETREQUEST, __module__ = 'feast.core.DatasetService_pb2' # @@protoc_insertion_point(class_scope:feast.core.DatasetServiceTypes.CreateDatasetRequest) @@ -246,6 +299,7 @@ )) _sym_db.RegisterMessage(DatasetServiceTypes) _sym_db.RegisterMessage(DatasetServiceTypes.CreateDatasetRequest) +_sym_db.RegisterMessage(DatasetServiceTypes.CreateDatasetRequest.FiltersEntry) _sym_db.RegisterMessage(DatasetServiceTypes.CreateDatasetResponse) FeatureSet = _reflection.GeneratedProtocolMessageType('FeatureSet', (_message.Message,), dict( @@ -264,6 +318,7 @@ DESCRIPTOR._options = None +_DATASETSERVICETYPES_CREATEDATASETREQUEST_FILTERSENTRY._options = None _DATASETSERVICE = _descriptor.ServiceDescriptor( name='DatasetService', @@ -271,8 +326,8 @@ file=DESCRIPTOR, index=0, serialized_options=None, - serialized_start=473, - serialized_end=617, + serialized_start=605, + serialized_end=749, methods=[ _descriptor.MethodDescriptor( name='CreateDataset', diff --git a/sdk/python/feast/sdk/client.py b/sdk/python/feast/sdk/client.py index 1d1fe0a93d6..79db6575ff8 100644 --- a/sdk/python/feast/sdk/client.py +++ b/sdk/python/feast/sdk/client.py @@ -169,7 +169,8 @@ def run( return response.jobId def create_dataset( - self, feature_set, start_date, end_date, limit=None, name_prefix=None + self, feature_set, start_date, end_date, limit=None, + name_prefix=None, filters=None ): """ Create training dataset for a feature set. The training dataset @@ -187,11 +188,21 @@ def create_dataset( limit (int, optional): (default: None) maximum number of row returned name_prefix (str, optional): (default: None) name prefix. + filters (dict, optional): (default: None) conditional clause + that will be used to filter dataset. Keys of filters could be + feature id or job_id. :return: feast.resources.feature_set.DatasetInfo: DatasetInfo containing - the information of training dataset + the information of training dataset. """ - self._check_create_dataset_args(feature_set, start_date, end_date, limit) + self._check_create_dataset_args(feature_set, start_date, end_date, + limit, filters) + + conv_filters = None + if filters is not None: + conv_filters = {} + for k, v in filters.items(): + conv_filters[str(k)] = str(v) req = DatasetServiceTypes.CreateDatasetRequest( featureSet=feature_set.proto, @@ -199,6 +210,7 @@ def create_dataset( endDate=_timestamp_from_datetime(_parse_date(end_date)), limit=limit, namePrefix=name_prefix, + filters=conv_filters ) if self.verbose: print( @@ -421,7 +433,8 @@ def _apply_storage(self, storage): ) return response.storageId - def _check_create_dataset_args(self, feature_set, start_date, end_date, limit): + def _check_create_dataset_args(self, feature_set, start_date, end_date, + limit, filters): if len(feature_set.features) < 1: raise ValueError("feature set is empty") @@ -433,6 +446,9 @@ def _check_create_dataset_args(self, feature_set, start_date, end_date, limit): if limit is not None and limit < 1: raise ValueError("limit is not a positive integer") + if filters is not None and not isinstance(filters, dict): + raise ValueError("filters is not dictionary") + def _parse_date(date): try: diff --git a/sdk/python/tests/sdk/test_client.py b/sdk/python/tests/sdk/test_client.py index 54f4d2790c3..9e76cde301d 100644 --- a/sdk/python/tests/sdk/test_client.py +++ b/sdk/python/tests/sdk/test_client.py @@ -177,6 +177,10 @@ def test_create_dataset_invalid_args(self, client): ValueError, match="limit is not a positive integer"): client.create_dataset(feature_set, "2018-12-01", "2018-12-02", -1) + with pytest.raises(ValueError, match="filters is not dictionary"): + client.create_dataset(feature_set, "2018-12-01", "2018-12-02", + 10, filters="filter") + def test_create_dataset(self, client, mocker): entity_name = "myentity" feature_ids = ["myentity.feature1", "myentity.feature2"] @@ -207,6 +211,40 @@ def test_create_dataset(self, client, mocker): limit=None, namePrefix=None)) + def test_create_dataset_with_filters(self, client, mocker): + entity_name = "myentity" + feature_ids = ["myentity.feature1", "myentity.feature2"] + fs = FeatureSet(entity_name, feature_ids) + start_date = "2018-01-02" + end_date = "2018-12-31" + + ds_pb = DatasetInfo_pb( + name="dataset_name", tableUrl="project.dataset.table") + + mock_trn_stub = training.DatasetServiceStub(grpc.insecure_channel("")) + mocker.patch.object( + mock_trn_stub, + "CreateDataset", + return_value=DatasetServiceTypes.CreateDatasetResponse( + datasetInfo=ds_pb)) + client._dataset_service_stub = mock_trn_stub + + job_filter = {"job_id": 12345} + ds = client.create_dataset(fs, start_date, end_date, + filters=job_filter) + + assert "dataset_name" == ds.name + assert "project.dataset.table" == ds.full_table_id + mock_trn_stub.CreateDataset.assert_called_once_with( + DatasetServiceTypes.CreateDatasetRequest( + featureSet=fs.proto, + startDate=_timestamp_from_datetime(_parse_date(start_date)), + endDate=_timestamp_from_datetime(_parse_date(end_date)), + limit=None, + namePrefix=None, + filters={"job_id": "12345"})) + + def test_create_dataset_with_limit(self, client, mocker): entity_name = "myentity" feature_ids = ["myentity.feature1", "myentity.feature2"]