diff --git a/rig/rig-core/src/embeddings/builder.rs b/rig/rig-core/src/embeddings/builder.rs index 63d739f98..f2d251bf6 100644 --- a/rig/rig-core/src/embeddings/builder.rs +++ b/rig/rig-core/src/embeddings/builder.rs @@ -111,15 +111,48 @@ where texts.push((i, doc_texts)); } + // Flatten the texts while keeping track of the document index. + let mut flat_texts = Vec::new(); + for (i, doc_texts) in texts.into_iter() { + for text in doc_texts { + flat_texts.push((i, text)); + } + } + + let max_documents = M::MAX_DOCUMENTS; + let max_tokens = self.model.max_tokens_per_request().unwrap_or(usize::MAX); + + // Group them into batches. + let mut batches = Vec::new(); + let mut current_batch = Vec::new(); + let mut current_tokens = 0; + + for (i, text) in flat_texts { + // Simple KISS estimate: bytes = tokens (upper bound) + let text_tokens = text.len(); + + // Check if adding this text would exceed the limit + if !current_batch.is_empty() + && (current_batch.len() >= max_documents + || current_tokens + text_tokens > max_tokens) + { + batches.push(current_batch); + current_batch = Vec::new(); + current_tokens = 0; + } + + current_tokens += text_tokens; + current_batch.push((i, text)); + } + if !current_batch.is_empty() { + batches.push(current_batch); + } + // Compute the embeddings. - let mut embeddings = stream::iter(texts.into_iter()) - // Merge the texts of each document into a single list of texts. - .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text)))) - // Chunk them into batches. Each batch size is at most the embedding API limit per request. - .chunks(M::MAX_DOCUMENTS) + let mut embeddings = stream::iter(batches.into_iter()) // Generate the embeddings for each batch. - .map(|text| async { - let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip(); + .map(|batch| async { + let (ids, docs): (Vec<_>, Vec<_>) = batch.into_iter().unzip(); let embeddings = self.model.embed_texts(docs).await?; Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::>()) @@ -407,4 +440,71 @@ mod tests { second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() ) } + + #[derive(Clone)] + struct LimitModel; + + impl EmbeddingModel for LimitModel { + const MAX_DOCUMENTS: usize = 100; + + type Client = Nothing; + + fn make(_: &Self::Client, _: impl Into, _: Option) -> Self { + Self + } + + fn max_tokens_per_request(&self) -> Option { + Some(10) + } + + fn ndims(&self) -> usize { + 10 + } + + async fn embed_texts( + &self, + documents: impl IntoIterator + Send, + ) -> Result, crate::embeddings::EmbeddingError> { + let docs: Vec = documents.into_iter().collect(); + let total_len: usize = docs.iter().map(|s| s.len()).sum(); + if total_len > 10 { + return Err(crate::embeddings::EmbeddingError::ProviderError( + "Too many tokens".to_string(), + )); + } + Ok(docs + .iter() + .map(|d| Embedding { + document: d.clone(), + vec: vec![0.0; 10], + }) + .collect()) + } + } + + #[tokio::test] + async fn test_build_respects_token_limit() { + let docs = vec![ + WordDefinitionSingle { + id: "1".into(), + definition: "hello".into(), + }, + WordDefinitionSingle { + id: "2".into(), + definition: "world!".into(), + }, + ]; + + let model = LimitModel; + // This should pass if batching splits "hello" and "world!" + let result = EmbeddingsBuilder::new(model) + .documents(docs) + .unwrap() + .build() + .await; + + assert!(result.is_ok(), "Build failed: {:?}", result.err()); + let result = result.unwrap(); + assert_eq!(result.len(), 2); + } } diff --git a/rig/rig-core/src/embeddings/embedding.rs b/rig/rig-core/src/embeddings/embedding.rs index cfa1257f1..a656e6afc 100644 --- a/rig/rig-core/src/embeddings/embedding.rs +++ b/rig/rig-core/src/embeddings/embedding.rs @@ -47,6 +47,12 @@ pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync { /// The maximum number of documents that can be embedded in a single request. const MAX_DOCUMENTS: usize; + /// The maximum number of tokens that can be embedded in a single request. + /// If None, the limit is assumed to be infinite (or unknown). + fn max_tokens_per_request(&self) -> Option { + None + } + type Client; fn make(client: &Self::Client, model: impl Into, dims: Option) -> Self; diff --git a/rig/rig-core/src/providers/openai/embedding.rs b/rig/rig-core/src/providers/openai/embedding.rs index 54f923708..c6e4448bc 100644 --- a/rig/rig-core/src/providers/openai/embedding.rs +++ b/rig/rig-core/src/providers/openai/embedding.rs @@ -70,6 +70,10 @@ where { const MAX_DOCUMENTS: usize = 1024; + fn max_tokens_per_request(&self) -> Option { + Some(300_000) + } + type Client = Client; fn make(client: &Self::Client, model: impl Into, ndims: Option) -> Self {