Skip to content

Commit 40e5453

Browse files
tosin2013claude
andcommitted
feat(orchestrator): Add SSH pre-flight auto-fix for DAG triggers
Runs automatic SSH connection checks before every DAG trigger via the intent parser. Validates connection existence, SSH user, key config, and sshd reachability against the Airflow REST API. Auto-creates missing connections and patches incorrect SSH users. Results are cached for 5 minutes and reported inline with the trigger response. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 39f55cd commit 40e5453

3 files changed

Lines changed: 653 additions & 1 deletion

File tree

intent_parser/handlers/dag.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,25 @@ async def handle_dag_trigger(params: Dict) -> str:
168168
dag_id = params.get("dag_id")
169169
if not dag_id:
170170
return "Error: DAG ID is required. Try: 'trigger dag <dag_id>'"
171+
172+
# Run SSH pre-flight checks with auto-fix
173+
from ..ssh_preflight import run_ssh_preflight
174+
175+
preflight = await run_ssh_preflight()
176+
preflight_report = preflight.format_report()
177+
171178
conf = params.get("conf")
172-
return await _call_with_http_fallback(
179+
trigger_result = await _call_with_http_fallback(
173180
(lambda dag_id, conf: _trigger_dag(dag_id=dag_id, conf=conf)) if _backend_available else (lambda dag_id, conf: ""),
174181
_http_trigger_dag,
175182
dag_id,
176183
conf,
177184
)
178185

186+
if preflight_report:
187+
return f"{preflight_report}\n\n{trigger_result}"
188+
return trigger_result
189+
179190

180191
register(IntentCategory.DAG_LIST, handle_dag_list)
181192
register(IntentCategory.DAG_INFO, handle_dag_info)

intent_parser/ssh_preflight.py

Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
"""
2+
SSH pre-flight checks with auto-fix for DAG triggers.
3+
4+
Validates the Airflow SSH connection (localhost_ssh) before triggering DAGs
5+
that rely on SSHOperator. Checks connection existence, SSH user, key config,
6+
and sshd reachability. Auto-fixes what it can, warns about the rest.
7+
8+
Results are cached to avoid repeated API calls on consecutive triggers.
9+
"""
10+
11+
import asyncio
12+
import json
13+
import logging
14+
import os
15+
import time
16+
from dataclasses import dataclass, field
17+
from enum import Enum
18+
from typing import Any, Dict, List, Optional
19+
20+
import httpx
21+
22+
logger = logging.getLogger("intent-parser.ssh-preflight")
23+
24+
25+
class CheckStatus(str, Enum):
26+
OK = "ok"
27+
FIXED = "fixed"
28+
WARNING = "warning"
29+
ERROR = "error"
30+
31+
32+
@dataclass
33+
class PreflightCheck:
34+
name: str
35+
status: CheckStatus
36+
message: str
37+
fix_applied: Optional[str] = None
38+
39+
40+
@dataclass
41+
class PreflightResult:
42+
checks: List[PreflightCheck] = field(default_factory=list)
43+
can_proceed: bool = True # Always True — warn but never block
44+
summary: str = ""
45+
46+
def format_report(self) -> str:
47+
"""Format checks into a human-readable report for prepending to trigger output."""
48+
if not self.checks:
49+
return ""
50+
51+
# If everything is OK, keep it brief
52+
if all(c.status == CheckStatus.OK for c in self.checks):
53+
return "[SSH Pre-flight] All checks passed."
54+
55+
fixed = [c for c in self.checks if c.status == CheckStatus.FIXED]
56+
warnings = [c for c in self.checks if c.status == CheckStatus.WARNING]
57+
errors = [c for c in self.checks if c.status == CheckStatus.ERROR]
58+
59+
parts = ["[SSH Pre-flight]"]
60+
61+
if fixed:
62+
fixes = "; ".join(c.fix_applied or c.message for c in fixed)
63+
parts.append(f" Auto-fixed {len(fixed)} issue(s): {fixes}")
64+
65+
if warnings:
66+
for w in warnings:
67+
parts.append(f" WARNING: {w.message}")
68+
69+
if errors:
70+
for e in errors:
71+
parts.append(f" ERROR: {e.message}")
72+
73+
return "\n".join(parts)
74+
75+
76+
# ---------------------------------------------------------------------------
77+
# Configuration
78+
# ---------------------------------------------------------------------------
79+
80+
def _get_config() -> Dict[str, str]:
81+
return {
82+
"api_url": os.getenv("AIRFLOW_API_URL", "http://localhost:8888"),
83+
"user": os.getenv("AIRFLOW_USER", os.getenv("AIRFLOW_API_USER", "admin")),
84+
"password": os.getenv("AIRFLOW_PASSWORD", os.getenv("AIRFLOW_API_PASSWORD", "admin")),
85+
"ssh_user": os.getenv("QUBINODE_SSH_USER", os.getenv("USER", "root")),
86+
"conn_id": os.getenv("QUBINODE_SSH_CONN_ID", "localhost_ssh"),
87+
"cache_ttl": int(os.getenv("SSH_PREFLIGHT_CACHE_TTL", "300")),
88+
}
89+
90+
91+
# ---------------------------------------------------------------------------
92+
# Cache
93+
# ---------------------------------------------------------------------------
94+
95+
_cache: Dict[str, tuple] = {} # {conn_id: (timestamp, PreflightResult)}
96+
97+
98+
def clear_cache() -> None:
99+
"""Clear the preflight cache (useful for testing)."""
100+
_cache.clear()
101+
102+
103+
def _get_cached(conn_id: str, ttl: int) -> Optional[PreflightResult]:
104+
entry = _cache.get(conn_id)
105+
if entry is None:
106+
return None
107+
ts, result = entry
108+
if time.time() - ts > ttl:
109+
del _cache[conn_id]
110+
return None
111+
return result
112+
113+
114+
def _set_cached(conn_id: str, result: PreflightResult) -> None:
115+
_cache[conn_id] = (time.time(), result)
116+
117+
118+
# ---------------------------------------------------------------------------
119+
# Individual checks
120+
# ---------------------------------------------------------------------------
121+
122+
async def _check_connection_exists(
123+
client: Any, api_url: str, auth: tuple, conn_id: str, ssh_user: str
124+
) -> tuple:
125+
"""Check if the SSH connection exists; create it if missing.
126+
127+
Returns (PreflightCheck, conn_data_or_None).
128+
"""
129+
try:
130+
resp = await client.get(
131+
f"{api_url}/api/v1/connections/{conn_id}",
132+
auth=auth,
133+
)
134+
except Exception as exc:
135+
return PreflightCheck(
136+
name="connection_exists",
137+
status=CheckStatus.ERROR,
138+
message=f"Cannot reach Airflow API: {exc}",
139+
), None
140+
141+
if resp.status_code == 200:
142+
conn_data = resp.json()
143+
return PreflightCheck(
144+
name="connection_exists",
145+
status=CheckStatus.OK,
146+
message=f"Connection '{conn_id}' exists.",
147+
), conn_data
148+
149+
if resp.status_code == 404:
150+
# Auto-create the connection
151+
new_conn = {
152+
"connection_id": conn_id,
153+
"conn_type": "ssh",
154+
"host": "localhost",
155+
"login": ssh_user,
156+
"port": 22,
157+
"extra": json.dumps({"key_file": f"/home/{ssh_user}/.ssh/id_rsa"}),
158+
}
159+
try:
160+
create_resp = await client.post(
161+
f"{api_url}/api/v1/connections",
162+
json=new_conn,
163+
auth=auth,
164+
)
165+
if create_resp.status_code in (200, 201):
166+
return PreflightCheck(
167+
name="connection_exists",
168+
status=CheckStatus.FIXED,
169+
message=f"Created missing connection '{conn_id}'.",
170+
fix_applied=f"created connection '{conn_id}'",
171+
), new_conn
172+
else:
173+
return PreflightCheck(
174+
name="connection_exists",
175+
status=CheckStatus.ERROR,
176+
message=f"Failed to create connection: HTTP {create_resp.status_code}",
177+
), None
178+
except Exception as exc:
179+
return PreflightCheck(
180+
name="connection_exists",
181+
status=CheckStatus.ERROR,
182+
message=f"Failed to create connection: {exc}",
183+
), None
184+
185+
return PreflightCheck(
186+
name="connection_exists",
187+
status=CheckStatus.ERROR,
188+
message=f"Unexpected API response: HTTP {resp.status_code}",
189+
), None
190+
191+
192+
async def _check_ssh_user(
193+
client: Any, api_url: str, auth: tuple, conn_id: str, conn_data: Dict, ssh_user: str
194+
) -> PreflightCheck:
195+
"""Check that the connection login matches the expected SSH user; patch if wrong."""
196+
current_login = conn_data.get("login", "")
197+
if current_login == ssh_user:
198+
return PreflightCheck(
199+
name="ssh_user",
200+
status=CheckStatus.OK,
201+
message=f"SSH user is '{ssh_user}'.",
202+
)
203+
204+
# Auto-fix: PATCH the connection
205+
try:
206+
patch_resp = await client.patch(
207+
f"{api_url}/api/v1/connections/{conn_id}",
208+
json={"login": ssh_user},
209+
auth=auth,
210+
)
211+
if patch_resp.status_code == 200:
212+
return PreflightCheck(
213+
name="ssh_user",
214+
status=CheckStatus.FIXED,
215+
message=f"Updated SSH user from '{current_login}' to '{ssh_user}'.",
216+
fix_applied=f"updated SSH user to '{ssh_user}'",
217+
)
218+
return PreflightCheck(
219+
name="ssh_user",
220+
status=CheckStatus.WARNING,
221+
message=f"SSH user is '{current_login}' but expected '{ssh_user}'. "
222+
f"PATCH failed with HTTP {patch_resp.status_code}.",
223+
)
224+
except Exception as exc:
225+
return PreflightCheck(
226+
name="ssh_user",
227+
status=CheckStatus.WARNING,
228+
message=f"SSH user is '{current_login}' but expected '{ssh_user}'. "
229+
f"Auto-fix failed: {exc}",
230+
)
231+
232+
233+
def _check_ssh_key(conn_data: Dict) -> PreflightCheck:
234+
"""Check if the connection has an SSH key file configured (report only)."""
235+
extra_raw = conn_data.get("extra", "{}")
236+
if isinstance(extra_raw, str):
237+
try:
238+
extra = json.loads(extra_raw)
239+
except (json.JSONDecodeError, TypeError):
240+
extra = {}
241+
else:
242+
extra = extra_raw
243+
244+
key_file = extra.get("key_file", "")
245+
if key_file:
246+
return PreflightCheck(
247+
name="ssh_key",
248+
status=CheckStatus.OK,
249+
message=f"SSH key configured: {key_file}",
250+
)
251+
return PreflightCheck(
252+
name="ssh_key",
253+
status=CheckStatus.WARNING,
254+
message="No SSH key file configured in connection extras. "
255+
"SSHOperator may rely on ssh-agent or password auth.",
256+
)
257+
258+
259+
async def _check_sshd_reachable(
260+
client: Any, api_url: str, auth: tuple, conn_id: str
261+
) -> PreflightCheck:
262+
"""Check if sshd is reachable via the Airflow connection test API or TCP fallback."""
263+
# Try Airflow's connection test endpoint first
264+
try:
265+
resp = await client.post(
266+
f"{api_url}/api/v1/connections/test",
267+
json={"connection_id": conn_id},
268+
auth=auth,
269+
)
270+
if resp.status_code == 200:
271+
data = resp.json()
272+
if data.get("status", False):
273+
return PreflightCheck(
274+
name="sshd_reachable",
275+
status=CheckStatus.OK,
276+
message="sshd is reachable on localhost:22.",
277+
)
278+
except Exception:
279+
pass # Fall through to TCP check
280+
281+
# TCP fallback: try connecting to port 22 directly
282+
283+
try:
284+
_, writer = await asyncio.wait_for(
285+
asyncio.open_connection("localhost", 22),
286+
timeout=3.0,
287+
)
288+
writer.close()
289+
await writer.wait_closed()
290+
return PreflightCheck(
291+
name="sshd_reachable",
292+
status=CheckStatus.OK,
293+
message="sshd is reachable on localhost:22 (TCP check).",
294+
)
295+
except Exception:
296+
return PreflightCheck(
297+
name="sshd_reachable",
298+
status=CheckStatus.WARNING,
299+
message="Cannot reach sshd on localhost:22. "
300+
"Ensure sshd is running: 'sudo systemctl start sshd'",
301+
)
302+
303+
304+
# ---------------------------------------------------------------------------
305+
# Main entry point
306+
# ---------------------------------------------------------------------------
307+
308+
async def run_ssh_preflight(force: bool = False) -> PreflightResult:
309+
"""Run SSH pre-flight checks, returning a PreflightResult.
310+
311+
Results are cached for `SSH_PREFLIGHT_CACHE_TTL` seconds (default 300).
312+
Pass force=True to bypass cache.
313+
"""
314+
cfg = _get_config()
315+
conn_id = cfg["conn_id"]
316+
317+
if not force:
318+
cached = _get_cached(conn_id, cfg["cache_ttl"])
319+
if cached is not None:
320+
logger.debug("SSH preflight cache hit for %s", conn_id)
321+
return cached
322+
323+
checks: List[PreflightCheck] = []
324+
auth = (cfg["user"], cfg["password"])
325+
326+
async with httpx.AsyncClient(timeout=10.0) as client:
327+
# Check 1: Connection exists
328+
conn_check, conn_data = await _check_connection_exists(
329+
client, cfg["api_url"], auth, conn_id, cfg["ssh_user"]
330+
)
331+
checks.append(conn_check)
332+
333+
if conn_data is not None:
334+
# Check 2: SSH user
335+
user_check = await _check_ssh_user(
336+
client, cfg["api_url"], auth, conn_id, conn_data, cfg["ssh_user"]
337+
)
338+
checks.append(user_check)
339+
340+
# Check 3: SSH key
341+
key_check = _check_ssh_key(conn_data)
342+
checks.append(key_check)
343+
344+
# Check 4: sshd reachable
345+
sshd_check = await _check_sshd_reachable(
346+
client, cfg["api_url"], auth, conn_id
347+
)
348+
checks.append(sshd_check)
349+
350+
result = PreflightResult(checks=checks)
351+
result.summary = result.format_report()
352+
_set_cached(conn_id, result)
353+
return result

0 commit comments

Comments
 (0)