Skip to content

Commit a1d539f

Browse files
chinmayapancholi13fabriciorsf
authored andcommitted
Add Sklearn API for Gensim models (piskvorky#1462)
* created sklearn wrapper for Doc2Vec * PEP8 fix * added 'transform' function and refactored code * updated d2v skl api code * added unittests for sklearn api for d2v model * fixed flake8 errors * added skl api class for Text2Bow model * updated docstring for d2vmodel api * updated text2bow skl api code * added unittests for text2bow skl api class * updated 'testPipeline' and 'testTransform' for text2bow * added 'tokenizer' param to text2bow skl api * updated unittests for text2bow * removed get_params and set_params functions from existing classes * added tfidf api class * added unittests for tfidf api class * flake8 fixes * added skl api for hdpmodel * added unittests for hdp model api class * flake8 fixes * updated hdp api class * added 'testPartialFit' and 'testPipeline' tests for hdp api class * flake8 fixes * added skl API class for phrases * added unit tests for phrases API class * flake8 fixes * added 'testPartialFit' function for 'TestPhrasesTransformer' * updated 'testPipeline' function for 'TestText2BowTransformer' * updated code for transform function for HDP transformer * updated tests as discussed in PR 1473 * added examples for new models in ipynb * unpinned sklearn version for running unit-tests * updated 'Pipeline' initialization format * updated 'Pipeline' initialization format in ipynb
1 parent b4bd541 commit a1d539f

9 files changed

Lines changed: 1217 additions & 45 deletions

File tree

docs/notebooks/sklearn_api.ipynb

Lines changed: 447 additions & 36 deletions
Large diffs are not rendered by default.

gensim/sklearn_api/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,8 @@
1717
from .ldaseqmodel import LdaSeqTransformer # noqa: F401
1818
from .w2vmodel import W2VTransformer # noqa: F401
1919
from .atmodel import AuthorTopicTransformer # noqa: F401
20+
from .d2vmodel import D2VTransformer # noqa: F401
21+
from .text2bow import Text2BowTransformer # noqa: F401
22+
from .tfidf import TfIdfTransformer # noqa: F401
23+
from .hdp import HdpTransformer # noqa: F401
24+
from .phrases import PhrasesTransformer # noqa: F401

gensim/sklearn_api/d2vmodel.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
5+
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
6+
7+
"""
8+
Scikit learn interface for gensim for easy use of gensim with scikit-learn
9+
Follows scikit-learn API conventions
10+
"""
11+
12+
import numpy as np
13+
from six import string_types
14+
from sklearn.base import TransformerMixin, BaseEstimator
15+
from sklearn.exceptions import NotFittedError
16+
17+
from gensim import models
18+
19+
20+
class D2VTransformer(TransformerMixin, BaseEstimator):
21+
"""
22+
Base Doc2Vec module
23+
"""
24+
25+
def __init__(self, dm_mean=None, dm=1, dbow_words=0, dm_concat=0,
26+
dm_tag_count=1, docvecs=None, docvecs_mapfile=None,
27+
comment=None, trim_rule=None, size=100, alpha=0.025,
28+
window=5, min_count=5, max_vocab_size=None, sample=1e-3,
29+
seed=1, workers=3, min_alpha=0.0001, hs=0, negative=5,
30+
cbow_mean=1, hashfxn=hash, iter=5, sorted_vocab=1,
31+
batch_words=10000):
32+
"""
33+
Sklearn api for Doc2Vec model. See gensim.models.Doc2Vec and gensim.models.Word2Vec for parameter details.
34+
"""
35+
self.gensim_model = None
36+
self.dm_mean = dm_mean
37+
self.dm = dm
38+
self.dbow_words = dbow_words
39+
self.dm_concat = dm_concat
40+
self.dm_tag_count = dm_tag_count
41+
self.docvecs = docvecs
42+
self.docvecs_mapfile = docvecs_mapfile
43+
self.comment = comment
44+
self.trim_rule = trim_rule
45+
46+
# attributes associated with gensim.models.Word2Vec
47+
self.size = size
48+
self.alpha = alpha
49+
self.window = window
50+
self.min_count = min_count
51+
self.max_vocab_size = max_vocab_size
52+
self.sample = sample
53+
self.seed = seed
54+
self.workers = workers
55+
self.min_alpha = min_alpha
56+
self.hs = hs
57+
self.negative = negative
58+
self.cbow_mean = int(cbow_mean)
59+
self.hashfxn = hashfxn
60+
self.iter = iter
61+
self.sorted_vocab = sorted_vocab
62+
self.batch_words = batch_words
63+
64+
def fit(self, X, y=None):
65+
"""
66+
Fit the model according to the given training data.
67+
Calls gensim.models.Doc2Vec
68+
"""
69+
self.gensim_model = models.Doc2Vec(documents=X, dm_mean=self.dm_mean, dm=self.dm,
70+
dbow_words=self.dbow_words, dm_concat=self.dm_concat, dm_tag_count=self.dm_tag_count,
71+
docvecs=self.docvecs, docvecs_mapfile=self.docvecs_mapfile, comment=self.comment,
72+
trim_rule=self.trim_rule, size=self.size, alpha=self.alpha, window=self.window,
73+
min_count=self.min_count, max_vocab_size=self.max_vocab_size, sample=self.sample,
74+
seed=self.seed, workers=self.workers, min_alpha=self.min_alpha, hs=self.hs,
75+
negative=self.negative, cbow_mean=self.cbow_mean, hashfxn=self.hashfxn,
76+
iter=self.iter, sorted_vocab=self.sorted_vocab, batch_words=self.batch_words)
77+
return self
78+
79+
def transform(self, docs):
80+
"""
81+
Return the vector representations for the input documents.
82+
The input `docs` should be a list of lists like : [ ['calculus', 'mathematical'], ['geometry', 'operations', 'curves'] ]
83+
or a single document like : ['calculus', 'mathematical']
84+
"""
85+
if self.gensim_model is None:
86+
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")
87+
88+
# The input as array of array
89+
check = lambda x: [x] if isinstance(x[0], string_types) else x
90+
docs = check(docs)
91+
X = [[] for _ in range(0, len(docs))]
92+
93+
for k, v in enumerate(docs):
94+
doc_vec = self.gensim_model.infer_vector(v)
95+
X[k] = doc_vec
96+
97+
return np.reshape(np.array(X), (len(docs), self.gensim_model.vector_size))

gensim/sklearn_api/hdp.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
5+
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
6+
7+
"""
8+
Scikit learn interface for gensim for easy use of gensim with scikit-learn
9+
Follows scikit-learn API conventions
10+
"""
11+
12+
import numpy as np
13+
from scipy import sparse
14+
from sklearn.base import TransformerMixin, BaseEstimator
15+
from sklearn.exceptions import NotFittedError
16+
17+
from gensim import models
18+
from gensim import matutils
19+
20+
21+
class HdpTransformer(TransformerMixin, BaseEstimator):
22+
"""
23+
Base HDP module
24+
"""
25+
26+
def __init__(self, id2word, max_chunks=None, max_time=None,
27+
chunksize=256, kappa=1.0, tau=64.0, K=15, T=150, alpha=1,
28+
gamma=1, eta=0.01, scale=1.0, var_converge=0.0001,
29+
outputdir=None, random_state=None):
30+
"""
31+
Sklearn api for HDP model. See gensim.models.HdpModel for parameter details.
32+
"""
33+
self.gensim_model = None
34+
self.id2word = id2word
35+
self.max_chunks = max_chunks
36+
self.max_time = max_time
37+
self.chunksize = chunksize
38+
self.kappa = kappa
39+
self.tau = tau
40+
self.K = K
41+
self.T = T
42+
self.alpha = alpha
43+
self.gamma = gamma
44+
self.eta = eta
45+
self.scale = scale
46+
self.var_converge = var_converge
47+
self.outputdir = outputdir
48+
self.random_state = random_state
49+
50+
def fit(self, X, y=None):
51+
"""
52+
Fit the model according to the given training data.
53+
Calls gensim.models.HdpModel
54+
"""
55+
if sparse.issparse(X):
56+
corpus = matutils.Sparse2Corpus(X)
57+
else:
58+
corpus = X
59+
60+
self.gensim_model = models.HdpModel(corpus=corpus, id2word=self.id2word, max_chunks=self.max_chunks,
61+
max_time=self.max_time, chunksize=self.chunksize, kappa=self.kappa, tau=self.tau,
62+
K=self.K, T=self.T, alpha=self.alpha, gamma=self.gamma, eta=self.eta, scale=self.scale,
63+
var_converge=self.var_converge, outputdir=self.outputdir, random_state=self.random_state)
64+
return self
65+
66+
def transform(self, docs):
67+
"""
68+
Takes a list of documents as input ('docs').
69+
Returns a matrix of topic distribution for the given document bow, where a_ij
70+
indicates (topic_i, topic_probability_j).
71+
The input `docs` should be in BOW format and can be a list of documents like : [ [(4, 1), (7, 1)], [(9, 1), (13, 1)], [(2, 1), (6, 1)] ]
72+
or a single document like : [(4, 1), (7, 1)]
73+
"""
74+
if self.gensim_model is None:
75+
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")
76+
77+
# The input as array of array
78+
check = lambda x: [x] if isinstance(x[0], tuple) else x
79+
docs = check(docs)
80+
X = [[] for _ in range(0, len(docs))]
81+
82+
max_num_topics = 0
83+
for k, v in enumerate(docs):
84+
X[k] = self.gensim_model[v]
85+
max_num_topics = max(max_num_topics, max(list(map(lambda x: x[0], X[k]))) + 1)
86+
87+
for k, v in enumerate(X):
88+
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
89+
dense_vec = matutils.sparse2full(v, max_num_topics)
90+
X[k] = dense_vec
91+
92+
return np.reshape(np.array(X), (len(docs), max_num_topics))
93+
94+
def partial_fit(self, X):
95+
"""
96+
Train model over X.
97+
"""
98+
if sparse.issparse(X):
99+
X = matutils.Sparse2Corpus(X)
100+
101+
if self.gensim_model is None:
102+
self.gensim_model = models.HdpModel(id2word=self.id2word, max_chunks=self.max_chunks,
103+
max_time=self.max_time, chunksize=self.chunksize, kappa=self.kappa, tau=self.tau,
104+
K=self.K, T=self.T, alpha=self.alpha, gamma=self.gamma, eta=self.eta, scale=self.scale,
105+
var_converge=self.var_converge, outputdir=self.outputdir, random_state=self.random_state)
106+
107+
self.gensim_model.update(corpus=X)
108+
return self

gensim/sklearn_api/phrases.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
5+
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
6+
7+
"""
8+
Scikit learn interface for gensim for easy use of gensim with scikit-learn
9+
Follows scikit-learn API conventions
10+
"""
11+
12+
from six import string_types
13+
from sklearn.base import TransformerMixin, BaseEstimator
14+
from sklearn.exceptions import NotFittedError
15+
16+
from gensim import models
17+
18+
19+
class PhrasesTransformer(TransformerMixin, BaseEstimator):
20+
"""
21+
Base Phrases module
22+
"""
23+
24+
def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000,
25+
delimiter=b'_', progress_per=10000):
26+
"""
27+
Sklearn wrapper for Phrases model.
28+
"""
29+
self.gensim_model = None
30+
self.min_count = min_count
31+
self.threshold = threshold
32+
self.max_vocab_size = max_vocab_size
33+
self.delimiter = delimiter
34+
self.progress_per = progress_per
35+
36+
def fit(self, X, y=None):
37+
"""
38+
Fit the model according to the given training data.
39+
"""
40+
self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold,
41+
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per)
42+
return self
43+
44+
def transform(self, docs):
45+
"""
46+
Return the input documents to return phrase tokens.
47+
"""
48+
if self.gensim_model is None:
49+
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")
50+
51+
# input as python lists
52+
check = lambda x: [x] if isinstance(x[0], string_types) else x
53+
docs = check(docs)
54+
X = [[] for _ in range(0, len(docs))]
55+
56+
for k, v in enumerate(docs):
57+
phrase_tokens = self.gensim_model[v]
58+
X[k] = phrase_tokens
59+
60+
return X
61+
62+
def partial_fit(self, X):
63+
if self.gensim_model is None:
64+
self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold,
65+
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per)
66+
67+
self.gensim_model.add_vocab(X)
68+
return self

