Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 121 additions & 1 deletion aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,113 @@ def convert_timestamp_to_datetime(obj, model_fields):
return obj


def convert_bytes_to_base64(obj):
"""Convert bytes objects to base64-encoded strings for storage.

This is necessary because Redis JSON and the jsonable_encoder cannot
handle arbitrary binary data. Base64 encoding ensures all byte values
(0-255) can be safely stored and retrieved.
"""
import base64

if isinstance(obj, dict):
return {key: convert_bytes_to_base64(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_bytes_to_base64(item) for item in obj]
elif isinstance(obj, bytes):
return base64.b64encode(obj).decode("ascii")
else:
return obj


def convert_base64_to_bytes(obj, model_fields):
"""Convert base64-encoded strings back to bytes based on model field types."""
import base64

if isinstance(obj, dict):
result = {}
for key, value in obj.items():
if key in model_fields:
field_info = model_fields[key]
field_type = (
field_info.annotation if hasattr(field_info, "annotation") else None
)

# Handle Optional types - extract the inner type
if hasattr(field_type, "__origin__") and field_type.__origin__ is Union:
# For Optional[T] which is Union[T, None], get the non-None type
args = getattr(field_type, "__args__", ())
non_none_types = [
arg for arg in args if arg is not type(None) # noqa: E721
]
if len(non_none_types) == 1:
field_type = non_none_types[0]

# Handle bytes fields
if field_type is bytes and isinstance(value, str):
try:
result[key] = base64.b64decode(value)
except (ValueError, TypeError):
# If it's not valid base64, keep original value
result[key] = value
# Handle nested models - check if it's a model with fields
elif isinstance(value, dict):
try:
if (
isinstance(field_type, type)
and hasattr(field_type, "model_fields")
and field_type.model_fields
):
result[key] = convert_base64_to_bytes(
value, field_type.model_fields
)
else:
result[key] = convert_base64_to_bytes(value, {})
except (TypeError, AttributeError):
result[key] = convert_base64_to_bytes(value, {})
# Handle lists that might contain nested models
elif isinstance(value, list):
# Try to extract the inner type from List[SomeModel]
inner_type = None
if (
hasattr(field_type, "__origin__")
and field_type.__origin__ in (list, List)
and hasattr(field_type, "__args__")
and field_type.__args__
):
inner_type = field_type.__args__[0]

if inner_type is not None:
try:
if (
isinstance(inner_type, type)
and hasattr(inner_type, "model_fields")
and inner_type.model_fields
):
result[key] = [
convert_base64_to_bytes(item, inner_type.model_fields)
if isinstance(item, dict)
else item
for item in value
]
else:
result[key] = convert_base64_to_bytes(value, {})
except (TypeError, AttributeError):
result[key] = convert_base64_to_bytes(value, {})
else:
result[key] = convert_base64_to_bytes(value, {})
else:
result[key] = convert_base64_to_bytes(value, {})
else:
# For keys not in model_fields, still recurse but with empty field info
result[key] = convert_base64_to_bytes(value, {})
return result
elif isinstance(obj, list):
return [convert_base64_to_bytes(item, model_fields) for item in obj]
else:
return obj


class PartialModel:
"""A partial model instance that only contains certain fields.

Expand Down Expand Up @@ -2558,10 +2665,14 @@ def to_string(s):
json_fields = convert_timestamp_to_datetime(
json_fields, cls.model_fields
)
# Convert base64 strings back to bytes for bytes fields
json_fields = convert_base64_to_bytes(json_fields, cls.model_fields)
doc = cls(**json_fields)
else:
# Convert timestamps back to datetime objects
fields = convert_timestamp_to_datetime(fields, cls.model_fields)
# Convert base64 strings back to bytes for bytes fields
fields = convert_base64_to_bytes(fields, cls.model_fields)
doc = cls(**fields)

docs.append(doc)
Expand Down Expand Up @@ -2752,9 +2863,10 @@ async def save(
self.check()
db = self._get_db(pipeline)

# Get model data and convert datetime objects first
# Get model data and apply conversions in the correct order
document = self.model_dump()
document = convert_datetime_to_timestamp(document)
document = convert_bytes_to_base64(document)

# Then apply jsonable encoding for other types
document = jsonable_encoder(document)
Expand Down Expand Up @@ -2854,6 +2966,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model":
try:
# Convert timestamps back to datetime objects before validation
document = convert_timestamp_to_datetime(document, cls.model_fields)
# Convert base64 strings back to bytes for bytes fields
document = convert_base64_to_bytes(document, cls.model_fields)
result = cls.model_validate(document)
except TypeError as e:
log.warning(
Expand All @@ -2865,6 +2979,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model":
document = decode_redis_value(document, cls.Meta.encoding)
# Convert timestamps back to datetime objects after decoding
document = convert_timestamp_to_datetime(document, cls.model_fields)
# Convert base64 strings back to bytes for bytes fields
document = convert_base64_to_bytes(document, cls.model_fields)
result = cls.model_validate(document)
return result

Expand Down Expand Up @@ -3126,6 +3242,8 @@ async def save(
data = self.model_dump()
# Convert datetime objects to timestamps for proper indexing
data = convert_datetime_to_timestamp(data)
# Convert bytes to base64 strings for safe JSON storage
data = convert_bytes_to_base64(data)
# Apply JSON encoding for complex types (Enums, UUIDs, Sets, etc.)
data = jsonable_encoder(data)

Expand Down Expand Up @@ -3199,6 +3317,8 @@ async def get(cls: Type["Model"], pk: Any) -> "Model":
raise NotFoundError
# Convert timestamps back to datetime objects before validation
document_data = convert_timestamp_to_datetime(document_data, cls.model_fields)
# Convert base64 strings back to bytes for bytes fields
document_data = convert_base64_to_bytes(document_data, cls.model_fields)
return cls.model_validate(document_data)

@classmethod
Expand Down
68 changes: 68 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,3 +1467,71 @@ async def test_save_nx_with_pipeline_raises_error(m):
async with m.Member.db().pipeline(transaction=True) as pipe:
with pytest.raises(ValueError, match="Cannot use nx or xx with pipeline"):
await member.save(pipeline=pipe, nx=True)




@py_test_mark_asyncio
async def test_bytes_field_with_binary_data(key_prefix, redis):
"""Test that bytes fields can store arbitrary binary data including non-UTF8 bytes.

Regression test for GitHub issue #779: bytes fields failed with UnicodeDecodeError
when storing actual binary data (non-UTF8 bytes).
"""

class FileHash(HashModel, index=True):
filename: str = Field(index=True)
content: bytes

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

# Test with binary data that is NOT valid UTF-8 (PNG header)
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR"

f = FileHash(filename="image.png", content=binary_content)
await f.save()

# Retrieve and verify
retrieved = await FileHash.get(f.pk)
assert retrieved.content == binary_content
assert retrieved.filename == "image.png"

# Test with null bytes and other non-printable characters
null_content = b"\x00\x01\x02\x03\xff\xfe\xfd"
f2 = FileHash(filename="binary.bin", content=null_content)
await f2.save()

retrieved2 = await FileHash.get(f2.pk)
assert retrieved2.content == null_content


@py_test_mark_asyncio
async def test_optional_bytes_field(key_prefix, redis):
"""Test that Optional[bytes] fields work correctly."""
from typing import Optional

class Attachment(HashModel, index=True):
name: str = Field(index=True)
data: Optional[bytes] = None

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

# Without data
a1 = Attachment(name="empty")
await a1.save()
r1 = await Attachment.get(a1.pk)
assert r1.data is None

# With binary data
a2 = Attachment(name="with_data", data=b"\x89PNG\x00\xff")
await a2.save()
r2 = await Attachment.get(a2.pk)
assert r2.data == b"\x89PNG\x00\xff"
98 changes: 98 additions & 0 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,3 +1892,101 @@ class Meta:
assert "normal_field" in schema_str
# Case sensitive fields use CASESENSITIVE in schema
assert "CASESENSITIVE" in schema_str




@py_test_mark_asyncio
async def test_bytes_field_with_binary_data(key_prefix, redis):
"""Test that bytes fields can store arbitrary binary data including non-UTF8 bytes.

Regression test for GitHub issue #779: bytes fields failed with UnicodeDecodeError
when storing actual binary data (non-UTF8 bytes).
"""

class FileJson(JsonModel, index=True):
filename: str = Field(index=True)
content: bytes

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

# Test with binary data that is NOT valid UTF-8 (PNG header)
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR"

f = FileJson(filename="image.png", content=binary_content)
await f.save()

# Retrieve and verify
retrieved = await FileJson.get(f.pk)
assert retrieved.content == binary_content
assert retrieved.filename == "image.png"

# Test with null bytes and other non-printable characters
null_content = b"\x00\x01\x02\x03\xff\xfe\xfd"
f2 = FileJson(filename="binary.bin", content=null_content)
await f2.save()

retrieved2 = await FileJson.get(f2.pk)
assert retrieved2.content == null_content


@py_test_mark_asyncio
async def test_optional_bytes_field(key_prefix, redis):
"""Test that Optional[bytes] fields work correctly."""
from typing import Optional

class Attachment(JsonModel, index=True):
name: str = Field(index=True)
data: Optional[bytes] = None

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

# Without data
a1 = Attachment(name="empty")
await a1.save()
r1 = await Attachment.get(a1.pk)
assert r1.data is None

# With binary data
a2 = Attachment(name="with_data", data=b"\x89PNG\x00\xff")
await a2.save()
r2 = await Attachment.get(a2.pk)
assert r2.data == b"\x89PNG\x00\xff"


@py_test_mark_asyncio
async def test_bytes_field_in_embedded_model(key_prefix, redis):
"""Test that bytes fields work in embedded models."""

class FileData(EmbeddedJsonModel):
content: bytes
mime_type: str

class Document(JsonModel, index=True):
name: str = Field(index=True)
file: FileData

class Meta:
global_key_prefix = key_prefix
database = redis

await Migrator().run()

binary_content = b"\x89PNG\r\n\x1a\n\x00\x00"
doc = Document(
name="test.png",
file=FileData(content=binary_content, mime_type="image/png"),
)
await doc.save()

retrieved = await Document.get(doc.pk)
assert retrieved.file.content == binary_content
assert retrieved.file.mime_type == "image/png"
Loading