Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ private static <T> Object serialize(T o, ErrorPath path) {
|| o instanceof Blob
|| o instanceof DocumentReference
|| o instanceof FieldValue
|| o instanceof Value) {
|| o instanceof Value
|| o instanceof VectorValue) {
return o;
} else if (o instanceof Instant) {
Instant instant = (Instant) o;
Expand Down Expand Up @@ -243,6 +244,8 @@ private static <T> T deserializeToClass(Object o, Class<T> clazz, DeserializeCon
return (T) convertBlob(o, context);
} else if (GeoPoint.class.isAssignableFrom(clazz)) {
return (T) convertGeoPoint(o, context);
} else if (VectorValue.class.isAssignableFrom(clazz)) {
return (T) convertVectorValue(o, context);
} else if (DocumentReference.class.isAssignableFrom(clazz)) {
return (T) convertDocumentReference(o, context);
} else if (clazz.isArray()) {
Expand Down Expand Up @@ -596,6 +599,16 @@ private static GeoPoint convertGeoPoint(Object o, DeserializeContext context) {
}
}

private static VectorValue convertVectorValue(Object o, DeserializeContext context) {
if (o instanceof VectorValue) {
return (VectorValue) o;
} else {
throw deserializeError(
context.errorPath,
"Failed to convert value of type " + o.getClass().getName() + " to VectorValue");
}
}

private static DocumentReference convertDocumentReference(Object o, DeserializeContext context) {
if (o instanceof DocumentReference) {
return (DocumentReference) o;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,18 @@ public GeoPoint getGeoPoint(@Nonnull String field) {
return (GeoPoint) get(field);
}

/**
* Returns the value of the field as a VectorValue.
*
* @param field The path to the field.
* @throws RuntimeException if the value is not a VectorValue.
* @return The value of the field.
*/
@Nullable
public VectorValue getVectorValue(@Nonnull String field) {
return (VectorValue) get(field);
}

/**
* Gets the reference to the document.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ public static FieldValue arrayRemove(@Nonnull Object... elements) {
return new ArrayRemoveFieldValue(Arrays.asList(elements));
}

/**
* Creates a new {@link VectorValue} constructed with a copy of the given array of doubles.
*
* @param values Create a {@link VectorValue} instance with a copy of this array of doubles.
* @return A new {@link VectorValue} constructed with a copy of the given array of doubles.
*/
@Nonnull
public static VectorValue vector(@Nonnull double[] values) {
return new VectorValue(values);
}

/** Whether this FieldTransform should be included in the document mask. */
abstract boolean includeInDocumentMask();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

abstract class MapType {
static final String RESERVED_MAP_KEY = "__type__";
static final String RESERVED_MAP_KEY_VECTOR_VALUE = "__vector__";
static final String VECTOR_MAP_VECTORS_KEY = "value";
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.google.cloud.firestore;

import com.google.firestore.v1.MapValue;
import com.google.firestore.v1.Value;
import com.google.firestore.v1.Value.ValueTypeCase;
import com.google.protobuf.ByteString;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
Expand All @@ -40,6 +42,7 @@ enum TypeOrder implements Comparable<TypeOrder> {
REF,
GEO_POINT,
ARRAY,
VECTOR,
OBJECT;

static TypeOrder fromValue(Value value) {
Expand All @@ -65,13 +68,24 @@ static TypeOrder fromValue(Value value) {
case ARRAY_VALUE:
return ARRAY;
case MAP_VALUE:
return OBJECT;
return fromMapValue(value.getMapValue());
default:
throw new IllegalArgumentException("Could not detect value type for " + value);
}
}
}

static TypeOrder fromMapValue(MapValue mapValue) {
switch (UserDataConverter.detectMapRepresentation(mapValue)) {
case VECTOR_VALUE:
return TypeOrder.VECTOR;
case UNKNOWN:
case NONE:
default:
return TypeOrder.OBJECT;
}
}

static final Order INSTANCE = new Order();

private Order() {}
Expand Down Expand Up @@ -113,6 +127,8 @@ public int compare(@Nonnull Value left, @Nonnull Value right) {
left.getArrayValue().getValuesList(), right.getArrayValue().getValuesList());
case OBJECT:
return compareObjects(left, right);
case VECTOR:
return compareVectors(left, right);
default:
throw new IllegalArgumentException("Cannot compare " + leftType);
}
Expand Down Expand Up @@ -209,6 +225,30 @@ private int compareObjects(Value left, Value right) {
return Boolean.compare(leftIterator.hasNext(), rightIterator.hasNext());
}

private int compareVectors(Value left, Value right) {
// The vector is a map, but only vector value is compared.
Value leftValueField =
left.getMapValue().getFieldsOrDefault(MapType.VECTOR_MAP_VECTORS_KEY, null);
Value rightValueField =
right.getMapValue().getFieldsOrDefault(MapType.VECTOR_MAP_VECTORS_KEY, null);

List<Value> leftArray =
(leftValueField != null)
? leftValueField.getArrayValue().getValuesList()
: Collections.emptyList();
List<Value> rightArray =
(rightValueField != null)
? rightValueField.getArrayValue().getValuesList()
: Collections.emptyList();

Integer lengthCompare = Long.compare(leftArray.size(), rightArray.size());
if (lengthCompare != 0) {
return lengthCompare;
}

return compareArrays(leftArray, rightArray);
}

private int compareNumbers(Value left, Value right) {
if (left.getValueTypeCase() == ValueTypeCase.DOUBLE_VALUE) {
if (right.getValueTypeCase() == ValueTypeCase.DOUBLE_VALUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.primitives.Doubles;
import com.google.firestore.v1.ArrayValue;
import com.google.firestore.v1.MapValue;
import com.google.firestore.v1.Value;
Expand All @@ -32,10 +33,12 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import javax.annotation.Nullable;

/** Converts user input into the Firestore Value representation. */
class UserDataConverter {
private static final Logger LOGGER = Logger.getLogger(UserDataConverter.class.getName());

/** Controls the behavior for field deletes. */
interface EncodingOptions {
Expand Down Expand Up @@ -183,12 +186,34 @@ static Value encodeValue(
// send the map.
return null;
}
} else if (sanitizedObject instanceof VectorValue) {
VectorValue vectorValue = (VectorValue) sanitizedObject;
return Value.newBuilder().setMapValue(vectorValue.toProto()).build();
}

throw FirestoreException.forInvalidArgument(
"Cannot convert %s to Firestore Value", sanitizedObject);
}

static MapValue encodeVector(double[] rawVector) {
MapValue.Builder res = MapValue.newBuilder();

res.putFields(
MapType.RESERVED_MAP_KEY,
encodeValue(
FieldPath.fromDotSeparatedString(MapType.RESERVED_MAP_KEY),
MapType.RESERVED_MAP_KEY_VECTOR_VALUE,
ARGUMENT));
res.putFields(
MapType.VECTOR_MAP_VECTORS_KEY,
encodeValue(
FieldPath.fromDotSeparatedString(MapType.RESERVED_MAP_KEY_VECTOR_VALUE),
Doubles.asList(rawVector),
ARGUMENT));

return res.build();
}

static Object decodeValue(FirestoreRpcContext<?> rpcContext, Value v) {
Value.ValueTypeCase typeCase = v.getValueTypeCase();
switch (typeCase) {
Expand Down Expand Up @@ -220,18 +245,72 @@ static Object decodeValue(FirestoreRpcContext<?> rpcContext, Value v) {
}
return list;
case MAP_VALUE:
return decodeMap(rpcContext, v.getMapValue());
default:
throw FirestoreException.forInvalidArgument(
String.format("Unknown Value Type: %s", typeCase));
}
}

static Object decodeMap(FirestoreRpcContext<?> rpcContext, MapValue mapValue) {
MapRepresentation mapRepresentation = detectMapRepresentation(mapValue);
Map<String, Value> inputMap = mapValue.getFieldsMap();
switch (mapRepresentation) {
case UNKNOWN:
LOGGER.warning(
"Parsing unknown map type as generic map. This map type may be supported in a newer SDK version.");
case NONE:
Map<String, Object> outputMap = new HashMap<>();
Map<String, Value> inputMap = v.getMapValue().getFieldsMap();
for (Map.Entry<String, Value> entry : inputMap.entrySet()) {
outputMap.put(entry.getKey(), decodeValue(rpcContext, entry.getValue()));
}
return outputMap;
case VECTOR_VALUE:
double[] values =
inputMap.get(MapType.VECTOR_MAP_VECTORS_KEY).getArrayValue().getValuesList().stream()
.mapToDouble(val -> val.getDoubleValue())
.toArray();
return new VectorValue(values);
default:
throw FirestoreException.forInvalidArgument(
String.format("Unknown Value Type: %s", typeCase));
String.format("Unsupported MapRepresentation: %s", mapRepresentation));
}
}

/** Indicates the data type represented by a MapValue. */
enum MapRepresentation {
/** The MapValue represents an unknown data type. */
UNKNOWN,
/** The MapValue does not represent any special data type. */
NONE,
/** The MapValue represents a VectorValue. */
VECTOR_VALUE
}

static MapRepresentation detectMapRepresentation(MapValue mapValue) {
Map<String, Value> fields = mapValue.getFieldsMap();
if (!fields.containsKey(MapType.RESERVED_MAP_KEY)) {
return MapRepresentation.NONE;
}

Value typeValue = fields.get(MapType.RESERVED_MAP_KEY);
if (typeValue.getValueTypeCase() != Value.ValueTypeCase.STRING_VALUE) {
LOGGER.warning(
"Unable to parse __type__ field of map. Unsupported value type: "
+ typeValue.getValueTypeCase().toString());
return MapRepresentation.UNKNOWN;
}

String typeString = typeValue.getStringValue();

if (typeString.equals(MapType.RESERVED_MAP_KEY_VECTOR_VALUE)) {
return MapRepresentation.VECTOR_VALUE;
}

LOGGER.warning("Unsupported __type__ value for map: " + typeString);
return MapRepresentation.UNKNOWN;
}

static Object decodeGoogleProtobufValue(com.google.protobuf.Value v) {
switch (v.getKindCase()) {
case NULL_VALUE:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

import com.google.firestore.v1.MapValue;
import java.io.Serializable;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/**
* Represents a vector in Firestore documents. Create an instance with {@link FieldValue#vector}.
*/
public final class VectorValue implements Serializable {
private final double[] values;

VectorValue(@Nullable double[] values) {
if (values == null) this.values = new double[] {};
else this.values = values.clone();
}

/**
* Returns a representation of the vector as an array of doubles.
*
* @return A representation of the vector as an array of doubles
*/
@Nonnull
public double[] toArray() {
return this.values.clone();
}

/**
* Returns true if this VectorValue is equal to the provided object.
*
* @param obj The object to compare against.
* @return Whether this VectorValue is equal to the provided object.
*/
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
VectorValue otherArray = (VectorValue) obj;
return Arrays.equals(this.values, otherArray.values);
}

@Override
public int hashCode() {
return Arrays.hashCode(values);
}

MapValue toProto() {
return UserDataConverter.encodeVector(this.values);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,8 @@ public void extractFieldMaskFromMerge() throws Exception {
"second.objectValue.foo",
"second.timestampValue",
"second.trueValue",
"second.model.foo");
"second.model.foo",
"second.vectorValue");

CommitRequest expectedCommit = commit(set(nestedUpdate, updateMask));
assertCommitEquals(expectedCommit, commitCapture.getValue());
Expand Down
Loading