Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL
# MultiMind Gateway Settings
MULTIMIND_LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL

# Gateway Authentication & Security
# Comma-separated API keys for X-API-Key auth (optional)
# API_KEYS=key_one,key_two

# JWT secret used to sign/verify tokens (required if using /token JWT auth)
# Use a long random value (32+ chars)
# JWT_SECRET=replace_with_a_long_random_secret

# JSON object mapping username -> bcrypt hashed password (required for /token)
# If passlib is missing, install it:
# pip install passlib[bcrypt]

# Generate hash example:
# python -c "from passlib.context import CryptContext; c=CryptContext(schemes=['bcrypt'], deprecated='auto'); print(c.hash('YourStrongPassword123!'))"
# JWT_USERS_JSON={"admin":"$2b$12$replace_with_bcrypt_hash","testuser":"$2b$12$replace_with_bcrypt_hash"}

# Optional: Model-specific Settings
# Uncomment and modify these if you need custom settings for specific models

Expand Down
62 changes: 62 additions & 0 deletions docs/api_reference/rag_api_server_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,68 @@ echo %ANTHROPIC_API_KEY%

You should see your API key printed. If nothing shows, the server will use local HuggingFace models.

## Step 2.1: Configure JWT/Auth Security Variables (`.env`)

If you want to use `/token` JWT auth (production-style), add these values to your `.env`.
This matches `.env.example` (`API_KEYS`, `JWT_SECRET`, `JWT_USERS_JSON`).

### Required/Optional fields

- `API_KEYS` (optional): comma-separated API keys for `X-API-Key` auth
- `JWT_SECRET` (required for `/token`): long random secret (32+ chars)
- `JWT_USERS_JSON` (required for `/token`): JSON mapping username -> bcrypt hash

### 1) Generate `JWT_SECRET`

PowerShell:
```powershell
python -c "import secrets; print(secrets.token_urlsafe(48))"
```

Copy output into:
```env
JWT_SECRET=your_generated_secret_here
```

### 2) Generate bcrypt password hash

Install passlib+bcrypt (if needed):
```bash
pip install passlib[bcrypt]
```

Generate hash:
```powershell
python -c "from passlib.context import CryptContext; c=CryptContext(schemes=['bcrypt'], deprecated='auto'); print(c.hash('YourStrongPassword123!'))"
```

### 3) Set `JWT_USERS_JSON`

Use the generated hash in JSON (single line):
```env
JWT_USERS_JSON={"admin":"$2b$12$replace_with_bcrypt_hash","testuser":"$2b$12$replace_with_bcrypt_hash"}
```

### 4) Optional `API_KEYS`

```env
API_KEYS=key_one,key_two,key_three
```

### Example secure `.env` block

```env
OPENAI_API_KEY=your-openai-key
API_KEYS=internal_key_1,internal_key_2
JWT_SECRET=replace_with_long_random_secret_32plus_chars
JWT_USERS_JSON={"admin":"$2b$12$replace_with_bcrypt_hash"}
```

**Important:**
- Do not commit real secrets to git.
- Do not use default/fallback secrets in production.
- If `JWT_SECRET`/`JWT_USERS_JSON` are missing, `/token` auth should be considered not securely configured.

## Step 3: Start the RAG API Server

### Method 1: Run as module (RECOMMENDED)
Expand Down
9 changes: 7 additions & 2 deletions multimind/api/unified_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from fastapi import FastAPI, HTTPException
from typing import Dict, Any
import asyncio
import logging
from ..models.moe import MoEFactory
from ..types import UnifiedRequest, UnifiedResponse

app = FastAPI(title="Unified Multi-Modal API")
logger = logging.getLogger(__name__)

# Initialize components
try:
Expand Down Expand Up @@ -87,10 +89,13 @@ async def process_request(request: UnifiedRequest):
}
)

except Exception as e:
except HTTPException:
raise
except Exception:
logger.exception("Error processing unified API request")
raise HTTPException(
status_code=500,
detail=f"Error processing request: {str(e)}"
detail="Internal server error"
)

@app.get("/v1/models")
Expand Down
2 changes: 1 addition & 1 deletion multimind/compliance/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from typing import List, Dict, Any, Optional
from datetime import datetime
from datetime import datetime, timedelta
from pydantic import BaseModel, Field
from .governance import GovernanceConfig
import json
Expand Down
2 changes: 1 addition & 1 deletion multimind/compliance/data_protection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import hashlib
import hmac
import json
import os
from cryptography.fernet import Fernet
from .governance import GovernanceConfig, DataCategory

