Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 107 additions & 7 deletions rig/rig-core/src/embeddings/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,48 @@ where
texts.push((i, doc_texts));
}

// Flatten the texts while keeping track of the document index.
let mut flat_texts = Vec::new();
for (i, doc_texts) in texts.into_iter() {
for text in doc_texts {
flat_texts.push((i, text));
}
}

let max_documents = M::MAX_DOCUMENTS;
let max_tokens = self.model.max_tokens_per_request().unwrap_or(usize::MAX);

// Group them into batches.
let mut batches = Vec::new();
let mut current_batch = Vec::new();
let mut current_tokens = 0;

for (i, text) in flat_texts {
// Simple KISS estimate: bytes = tokens (upper bound)
let text_tokens = text.len();

// Check if adding this text would exceed the limit
if !current_batch.is_empty()
&& (current_batch.len() >= max_documents
|| current_tokens + text_tokens > max_tokens)
{
batches.push(current_batch);
current_batch = Vec::new();
current_tokens = 0;
}

current_tokens += text_tokens;
current_batch.push((i, text));
}
if !current_batch.is_empty() {
batches.push(current_batch);
}

// Compute the embeddings.
let mut embeddings = stream::iter(texts.into_iter())
// Merge the texts of each document into a single list of texts.
.flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
// Chunk them into batches. Each batch size is at most the embedding API limit per request.
.chunks(M::MAX_DOCUMENTS)
let mut embeddings = stream::iter(batches.into_iter())
// Generate the embeddings for each batch.
.map(|text| async {
let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
.map(|batch| async {
let (ids, docs): (Vec<_>, Vec<_>) = batch.into_iter().unzip();

let embeddings = self.model.embed_texts(docs).await?;
Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
Expand Down Expand Up @@ -407,4 +440,71 @@ mod tests {
second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
)
}

#[derive(Clone)]
struct LimitModel;

impl EmbeddingModel for LimitModel {
const MAX_DOCUMENTS: usize = 100;

type Client = Nothing;

fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
Self
}

fn max_tokens_per_request(&self) -> Option<usize> {
Some(10)
}

fn ndims(&self) -> usize {
10
}

async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
let docs: Vec<String> = documents.into_iter().collect();
let total_len: usize = docs.iter().map(|s| s.len()).sum();
if total_len > 10 {
return Err(crate::embeddings::EmbeddingError::ProviderError(
"Too many tokens".to_string(),
));
}
Ok(docs
.iter()
.map(|d| Embedding {
document: d.clone(),
vec: vec![0.0; 10],
})
.collect())
}
}

#[tokio::test]
async fn test_build_respects_token_limit() {
let docs = vec![
WordDefinitionSingle {
id: "1".into(),
definition: "hello".into(),
},
WordDefinitionSingle {
id: "2".into(),
definition: "world!".into(),
},
];

let model = LimitModel;
// This should pass if batching splits "hello" and "world!"
let result = EmbeddingsBuilder::new(model)
.documents(docs)
.unwrap()
.build()
.await;

assert!(result.is_ok(), "Build failed: {:?}", result.err());
let result = result.unwrap();
assert_eq!(result.len(), 2);
}
}
6 changes: 6 additions & 0 deletions rig/rig-core/src/embeddings/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync {
/// The maximum number of documents that can be embedded in a single request.
const MAX_DOCUMENTS: usize;

/// The maximum number of tokens that can be embedded in a single request.
/// If None, the limit is assumed to be infinite (or unknown).
fn max_tokens_per_request(&self) -> Option<usize> {
None
}

type Client;

fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self;
Expand Down
4 changes: 4 additions & 0 deletions rig/rig-core/src/providers/openai/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ where
{
const MAX_DOCUMENTS: usize = 1024;

fn max_tokens_per_request(&self) -> Option<usize> {
Some(300_000)
}

type Client = Client<T>;

fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
Expand Down