diff --git a/.env.example b/.env.example index cf81685d..e5a7dc1d 100644 --- a/.env.example +++ b/.env.example @@ -34,6 +34,22 @@ LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL # MultiMind Gateway Settings MULTIMIND_LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL +# Gateway Authentication & Security +# Comma-separated API keys for X-API-Key auth (optional) +# API_KEYS=key_one,key_two + +# JWT secret used to sign/verify tokens (required if using /token JWT auth) +# Use a long random value (32+ chars) +# JWT_SECRET=replace_with_a_long_random_secret + +# JSON object mapping username -> bcrypt hashed password (required for /token) +# If passlib is missing, install it: +# pip install passlib[bcrypt] + +# Generate hash example: +# python -c "from passlib.context import CryptContext; c=CryptContext(schemes=['bcrypt'], deprecated='auto'); print(c.hash('YourStrongPassword123!'))" +# JWT_USERS_JSON={"admin":"$2b$12$replace_with_bcrypt_hash","testuser":"$2b$12$replace_with_bcrypt_hash"} + # Optional: Model-specific Settings # Uncomment and modify these if you need custom settings for specific models diff --git a/docs/api_reference/rag_api_server_setup.md b/docs/api_reference/rag_api_server_setup.md index d4b00cd0..7c4f2e26 100644 --- a/docs/api_reference/rag_api_server_setup.md +++ b/docs/api_reference/rag_api_server_setup.md @@ -72,6 +72,68 @@ echo %ANTHROPIC_API_KEY% You should see your API key printed. If nothing shows, the server will use local HuggingFace models. +## Step 2.1: Configure JWT/Auth Security Variables (`.env`) + +If you want to use `/token` JWT auth (production-style), add these values to your `.env`. +This matches `.env.example` (`API_KEYS`, `JWT_SECRET`, `JWT_USERS_JSON`). + +### Required/Optional fields + +- `API_KEYS` (optional): comma-separated API keys for `X-API-Key` auth +- `JWT_SECRET` (required for `/token`): long random secret (32+ chars) +- `JWT_USERS_JSON` (required for `/token`): JSON mapping username -> bcrypt hash + +### 1) Generate `JWT_SECRET` + +PowerShell: +```powershell +python -c "import secrets; print(secrets.token_urlsafe(48))" +``` + +Copy output into: +```env +JWT_SECRET=your_generated_secret_here +``` + +### 2) Generate bcrypt password hash + +Install passlib+bcrypt (if needed): +```bash +pip install passlib[bcrypt] +``` + +Generate hash: +```powershell +python -c "from passlib.context import CryptContext; c=CryptContext(schemes=['bcrypt'], deprecated='auto'); print(c.hash('YourStrongPassword123!'))" +``` + +### 3) Set `JWT_USERS_JSON` + +Use the generated hash in JSON (single line): +```env +JWT_USERS_JSON={"admin":"$2b$12$replace_with_bcrypt_hash","testuser":"$2b$12$replace_with_bcrypt_hash"} +``` + +### 4) Optional `API_KEYS` + +```env +API_KEYS=key_one,key_two,key_three +``` + +### Example secure `.env` block + +```env +OPENAI_API_KEY=your-openai-key +API_KEYS=internal_key_1,internal_key_2 +JWT_SECRET=replace_with_long_random_secret_32plus_chars +JWT_USERS_JSON={"admin":"$2b$12$replace_with_bcrypt_hash"} +``` + +**Important:** +- Do not commit real secrets to git. +- Do not use default/fallback secrets in production. +- If `JWT_SECRET`/`JWT_USERS_JSON` are missing, `/token` auth should be considered not securely configured. + ## Step 3: Start the RAG API Server ### Method 1: Run as module (RECOMMENDED) diff --git a/multimind/api/unified_api.py b/multimind/api/unified_api.py index dab2d180..33c881a1 100644 --- a/multimind/api/unified_api.py +++ b/multimind/api/unified_api.py @@ -5,10 +5,12 @@ from fastapi import FastAPI, HTTPException from typing import Dict, Any import asyncio +import logging from ..models.moe import MoEFactory from ..types import UnifiedRequest, UnifiedResponse app = FastAPI(title="Unified Multi-Modal API") +logger = logging.getLogger(__name__) # Initialize components try: @@ -87,10 +89,13 @@ async def process_request(request: UnifiedRequest): } ) - except Exception as e: + except HTTPException: + raise + except Exception: + logger.exception("Error processing unified API request") raise HTTPException( status_code=500, - detail=f"Error processing request: {str(e)}" + detail="Internal server error" ) @app.get("/v1/models") diff --git a/multimind/compliance/audit.py b/multimind/compliance/audit.py index 4014637e..33cb0341 100644 --- a/multimind/compliance/audit.py +++ b/multimind/compliance/audit.py @@ -3,7 +3,7 @@ """ from typing import List, Dict, Any, Optional -from datetime import datetime +from datetime import datetime, timedelta from pydantic import BaseModel, Field from .governance import GovernanceConfig import json diff --git a/multimind/compliance/data_protection.py b/multimind/compliance/data_protection.py index 21bb070d..ae191092 100644 --- a/multimind/compliance/data_protection.py +++ b/multimind/compliance/data_protection.py @@ -8,6 +8,7 @@ import hashlib import hmac import json +import os from cryptography.fernet import Fernet from .governance import GovernanceConfig, DataCategory @@ -34,7 +35,6 @@ async def protect_data( ) -> Dict[str, Any]: """Protect data according to its category.""" protected_data = { - "original_data": data, "category": category, "metadata": metadata or {}, "protection_applied": [] diff --git a/multimind/compliance/gdpr.py b/multimind/compliance/gdpr.py index 5f17b3d5..cc8dcf2a 100644 --- a/multimind/compliance/gdpr.py +++ b/multimind/compliance/gdpr.py @@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional from datetime import datetime, timedelta from pydantic import BaseModel, Field -from .governance import GovernanceConfig, ComplianceMetadata, DataCategory +from .governance import GovernanceConfig, ComplianceMetadata, DataCategory, Regulation class GDPRCompliance(BaseModel): """GDPR compliance manager.""" diff --git a/multimind/compliance/model_training.py b/multimind/compliance/model_training.py index 5045322c..3f5c3bbd 100644 --- a/multimind/compliance/model_training.py +++ b/multimind/compliance/model_training.py @@ -8,7 +8,7 @@ from datetime import datetime from pydantic import BaseModel, Field import numpy as np -from dataclasses import dataclass +from dataclasses import dataclass, field try: import torch from torch.utils.data import Dataset, DataLoader @@ -27,7 +27,7 @@ class ComplianceMetrics: privacy_score: float transparency_score: float fairness_score: float - timestamp: datetime = Field(default_factory=datetime.utcnow) + timestamp: datetime = field(default_factory=datetime.utcnow) class ComplianceDataset: """Dataset wrapper that ensures compliance during training.""" diff --git a/multimind/core/chat.py b/multimind/core/chat.py index 3d897cf5..9b45b5ff 100644 --- a/multimind/core/chat.py +++ b/multimind/core/chat.py @@ -8,8 +8,7 @@ from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Union -from dataclasses import dataclass, asdict -from pydantic import BaseModel +from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -18,17 +17,17 @@ class ChatMessage(BaseModel): role: str content: str model: str - timestamp: datetime = datetime.now() - metadata: Dict = {} + timestamp: datetime = Field(default_factory=datetime.now) + metadata: Dict = Field(default_factory=dict) class ChatSession(BaseModel): """A chat session with history and metadata""" session_id: str model: str - created_at: datetime = datetime.now() - updated_at: datetime = datetime.now() - messages: List[ChatMessage] = [] - metadata: Dict = {} + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + messages: List[ChatMessage] = Field(default_factory=list) + metadata: Dict = Field(default_factory=dict) system_prompt: Optional[str] = None def add_message(self, role: str, content: str, model: str, metadata: Optional[Dict[str, Union[str, int, float]]] = None) -> None: @@ -50,10 +49,11 @@ def get_context(self, max_messages: int = 10) -> List[Dict[str, str]]: def export(self, format: str = "json") -> Union[str, Dict]: """Export session to different formats""" + session_data = self.model_dump(mode="json") if format == "json": - return json.dumps(asdict(self), default=str) + return json.dumps(session_data) elif format == "dict": - return asdict(self) + return session_data else: raise ValueError(f"Unsupported export format: {format}") @@ -75,7 +75,7 @@ def save(self, directory: Union[str, Path]) -> Path: directory.mkdir(parents=True, exist_ok=True) file_path = directory / f"chat_{self.session_id}.json" with open(file_path, "w") as f: - json.dump(asdict(self), f, default=str, indent=2) + json.dump(self.model_dump(mode="json"), f, indent=2) return file_path class ChatManager: diff --git a/multimind/core/models.py b/multimind/core/models.py index 97055751..719500e9 100644 --- a/multimind/core/models.py +++ b/multimind/core/models.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from typing import Dict, List, Optional, Union -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime logger = logging.getLogger(__name__) @@ -18,7 +18,7 @@ class ModelResponse: model: str usage: Optional[Dict[str, int]] = None finish_reason: Optional[str] = None - timestamp: str = datetime.now().isoformat() + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) class ModelHandler(ABC): """Abstract base class for model handlers""" diff --git a/multimind/core/provider.py b/multimind/core/provider.py index 6a7301ce..710189a1 100644 --- a/multimind/core/provider.py +++ b/multimind/core/provider.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Union, Any from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field from datetime import datetime class ProviderCapability(str, Enum): @@ -49,8 +49,8 @@ class GenerationResult(BaseModel): model_name: str latency_ms: float cost_estimate_usd: float - metadata: Dict[str, Any] = {} - created_at: datetime = datetime.now() + metadata: Dict[str, Any] = Field(default_factory=dict) + created_at: datetime = Field(default_factory=datetime.now) class EmbeddingResult(BaseModel): """Standardized result from embeddings generation.""" @@ -60,7 +60,7 @@ class EmbeddingResult(BaseModel): model_name: str latency_ms: float cost_estimate_usd: float - metadata: Dict[str, Any] = {} + metadata: Dict[str, Any] = Field(default_factory=dict) class ImageAnalysisResult(BaseModel): """Standardized result from image analysis.""" @@ -71,7 +71,7 @@ class ImageAnalysisResult(BaseModel): model_name: str latency_ms: float cost_estimate_usd: float - metadata: Dict[str, Any] = {} + metadata: Dict[str, Any] = Field(default_factory=dict) class ProviderAdapter(ABC): """Base class for provider adapters.""" diff --git a/multimind/core/router.py b/multimind/core/router.py index 9dd6c96a..ea7a7f7c 100644 --- a/multimind/core/router.py +++ b/multimind/core/router.py @@ -234,52 +234,55 @@ async def _handle_single_provider( call_kwargs = dict(kwargs) call_kwargs.pop("provider", None) model_arg = call_kwargs.pop("model", None) - start = time.time() - try: - if task_type == TaskType.TEXT_GENERATION: - if model_arg is not None: - result = await provider.generate_text(model=model_arg, prompt=input_data, **call_kwargs) - else: - result = await provider.generate_text(prompt=input_data, **call_kwargs) - elif task_type == TaskType.EMBEDDINGS: - if model_arg is not None: - result = await provider.generate_embeddings(text=input_data, model=model_arg, **call_kwargs) - else: - result = await provider.generate_embeddings(text=input_data, **call_kwargs) - elif task_type == TaskType.IMAGE_ANALYSIS: - if model_arg is not None: - result = await provider.analyze_image(image_data=input_data, model=model_arg, **call_kwargs) + max_attempts = self.fallback_policy.max_retries + 1 if self.fallback_policy.strategy == "retry" else 1 + last_error = None + + for _ in range(max_attempts): + start = time.time() + try: + if task_type == TaskType.TEXT_GENERATION: + if model_arg is not None: + result = await provider.generate_text(model=model_arg, prompt=input_data, **call_kwargs) + else: + result = await provider.generate_text(prompt=input_data, **call_kwargs) + elif task_type == TaskType.EMBEDDINGS: + if model_arg is not None: + result = await provider.generate_embeddings(text=input_data, model=model_arg, **call_kwargs) + else: + result = await provider.generate_embeddings(text=input_data, **call_kwargs) + elif task_type == TaskType.IMAGE_ANALYSIS: + if model_arg is not None: + result = await provider.analyze_image(image_data=input_data, model=model_arg, **call_kwargs) + else: + result = await provider.analyze_image(image_data=input_data, **call_kwargs) else: - result = await provider.analyze_image(image_data=input_data, **call_kwargs) - else: - raise ValueError(f"Unsupported task type: {task_type}") - latency = time.time() - start - quality = getattr(result, 'quality', None) or (result.metadata.get('quality') if hasattr(result, 'metadata') else None) - feedback = getattr(result, 'feedback', None) or (result.metadata.get('feedback') if hasattr(result, 'metadata') else None) - self.performance_tracker.record(provider_name, success=True, latency=latency, quality=quality, feedback=feedback) - return result - except Exception as e: - latency = time.time() - start - self.performance_tracker.record(provider_name, success=False, latency=latency) - self.fallback_policy.record_failure(provider_name) - # Centralized fallback logic - if self.fallback_policy.strategy == "retry" and self.fallback_policy.failure_counts[provider_name] <= self.fallback_policy.max_retries: - # Retry the same provider - return await self._handle_single_provider(task_type, input_data, config, use_adaptive_routing, **kwargs) - elif self.fallback_policy.strategy == "switch_provider" and len(config.preferred_providers) > 1: - # Switch to next best provider - remaining = [p for p in config.preferred_providers if p != provider_name] - if remaining: - next_provider = self.performance_tracker.get_best_provider(remaining) - if self.fallback_policy.notify_user: - print(self.fallback_policy.get_fallback_message(provider_name, e)) - # Try next provider - config_copy = config.copy() - config_copy.preferred_providers = remaining - return await self._handle_single_provider(task_type, input_data, config_copy, use_adaptive_routing, **kwargs) - if self.fallback_policy.notify_user: - print(self.fallback_policy.get_fallback_message(provider_name, e)) - raise + raise ValueError(f"Unsupported task type: {task_type}") + latency = time.time() - start + quality = getattr(result, 'quality', None) or (result.metadata.get('quality') if hasattr(result, 'metadata') else None) + feedback = getattr(result, 'feedback', None) or (result.metadata.get('feedback') if hasattr(result, 'metadata') else None) + self.performance_tracker.record(provider_name, success=True, latency=latency, quality=quality, feedback=feedback) + return result + except Exception as e: + latency = time.time() - start + self.performance_tracker.record(provider_name, success=False, latency=latency) + self.fallback_policy.record_failure(provider_name) + last_error = e + + # Centralized fallback logic after retry attempts are exhausted. + if self.fallback_policy.strategy == "switch_provider" and len(config.preferred_providers) > 1: + # Switch to next best provider + remaining = [p for p in config.preferred_providers if p != provider_name] + if remaining: + next_provider = self.performance_tracker.get_best_provider(remaining) + if self.fallback_policy.notify_user: + print(self.fallback_policy.get_fallback_message(provider_name, last_error)) + # Try next provider + config_copy = config.copy() + config_copy.preferred_providers = remaining + return await self._handle_single_provider(task_type, input_data, config_copy, use_adaptive_routing, **kwargs) + if self.fallback_policy.notify_user: + print(self.fallback_policy.get_fallback_message(provider_name, last_error)) + raise last_error async def _handle_ensemble( self, diff --git a/multimind/gateway/rag_api.py b/multimind/gateway/rag_api.py index b1ee1687..0ede8dc1 100644 --- a/multimind/gateway/rag_api.py +++ b/multimind/gateway/rag_api.py @@ -8,7 +8,6 @@ import logging import json import tempfile -import hashlib from typing import List, Dict, Any, Optional from pathlib import Path from datetime import datetime, timedelta @@ -19,7 +18,7 @@ from pydantic import BaseModel, Field import jwt -# Try to import passlib for password hashing, fallback to simple hash +# Require passlib for secure password hashing. try: from passlib.context import CryptContext pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -63,31 +62,41 @@ # Password hashing helper def hash_password(password: str) -> str: """Hash a password.""" - if HAS_PASSLIB: - return pwd_context.hash(password) - else: - # Simple hash fallback (not secure, but works for development) - return hashlib.sha256(password.encode()).hexdigest() + if not HAS_PASSLIB: + raise RuntimeError("passlib[bcrypt] is required for secure password hashing") + return pwd_context.hash(password) def verify_password(password: str, hashed: str) -> bool: """Verify a password.""" - if HAS_PASSLIB: - return pwd_context.verify(password, hashed) - else: - # Simple hash verification fallback - return hashlib.sha256(password.encode()).hexdigest() == hashed + if not HAS_PASSLIB: + raise RuntimeError("passlib[bcrypt] is required for secure password verification") + return pwd_context.verify(password, hashed) # Get API keys from environment API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else [] -JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key-change-in-production") +JWT_SECRET = os.getenv("JWT_SECRET") JWT_ALGORITHM = "HS256" JWT_EXPIRATION_MINUTES = 30 -# Default users for JWT authentication (in production, use a database) -DEFAULT_USERS = { - "testuser": hash_password("secret"), - "admin": hash_password("admin123") -} +def _load_jwt_users() -> Dict[str, str]: + """ + Load JWT users from environment variable JWT_USERS_JSON. + Format: {"username":"hashed_password", ...} + """ + raw_users = os.getenv("JWT_USERS_JSON") + if not raw_users: + return {} + try: + users = json.loads(raw_users) + if isinstance(users, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in users.items()): + return users + logger.error("JWT_USERS_JSON must be a JSON object of username->hashed_password") + return {} + except json.JSONDecodeError: + logger.error("JWT_USERS_JSON is not valid JSON") + return {} + +JWT_USERS = _load_jwt_users() # Global RAG instance and model rag_instance: Optional[RAG] = None @@ -162,6 +171,8 @@ def verify_api_key(api_key: Optional[str] = Header(None, alias="X-API-Key")) -> def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)) -> Dict[str, Any]: """Verify JWT token.""" + if not JWT_SECRET: + raise HTTPException(status_code=503, detail="JWT authentication is not configured") if not credentials: raise HTTPException(status_code=401, detail="Authorization header required") @@ -298,11 +309,15 @@ async def startup_event(): @app.post("/token", response_model=TokenResponse) async def login(username: str = Form(...), password: str = Form(...)): """Get JWT token for authentication.""" - # In production, verify against database - if username not in DEFAULT_USERS: + if not JWT_SECRET: + raise HTTPException(status_code=503, detail="JWT authentication is not configured") + if not JWT_USERS: + raise HTTPException(status_code=503, detail="No JWT users configured") + + if username not in JWT_USERS: raise HTTPException(status_code=401, detail="Invalid username or password") - if not verify_password(password, DEFAULT_USERS[username]): + if not verify_password(password, JWT_USERS[username]): raise HTTPException(status_code=401, detail="Invalid username or password") # Create token diff --git a/multimind/integrations/model_adapters.py b/multimind/integrations/model_adapters.py index b520a85c..0a26e2d2 100644 --- a/multimind/integrations/model_adapters.py +++ b/multimind/integrations/model_adapters.py @@ -2,19 +2,19 @@ Integration adapters for fine-tuned models to work with various frameworks. """ -from typing import List, Dict, Any, Optional, Union, Tuple, Se +from typing import List, Dict, Any, Optional, Union, Tuple, Sequence import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer from langchain.llms.base import LLM from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.embeddings.base import Embeddings -from langchain.schema import Documen +from langchain.schema import Document from lite_llm import LiteLLM -from superagi.agent import Agen +from superagi.agent import Agent from superagi.tools import Tool from semantic_kernel import Kernel, KernelFunction -from crewai import Agent as CrewAgen +from crewai import Agent as CrewAgent from crewai import Task import logging from ..fine_tuning import ( @@ -254,7 +254,7 @@ def invoke( # Update context with response context["response"] = response - return contex + return context class CrewAIAdapter(CrewAgent, BaseModelAdapter): """Adapter for CrewAI integration.""" diff --git a/multimind/mcp/advanced_executor.py b/multimind/mcp/advanced_executor.py index 3b121b56..7d884bb5 100644 --- a/multimind/mcp/advanced_executor.py +++ b/multimind/mcp/advanced_executor.py @@ -32,14 +32,13 @@ def __init__( retry_delay: float = 1.0 ): self.parser = parser or MCPParser() - self.metrics_collector = self.MetricsCollector() - self.model_registry = { + self.metrics_collector = metrics_collector or self.MetricsCollector() + self.model_registry = model_registry or { "ollama": OllamaModel(), "openai": OpenAIModel(), "claude": ClaudeModel(), "gemini": GeminiModel() } - print(f"Model registry contents: {self.model_registry}") self.max_retries = max_retries self.retry_delay = retry_delay self.workflow_state: Dict[str, Any] = {} diff --git a/multimind/memory/active_learning.py b/multimind/memory/active_learning.py index 514a7224..6ff3846e 100644 --- a/multimind/memory/active_learning.py +++ b/multimind/memory/active_learning.py @@ -47,7 +47,6 @@ def __init__( self.reinforcement: Dict[str, List[Dict[str, Any]]] = {} # item_id -> reinforcement data self.last_analysis = datetime.now() self.last_optimization = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and track feedback.""" @@ -210,7 +209,7 @@ async def _remove_item(self, item_id: str) -> None: if item_id in self.reinforcement: del self.reinforcement[item_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all items.""" messages = [] for item in self.items: @@ -241,7 +240,7 @@ async def save(self) -> None: "last_optimization": self.last_optimization.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load items and feedback from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/associative.py b/multimind/memory/associative.py index 1209edde..2ad079ec 100644 --- a/multimind/memory/associative.py +++ b/multimind/memory/associative.py @@ -88,7 +88,6 @@ def __init__( self.last_cluster_update = datetime.now() self.last_analysis = datetime.now() self.last_evolution = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message as new association.""" @@ -407,7 +406,7 @@ async def _remove_association(self, association_id: str) -> None: if association_id in self.learning_history: del self.learning_history[association_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all associations.""" messages = [] for association in self.associations: @@ -448,7 +447,7 @@ async def save(self) -> None: "last_evolution": self.last_evolution.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load associations from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/base.py b/multimind/memory/base.py index 41603096..e0e1b0ed 100644 --- a/multimind/memory/base.py +++ b/multimind/memory/base.py @@ -3,7 +3,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Any from datetime import datetime class BaseMemory(ABC): @@ -14,26 +14,26 @@ def __init__(self, memory_key: str = "chat_history"): self.created_at = datetime.now() @abstractmethod - def add_message(self, message: Dict[str, str]) -> None: + async def add_message(self, message: Dict[str, str]) -> None: """Add a message to memory.""" pass @abstractmethod - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from memory.""" pass @abstractmethod - def clear(self) -> None: + async def clear(self) -> None: """Clear all messages from memory.""" pass @abstractmethod - def save(self) -> None: + async def save(self) -> None: """Save memory to persistent storage.""" pass @abstractmethod - def load(self) -> None: + async def load(self) -> None: """Load memory from persistent storage.""" pass \ No newline at end of file diff --git a/multimind/memory/buffer.py b/multimind/memory/buffer.py index 7d6dd3d3..7864911b 100644 --- a/multimind/memory/buffer.py +++ b/multimind/memory/buffer.py @@ -50,9 +50,7 @@ def __init__( self.last_backup = datetime.now() self.backup_history: List[Dict[str, Any]] = [] - # Load if storage path exists - if self.storage_path and self.storage_path.exists(): - self.load() + # Load explicitly via await memory.load() when needed. async def add_message( self, @@ -94,7 +92,7 @@ async def add_message( if self.enable_backup and (datetime.now() - self.last_backup).total_seconds() >= self.backup_interval: await self._backup() - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from the buffer.""" return self.messages @@ -108,7 +106,7 @@ def get_messages_with_metadata(self) -> List[Dict[str, Any]]: for i, msg in enumerate(self.messages) ] - def clear(self) -> None: + async def clear(self) -> None: """Clear all messages from the buffer.""" self.messages = [] self.message_tokens = [] @@ -117,7 +115,7 @@ def clear(self) -> None: if self.storage_path and self.storage_path.exists(): self.storage_path.unlink() - def save(self) -> None: + async def save(self) -> None: """Save buffer to persistent storage.""" if not self.storage_path: return @@ -135,7 +133,7 @@ def save(self) -> None: with open(self.storage_path, "w") as f: json.dump(data, f) - def load(self) -> None: + async def load(self) -> None: """Load buffer from persistent storage.""" if not self.storage_path or not self.storage_path.exists(): return @@ -152,7 +150,7 @@ def load(self) -> None: self.backup_history = data["backup_history"] except Exception as e: print(f"Error loading buffer: {e}") - self.clear() + await self.clear() def _remove_oldest(self) -> None: """Remove the oldest message based on strategy.""" @@ -173,10 +171,19 @@ def _remove_oldest(self) -> None: self.metadata = new_metadata elif self.strategy == "lru": - # Remove least recently used message - # This is a simplified implementation - # In practice, you would track access times - self._remove_oldest() + # Remove least recently used message. + # Access-time tracking is not available yet, so use FIFO eviction + # as a safe, non-recursive fallback. + self.total_tokens -= self.message_tokens[0] + self.messages.pop(0) + self.message_tokens.pop(0) + if self.enable_metadata: + self.metadata.pop("0", None) + # Shift metadata indices + new_metadata = {} + for i in range(len(self.messages)): + new_metadata[str(i)] = self.metadata.get(str(i + 1), {}) + self.metadata = new_metadata else: # sliding # Remove messages from start until we have space @@ -254,7 +261,7 @@ async def _backup(self) -> None: # Save to disk if storage path exists if self.storage_path: - self.save() + await self.save() def get_stats(self) -> Dict[str, Any]: """Get buffer statistics.""" diff --git a/multimind/memory/cognitive_scratchpad.py b/multimind/memory/cognitive_scratchpad.py index 0737b674..cc324cbf 100644 --- a/multimind/memory/cognitive_scratchpad.py +++ b/multimind/memory/cognitive_scratchpad.py @@ -47,7 +47,6 @@ def __init__( self.reasoning_chains: Dict[str, List[Dict[str, Any]]] = {} # chain_id -> chain data self.last_analysis = datetime.now() self.last_optimization = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and track reasoning steps.""" @@ -206,7 +205,7 @@ async def _remove_item(self, item_id: str) -> None: s for s in chain_data if s["item_id"] != item_id ] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all items.""" messages = [] for item in self.items: @@ -237,7 +236,7 @@ async def save(self) -> None: "last_optimization": self.last_optimization.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load items and steps from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/combined.py b/multimind/memory/combined.py index df74b315..f6c4c768 100644 --- a/multimind/memory/combined.py +++ b/multimind/memory/combined.py @@ -17,32 +17,33 @@ def __init__( super().__init__(memory_key) self.memories = memories - def add_message(self, message: Dict[str, str]) -> None: + async def add_message(self, message: Dict[str, str]) -> None: """Add message to all memory types.""" for memory in self.memories: - memory.add_message(message) + await memory.add_message(message) - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get messages from all memory types.""" all_messages = [] for memory in self.memories: - all_messages.extend(memory.get_messages()) + msgs = await memory.get_messages() + all_messages.extend(msgs) return all_messages - def clear(self) -> None: + async def clear(self) -> None: """Clear all memory types.""" for memory in self.memories: - memory.clear() + await memory.clear() - def save(self) -> None: + async def save(self) -> None: """Save all memory types.""" for memory in self.memories: - memory.save() + await memory.save() - def load(self) -> None: + async def load(self) -> None: """Load all memory types.""" for memory in self.memories: - memory.load() + await memory.load() def get_memory(self, memory_type: type) -> Optional[BaseMemory]: """Get a specific memory type instance.""" diff --git a/multimind/memory/consensus.py b/multimind/memory/consensus.py index cf1d6b2b..19b01c65 100644 --- a/multimind/memory/consensus.py +++ b/multimind/memory/consensus.py @@ -76,10 +76,33 @@ def __init__( self.consensus_rounds = 0 self.leader_changes = 0 self.last_heartbeat = datetime.now() - - # Start background tasks - asyncio.create_task(self._run_election_timer()) - asyncio.create_task(self._run_heartbeat()) + self._running = False + self._election_task: Optional[asyncio.Task] = None + self._heartbeat_task: Optional[asyncio.Task] = None + + async def start_background_tasks(self) -> None: + """Start background RAFT tasks when an event loop is available.""" + if self._running: + return + self._running = True + self._election_task = asyncio.create_task(self._run_election_timer()) + self._heartbeat_task = asyncio.create_task(self._run_heartbeat()) + + async def stop_background_tasks(self) -> None: + """Stop background RAFT tasks gracefully.""" + self._running = False + tasks = [t for t in [self._election_task, self._heartbeat_task] if t is not None] + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + self._election_task = None + self._heartbeat_task = None + + async def _ensure_background_tasks_started(self) -> None: + """Lazily start background tasks from async call sites.""" + if not self._running: + await self.start_background_tasks() async def add_memory( self, @@ -88,6 +111,7 @@ async def add_memory( metadata: Optional[Dict[str, Any]] = None ) -> None: """Add a new memory through consensus.""" + await self._ensure_background_tasks_started() if self.state == NodeState.LEADER: # Create log entry entry = LogEntry( @@ -133,6 +157,7 @@ async def update_memory( updates: Dict[str, Any] ) -> None: """Update a memory through consensus.""" + await self._ensure_background_tasks_started() if self.state == NodeState.LEADER: # Create log entry entry = LogEntry( @@ -163,6 +188,7 @@ async def update_memory( async def remove_memory(self, memory_id: str) -> None: """Remove a memory through consensus.""" + await self._ensure_background_tasks_started() if self.state == NodeState.LEADER: # Create log entry entry = LogEntry( @@ -213,19 +239,25 @@ async def get_stats(self) -> Dict[str, Any]: async def _run_election_timer(self) -> None: """Run election timer for leader election.""" - while True: - if self.state != NodeState.LEADER: - # Check if election timeout - if (datetime.now() - self.last_heartbeat).total_seconds() > self.election_timeout: - await self._start_election() - await asyncio.sleep(self.election_timeout) + try: + while self._running: + if self.state != NodeState.LEADER: + # Check if election timeout + if (datetime.now() - self.last_heartbeat).total_seconds() > self.election_timeout: + await self._start_election() + await asyncio.sleep(self.election_timeout) + except asyncio.CancelledError: + return async def _run_heartbeat(self) -> None: """Run heartbeat for leader.""" - while True: - if self.state == NodeState.LEADER: - await self._send_heartbeat() - await asyncio.sleep(self.heartbeat_interval) + try: + while self._running: + if self.state == NodeState.LEADER: + await self._send_heartbeat() + await asyncio.sleep(self.heartbeat_interval) + except asyncio.CancelledError: + return async def _start_election(self) -> None: """Start leader election.""" diff --git a/multimind/memory/contextual.py b/multimind/memory/contextual.py index 88aba843..818037dc 100644 --- a/multimind/memory/contextual.py +++ b/multimind/memory/contextual.py @@ -64,7 +64,6 @@ def __init__( self.context_summaries: Dict[str, str] = {} # context_id -> summary self.context_evolution: Dict[str, List[Dict[str, Any]]] = {} # context_id -> evolution history self.last_summarization = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message to the appropriate context.""" @@ -423,7 +422,7 @@ async def _remove_context(self, context_id: str) -> None: del self.context_metadata[context_id] del self.context_weights[context_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all contexts.""" messages = [] for context in self.contexts: @@ -465,7 +464,7 @@ async def save(self) -> None: "last_summarization": self.last_summarization.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load contexts from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: @@ -540,9 +539,10 @@ async def traverse(current_id: str, depth: int) -> None: async def get_context_stats(self) -> Dict[str, Any]: """Get statistics about contexts.""" + messages = await self.get_messages() stats = { "total_contexts": len(self.contexts), - "total_messages": len(self.get_messages()), + "total_messages": len(messages), "relationship_types": { rel_type: 0 for rel_type in self.relationship_types }, diff --git a/multimind/memory/declarative.py b/multimind/memory/declarative.py index ba3df815..bba71e52 100644 --- a/multimind/memory/declarative.py +++ b/multimind/memory/declarative.py @@ -130,7 +130,6 @@ def __init__( self.last_temporal = datetime.now() self.last_causal = datetime.now() self.last_graph_update = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and analyze factual information.""" @@ -649,7 +648,7 @@ async def _remove_fact(self, fact_id: str) -> None: if fact_id in self.validation_history: del self.validation_history[fact_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all facts.""" messages = [] for fact in self.facts: @@ -706,7 +705,7 @@ async def save(self) -> None: "last_graph_update": self.last_graph_update.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load facts from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/dnc.py b/multimind/memory/dnc.py index c992c9e6..9c9e23bb 100644 --- a/multimind/memory/dnc.py +++ b/multimind/memory/dnc.py @@ -87,7 +87,6 @@ def __init__( self.last_optimization = datetime.now() self.last_analysis = datetime.now() self.last_backup = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message to DNC memory.""" @@ -368,7 +367,7 @@ async def _create_backup(self) -> None: except Exception as e: print(f"Error creating backup: {e}") - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from memory.""" messages = [] for item in self.items: @@ -426,7 +425,7 @@ async def save(self) -> None: "last_backup": self.last_backup.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load memory from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/emotional.py b/multimind/memory/emotional.py index a1ce03e2..ed9d6abe 100644 --- a/multimind/memory/emotional.py +++ b/multimind/memory/emotional.py @@ -90,7 +90,6 @@ def __init__( self.last_pattern_update = datetime.now() self.last_evolution = datetime.now() self.last_cluster_update = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and analyze emotional state.""" @@ -178,6 +177,75 @@ async def add_message(self, message: Dict[str, str]) -> None: await self.save() + async def get_messages(self) -> List[Dict[str, str]]: + """Get all emotional-memory states.""" + return self.states + + async def clear(self) -> None: + """Clear all emotional-memory state.""" + self.states = [] + self.state_embeddings = [] + self.emotion_patterns = {} + self.adaptation_history = {} + self.learning_history = {} + self.emotion_history = [] + self.evolution_history = {} + self.relationships = {} + self.clusters = {} + self.last_analysis = datetime.now() + self.last_pattern_update = datetime.now() + self.last_evolution = datetime.now() + self.last_cluster_update = datetime.now() + if self.storage_path and self.storage_path.exists(): + self.storage_path.unlink() + + async def save(self) -> None: + """Persist emotional-memory state.""" + if not self.storage_path: + return + + data = { + "states": self.states, + "state_embeddings": self.state_embeddings, + "emotion_patterns": self.emotion_patterns, + "adaptation_history": self.adaptation_history, + "learning_history": self.learning_history, + "emotion_history": self.emotion_history, + "evolution_history": self.evolution_history, + "relationships": self.relationships, + "clusters": self.clusters, + "last_analysis": self.last_analysis.isoformat(), + "last_pattern_update": self.last_pattern_update.isoformat(), + "last_evolution": self.last_evolution.isoformat(), + "last_cluster_update": self.last_cluster_update.isoformat(), + } + + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.storage_path, "w") as f: + json.dump(data, f) + + async def load(self) -> None: + """Load emotional-memory state from disk.""" + if not self.storage_path or not self.storage_path.exists(): + return + + with open(self.storage_path, "r") as f: + data = json.load(f) + + self.states = data.get("states", []) + self.state_embeddings = data.get("state_embeddings", []) + self.emotion_patterns = data.get("emotion_patterns", {}) + self.adaptation_history = data.get("adaptation_history", {}) + self.learning_history = data.get("learning_history", {}) + self.emotion_history = data.get("emotion_history", []) + self.evolution_history = data.get("evolution_history", {}) + self.relationships = data.get("relationships", {}) + self.clusters = data.get("clusters", {}) + self.last_analysis = datetime.fromisoformat(data.get("last_analysis", datetime.now().isoformat())) + self.last_pattern_update = datetime.fromisoformat(data.get("last_pattern_update", datetime.now().isoformat())) + self.last_evolution = datetime.fromisoformat(data.get("last_evolution", datetime.now().isoformat())) + self.last_cluster_update = datetime.fromisoformat(data.get("last_cluster_update", datetime.now().isoformat())) + async def _find_relationships(self, state_id: str) -> None: """Find relationships between emotional states.""" state = next(s for s in self.states if s["id"] == state_id) diff --git a/multimind/memory/episodic.py b/multimind/memory/episodic.py index d1fa11ba..bbada3cd 100644 --- a/multimind/memory/episodic.py +++ b/multimind/memory/episodic.py @@ -62,7 +62,6 @@ def __init__( self.episode_importance: Dict[str, float] = {} # episode_id -> importance score self.emotional_profiles: Dict[str, Dict[str, float]] = {} # episode_id -> emotion -> intensity self.last_consolidation = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message as a new episode with context.""" @@ -395,7 +394,7 @@ async def _remove_episode(self, episode_id: str) -> None: # Remove weight del self.episode_weights[episode_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all episodes.""" messages = [] for episode in self.episodes: @@ -440,7 +439,7 @@ async def save(self) -> None: "last_consolidation": self.last_consolidation.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load episodes from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/event_sourced.py b/multimind/memory/event_sourced.py index f7b1033b..001030f1 100644 --- a/multimind/memory/event_sourced.py +++ b/multimind/memory/event_sourced.py @@ -52,7 +52,6 @@ def __init__( self.causal_chains: Dict[str, List[Dict[str, Any]]] = {} # chain_id -> causal chain self.last_analysis = datetime.now() self.last_optimization = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and create events.""" @@ -351,7 +350,7 @@ async def _remove_item(self, item_id: str) -> None: c for c in chain_data if c["item_id"] != item_id ] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all items.""" messages = [] for item in self.items: @@ -384,7 +383,7 @@ async def save(self) -> None: "last_optimization": self.last_optimization.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load items and events from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/explicit.py b/multimind/memory/explicit.py index 740be0d5..73e4d96e 100644 --- a/multimind/memory/explicit.py +++ b/multimind/memory/explicit.py @@ -5,6 +5,7 @@ from typing import Dict, Any, Optional, List, Set, Tuple from datetime import datetime, timedelta import numpy as np +import networkx as nx from .base import BaseMemory from .declarative import DeclarativeMemory from .semantic import SemanticMemory diff --git a/multimind/memory/forgetting_curve.py b/multimind/memory/forgetting_curve.py index 8081060a..bb6d8e80 100644 --- a/multimind/memory/forgetting_curve.py +++ b/multimind/memory/forgetting_curve.py @@ -79,7 +79,6 @@ def __init__( self.last_adaptive = datetime.now() self.last_consolidation = datetime.now() self.last_optimization = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and initialize forgetting curve.""" @@ -337,7 +336,7 @@ async def _remove_item(self, item_id: str) -> None: if item_id in self.interference_graph: del self.interference_graph[item_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all items.""" messages = [] for item in self.items: @@ -376,7 +375,7 @@ async def save(self) -> None: "last_optimization": self.last_optimization.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load items from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/hierarchical.py b/multimind/memory/hierarchical.py index 2d59e258..489e5a9d 100644 --- a/multimind/memory/hierarchical.py +++ b/multimind/memory/hierarchical.py @@ -60,7 +60,6 @@ def __init__( self.node_map: Dict[str, Dict[str, Any]] = {"root": self.root} self.category_embeddings: Dict[str, List[float]] = {} self.semantic_index: Dict[str, Set[str]] = {} # tag -> node_ids - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message to the appropriate node in the hierarchy.""" @@ -101,7 +100,7 @@ async def add_message(self, message: Dict[str, str]) -> None: await self._maintain_hierarchy() await self.save() - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from the hierarchy.""" messages = [] self._collect_messages(self.root, messages) @@ -129,7 +128,7 @@ async def save(self) -> None: "node_map": self.node_map }, f) - def load(self) -> None: + async def load(self) -> None: """Load hierarchy from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: @@ -720,9 +719,10 @@ async def get_lifecycle_stats(self) -> Dict[str, Any]: async def get_hierarchy_stats(self) -> Dict[str, Any]: """Get statistics about the hierarchy.""" + messages = await self.get_messages() stats = { "total_nodes": len(self.node_map), - "total_messages": len(self.get_messages()), + "total_messages": len(messages), "max_depth": max(self._get_node_depth(node_id) for node_id in self.node_map), "node_distribution": {}, "message_distribution": {}, diff --git a/multimind/memory/hybrid.py b/multimind/memory/hybrid.py index 1dde5506..6a52ba4d 100644 --- a/multimind/memory/hybrid.py +++ b/multimind/memory/hybrid.py @@ -125,7 +125,6 @@ def __init__( # Initialize memories self._initialize_memories() - self.load() def _initialize_memories(self) -> None: """Initialize memory instances.""" @@ -547,7 +546,7 @@ async def _create_backup(self) -> None: except Exception as e: print(f"Error creating backup: {e}") - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all memory types.""" all_messages = [] for memory in self.memories.values(): @@ -595,7 +594,7 @@ async def save(self) -> None: "last_evolution": self.last_evolution.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load memory state from persistent storage.""" if self.storage_path and (self.storage_path / "hybrid_memory.json").exists(): with open(self.storage_path / "hybrid_memory.json", 'r') as f: diff --git a/multimind/memory/knowledge_graph.py b/multimind/memory/knowledge_graph.py index abf51668..9f3f9aa2 100644 --- a/multimind/memory/knowledge_graph.py +++ b/multimind/memory/knowledge_graph.py @@ -32,7 +32,6 @@ def __init__( ] self.graph = nx.DiGraph() self.messages: List[Dict[str, str]] = [] - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and extract entities and relationships.""" @@ -54,7 +53,7 @@ async def add_message(self, message: Dict[str, str]) -> None: await self.save() - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages.""" return self.messages @@ -74,7 +73,7 @@ async def save(self) -> None: "graph": nx.node_link_data(self.graph) }, f) - def load(self) -> None: + async def load(self) -> None: """Load messages and graph from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/novelty.py b/multimind/memory/novelty.py index 7a944b06..47950f3a 100644 --- a/multimind/memory/novelty.py +++ b/multimind/memory/novelty.py @@ -86,7 +86,6 @@ def __init__( self.last_salience = datetime.now() self.last_optimization = datetime.now() self.last_adaptation = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and calculate novelty/salience scores.""" @@ -465,7 +464,7 @@ async def _remove_item(self, item_id: str) -> None: if item_id in self.temporal_windows: del self.temporal_windows[item_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all items.""" messages = [] for item in self.items: @@ -510,7 +509,7 @@ async def save(self) -> None: "last_adaptation": self.last_adaptation.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load items from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/procedural.py b/multimind/memory/procedural.py index 08287c20..7d29c953 100644 --- a/multimind/memory/procedural.py +++ b/multimind/memory/procedural.py @@ -70,7 +70,6 @@ def __init__( self.last_optimization = datetime.now() self.last_validation = datetime.now() self.last_monitoring = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message as new procedural knowledge.""" @@ -419,7 +418,7 @@ async def _remove_procedure(self, procedure_id: str) -> None: if procedure_id in self.optimization_cache: del self.optimization_cache[procedure_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all procedures.""" messages = [] for procedure in self.procedures: @@ -468,7 +467,7 @@ async def save(self) -> None: "last_monitoring": self.last_monitoring.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load procedures from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/redis.py b/multimind/memory/redis.py index 4f1bb9e7..42b449b4 100644 --- a/multimind/memory/redis.py +++ b/multimind/memory/redis.py @@ -21,7 +21,7 @@ def __init__( self.redis_client = redis.from_url(redis_url) self.ttl = ttl # Time to live in seconds - def add_message(self, message: Dict[str, str]) -> None: + async def add_message(self, message: Dict[str, str]) -> None: """Add message to Redis.""" message_with_timestamp = { **message, @@ -38,20 +38,20 @@ def add_message(self, message: Dict[str, str]) -> None: if self.ttl: self.redis_client.expire(self.memory_key, self.ttl) - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from Redis.""" messages = self.redis_client.lrange(self.memory_key, 0, -1) return [json.loads(msg) for msg in messages] - def clear(self) -> None: + async def clear(self) -> None: """Clear all messages from Redis.""" self.redis_client.delete(self.memory_key) - def save(self) -> None: + async def save(self) -> None: """Save is handled automatically by Redis.""" pass - def load(self) -> None: + async def load(self) -> None: """Load is handled automatically by Redis.""" pass @@ -61,7 +61,8 @@ def get_message_count(self) -> int: def get_messages_since(self, timestamp: datetime) -> List[Dict[str, str]]: """Get messages since a specific timestamp.""" - all_messages = self.get_messages() + all_messages = self.redis_client.lrange(self.memory_key, 0, -1) + all_messages = [json.loads(msg) for msg in all_messages] return [ msg for msg in all_messages if datetime.fromisoformat(msg["timestamp"]) > timestamp diff --git a/multimind/memory/semantic.py b/multimind/memory/semantic.py index fe674f18..4d091bd1 100644 --- a/multimind/memory/semantic.py +++ b/multimind/memory/semantic.py @@ -45,7 +45,6 @@ def __init__( self.concept_metadata: Dict[str, Dict[str, Any]] = {} # concept_id -> metadata self.inference_cache: Dict[str, List[Dict[str, Any]]] = {} # concept_id -> inferred relationships self.last_validation = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message as new semantic knowledge.""" @@ -355,7 +354,7 @@ async def _remove_concept(self, concept_id: str) -> None: if concept_id in self.inference_cache: del self.inference_cache[concept_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all concepts.""" messages = [] for concept in self.concepts: @@ -398,7 +397,7 @@ async def save(self) -> None: "last_validation": self.last_validation.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load concepts from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/sensory.py b/multimind/memory/sensory.py index 9918e5db..95fdfef2 100644 --- a/multimind/memory/sensory.py +++ b/multimind/memory/sensory.py @@ -111,7 +111,6 @@ def __init__( self.last_cross_modal = datetime.now() self.last_fusion = datetime.now() self.last_advanced_pattern = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and analyze sensory information.""" @@ -567,7 +566,7 @@ async def _remove_experience(self, experience_id: str) -> None: if experience_id in self.validation_history: del self.validation_history[experience_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all experiences.""" messages = [] for experience in self.experiences: @@ -616,7 +615,7 @@ async def save(self) -> None: "last_advanced_pattern": self.last_advanced_pattern.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load experiences from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/spatial.py b/multimind/memory/spatial.py index ea70d180..b41e0c04 100644 --- a/multimind/memory/spatial.py +++ b/multimind/memory/spatial.py @@ -83,7 +83,6 @@ def __init__( self.last_cluster_update = datetime.now() self.last_evolution = datetime.now() self.last_validation = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and analyze spatial information.""" @@ -503,7 +502,7 @@ async def _remove_location(self, location_id: str) -> None: if location_id in self.validation_history: del self.validation_history[location_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all locations.""" messages = [] for location in self.locations: @@ -546,7 +545,7 @@ async def save(self) -> None: "last_validation": self.last_validation.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load locations from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/sqlalchemy.py b/multimind/memory/sqlalchemy.py index 68fc06cd..d5cb3aea 100644 --- a/multimind/memory/sqlalchemy.py +++ b/multimind/memory/sqlalchemy.py @@ -36,7 +36,7 @@ def __init__( self.Session = sessionmaker(bind=self.engine) Base.metadata.create_all(self.engine) - def add_message(self, message: Dict[str, str]) -> None: + async def add_message(self, message: Dict[str, str]) -> None: """Add message to database.""" session = self.Session() try: @@ -50,7 +50,7 @@ def add_message(self, message: Dict[str, str]) -> None: finally: session.close() - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from database.""" session = self.Session() try: @@ -67,7 +67,7 @@ def get_messages(self) -> List[Dict[str, str]]: finally: session.close() - def clear(self) -> None: + async def clear(self) -> None: """Clear all messages from database.""" session = self.Session() try: @@ -76,11 +76,11 @@ def clear(self) -> None: finally: session.close() - def save(self) -> None: + async def save(self) -> None: """Save is handled automatically by SQLAlchemy.""" pass - def load(self) -> None: + async def load(self) -> None: """Load is handled automatically by SQLAlchemy.""" pass diff --git a/multimind/memory/summary.py b/multimind/memory/summary.py index aaadc7ad..adb3cece 100644 --- a/multimind/memory/summary.py +++ b/multimind/memory/summary.py @@ -62,9 +62,11 @@ def __init__( self.last_backup = datetime.now() self.backup_history: List[Dict[str, Any]] = [] - # Load if storage path exists - if self.storage_path and self.storage_path.exists(): - self.load() + # Load explicitly via await memory.load() when needed. + + async def add_message(self, message: Dict[str, str]) -> None: + """Add a single message to summary memory.""" + await self.add_messages([message]) async def add_messages( self, @@ -277,7 +279,17 @@ async def _backup(self) -> None: # Save to disk if storage path exists if self.storage_path: - self.save() + await self.save() + + async def get_messages(self) -> List[Dict[str, str]]: + """Return summaries in message format.""" + return [ + { + "role": "system", + "content": summary.get("summary", ""), + } + for summary in self.summaries + ] def get_summaries(self) -> List[Dict[str, Any]]: """Get all summaries.""" @@ -293,7 +305,7 @@ def get_summaries_with_metadata(self) -> List[Dict[str, Any]]: for i, summary in enumerate(self.summaries) ] - def clear(self) -> None: + async def clear(self) -> None: """Clear all summaries.""" self.summaries = [] self.summary_metadata = {} @@ -302,7 +314,7 @@ def clear(self) -> None: if self.storage_path and self.storage_path.exists(): self.storage_path.unlink() - def save(self) -> None: + async def save(self) -> None: """Save summaries to persistent storage.""" if not self.storage_path: return @@ -320,7 +332,7 @@ def save(self) -> None: with open(self.storage_path, "w") as f: json.dump(data, f) - def load(self) -> None: + async def load(self) -> None: """Load summaries from persistent storage.""" if not self.storage_path or not self.storage_path.exists(): return @@ -337,7 +349,7 @@ def load(self) -> None: self.backup_history = data["backup_history"] except Exception as e: print(f"Error loading summaries: {e}") - self.clear() + await self.clear() def get_stats(self) -> Dict[str, Any]: """Get summary statistics.""" diff --git a/multimind/memory/temporal.py b/multimind/memory/temporal.py index f9dafa80..b7da1dbd 100644 --- a/multimind/memory/temporal.py +++ b/multimind/memory/temporal.py @@ -83,7 +83,6 @@ def __init__( self.last_pattern_update = datetime.now() self.last_evolution = datetime.now() self.last_validation = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and analyze temporal information.""" @@ -498,7 +497,7 @@ async def _remove_event(self, event_id: str) -> None: if event_id in self.validation_history: del self.validation_history[event_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all events.""" messages = [] for event in self.events: @@ -541,7 +540,7 @@ async def save(self) -> None: "last_validation": self.last_validation.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load events from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/time_weighted.py b/multimind/memory/time_weighted.py index 919699c1..0f10eee3 100644 --- a/multimind/memory/time_weighted.py +++ b/multimind/memory/time_weighted.py @@ -30,9 +30,8 @@ def __init__( self.decay_function = decay_function self.time_units = time_units self.messages: List[Dict[str, Any]] = [] - self.load() - def add_message(self, message: Dict[str, str]) -> None: + async def add_message(self, message: Dict[str, str]) -> None: """Add message with timestamp and weight.""" message_with_metadata = { **message, @@ -42,26 +41,34 @@ def add_message(self, message: Dict[str, str]) -> None: } self.messages.append(message_with_metadata) self._update_weights() - self.save() + self._save_sync() - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages with their current weights.""" self._update_weights() # Update weights before returning return self.messages - def clear(self) -> None: + async def clear(self) -> None: """Clear all messages.""" self.messages.clear() - self.save() + self._save_sync() - def save(self) -> None: + async def save(self) -> None: + """Save messages to persistent storage.""" + self._save_sync() + + def _save_sync(self) -> None: """Save messages to persistent storage.""" if self.storage_path: self.storage_path.parent.mkdir(parents=True, exist_ok=True) with open(self.storage_path, 'w') as f: json.dump(self.messages, f) - def load(self) -> None: + async def load(self) -> None: + """Load messages from persistent storage.""" + self._load_sync() + + def _load_sync(self) -> None: """Load messages from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: @@ -184,7 +191,7 @@ def set_message_importance(self, message_index: int, importance: float) -> None: if 0 <= message_index < len(self.messages): self.messages[message_index]["importance"] = max(0.0, min(1.0, importance)) self._update_weights() - self.save() + self._save_sync() def get_weight_distribution(self) -> Dict[str, float]: """Get distribution of message weights.""" diff --git a/multimind/memory/vector_store.py b/multimind/memory/vector_store.py index dc0dad75..f3635ffb 100644 --- a/multimind/memory/vector_store.py +++ b/multimind/memory/vector_store.py @@ -59,9 +59,36 @@ def __init__( self.backup_history: List[Dict[str, Any]] = [] self.last_pruning = datetime.now() - # Initialize vector store - if vector_store_config.storage_path: - self.vector_store.load(vector_store_config.storage_path) + # Load explicitly via await memory.load() when needed. + + async def add_message(self, message: Dict[str, str]) -> None: + """Add a message using an auto-generated memory ID.""" + memory_id = f"memory_{datetime.now().timestamp()}" + await self.add( + memory_id=memory_id, + content=message["content"], + metadata={"role": message.get("role", "user")}, + ) + + async def get_messages(self) -> List[Dict[str, str]]: + """Return stored message payloads.""" + results = await self.vector_store.search( + query_vector=[0] * self.vector_store.config.vector_dim, + k=self.vector_store.config.max_vectors, + ) + messages: List[Dict[str, str]] = [] + for result in results: + metadata = getattr(result, "metadata", None) or {} + content = metadata.get("content") + if content is None and isinstance(result, dict): + content = result.get("content") + if content is None: + continue + messages.append({ + "role": metadata.get("role", "user"), + "content": content, + }) + return messages async def add( self, @@ -193,13 +220,29 @@ async def _backup(self) -> None: self.last_backup = datetime.now() - def clear(self) -> None: + async def clear(self) -> None: """Clear all vectors.""" - self.vector_store.clear() + await self.vector_store.clear() self.last_backup = datetime.now() self.backup_history = [] self.last_pruning = datetime.now() + async def save(self) -> None: + """Persist the vector store when storage is configured.""" + storage_path = self.vector_store.config.storage_path + if not storage_path: + return + await self.vector_store.persist(storage_path) + + async def load(self) -> None: + """Load the vector store from persistent storage.""" + storage_path = self.vector_store.config.storage_path + if not storage_path: + return + self.vector_store = await VectorStore.load(storage_path, self.vector_store.config) + self.last_backup = datetime.now() + self.last_pruning = datetime.now() + def get_stats(self) -> Dict[str, Any]: """Get vector store statistics.""" return self.vector_store.get_stats() \ No newline at end of file diff --git a/multimind/memory/versioned.py b/multimind/memory/versioned.py index b63c18e0..850c7e54 100644 --- a/multimind/memory/versioned.py +++ b/multimind/memory/versioned.py @@ -82,7 +82,6 @@ def __init__( self.last_analysis = datetime.now() self.last_optimization = datetime.now() self.last_graph_update = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message and create version.""" @@ -512,7 +511,7 @@ async def _remove_item(self, item_id: str) -> None: for version_id in versions_to_remove: del self.version_graph[version_id] - def get_messages(self) -> List[Dict[str, str]]: + async def get_messages(self) -> List[Dict[str, str]]: """Get all messages from all items.""" messages = [] for item in self.items: @@ -560,7 +559,7 @@ async def save(self) -> None: "last_graph_update": self.last_graph_update.isoformat() }, f) - def load(self) -> None: + async def load(self) -> None: """Load items from persistent storage.""" if self.storage_path and self.storage_path.exists(): with open(self.storage_path, 'r') as f: diff --git a/multimind/memory/working.py b/multimind/memory/working.py index 2692f04c..9bfda1c4 100644 --- a/multimind/memory/working.py +++ b/multimind/memory/working.py @@ -78,7 +78,6 @@ def __init__( self.last_consolidation = datetime.now() self.last_attention_update = datetime.now() self.last_backup = datetime.now() - self.load() async def add_message(self, message: Dict[str, str]) -> None: """Add message to working memory.""" @@ -302,6 +301,72 @@ async def restore_from_backup(self, backup_index: int = -1) -> None: await self.save() + async def get_messages(self) -> List[Dict[str, str]]: + """Get all working-memory items.""" + return self.items + + async def clear(self) -> None: + """Clear all working-memory state.""" + self.items = [] + self.item_embeddings = [] + self.attention_scores = {} + self.attention_history = {} + self.consolidation_history = {} + self.priority_scores = {} + self.compression_history = {} + self.backup_history = [] + self.last_decay = datetime.now() + self.last_consolidation = datetime.now() + self.last_attention_update = datetime.now() + self.last_backup = datetime.now() + if self.storage_path and self.storage_path.exists(): + self.storage_path.unlink() + + async def save(self) -> None: + """Persist working-memory state.""" + if not self.storage_path: + return + + data = { + "items": self.items, + "item_embeddings": self.item_embeddings, + "attention_scores": self.attention_scores, + "attention_history": self.attention_history, + "consolidation_history": self.consolidation_history, + "priority_scores": self.priority_scores, + "compression_history": self.compression_history, + "backup_history": self.backup_history, + "last_decay": self.last_decay.isoformat(), + "last_consolidation": self.last_consolidation.isoformat(), + "last_attention_update": self.last_attention_update.isoformat(), + "last_backup": self.last_backup.isoformat(), + } + + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.storage_path, "w") as f: + json.dump(data, f) + + async def load(self) -> None: + """Load working-memory state from disk.""" + if not self.storage_path or not self.storage_path.exists(): + return + + with open(self.storage_path, "r") as f: + data = json.load(f) + + self.items = data.get("items", []) + self.item_embeddings = data.get("item_embeddings", []) + self.attention_scores = data.get("attention_scores", {}) + self.attention_history = data.get("attention_history", {}) + self.consolidation_history = data.get("consolidation_history", {}) + self.priority_scores = data.get("priority_scores", {}) + self.compression_history = data.get("compression_history", {}) + self.backup_history = data.get("backup_history", []) + self.last_decay = datetime.fromisoformat(data.get("last_decay", datetime.now().isoformat())) + self.last_consolidation = datetime.fromisoformat(data.get("last_consolidation", datetime.now().isoformat())) + self.last_attention_update = datetime.fromisoformat(data.get("last_attention_update", datetime.now().isoformat())) + self.last_backup = datetime.fromisoformat(data.get("last_backup", datetime.now().isoformat())) + async def get_backup_info(self) -> List[Dict[str, Any]]: """Get information about available backups.""" return [ diff --git a/multimind/models/huggingface.py b/multimind/models/huggingface.py index 41aa3cb5..377b1e59 100644 --- a/multimind/models/huggingface.py +++ b/multimind/models/huggingface.py @@ -3,6 +3,7 @@ """ import asyncio +import functools from typing import List, Dict, Any, Optional, AsyncGenerator, Union from .base import BaseLLM @@ -107,15 +108,15 @@ async def generate( **kwargs ) -> str: """Generate text from the model.""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - None, + loop = asyncio.get_running_loop() + generate_fn = functools.partial( self._generate_text, prompt, temperature, max_tokens, **kwargs ) + return await loop.run_in_executor(None, generate_fn) async def generate_stream( self, @@ -173,17 +174,46 @@ async def chat_stream( async for chunk in self.generate_stream(prompt, temperature, max_tokens, **kwargs): yield chunk + def _compute_embeddings( + self, + texts: List[str], + max_length: int = 512 + ) -> List[List[float]]: + """Compute embeddings using mean pooling over the last hidden state.""" + inputs = self.tokenizer( + texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_length + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True, return_dict=True) + + last_hidden = outputs.hidden_states[-1] + attention_mask = inputs["attention_mask"].unsqueeze(-1).type_as(last_hidden) + masked_hidden = last_hidden * attention_mask + token_counts = attention_mask.sum(dim=1).clamp(min=1) + pooled = masked_hidden.sum(dim=1) / token_counts + + return pooled.cpu().tolist() + async def embeddings( self, text: Union[str, List[str]], **kwargs ) -> Union[List[float], List[List[float]]]: """Generate embeddings for the input text.""" - # Use the model's embedding layer if available - # For now, return a placeholder - embeddings should use a dedicated embedding model + texts = [text] if isinstance(text, str) else text + max_length = kwargs.get("max_length", 512) + + loop = asyncio.get_running_loop() + compute_fn = functools.partial(self._compute_embeddings, texts, max_length) + embeddings = await loop.run_in_executor(None, compute_fn) + if isinstance(text, str): - # Return a dummy embedding vector (768 dimensions) - return [0.0] * 768 - else: - return [[0.0] * 768 for _ in text] + return embeddings[0] + return embeddings diff --git a/multimind/providers/claude.py b/multimind/providers/claude.py index f02eaa4d..759b80cb 100644 --- a/multimind/providers/claude.py +++ b/multimind/providers/claude.py @@ -3,6 +3,7 @@ """ from typing import Dict, List, Optional, Union, Any +import base64 import anthropic import logging from datetime import datetime @@ -141,7 +142,7 @@ async def analyze_image( "source": { "type": "base64", "media_type": "image/jpeg", - "data": image_data.hex() + "data": base64.b64encode(image_data).decode("utf-8") } } ] @@ -162,12 +163,14 @@ async def analyze_image( ) / 1000 # Convert to USD return ImageAnalysisResult( + objects=[], + captions=[result] if result else [], + text=result, provider_name="claude", model_name=model, - result=result, - tokens_used=tokens_used, latency_ms=latency_ms, - cost_estimate_usd=cost + cost_estimate_usd=cost, + metadata={"tokens_used": tokens_used} ) except Exception as e: diff --git a/requirements.txt b/requirements.txt index c36a1928..daa971c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -134,6 +134,7 @@ python-dateutil==2.9.0.post0 python-docx==1.1.2 python-dotenv==1.1.0 python-jose==3.5.0 +passlib[bcrypt] python-multipart==0.0.20 pytz==2025.2 pywin32-ctypes==0.2.3 diff --git a/tests/test_advanced_features.py b/tests/test_advanced_features.py index 03baca73..8c52945f 100644 --- a/tests/test_advanced_features.py +++ b/tests/test_advanced_features.py @@ -259,7 +259,7 @@ async def test_memory_edge_cases(self): """Test memory systems with edge cases.""" # Test empty memory memory = BufferMemory() - assert memory.get_messages() == [] + assert await memory.get_messages() == [] # Test memory with None values memory = BufferMemory() @@ -270,7 +270,7 @@ async def test_memory_edge_cases(self): long_content = "x" * 10000 memory = BufferMemory() await memory.add_message({"role": "user", "content": long_content}) - messages = memory.get_messages() + messages = await memory.get_messages() assert len(messages) == 1 assert messages[0]["content"] == long_content @@ -299,11 +299,12 @@ async def test_memory_persistence(self): # Create memory and add data memory = BufferMemory(storage_path=temp_path) await memory.add_message({"role": "user", "content": "test"}) - memory.save() + await memory.save() - # Create new memory instance and load + # Create new memory instance and load explicitly new_memory = BufferMemory(storage_path=temp_path) - assert new_memory.get_messages() != [] + await new_memory.load() + assert await new_memory.get_messages() != [] finally: os.unlink(temp_path) @@ -686,7 +687,7 @@ async def test_large_data_handling(self): "content": f"Message {i}" }) - messages = memory.get_messages() + messages = await memory.get_messages() assert len(messages) == 1000 @pytest.mark.asyncio @@ -705,7 +706,7 @@ async def add_message(msg): await asyncio.gather(*tasks) - messages = memory.get_messages() + messages = await memory.get_messages() assert len(messages) == 10 @pytest.mark.asyncio