Expand All @@ -34,7 +35,6 @@ async def protect_data(
) -> Dict[str, Any]:
"""Protect data according to its category."""
protected_data = {
"original_data": data,
"category": category,
"metadata": metadata or {},
"protection_applied": []
Expand Down
2 changes: 1 addition & 1 deletion multimind/compliance/gdpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from pydantic import BaseModel, Field
from .governance import GovernanceConfig, ComplianceMetadata, DataCategory
from .governance import GovernanceConfig, ComplianceMetadata, DataCategory, Regulation

class GDPRCompliance(BaseModel):
"""GDPR compliance manager."""
Expand Down
4 changes: 2 additions & 2 deletions multimind/compliance/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime
from pydantic import BaseModel, Field
import numpy as np
from dataclasses import dataclass
from dataclasses import dataclass, field
try:
import torch
from torch.utils.data import Dataset, DataLoader
Expand All @@ -27,7 +27,7 @@ class ComplianceMetrics:
privacy_score: float
transparency_score: float
fairness_score: float
timestamp: datetime = Field(default_factory=datetime.utcnow)
timestamp: datetime = field(default_factory=datetime.utcnow)

class ComplianceDataset:
"""Dataset wrapper that ensures compliance during training."""
Expand Down
22 changes: 11 additions & 11 deletions multimind/core/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, asdict
from pydantic import BaseModel
from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)

Expand All @@ -18,17 +17,17 @@ class ChatMessage(BaseModel):
role: str
content: str
model: str
timestamp: datetime = datetime.now()
metadata: Dict = {}
timestamp: datetime = Field(default_factory=datetime.now)
metadata: Dict = Field(default_factory=dict)

class ChatSession(BaseModel):
"""A chat session with history and metadata"""
session_id: str
model: str
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
messages: List[ChatMessage] = []
metadata: Dict = {}
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
messages: List[ChatMessage] = Field(default_factory=list)
metadata: Dict = Field(default_factory=dict)
system_prompt: Optional[str] = None

def add_message(self, role: str, content: str, model: str, metadata: Optional[Dict[str, Union[str, int, float]]] = None) -> None:
Expand All @@ -50,10 +49,11 @@ def get_context(self, max_messages: int = 10) -> List[Dict[str, str]]:

def export(self, format: str = "json") -> Union[str, Dict]:
"""Export session to different formats"""
session_data = self.model_dump(mode="json")
if format == "json":
return json.dumps(asdict(self), default=str)
return json.dumps(session_data)
elif format == "dict":
return asdict(self)
return session_data
else:
raise ValueError(f"Unsupported export format: {format}")

Expand All @@ -75,7 +75,7 @@ def save(self, directory: Union[str, Path]) -> Path:
directory.mkdir(parents=True, exist_ok=True)
file_path = directory / f"chat_{self.session_id}.json"
with open(file_path, "w") as f:
json.dump(asdict(self), f, default=str, indent=2)
json.dump(self.model_dump(mode="json"), f, indent=2)
return file_path

class ChatManager:
Expand Down
4 changes: 2 additions & 2 deletions multimind/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime

logger = logging.getLogger(__name__)
Expand All @@ -18,7 +18,7 @@ class ModelResponse:
model: str
usage: Optional[Dict[str, int]] = None
finish_reason: Optional[str] = None
timestamp: str = datetime.now().isoformat()
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())

class ModelHandler(ABC):
"""Abstract base class for model handlers"""
Expand Down
10 changes: 5 additions & 5 deletions multimind/core/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, Any
from enum import Enum
from pydantic import BaseModel
from pydantic import BaseModel, Field
from datetime import datetime

class ProviderCapability(str, Enum):
Expand Down Expand Up @@ -49,8 +49,8 @@ class GenerationResult(BaseModel):
model_name: str
latency_ms: float
cost_estimate_usd: float
metadata: Dict[str, Any] = {}
created_at: datetime = datetime.now()
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: datetime = Field(default_factory=datetime.now)

class EmbeddingResult(BaseModel):
"""Standardized result from embeddings generation."""
Expand All @@ -60,7 +60,7 @@ class EmbeddingResult(BaseModel):
model_name: str
latency_ms: float
cost_estimate_usd: float
metadata: Dict[str, Any] = {}
metadata: Dict[str, Any] = Field(default_factory=dict)

class ImageAnalysisResult(BaseModel):
"""Standardized result from image analysis."""
Expand All @@ -71,7 +71,7 @@ class ImageAnalysisResult(BaseModel):
model_name: str
latency_ms: float
cost_estimate_usd: float
metadata: Dict[str, Any] = {}
metadata: Dict[str, Any] = Field(default_factory=dict)

