diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index a110af96..be5d2de6 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -77,6 +77,32 @@ log = logging.getLogger(__name__) escaper = TokenEscaper() +# Minimum redis-py version for hash field expiration support +_HASH_FIELD_EXPIRATION_MIN_VERSION = (5, 1, 0) + + +def supports_hash_field_expiration() -> bool: + """ + Check if the installed redis-py version supports hash field expiration commands. + + Hash field expiration (HEXPIRE, HTTL, HPERSIST, etc.) was added in redis-py 5.1.0 + and requires Redis server 7.4+. + + Returns: + True if redis-py >= 5.1.0 and has the hexpire method, False otherwise. + """ + try: + import redis as redis_lib + + version_str = getattr(redis_lib, "__version__", "0.0.0") + version_parts = tuple(int(x) for x in version_str.split(".")[:3]) + if version_parts >= _HASH_FIELD_EXPIRATION_MIN_VERSION: + # Also check that the method actually exists + return hasattr(redis_lib.asyncio.Redis, "hexpire") + return False + except (ValueError, AttributeError): + return False + def convert_datetime_to_timestamp(obj): """Convert datetime objects to Unix timestamps for storage.""" @@ -1879,6 +1905,7 @@ def __init__(self, default: Any = ..., **kwargs: Any) -> None: index = kwargs.pop("index", None) full_text_search = kwargs.pop("full_text_search", None) vector_options = kwargs.pop("vector_options", None) + expire = kwargs.pop("expire", None) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.sortable = sortable @@ -1886,6 +1913,7 @@ def __init__(self, default: Any = ..., **kwargs: Any) -> None: self.index = index self.full_text_search = full_text_search self.vector_options = vector_options + self.expire = expire class RelationshipInfo(Representation): @@ -1996,8 +2024,27 @@ def Field( index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, vector_options: Optional[VectorFieldOptions] = None, + expire: Optional[int] = None, **kwargs: Unpack[_FromFieldInfoInputs], ) -> Any: + """ + Create a field with Redis OM specific options. + + Args: + default: Default value for the field. + primary_key: Whether this field is the primary key. + sortable: Whether this field should be sortable in queries. + case_sensitive: Whether string matching should be case-sensitive. + index: Whether this field should be indexed for searching. + full_text_search: Whether to enable full-text search on this field. + vector_options: Vector field configuration for similarity search. + expire: TTL in seconds for this field (HashModel only, requires Redis 7.4+). + When set, the field will automatically expire after save(). + **kwargs: Additional Pydantic field options. + + Returns: + FieldInfo instance with the configured options. + """ field_info = FieldInfo( **kwargs, default=default, @@ -2007,6 +2054,7 @@ def Field( index=index, full_text_search=full_text_search, vector_options=vector_options, + expire=expire, ) return field_info @@ -2631,12 +2679,62 @@ def __init_subclass__(cls, **kwargs): f"HashModels cannot index dataclass fields. Field: {name}" ) + def _get_field_expirations( + self, field_expirations: Optional[Dict[str, int]] = None + ) -> Dict[str, int]: + """ + Collect field expirations from Field(expire=N) defaults and overrides. + + Args: + field_expirations: Optional dict of {field_name: ttl_seconds} to override defaults. + + Returns: + Dict of field names to TTL in seconds. + """ + expirations: Dict[str, int] = {} + + # Collect default expirations from Field(expire=N) + for name, field in self.model_fields.items(): + field_info = field + # Handle metadata-wrapped FieldInfo + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] + + expire = getattr(field_info, "expire", None) + if expire is not None: + expirations[name] = expire + + # Override with explicit field_expirations + if field_expirations: + expirations.update(field_expirations) + + return expirations + async def save( self: "Model", pipeline: Optional[redis.client.Pipeline] = None, nx: bool = False, xx: bool = False, + field_expirations: Optional[Dict[str, int]] = None, ) -> Optional["Model"]: + """ + Save the model to Redis. + + Args: + pipeline: Optional Redis pipeline for batching commands. + nx: Only save if the key doesn't exist. + xx: Only save if the key already exists. + field_expirations: Dict of {field_name: ttl_seconds} to set field expirations. + Overrides any Field(expire=N) defaults. Requires Redis 7.4+. + + Returns: + The saved model, or None if nx/xx conditions weren't met. + """ if nx and xx: raise ValueError("Cannot specify both nx and xx") if pipeline and (nx or xx): @@ -2666,6 +2764,12 @@ async def save( key = self.key() + # Collect field expirations + expirations = self._get_field_expirations(field_expirations) + + # Check if we're using a pipeline (pipelines don't support TTL preservation) + is_pipeline = pipeline is not None + async def _do_save(conn): # Check nx/xx conditions (HSET doesn't support these natively) if nx or xx: @@ -2675,7 +2779,37 @@ async def _do_save(conn): if xx and not exists: return None # Key doesn't exist, xx means only update existing + # Preserve existing field TTLs before HSET (HSET removes field-level TTLs) + # See issue #753: .save() conflicts with TTL on unrelated field + # Note: TTL preservation is skipped when using pipelines because + # pipeline commands return futures, not actual values + preserved_ttls: Dict[str, int] = {} + if supports_hash_field_expiration() and not is_pipeline: + fields_to_check = [f for f in document.keys() if f != "pk"] + if fields_to_check: + current_ttls = await conn.httl(key, *fields_to_check) + if current_ttls: + for i, field_name in enumerate(fields_to_check): + if current_ttls[i] > 0: # Has a TTL + preserved_ttls[field_name] = current_ttls[i] + await conn.hset(key, mapping=document) + + # Apply field expirations after HSET (requires Redis 7.4+) + # When using pipelines, we can still apply default expirations but + # can't preserve manually-set TTLs + if supports_hash_field_expiration(): + for field_name in document.keys(): + if field_name == "pk": + continue + # Priority: preserved TTL > explicit field_expirations > Field(expire=N) default + if field_name in preserved_ttls: + # Restore the TTL that was removed by HSET + await conn.hexpire(key, preserved_ttls[field_name], field_name) + elif field_name in expirations: + # Apply new expiration (from Field(expire=N) or field_expirations param) + await conn.hexpire(key, expirations[field_name], field_name) + return self # TODO: Wrap any Redis response errors in a custom exception? @@ -2861,6 +2995,101 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): return schema + # ========================================================================= + # Hash Field Expiration Methods (Redis 7.4+) + # ========================================================================= + + async def expire_field( + self, + field_name: str, + seconds: int, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> int: + """ + Set a TTL on a specific hash field. + + Requires Redis 7.4+ and redis-py >= 5.1.0. + + Args: + field_name: The name of the field to expire. + seconds: TTL in seconds. + nx: Only set expiry if field has no expiry. + xx: Only set expiry if field already has an expiry. + gt: Only set expiry if new expiry is greater than current. + lt: Only set expiry if new expiry is less than current. + + Returns: + 1 if expiry was set, -2 if field doesn't exist, 0 if conditions not met. + + Raises: + NotImplementedError: If redis-py version doesn't support HEXPIRE. + """ + if not supports_hash_field_expiration(): + raise NotImplementedError( + "Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+" + ) + + db = self.db() + key = self.key() + result = await db.hexpire(key, seconds, field_name, nx=nx, xx=xx, gt=gt, lt=lt) + # hexpire returns a list of results, one per field + return result[0] if result else -2 + + async def field_ttl(self, field_name: str) -> int: + """ + Get the remaining TTL of a hash field in seconds. + + Requires Redis 7.4+ and redis-py >= 5.1.0. + + Args: + field_name: The name of the field. + + Returns: + TTL in seconds, -1 if no expiry, -2 if field doesn't exist. + + Raises: + NotImplementedError: If redis-py version doesn't support HTTL. + """ + if not supports_hash_field_expiration(): + raise NotImplementedError( + "Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+" + ) + + db = self.db() + key = self.key() + result = await db.httl(key, field_name) + # httl returns a list of results, one per field + return result[0] if result else -2 + + async def persist_field(self, field_name: str) -> int: + """ + Remove the expiration from a hash field. + + Requires Redis 7.4+ and redis-py >= 5.1.0. + + Args: + field_name: The name of the field. + + Returns: + 1 if expiry was removed, -1 if no expiry, -2 if field doesn't exist. + + Raises: + NotImplementedError: If redis-py version doesn't support HPERSIST. + """ + if not supports_hash_field_expiration(): + raise NotImplementedError( + "Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+" + ) + + db = self.db() + key = self.key() + result = await db.hpersist(key, field_name) + # hpersist returns a list of results, one per field + return result[0] if result else -2 + class JsonModel(RedisModel, abc.ABC): def __init_subclass__(cls, **kwargs): diff --git a/tests/test_hash_field_expiration.py b/tests/test_hash_field_expiration.py new file mode 100644 index 00000000..300ebc9e --- /dev/null +++ b/tests/test_hash_field_expiration.py @@ -0,0 +1,339 @@ +# type: ignore +""" +Tests for Redis 7.4+ hash field expiration support in HashModel. + +These tests verify: +1. Field(expire=N) declarative TTL on fields +2. expire_field(), field_ttl(), persist_field() instance methods +3. field_expirations parameter on save() +4. Graceful handling when redis-py lacks HEXPIRE support +""" + +import abc +import asyncio +import datetime +import time +from collections import namedtuple +from unittest import mock + +import pytest +import pytest_asyncio + +from aredis_om import Field, HashModel, Migrator + +# We need to run this check as sync code (during tests) even in async mode +from redis_om import has_redisearch + +from .conftest import py_test_mark_asyncio + + +if not has_redisearch(): + pytestmark = pytest.mark.skip + + +@pytest_asyncio.fixture +async def models(key_prefix, redis): + """Fixture providing HashModel subclasses for testing field expiration.""" + + class BaseHashModel(HashModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + class Session(BaseHashModel, index=True): + user_id: str + token: str = Field(expire=60) # Default 60 second TTL + refresh_token: str = Field(expire=3600) # 1 hour TTL + + class SimpleModel(BaseHashModel, index=True): + name: str + value: str + + await Migrator(conn=redis).run() + + return namedtuple("Models", ["BaseHashModel", "Session", "SimpleModel"])( + BaseHashModel, Session, SimpleModel + ) + + +# ============================================================================= +# Tests for Field(expire=N) - Declarative field expiration +# ============================================================================= + + +@py_test_mark_asyncio +async def test_field_with_expire_parameter(models): + """Field(expire=N) should store expire value in field info.""" + # Check that the expire parameter is captured in the field info + token_field = models.Session.model_fields.get("token") + assert token_field is not None + # The field_info should have an 'expire' attribute + assert hasattr(token_field, "expire") or hasattr( + getattr(token_field, "json_schema_extra", None) or {}, "expire" + ), "Field should have expire attribute" + + +@py_test_mark_asyncio +async def test_save_applies_field_expiration(models, redis): + """save() should apply HEXPIRE for fields with expire= set.""" + session = models.Session( + user_id="user123", + token="abc123", + refresh_token="refresh456", + ) + await session.save() + + # Check that the field has a TTL set (should be <= 60 seconds) + ttl = await session.field_ttl("token") + assert ttl is not None + assert 0 < ttl <= 60, f"Expected TTL <= 60, got {ttl}" + + # refresh_token should have TTL <= 3600 + refresh_ttl = await session.field_ttl("refresh_token") + assert refresh_ttl is not None + assert 0 < refresh_ttl <= 3600 + + +# ============================================================================= +# Tests for expire_field() method +# ============================================================================= + + +@py_test_mark_asyncio +async def test_expire_field_sets_ttl(models, redis): + """expire_field() should set TTL on a specific field.""" + simple = models.SimpleModel(name="test", value="data") + await simple.save() + + # Set expiration on the 'value' field + result = await simple.expire_field("value", 120) + assert result == 1, "expire_field should return 1 on success" + + # Verify TTL was set + ttl = await simple.field_ttl("value") + assert 0 < ttl <= 120 + + +@py_test_mark_asyncio +async def test_expire_field_nonexistent_field(models, redis): + """expire_field() on non-existent field should return -2.""" + simple = models.SimpleModel(name="test", value="data") + await simple.save() + + result = await simple.expire_field("nonexistent", 60) + assert result == -2, "expire_field on non-existent field should return -2" + + +# ============================================================================= +# Tests for field_ttl() method +# ============================================================================= + + +@py_test_mark_asyncio +async def test_field_ttl_returns_remaining_time(models, redis): + """field_ttl() should return remaining TTL in seconds.""" + simple = models.SimpleModel(name="test", value="data") + await simple.save() + + await simple.expire_field("value", 300) + ttl = await simple.field_ttl("value") + + assert ttl is not None + assert 0 < ttl <= 300 + + +@py_test_mark_asyncio +async def test_field_ttl_no_expiration(models, redis): + """field_ttl() should return -1 for fields without expiration.""" + simple = models.SimpleModel(name="test", value="data") + await simple.save() + + +# ============================================================================= +# Tests for persist_field() method +# ============================================================================= + + +@py_test_mark_asyncio +async def test_persist_field_removes_expiration(models, redis): + """persist_field() should remove TTL from a field.""" + simple = models.SimpleModel(name="test", value="data") + await simple.save() + + # Set expiration first + await simple.expire_field("value", 60) + ttl_before = await simple.field_ttl("value") + assert ttl_before > 0 + + # Remove expiration + result = await simple.persist_field("value") + assert result == 1, "persist_field should return 1 on success" + + # Verify TTL was removed + ttl_after = await simple.field_ttl("value") + assert ttl_after == -1, "Field should have no expiration after persist" + + +@py_test_mark_asyncio +async def test_persist_field_no_expiration(models, redis): + """persist_field() on field without expiration should return -1.""" + simple = models.SimpleModel(name="test", value="data") + await simple.save() + + result = await simple.persist_field("name") + assert result == -1 + + +# ============================================================================= +# Tests for save(field_expirations={...}) parameter +# ============================================================================= + + +@py_test_mark_asyncio +async def test_save_with_field_expirations_param(models, redis): + """save(field_expirations={"field": ttl}) should apply TTLs.""" + simple = models.SimpleModel(name="test", value="important_data") + await simple.save(field_expirations={"value": 180}) + + ttl = await simple.field_ttl("value") + assert 0 < ttl <= 180 + + # name should have no TTL since it wasn't in field_expirations + name_ttl = await simple.field_ttl("name") + assert name_ttl == -1 + + +@py_test_mark_asyncio +async def test_save_field_expirations_overrides_default(models, redis): + """save(field_expirations=) should override Field(expire=) defaults.""" + session = models.Session( + user_id="user123", + token="abc123", + refresh_token="refresh456", + ) + # Override the default 60s TTL with 10s + await session.save(field_expirations={"token": 10}) + + ttl = await session.field_ttl("token") + assert 0 < ttl <= 10, f"Expected TTL <= 10, got {ttl}" + + +# ============================================================================= +# Tests for version/capability checking +# ============================================================================= + + +@py_test_mark_asyncio +async def test_hexpire_not_available_raises_or_warns(models, redis): + """When HEXPIRE is not available, should raise or handle gracefully.""" + simple = models.SimpleModel(name="test", value="data") + await simple.save() + + # Mock the redis client to simulate HEXPIRE not existing + with mock.patch.object( + redis, "hexpire", side_effect=AttributeError("hexpire not found") + ): + # Should raise a clear error or return a sentinel value + with pytest.raises((AttributeError, NotImplementedError)): + await simple.expire_field("value", 60) + + +@py_test_mark_asyncio +async def test_check_hash_field_expiration_support(): + """Test utility function to check if hash field expiration is supported.""" + from aredis_om.model.model import supports_hash_field_expiration + + # This should return True for redis-py >= 5.1.0 + # The actual value depends on installed redis-py version + result = supports_hash_field_expiration() + assert isinstance(result, bool) + + +# ============================================================================= +# Tests for field expiration with updates +# ============================================================================= + + +@py_test_mark_asyncio +async def test_update_preserves_field_expiration(models, redis): + """update() should not reset field expiration by default.""" + session = models.Session( + user_id="user123", + token="abc123", + refresh_token="refresh456", + ) + await session.save() + + # Get initial TTL + initial_ttl = await session.field_ttl("token") + assert initial_ttl > 0 + + # Wait a moment + await asyncio.sleep(0.1) + + # Update a different field + await session.update(user_id="user456") + + # Token TTL should still be set (possibly slightly lower) + updated_ttl = await session.field_ttl("token") + assert updated_ttl > 0 + assert updated_ttl <= initial_ttl + + +@py_test_mark_asyncio +async def test_save_preserves_manually_set_ttl(models, redis): + """ + Calling save() should not overwrite a manually-set TTL with the default. + + Regression test for issue #753: .save() conflicts with TTL on unrelated field + """ + session = models.Session( + user_id="user123", + token="abc123", + refresh_token="refresh456", + ) + await session.save() + + # Default TTL is 60 seconds from Field(expire=60) + default_ttl = await session.field_ttl("token") + assert default_ttl > 0 and default_ttl <= 60 + + # Manually extend TTL to 1 hour + await session.expire_field("token", 3600) + extended_ttl = await session.field_ttl("token") + assert extended_ttl > 60 # Should be ~3600 + + # Modify a different field and save + session.user_id = "user456" + await session.save() + + # The manually-set TTL should be preserved, not reset to 60 seconds + ttl_after_save = await session.field_ttl("token") + assert ttl_after_save > 60, f"TTL was reset to default! Got {ttl_after_save}" + + +@py_test_mark_asyncio +async def test_field_expires_after_ttl(models, redis): + """Field should be deleted after TTL expires.""" + simple = models.SimpleModel(name="test", value="temporary") + await simple.save() + + # Set short expiration (2 seconds to allow for CI slowness) + await simple.expire_field("value", 2) + + # Verify field exists initially + ttl_before = await simple.field_ttl("value") + assert ttl_before > 0 + + # Wait for expiration (3 seconds to ensure it expires even in slow CI) + # Use time.sleep for sync compatibility (asyncio.sleep doesn't convert via unasync) + time.sleep(3) + + # Verify field has expired (TTL should be -2 for non-existent field) + ttl_after = await simple.field_ttl("value") + assert ttl_after == -2, f"Expected -2 (field expired), got {ttl_after}" + + # Check directly with Redis that the field is gone + key = simple.key() + value_exists = await redis.hexists(key, "value") + assert not value_exists, "Field 'value' should have expired"