diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index b1744168..5dcb1c77 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -4,6 +4,7 @@ import json import logging import operator +import struct from copy import copy from enum import Enum from functools import reduce @@ -323,6 +324,68 @@ def convert_base64_to_bytes(obj, model_fields): return obj +def convert_vector_to_bytes(obj, model_fields): + """Convert list[float] vector fields to packed bytes for HashModel storage. + + Redis Hash fields can only store scalar values (strings, bytes, numbers). + Vector fields (list[float]) need to be serialized to bytes for storage. + This uses little-endian float32 packing, matching the format expected by + RediSearch for vector similarity queries. + """ + if not isinstance(obj, dict): + return obj + + result = {} + for key, value in obj.items(): + if key in model_fields and isinstance(value, list): + field_info = model_fields[key] + vector_options = getattr(field_info, "vector_options", None) + if vector_options is not None and value: + # Pack floats as little-endian float32 bytes + try: + result[key] = struct.pack(f"<{len(value)}f", *value) + except struct.error: + # If packing fails, keep original value + result[key] = value + else: + result[key] = value + else: + result[key] = value + return result + + +def convert_bytes_to_vector(obj, model_fields): + """Convert packed bytes back to list[float] for vector fields. + + This reverses the conversion done by convert_vector_to_bytes. + """ + if not isinstance(obj, dict): + return obj + + result = {} + for key, value in obj.items(): + if key in model_fields: + field_info = model_fields[key] + vector_options = getattr(field_info, "vector_options", None) + if vector_options is not None and isinstance(value, (bytes, str)): + # Handle bytes or string (Redis may return as string with decode_responses) + try: + if isinstance(value, str): + # If decode_responses=True, we get a string - need to encode back + value = value.encode("latin-1") + # Unpack little-endian float32 bytes + num_floats = len(value) // 4 + result[key] = list(struct.unpack(f"<{num_floats}f", value)) + except (struct.error, ValueError, UnicodeEncodeError): + # If unpacking fails, keep original value + result[key] = value + else: + result[key] = value + else: + result[key] = value + return result + + class PartialModel: """A partial model instance that only contains certain fields. @@ -2834,11 +2897,31 @@ class HashModel(RedisModel, abc.ABC): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) + # Helper to check if a field has vector_options (making it a vector field). + # We check cls.__dict__ because model_fields may not be populated yet + # when __init_subclass__ runs during class creation. + def _has_vector_options(field_name: str) -> bool: + """Check if a field has vector_options set, making it a vector field.""" + # First check cls.__dict__ for the original FieldInfo (before Pydantic processing) + if field_name in cls.__dict__: + field = cls.__dict__[field_name] + if getattr(field, "vector_options", None) is not None: + return True + # Also check model_fields in case it's populated + if hasattr(cls, "model_fields") and field_name in cls.model_fields: + field = cls.model_fields[field_name] + if getattr(field, "vector_options", None) is not None: + return True + return False + if hasattr(cls, "__annotations__"): for name, field_type in cls.__annotations__.items(): origin = get_origin(field_type) for typ in (Set, Mapping, List): if isinstance(origin, type) and issubclass(origin, typ): + # Vector fields are allowed to be lists (list[float]) + if _has_vector_options(name): + continue raise RedisModelError( f"HashModels cannot index set, list, " f"or mapping fields. Field: {name}" @@ -2860,6 +2943,9 @@ def __init_subclass__(cls, **kwargs): if origin: for typ in (Set, Mapping, List): if issubclass(origin, typ): + # Vector fields are allowed to be lists (list[float]) + if getattr(field, "vector_options", None) is not None: + continue raise RedisModelError( f"HashModels cannot index set, list, " f"or mapping fields. Field: {name}" @@ -2944,6 +3030,8 @@ async def save( # Get model data and apply conversions in the correct order document = self.model_dump() document = convert_datetime_to_timestamp(document) + # Convert vector fields (list[float]) to bytes before base64 encoding + document = convert_vector_to_bytes(document, self.__class__.model_fields) document = convert_bytes_to_base64(document) # Then apply jsonable encoding for other types @@ -3046,6 +3134,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": document = convert_timestamp_to_datetime(document, cls.model_fields) # Convert base64 strings back to bytes for bytes fields document = convert_base64_to_bytes(document, cls.model_fields) + # Convert bytes back to list[float] for vector fields + document = convert_bytes_to_vector(document, cls.model_fields) result = cls.model_validate(document) except TypeError as e: log.warning( @@ -3059,6 +3149,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": document = convert_timestamp_to_datetime(document, cls.model_fields) # Convert base64 strings back to bytes for bytes fields document = convert_base64_to_bytes(document, cls.model_fields) + # Convert bytes back to list[float] for vector fields + document = convert_bytes_to_vector(document, cls.model_fields) result = cls.model_validate(document) return result diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 52bf7fe8..21330ab8 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -22,6 +22,7 @@ QueryNotSupportedError, RedisModel, RedisModelError, + VectorFieldOptions, ) from aredis_om.model.model import ExpressionProxy @@ -1569,3 +1570,36 @@ class Meta: assert retrieved.pk == 42 assert retrieved.x == 42 assert retrieved.name == "test" + + +@py_test_mark_asyncio +async def test_hashmodel_vector_field_with_list(key_prefix, redis): + """Test that HashModel allows list[float] fields when used with vector_options. + + Regression test for GitHub issue #544: HashModel rejected list fields + even when they were vector fields that require list[float] type. + """ + vector_options = VectorFieldOptions.flat( + type=VectorFieldOptions.TYPE.FLOAT32, + dimension=4, + distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE, + ) + + # This should NOT raise an error - vector fields are allowed to be lists + class VectorDocument(HashModel, index=True): + name: str + embedding: list[float] = Field(default=[], vector_options=vector_options) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + # Create and save a document with a vector + doc = VectorDocument(name="test", embedding=[0.1, 0.2, 0.3, 0.4]) + await doc.save() + + # Retrieve and verify + retrieved = await VectorDocument.get(doc.pk) + assert retrieved.name == "test"