diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index cd045e1f9487..4a23797cc38f 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -3137,6 +3137,26 @@ dependencies = [ "wiremock", ] +[[package]] +name = "codex-memories-extension" +version = "0.0.0" +dependencies = [ + "codex-core", + "codex-extension-api", + "codex-features", + "codex-memories-read", + "codex-tools", + "codex-utils-absolute-path", + "codex-utils-output-truncation", + "pretty_assertions", + "schemars 0.8.22", + "serde", + "serde_json", + "tempfile", + "thiserror 2.0.18", + "tokio", +] + [[package]] name = "codex-memories-mcp" version = "0.0.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 233edcdee858..bd0bcec66a94 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -47,6 +47,7 @@ members = [ "ext/extension-api", "ext/guardian", "ext/git-attribution", + "ext/memories", "external-agent-migration", "external-agent-sessions", "keyring-store", @@ -179,6 +180,7 @@ codex-linux-sandbox = { path = "linux-sandbox" } codex-lmstudio = { path = "lmstudio" } codex-login = { path = "login" } codex-message-history = { path = "message-history" } +codex-memories-extension = { path = "ext/memories" } codex-memories-read = { path = "memories/read" } codex-memories-write = { path = "memories/write" } codex-mcp = { path = "codex-mcp" } @@ -469,6 +471,7 @@ unwrap_used = "deny" [workspace.metadata.cargo-shear] ignored = [ "codex-agent-graph-store", + "codex-memories-extension", "icu_provider", "openssl-sys", "codex-v8-poc", diff --git a/codex-rs/ext/memories/BUILD.bazel b/codex-rs/ext/memories/BUILD.bazel new file mode 100644 index 000000000000..5da7a952ab3b --- /dev/null +++ b/codex-rs/ext/memories/BUILD.bazel @@ -0,0 +1,6 @@ +load("//:defs.bzl", "codex_rust_crate") + +codex_rust_crate( + name = "memories", + crate_name = "codex_memories_extension", +) diff --git a/codex-rs/ext/memories/Cargo.toml b/codex-rs/ext/memories/Cargo.toml new file mode 100644 index 000000000000..542eaa6bc5d5 --- /dev/null +++ b/codex-rs/ext/memories/Cargo.toml @@ -0,0 +1,32 @@ +[package] +edition.workspace = true +license.workspace = true +name = "codex-memories-extension" +version.workspace = true + +[lib] +name = "codex_memories_extension" +path = "src/lib.rs" +doctest = false + +[lints] +workspace = true + +[dependencies] +codex-core = { workspace = true } +codex-extension-api = { workspace = true } +codex-features = { workspace = true } +codex-memories-read = { workspace = true } +codex-utils-absolute-path = { workspace = true } +codex-utils-output-truncation = { workspace = true } +schemars = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["fs"] } + +[dev-dependencies] +codex-tools = { workspace = true } +pretty_assertions = { workspace = true } +tempfile = { workspace = true } +tokio = { workspace = true, features = ["fs", "macros"] } diff --git a/codex-rs/ext/memories/src/backend.rs b/codex-rs/ext/memories/src/backend.rs new file mode 100644 index 000000000000..929852f1b1fe --- /dev/null +++ b/codex-rs/ext/memories/src/backend.rs @@ -0,0 +1,164 @@ +use schemars::JsonSchema; +use serde::Deserialize; +use serde::Serialize; +use std::future::Future; + +pub const DEFAULT_LIST_MAX_RESULTS: usize = 2_000; +pub const MAX_LIST_RESULTS: usize = 2_000; +pub const DEFAULT_SEARCH_MAX_RESULTS: usize = 200; +pub const MAX_SEARCH_RESULTS: usize = 200; +pub const DEFAULT_READ_MAX_TOKENS: usize = 20_000; + +/// Storage interface behind the memories MCP tools. +/// +/// Implementations should return paths relative to the memory store and enforce +/// their own storage-specific access rules. The local implementation uses the +/// filesystem today; a later implementation can satisfy the same contract from a +/// remote backend. +pub trait MemoriesBackend: Clone + Send + Sync + 'static { + fn list( + &self, + request: ListMemoriesRequest, + ) -> impl Future> + Send; + + fn read( + &self, + request: ReadMemoryRequest, + ) -> impl Future> + Send; + + fn search( + &self, + request: SearchMemoriesRequest, + ) -> impl Future> + Send; +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ListMemoriesRequest { + pub path: Option, + pub cursor: Option, + pub max_results: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, JsonSchema)] +#[schemars(deny_unknown_fields)] +pub struct ListMemoriesResponse { + pub path: Option, + pub entries: Vec, + pub next_cursor: Option, + pub truncated: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReadMemoryRequest { + pub path: String, + pub line_offset: usize, + pub max_lines: Option, + pub max_tokens: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, JsonSchema)] +#[schemars(deny_unknown_fields)] +pub struct ReadMemoryResponse { + pub path: String, + pub start_line_number: usize, + pub content: String, + pub truncated: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SearchMemoriesRequest { + pub queries: Vec, + pub match_mode: SearchMatchMode, + pub path: Option, + pub cursor: Option, + pub context_lines: usize, + pub case_sensitive: bool, + pub normalized: bool, + pub max_results: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, JsonSchema)] +#[schemars(deny_unknown_fields)] +pub struct SearchMemoriesResponse { + pub queries: Vec, + pub match_mode: SearchMatchMode, + pub path: Option, + pub matches: Vec, + pub next_cursor: Option, + pub truncated: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum SearchMatchMode { + Any, + AllOnSameLine, + AllWithinLines { + #[schemars(range(min = 1))] + line_count: usize, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, JsonSchema)] +#[schemars(deny_unknown_fields)] +pub struct MemoryEntry { + pub path: String, + pub entry_type: MemoryEntryType, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum MemoryEntryType { + File, + Directory, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, JsonSchema)] +#[schemars(deny_unknown_fields)] +pub struct MemorySearchMatch { + pub path: String, + pub match_line_number: usize, + pub content_start_line_number: usize, + pub content: String, + pub matched_queries: Vec, +} + +#[derive(Debug, thiserror::Error)] +pub enum MemoriesBackendError { + #[error("path '{path}' {reason}")] + InvalidPath { path: String, reason: String }, + #[error("cursor '{cursor}' {reason}")] + InvalidCursor { cursor: String, reason: String }, + #[error("path '{path}' was not found")] + NotFound { path: String }, + #[error("line_offset must be a 1-indexed line number")] + InvalidLineOffset, + #[error("max_lines must be a positive integer")] + InvalidMaxLines, + #[error("line_offset exceeds file length")] + LineOffsetExceedsFileLength, + #[error("path '{path}' is not a file")] + NotFile { path: String }, + #[error("queries must not be empty or contain empty strings")] + EmptyQuery, + #[error("all_within_lines.line_count must be a positive integer")] + InvalidMatchWindow, + #[error("I/O error while reading memories: {0}")] + Io(#[from] std::io::Error), +} + +impl MemoriesBackendError { + pub fn invalid_path(path: impl Into, reason: impl Into) -> Self { + Self::InvalidPath { + path: path.into(), + reason: reason.into(), + } + } + + pub fn invalid_cursor(cursor: impl Into, reason: impl Into) -> Self { + Self::InvalidCursor { + cursor: cursor.into(), + reason: reason.into(), + } + } +} diff --git a/codex-rs/ext/memories/src/lib.rs b/codex-rs/ext/memories/src/lib.rs new file mode 100644 index 000000000000..41143bef0f3b --- /dev/null +++ b/codex-rs/ext/memories/src/lib.rs @@ -0,0 +1,192 @@ +use std::sync::Arc; + +use codex_core::config::Config; +use codex_extension_api::ContextContributor; +use codex_extension_api::ExtensionData; +use codex_extension_api::ExtensionRegistryBuilder; +use codex_extension_api::PromptFragment; +use codex_extension_api::ThreadLifecycleContributor; +use codex_extension_api::ThreadStartInput; +use codex_extension_api::ToolContributor; +use codex_features::Feature; +use codex_memories_read::build_memory_tool_developer_instructions; +use codex_utils_absolute_path::AbsolutePathBuf; + +mod backend; +mod local; +mod schema; +mod tools; + +use local::LocalMemoriesBackend; + +/// Contributes Codex memory read-path prompt context and memory read tools. +#[derive(Clone, Copy, Debug, Default)] +pub struct MemoriesExtension; + +#[derive(Clone, Debug)] +struct MemoriesExtensionConfig { + enabled: bool, + codex_home: AbsolutePathBuf, +} + +impl ContextContributor for MemoriesExtension { + fn contribute<'a>( + &'a self, + _session_store: &'a ExtensionData, + thread_store: &'a ExtensionData, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let Some(config) = thread_store.get::() else { + return Vec::new(); + }; + if !config.enabled { + return Vec::new(); + } + + build_memory_tool_developer_instructions(&config.codex_home) + .await + .map(PromptFragment::developer_policy) + .into_iter() + .collect() + }) + } +} + +impl ThreadLifecycleContributor for MemoriesExtension { + fn on_thread_start(&self, input: ThreadStartInput<'_, Config>) { + input.thread_store.insert(MemoriesExtensionConfig { + enabled: input.config.features.enabled(Feature::MemoryTool) + && input.config.memories.use_memories, + codex_home: input.config.codex_home.clone(), + }); + } +} + +impl ToolContributor for MemoriesExtension { + fn tools( + &self, + _session_store: &ExtensionData, + thread_store: &ExtensionData, + ) -> Vec> { + let Some(config) = thread_store.get::() else { + return Vec::new(); + }; + if !config.enabled { + return Vec::new(); + } + + tools::memory_tools(LocalMemoriesBackend::from_codex_home(&config.codex_home)) + } +} + +/// Installs the memories extension contributors into the extension registry. +pub fn install(registry: &mut ExtensionRegistryBuilder) { + let extension = Arc::new(MemoriesExtension); + registry.thread_lifecycle_contributor(extension.clone()); + registry.prompt_contributor(extension.clone()); + registry.tool_contributor(extension); +} + +#[cfg(test)] +mod tests { + use codex_extension_api::ContextContributor; + use codex_extension_api::ExtensionData; + use codex_extension_api::PromptSlot; + use codex_extension_api::ToolContributor; + use codex_utils_absolute_path::test_support::PathBufExt; + use codex_utils_absolute_path::test_support::PathExt; + use codex_utils_absolute_path::test_support::test_path_buf; + use pretty_assertions::assert_eq; + + use super::MemoriesExtension; + use super::MemoriesExtensionConfig; + + #[test] + fn tools_are_not_contributed_without_thread_config() { + let extension = MemoriesExtension; + + assert!( + extension + .tools( + &ExtensionData::new("session"), + &ExtensionData::new("thread") + ) + .is_empty() + ); + } + + #[test] + fn tools_are_not_contributed_when_disabled() { + let extension = MemoriesExtension; + let thread_store = ExtensionData::new("thread"); + thread_store.insert(MemoriesExtensionConfig { + enabled: false, + codex_home: test_path_buf("/tmp/codex-home").abs(), + }); + + assert!( + extension + .tools(&ExtensionData::new("session"), &thread_store) + .is_empty() + ); + } + + #[test] + fn tools_are_contributed_when_enabled() { + let extension = MemoriesExtension; + let thread_store = ExtensionData::new("thread"); + thread_store.insert(MemoriesExtensionConfig { + enabled: true, + codex_home: test_path_buf("/tmp/codex-home").abs(), + }); + + let tool_names = extension + .tools(&ExtensionData::new("session"), &thread_store) + .into_iter() + .map(|tool| tool.tool_name().to_string()) + .collect::>(); + + assert_eq!( + tool_names, + vec![ + "memory_list".to_string(), + "memory_read".to_string(), + "memory_search".to_string() + ] + ); + } + + #[tokio::test] + async fn prompt_contribution_uses_memory_summary_when_enabled() { + let tempdir = tempfile::tempdir().expect("tempdir"); + let memories_dir = tempdir.path().join("memories"); + tokio::fs::create_dir_all(&memories_dir) + .await + .expect("create memories dir"); + tokio::fs::write( + memories_dir.join("memory_summary.md"), + "Remember repository-specific implementation preferences.", + ) + .await + .expect("write memory summary"); + + let extension = MemoriesExtension; + let thread_store = ExtensionData::new("thread"); + thread_store.insert(MemoriesExtensionConfig { + enabled: true, + codex_home: tempdir.path().abs(), + }); + + let fragments = extension + .contribute(&ExtensionData::new("session"), &thread_store) + .await; + + assert_eq!(fragments.len(), 1); + assert_eq!(fragments[0].slot(), PromptSlot::DeveloperPolicy); + assert!( + fragments[0] + .text() + .contains("Remember repository-specific implementation preferences.") + ); + } +} diff --git a/codex-rs/ext/memories/src/local.rs b/codex-rs/ext/memories/src/local.rs new file mode 100644 index 000000000000..47c09b60d90b --- /dev/null +++ b/codex-rs/ext/memories/src/local.rs @@ -0,0 +1,616 @@ +use crate::backend::DEFAULT_READ_MAX_TOKENS; +use crate::backend::ListMemoriesRequest; +use crate::backend::ListMemoriesResponse; +use crate::backend::MAX_LIST_RESULTS; +use crate::backend::MAX_SEARCH_RESULTS; +use crate::backend::MemoriesBackend; +use crate::backend::MemoriesBackendError; +use crate::backend::MemoryEntry; +use crate::backend::MemoryEntryType; +use crate::backend::MemorySearchMatch; +use crate::backend::ReadMemoryRequest; +use crate::backend::ReadMemoryResponse; +use crate::backend::SearchMatchMode; +use crate::backend::SearchMemoriesRequest; +use crate::backend::SearchMemoriesResponse; +use codex_utils_absolute_path::AbsolutePathBuf; +use codex_utils_output_truncation::TruncationPolicy; +use codex_utils_output_truncation::truncate_text; +use std::borrow::Cow; +use std::path::Component; +use std::path::Path; +use std::path::PathBuf; + +#[derive(Debug, Clone)] +pub struct LocalMemoriesBackend { + root: PathBuf, +} + +impl LocalMemoriesBackend { + pub fn from_codex_home(codex_home: &AbsolutePathBuf) -> Self { + Self::from_memory_root(codex_home.join("memories").to_path_buf()) + } + + pub fn from_memory_root(root: impl Into) -> Self { + Self { root: root.into() } + } + + async fn resolve_scoped_path( + &self, + relative_path: Option<&str>, + ) -> Result { + let Some(relative_path) = relative_path else { + return Ok(self.root.clone()); + }; + let relative = Path::new(relative_path); + if relative.components().any(|component| { + matches!( + component, + Component::ParentDir | Component::RootDir | Component::Prefix(_) + ) + }) { + return Err(MemoriesBackendError::invalid_path( + relative_path, + "must stay within the memories root", + )); + } + if relative.components().any(is_hidden_component) { + return Err(MemoriesBackendError::NotFound { + path: relative_path.to_string(), + }); + } + + let components = relative.components().collect::>(); + let mut scoped_path = self.root.clone(); + for (idx, component) in components.iter().enumerate() { + scoped_path.push(component.as_os_str()); + + let Some(metadata) = Self::metadata_or_none(&scoped_path).await? else { + for remaining_component in components.iter().skip(idx + 1) { + scoped_path.push(remaining_component.as_os_str()); + } + return Ok(scoped_path); + }; + + reject_symlink(&display_relative_path(&self.root, &scoped_path), &metadata)?; + if idx + 1 < components.len() && !metadata.is_dir() { + return Err(MemoriesBackendError::invalid_path( + relative_path, + "traverses through a non-directory path component", + )); + } + } + + Ok(scoped_path) + } + + async fn metadata_or_none( + path: &Path, + ) -> Result, MemoriesBackendError> { + match tokio::fs::symlink_metadata(path).await { + Ok(metadata) => Ok(Some(metadata)), + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(err) => Err(err.into()), + } + } +} + +impl MemoriesBackend for LocalMemoriesBackend { + async fn list( + &self, + request: ListMemoriesRequest, + ) -> Result { + let max_results = request.max_results.min(MAX_LIST_RESULTS); + let start = self.resolve_scoped_path(request.path.as_deref()).await?; + let start_index = match request.cursor.as_deref() { + Some(cursor) => cursor.parse::().map_err(|_| { + MemoriesBackendError::invalid_cursor(cursor, "must be a non-negative integer") + })?, + None => 0, + }; + let Some(metadata) = Self::metadata_or_none(&start).await? else { + return Err(MemoriesBackendError::NotFound { + path: request.path.unwrap_or_default(), + }); + }; + reject_symlink(&display_relative_path(&self.root, &start), &metadata)?; + + let mut entries = if metadata.is_file() { + vec![MemoryEntry { + path: display_relative_path(&self.root, &start), + entry_type: MemoryEntryType::File, + }] + } else if metadata.is_dir() { + let mut entries = Vec::new(); + for path in read_sorted_dir_paths(&start).await? { + if is_hidden_path(&path) { + continue; + } + let Some(metadata) = Self::metadata_or_none(&path).await? else { + continue; + }; + if metadata.file_type().is_symlink() { + continue; + } + + let entry_type = if metadata.is_dir() { + MemoryEntryType::Directory + } else if metadata.is_file() { + MemoryEntryType::File + } else { + continue; + }; + entries.push(MemoryEntry { + path: display_relative_path(&self.root, &path), + entry_type, + }); + } + entries + } else { + Vec::new() + }; + if start_index > entries.len() { + return Err(MemoriesBackendError::invalid_cursor( + start_index.to_string(), + "exceeds result count", + )); + } + + let end_index = start_index.saturating_add(max_results).min(entries.len()); + let next_cursor = (end_index < entries.len()).then(|| end_index.to_string()); + let truncated = next_cursor.is_some(); + Ok(ListMemoriesResponse { + path: request.path, + entries: entries.drain(start_index..end_index).collect(), + next_cursor, + truncated, + }) + } + + async fn read( + &self, + request: ReadMemoryRequest, + ) -> Result { + if request.line_offset == 0 { + return Err(MemoriesBackendError::InvalidLineOffset); + } + if request.max_lines == Some(0) { + return Err(MemoriesBackendError::InvalidMaxLines); + } + + let path = self + .resolve_scoped_path(Some(request.path.as_str())) + .await?; + let Some(metadata) = Self::metadata_or_none(&path).await? else { + return Err(MemoriesBackendError::NotFound { path: request.path }); + }; + reject_symlink(&request.path, &metadata)?; + if !metadata.is_file() { + return Err(MemoriesBackendError::NotFile { path: request.path }); + } + + let original_content = tokio::fs::read_to_string(&path).await?; + let start_byte = line_start_byte_offset(&original_content, request.line_offset)?; + let end_byte = line_end_byte_offset(&original_content, start_byte, request.max_lines); + let content_from_offset = &original_content[start_byte..end_byte]; + let max_tokens = if request.max_tokens == 0 { + DEFAULT_READ_MAX_TOKENS + } else { + request.max_tokens + }; + let content = truncate_text(content_from_offset, TruncationPolicy::Tokens(max_tokens)); + let truncated = end_byte < original_content.len() || content != content_from_offset; + Ok(ReadMemoryResponse { + path: request.path, + start_line_number: request.line_offset, + content, + truncated, + }) + } + + async fn search( + &self, + request: SearchMemoriesRequest, + ) -> Result { + let queries = request + .queries + .iter() + .map(|query| query.trim().to_string()) + .collect::>(); + if queries.is_empty() || queries.iter().any(std::string::String::is_empty) { + return Err(MemoriesBackendError::EmptyQuery); + } + if matches!( + request.match_mode, + SearchMatchMode::AllWithinLines { line_count: 0 } + ) { + return Err(MemoriesBackendError::InvalidMatchWindow); + } + + let max_results = request.max_results.min(MAX_SEARCH_RESULTS); + let start = self.resolve_scoped_path(request.path.as_deref()).await?; + let start_index = match request.cursor.as_deref() { + Some(cursor) => cursor.parse::().map_err(|_| { + MemoriesBackendError::invalid_cursor(cursor, "must be a non-negative integer") + })?, + None => 0, + }; + let Some(metadata) = Self::metadata_or_none(&start).await? else { + return Err(MemoriesBackendError::NotFound { + path: request.path.unwrap_or_default(), + }); + }; + reject_symlink(&display_relative_path(&self.root, &start), &metadata)?; + + let matcher = SearchMatcher::new( + queries.clone(), + request.match_mode.clone(), + request.case_sensitive, + request.normalized, + )?; + let mut matches = Vec::new(); + search_entries( + &self.root, + &start, + &metadata, + &matcher, + request.context_lines, + &mut matches, + ) + .await?; + matches.sort_by(|left, right| { + left.path + .cmp(&right.path) + .then(left.match_line_number.cmp(&right.match_line_number)) + }); + if start_index > matches.len() { + return Err(MemoriesBackendError::invalid_cursor( + start_index.to_string(), + "exceeds result count", + )); + } + let end_index = start_index.saturating_add(max_results).min(matches.len()); + let next_cursor = (end_index < matches.len()).then(|| end_index.to_string()); + let truncated = next_cursor.is_some(); + Ok(SearchMemoriesResponse { + queries, + match_mode: request.match_mode, + path: request.path, + matches: matches.drain(start_index..end_index).collect(), + next_cursor, + truncated, + }) + } +} + +async fn search_entries( + root: &Path, + current: &Path, + current_metadata: &std::fs::Metadata, + matcher: &SearchMatcher, + context_lines: usize, + matches: &mut Vec, +) -> Result<(), MemoriesBackendError> { + if current_metadata.is_file() { + search_file(root, current, matcher, context_lines, matches).await?; + return Ok(()); + } + if !current_metadata.is_dir() { + return Ok(()); + } + + let mut pending = vec![current.to_path_buf()]; + while let Some(dir_path) = pending.pop() { + for path in read_sorted_dir_paths(&dir_path).await? { + if is_hidden_path(&path) { + continue; + } + let Some(metadata) = LocalMemoriesBackend::metadata_or_none(&path).await? else { + continue; + }; + if metadata.file_type().is_symlink() { + continue; + } + if metadata.is_dir() { + pending.push(path); + } else if metadata.is_file() { + search_file(root, &path, matcher, context_lines, matches).await?; + } + } + } + + Ok(()) +} + +async fn search_file( + root: &Path, + path: &Path, + matcher: &SearchMatcher, + context_lines: usize, + matches: &mut Vec, +) -> Result<(), MemoriesBackendError> { + let content = match tokio::fs::read_to_string(path).await { + Ok(content) => content, + Err(err) if err.kind() == std::io::ErrorKind::InvalidData => return Ok(()), + Err(err) => return Err(err.into()), + }; + let lines = content.lines().collect::>(); + let line_matches = lines + .iter() + .map(|line| matcher.matched_query_flags(line)) + .collect::>(); + match &matcher.match_mode { + SearchMatchMode::Any => { + for (idx, matched_query_flags) in line_matches.iter().enumerate() { + if matched_query_flags.iter().any(|matched| *matched) { + matches.push(build_search_match( + root, + path, + &lines, + idx, + idx, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } + } + } + SearchMatchMode::AllOnSameLine => { + for (idx, matched_query_flags) in line_matches.iter().enumerate() { + if matched_query_flags.iter().all(|matched| *matched) { + matches.push(build_search_match( + root, + path, + &lines, + idx, + idx, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } + } + } + SearchMatchMode::AllWithinLines { line_count } => { + let mut windows = Vec::new(); + for start_index in 0..lines.len() { + if !line_matches[start_index].iter().any(|matched| *matched) { + continue; + } + let last_allowed_index = start_index + .saturating_add(line_count.saturating_sub(1)) + .min(lines.len().saturating_sub(1)); + let mut matched_query_flags = vec![false; matcher.queries.len()]; + for (end_index, line_match_flags) in line_matches + .iter() + .enumerate() + .take(last_allowed_index + 1) + .skip(start_index) + { + for (idx, matched) in line_match_flags.iter().enumerate() { + matched_query_flags[idx] |= matched; + } + if matched_query_flags.iter().all(|matched| *matched) { + windows.push((start_index, end_index, matched_query_flags)); + break; + } + } + } + for (idx, (start_index, end_index, matched_query_flags)) in windows.iter().enumerate() { + let strictly_contains_another_window = windows.iter().enumerate().any( + |(other_idx, (other_start_index, other_end_index, _))| { + idx != other_idx + && start_index <= other_start_index + && end_index >= other_end_index + && (start_index != other_start_index || end_index != other_end_index) + }, + ); + if strictly_contains_another_window { + continue; + } + matches.push(build_search_match( + root, + path, + &lines, + *start_index, + *end_index, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } + } + } + Ok(()) +} + +fn build_search_match( + root: &Path, + path: &Path, + lines: &[&str], + match_start_index: usize, + match_end_index: usize, + context_lines: usize, + matched_queries: Vec, +) -> MemorySearchMatch { + let content_start_index = match_start_index.saturating_sub(context_lines); + let content_end_index = match_end_index + .saturating_add(context_lines) + .saturating_add(1) + .min(lines.len()); + MemorySearchMatch { + path: display_relative_path(root, path), + match_line_number: match_start_index + 1, + content_start_line_number: content_start_index + 1, + content: lines[content_start_index..content_end_index].join("\n"), + matched_queries, + } +} + +struct SearchMatcher { + queries: Vec, + prepared_queries: Vec, + comparison: SearchComparison, + match_mode: SearchMatchMode, +} + +impl SearchMatcher { + fn new( + queries: Vec, + match_mode: SearchMatchMode, + case_sensitive: bool, + normalized: bool, + ) -> Result { + let comparison = SearchComparison::new(case_sensitive, normalized); + let prepared_queries = queries + .iter() + .map(|query| comparison.prepare(query)) + .map(Cow::into_owned) + .collect::>(); + if prepared_queries.iter().any(std::string::String::is_empty) { + return Err(MemoriesBackendError::EmptyQuery); + } + Ok(Self { + queries, + prepared_queries, + comparison, + match_mode, + }) + } + + fn matched_query_flags(&self, line: &str) -> Vec { + let line = self.comparison.prepare(line); + self.prepared_queries + .iter() + .map(|query| line.as_ref().contains(query)) + .collect() + } + + fn matched_queries(&self, matched_query_flags: &[bool]) -> Vec { + self.queries + .iter() + .zip(matched_query_flags) + .filter_map(|(query, matched)| matched.then_some(query.clone())) + .collect() + } +} + +#[derive(Clone, Copy)] +struct SearchComparison { + case_sensitive: bool, + normalized: bool, +} + +impl SearchComparison { + fn new(case_sensitive: bool, normalized: bool) -> Self { + Self { + case_sensitive, + normalized, + } + } + + fn prepare<'a>(self, value: &'a str) -> Cow<'a, str> { + if self.case_sensitive && !self.normalized { + return Cow::Borrowed(value); + } + + let value = if self.case_sensitive { + Cow::Borrowed(value) + } else { + Cow::Owned(value.to_lowercase()) + }; + if !self.normalized { + return value; + } + + Cow::Owned( + value + .chars() + .filter(|ch| ch.is_alphanumeric()) + .collect::(), + ) + } +} + +async fn read_sorted_dir_paths(dir_path: &Path) -> Result, MemoriesBackendError> { + let mut dir = match tokio::fs::read_dir(dir_path).await { + Ok(dir) => dir, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()), + Err(err) => return Err(err.into()), + }; + let mut paths = Vec::new(); + while let Some(entry) = dir.next_entry().await? { + paths.push(entry.path()); + } + paths.sort(); + Ok(paths) +} + +fn reject_symlink(path: &str, metadata: &std::fs::Metadata) -> Result<(), MemoriesBackendError> { + if metadata.file_type().is_symlink() { + return Err(MemoriesBackendError::invalid_path( + path, + "must not be a symlink", + )); + } + Ok(()) +} + +fn is_hidden_component(component: Component<'_>) -> bool { + matches!( + component, + Component::Normal(name) if name.to_string_lossy().starts_with('.') + ) +} + +fn is_hidden_path(path: &Path) -> bool { + path.file_name() + .is_some_and(|name| name.to_string_lossy().starts_with('.')) +} + +fn display_relative_path(root: &Path, path: &Path) -> String { + path.strip_prefix(root) + .unwrap_or(path) + .components() + .map(|component| component.as_os_str().to_string_lossy()) + .filter(|component| !component.is_empty()) + .collect::>() + .join("/") +} + +fn line_start_byte_offset( + content: &str, + line_offset: usize, +) -> Result { + if line_offset == 1 { + return Ok(0); + } + + let mut current_line = 1; + for (idx, ch) in content.char_indices() { + if ch == '\n' { + current_line += 1; + if current_line == line_offset { + return Ok(idx + 1); + } + } + } + + Err(MemoriesBackendError::LineOffsetExceedsFileLength) +} + +fn line_end_byte_offset(content: &str, start_byte: usize, max_lines: Option) -> usize { + let Some(max_lines) = max_lines else { + return content.len(); + }; + + let mut lines_seen = 1; + for (relative_idx, ch) in content[start_byte..].char_indices() { + if ch == '\n' { + if lines_seen == max_lines { + return start_byte + relative_idx + 1; + } + lines_seen += 1; + } + } + + content.len() +} diff --git a/codex-rs/ext/memories/src/schema.rs b/codex-rs/ext/memories/src/schema.rs new file mode 100644 index 000000000000..977cd5679a80 --- /dev/null +++ b/codex-rs/ext/memories/src/schema.rs @@ -0,0 +1,42 @@ +use schemars::JsonSchema; +use schemars::r#gen::SchemaSettings; +use serde_json::Map; +use serde_json::Value; + +pub(crate) fn input_schema_for() -> Value { + schema_for::(/*option_add_null_type*/ false) +} + +pub(crate) fn output_schema_for() -> Value { + schema_for::(/*option_add_null_type*/ true) +} + +fn schema_for(option_add_null_type: bool) -> Value { + let schema = SchemaSettings::draft2019_09() + .with(|settings| { + settings.inline_subschemas = true; + settings.option_add_null_type = option_add_null_type; + }) + .into_generator() + .into_root_schema_for::(); + let schema_value = serde_json::to_value(schema) + .unwrap_or_else(|err| panic!("generated tool schema should serialize: {err}")); + let Value::Object(mut schema_object) = schema_value else { + unreachable!("root tool schema must be an object"); + }; + + let mut tool_schema = Map::new(); + for key in [ + "properties", + "required", + "type", + "additionalProperties", + "$defs", + "definitions", + ] { + if let Some(value) = schema_object.remove(key) { + tool_schema.insert(key.to_string(), value); + } + } + Value::Object(tool_schema) +} diff --git a/codex-rs/ext/memories/src/tools/list.rs b/codex-rs/ext/memories/src/tools/list.rs new file mode 100644 index 000000000000..e9674f704b60 --- /dev/null +++ b/codex-rs/ext/memories/src/tools/list.rs @@ -0,0 +1,69 @@ +use codex_extension_api::ExtensionToolExecutor; +use codex_extension_api::ExtensionToolFuture; +use codex_extension_api::JsonToolOutput; +use codex_extension_api::ToolCall; +use codex_extension_api::ToolName; +use codex_extension_api::ToolSpec; +use schemars::JsonSchema; +use serde::Deserialize; +use serde_json::json; + +use crate::backend::DEFAULT_LIST_MAX_RESULTS; +use crate::backend::ListMemoriesRequest; +use crate::backend::ListMemoriesResponse; +use crate::backend::MAX_LIST_RESULTS; +use crate::backend::MemoriesBackend; +use crate::local::LocalMemoriesBackend; + +use super::LIST_TOOL_NAME; +use super::backend_error_to_function_call; +use super::clamp_max_results; +use super::function_tool; +use super::parse_args; + +#[derive(Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct ListArgs { + path: Option, + cursor: Option, + #[schemars(range(min = 1))] + max_results: Option, +} + +#[derive(Clone)] +pub(super) struct ListTool { + pub(super) backend: LocalMemoriesBackend, +} + +impl ExtensionToolExecutor for ListTool { + fn tool_name(&self) -> ToolName { + ToolName::plain(LIST_TOOL_NAME) + } + + fn spec(&self) -> Option { + Some(function_tool::( + LIST_TOOL_NAME, + "List immediate files and directories under a path in the Codex memories store.", + )) + } + + fn handle(&self, call: ToolCall) -> ExtensionToolFuture<'_> { + let backend = self.backend.clone(); + Box::pin(async move { + let args: ListArgs = parse_args(&call)?; + let response = backend + .list(ListMemoriesRequest { + path: args.path, + cursor: args.cursor, + max_results: clamp_max_results( + args.max_results, + DEFAULT_LIST_MAX_RESULTS, + MAX_LIST_RESULTS, + ), + }) + .await + .map_err(backend_error_to_function_call)?; + Ok(JsonToolOutput::new(json!(response))) + }) + } +} diff --git a/codex-rs/ext/memories/src/tools/mod.rs b/codex-rs/ext/memories/src/tools/mod.rs new file mode 100644 index 000000000000..735126349e5b --- /dev/null +++ b/codex-rs/ext/memories/src/tools/mod.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use codex_extension_api::ExtensionToolExecutor; +use codex_extension_api::FunctionCallError; +use codex_extension_api::ResponsesApiTool; +use codex_extension_api::ToolCall; +use codex_extension_api::ToolSpec; +use codex_extension_api::parse_tool_input_schema; +use schemars::JsonSchema; +use serde::Deserialize; +use serde_json::Value; + +use crate::backend::MemoriesBackendError; +use crate::local::LocalMemoriesBackend; +use crate::schema; + +mod list; +mod read; +mod search; + +const LIST_TOOL_NAME: &str = "memory_list"; +const READ_TOOL_NAME: &str = "memory_read"; +const SEARCH_TOOL_NAME: &str = "memory_search"; + +pub(crate) fn memory_tools(backend: LocalMemoriesBackend) -> Vec> { + vec![ + Arc::new(list::ListTool { + backend: backend.clone(), + }), + Arc::new(read::ReadTool { + backend: backend.clone(), + }), + Arc::new(search::SearchTool { backend }), + ] +} + +fn function_tool(name: &str, description: &str) -> ToolSpec { + ToolSpec::Function(ResponsesApiTool { + name: name.to_string(), + description: description.to_string(), + strict: false, + defer_loading: None, + parameters: parse_tool_input_schema(&schema::input_schema_for::()) + .unwrap_or_else(|err| panic!("generated input schema for {name} should parse: {err}")), + output_schema: Some(schema::output_schema_for::()), + }) +} + +fn parse_args Deserialize<'de>>(call: &ToolCall) -> Result { + let arguments = call.function_arguments()?; + let value = if arguments.trim().is_empty() { + Value::Object(serde_json::Map::new()) + } else { + serde_json::from_str(arguments) + .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))? + }; + serde_json::from_value(value).map_err(|err| FunctionCallError::RespondToModel(err.to_string())) +} + +fn clamp_max_results(requested: Option, default: usize, max: usize) -> usize { + requested.unwrap_or(default).clamp(1, max) +} + +fn backend_error_to_function_call(err: MemoriesBackendError) -> FunctionCallError { + match err { + MemoriesBackendError::InvalidPath { .. } + | MemoriesBackendError::InvalidCursor { .. } + | MemoriesBackendError::NotFound { .. } + | MemoriesBackendError::InvalidLineOffset + | MemoriesBackendError::InvalidMaxLines + | MemoriesBackendError::LineOffsetExceedsFileLength + | MemoriesBackendError::NotFile { .. } + | MemoriesBackendError::EmptyQuery + | MemoriesBackendError::InvalidMatchWindow => { + FunctionCallError::RespondToModel(err.to_string()) + } + MemoriesBackendError::Io(_) => FunctionCallError::Fatal(err.to_string()), + } +} diff --git a/codex-rs/ext/memories/src/tools/read.rs b/codex-rs/ext/memories/src/tools/read.rs new file mode 100644 index 000000000000..96129ec0f915 --- /dev/null +++ b/codex-rs/ext/memories/src/tools/read.rs @@ -0,0 +1,120 @@ +use codex_extension_api::ExtensionToolExecutor; +use codex_extension_api::ExtensionToolFuture; +use codex_extension_api::JsonToolOutput; +use codex_extension_api::ToolCall; +use codex_extension_api::ToolName; +use codex_extension_api::ToolSpec; +use schemars::JsonSchema; +use serde::Deserialize; +use serde_json::json; + +use crate::backend::DEFAULT_READ_MAX_TOKENS; +use crate::backend::MemoriesBackend; +use crate::backend::ReadMemoryRequest; +use crate::backend::ReadMemoryResponse; +use crate::local::LocalMemoriesBackend; + +use super::READ_TOOL_NAME; +use super::backend_error_to_function_call; +use super::function_tool; +use super::parse_args; + +#[derive(Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct ReadArgs { + path: String, + #[schemars(range(min = 1))] + line_offset: Option, + #[schemars(range(min = 1))] + max_lines: Option, +} + +#[derive(Clone)] +pub(super) struct ReadTool { + pub(super) backend: LocalMemoriesBackend, +} + +impl ExtensionToolExecutor for ReadTool { + fn tool_name(&self) -> ToolName { + ToolName::plain(READ_TOOL_NAME) + } + + fn spec(&self) -> Option { + Some(function_tool::( + READ_TOOL_NAME, + "Read a Codex memory file by relative path, optionally starting at a 1-indexed line offset and limiting the number of lines returned.", + )) + } + + fn handle(&self, call: ToolCall) -> ExtensionToolFuture<'_> { + let backend = self.backend.clone(); + Box::pin(async move { + let args: ReadArgs = parse_args(&call)?; + let response = backend + .read(ReadMemoryRequest { + path: args.path, + line_offset: args.line_offset.unwrap_or(1), + max_lines: args.max_lines, + max_tokens: DEFAULT_READ_MAX_TOKENS, + }) + .await + .map_err(backend_error_to_function_call)?; + Ok(JsonToolOutput::new(json!(response))) + }) + } +} + +#[cfg(test)] +mod tests { + use codex_extension_api::ToolPayload; + use codex_tools::ToolOutput; + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + #[tokio::test] + async fn read_tool_reads_memory_file() { + let tempdir = tempfile::tempdir().expect("tempdir"); + let memory_root = tempdir.path().join("memories"); + tokio::fs::create_dir_all(&memory_root) + .await + .expect("create memories dir"); + tokio::fs::write( + memory_root.join("MEMORY.md"), + "first line\nsecond needle line\nthird line\n", + ) + .await + .expect("write memory"); + let tool = ReadTool { + backend: LocalMemoriesBackend::from_memory_root(&memory_root), + }; + let payload = ToolPayload::Function { + arguments: json!({ + "path": "MEMORY.md", + "line_offset": 2, + "max_lines": 1 + }) + .to_string(), + }; + + let output = tool + .handle(ToolCall { + call_id: "call-1".to_string(), + tool_name: ToolName::plain(READ_TOOL_NAME), + payload: payload.clone(), + }) + .await + .expect("read should succeed"); + + assert_eq!( + output.post_tool_use_response("call-1", &payload), + Some(json!({ + "path": "MEMORY.md", + "content": "second needle line\n", + "start_line_number": 2, + "truncated": true + })) + ); + } +} diff --git a/codex-rs/ext/memories/src/tools/search.rs b/codex-rs/ext/memories/src/tools/search.rs new file mode 100644 index 000000000000..f437df0040a9 --- /dev/null +++ b/codex-rs/ext/memories/src/tools/search.rs @@ -0,0 +1,160 @@ +use codex_extension_api::ExtensionToolExecutor; +use codex_extension_api::ExtensionToolFuture; +use codex_extension_api::JsonToolOutput; +use codex_extension_api::ToolCall; +use codex_extension_api::ToolName; +use codex_extension_api::ToolSpec; +use schemars::JsonSchema; +use serde::Deserialize; +use serde_json::json; + +use crate::backend::DEFAULT_SEARCH_MAX_RESULTS; +use crate::backend::MAX_SEARCH_RESULTS; +use crate::backend::MemoriesBackend; +use crate::backend::SearchMatchMode; +use crate::backend::SearchMemoriesRequest; +use crate::backend::SearchMemoriesResponse; +use crate::local::LocalMemoriesBackend; + +use super::SEARCH_TOOL_NAME; +use super::backend_error_to_function_call; +use super::clamp_max_results; +use super::function_tool; +use super::parse_args; + +#[derive(Debug, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct SearchArgs { + #[schemars(length(min = 1))] + queries: Vec, + match_mode: Option, + path: Option, + cursor: Option, + #[schemars(range(min = 0))] + context_lines: Option, + case_sensitive: Option, + normalized: Option, + #[schemars(range(min = 1))] + max_results: Option, +} + +#[derive(Clone)] +pub(super) struct SearchTool { + pub(super) backend: LocalMemoriesBackend, +} + +impl ExtensionToolExecutor for SearchTool { + fn tool_name(&self) -> ToolName { + ToolName::plain(SEARCH_TOOL_NAME) + } + + fn spec(&self) -> Option { + Some(function_tool::( + SEARCH_TOOL_NAME, + "Search Codex memory files for substring matches, optionally normalizing separators or requiring all query substrings on the same line or within a line window.", + )) + } + + fn handle(&self, call: ToolCall) -> ExtensionToolFuture<'_> { + let backend = self.backend.clone(); + Box::pin(async move { + let args: SearchArgs = parse_args(&call)?; + let response = backend + .search(args.into_request()) + .await + .map_err(backend_error_to_function_call)?; + Ok(JsonToolOutput::new(json!(response))) + }) + } +} + +impl SearchArgs { + fn into_request(self) -> SearchMemoriesRequest { + SearchMemoriesRequest { + queries: self.queries, + match_mode: self.match_mode.unwrap_or(SearchMatchMode::Any), + path: self.path, + cursor: self.cursor, + context_lines: self.context_lines.unwrap_or(0), + case_sensitive: self.case_sensitive.unwrap_or(true), + normalized: self.normalized.unwrap_or(false), + max_results: clamp_max_results( + self.max_results, + DEFAULT_SEARCH_MAX_RESULTS, + MAX_SEARCH_RESULTS, + ), + } + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + #[test] + fn search_args_accept_multiple_queries() { + let args: SearchArgs = serde_json::from_value(json!({ + "queries": ["alpha", "needle"], + "case_sensitive": false + })) + .expect("multi-query args should parse"); + + let request = args.into_request(); + + assert_eq!( + request, + SearchMemoriesRequest { + queries: vec!["alpha".to_string(), "needle".to_string()], + match_mode: SearchMatchMode::Any, + path: None, + cursor: None, + context_lines: 0, + case_sensitive: false, + normalized: false, + max_results: DEFAULT_SEARCH_MAX_RESULTS, + } + ); + } + + #[test] + fn search_args_accept_windowed_all_match_mode() { + let args: SearchArgs = serde_json::from_value(json!({ + "queries": ["alpha", "needle"], + "match_mode": { + "type": "all_within_lines", + "line_count": 3 + } + })) + .expect("windowed all args should parse"); + + let request = args.into_request(); + + assert_eq!( + request, + SearchMemoriesRequest { + queries: vec!["alpha".to_string(), "needle".to_string()], + match_mode: SearchMatchMode::AllWithinLines { line_count: 3 }, + path: None, + cursor: None, + context_lines: 0, + case_sensitive: true, + normalized: false, + max_results: DEFAULT_SEARCH_MAX_RESULTS, + } + ); + } + + #[test] + fn search_args_reject_legacy_single_query() { + let err = serde_json::from_value::(json!({ + "query": "needle", + })) + .expect_err("legacy query field should be rejected"); + + assert!(err.to_string().contains("unknown field")); + assert!(err.to_string().contains("query")); + } +}