Skip to content

Commit 686e975

Browse files
authored
Merge pull request #1318 from quole/develop
Add KeyedVectors support to AnnoyIndexer
2 parents 76d9861 + dbfea83 commit 686e975

4 files changed

Lines changed: 275 additions & 86 deletions

File tree

docs/notebooks/annoytutorial.ipynb

Lines changed: 241 additions & 76 deletions
Large diffs are not rendered by default.

gensim/models/keyedvectors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def __init__(self):
118118
self.index2word = []
119119
self.vector_size = None
120120

121+
@property
122+
def wv(self):
123+
return self
124+
121125
def save(self, *args, **kwargs):
122126
# don't bother storing the cached normalized vectors
123127
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm'])

gensim/similarities/index.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from gensim.models.doc2vec import Doc2Vec
1515
from gensim.models.word2vec import Word2Vec
16+
from gensim.models.keyedvectors import KeyedVectors
1617
try:
1718
from annoy import AnnoyIndex
1819
except ImportError:
@@ -32,8 +33,10 @@ def __init__(self, model=None, num_trees=None):
3233
self.build_from_doc2vec()
3334
elif isinstance(self.model, Word2Vec):
3435
self.build_from_word2vec()
36+
elif isinstance(self.model, KeyedVectors):
37+
self.build_from_keyedvectors()
3538
else:
36-
raise ValueError("Only a Word2Vec or Doc2Vec instance can be used")
39+
raise ValueError("Only a Word2Vec, Doc2Vec or KeyedVectors instance can be used")
3740

3841
def save(self, fname, protocol=2):
3942
fname_dict = fname + '.d'
@@ -70,6 +73,12 @@ def build_from_doc2vec(self):
7073
labels = [docvecs.index_to_doctag(i) for i in range(0, docvecs.count)]
7174
return self._build_from_model(docvecs.doctag_syn0norm, labels, self.model.vector_size)
7275

76+
def build_from_keyedvectors(self):
77+
"""Build an Annoy index using word vectors from a KeyedVectors model"""
78+
79+
self.model.init_sims()
80+
return self._build_from_model(self.model.syn0norm, self.model.index2word, self.model.vector_size)
81+
7382
def _build_from_model(self, vectors, labels, num_features):
7483
index = AnnoyIndex(num_features)
7584

gensim/test/test_similarities.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from gensim.corpora import mmcorpus, Dictionary
2121
from gensim.models import word2vec
2222
from gensim.models import doc2vec
23+
from gensim.models import KeyedVectors
2324
from gensim.models.wrappers import fasttext
2425
from gensim import matutils, utils, similarities
2526
from gensim.models import Word2Vec
@@ -458,8 +459,8 @@ def testWord2Vec(self):
458459
model.init_sims()
459460
index = self.indexer(model, 10)
460461

461-
self.assertVectorIsSimilarToItself(model, index)
462-
self.assertApproxNeighborsMatchExact(model, index)
462+
self.assertVectorIsSimilarToItself(model.wv, index)
463+
self.assertApproxNeighborsMatchExact(model, model.wv, index)
463464
self.assertIndexSaved(index)
464465
self.assertLoadedIndexEqual(index, model)
465466

@@ -473,28 +474,38 @@ def testFastText(self):
473474
model.init_sims()
474475
index = self.indexer(model, 10)
475476

476-
self.assertVectorIsSimilarToItself(model, index)
477-
self.assertApproxNeighborsMatchExact(model, index)
477+
self.assertVectorIsSimilarToItself(model.wv, index)
478+
self.assertApproxNeighborsMatchExact(model, model.wv, index)
478479
self.assertIndexSaved(index)
479480
self.assertLoadedIndexEqual(index, model)
480481

482+
def testAnnoyIndexingOfKeyedVectors(self):
483+
from gensim.similarities.index import AnnoyIndexer
484+
keyVectors_file = datapath('lee_fasttext.vec')
485+
model = KeyedVectors.load_word2vec_format(keyVectors_file)
486+
index = AnnoyIndexer(model, 10)
487+
488+
self.assertEqual(index.num_trees, 10)
489+
self.assertVectorIsSimilarToItself(model, index)
490+
self.assertApproxNeighborsMatchExact(model, model, index)
491+
481492
def testLoadMissingRaisesError(self):
482493
from gensim.similarities.index import AnnoyIndexer
483494
test_index = AnnoyIndexer()
484495

485496
self.assertRaises(IOError, test_index.load, fname='test-index')
486497

487-
def assertVectorIsSimilarToItself(self, model, index):
488-
vector = model.wv.syn0norm[0]
489-
label = model.wv.index2word[0]
498+
def assertVectorIsSimilarToItself(self, wv, index):
499+
vector = wv.syn0norm[0]
500+
label = wv.index2word[0]
490501
approx_neighbors = index.most_similar(vector, 1)
491502
word, similarity = approx_neighbors[0]
492503

493504
self.assertEqual(word, label)
494505
self.assertEqual(similarity, 1.0)
495506

496-
def assertApproxNeighborsMatchExact(self, model, index):
497-
vector = model.wv.syn0norm[0]
507+
def assertApproxNeighborsMatchExact(self, model, wv, index):
508+
vector = wv.syn0norm[0]
498509
approx_neighbors = model.most_similar([vector], topn=5, indexer=index)
499510
exact_neighbors = model.most_similar(positive=[vector], topn=5)
500511

0 commit comments

Comments
 (0)