Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import operator
import struct
from copy import copy
from enum import Enum
from functools import reduce
Expand Down Expand Up @@ -323,6 +324,68 @@ def convert_base64_to_bytes(obj, model_fields):
return obj


def convert_vector_to_bytes(obj, model_fields):
"""Convert list[float] vector fields to packed bytes for HashModel storage.

Redis Hash fields can only store scalar values (strings, bytes, numbers).
Vector fields (list[float]) need to be serialized to bytes for storage.
This uses little-endian float32 packing, matching the format expected by
RediSearch for vector similarity queries.
"""
if not isinstance(obj, dict):
return obj

result = {}
for key, value in obj.items():
if key in model_fields and isinstance(value, list):
field_info = model_fields[key]
vector_options = getattr(field_info, "vector_options", None)
if vector_options is not None and value:
# Pack floats as little-endian float32 bytes
try:
result[key] = struct.pack(f"<{len(value)}f", *value)
except struct.error:
# If packing fails, keep original value
result[key] = value
else:
result[key] = value
else:
result[key] = value
return result


def convert_bytes_to_vector(obj, model_fields):
"""Convert packed bytes back to list[float] for vector fields.

This reverses the conversion done by convert_vector_to_bytes.
"""
if not isinstance(obj, dict):
return obj

result = {}
for key, value in obj.items():
if key in model_fields:
field_info = model_fields[key]
vector_options = getattr(field_info, "vector_options", None)
if vector_options is not None and isinstance(value, (bytes, str)):
# Handle bytes or string (Redis may return as string with decode_responses)
try:
if isinstance(value, str):
# If decode_responses=True, we get a string - need to encode back
value = value.encode("latin-1")
# Unpack little-endian float32 bytes
num_floats = len(value) // 4
result[key] = list(struct.unpack(f"<{num_floats}f", value))
except (struct.error, ValueError, UnicodeEncodeError):
# If unpacking fails, keep original value
result[key] = value
else:
result[key] = value
else:
result[key] = value
return result


class PartialModel:
"""A partial model instance that only contains certain fields.

Expand Down Expand Up @@ -2834,11 +2897,31 @@ class HashModel(RedisModel, abc.ABC):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

# Helper to check if a field has vector_options (making it a vector field).
# We check cls.__dict__ because model_fields may not be populated yet
# when __init_subclass__ runs during class creation.
def _has_vector_options(field_name: str) -> bool:
"""Check if a field has vector_options set, making it a vector field."""
# First check cls.__dict__ for the original FieldInfo (before Pydantic processing)
if field_name in cls.__dict__:
field = cls.__dict__[field_name]
if getattr(field, "vector_options", None) is not None:
return True
# Also check model_fields in case it's populated
if hasattr(cls, "model_fields") and field_name in cls.model_fields:
field = cls.model_fields[field_name]
if getattr(field, "vector_options", None) is not None:
return True
return False

if hasattr(cls, "__annotations__"):
for name, field_type in cls.__annotations__.items():
origin = get_origin(field_type)
for typ in (Set, Mapping, List):
if isinstance(origin, type) and issubclass(origin, typ):
# Vector fields are allowed to be lists (list[float])
if _has_vector_options(name):
continue
raise RedisModelError(
f"HashModels cannot index set, list, "
f"or mapping fields. Field: {name}"
Expand All @@ -2860,6 +2943,9 @@ def __init_subclass__(cls, **kwargs):
if origin:
for typ in (Set, Mapping, List):
if issubclass(origin, typ):
# Vector fields are allowed to be lists (list[float])
if getattr(field, "vector_options", None) is not None:
continue
raise RedisModelError(
f"HashModels cannot index set, list, "
f"or mapping fields. Field: {name}"
Expand Down Expand Up @@ -2944,6 +3030,8 @@ async def save(
# Get model data and apply conversions in the correct order
document = self.model_dump()
document = convert_datetime_to_timestamp(document)
# Convert vector fields (list[float]) to bytes before base64 encoding
document = convert_vector_to_bytes(document, self.__class__.model_fields)
document = convert_bytes_to_base64(document)

# Then apply jsonable encoding for other types
Expand Down Expand Up @@ -3046,6 +3134,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model":
document = convert_timestamp_to_datetime(document, cls.model_fields)
# Convert base64 strings back to bytes for bytes fields
document = convert_base64_to_bytes(document, cls.model_fields)
# Convert bytes back to list[float] for vector fields
document = convert_bytes_to_vector(document, cls.model_fields)
result = cls.model_validate(document)
except TypeError as e:
log.warning(
Expand All @@ -3059,6 +3149,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model":
document = convert_timestamp_to_datetime(document, cls.model_fields)
# Convert base64 strings back to bytes for bytes fields
document = convert_base64_to_bytes(document, cls.model_fields)
# Convert bytes back to list[float] for vector fields
document = convert_bytes_to_vector(document, cls.model_fields)
result = cls.model_validate(document)
return result

Expand Down
34 changes: 34 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
QueryNotSupportedError,
RedisModel,
RedisModelError,
VectorFieldOptions,
)
from aredis_om.model.model import ExpressionProxy

Expand Down Expand Up @@ -1569,3 +1570,36 @@ class Meta:
assert retrieved.pk == 42
assert retrieved.x == 42
assert retrieved.name == "test"


@py_test_mark_asyncio
async def test_hashmodel_vector_field_with_list(key_prefix, redis):
"""Test that HashModel allows list[float] fields when used with vector_options.

Regression test for GitHub issue #544: HashModel rejected list fields
even when they were vector fields that require list[float] type.
"""
vector_options = VectorFieldOptions.flat(
type=VectorFieldOptions.TYPE.FLOAT32,
dimension=4,
distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE,
)

# This should NOT raise an error - vector fields are allowed to be lists
class VectorDocument(HashModel, index=True):
name: str
embedding: list[float] = Field(default=[], vector_options=vector_options)

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

# Create and save a document with a vector
doc = VectorDocument(name="test", embedding=[0.1, 0.2, 0.3, 0.4])
await doc.save()

# Retrieve and verify
retrieved = await VectorDocument.get(doc.pk)
assert retrieved.name == "test"
Loading