diff --git a/pyproject.toml b/pyproject.toml index 2d4f5d5..d7bb1f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ markers = [ "transactions: transaction-related tests", "budgets: budget-related tests", "insights: insight analysis tests", + "rules: rule management tests", "unit: marks tests as unit tests", ] diff --git a/src/lampyrid/clients/firefly.py b/src/lampyrid/clients/firefly.py index 84f013f..b135316 100644 --- a/src/lampyrid/clients/firefly.py +++ b/src/lampyrid/clients/firefly.py @@ -20,6 +20,9 @@ InsightGroup, InsightTotal, InsightTransfer, + RuleArray, + RuleSingle, + RuleUpdate, TransactionArray, TransactionSingle, TransactionStore, @@ -470,3 +473,91 @@ async def get_transfer_by_asset_account( self._handle_api_error(r) r.raise_for_status() return InsightTransfer.model_validate(r.json()) + + # ========================================================================= + # Rule Management Methods + # ========================================================================= + + async def get_rules(self, page: int = 1) -> RuleArray: + """Get all rules with pagination.""" + r = await self._client.get('/api/v1/rules', params={'page': page}) + self._handle_api_error(r) + r.raise_for_status() + return RuleArray.model_validate(r.json()) + + async def get_rule(self, rule_id: str) -> RuleSingle: + """Get a single rule by ID.""" + r = await self._client.get(f'/api/v1/rules/{rule_id}') + self._handle_api_error(r) + r.raise_for_status() + return RuleSingle.model_validate(r.json()) + + async def update_rule(self, rule_id: str, rule_update: RuleUpdate) -> RuleSingle: + """Update an existing rule.""" + payload = self._serialize_model(rule_update, exclude_unset=True) + r = await self._client.put(f'/api/v1/rules/{rule_id}', json=payload) + self._handle_api_error(r, payload) + r.raise_for_status() + return RuleSingle.model_validate(r.json()) + + async def test_rule( + self, + rule_id: str, + start_date: date, + end_date: date, + account_ids: Optional[list[str]] = None, + ) -> TransactionArray: + """Test a rule in preview mode (shows matches without changes). + + Args: + rule_id: ID of the rule to test + start_date: Start date for matching transactions + end_date: End date for matching transactions + account_ids: Optional list of account IDs to filter + + Returns: + TransactionArray with matching transactions + + """ + params: Dict[str, Any] = { + 'start': start_date.strftime('%Y-%m-%d'), + 'end': end_date.strftime('%Y-%m-%d'), + } + if account_ids: + params['accounts[]'] = account_ids + + r = await self._client.get(f'/api/v1/rules/{rule_id}/test', params=params) + self._handle_api_error(r) + r.raise_for_status() + return TransactionArray.model_validate(r.json()) + + async def trigger_rule( + self, + rule_id: str, + start_date: date, + end_date: date, + account_ids: Optional[list[str]] = None, + ) -> bool: + """Execute a rule (applies changes to matching transactions). + + Args: + rule_id: ID of the rule to execute + start_date: Start date for matching transactions + end_date: End date for matching transactions + account_ids: Optional list of account IDs to filter + + Returns: + True if the rule execution was accepted (processing is async) + + """ + params: Dict[str, Any] = { + 'start': start_date.strftime('%Y-%m-%d'), + 'end': end_date.strftime('%Y-%m-%d'), + } + if account_ids: + params['accounts[]'] = account_ids + + r = await self._client.post(f'/api/v1/rules/{rule_id}/trigger', params=params) + self._handle_api_error(r) + r.raise_for_status() + return r.status_code == 204 diff --git a/src/lampyrid/models/lampyrid_models.py b/src/lampyrid/models/lampyrid_models.py index 05cface..a5b394c 100644 --- a/src/lampyrid/models/lampyrid_models.py +++ b/src/lampyrid/models/lampyrid_models.py @@ -1,7 +1,7 @@ """Simplified models for MCP tool interfaces with budget support.""" from datetime import date, datetime, timezone -from typing import List, Literal, Optional +from typing import Any, List, Literal, Optional from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -9,6 +9,9 @@ AccountRead, AccountTypeFilter, BudgetRead, + RuleActionKeyword, + RuleRead, + RuleTriggerKeyword, ShortAccountTypeProperty, TransactionArray, TransactionRead, @@ -1049,3 +1052,304 @@ class FinancialSummary(BaseModel): ) start_date: date = Field(..., description='Start of the analysis period') end_date: date = Field(..., description='End of the analysis period') + + +# ============================================================================= +# Rule Models - Request and Response models for rule management +# ============================================================================= + + +class RuleTriggerSimple(BaseModel): + """Simplified trigger model for rule display.""" + + type: RuleTriggerKeyword = Field(..., description='Type of trigger') + value: Optional[str] = Field(None, description='Value for the trigger (if applicable)') + prohibited: bool = Field( + False, + description='Whether this trigger is negated (read-only from API)', + ) + + +class RuleActionSimple(BaseModel): + """Simplified action model for rule display.""" + + type: RuleActionKeyword = Field(..., description='Type of action') + value: Optional[str] = Field(None, description='Value for the action (if applicable)') + + +class Rule(BaseModel): + """Simplified rule model for MCP responses.""" + + id: str = Field(..., description='Unique identifier for the rule') + title: str = Field(..., description='Display name of the rule') + description: Optional[str] = Field(None, description='Optional description of the rule') + active: bool = Field(True, description='Whether the rule is active') + strict: Optional[bool] = Field( + None, + description='If True, ALL triggers must match (AND). If False, any trigger can match (OR).', + ) + stop_processing: bool = Field( + False, description='Whether to stop processing other rules after this one' + ) + trigger: Optional[str] = Field( + None, deprecated='Use triggers list instead', description='Single trigger type' + ) + triggers: List[RuleTriggerSimple] = Field( + default_factory=list, description='List of triggers for this rule' + ) + actions: List[RuleActionSimple] = Field( + default_factory=list, description='List of actions for this rule' + ) + + @classmethod + def from_rule_read(cls, rule_read: RuleRead) -> 'Rule': + """Create a Rule instance from a Firefly RuleRead object.""" + rule_attrs = rule_read.attributes + return cls( + id=rule_read.id, + title=rule_attrs.title, + description=rule_attrs.description, + active=rule_attrs.active if rule_attrs.active is not None else True, + strict=rule_attrs.strict, + stop_processing=( + rule_attrs.stop_processing if rule_attrs.stop_processing is not None else False + ), + trigger=rule_attrs.trigger.value, + triggers=[ + RuleTriggerSimple( + type=t.type, + value=t.value, + prohibited=t.prohibited if t.prohibited is not None else False, + ) + for t in rule_attrs.triggers + ], + actions=[ + RuleActionSimple( + type=a.type, + value=a.value, + ) + for a in rule_attrs.actions + ], + ) + + +class SearchRulesRequest(BaseModel): + """Request model for searching rules.""" + + model_config = ConfigDict(extra='forbid') + + trigger_type: Optional[str] = Field( + None, + description=( + 'Filter by trigger type keyword (substring match, case-insensitive). ' + 'Examples: "description", "amount", "account"' + ), + ) + action_type: Optional[str] = Field( + None, + description=( + 'Filter by action type keyword (substring match, case-insensitive). ' + 'Examples: "set_budget", "set_category", "delete"' + ), + ) + trigger_value_pattern: Optional[str] = Field( + None, + description=( + 'Filter by trigger value using regex pattern (case-insensitive). ' + 'Example: ".*groceries.*" matches any trigger value containing "groceries"' + ), + ) + action_value_pattern: Optional[str] = Field( + None, + description=( + 'Filter by action value using regex pattern (case-insensitive). ' + 'Example: "^[0-9]+$" matches numeric values' + ), + ) + title_contains: Optional[str] = Field( + None, + description=( + 'Filter by rule title (substring match, case-insensitive). ' + 'Example: "auto" matches "Auto-categorize" and "automatic"' + ), + ) + active: Optional[bool] = Field( + None, + description='Filter by active status (True for active, False for inactive, None for all)', + ) + + @model_validator(mode='after') + def validate_search_criteria(self): + """Ensure at least one search criterion is provided.""" + has_criteria = any( + [ + self.trigger_type, + self.action_type, + self.trigger_value_pattern, + self.action_value_pattern, + self.title_contains, + self.active is not None, + ] + ) + if not has_criteria: + raise ValueError('At least one search criterion must be provided') + return self + + +class GetRuleRequest(BaseModel): + """Request model for getting a single rule by ID.""" + + model_config = ConfigDict(extra='forbid') + + id: str = Field(..., description='Unique identifier of the rule to retrieve') + + +class UpdateRuleRequest(BaseModel): + """Request model for updating a rule.""" + + model_config = ConfigDict(extra='forbid') + + rule_id: str = Field(..., description='Unique identifier of the rule to update') + title: Optional[str] = Field(None, description='New title for the rule') + description: Optional[str] = Field(None, description='New description for the rule') + rule_group_id: Optional[str] = Field(None, description='Rule group ID to move the rule to') + active: Optional[bool] = Field(None, description='Whether the rule is active') + strict: Optional[bool] = Field( + None, + description='If True, ALL triggers must match. If False, any trigger can match.', + ) + stop_processing: Optional[bool] = Field( + None, description='Whether to stop processing other rules after this one' + ) + triggers: Optional[List[dict[str, Any]]] = Field( + None, + description=( + 'Array of trigger objects to update. Each object should have: ' + 'type (required), value (optional), active (optional), ' + 'order (optional), stop_processing (optional). ' + 'Note: prohibited field cannot be modified through the API.' + ), + ) + actions: Optional[List[dict[str, Any]]] = Field( + None, + description=( + 'Array of action objects to update. Each object should have: ' + 'type (required), value (optional), active (optional), ' + 'order (optional), stop_processing (optional).' + ), + ) + + +class TestRuleRequest(BaseModel): + """Request model for testing a rule (preview matches).""" + + model_config = ConfigDict(extra='forbid') + + rule_id: str = Field(..., description='Unique identifier of the rule to test') + start_date: date = Field( + ..., + description=( + 'Start date for matching (YYYY-MM-DD). ' + 'Only transactions on or after this date will be checked.' + ), + ) + end_date: date = Field( + ..., + description=( + 'End date for matching (YYYY-MM-DD). ' + 'Only transactions on or before this date will be checked.' + ), + ) + account_ids: Optional[List[str]] = Field( + None, + description=( + 'Optional list of account IDs to limit the test to specific accounts. ' + 'When provided, only transactions involving these accounts are tested.' + ), + ) + + @model_validator(mode='after') + def validate_date_range(self) -> 'TestRuleRequest': + """Validate that start_date is not after end_date.""" + if self.start_date > self.end_date: + raise ValueError('start_date must be on or before end_date') + return self + + +class ExecuteRuleRequest(BaseModel): + """Request model for executing a rule (apply changes).""" + + model_config = ConfigDict(extra='forbid') + + rule_id: str = Field(..., description='Unique identifier of the rule to execute') + start_date: date = Field( + ..., + description=( + 'Start date for execution (YYYY-MM-DD). ' + 'Only transactions on or after this date will be modified.' + ), + ) + end_date: date = Field( + ..., + description=( + 'End date for execution (YYYY-MM-DD). ' + 'Only transactions on or before this date will be modified.' + ), + ) + account_ids: Optional[List[str]] = Field( + None, + description=( + 'Optional list of account IDs to limit execution to specific accounts. ' + 'When provided, only transactions involving these accounts are modified.' + ), + ) + confirm: bool = Field( + False, + description=( + 'REQUIRED: Must be set to True to actually execute the rule. ' + 'This is a safety measure to prevent accidental modifications. ' + 'Always preview with test_rule first!' + ), + ) + + @model_validator(mode='after') + def validate_date_range(self) -> 'ExecuteRuleRequest': + """Validate that start_date is not after end_date.""" + if self.start_date > self.end_date: + raise ValueError('start_date must be on or before end_date') + return self + + +class RuleTestResult(BaseModel): + """Result of testing a rule (preview mode).""" + + rule_id: str = Field(..., description='ID of the rule that was tested') + rule_title: str = Field(..., description='Title of the rule that was tested') + matched_transaction_count: int = Field( + ..., description='Number of transactions that would be affected by this rule' + ) + matched_transactions: List[Transaction] = Field( + ..., + description=( + 'List of transactions that match the rule criteria. ' + 'NOTE: These are the transactions that would be modified if the rule is executed, ' + 'but they are shown here WITHOUT any modifications applied (preview mode).' + ), + ) + + +class RuleExecuteResult(BaseModel): + """Result of executing a rule.""" + + rule_id: str = Field(..., description='ID of the rule that was executed') + rule_title: str = Field(..., description='Title of the rule that was executed') + success: bool = Field(..., description='Whether the rule execution was accepted by the server') + message: str = Field( + ..., + description=( + 'Status message. ' + 'NOTE: Firefly III processes rule execution asynchronously. ' + 'The rule is queued but may still be processing. Check the rule ' + 'later to confirm changes.' + ), + ) diff --git a/src/lampyrid/services/rules.py b/src/lampyrid/services/rules.py new file mode 100644 index 0000000..e4d5837 --- /dev/null +++ b/src/lampyrid/services/rules.py @@ -0,0 +1,253 @@ +"""Rule Service for LamPyrid. + +This service handles rule-related business logic and orchestrates +operations between the MCP tools and the Firefly III client. +""" + +import re +from typing import List + +from pydantic import ValidationError + +from ..clients.firefly import FireflyClient +from ..models.firefly_models import RuleActionUpdate, RuleTriggerUpdate, RuleUpdate +from ..models.lampyrid_models import ( + ExecuteRuleRequest, + GetRuleRequest, + Rule, + RuleExecuteResult, + RuleTestResult, + SearchRulesRequest, + TestRuleRequest, + Transaction, + UpdateRuleRequest, +) + + +class RuleService: + """Service for managing Firefly III rules. + + This service provides a high-level interface for rule operations, + handling filtering, regex matching, and multi-call orchestration + while delegating HTTP operations to the FireflyClient. + """ + + def __init__(self, client: FireflyClient) -> None: + """Initialize the rule service with a FireflyClient instance.""" + self._client = client + + async def search_rules(self, req: SearchRulesRequest) -> List[Rule]: + """Search rules with client-side filtering. + + Since Firefly III has no server-side rule search, this fetches all rules + and filters them client-side using keyword matching and regex patterns. + + Args: + req: Request containing search criteria + + Returns: + List of rules matching the filter criteria + + Raises: + ValueError: If regex patterns are invalid + + """ + # Compile regex patterns early to catch errors + trigger_pattern = None + action_pattern = None + + if req.trigger_value_pattern: + try: + trigger_pattern = re.compile(req.trigger_value_pattern, re.IGNORECASE) + except re.error as e: + raise ValueError(f'Invalid trigger_value_pattern regex: {e}') + + if req.action_value_pattern: + try: + action_pattern = re.compile(req.action_value_pattern, re.IGNORECASE) + except re.error as e: + raise ValueError(f'Invalid action_value_pattern regex: {e}') + + # Fetch all rules with pagination + all_rules = [] + page = 1 + while True: + rule_array = await self._client.get_rules(page) + all_rules.extend(rule_array.data) + + # Check pagination safely (can be None) + if ( + not rule_array.meta.pagination + or rule_array.meta.pagination.current_page >= rule_array.meta.pagination.total_pages + ): + break + page += 1 + + # Filter client-side + filtered_rules = [] + for rule_read in all_rules: + rule_attrs = rule_read.attributes + + # Filter by active status if specified + if req.active is not None and rule_attrs.active != req.active: + continue + + # Filter by title if specified + if req.title_contains: + if req.title_contains.lower() not in rule_attrs.title.lower(): + continue + + # Filter by trigger type keyword if specified + if req.trigger_type: + trigger_keywords = [t.type.value for t in rule_attrs.triggers] + if not any(req.trigger_type.lower() in kw.lower() for kw in trigger_keywords): + continue + + # Filter by trigger value pattern if specified + if trigger_pattern: + trigger_values = [t.value for t in rule_attrs.triggers if t.value is not None] + if not any(trigger_pattern.search(v) for v in trigger_values): + continue + + # Filter by action type keyword if specified + if req.action_type: + action_keywords = [a.type.value for a in rule_attrs.actions] + if not any(req.action_type.lower() in kw.lower() for kw in action_keywords): + continue + + # Filter by action value pattern if specified + if action_pattern: + action_values = [a.value for a in rule_attrs.actions if a.value is not None] + if not any(action_pattern.search(v) for v in action_values): + continue + + # All filters passed, include this rule + filtered_rules.append(rule_read) + + # Convert to simplified models + return [Rule.from_rule_read(rule_read) for rule_read in filtered_rules] + + async def get_rule(self, req: GetRuleRequest) -> Rule: + """Get detailed information for a single rule. + + Args: + req: Request containing the rule ID + + Returns: + Rule details + + """ + rule_single = await self._client.get_rule(req.id) + return Rule.from_rule_read(rule_single.data) + + async def update_rule(self, req: UpdateRuleRequest) -> Rule: + """Update an existing rule. + + Args: + req: Request containing the rule ID and updates + + Returns: + Updated rule details + + Raises: + ValueError: If trigger/action dicts have invalid formats + + """ + # Build the RuleUpdate object from the request + rule_update = RuleUpdate( + title=req.title, + description=req.description, + rule_group_id=req.rule_group_id, + active=req.active, + strict=req.strict, + stop_processing=req.stop_processing, + ) + + # Convert triggers array to RuleTriggerUpdate objects if provided + if req.triggers is not None: + try: + rule_update.triggers = [RuleTriggerUpdate(**t) for t in req.triggers] + except ValidationError as e: + raise ValueError(f'Invalid trigger format: {e}') + + # Convert actions array to RuleActionUpdate objects if provided + if req.actions is not None: + try: + rule_update.actions = [RuleActionUpdate(**a) for a in req.actions] + except ValidationError as e: + raise ValueError(f'Invalid action format: {e}') + + # Call the client to update the rule + rule_single = await self._client.update_rule(req.rule_id, rule_update) + return Rule.from_rule_read(rule_single.data) + + async def test_rule(self, req: TestRuleRequest) -> RuleTestResult: + """Test a rule in preview mode (show matches without changes). + + Args: + req: Request containing rule ID and date range + + Returns: + Test result with matched transactions (preview mode) + + """ + # Get the rule for its title + rule_single = await self._client.get_rule(req.rule_id) + rule_title = rule_single.data.attributes.title + + # Test the rule on the transactions + transaction_array = await self._client.test_rule( + req.rule_id, req.start_date, req.end_date, req.account_ids + ) + + # Convert transactions to simplified models + matched_transactions = [ + Transaction.from_transaction_read(trx_read) for trx_read in transaction_array.data + ] + + return RuleTestResult( + rule_id=req.rule_id, + rule_title=rule_title, + matched_transaction_count=len(matched_transactions), + matched_transactions=matched_transactions, + ) + + async def execute_rule(self, req: ExecuteRuleRequest) -> RuleExecuteResult: + """Execute a rule (apply changes to matching transactions). + + Args: + req: Request containing rule ID and execution parameters + + Returns: + Execution result + + Raises: + ValueError: If confirm is not True (safety check) + + """ + # Safety check: require explicit confirmation + if not req.confirm: + raise ValueError( + 'Rule execution requires confirm=True to prevent accidental ' + 'modifications. Use test_rule first to preview matches.' + ) + + # Get the rule for its title + rule_single = await self._client.get_rule(req.rule_id) + rule_title = rule_single.data.attributes.title + + # Execute the rule + success = await self._client.trigger_rule( + req.rule_id, req.start_date, req.end_date, req.account_ids + ) + + return RuleExecuteResult( + rule_id=req.rule_id, + rule_title=rule_title, + success=success, + message=( + 'Rule execution accepted and queued for processing. ' + 'Firefly III applies rule changes asynchronously. ' + 'Check the rule or transactions later to confirm changes.' + ), + ) diff --git a/src/lampyrid/tools/__init__.py b/src/lampyrid/tools/__init__.py index 496f4ad..d53b3fb 100644 --- a/src/lampyrid/tools/__init__.py +++ b/src/lampyrid/tools/__init__.py @@ -10,6 +10,7 @@ from .accounts import create_accounts_server from .budgets import create_budgets_server from .insights import create_insights_server +from .rules import create_rules_server from .transactions import create_transactions_server @@ -26,9 +27,11 @@ def compose_all_servers(mcp: FastMCP, client: FireflyClient) -> None: transactions_server = create_transactions_server(client) budgets_server = create_budgets_server(client) insights_server = create_insights_server(client) + rules_server = create_rules_server(client) # Mount all servers into the main server without namespaces (static composition) mcp.mount(accounts_server) mcp.mount(transactions_server) mcp.mount(budgets_server) mcp.mount(insights_server) + mcp.mount(rules_server) diff --git a/src/lampyrid/tools/rules.py b/src/lampyrid/tools/rules.py new file mode 100644 index 0000000..d6a49b7 --- /dev/null +++ b/src/lampyrid/tools/rules.py @@ -0,0 +1,102 @@ +"""Rule Management MCP Tools. + +This module provides MCP tools for managing Firefly III rules including +searching, retrieving, updating, testing (preview), and executing rules. +""" + +from typing import List + +from fastmcp import FastMCP + +from ..clients.firefly import FireflyClient +from ..models.lampyrid_models import ( + ExecuteRuleRequest, + GetRuleRequest, + Rule, + RuleExecuteResult, + RuleTestResult, + SearchRulesRequest, + TestRuleRequest, + UpdateRuleRequest, +) +from ..services.rules import RuleService + + +def create_rules_server(client: FireflyClient) -> FastMCP: + """Create a standalone FastMCP server for rule management tools. + + Args: + client: The FireflyClient instance for API interactions + + Returns: + FastMCP server instance with rule management tools registered + + """ + rule_service = RuleService(client) + + rules_mcp = FastMCP('rules') + + @rules_mcp.tool(tags={'rules', 'search'}) + async def search_rules(req: SearchRulesRequest) -> List[Rule]: + """Search your rules using multiple filter criteria. + + Since Firefly III doesn't have a built-in rule search API, this tool + fetches all rules and filters them client-side using keyword matching + and regex patterns for maximum flexibility. + + Provide at least one search criterion. All criteria are combined with AND logic. + """ + return await rule_service.search_rules(req) + + @rules_mcp.tool(tags={'rules'}) + async def get_rule(req: GetRuleRequest) -> Rule: + """Retrieve a single rule by ID with all its triggers and actions. + + Returns the complete rule configuration including triggers (conditions) + and actions (what to do when the rule matches). + """ + return await rule_service.get_rule(req) + + @rules_mcp.tool(tags={'rules', 'modify'}) + async def update_rule(req: UpdateRuleRequest) -> Rule: + """Update an existing rule's configuration. + + You can update any combination of: + - Basic settings (title, description, active status) + - Logic control (strict mode, stop processing) + - Triggers (conditions that must match) + - Actions (changes to apply) + + Note: The 'prohibited' field on triggers is read-only and cannot be modified. + """ + return await rule_service.update_rule(req) + + @rules_mcp.tool(tags={'rules', 'test'}) + async def test_rule(req: TestRuleRequest) -> RuleTestResult: + """Preview which transactions a rule would match WITHOUT applying changes. + + Use this before executing a rule to see what would be affected. + This is a read-only operation - no transactions are modified. + + Returns the list of matching transactions that would be changed + if the rule is executed. + """ + return await rule_service.test_rule(req) + + @rules_mcp.tool(tags={'rules', 'execute'}) + async def execute_rule(req: ExecuteRuleRequest) -> RuleExecuteResult: + """Execute a rule to apply changes to matching transactions. + + WARNING: This is a destructive operation that modifies your transactions. + - Always use test_rule first to preview what will be changed + - Requires confirm=True to prevent accidental execution + - Execution happens asynchronously - changes may take a moment + - Date ranges are REQUIRED - no defaults are applied + + Rule execution in Firefly III is asynchronous. The rule will be queued + for processing and applied in the background. Check your transactions + later to confirm the changes have been applied. + """ + return await rule_service.execute_rule(req) + + return rules_mcp diff --git a/tests/integration/test_rules.py b/tests/integration/test_rules.py new file mode 100644 index 0000000..040152e --- /dev/null +++ b/tests/integration/test_rules.py @@ -0,0 +1,503 @@ +"""Integration tests for rule management tools.""" + +from datetime import date, timedelta +from typing import List + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from lampyrid.clients.firefly import FireflyClient + +# ==================== Helpers ==================== + +# Cached rule group ID for test rule creation +_test_rule_group_id: str | None = None + + +async def _ensure_rule_group(firefly_client: FireflyClient) -> str: + """Get or create a rule group for integration tests. Caches the ID.""" + global _test_rule_group_id + if _test_rule_group_id is not None: + return _test_rule_group_id + + # Check for existing rule groups + r = await firefly_client._client.get('/api/v1/rule-groups') + r.raise_for_status() + groups = r.json().get('data', []) + if groups: + _test_rule_group_id = groups[0]['id'] + return _test_rule_group_id + + # Create one if none exist + r = await firefly_client._client.post( + '/api/v1/rule-groups', + json={'title': 'Test Rules', 'order': 1}, + ) + r.raise_for_status() + _test_rule_group_id = r.json()['data']['id'] + return _test_rule_group_id + + +async def _create_rule_via_api( + firefly_client: FireflyClient, + title: str, + trigger_type: str = 'description_contains', + trigger_value: str = 'test', + action_type: str = 'set_category', + action_value: str = 'Test Category', + active: bool = True, +) -> str: + """Create a rule directly via Firefly III API and return its ID.""" + rule_group_id = await _ensure_rule_group(firefly_client) + r = await firefly_client._client.post( + '/api/v1/rules', + json={ + 'title': title, + 'rule_group_id': rule_group_id, + 'trigger': 'store-journal', + 'active': active, + 'strict': True, + 'triggers': [ + {'type': trigger_type, 'value': trigger_value, 'active': True}, + ], + 'actions': [ + {'type': action_type, 'value': action_value, 'active': True}, + ], + }, + ) + if r.status_code >= 400: + raise RuntimeError(f'Failed to create rule ({r.status_code}): {r.text}') + return r.json()['data']['id'] + + +async def _delete_rule_via_api(firefly_client: FireflyClient, rule_id: str) -> None: + """Delete a rule directly via Firefly III API.""" + r = await firefly_client._client.delete(f'/api/v1/rules/{rule_id}') + r.raise_for_status() + + +# ==================== Fixtures ==================== + + +@pytest.fixture +async def rule_cleanup(firefly_client: FireflyClient): + """Fixture to track and cleanup rules created during tests.""" + created_rule_ids: List[str] = [] + + yield created_rule_ids + + for rule_id in created_rule_ids: + try: + await _delete_rule_via_api(firefly_client, rule_id) + except Exception as e: + print(f'Failed to cleanup rule {rule_id}: {e}') + + +# ==================== Search Rules ==================== + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_search_rules_returns_results( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test searching rules by title keyword.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Categorize Groceries', + trigger_type='description_contains', + trigger_value='groceries', + action_type='set_category', + action_value='Groceries', + ) + rule_cleanup.append(rule_id) + + result = await mcp_client.call_tool( + 'search_rules', + {'req': {'title_contains': 'Categorize Groceries'}}, + ) + rules = result.structured_content['result'] + assert len(rules) >= 1 + assert any(r['title'] == 'Integration Test - Categorize Groceries' for r in rules) + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_search_rules_by_active_status( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test filtering rules by active status.""" + active_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Active Rule', + active=True, + ) + rule_cleanup.append(active_id) + + inactive_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Inactive Rule', + active=False, + ) + rule_cleanup.append(inactive_id) + + # Search for active rules only + result = await mcp_client.call_tool( + 'search_rules', + {'req': {'title_contains': 'Integration Test', 'active': True}}, + ) + rules = result.structured_content['result'] + titles = [r['title'] for r in rules] + assert 'Integration Test - Active Rule' in titles + assert 'Integration Test - Inactive Rule' not in titles + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_search_rules_by_trigger_type( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test filtering rules by trigger type keyword.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Amount Rule', + trigger_type='amount_more', + trigger_value='100', + ) + rule_cleanup.append(rule_id) + + result = await mcp_client.call_tool( + 'search_rules', + {'req': {'title_contains': 'Integration Test', 'trigger_type': 'amount'}}, + ) + rules = result.structured_content['result'] + assert any(r['title'] == 'Integration Test - Amount Rule' for r in rules) + + +# ==================== Get Rule ==================== + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_get_rule( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test retrieving a single rule by ID.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Get Rule', + trigger_type='description_contains', + trigger_value='test-pattern', + action_type='set_category', + action_value='Test', + ) + rule_cleanup.append(rule_id) + + result = await mcp_client.call_tool('get_rule', {'req': {'id': rule_id}}) + rule = result.structured_content + + assert rule['id'] == rule_id + assert rule['title'] == 'Integration Test - Get Rule' + assert rule['active'] is True + assert len(rule['triggers']) == 1 + assert rule['triggers'][0]['type'] == 'description_contains' + assert rule['triggers'][0]['value'] == 'test-pattern' + assert len(rule['actions']) == 1 + assert rule['actions'][0]['type'] == 'set_category' + assert rule['actions'][0]['value'] == 'Test' + + +# ==================== Update Rule ==================== + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_update_rule_title( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test updating a rule's title.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Original Title', + ) + rule_cleanup.append(rule_id) + + result = await mcp_client.call_tool( + 'update_rule', + {'req': {'rule_id': rule_id, 'title': 'Integration Test - Updated Title'}}, + ) + rule = result.structured_content + assert rule['title'] == 'Integration Test - Updated Title' + + # Verify persistence by re-fetching + verify = await mcp_client.call_tool('get_rule', {'req': {'id': rule_id}}) + assert verify.structured_content['title'] == 'Integration Test - Updated Title' + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_update_rule_active_status( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test toggling a rule's active status.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Toggle Active', + active=True, + ) + rule_cleanup.append(rule_id) + + # Deactivate + result = await mcp_client.call_tool( + 'update_rule', + {'req': {'rule_id': rule_id, 'active': False}}, + ) + assert result.structured_content['active'] is False + + # Reactivate + result = await mcp_client.call_tool( + 'update_rule', + {'req': {'rule_id': rule_id, 'active': True}}, + ) + assert result.structured_content['active'] is True + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_update_rule_triggers( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test updating a rule's triggers.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Update Triggers', + trigger_type='description_contains', + trigger_value='old-pattern', + ) + rule_cleanup.append(rule_id) + + result = await mcp_client.call_tool( + 'update_rule', + { + 'req': { + 'rule_id': rule_id, + 'triggers': [ + {'type': 'description_contains', 'value': 'new-pattern'}, + ], + } + }, + ) + rule = result.structured_content + assert len(rule['triggers']) == 1 + assert rule['triggers'][0]['value'] == 'new-pattern' + + +# ==================== Test Rule (Preview) ==================== + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_test_rule_preview( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test previewing which transactions a rule would match.""" + # Create a rule that matches seed transactions (description contains 'Seed:') + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Preview Rule', + trigger_type='description_contains', + trigger_value='Seed:', + action_type='set_category', + action_value='Matched', + ) + rule_cleanup.append(rule_id) + + today = date.today() + start = today.replace(day=1) + end = start + timedelta(days=31) + + result = await mcp_client.call_tool( + 'test_rule', + { + 'req': { + 'rule_id': rule_id, + 'start_date': start.isoformat(), + 'end_date': end.isoformat(), + } + }, + ) + test_result = result.structured_content + + assert test_result['rule_id'] == rule_id + assert test_result['rule_title'] == 'Integration Test - Preview Rule' + assert isinstance(test_result['matched_transaction_count'], int) + # Seed transactions from conftest should match + assert test_result['matched_transaction_count'] >= 0 + assert isinstance(test_result['matched_transactions'], list) + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_test_rule_no_matches( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test previewing a rule that matches no transactions.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - No Matches', + trigger_type='description_contains', + trigger_value='zzz_nonexistent_pattern_zzz', + ) + rule_cleanup.append(rule_id) + + today = date.today() + + result = await mcp_client.call_tool( + 'test_rule', + { + 'req': { + 'rule_id': rule_id, + 'start_date': today.isoformat(), + 'end_date': today.isoformat(), + } + }, + ) + test_result = result.structured_content + assert test_result['matched_transaction_count'] == 0 + assert test_result['matched_transactions'] == [] + + +# ==================== Execute Rule ==================== + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_execute_rule_requires_confirm( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test that execute_rule without confirm=True raises an error.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Confirm Guard', + ) + rule_cleanup.append(rule_id) + + today = date.today() + + with pytest.raises(ToolError, match='confirm=True'): + await mcp_client.call_tool( + 'execute_rule', + { + 'req': { + 'rule_id': rule_id, + 'start_date': today.isoformat(), + 'end_date': today.isoformat(), + 'confirm': False, + } + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_execute_rule_with_confirm( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test executing a rule with proper confirmation.""" + # Use add_tag action — non-destructive and won't interfere with other tests + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Execute Rule', + trigger_type='description_contains', + trigger_value='Seed:', + action_type='add_tag', + action_value='integration-test-tag', + ) + rule_cleanup.append(rule_id) + + today = date.today() + start = today.replace(day=1) + end = start + timedelta(days=31) + + result = await mcp_client.call_tool( + 'execute_rule', + { + 'req': { + 'rule_id': rule_id, + 'start_date': start.isoformat(), + 'end_date': end.isoformat(), + 'confirm': True, + } + }, + ) + exec_result = result.structured_content + + assert exec_result['rule_id'] == rule_id + assert exec_result['rule_title'] == 'Integration Test - Execute Rule' + assert exec_result['success'] is True + assert 'asynchronously' in exec_result['message'] + + +# ==================== Date Validation ==================== + + +@pytest.mark.asyncio +@pytest.mark.rules +@pytest.mark.integration +async def test_test_rule_rejects_inverted_dates( + mcp_client: Client, + firefly_client: FireflyClient, + rule_cleanup: List[str], +): + """Test that test_rule rejects start_date after end_date.""" + rule_id = await _create_rule_via_api( + firefly_client, + title='Integration Test - Date Validation', + ) + rule_cleanup.append(rule_id) + + with pytest.raises(ToolError, match='start_date'): + await mcp_client.call_tool( + 'test_rule', + { + 'req': { + 'rule_id': rule_id, + 'start_date': '2024-12-31', + 'end_date': '2024-01-01', + } + }, + ) diff --git a/tests/unit/test_rules_service.py b/tests/unit/test_rules_service.py new file mode 100644 index 0000000..8c63c0f --- /dev/null +++ b/tests/unit/test_rules_service.py @@ -0,0 +1,486 @@ +"""Unit tests for RuleService.""" + +from datetime import date +from unittest.mock import AsyncMock + +import pytest + +from lampyrid.models.firefly_models import ( + Meta, + ObjectLink, + PageLink, + Pagination, + RuleAction, + RuleActionKeyword, + RuleArray, + RuleRead, + RuleSingle, + RuleTrigger, + RuleTriggerKeyword, + RuleTriggerType, + TransactionArray, +) +from lampyrid.models.firefly_models import ( + Rule as RuleAttrs, +) +from lampyrid.models.lampyrid_models import ( + ExecuteRuleRequest, + GetRuleRequest, + SearchRulesRequest, + TestRuleRequest, + UpdateRuleRequest, +) +from lampyrid.services.rules import RuleService + + +def _make_rule_attrs( + title: str = 'Test Rule', + description: str = 'Test Description', + active: bool = True, + strict: bool = True, + stop_processing: bool = False, + trigger_type: str = 'description_contains', + trigger_value: str = 'test', + action_type: str = 'set_category', + action_value: str = 'Test Category', +) -> RuleAttrs: + """Create RuleAttrs for testing.""" + return RuleAttrs( + title=title, + description=description, + rule_group_id='1', + active=active, + strict=strict, + stop_processing=stop_processing, + trigger=RuleTriggerType('store-journal'), + triggers=[ + RuleTrigger( + type=RuleTriggerKeyword(trigger_type), + value=trigger_value, + prohibited=False, + active=True, + ) + ], + actions=[ + RuleAction( + type=RuleActionKeyword(action_type), + value=action_value, + active=True, + ) + ], + ) + + +def _make_rule_read( + rule_id: str = '1', + title: str = 'Test Rule', + **attrs_kwargs, +) -> RuleRead: + """Create RuleRead for testing.""" + return RuleRead( + type='rules', + id=rule_id, + attributes=_make_rule_attrs(title=title, **attrs_kwargs), + links=ObjectLink(self='http://example.com'), + ) + + +def _make_rule_array(rules: list[RuleRead]) -> RuleArray: + """Create RuleArray with pagination.""" + return RuleArray( + data=rules, + meta=Meta( + pagination=Pagination( + total=len(rules), + count=len(rules), + per_page=50, + current_page=1, + total_pages=1, + ) + ), + links=PageLink( + self='http://example.com', + first='http://example.com?page=1', + last='http://example.com?page=1', + ), + ) + + +class TestRuleService: + """Test cases for RuleService class.""" + + @pytest.fixture + def mock_client(self): + """Create a mock FireflyClient.""" + return AsyncMock() + + @pytest.fixture + def service(self, mock_client): + """Create a RuleService with mocked client.""" + return RuleService(mock_client) + + @pytest.mark.asyncio + async def test_search_rules_by_title_contains(self, service, mock_client): + """Test searching rules by title contains.""" + rule1 = _make_rule_read('1', 'Auto-categorize groceries') + rule2 = _make_rule_read('2', 'Manual invoice processing') + mock_client.get_rules.return_value = _make_rule_array([rule1, rule2]) + + req = SearchRulesRequest(title_contains='auto') + result = await service.search_rules(req) + + assert len(result) == 1 + assert result[0].title == 'Auto-categorize groceries' + + @pytest.mark.asyncio + async def test_search_rules_by_active_status(self, service, mock_client): + """Test searching rules by active status.""" + rule1 = _make_rule_read('1', 'Active Rule', active=True) + rule2 = _make_rule_read('2', 'Inactive Rule', active=False) + mock_client.get_rules.return_value = _make_rule_array([rule1, rule2]) + + req = SearchRulesRequest(active=True) + result = await service.search_rules(req) + + assert len(result) == 1 + assert result[0].title == 'Active Rule' + + @pytest.mark.asyncio + async def test_search_rules_by_trigger_type(self, service, mock_client): + """Test searching rules by trigger type keyword.""" + rule1 = _make_rule_read( + '1', + 'Description Trigger', + trigger_type='description_contains', + ) + rule2 = _make_rule_read('2', 'Amount Trigger', trigger_type='amount_more') + mock_client.get_rules.return_value = _make_rule_array([rule1, rule2]) + + req = SearchRulesRequest(trigger_type='description') + result = await service.search_rules(req) + + assert len(result) == 1 + assert result[0].title == 'Description Trigger' + + @pytest.mark.asyncio + async def test_search_rules_by_action_type(self, service, mock_client): + """Test searching rules by action type keyword.""" + rule1 = _make_rule_read( + '1', + 'Set Budget Rule', + action_type='set_budget', + ) + rule2 = _make_rule_read('2', 'Set Category Rule', action_type='set_category') + mock_client.get_rules.return_value = _make_rule_array([rule1, rule2]) + + req = SearchRulesRequest(action_type='budget') + result = await service.search_rules(req) + + assert len(result) == 1 + assert result[0].title == 'Set Budget Rule' + + @pytest.mark.asyncio + async def test_search_rules_by_trigger_value_pattern(self, service, mock_client): + """Test searching rules by trigger value regex pattern.""" + rule1 = _make_rule_read( + '1', + 'Groceries Rule', + trigger_value='groceries', + ) + rule2 = _make_rule_read('2', 'Utilities Rule', trigger_value='utilities') + mock_client.get_rules.return_value = _make_rule_array([rule1, rule2]) + + req = SearchRulesRequest(trigger_value_pattern='.*groceries.*') + result = await service.search_rules(req) + + assert len(result) == 1 + assert result[0].title == 'Groceries Rule' + + @pytest.mark.asyncio + async def test_search_rules_by_action_value_pattern(self, service, mock_client): + """Test searching rules by action value regex pattern.""" + rule1 = _make_rule_read( + '1', + 'Budget 100 Rule', + action_value='100.00', + ) + rule2 = _make_rule_read('2', 'Budget Text Rule', action_value='Food') + mock_client.get_rules.return_value = _make_rule_array([rule1, rule2]) + + req = SearchRulesRequest(action_value_pattern='^[0-9]+') + result = await service.search_rules(req) + + assert len(result) == 1 + assert result[0].title == 'Budget 100 Rule' + + @pytest.mark.asyncio + async def test_search_rules_invalid_regex_trigger(self, service, mock_client): + """Test that invalid trigger regex pattern raises ValueError.""" + mock_client.get_rules.return_value = _make_rule_array([]) + + req = SearchRulesRequest(trigger_value_pattern='[invalid') + with pytest.raises(ValueError, match='Invalid trigger_value_pattern regex'): + await service.search_rules(req) + + @pytest.mark.asyncio + async def test_search_rules_invalid_regex_action(self, service, mock_client): + """Test that invalid action regex pattern raises ValueError.""" + mock_client.get_rules.return_value = _make_rule_array([]) + + req = SearchRulesRequest(action_value_pattern='[invalid') + with pytest.raises(ValueError, match='Invalid action_value_pattern regex'): + await service.search_rules(req) + + @pytest.mark.asyncio + async def test_search_rules_with_pagination(self, service, mock_client): + """Test that search handles pagination correctly.""" + rule1 = _make_rule_read('1', 'Rule 1') + rule2 = _make_rule_read('2', 'Rule 2') + + # Create paginated responses + page1_response = RuleArray( + data=[rule1], + meta=Meta( + pagination=Pagination( + total=2, + count=1, + per_page=1, + current_page=1, + total_pages=2, + ) + ), + links=PageLink( + self='http://example.com?page=1', + first='http://example.com?page=1', + last='http://example.com?page=2', + ), + ) + page2_response = RuleArray( + data=[rule2], + meta=Meta( + pagination=Pagination( + total=2, + count=1, + per_page=1, + current_page=2, + total_pages=2, + ) + ), + links=PageLink( + self='http://example.com?page=2', + first='http://example.com?page=1', + last='http://example.com?page=2', + ), + ) + + mock_client.get_rules.side_effect = [page1_response, page2_response] + + req = SearchRulesRequest(active=True) + result = await service.search_rules(req) + + assert len(result) == 2 + assert result[0].title == 'Rule 1' + assert result[1].title == 'Rule 2' + assert mock_client.get_rules.call_count == 2 + + @pytest.mark.asyncio + async def test_search_rules_with_none_pagination(self, service, mock_client): + """Test that search handles None pagination metadata.""" + rule1 = _make_rule_read('1', 'Rule 1') + response = RuleArray( + data=[rule1], + meta=Meta(pagination=None), + links=PageLink( + self='http://example.com', + first='http://example.com', + last='http://example.com', + ), + ) + mock_client.get_rules.return_value = response + + req = SearchRulesRequest(active=True) + result = await service.search_rules(req) + + assert len(result) == 1 + assert result[0].title == 'Rule 1' + + @pytest.mark.asyncio + async def test_get_rule(self, service, mock_client): + """Test getting a single rule by ID.""" + rule_read = _make_rule_read('42', 'My Rule') + rule_single = RuleSingle(data=rule_read) + mock_client.get_rule.return_value = rule_single + + req = GetRuleRequest(id='42') + result = await service.get_rule(req) + + assert result.id == '42' + assert result.title == 'My Rule' + assert len(result.triggers) == 1 + assert len(result.actions) == 1 + mock_client.get_rule.assert_called_once_with('42') + + @pytest.mark.asyncio + async def test_update_rule_basic_fields(self, service, mock_client): + """Test updating basic rule fields.""" + rule_read = _make_rule_read('42', 'Updated Rule', active=False) + rule_single = RuleSingle(data=rule_read) + mock_client.update_rule.return_value = rule_single + + req = UpdateRuleRequest( + rule_id='42', + title='Updated Rule', + active=False, + ) + result = await service.update_rule(req) + + assert result.id == '42' + assert result.title == 'Updated Rule' + assert result.active is False + mock_client.update_rule.assert_called_once() + + @pytest.mark.asyncio + async def test_update_rule_with_triggers(self, service, mock_client): + """Test updating rule with new triggers.""" + rule_read = _make_rule_read('42', 'Rule with New Triggers') + rule_single = RuleSingle(data=rule_read) + mock_client.update_rule.return_value = rule_single + + req = UpdateRuleRequest( + rule_id='42', + triggers=[ + {'type': 'description_contains', 'value': 'groceries'}, + ], + ) + result = await service.update_rule(req) + + assert result.id == '42' + mock_client.update_rule.assert_called_once() + + # Check the call to verify triggers were converted + call_args = mock_client.update_rule.call_args + rule_update = call_args[0][1] + assert rule_update.triggers is not None + assert len(rule_update.triggers) == 1 + + @pytest.mark.asyncio + async def test_update_rule_invalid_trigger_dict(self, service, mock_client): + """Test that invalid trigger dict raises ValueError.""" + req = UpdateRuleRequest( + rule_id='42', + triggers=[ + {'type': 'not_a_valid_keyword'}, # Invalid enum value + ], + ) + with pytest.raises(ValueError, match='Invalid trigger format'): + await service.update_rule(req) + + @pytest.mark.asyncio + async def test_update_rule_invalid_action_dict(self, service, mock_client): + """Test that invalid action dict raises ValueError.""" + req = UpdateRuleRequest( + rule_id='42', + actions=[ + {'type': 'not_a_valid_keyword'}, # Invalid enum value + ], + ) + with pytest.raises(ValueError, match='Invalid action format'): + await service.update_rule(req) + + @pytest.mark.asyncio + async def test_test_rule(self, service, mock_client): + """Test the test_rule method (preview mode).""" + rule_single = RuleSingle(data=_make_rule_read('42', 'Test Rule')) + mock_client.get_rule.return_value = rule_single + + # Mock empty transaction array + mock_client.test_rule.return_value = TransactionArray( + data=[], + meta=Meta( + pagination=Pagination( + total=0, + count=0, + per_page=50, + current_page=1, + total_pages=1, + ) + ), + links=PageLink( + self='http://example.com', + first='http://example.com', + last='http://example.com', + ), + ) + + req = TestRuleRequest( + rule_id='42', + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + ) + result = await service.test_rule(req) + + assert result.rule_id == '42' + assert result.rule_title == 'Test Rule' + assert result.matched_transaction_count == 0 + assert result.matched_transactions == [] + mock_client.get_rule.assert_called_once_with('42') + mock_client.test_rule.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_rule_without_confirm(self, service, mock_client): + """Test that execute_rule without confirm=True raises ValueError.""" + req = ExecuteRuleRequest( + rule_id='42', + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + confirm=False, + ) + with pytest.raises(ValueError, match='confirm=True'): + await service.execute_rule(req) + + @pytest.mark.asyncio + async def test_execute_rule_with_confirm(self, service, mock_client): + """Test executing a rule with proper confirmation.""" + rule_single = RuleSingle(data=_make_rule_read('42', 'Execute Rule')) + mock_client.get_rule.return_value = rule_single + mock_client.trigger_rule.return_value = True + + req = ExecuteRuleRequest( + rule_id='42', + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + confirm=True, + ) + result = await service.execute_rule(req) + + assert result.rule_id == '42' + assert result.rule_title == 'Execute Rule' + assert result.success is True + assert 'asynchronously' in result.message + mock_client.get_rule.assert_called_once_with('42') + mock_client.trigger_rule.assert_called_once() + + @pytest.mark.asyncio + async def test_search_rules_no_criteria(self): + """Test that search_rules without any criteria raises ValueError.""" + with pytest.raises(ValueError, match='At least one search criterion'): + SearchRulesRequest() + + def test_test_rule_request_rejects_inverted_dates(self): + """Test that TestRuleRequest rejects start_date after end_date.""" + with pytest.raises(ValueError, match='start_date must be on or before end_date'): + TestRuleRequest( + rule_id='1', + start_date=date(2024, 12, 31), + end_date=date(2024, 1, 1), + ) + + def test_execute_rule_request_rejects_inverted_dates(self): + """Test that ExecuteRuleRequest rejects start_date after end_date.""" + with pytest.raises(ValueError, match='start_date must be on or before end_date'): + ExecuteRuleRequest( + rule_id='1', + start_date=date(2024, 12, 31), + end_date=date(2024, 1, 1), + confirm=True, + )