3939from gensim .models .base_any2vec import BaseWordEmbeddingsModel
4040from gensim .models .utils_any2vec import _compute_ngrams , _ft_hash
4141
42- from six import iteritems
4342from gensim .utils import deprecated , call_on_class_only
4443from gensim import utils
4544
@@ -93,10 +92,12 @@ def train_batch_cbow(model, sentences, alpha, work=None, neu1=None):
9392
9493 for index in word2_indices :
9594 vocab_subwords_indices += [index ]
96- word2_subwords += model .wv .ngrams_word [model .wv .index2word [index ]]
95+ word2_subwords += _compute_ngrams (model .wv .index2word [index ],
96+ model .min_n , model .max_n )
9797
9898 for subword in word2_subwords :
99- ngrams_subwords_indices .append (model .wv .ngrams [subword ])
99+ ngrams_subwords_indices .append (
100+ model .wv .hash2index [_ft_hash (subword ) % model .bucket ])
100101
101102 l1_vocab = np_sum (model .wv .syn0_vocab [vocab_subwords_indices ], axis = 0 ) # 1 x vector_size
102103 l1_ngrams = np_sum (model .wv .syn0_ngrams [ngrams_subwords_indices ], axis = 0 ) # 1 x vector_size
@@ -144,10 +145,11 @@ def train_batch_sg(model, sentences, alpha, work=None, neu1=None):
144145 start = max (0 , pos - model .window + reduced_window )
145146
146147 subwords_indices = [word .index ]
147- word2_subwords = model .wv .ngrams_word [model .wv .index2word [word .index ]]
148+ word2_subwords = _compute_ngrams (model .wv .index2word [word .index ],
149+ model .min_n , model .max_n )
148150
149151 for subword in word2_subwords :
150- subwords_indices .append (model .wv .ngrams [ subword ])
152+ subwords_indices .append (model .wv .hash2index [ _ft_hash ( subword ) % model . bucket ])
151153
152154 for pos2 , word2 in enumerate (word_vocabs [start :(pos + model .window + 1 - reduced_window )], start ):
153155 if pos2 != pos : # don't train on the `word` itself
@@ -278,6 +280,7 @@ def __init__(self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5,
278280 sorted_vocab = bool (sorted_vocab ), null_word = null_word )
279281 self .trainables = FastTextTrainables (
280282 vector_size = size , seed = seed , bucket = bucket , hashfxn = hashfxn )
283+ self .wv .bucket = self .bucket
281284
282285 super (FastText , self ).__init__ (
283286 sentences = sentences , workers = workers , vector_size = size , epochs = iter , callbacks = callbacks ,
@@ -396,6 +399,36 @@ def _clear_post_train(self):
396399 self .wv .vectors_vocab_norm = None
397400 self .wv .vectors_ngrams_norm = None
398401
402+ def estimate_memory (self , vocab_size = None , report = None ):
403+ vocab_size = vocab_size or len (self .wv .vocab )
404+ vec_size = self .vector_size * np .dtype (np .float32 ).itemsize
405+ l1_size = self .layer1_size * np .dtype (np .float32 ).itemsize
406+ report = report or {}
407+ report ['vocab' ] = len (self .wv .vocab ) * (700 if self .hs else 500 )
408+ report ['syn0_vocab' ] = len (self .wv .vocab ) * vec_size
409+ num_buckets = self .bucket
410+ if self .hs :
411+ report ['syn1' ] = len (self .wv .vocab ) * l1_size
412+ if self .negative :
413+ report ['syn1neg' ] = len (self .wv .vocab ) * l1_size
414+ if self .word_ngrams > 0 and self .wv .vocab :
415+ buckets = set ()
416+ for word in self .wv .vocab :
417+ ngrams = _compute_ngrams (word , self .min_n , self .max_n )
418+ buckets .update (_ft_hash (ng ) % self .bucket for ng in ngrams )
419+ num_buckets = len (buckets )
420+ report ['syn0_ngrams' ] = len (buckets ) * vec_size
421+ elif self .word_ngrams > 0 :
422+ logger .warn ('subword information is enabled, but no vocabulary '
423+ 'could be found, estimated required memory might be '
424+ 'inaccurate!' )
425+ report ['total' ] = sum (report .values ())
426+ logger .info (
427+ "estimated required memory for %i words, %i buckets and %i dimensions: %i bytes" ,
428+ len (self .wv .vocab ), num_buckets , self .vector_size , report ['total' ]
429+ )
430+ return report
431+
399432 def _do_train_job (self , sentences , alpha , inits ):
400433 """Train a single batch of sentences. Return 2-tuple `(effective word count after
401434 ignoring unknown words and sentence length trimming, total word count)`.
@@ -580,6 +613,7 @@ def _load_model_params(self, file_handle):
580613 self .hs = loss == 1
581614 self .sg = model == 2
582615 self .trainables .bucket = bucket
616+ self .wv .bucket = bucket
583617 self .wv .min_n = minn
584618 self .wv .max_n = maxn
585619 self .vocabulary .sample = t
@@ -709,18 +743,8 @@ def prepare_vocab(self, hs, negative, wv, update=False, keep_raw_vocab=False, tr
709743 report_values = super (FastTextVocab , self ).prepare_vocab (
710744 hs , negative , wv , update = update , keep_raw_vocab = keep_raw_vocab , trim_rule = trim_rule ,
711745 min_count = min_count , sample = sample , dry_run = dry_run )
712- self .build_ngrams (wv , update = update )
713746 return report_values
714747
715- def build_ngrams (self , wv , update = False ):
716- if not update :
717- wv .ngrams_word = {}
718- for w , v in iteritems (wv .vocab ):
719- wv .ngrams_word [w ] = _compute_ngrams (w , wv .min_n , wv .max_n )
720- else :
721- for w , v in iteritems (wv .vocab ):
722- wv .ngrams_word [w ] = _compute_ngrams (w , wv .min_n , wv .max_n )
723-
724748
725749class FastTextTrainables (Word2VecTrainables ):
726750 def __init__ (self , vector_size = 100 , seed = 1 , hashfxn = hash , bucket = 2000000 ):
@@ -744,54 +768,43 @@ def init_ngrams_weights(self, wv, update=False, vocabulary=None):
744768
745769 """
746770 if not update :
747- wv .ngrams = {}
748771 wv .vectors_vocab = empty ((len (wv .vocab ), wv .vector_size ), dtype = REAL )
749772 self .vectors_vocab_lockf = ones ((len (wv .vocab ), wv .vector_size ), dtype = REAL )
750773
751774 wv .vectors_ngrams = empty ((self .bucket , wv .vector_size ), dtype = REAL )
752775 self .vectors_ngrams_lockf = ones ((self .bucket , wv .vector_size ), dtype = REAL )
753776
754- all_ngrams = []
755- for w , ngrams in iteritems (wv .ngrams_word ):
756- all_ngrams += ngrams
757-
758- all_ngrams = list (set (all_ngrams ))
759- wv .num_ngram_vectors = len (all_ngrams )
760- logger .info ("Total number of ngrams is %d" , len (all_ngrams ))
761-
762777 wv .hash2index = {}
763778 ngram_indices = []
764779 new_hash_count = 0
765- for i , ngram in enumerate (all_ngrams ):
766- ngram_hash = _ft_hash (ngram ) % self .bucket
767- if ngram_hash in wv .hash2index :
768- wv .ngrams [ngram ] = wv .hash2index [ngram_hash ]
769- else :
770- ngram_indices .append (ngram_hash % self .bucket )
771- wv .hash2index [ngram_hash ] = new_hash_count
772- wv .ngrams [ngram ] = wv .hash2index [ngram_hash ]
773- new_hash_count = new_hash_count + 1
780+ wv .num_ngram_vectors = 0
781+ for word in wv .vocab .keys ():
782+ for ngram in _compute_ngrams (word , wv .min_n , wv .max_n ):
783+ ngram_hash = _ft_hash (ngram ) % self .bucket
784+ if ngram_hash not in wv .hash2index :
785+ wv .num_ngram_vectors += 1
786+ ngram_indices .append (ngram_hash )
787+ wv .hash2index [ngram_hash ] = new_hash_count
788+ new_hash_count = new_hash_count + 1
789+
790+ logger .info ("Total number of ngrams is %d" , wv .num_ngram_vectors )
774791
775792 wv .vectors_ngrams = wv .vectors_ngrams .take (ngram_indices , axis = 0 )
776793 self .vectors_ngrams_lockf = self .vectors_ngrams_lockf .take (ngram_indices , axis = 0 )
777794 self .reset_ngrams_weights (wv )
778795 else :
779- new_ngrams = []
780- for w , ngrams in iteritems (wv .ngrams_word ):
781- new_ngrams += [ng for ng in ngrams if ng not in wv .ngrams ]
782-
783- new_ngrams = list (set (new_ngrams ))
784- wv .num_ngram_vectors += len (new_ngrams )
785- logger .info ("Number of new ngrams is %d" , len (new_ngrams ))
786796 new_hash_count = 0
787- for i , ngram in enumerate (new_ngrams ):
788- ngram_hash = _ft_hash (ngram ) % self .bucket
789- if ngram_hash not in wv .hash2index :
790- wv .hash2index [ngram_hash ] = new_hash_count + self .old_hash2index_len
791- wv .ngrams [ngram ] = wv .hash2index [ngram_hash ]
792- new_hash_count = new_hash_count + 1
793- else :
794- wv .ngrams [ngram ] = wv .hash2index [ngram_hash ]
797+ num_new_ngrams = 0
798+ for word in wv .vocab .keys ():
799+ for ngram in _compute_ngrams (word , wv .min_n , wv .max_n ):
800+ ngram_hash = _ft_hash (ngram ) % self .bucket
801+ if ngram_hash not in wv .hash2index :
802+ wv .hash2index [ngram_hash ] = new_hash_count + self .old_hash2index_len
803+ new_hash_count = new_hash_count + 1
804+ num_new_ngrams += 1
805+
806+ wv .num_ngram_vectors += num_new_ngrams
807+ logger .info ("Number of new ngrams is %d" , num_new_ngrams )
795808
796809 rand_obj = np .random
797810 rand_obj .seed (self .seed )
@@ -833,10 +846,11 @@ def get_vocab_word_vecs(self, wv):
833846 """Calculate vectors for words in vocabulary and stores them in `vectors`."""
834847 for w , v in wv .vocab .items ():
835848 word_vec = np .copy (wv .vectors_vocab [v .index ])
836- ngrams = wv .ngrams_word [ w ]
849+ ngrams = _compute_ngrams ( w , wv .min_n , wv . max_n )
837850 ngram_weights = wv .vectors_ngrams
838851 for ngram in ngrams :
839- word_vec += ngram_weights [wv .ngrams [ngram ]]
852+ word_vec += ngram_weights [
853+ wv .hash2index [_ft_hash (ngram ) % self .bucket ]]
840854 word_vec /= (len (ngrams ) + 1 )
841855 wv .vectors [v .index ] = word_vec
842856
@@ -847,20 +861,21 @@ def init_ngrams_post_load(self, file_name, wv):
847861 vectors are discarded here to save space.
848862
849863 """
850- all_ngrams = []
851864 wv .vectors = np .zeros ((len (wv .vocab ), wv .vector_size ), dtype = REAL )
852865
853866 for w , vocab in wv .vocab .items ():
854- all_ngrams += _compute_ngrams (w , wv .min_n , wv .max_n )
855867 wv .vectors [vocab .index ] += np .array (wv .vectors_ngrams [vocab .index ])
856868
857- all_ngrams = set (all_ngrams )
858- wv .num_ngram_vectors = len (all_ngrams )
859869 ngram_indices = []
860- for i , ngram in enumerate (all_ngrams ):
861- ngram_hash = _ft_hash (ngram )
862- ngram_indices .append (len (wv .vocab ) + ngram_hash % self .bucket )
863- wv .ngrams [ngram ] = i
870+ wv .num_ngram_vectors = 0
871+ for word in wv .vocab .keys ():
872+ for ngram in _compute_ngrams (word , wv .min_n , wv .max_n ):
873+ ngram_hash = _ft_hash (ngram ) % self .bucket
874+ if ngram_hash in wv .hash2index :
875+ continue
876+ ngram_indices .append (len (wv .vocab ) + ngram_hash )
877+ wv .hash2index [ngram_hash ] = wv .num_ngram_vectors
878+ wv .num_ngram_vectors += 1
864879 wv .vectors_ngrams = wv .vectors_ngrams .take (ngram_indices , axis = 0 )
865880
866881 ngram_weights = wv .vectors_ngrams
@@ -873,7 +888,8 @@ def init_ngrams_post_load(self, file_name, wv):
873888 for w , vocab in wv .vocab .items ():
874889 word_ngrams = _compute_ngrams (w , wv .min_n , wv .max_n )
875890 for word_ngram in word_ngrams :
876- wv .vectors [vocab .index ] += np .array (ngram_weights [wv .ngrams [word_ngram ]])
891+ vec_idx = wv .hash2index [_ft_hash (word_ngram ) % self .bucket ]
892+ wv .vectors [vocab .index ] += np .array (ngram_weights [vec_idx ])
877893
878894 wv .vectors [vocab .index ] /= (len (word_ngrams ) + 1 )
879895 logger .info (
0 commit comments