|
| 1 | +""" |
| 2 | +SSH pre-flight checks with auto-fix for DAG triggers. |
| 3 | +
|
| 4 | +Validates the Airflow SSH connection (localhost_ssh) before triggering DAGs |
| 5 | +that rely on SSHOperator. Checks connection existence, SSH user, key config, |
| 6 | +and sshd reachability. Auto-fixes what it can, warns about the rest. |
| 7 | +
|
| 8 | +Results are cached to avoid repeated API calls on consecutive triggers. |
| 9 | +""" |
| 10 | + |
| 11 | +import asyncio |
| 12 | +import json |
| 13 | +import logging |
| 14 | +import os |
| 15 | +import time |
| 16 | +from dataclasses import dataclass, field |
| 17 | +from enum import Enum |
| 18 | +from typing import Any, Dict, List, Optional |
| 19 | + |
| 20 | +import httpx |
| 21 | + |
| 22 | +logger = logging.getLogger("intent-parser.ssh-preflight") |
| 23 | + |
| 24 | + |
| 25 | +class CheckStatus(str, Enum): |
| 26 | + OK = "ok" |
| 27 | + FIXED = "fixed" |
| 28 | + WARNING = "warning" |
| 29 | + ERROR = "error" |
| 30 | + |
| 31 | + |
| 32 | +@dataclass |
| 33 | +class PreflightCheck: |
| 34 | + name: str |
| 35 | + status: CheckStatus |
| 36 | + message: str |
| 37 | + fix_applied: Optional[str] = None |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class PreflightResult: |
| 42 | + checks: List[PreflightCheck] = field(default_factory=list) |
| 43 | + can_proceed: bool = True # Always True — warn but never block |
| 44 | + summary: str = "" |
| 45 | + |
| 46 | + def format_report(self) -> str: |
| 47 | + """Format checks into a human-readable report for prepending to trigger output.""" |
| 48 | + if not self.checks: |
| 49 | + return "" |
| 50 | + |
| 51 | + # If everything is OK, keep it brief |
| 52 | + if all(c.status == CheckStatus.OK for c in self.checks): |
| 53 | + return "[SSH Pre-flight] All checks passed." |
| 54 | + |
| 55 | + fixed = [c for c in self.checks if c.status == CheckStatus.FIXED] |
| 56 | + warnings = [c for c in self.checks if c.status == CheckStatus.WARNING] |
| 57 | + errors = [c for c in self.checks if c.status == CheckStatus.ERROR] |
| 58 | + |
| 59 | + parts = ["[SSH Pre-flight]"] |
| 60 | + |
| 61 | + if fixed: |
| 62 | + fixes = "; ".join(c.fix_applied or c.message for c in fixed) |
| 63 | + parts.append(f" Auto-fixed {len(fixed)} issue(s): {fixes}") |
| 64 | + |
| 65 | + if warnings: |
| 66 | + for w in warnings: |
| 67 | + parts.append(f" WARNING: {w.message}") |
| 68 | + |
| 69 | + if errors: |
| 70 | + for e in errors: |
| 71 | + parts.append(f" ERROR: {e.message}") |
| 72 | + |
| 73 | + return "\n".join(parts) |
| 74 | + |
| 75 | + |
| 76 | +# --------------------------------------------------------------------------- |
| 77 | +# Configuration |
| 78 | +# --------------------------------------------------------------------------- |
| 79 | + |
| 80 | +def _get_config() -> Dict[str, str]: |
| 81 | + return { |
| 82 | + "api_url": os.getenv("AIRFLOW_API_URL", "http://localhost:8888"), |
| 83 | + "user": os.getenv("AIRFLOW_USER", os.getenv("AIRFLOW_API_USER", "admin")), |
| 84 | + "password": os.getenv("AIRFLOW_PASSWORD", os.getenv("AIRFLOW_API_PASSWORD", "admin")), |
| 85 | + "ssh_user": os.getenv("QUBINODE_SSH_USER", os.getenv("USER", "root")), |
| 86 | + "conn_id": os.getenv("QUBINODE_SSH_CONN_ID", "localhost_ssh"), |
| 87 | + "cache_ttl": int(os.getenv("SSH_PREFLIGHT_CACHE_TTL", "300")), |
| 88 | + } |
| 89 | + |
| 90 | + |
| 91 | +# --------------------------------------------------------------------------- |
| 92 | +# Cache |
| 93 | +# --------------------------------------------------------------------------- |
| 94 | + |
| 95 | +_cache: Dict[str, tuple] = {} # {conn_id: (timestamp, PreflightResult)} |
| 96 | + |
| 97 | + |
| 98 | +def clear_cache() -> None: |
| 99 | + """Clear the preflight cache (useful for testing).""" |
| 100 | + _cache.clear() |
| 101 | + |
| 102 | + |
| 103 | +def _get_cached(conn_id: str, ttl: int) -> Optional[PreflightResult]: |
| 104 | + entry = _cache.get(conn_id) |
| 105 | + if entry is None: |
| 106 | + return None |
| 107 | + ts, result = entry |
| 108 | + if time.time() - ts > ttl: |
| 109 | + del _cache[conn_id] |
| 110 | + return None |
| 111 | + return result |
| 112 | + |
| 113 | + |
| 114 | +def _set_cached(conn_id: str, result: PreflightResult) -> None: |
| 115 | + _cache[conn_id] = (time.time(), result) |
| 116 | + |
| 117 | + |
| 118 | +# --------------------------------------------------------------------------- |
| 119 | +# Individual checks |
| 120 | +# --------------------------------------------------------------------------- |
| 121 | + |
| 122 | +async def _check_connection_exists( |
| 123 | + client: Any, api_url: str, auth: tuple, conn_id: str, ssh_user: str |
| 124 | +) -> tuple: |
| 125 | + """Check if the SSH connection exists; create it if missing. |
| 126 | +
|
| 127 | + Returns (PreflightCheck, conn_data_or_None). |
| 128 | + """ |
| 129 | + try: |
| 130 | + resp = await client.get( |
| 131 | + f"{api_url}/api/v1/connections/{conn_id}", |
| 132 | + auth=auth, |
| 133 | + ) |
| 134 | + except Exception as exc: |
| 135 | + return PreflightCheck( |
| 136 | + name="connection_exists", |
| 137 | + status=CheckStatus.ERROR, |
| 138 | + message=f"Cannot reach Airflow API: {exc}", |
| 139 | + ), None |
| 140 | + |
| 141 | + if resp.status_code == 200: |
| 142 | + conn_data = resp.json() |
| 143 | + return PreflightCheck( |
| 144 | + name="connection_exists", |
| 145 | + status=CheckStatus.OK, |
| 146 | + message=f"Connection '{conn_id}' exists.", |
| 147 | + ), conn_data |
| 148 | + |
| 149 | + if resp.status_code == 404: |
| 150 | + # Auto-create the connection |
| 151 | + new_conn = { |
| 152 | + "connection_id": conn_id, |
| 153 | + "conn_type": "ssh", |
| 154 | + "host": "localhost", |
| 155 | + "login": ssh_user, |
| 156 | + "port": 22, |
| 157 | + "extra": json.dumps({"key_file": f"/home/{ssh_user}/.ssh/id_rsa"}), |
| 158 | + } |
| 159 | + try: |
| 160 | + create_resp = await client.post( |
| 161 | + f"{api_url}/api/v1/connections", |
| 162 | + json=new_conn, |
| 163 | + auth=auth, |
| 164 | + ) |
| 165 | + if create_resp.status_code in (200, 201): |
| 166 | + return PreflightCheck( |
| 167 | + name="connection_exists", |
| 168 | + status=CheckStatus.FIXED, |
| 169 | + message=f"Created missing connection '{conn_id}'.", |
| 170 | + fix_applied=f"created connection '{conn_id}'", |
| 171 | + ), new_conn |
| 172 | + else: |
| 173 | + return PreflightCheck( |
| 174 | + name="connection_exists", |
| 175 | + status=CheckStatus.ERROR, |
| 176 | + message=f"Failed to create connection: HTTP {create_resp.status_code}", |
| 177 | + ), None |
| 178 | + except Exception as exc: |
| 179 | + return PreflightCheck( |
| 180 | + name="connection_exists", |
| 181 | + status=CheckStatus.ERROR, |
| 182 | + message=f"Failed to create connection: {exc}", |
| 183 | + ), None |
| 184 | + |
| 185 | + return PreflightCheck( |
| 186 | + name="connection_exists", |
| 187 | + status=CheckStatus.ERROR, |
| 188 | + message=f"Unexpected API response: HTTP {resp.status_code}", |
| 189 | + ), None |
| 190 | + |
| 191 | + |
| 192 | +async def _check_ssh_user( |
| 193 | + client: Any, api_url: str, auth: tuple, conn_id: str, conn_data: Dict, ssh_user: str |
| 194 | +) -> PreflightCheck: |
| 195 | + """Check that the connection login matches the expected SSH user; patch if wrong.""" |
| 196 | + current_login = conn_data.get("login", "") |
| 197 | + if current_login == ssh_user: |
| 198 | + return PreflightCheck( |
| 199 | + name="ssh_user", |
| 200 | + status=CheckStatus.OK, |
| 201 | + message=f"SSH user is '{ssh_user}'.", |
| 202 | + ) |
| 203 | + |
| 204 | + # Auto-fix: PATCH the connection |
| 205 | + try: |
| 206 | + patch_resp = await client.patch( |
| 207 | + f"{api_url}/api/v1/connections/{conn_id}", |
| 208 | + json={"login": ssh_user}, |
| 209 | + auth=auth, |
| 210 | + ) |
| 211 | + if patch_resp.status_code == 200: |
| 212 | + return PreflightCheck( |
| 213 | + name="ssh_user", |
| 214 | + status=CheckStatus.FIXED, |
| 215 | + message=f"Updated SSH user from '{current_login}' to '{ssh_user}'.", |
| 216 | + fix_applied=f"updated SSH user to '{ssh_user}'", |
| 217 | + ) |
| 218 | + return PreflightCheck( |
| 219 | + name="ssh_user", |
| 220 | + status=CheckStatus.WARNING, |
| 221 | + message=f"SSH user is '{current_login}' but expected '{ssh_user}'. " |
| 222 | + f"PATCH failed with HTTP {patch_resp.status_code}.", |
| 223 | + ) |
| 224 | + except Exception as exc: |
| 225 | + return PreflightCheck( |
| 226 | + name="ssh_user", |
| 227 | + status=CheckStatus.WARNING, |
| 228 | + message=f"SSH user is '{current_login}' but expected '{ssh_user}'. " |
| 229 | + f"Auto-fix failed: {exc}", |
| 230 | + ) |
| 231 | + |
| 232 | + |
| 233 | +def _check_ssh_key(conn_data: Dict) -> PreflightCheck: |
| 234 | + """Check if the connection has an SSH key file configured (report only).""" |
| 235 | + extra_raw = conn_data.get("extra", "{}") |
| 236 | + if isinstance(extra_raw, str): |
| 237 | + try: |
| 238 | + extra = json.loads(extra_raw) |
| 239 | + except (json.JSONDecodeError, TypeError): |
| 240 | + extra = {} |
| 241 | + else: |
| 242 | + extra = extra_raw |
| 243 | + |
| 244 | + key_file = extra.get("key_file", "") |
| 245 | + if key_file: |
| 246 | + return PreflightCheck( |
| 247 | + name="ssh_key", |
| 248 | + status=CheckStatus.OK, |
| 249 | + message=f"SSH key configured: {key_file}", |
| 250 | + ) |
| 251 | + return PreflightCheck( |
| 252 | + name="ssh_key", |
| 253 | + status=CheckStatus.WARNING, |
| 254 | + message="No SSH key file configured in connection extras. " |
| 255 | + "SSHOperator may rely on ssh-agent or password auth.", |
| 256 | + ) |
| 257 | + |
| 258 | + |
| 259 | +async def _check_sshd_reachable( |
| 260 | + client: Any, api_url: str, auth: tuple, conn_id: str |
| 261 | +) -> PreflightCheck: |
| 262 | + """Check if sshd is reachable via the Airflow connection test API or TCP fallback.""" |
| 263 | + # Try Airflow's connection test endpoint first |
| 264 | + try: |
| 265 | + resp = await client.post( |
| 266 | + f"{api_url}/api/v1/connections/test", |
| 267 | + json={"connection_id": conn_id}, |
| 268 | + auth=auth, |
| 269 | + ) |
| 270 | + if resp.status_code == 200: |
| 271 | + data = resp.json() |
| 272 | + if data.get("status", False): |
| 273 | + return PreflightCheck( |
| 274 | + name="sshd_reachable", |
| 275 | + status=CheckStatus.OK, |
| 276 | + message="sshd is reachable on localhost:22.", |
| 277 | + ) |
| 278 | + except Exception: |
| 279 | + pass # Fall through to TCP check |
| 280 | + |
| 281 | + # TCP fallback: try connecting to port 22 directly |
| 282 | + |
| 283 | + try: |
| 284 | + _, writer = await asyncio.wait_for( |
| 285 | + asyncio.open_connection("localhost", 22), |
| 286 | + timeout=3.0, |
| 287 | + ) |
| 288 | + writer.close() |
| 289 | + await writer.wait_closed() |
| 290 | + return PreflightCheck( |
| 291 | + name="sshd_reachable", |
| 292 | + status=CheckStatus.OK, |
| 293 | + message="sshd is reachable on localhost:22 (TCP check).", |
| 294 | + ) |
| 295 | + except Exception: |
| 296 | + return PreflightCheck( |
| 297 | + name="sshd_reachable", |
| 298 | + status=CheckStatus.WARNING, |
| 299 | + message="Cannot reach sshd on localhost:22. " |
| 300 | + "Ensure sshd is running: 'sudo systemctl start sshd'", |
| 301 | + ) |
| 302 | + |
| 303 | + |
| 304 | +# --------------------------------------------------------------------------- |
| 305 | +# Main entry point |
| 306 | +# --------------------------------------------------------------------------- |
| 307 | + |
| 308 | +async def run_ssh_preflight(force: bool = False) -> PreflightResult: |
| 309 | + """Run SSH pre-flight checks, returning a PreflightResult. |
| 310 | +
|
| 311 | + Results are cached for `SSH_PREFLIGHT_CACHE_TTL` seconds (default 300). |
| 312 | + Pass force=True to bypass cache. |
| 313 | + """ |
| 314 | + cfg = _get_config() |
| 315 | + conn_id = cfg["conn_id"] |
| 316 | + |
| 317 | + if not force: |
| 318 | + cached = _get_cached(conn_id, cfg["cache_ttl"]) |
| 319 | + if cached is not None: |
| 320 | + logger.debug("SSH preflight cache hit for %s", conn_id) |
| 321 | + return cached |
| 322 | + |
| 323 | + checks: List[PreflightCheck] = [] |
| 324 | + auth = (cfg["user"], cfg["password"]) |
| 325 | + |
| 326 | + async with httpx.AsyncClient(timeout=10.0) as client: |
| 327 | + # Check 1: Connection exists |
| 328 | + conn_check, conn_data = await _check_connection_exists( |
| 329 | + client, cfg["api_url"], auth, conn_id, cfg["ssh_user"] |
| 330 | + ) |
| 331 | + checks.append(conn_check) |
| 332 | + |
| 333 | + if conn_data is not None: |
| 334 | + # Check 2: SSH user |
| 335 | + user_check = await _check_ssh_user( |
| 336 | + client, cfg["api_url"], auth, conn_id, conn_data, cfg["ssh_user"] |
| 337 | + ) |
| 338 | + checks.append(user_check) |
| 339 | + |
| 340 | + # Check 3: SSH key |
| 341 | + key_check = _check_ssh_key(conn_data) |
| 342 | + checks.append(key_check) |
| 343 | + |
| 344 | + # Check 4: sshd reachable |
| 345 | + sshd_check = await _check_sshd_reachable( |
| 346 | + client, cfg["api_url"], auth, conn_id |
| 347 | + ) |
| 348 | + checks.append(sshd_check) |
| 349 | + |
| 350 | + result = PreflightResult(checks=checks) |
| 351 | + result.summary = result.format_report() |
| 352 | + _set_cached(conn_id, result) |
| 353 | + return result |
0 commit comments