gensim/sklearn_api/text2bow.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
5+
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
6+
7+
"""
8+
Scikit learn interface for gensim for easy use of gensim with scikit-learn
9+
Follows scikit-learn API conventions
10+
"""
11+
12+
from six import string_types
13+
from sklearn.base import TransformerMixin, BaseEstimator
14+
from sklearn.exceptions import NotFittedError
15+
16+
from gensim.corpora import Dictionary
17+
from gensim.utils import tokenize
18+
19+
20+
class Text2BowTransformer(TransformerMixin, BaseEstimator):
21+
"""
22+
Base Text2Bow module
23+
"""
24+
25+
def __init__(self, prune_at=2000000, tokenizer=tokenize):
26+
"""
27+
Sklearn wrapper for Text2Bow model.
28+
"""
29+
self.gensim_model = None
30+
self.prune_at = prune_at
31+
self.tokenizer = tokenizer
32+
33+
def fit(self, X, y=None):
34+
"""
35+
Fit the model according to the given training data.
36+
"""
37+
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), X))
38+
self.gensim_model = Dictionary(documents=tokenized_docs, prune_at=self.prune_at)
39+
return self
40+
41+
def transform(self, docs):
42+
"""
43+
Return the BOW format for the input documents.
44+
"""
45+
if self.gensim_model is None:
46+
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")
47+
48+
# input as python lists
49+
check = lambda x: [x] if isinstance(x, string_types) else x
50+
docs = check(docs)
51+
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), docs))
52+
X = [[] for _ in range(0, len(tokenized_docs))]
53+
54+
for k, v in enumerate(tokenized_docs):
55+
bow_val = self.gensim_model.doc2bow(v)
56+
X[k] = bow_val
57+
58+
return X
59+
60+
def partial_fit(self, X):
61+
if self.gensim_model is None:
62+
self.gensim_model = Dictionary(prune_at=self.prune_at)
63+
64+
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), X))
65+
self.gensim_model.add_documents(tokenized_docs)
66+
return self

0 commit comments

Comments
 (0)