Skip to content

Commit 51a1a6e

Browse files
committed
Reduce fasttext memory usage by computing ngrams on the fly
1 parent 9f3428a commit 51a1a6e

11 files changed

Lines changed: 1843 additions & 1379 deletions

gensim/models/deprecated/fasttext_wrapper.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def __init__(self):
7070
self.syn0_vocab_norm = None
7171
self.syn0_ngrams = None
7272
self.syn0_ngrams_norm = None
73-
self.ngrams = {}
7473
self.hash2index = {}
75-
self.ngrams_word = {}
7674
self.min_n = 0
7775
self.max_n = 0
7876

@@ -99,17 +97,18 @@ def word_vec(self, word, use_norm=False):
9997
return super(FastTextKeyedVectors, self).word_vec(word, use_norm)
10098
else:
10199
word_vec = np.zeros(self.syn0_ngrams.shape[1], dtype=np.float32)
102-
ngrams = compute_ngrams(word, self.min_n, self.max_n)
103-
ngrams = [ng for ng in ngrams if ng in self.ngrams]
100+
hashes = [ft_hash(ng) % self.bucket
101+
for ng in compute_ngrams(word, self.min_n, self.max_n)]
102+
hashes = [h for h in hashes if h in self.hash2index]
104103
if use_norm:
105104
ngram_weights = self.syn0_ngrams_norm
106105
else:
107106
ngram_weights = self.syn0_ngrams
108-
for ngram in ngrams:
109-
word_vec += ngram_weights[self.ngrams[ngram]]
107+
for ngram_hash in hashes:
108+
word_vec += ngram_weights[self.hash2index[ngram_hash]]
110109
if word_vec.any():
111-
return word_vec / len(ngrams)
112-
else: # No ngrams of the word are present in self.ngrams
110+
return word_vec / len(hashes)
111+
else: # No hashes of any ngrams of the word are present in self.hash2index
113112
raise KeyError('all ngrams for word %s absent from model' % word)
114113

115114
def init_sims(self, replace=False):
@@ -143,7 +142,8 @@ def __contains__(self, word):
143142
return True
144143
else:
145144
char_ngrams = compute_ngrams(word, self.min_n, self.max_n)
146-
return any(ng in self.ngrams for ng in char_ngrams)
145+
return any(ft_hash(ng) % self.bucket in self.hash2index
146+
for ng in char_ngrams)
147147

148148
@classmethod
149149
def load_word2vec_format(cls, *args, **kwargs):
@@ -277,6 +277,12 @@ def load(cls, *args, **kwargs):
277277
if hasattr(model.wv, 'syn0_all'):
278278
setattr(model.wv, 'syn0_ngrams', model.wv.syn0_all)
279279
delattr(model.wv, 'syn0_all')
280+
setattr(model.wv, 'bucket', model.wv.syn0_ngrams.shape[0])
281+
if not hasattr(model.wv, 'hash2index') and hasattr(model.wv, 'ngrams'):
282+
model.wv.hash2index = {}
283+
for i, ngram in enumerate(model.wv.ngrams):
284+
ngram_hash = ft_hash(ngram) % model.wv.bucket
285+
model.wv.hash2index[ngram_hash] = i
280286
return model
281287

