Skip to content

Commit a3c6200

Browse files
committed
fix: Address Cursor Bugbot review feedback
- Fix multiple embedding types getting wrong indices by tracking used_batch_indices per embedding type instead of shared set - Fix fallback parser to use batch_texts when API doesn't return texts - Remove unused variables (current_path, in_embeddings) and dead code - Remove unused stream_embed_response convenience function
1 parent d8bb1e7 commit a3c6200

3 files changed

Lines changed: 27 additions & 34 deletions

File tree

src/cohere/base_client.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,20 +1221,28 @@ def embed_stream(
12211221

12221222
# Parse embeddings from response incrementally
12231223
parser = StreamingEmbedParser(response._response, batch_texts)
1224-
# Track used indices to handle duplicate texts correctly
1225-
used_batch_indices = set()
1224+
# Track used indices per embedding type to handle:
1225+
# 1. Duplicate texts within a batch
1226+
# 2. Multiple embedding types (float, int8, etc.) for the same texts
1227+
used_batch_indices_by_type: dict[str, set[int]] = {}
12261228

12271229
for embedding in parser.iter_embeddings():
12281230
# The parser sets embedding.text correctly for multiple embedding types
12291231
# Adjust the global index based on text position in batch
12301232
if embedding.text and embedding.text in batch_texts:
1233+
# Get or create the set of used indices for this embedding type
1234+
emb_type = embedding.embedding_type
1235+
if emb_type not in used_batch_indices_by_type:
1236+
used_batch_indices_by_type[emb_type] = set()
1237+
used_indices = used_batch_indices_by_type[emb_type]
1238+
12311239
# Find the next unused occurrence of this text in the batch
12321240
# This handles duplicate texts correctly
12331241
text_idx_in_batch = None
12341242
for idx, text in enumerate(batch_texts):
1235-
if text == embedding.text and idx not in used_batch_indices:
1243+
if text == embedding.text and idx not in used_indices:
12361244
text_idx_in_batch = idx
1237-
used_batch_indices.add(idx)
1245+
used_indices.add(idx)
12381246
break
12391247

12401248
if text_idx_in_batch is not None:

src/cohere/streaming_utils.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,13 @@ def iter_embeddings(self) -> Iterator[StreamedEmbedding]:
8080

8181
def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
8282
"""Parse embeddings using ijson incremental parser."""
83-
current_path: List[str] = []
8483
current_embedding = []
8584
# Track text index separately per embedding type
8685
# When multiple types requested, each text gets multiple embeddings
8786
type_text_indices: dict = {}
88-
embedding_type = "float"
8987
response_type = None
90-
in_embeddings = False
9188

9289
for prefix, event, value in parser:
93-
# Track current path
94-
if event == 'map_key':
95-
if current_path and current_path[-1] == 'embeddings':
96-
# This is an embedding type key (float_, int8, etc.)
97-
embedding_type = value.rstrip('_')
98-
9990
# Detect response type
10091
if prefix == 'response_type':
10192
response_type = value
@@ -170,10 +161,11 @@ def _iter_embeddings_fallback(self) -> Iterator[StreamedEmbedding]:
170161
def _iter_embeddings_fallback_from_dict(self, data: dict) -> Iterator[StreamedEmbedding]:
171162
"""Parse embeddings from a dictionary (used by fallback methods)."""
172163
response_type = data.get('response_type', '')
164+
# Use batch_texts from constructor as fallback if API doesn't return texts
165+
texts = data.get('texts') or self.batch_texts
173166

174167
if response_type == 'embeddings_floats':
175168
embeddings = data.get('embeddings', [])
176-
texts = data.get('texts', [])
177169
for i, embedding in enumerate(embeddings):
178170
yield StreamedEmbedding(
179171
index=self.embeddings_yielded + i,
@@ -184,7 +176,6 @@ def _iter_embeddings_fallback_from_dict(self, data: dict) -> Iterator[StreamedEm
184176

185177
elif response_type == 'embeddings_by_type':
186178
embeddings_obj = data.get('embeddings', {})
187-
texts = data.get('texts', [])
188179

189180
# Iterate through each embedding type
190181
for emb_type, embeddings_list in embeddings_obj.items():
@@ -198,18 +189,3 @@ def _iter_embeddings_fallback_from_dict(self, data: dict) -> Iterator[StreamedEm
198189
text=texts[i] if i < len(texts) else None
199190
)
200191
self.embeddings_yielded += 1
201-
202-
203-
def stream_embed_response(response: httpx.Response, texts: List[str]) -> Iterator[StreamedEmbedding]:
204-
"""
205-
Convenience function to stream embeddings from a response.
206-
207-
Args:
208-
response: The httpx response containing embeddings
209-
texts: The original texts that were embedded
210-
211-
Yields:
212-
StreamedEmbedding objects
213-
"""
214-
parser = StreamingEmbedParser(response, texts)
215-
yield from parser.iter_embeddings()

src/cohere/v2/client.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,21 +602,30 @@ def embed_stream(
602602

603603
# Parse embeddings from response incrementally
604604
parser = StreamingEmbedParser(response._response, batch_texts)
605-
# Track used indices to handle duplicate texts correctly
606-
used_batch_indices: set[int] = set()
605+
# Track used indices per embedding type to handle:
606+
# 1. Duplicate texts within a batch
607+
# 2. Multiple embedding types (float, int8, etc.) for the same texts
608+
used_batch_indices_by_type: dict[str, set[int]] = {}
607609

608610
for embedding in parser.iter_embeddings():
609611
# The parser sets embedding.text correctly for multiple embedding types
610612
# Adjust the global index based on text position in batch
611613
if embedding.text and embedding.text in batch_texts:
614+
# Get or create the set of used indices for this embedding type
615+
emb_type = embedding.embedding_type
616+
if emb_type not in used_batch_indices_by_type:
617+
used_batch_indices_by_type[emb_type] = set()
618+
used_indices = used_batch_indices_by_type[emb_type]
619+
612620
# Find the next unused occurrence of this text in the batch
613621
# This handles duplicate texts correctly
614622
text_idx_in_batch = None
615623
for idx, text in enumerate(batch_texts):
616-
if text == embedding.text and idx not in used_batch_indices:
624+
if text == embedding.text and idx not in used_indices:
617625
text_idx_in_batch = idx
618-
used_batch_indices.add(idx)
626+
used_indices.add(idx)
619627
break
628+
620629
if text_idx_in_batch is not None:
621630
embedding.index = batch_start + text_idx_in_batch
622631
yield embedding

0 commit comments

Comments
 (0)