class ProviderAdapter(ABC):
"""Base class for provider adapters."""
Expand Down
93 changes: 48 additions & 45 deletions multimind/core/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,52 +234,55 @@ async def _handle_single_provider(
call_kwargs = dict(kwargs)
call_kwargs.pop("provider", None)
model_arg = call_kwargs.pop("model", None)
start = time.time()
try:
if task_type == TaskType.TEXT_GENERATION:
if model_arg is not None:
result = await provider.generate_text(model=model_arg, prompt=input_data, **call_kwargs)
else:
result = await provider.generate_text(prompt=input_data, **call_kwargs)
elif task_type == TaskType.EMBEDDINGS:
if model_arg is not None:
result = await provider.generate_embeddings(text=input_data, model=model_arg, **call_kwargs)
else:
result = await provider.generate_embeddings(text=input_data, **call_kwargs)
elif task_type == TaskType.IMAGE_ANALYSIS:
if model_arg is not None:
result = await provider.analyze_image(image_data=input_data, model=model_arg, **call_kwargs)
max_attempts = self.fallback_policy.max_retries + 1 if self.fallback_policy.strategy == "retry" else 1
last_error = None

for _ in range(max_attempts):
start = time.time()
try:
if task_type == TaskType.TEXT_GENERATION:
if model_arg is not None:
result = await provider.generate_text(model=model_arg, prompt=input_data, **call_kwargs)
else:
result = await provider.generate_text(prompt=input_data, **call_kwargs)
elif task_type == TaskType.EMBEDDINGS:
if model_arg is not None:
result = await provider.generate_embeddings(text=input_data, model=model_arg, **call_kwargs)
else:
result = await provider.generate_embeddings(text=input_data, **call_kwargs)
elif task_type == TaskType.IMAGE_ANALYSIS:
if model_arg is not None:
result = await provider.analyze_image(image_data=input_data, model=model_arg, **call_kwargs)
else:
result = await provider.analyze_image(image_data=input_data, **call_kwargs)
else:
result = await provider.analyze_image(image_data=input_data, **call_kwargs)
else:
raise ValueError(f"Unsupported task type: {task_type}")
latency = time.time() - start
quality = getattr(result, 'quality', None) or (result.metadata.get('quality') if hasattr(result, 'metadata') else None)
feedback = getattr(result, 'feedback', None) or (result.metadata.get('feedback') if hasattr(result, 'metadata') else None)
self.performance_tracker.record(provider_name, success=True, latency=latency, quality=quality, feedback=feedback)
return result
except Exception as e:
latency = time.time() - start
self.performance_tracker.record(provider_name, success=False, latency=latency)
self.fallback_policy.record_failure(provider_name)
# Centralized fallback logic
if self.fallback_policy.strategy == "retry" and self.fallback_policy.failure_counts[provider_name] <= self.fallback_policy.max_retries:
# Retry the same provider
return await self._handle_single_provider(task_type, input_data, config, use_adaptive_routing, **kwargs)
elif self.fallback_policy.strategy == "switch_provider" and len(config.preferred_providers) > 1:
# Switch to next best provider
remaining = [p for p in config.preferred_providers if p != provider_name]
if remaining:
next_provider = self.performance_tracker.get_best_provider(remaining)
if self.fallback_policy.notify_user:
print(self.fallback_policy.get_fallback_message(provider_name, e))
# Try next provider
config_copy = config.copy()
config_copy.preferred_providers = remaining
return await self._handle_single_provider(task_type, input_data, config_copy, use_adaptive_routing, **kwargs)
if self.fallback_policy.notify_user:
print(self.fallback_policy.get_fallback_message(provider_name, e))
raise
raise ValueError(f"Unsupported task type: {task_type}")
latency = time.time() - start
quality = getattr(result, 'quality', None) or (result.metadata.get('quality') if hasattr(result, 'metadata') else None)
feedback = getattr(result, 'feedback', None) or (result.metadata.get('feedback') if hasattr(result, 'metadata') else None)
self.performance_tracker.record(provider_name, success=True, latency=latency, quality=quality, feedback=feedback)
return result
except Exception as e:
latency = time.time() - start
self.performance_tracker.record(provider_name, success=False, latency=latency)
self.fallback_policy.record_failure(provider_name)
last_error = e

# Centralized fallback logic after retry attempts are exhausted.
if self.fallback_policy.strategy == "switch_provider" and len(config.preferred_providers) > 1:
# Switch to next best provider
remaining = [p for p in config.preferred_providers if p != provider_name]
if remaining:
next_provider = self.performance_tracker.get_best_provider(remaining)
if self.fallback_policy.notify_user:
print(self.fallback_policy.get_fallback_message(provider_name, last_error))
# Try next provider
config_copy = config.copy()
config_copy.preferred_providers = remaining
return await self._handle_single_provider(task_type, input_data, config_copy, use_adaptive_routing, **kwargs)
if self.fallback_policy.notify_user:
print(self.fallback_policy.get_fallback_message(provider_name, last_error))
raise last_error

async def _handle_ensemble(
self,
Expand Down
Loading