282288
@classmethod
@@ -316,6 +322,7 @@ def load_model_params(self, file_handle):
316322
self.hs = loss == 1
317323
self.sg = model == 2
318324
self.bucket = bucket
325+
self.wv.bucket = bucket
319326
self.wv.min_n = minn
320327
self.wv.max_n = maxn
321328
self.sample = t
@@ -394,7 +401,6 @@ def init_ngrams(self):
394401
vectors are discarded here to save space.
395402
396403
"""
397-
self.wv.ngrams = {}
398404
all_ngrams = []
399405
self.wv.syn0 = np.zeros((len(self.wv.vocab), self.vector_size), dtype=REAL)
400406

@@ -406,9 +412,9 @@ def init_ngrams(self):
406412
self.num_ngram_vectors = len(all_ngrams)
407413
ngram_indices = []
408414
for i, ngram in enumerate(all_ngrams):
409-
ngram_hash = ft_hash(ngram)
415+
ngram_hash = ft_hash(ngram) % self.bucket
410416
ngram_indices.append(len(self.wv.vocab) + ngram_hash % self.bucket)
411-
self.wv.ngrams[ngram] = i
417+
self.wv.hash2index[ngram_hash] = i
412418
self.wv.syn0_ngrams = self.wv.syn0_ngrams.take(ngram_indices, axis=0)
413419

414420
ngram_weights = self.wv.syn0_ngrams
@@ -421,7 +427,9 @@ def init_ngrams(self):
421427
for w, vocab in self.wv.vocab.items():
422428
word_ngrams = compute_ngrams(w, self.wv.min_n, self.wv.max_n)
423429
for word_ngram in word_ngrams:
424-
self.wv.syn0[vocab.index] += np.array(ngram_weights[self.wv.ngrams[word_ngram]])
430+
ng_hash = ft_hash(word_ngram) % self.bucket
431+
self.wv.syn0[vocab.index] += np.array(
432+
ngram_weights[self.wv.hash2index[ng_hash]])
425433

426434
self.wv.syn0[vocab.index] /= (len(word_ngrams) + 1)
427435
logger.info(

gensim/models/fasttext.py

Lines changed: 75 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from gensim.models.base_any2vec import BaseWordEmbeddingsModel
4040
from gensim.models.utils_any2vec import _compute_ngrams, _ft_hash
4141

42-
from six import iteritems
4342
from gensim.utils import deprecated, call_on_class_only
4443
from gensim import utils
4544

@@ -93,10 +92,12 @@ def train_batch_cbow(model, sentences, alpha, work=None, neu1=None):
9392

9493
for index in word2_indices:
9594
vocab_subwords_indices += [index]
96-
word2_subwords += model.wv.ngrams_word[model.wv.index2word[index]]
95+
word2_subwords += _compute_ngrams(model.wv.index2word[index],
96+
model.min_n, model.max_n)
9797

9898
for subword in word2_subwords:
99-
ngrams_subwords_indices.append(model.wv.ngrams[subword])
99+
ngrams_subwords_indices.append(
100+
model.wv.hash2index[_ft_hash(subword) % model.bucket])
100101

101102
l1_vocab = np_sum(model.wv.syn0_vocab[vocab_subwords_indices], axis=0) # 1 x vector_size
102103
l1_ngrams = np_sum(model.wv.syn0_ngrams[ngrams_subwords_indices], axis=0) # 1 x vector_size
@@ -144,10 +145,11 @@ def train_batch_sg(model, sentences, alpha, work=None, neu1=None):
144145
start = max(0, pos - model.window + reduced_window)
145146

146147
subwords_indices = [word.index]
147-
word2_subwords = model.wv.ngrams_word[model.wv.index2word[word.index]]
148+
word2_subwords = _compute_ngrams(model.wv.index2word[word.index],
149+
model.min_n, model.max_n)
148150

149151
for subword in word2_subwords:
150-
subwords_indices.append(model.wv.ngrams[subword])
152+
subwords_indices.append(model.wv.hash2index[_ft_hash(subword) % model.bucket])
151153

152154
for pos2, word2 in enumerate(word_vocabs[start:(pos + model.window + 1 - reduced_window)], start):
153155
if pos2 != pos: # don't train on the `word` itself
@@ -278,6 +280,7 @@ def __init__(self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5,
278280
sorted_vocab=bool(sorted_vocab), null_word=null_word)
279281
self.trainables = FastTextTrainables(
280282
vector_size=size, seed=seed, bucket=bucket, hashfxn=hashfxn)
283+
self.wv.bucket = self.bucket
281284

282285
super(FastText, self).__init__(
283286
sentences=sentences, workers=workers, vector_size=size, epochs=iter, callbacks=callbacks,
@@ -396,6 +399,36 @@ def _clear_post_train(self):
396399
self.wv.vectors_vocab_norm = None
397400
self.wv.vectors_ngrams_norm = None
398401

402+
def estimate_memory(self, vocab_size=None, report=None):
403+
vocab_size = vocab_size or len(self.wv.vocab)
404+
vec_size = self.vector_size * np.dtype(np.float32).itemsize
405+
l1_size = self.layer1_size * np.dtype(np.float32).itemsize
406+
report = report or {}
407+
report['vocab'] = len(self.wv.vocab) * (700 if self.hs else 500)
408+
report['syn0_vocab'] = len(self.wv.vocab) * vec_size
409+
num_buckets = self.bucket
410+
if self.hs:
411+
report['syn1'] = len(self.wv.vocab) * l1_size
412+
if self.negative:
413+
report['syn1neg'] = len(self.wv.vocab) * l1_size
414+
if self.word_ngrams > 0 and self.wv.vocab:
415+
buckets = set()
416+
for word in self.wv.vocab:
417+
ngrams = _compute_ngrams(word, self.min_n, self.max_n)
418+
buckets.update(_ft_hash(ng) % self.bucket for ng in ngrams)
419+
num_buckets = len(buckets)
420+
report['syn0_ngrams'] = len(buckets) * vec_size
421+
elif self.word_ngrams > 0:
422+
logger.warn('subword information is enabled, but no vocabulary '
423+
'could be found, estimated required memory might be '
424+
'inaccurate!')
425+
report['total'] = sum(report.values())
426+
logger.info(
427+
"estimated required memory for %i words, %i buckets and %i dimensions: %i bytes",
428+
len(self.wv.vocab), num_buckets, self.vector_size, report['total']
429+
)
430+
return report
431+
399432
def _do_train_job(self, sentences, alpha, inits):
400433
"""Train a single batch of sentences. Return 2-tuple `(effective word count after
401434
ignoring unknown words and sentence length trimming, total word count)`.
@@ -580,6 +613,7 @@ def _load_model_params(self, file_handle):
580613
self.hs = loss == 1
581614
self.sg = model == 2
582615
self.trainables.bucket = bucket
616+
self.wv.bucket = bucket
583617
self.wv.min_n = minn
584618
self.wv.max_n = maxn
585619
self.vocabulary.sample = t
@@ -709,18 +743,8 @@ def prepare_vocab(self, hs, negative, wv, update=False, keep_raw_vocab=False, tr
709743
report_values = super(FastTextVocab, self).prepare_vocab(
710744
hs, negative, wv, update=update, keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule,
711745
min_count=min_count, sample=sample, dry_run=dry_run)
712-
self.build_ngrams(wv, update=update)
713746
return report_values
714747

715-
def build_ngrams(self, wv, update=False):
716-
if not update:
717-
wv.ngrams_word = {}
718-
for w, v in iteritems(wv.vocab):
719-
wv.ngrams_word[w] = _compute_ngrams(w, wv.min_n, wv.max_n)
720-
else:
721-
for w, v in iteritems(wv.vocab):
722-
wv.ngrams_word[w] = _compute_ngrams(w, wv.min_n, wv.max_n)
723-
724748

725749
class FastTextTrainables(Word2VecTrainables):
726750
def __init__(self, vector_size=100, seed=1, hashfxn=hash, bucket=2000000):
@@ -744,54 +768,43 @@ def init_ngrams_weights(self, wv, update=False, vocabulary=None):
744768
745769
"""
746770
if not update:
747-
wv.ngrams = {}
748771
wv.vectors_vocab = empty((len(wv.vocab), wv.vector_size), dtype=REAL)
749772
self.vectors_vocab_lockf = ones((len(wv.vocab), wv.vector_size), dtype=REAL)
750773

751774
wv.vectors_ngrams = empty((self.bucket, wv.vector_size), dtype=REAL)
752775
self.vectors_ngrams_lockf = ones((self.bucket, wv.vector_size), dtype=REAL)
753776

754-
all_ngrams = []
755-
for w, ngrams in iteritems(wv.ngrams_word):
756-
all_ngrams += ngrams
757-
758-
all_ngrams = list(set(all_ngrams))
759-
wv.num_ngram_vectors = len(all_ngrams)
760-
logger.info("Total number of ngrams is %d", len(all_ngrams))
761-
762777
wv.hash2index = {}
763778
ngram_indices = []
764779
new_hash_count = 0
765-
for i, ngram in enumerate(all_ngrams):
766-
ngram_hash = _ft_hash(ngram) % self.bucket
767-
if ngram_hash in wv.hash2index:
768-
wv.ngrams[ngram] = wv.hash2index[ngram_hash]
769-
else:
770-
ngram_indices.append(ngram_hash % self.bucket)
771-
wv.hash2index[ngram_hash] = new_hash_count
772-
wv.ngrams[ngram] = wv.hash2index[ngram_hash]
773-
new_hash_count = new_hash_count + 1
780+
wv.num_ngram_vectors = 0
781+
for word in wv.vocab.keys():
782+
for ngram in _compute_ngrams(word, wv.min_n, wv.max_n):
783+
ngram_hash = _ft_hash(ngram) % self.bucket
784+
if ngram_hash not in wv.hash2index:
785+
wv.num_ngram_vectors += 1
786+
ngram_indices.append(ngram_hash)
787+
wv.hash2index[ngram_hash] = new_hash_count
788+
new_hash_count = new_hash_count + 1
789+
790+
logger.info("Total number of ngrams is %d", wv.num_ngram_vectors)
774791

775792
wv.vectors_ngrams = wv.vectors_ngrams.take(ngram_indices, axis=0)
776793
self.vectors_ngrams_lockf = self.vectors_ngrams_lockf.take(ngram_indices, axis=0)
777794
self.reset_ngrams_weights(wv)
778795
else:
779-
new_ngrams = []
780-
for w, ngrams in iteritems(wv.ngrams_word):
781-
new_ngrams += [ng for ng in ngrams if ng not in wv.ngrams]
782-
783-
new_ngrams = list(set(new_ngrams))
784-
wv.num_ngram_vectors += len(new_ngrams)
785-
logger.info("Number of new ngrams is %d", len(new_ngrams))
786796
new_hash_count = 0
787-
for i, ngram in enumerate(new_ngrams):
788-
ngram_hash = _ft_hash(ngram) % self.bucket
789-
if ngram_hash not in wv.hash2index:
790-
wv.hash2index[ngram_hash] = new_hash_count + self.old_hash2index_len
791-
wv.ngrams[ngram] = wv.hash2index[ngram_hash]
792-
new_hash_count = new_hash_count + 1
793-
else:
794-
wv.ngrams[ngram] = wv.hash2index[ngram_hash]
797+
num_new_ngrams = 0
798+
for word in wv.vocab.keys():
799+
for ngram in _compute_ngrams(word, wv.min_n, wv.max_n):
800+
ngram_hash = _ft_hash(ngram) % self.bucket
801+
if ngram_hash not in wv.hash2index:
802+
wv.hash2index[ngram_hash] = new_hash_count + self.old_hash2index_len
803+
new_hash_count = new_hash_count + 1
804+
num_new_ngrams += 1
805+
806+
wv.num_ngram_vectors += num_new_ngrams
807+
logger.info("Number of new ngrams is %d", num_new_ngrams)
795808

796809
rand_obj = np.random
797810
rand_obj.seed(self.seed)
@@ -833,10 +846,11 @@ def get_vocab_word_vecs(self, wv):
833846
"""Calculate vectors for words in vocabulary and stores them in `vectors`."""
834847
for w, v in wv.vocab.items():
835848
word_vec = np.copy(wv.vectors_vocab[v.index])
836-
ngrams = wv.ngrams_word[w]
849+
ngrams = _compute_ngrams(w, wv.min_n, wv.max_n)
837850
ngram_weights = wv.vectors_ngrams
838851
for ngram in ngrams:
839-
word_vec += ngram_weights[wv.ngrams[ngram]]
852+
word_vec += ngram_weights[
853+
wv.hash2index[_ft_hash(ngram) % self.bucket]]
840854
word_vec /= (len(ngrams) + 1)
841855
wv.vectors[v.index] = word_vec
842856

@@ -847,20 +861,21 @@ def init_ngrams_post_load(self, file_name, wv):
847861
vectors are discarded here to save space.
848862
849863
"""
850-
all_ngrams = []
851864
wv.vectors = np.zeros((len(wv.vocab), wv.vector_size), dtype=REAL)
852865

853866
for w, vocab in wv.vocab.items():
854-
all_ngrams += _compute_ngrams(w, wv.min_n, wv.max_n)
855867
wv.vectors[vocab.index] += np.array(wv.vectors_ngrams[vocab.index])
856868

857-
all_ngrams = set(all_ngrams)
858-
wv.num_ngram_vectors = len(all_ngrams)
859869
ngram_indices = []
860-
for i, ngram in enumerate(all_ngrams):
861-
ngram_hash = _ft_hash(ngram)
862-
ngram_indices.append(len(wv.vocab) + ngram_hash % self.bucket)
863-
wv.ngrams[ngram] = i
870+
wv.num_ngram_vectors = 0
871+
for word in wv.vocab.keys():
872+
for ngram in _compute_ngrams(word, wv.min_n, wv.max_n):
873+
ngram_hash = _ft_hash(ngram) % self.bucket
874+
if ngram_hash in wv.hash2index:
875+
continue
876+
ngram_indices.append(len(wv.vocab) + ngram_hash)
877+
wv.hash2index[ngram_hash] = wv.num_ngram_vectors
878+
wv.num_ngram_vectors += 1
864879
wv.vectors_ngrams = wv.vectors_ngrams.take(ngram_indices, axis=0)
865880

866881
ngram_weights = wv.vectors_ngrams
@@ -873,7 +888,8 @@ def init_ngrams_post_load(self, file_name, wv):
873888
for w, vocab in wv.vocab.items():
874889
word_ngrams = _compute_ngrams(w, wv.min_n, wv.max_n)
875890
for word_ngram in word_ngrams:
876-
wv.vectors[vocab.index] += np.array(ngram_weights[wv.ngrams[word_ngram]])
891+
vec_idx = wv.hash2index[_ft_hash(word_ngram) % self.bucket]
892+
wv.vectors[vocab.index] += np.array(ngram_weights[vec_idx])
877893

878894
wv.vectors[vocab.index] /= (len(word_ngrams) + 1)
879895
logger.info(

0 commit comments

Comments
 (0)