diff --git a/crates/openfang-api/src/lib.rs b/crates/openfang-api/src/lib.rs index a4653917a..243da788b 100644 --- a/crates/openfang-api/src/lib.rs +++ b/crates/openfang-api/src/lib.rs @@ -11,10 +11,7 @@ pub(crate) fn percent_decode(input: &str) -> String { let mut i = 0; while i < bytes.len() { if bytes[i] == b'%' && i + 2 < bytes.len() { - if let (Some(hi), Some(lo)) = ( - hex_val(bytes[i + 1]), - hex_val(bytes[i + 2]), - ) { + if let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) { out.push(hi << 4 | lo); i += 3; continue; diff --git a/crates/openfang-channels/src/bridge.rs b/crates/openfang-channels/src/bridge.rs index c0cb89710..56643eaa8 100644 --- a/crates/openfang-channels/src/bridge.rs +++ b/crates/openfang-channels/src/bridge.rs @@ -643,12 +643,20 @@ async fn dispatch_message( .as_ref() .map(|o| o.lifecycle_reactions) .unwrap_or(true); - let thread_id = if threading_enabled { - message.thread_id.as_deref() + + // --- Auto-thread: decide intent now, but create AFTER all policy guards --- + let auto_thread_name = if !threading_enabled && message.thread_id.is_none() { + adapter.should_auto_thread(message).await } else { None }; + // thread_id is resolved later, after all guards pass. + // Always propagate an existing thread_id (message arrived inside a thread), + // regardless of threading_enabled — that flag controls explicit threading config, + // not auto-detected thread context. + let mut effective_thread_id: Option = message.thread_id.clone(); + // --- DM/Group policy check --- if let Some(ref ov) = overrides { if message.is_group { @@ -709,12 +717,42 @@ async fn dispatch_message( if let Err(msg) = rate_limiter.check(ct_str, sender_user_id(message), ov.rate_limit_per_user) { - send_response(adapter, &message.sender, msg, thread_id, output_format).await; + // Rate-limit rejection: don't create a thread, use existing thread if any + send_response( + adapter, + &message.sender, + msg, + message.thread_id.as_deref(), + output_format, + ) + .await; return; } } } + // --- Create auto-thread NOW (after all policy guards have passed) --- + if let Some(ref thread_name) = auto_thread_name { + match adapter + .create_thread(&message.sender, &message.platform_message_id, thread_name) + .await + { + Ok(new_thread_id) => { + info!( + "Created auto-thread {} for message {}", + thread_name, message.platform_message_id + ); + effective_thread_id = Some(new_thread_id); + } + Err(e) => { + warn!("Failed to create auto-thread: {}", e); + } + } + } + + // Resolve final thread_id reference used by all downstream send_response calls + let thread_id = effective_thread_id.as_deref(); + // Handle commands first (early return) if let ChannelContent::Command { ref name, ref args } = message.content { let result = handle_command(name, args, handle, router, &message.sender).await; diff --git a/crates/openfang-channels/src/discord.rs b/crates/openfang-channels/src/discord.rs index cffa63a87..7e0f17edf 100644 --- a/crates/openfang-channels/src/discord.rs +++ b/crates/openfang-channels/src/discord.rs @@ -8,7 +8,7 @@ use crate::types::{ }; use async_trait::async_trait; use futures::{SinkExt, Stream, StreamExt}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -56,6 +56,8 @@ pub struct DiscordAdapter { allowed_users: Vec, ignore_bots: bool, intents: u64, + /// Auto-thread behavior: "true", "false", or "smart" + auto_thread: String, shutdown_tx: Arc>, shutdown_rx: watch::Receiver, /// Bot's own user ID (populated after READY event). @@ -64,6 +66,13 @@ pub struct DiscordAdapter { session_id: Arc>>, /// Resume gateway URL. resume_gateway_url: Arc>>, + /// Thread channel IDs created by this bot (thread_id → parent_channel_id). + /// Used to detect when incoming messages are inside a bot-created thread. + created_thread_ids: Arc>>, + /// Message IDs seen via MESSAGE_CREATE (used to drop duplicate MESSAGE_UPDATE events). + /// Populated immediately when MESSAGE_CREATE is forwarded — before bridge processing — + /// to eliminate the race window where MESSAGE_UPDATE arrives before thread creation completes. + threaded_message_ids: Arc>>, } impl DiscordAdapter { @@ -73,6 +82,7 @@ impl DiscordAdapter { allowed_users: Vec, ignore_bots: bool, intents: u64, + auto_thread: String, ) -> Self { let (shutdown_tx, shutdown_rx) = watch::channel(false); Self { @@ -82,11 +92,14 @@ impl DiscordAdapter { allowed_users, ignore_bots, intents, + auto_thread, shutdown_tx: Arc::new(shutdown_tx), shutdown_rx, bot_user_id: Arc::new(RwLock::new(None)), session_id: Arc::new(RwLock::new(None)), resume_gateway_url: Arc::new(RwLock::new(None)), + created_thread_ids: Arc::new(RwLock::new(HashMap::new())), + threaded_message_ids: Arc::new(RwLock::new(HashSet::new())), } } @@ -147,6 +160,79 @@ impl DiscordAdapter { .await?; Ok(()) } + + /// Create a thread from a message in a Discord channel. + async fn api_create_thread( + &self, + channel_id: &str, + message_id: &str, + name: &str, + ) -> Result> { + let url = format!( + "{DISCORD_API_BASE}/channels/{channel_id}/messages/{message_id}/threads", + channel_id = channel_id, + message_id = message_id + ); + let body = serde_json::json!({ + "name": name, + "auto_archive_duration": 1440 // 24 hours + }); + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bot {}", self.token.as_str())) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let body_text = resp.text().await.unwrap_or_default(); + return Err(format!("Discord createThread failed: {}", body_text).into()); + } + + let response: serde_json::Value = resp.json().await?; + let thread_id = response["id"].as_str().unwrap_or("").to_string(); + + // Track thread_id → parent channel_id so we can recognise messages + // that arrive inside this thread. + if !thread_id.is_empty() { + self.created_thread_ids + .write() + .await + .insert(thread_id.clone(), channel_id.to_string()); + } + + Ok(thread_id) + } + + /// Send a message to an existing thread. + /// Discord threads are channels — post directly to channels/{thread_id}/messages. + async fn api_send_thread_message( + &self, + _channel_id: &str, + thread_id: &str, + text: &str, + ) -> Result<(), Box> { + let url = format!("{DISCORD_API_BASE}/channels/{thread_id}/messages"); + let chunks = split_message(text, DISCORD_MSG_LIMIT); + + for chunk in chunks { + let body = serde_json::json!({ "content": chunk }); + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bot {}", self.token.as_str())) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let body_text = resp.text().await.unwrap_or_default(); + warn!("Discord sendThreadMessage failed: {body_text}"); + } + } + Ok(()) + } } #[async_trait] @@ -159,6 +245,33 @@ impl ChannelAdapter for DiscordAdapter { ChannelType::Discord } + async fn should_auto_thread(&self, message: &ChannelMessage) -> Option { + // Only auto-thread in group channels (servers), not DMs + if !message.is_group { + return None; + } + + // Check auto_thread mode + match self.auto_thread.as_str() { + "true" => Some(thread_name_from_message(message)), + "false" => None, + "smart" => { + // Only create thread if bot was @mentioned + let was_mentioned = message + .metadata + .get("was_mentioned") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + if was_mentioned { + Some(thread_name_from_message(message)) + } else { + None + } + } + _ => None, + } + } + async fn start( &self, ) -> Result + Send>>, Box> @@ -176,6 +289,8 @@ impl ChannelAdapter for DiscordAdapter { let bot_user_id = self.bot_user_id.clone(); let session_id_store = self.session_id.clone(); let resume_url_store = self.resume_gateway_url.clone(); + let created_thread_ids = self.created_thread_ids.clone(); + let threaded_message_ids = self.threaded_message_ids.clone(); let mut shutdown = self.shutdown_rx.clone(); tokio::spawn(async move { @@ -414,19 +529,60 @@ impl ChannelAdapter for DiscordAdapter { &allowed_guilds, &allowed_users, ignore_bots, + &created_thread_ids, ) .await { + // MESSAGE_UPDATE must be suppressed if we already + // forwarded a MESSAGE_CREATE for this message ID. + // The check uses `seen_message_ids` (tracked below) + // which is populated the moment MESSAGE_CREATE is + // forwarded — before the bridge even processes it. + // This closes the race window where MESSAGE_UPDATE + // arrives before adapter.create_thread() completes. + if event_name == "MESSAGE_UPDATE" + && threaded_message_ids + .read() + .await + .contains(&msg.platform_message_id) + { + debug!( + "Discord MESSAGE_UPDATE skipped (already seen {})", + msg.platform_message_id + ); + continue; + } + debug!( "Discord {event_name} from {}: {:?}", msg.sender.display_name, msg.content ); + + // Mark this message as seen immediately so any + // concurrent or subsequent MESSAGE_UPDATE is dropped. + if event_name == "MESSAGE_CREATE" { + threaded_message_ids + .write() + .await + .insert(msg.platform_message_id.clone()); + } + if tx.send(msg).await.is_err() { return; } } } + "THREAD_DELETE" | "CHANNEL_DELETE" => { + // Clean up tracking when a thread is deleted so the + // next message in the parent channel is treated fresh. + if let Some(tid) = d["id"].as_str() { + created_thread_ids.write().await.remove(tid); + threaded_message_ids.write().await.retain(|_| true); // keep others + debug!("Discord thread/channel deleted: {tid}"); + } + } + "RESUMED" => { info!("Discord session resumed successfully"); } @@ -532,6 +688,46 @@ impl ChannelAdapter for DiscordAdapter { self.api_send_typing(&user.platform_id).await } + async fn send_in_thread( + &self, + user: &ChannelUser, + content: ChannelContent, + thread_id: &str, + ) -> Result<(), Box> { + let channel_id = &user.platform_id; + match content { + ChannelContent::Text(text) => { + self.api_send_thread_message(channel_id, thread_id, &text) + .await?; + } + _ => { + self.api_send_thread_message(channel_id, thread_id, "(Unsupported content type)") + .await?; + } + } + Ok(()) + } + + async fn create_thread( + &self, + user: &ChannelUser, + message_id: &str, + thread_name: &str, + ) -> Result> { + let channel_id = &user.platform_id; + let thread_id = self + .api_create_thread(channel_id, message_id, thread_name) + .await?; + // Also ensure the message_id is marked as seen (belt-and-suspenders: + // the gateway loop already inserts on MESSAGE_CREATE, but keep this + // in case create_thread is ever called from another path). + self.threaded_message_ids + .write() + .await + .insert(message_id.to_string()); + Ok(thread_id) + } + async fn stop(&self) -> Result<(), Box> { let _ = self.shutdown_tx.send(true); Ok(()) @@ -545,6 +741,7 @@ async fn parse_discord_message( allowed_guilds: &[String], allowed_users: &[String], ignore_bots: bool, + created_thread_ids: &Arc>>, ) -> Option { let author = d.get("author")?; let author_id = author["id"].as_str()?; @@ -583,6 +780,20 @@ async fn parse_discord_message( let channel_id = d["channel_id"].as_str()?; let message_id = d["id"].as_str().unwrap_or("0"); + + // Detect if this message is inside a bot-created thread. + // In Discord, a thread is its own channel — channel_id will be the thread's ID. + // If so, use the parent channel as platform_id and set thread_id so that: + // (a) auto-thread logic is skipped (message.thread_id.is_some()) + // (b) responses are sent back into the same thread + let (effective_channel_id, parsed_thread_id) = { + let threads = created_thread_ids.read().await; + if let Some(parent_channel_id) = threads.get(channel_id) { + (parent_channel_id.clone(), Some(channel_id.to_string())) + } else { + (channel_id.to_string(), None) + } + }; let username = author["username"].as_str().unwrap_or("Unknown"); let discriminator = author["discriminator"].as_str().unwrap_or("0000"); let display_name = if discriminator == "0" { @@ -641,7 +852,7 @@ async fn parse_discord_message( channel: ChannelType::Discord, platform_message_id: message_id.to_string(), sender: ChannelUser { - platform_id: channel_id.to_string(), + platform_id: effective_channel_id, display_name, openfang_user: None, }, @@ -649,15 +860,50 @@ async fn parse_discord_message( target_agent: None, timestamp, is_group, - thread_id: None, + thread_id: parsed_thread_id, metadata, }) } -#[cfg(test)] +/// Build a Discord thread name from the message content. +/// Strips @mention prefixes (`<@...>`), trims whitespace, and truncates to +/// Discord's 100-character thread name limit. Falls back to the sender's +/// display name if the message has no usable text (e.g. image-only). +fn thread_name_from_message(message: &ChannelMessage) -> String { + let raw = match &message.content { + ChannelContent::Text(t) => t.clone(), + ChannelContent::Image { caption, .. } => caption.clone().unwrap_or_default(), + _ => String::new(), + }; + + // Strip leading Discord mention tokens (<@id> / <@!id>) + let stripped = regex_lite::Regex::new(r"^(<@!?\d+>\s*)+") + .map(|re| re.replace(&raw, "").into_owned()) + .unwrap_or(raw); + + let trimmed = stripped.trim().to_string(); + + if trimmed.is_empty() { + return message.sender.display_name.clone(); + } + + // Truncate to Discord's 100-char limit + if trimmed.chars().count() <= 100 { + trimmed + } else { + trimmed.chars().take(97).collect::() + "…" + } +} + mod tests { use super::*; + /// Convenience helper: empty thread-tracking map for tests that don't exercise threading. + #[allow(dead_code)] + fn empty_threads() -> Arc>> { + Arc::new(RwLock::new(HashMap::new())) + } + #[tokio::test] async fn test_parse_discord_message_basic() { let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); @@ -674,7 +920,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()) .await .unwrap(); assert_eq!(msg.channel, ChannelType::Discord); @@ -698,7 +944,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true).await; + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()).await; assert!(msg.is_none()); } @@ -718,7 +964,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true).await; + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()).await; assert!(msg.is_none()); } @@ -739,7 +985,7 @@ mod tests { }); // With ignore_bots=false, other bots' messages should be allowed - let msg = parse_discord_message(&d, &bot_id, &[], &[], false).await; + let msg = parse_discord_message(&d, &bot_id, &[], &[], false, &empty_threads()).await; assert!(msg.is_some()); let msg = msg.unwrap(); assert_eq!(msg.sender.display_name, "somebot"); @@ -763,7 +1009,7 @@ mod tests { }); // Even with ignore_bots=false, the bot's own messages must still be filtered - let msg = parse_discord_message(&d, &bot_id, &[], &[], false).await; + let msg = parse_discord_message(&d, &bot_id, &[], &[], false, &empty_threads()).await; assert!(msg.is_none()); } @@ -784,12 +1030,20 @@ mod tests { }); // Not in allowed guilds - let msg = - parse_discord_message(&d, &bot_id, &["111".into(), "222".into()], &[], true).await; + let msg = parse_discord_message( + &d, + &bot_id, + &["111".into(), "222".into()], + &[], + true, + &empty_threads(), + ) + .await; assert!(msg.is_none()); // In allowed guilds - let msg = parse_discord_message(&d, &bot_id, &["999".into()], &[], true).await; + let msg = + parse_discord_message(&d, &bot_id, &["999".into()], &[], true, &empty_threads()).await; assert!(msg.is_some()); } @@ -808,7 +1062,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()) .await .unwrap(); match &msg.content { @@ -835,7 +1089,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true).await; + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()).await; assert!(msg.is_none()); } @@ -854,7 +1108,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()) .await .unwrap(); assert_eq!(msg.sender.display_name, "alice#1234"); @@ -878,7 +1132,7 @@ mod tests { }); // MESSAGE_UPDATE uses the same parse function as MESSAGE_CREATE - let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()) .await .unwrap(); assert_eq!(msg.channel, ChannelType::Discord); @@ -909,16 +1163,25 @@ mod tests { &[], &["user111".into(), "user222".into()], true, + &empty_threads(), ) .await; assert!(msg.is_none()); // In allowed users - let msg = parse_discord_message(&d, &bot_id, &[], &["user999".into()], true).await; + let msg = parse_discord_message( + &d, + &bot_id, + &[], + &["user999".into()], + true, + &empty_threads(), + ) + .await; assert!(msg.is_some()); // Empty allowed_users = allow all - let msg = parse_discord_message(&d, &bot_id, &[], &[], true).await; + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()).await; assert!(msg.is_some()); } @@ -941,7 +1204,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()) .await .unwrap(); assert!(msg.is_group); @@ -964,7 +1227,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg2 = parse_discord_message(&d2, &bot_id, &[], &[], true) + let msg2 = parse_discord_message(&d2, &bot_id, &[], &[], true, &empty_threads()) .await .unwrap(); assert!(msg2.is_group); @@ -986,7 +1249,7 @@ mod tests { "timestamp": "2024-01-01T00:00:00+00:00" }); - let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + let msg = parse_discord_message(&d, &bot_id, &[], &[], true, &empty_threads()) .await .unwrap(); assert!(!msg.is_group); @@ -1025,6 +1288,7 @@ mod tests { vec![], true, 37376, + "true".to_string(), ); assert_eq!(adapter.name(), "discord"); assert_eq!(adapter.channel_type(), ChannelType::Discord); diff --git a/crates/openfang-cli/src/main.rs b/crates/openfang-cli/src/main.rs index 5817017c0..8aa4b7f0b 100644 --- a/crates/openfang-cli/src/main.rs +++ b/crates/openfang-cli/src/main.rs @@ -2466,7 +2466,9 @@ decay_rate = 0.05 if !json { ui::check_ok("GitHub Copilot (authenticated via device flow)"); } - checks.push(serde_json::json!({"check": "provider", "name": "GitHub Copilot", "status": "ok"})); + checks.push( + serde_json::json!({"check": "provider", "name": "GitHub Copilot", "status": "ok"}), + ); } } @@ -4992,7 +4994,9 @@ fn cmd_config_set_key(provider: &str) { ui::error(&format!("Failed to create async runtime: {e}")); std::process::exit(1); }); - match rt.block_on(openfang_runtime::drivers::copilot::run_interactive_setup(&openfang_dir)) { + match rt.block_on(openfang_runtime::drivers::copilot::run_interactive_setup( + &openfang_dir, + )) { Ok(_) => { ui::success("GitHub Copilot configured successfully"); ui::hint("Restart the daemon: openfang stop && openfang start"); diff --git a/crates/openfang-cli/src/tui/screens/init_wizard.rs b/crates/openfang-cli/src/tui/screens/init_wizard.rs index 54071b28e..4d708f49e 100644 --- a/crates/openfang-cli/src/tui/screens/init_wizard.rs +++ b/crates/openfang-cli/src/tui/screens/init_wizard.rs @@ -312,7 +312,10 @@ enum CopilotAuthStatus { } enum CopilotAuthEvent { - DeviceCode { user_code: String, verification_uri: String }, + DeviceCode { + user_code: String, + verification_uri: String, + }, Authenticated, Models(Vec), } @@ -648,8 +651,7 @@ pub fn run() -> InitResult { let (test_tx, test_rx) = std::sync::mpsc::channel::(); let (migrate_tx, migrate_rx) = std::sync::mpsc::channel::>(); - let (copilot_tx, copilot_rx) = - std::sync::mpsc::channel::>(); + let (copilot_tx, copilot_rx) = std::sync::mpsc::channel::>(); let result = loop { terminal @@ -660,7 +662,10 @@ pub fn run() -> InitResult { if state.step == Step::CopilotAuth { while let Ok(event) = copilot_rx.try_recv() { match event { - Ok(CopilotAuthEvent::DeviceCode { user_code, verification_uri }) => { + Ok(CopilotAuthEvent::DeviceCode { + user_code, + verification_uri, + }) => { state.copilot_user_code = user_code; state.copilot_verification_uri = verification_uri; state.copilot_auth_status = CopilotAuthStatus::WaitingForUser; @@ -839,7 +844,8 @@ pub fn run() -> InitResult { let rt = match tokio::runtime::Runtime::new() { Ok(rt) => rt, Err(e) => { - let _ = copilot_tx.send(Err(format!("Runtime error: {e}"))); + let _ = copilot_tx + .send(Err(format!("Runtime error: {e}"))); return; } }; @@ -851,21 +857,31 @@ pub fn run() -> InitResult { .map_err(|e| format!("HTTP error: {e}")); let http = match http { Ok(h) => h, - Err(e) => { let _ = copilot_tx.send(Err(e)); return; } + Err(e) => { + let _ = copilot_tx.send(Err(e)); + return; + } }; // Step 1: request device code use openfang_runtime::drivers::copilot; - let device = match copilot::request_device_code(&http).await { - Ok(d) => d, - Err(e) => { let _ = copilot_tx.send(Err(e)); return; } - }; + let device = + match copilot::request_device_code(&http).await { + Ok(d) => d, + Err(e) => { + let _ = copilot_tx.send(Err(e)); + return; + } + }; // Send device code to TUI for display - let _ = copilot_tx.send(Ok(CopilotAuthEvent::DeviceCode { - user_code: device.user_code.clone(), - verification_uri: device.verification_uri.clone(), - })); + let _ = + copilot_tx.send(Ok(CopilotAuthEvent::DeviceCode { + user_code: device.user_code.clone(), + verification_uri: device + .verification_uri + .clone(), + })); // Browser will be opened by user pressing Enter in TUI @@ -874,9 +890,14 @@ pub fn run() -> InitResult { &http, &device.device_code, device.interval, - ).await { + ) + .await + { Ok(t) => t, - Err(e) => { let _ = copilot_tx.send(Err(e)); return; } + Err(e) => { + let _ = copilot_tx.send(Err(e)); + return; + } }; // Save tokens @@ -885,16 +906,38 @@ pub fn run() -> InitResult { return; } - let _ = copilot_tx.send(Ok(CopilotAuthEvent::Authenticated)); + let _ = copilot_tx + .send(Ok(CopilotAuthEvent::Authenticated)); // Step 3: fetch models - let ct = match copilot::exchange_copilot_token(&http, &tokens.access_token).await { + let ct = match copilot::exchange_copilot_token( + &http, + &tokens.access_token, + ) + .await + { Ok(ct) => ct, - Err(e) => { let _ = copilot_tx.send(Err(format!("Token exchange: {e}"))); return; } + Err(e) => { + let _ = copilot_tx + .send(Err(format!("Token exchange: {e}"))); + return; + } }; - match copilot::fetch_models(&http, &ct.base_url, &ct.token).await { - Ok(models) => { let _ = copilot_tx.send(Ok(CopilotAuthEvent::Models(models))); } - Err(e) => { let _ = copilot_tx.send(Err(format!("Model fetch: {e}"))); } + match copilot::fetch_models( + &http, + &ct.base_url, + &ct.token, + ) + .await + { + Ok(models) => { + let _ = copilot_tx + .send(Ok(CopilotAuthEvent::Models(models))); + } + Err(e) => { + let _ = copilot_tx + .send(Err(format!("Model fetch: {e}"))); + } } }); }); @@ -924,11 +967,15 @@ pub fn run() -> InitResult { } } KeyCode::Enter => { - if matches!(state.copilot_auth_status, CopilotAuthStatus::WaitingForUser) { + if matches!( + state.copilot_auth_status, + CopilotAuthStatus::WaitingForUser + ) { if !state.copilot_verification_uri.is_empty() { - let _ = openfang_runtime::drivers::copilot::open_verification_url( - &state.copilot_verification_uri, - ); + let _ = + openfang_runtime::drivers::copilot::open_verification_url( + &state.copilot_verification_uri, + ); } } } @@ -1956,14 +2003,15 @@ fn draw_copilot_auth(f: &mut Frame, area: Rect, state: &mut State) { Constraint::Length(1), // code value Constraint::Length(1), // blank Constraint::Length(1), // url - Constraint::Min(0), // spacer + Constraint::Min(0), // spacer Constraint::Length(1), // hint ]) .split(area); - let title = Paragraph::new(Line::from(vec![ - Span::styled(" GitHub Copilot Authentication", Style::default().fg(theme::ACCENT)), - ])); + let title = Paragraph::new(Line::from(vec![Span::styled( + " GitHub Copilot Authentication", + Style::default().fg(theme::ACCENT), + )])); f.render_widget(title, chunks[0]); let spinner = theme::SPINNER_FRAMES[state.tick % theme::SPINNER_FRAMES.len()]; @@ -1985,9 +2033,7 @@ fn draw_copilot_auth(f: &mut Frame, area: Rect, state: &mut State) { ])); f.render_widget(line1, chunks[2]); - let code_label = Paragraph::new(Line::from(vec![ - Span::raw(" Enter this code:"), - ])); + let code_label = Paragraph::new(Line::from(vec![Span::raw(" Enter this code:")])); f.render_widget(code_label, chunks[5]); let code_value = Paragraph::new(Line::from(vec![ @@ -2007,9 +2053,10 @@ fn draw_copilot_auth(f: &mut Frame, area: Rect, state: &mut State) { ])); f.render_widget(url, chunks[8]); - let hint = Paragraph::new(Line::from(vec![ - Span::styled(" [Enter] Open browser", theme::dim_style()), - ])); + let hint = Paragraph::new(Line::from(vec![Span::styled( + " [Enter] Open browser", + theme::dim_style(), + )])); f.render_widget(hint, chunks[10]); } CopilotAuthStatus::FetchingModels => { @@ -2040,9 +2087,10 @@ fn draw_copilot_auth(f: &mut Frame, area: Rect, state: &mut State) { ])); f.render_widget(line, chunks[2]); - let hint = Paragraph::new(Line::from(vec![ - Span::styled(" Esc to go back", theme::dim_style()), - ])); + let hint = Paragraph::new(Line::from(vec![Span::styled( + " Esc to go back", + theme::dim_style(), + )])); f.render_widget(hint, chunks[10]); } } diff --git a/crates/openfang-kernel/src/kernel.rs b/crates/openfang-kernel/src/kernel.rs index 599541ee9..7f67fbc20 100644 --- a/crates/openfang-kernel/src/kernel.rs +++ b/crates/openfang-kernel/src/kernel.rs @@ -531,9 +531,9 @@ impl OpenFangKernel { // Otherwise (CLI commands), create a new one. if let Ok(handle) = tokio::runtime::Handle::try_current() { std::thread::scope(|s| { - s.spawn(|| { - handle.block_on(fetch) - }).join().unwrap_or(Err("Thread panicked".to_string())) + s.spawn(|| handle.block_on(fetch)) + .join() + .unwrap_or(Err("Thread panicked".to_string())) }) } else { let rt = tokio::runtime::Runtime::new() diff --git a/crates/openfang-runtime/src/compactor.rs b/crates/openfang-runtime/src/compactor.rs index de120f21d..fef90c815 100644 --- a/crates/openfang-runtime/src/compactor.rs +++ b/crates/openfang-runtime/src/compactor.rs @@ -435,10 +435,10 @@ async fn summarize_messages( let safe_start = if conversation_text.is_char_boundary(start) { start } else { - // Find the nearest valid character boundary moving upward - (start..conversation_text.len()) - .find(|&i| conversation_text.is_char_boundary(i)) - .unwrap_or(conversation_text.len()) + // Find the nearest valid character boundary moving upward + (start..conversation_text.len()) + .find(|&i| conversation_text.is_char_boundary(i)) + .unwrap_or(conversation_text.len()) }; conversation_text = conversation_text[safe_start..].to_string(); } diff --git a/crates/openfang-runtime/src/drivers/copilot.rs b/crates/openfang-runtime/src/drivers/copilot.rs index 8df53c64e..5ddf59127 100644 --- a/crates/openfang-runtime/src/drivers/copilot.rs +++ b/crates/openfang-runtime/src/drivers/copilot.rs @@ -169,6 +169,7 @@ struct OAuthTokenResponse { #[serde(default)] expires_in: Option, #[serde(default)] + #[allow(dead_code)] refresh_token_expires_in: Option, #[serde(default)] error: Option, @@ -177,9 +178,7 @@ struct OAuthTokenResponse { } /// Request a device code from GitHub using the Copilot client ID. -pub async fn request_device_code( - client: &reqwest::Client, -) -> Result { +pub async fn request_device_code(client: &reqwest::Client) -> Result { let resp = client .post(GITHUB_DEVICE_CODE_URL) .header("Accept", "application/json") @@ -259,9 +258,7 @@ pub async fn poll_for_token( let access_token = token_resp .access_token .ok_or("Missing access_token in response")?; - let refresh_token = token_resp - .refresh_token - .unwrap_or_default(); // Empty if token expiration is disabled on the OAuth App + let refresh_token = token_resp.refresh_token.unwrap_or_default(); // Empty if token expiration is disabled on the OAuth App let expires_in = token_resp.expires_in.unwrap_or(0); // 0 = non-expiring return Ok(PersistedTokens { @@ -444,7 +441,11 @@ pub async fn fetch_models( .and_then(|v| v.as_array()) .map(|arr| { arr.iter() - .filter_map(|m| m.get("id").and_then(|id| id.as_str()).map(|s| s.to_string())) + .filter_map(|m| { + m.get("id") + .and_then(|id| id.as_str()) + .map(|s| s.to_string()) + }) .collect() }) .unwrap_or_default(); @@ -514,12 +515,7 @@ impl CopilotDriver { }; if let Some(ref rt) = refresh_token { - match refresh_access_token( - &self.http_client, - rt, - ) - .await - { + match refresh_access_token(&self.http_client, rt).await { Ok(new_tokens) => { info!("Copilot access token refreshed successfully"); if let Err(e) = new_tokens.save(&self.openfang_dir) { @@ -547,8 +543,9 @@ impl CopilotDriver { } /// Ensure we have a valid Copilot API token (tid=…). - async fn ensure_copilot_token(&self) -> Result - { + async fn ensure_copilot_token( + &self, + ) -> Result { // Check cache. { let lock = self.copilot_token.lock().unwrap_or_else(|e| e.into_inner()); @@ -595,13 +592,16 @@ impl CopilotDriver { &self, copilot_token: &CachedCopilotToken, ) -> Result, crate::llm_driver::LlmError> { - let models = - fetch_models(&self.http_client, &copilot_token.base_url, &copilot_token.token) - .await - .map_err(|e| crate::llm_driver::LlmError::Api { - status: 500, - message: format!("Failed to fetch model list: {e}"), - })?; + let models = fetch_models( + &self.http_client, + &copilot_token.base_url, + &copilot_token.token, + ) + .await + .map_err(|e| crate::llm_driver::LlmError::Api { + status: 500, + message: format!("Failed to fetch model list: {e}"), + })?; let mut lock = self.models.lock().unwrap_or_else(|e| e.into_inner()); *lock = Some(CachedModels { @@ -638,10 +638,7 @@ impl CopilotDriver { execute: F, ) -> Result where - F: Fn( - super::openai::OpenAIDriver, - crate::llm_driver::CompletionRequest, - ) -> Fut, + F: Fn(super::openai::OpenAIDriver, crate::llm_driver::CompletionRequest) -> Fut, Fut: std::future::Future< Output = Result, >, @@ -652,9 +649,10 @@ impl CopilotDriver { match execute(driver, request.clone()).await { Ok(resp) => Ok(resp), - Err(crate::llm_driver::LlmError::Api { status, ref message }) - if status == 400 && message.contains("model_not_supported") => - { + Err(crate::llm_driver::LlmError::Api { + status, + ref message, + }) if status == 400 && message.contains("model_not_supported") => { // Refresh model list so subsequent calls have updated info. warn!( model = %request.model, @@ -683,9 +681,10 @@ impl crate::llm_driver::LlmDriver for CopilotDriver { &self, request: crate::llm_driver::CompletionRequest, ) -> Result { - self.execute_with_model_retry(request, |driver, req| async move { - driver.complete(req).await - }) + self.execute_with_model_retry( + request, + |driver, req| async move { driver.complete(req).await }, + ) .await } @@ -700,9 +699,10 @@ impl crate::llm_driver::LlmDriver for CopilotDriver { match driver.stream(request.clone(), tx.clone()).await { Ok(resp) => Ok(resp), - Err(crate::llm_driver::LlmError::Api { status, ref message }) - if status == 400 && message.contains("model_not_supported") => - { + Err(crate::llm_driver::LlmError::Api { + status, + ref message, + }) if status == 400 && message.contains("model_not_supported") => { warn!( model = %request.model, "Model not supported — refreshing model catalog" @@ -732,9 +732,7 @@ impl crate::llm_driver::LlmDriver for CopilotDriver { /// /// Called from `openfang config set-key github-copilot`, `openfang init`, /// `openfang onboard`, and `openfang configure`. -pub async fn run_interactive_setup( - openfang_dir: &PathBuf, -) -> Result { +pub async fn run_interactive_setup(openfang_dir: &PathBuf) -> Result { run_device_flow(openfang_dir).await } @@ -742,9 +740,7 @@ pub async fn run_interactive_setup( /// /// Prints the user code and verification URL, attempts to open the browser, /// then polls until the user authorizes. -pub async fn run_device_flow( - openfang_dir: &PathBuf, -) -> Result { +pub async fn run_device_flow(openfang_dir: &PathBuf) -> Result { let client = reqwest::Client::builder() .timeout(Duration::from_secs(30)) .build() @@ -767,12 +763,7 @@ pub async fn run_device_flow( println!(" Waiting for authorization..."); // Step 3: Poll for authorization. - let tokens = poll_for_token( - &client, - &device.device_code, - device.interval, - ) - .await?; + let tokens = poll_for_token(&client, &device.device_code, device.interval).await?; // Step 4: Persist. tokens.save(openfang_dir)?; @@ -782,6 +773,7 @@ pub async fn run_device_flow( } /// Read a line from stdin with a prompt. Used during interactive setup. +#[allow(dead_code)] fn prompt_line(prompt: &str) -> Result { use std::io::{self, BufRead, Write}; print!("{prompt}"); diff --git a/crates/openfang-runtime/src/drivers/mod.rs b/crates/openfang-runtime/src/drivers/mod.rs index b69ca82b0..50359216f 100644 --- a/crates/openfang-runtime/src/drivers/mod.rs +++ b/crates/openfang-runtime/src/drivers/mod.rs @@ -19,7 +19,7 @@ use openfang_types::model_catalog::{ AI21_BASE_URL, ANTHROPIC_BASE_URL, AZURE_OPENAI_BASE_URL, CEREBRAS_BASE_URL, CHUTES_BASE_URL, COHERE_BASE_URL, DEEPSEEK_BASE_URL, FIREWORKS_BASE_URL, GEMINI_BASE_URL, GROQ_BASE_URL, HUGGINGFACE_BASE_URL, KIMI_CODING_BASE_URL, LEMONADE_BASE_URL, LMSTUDIO_BASE_URL, - MINIMAX_BASE_URL, MISTRAL_BASE_URL, MOONSHOT_BASE_URL, NVIDIA_NIM_BASE_URL, NOVITA_BASE_URL, + MINIMAX_BASE_URL, MISTRAL_BASE_URL, MOONSHOT_BASE_URL, NOVITA_BASE_URL, NVIDIA_NIM_BASE_URL, OLLAMA_BASE_URL, OPENAI_BASE_URL, OPENROUTER_BASE_URL, PERPLEXITY_BASE_URL, QIANFAN_BASE_URL, QWEN_BASE_URL, REPLICATE_BASE_URL, SAMBANOVA_BASE_URL, TOGETHER_BASE_URL, VENICE_BASE_URL, VLLM_BASE_URL, VOLCENGINE_BASE_URL, VOLCENGINE_CODING_BASE_URL, XAI_BASE_URL, ZAI_BASE_URL, diff --git a/crates/openfang-runtime/src/subprocess_sandbox.rs b/crates/openfang-runtime/src/subprocess_sandbox.rs index 54b754304..7d6fd249f 100644 --- a/crates/openfang-runtime/src/subprocess_sandbox.rs +++ b/crates/openfang-runtime/src/subprocess_sandbox.rs @@ -192,13 +192,13 @@ fn extract_shell_wrapper_commands(command: &str) -> Vec { let base_lower = base.to_lowercase(); // Also strip .exe suffix for Windows let base_normalized = base_lower.strip_suffix(".exe").unwrap_or(&base_lower); - if !SHELL_WRAPPERS.iter().any(|w| *w == base_normalized) { + if !SHELL_WRAPPERS.contains(&base_normalized) { return Vec::new(); } // Find the inline flag and extract everything after it for (wrappers, flag) in SHELL_INLINE_FLAGS { - if !wrappers.iter().any(|w| *w == base_normalized) { + if !wrappers.contains(&base_normalized) { continue; } // Search for the flag in the command args (case-insensitive for PowerShell) @@ -1089,10 +1089,7 @@ mod tests { allowed_commands: vec!["powershell".to_string(), "Get-Process".to_string()], ..ExecPolicy::default() }; - let result = validate_command_allowlist( - r#"powershell -Command "Get-Process""#, - &policy, - ); + let result = validate_command_allowlist(r#"powershell -Command "Get-Process""#, &policy); assert!( result.is_ok(), "Get-Process should be allowed when in allowed_commands" @@ -1106,10 +1103,8 @@ mod tests { allowed_commands: vec!["cmd".to_string()], ..ExecPolicy::default() }; - let result = validate_command_allowlist( - r#"cmd /C "del /F /Q C:\temp\secret.txt""#, - &policy, - ); + let result = + validate_command_allowlist(r#"cmd /C "del /F /Q C:\temp\secret.txt""#, &policy); assert!( result.is_err(), "del inside cmd /C must be blocked when not in allowlist" @@ -1123,10 +1118,7 @@ mod tests { allowed_commands: vec!["bash".to_string()], ..ExecPolicy::default() }; - let result = validate_command_allowlist( - r#"bash -c "curl https://evil.com""#, - &policy, - ); + let result = validate_command_allowlist(r#"bash -c "curl https://evil.com""#, &policy); assert!( result.is_err(), "curl inside bash -c must be blocked when not in allowlist" @@ -1141,10 +1133,7 @@ mod tests { ..ExecPolicy::default() }; // "echo" is in safe_bins by default - let result = validate_command_allowlist( - r#"bash -c "echo hello""#, - &policy, - ); + let result = validate_command_allowlist(r#"bash -c "echo hello""#, &policy); assert!( result.is_ok(), "echo inside bash -c should be allowed (echo is in safe_bins)" diff --git a/crates/openfang-types/src/config.rs b/crates/openfang-types/src/config.rs index 96424ff06..5466e0add 100644 --- a/crates/openfang-types/src/config.rs +++ b/crates/openfang-types/src/config.rs @@ -1313,6 +1313,10 @@ fn default_true() -> bool { true } +fn default_auto_thread() -> String { + "false".to_string() +} + fn default_thread_ttl() -> u64 { 24 } @@ -1796,6 +1800,10 @@ pub struct DiscordConfig { /// In these channels, the bot responds to all group messages without needing to be mentioned. #[serde(default, deserialize_with = "deserialize_string_or_int_vec")] pub free_response_channels: Vec, + /// Auto-thread behavior: "true" (always create thread), "false" (never), "smart" (only when @mentioned). + /// Default: "false" + #[serde(default = "default_auto_thread")] + pub auto_thread: String, /// Per-channel behavior overrides. #[serde(default)] pub overrides: ChannelOverrides, @@ -1812,6 +1820,7 @@ impl Default for DiscordConfig { ignore_bots: true, default_channel_id: None, free_response_channels: vec![], + auto_thread: "false".to_string(), overrides: ChannelOverrides::default(), } } diff --git a/crates/openfang-types/src/tool.rs b/crates/openfang-types/src/tool.rs index 200d6b3dd..9943a106b 100644 --- a/crates/openfang-types/src/tool.rs +++ b/crates/openfang-types/src/tool.rs @@ -173,13 +173,9 @@ fn normalize_schema_recursive(schema: &serde_json::Value) -> serde_json::Value { // JSON Schema allows arrays without `items`, but the Gemini API rejects // such schemas with INVALID_ARGUMENT. Inject a default string items schema // so MCP tools (and any other source) don't break Gemini requests. - if result.get("type").and_then(|t| t.as_str()) == Some("array") - && !result.contains_key("items") + if result.get("type").and_then(|t| t.as_str()) == Some("array") && !result.contains_key("items") { - result.insert( - "items".to_string(), - serde_json::json!({"type": "string"}), - ); + result.insert("items".to_string(), serde_json::json!({"type": "string"})); } serde_json::Value::Object(result)