[WIP] Added sklearn wrapper for LDASeq model#1405
Conversation
| """ | ||
| Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel | ||
| """ | ||
| self.corpus = None |
There was a problem hiding this comment.
Why you needed a field for a corpus?
There was a problem hiding this comment.
@menshikh-iv In my opinion, the user might be interested to know about the corpus used for training the model (using the get_params function). Should we continue to store this value?
There was a problem hiding this comment.
@chinmayapancholi13 No, sklearn does not store X, so we should not
There was a problem hiding this comment.
@menshikh-iv Yes, that is true for sklearn. Removing corpus attribute from all the wrappers then.
| Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel | ||
| """ | ||
| self.corpus = None | ||
| self.model = None |
There was a problem hiding this comment.
Please do this field "private" (start with underscores)
| initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10, | ||
| random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100) | ||
| """ | ||
| self.corpus = X |
| """ | ||
| Fit the model according to the given training data. | ||
| Calls gensim.models.LdaSeqModel: | ||
| >>> gensim.models.LdaSeqModel(corpus=None, time_slice=None, id2word=None, alphas=0.01, num_topics=10, |
There was a problem hiding this comment.
Please remove this block >>> ... , this example does not help for a new user.
There was a problem hiding this comment.
@menshikh-iv Should we remove this >>> .... statement in all the model wrappers? This line basically tells us how the associated Gensim model is actually called.
There was a problem hiding this comment.
You just need to specify the class that is used (you have already done above) and write where a user can read the documentation.
| em_min_iter=self.em_min_iter, em_max_iter=self.em_max_iter, chunksize=self.chunksize) | ||
| return self | ||
|
|
||
| def transform(self, docs): |
There was a problem hiding this comment.
Chek case, when you create instance and call transform immediately (without fit), you need to raise exception like sklearn
There was a problem hiding this comment.
Also, please add an example of docs param in docstring.
There was a problem hiding this comment.
@menshikh-iv For checking if the model has been fitted, would it be a good idea to check if self.gensim_model is None or not? This approach would clearly give an error when fit hasn't been called before calling transform but this also allows the user to set the value of self.gensim_model through set_params function (or even as wrapper.gensim_model=...) and then call transform function, which makes sense for us to allow.
There was a problem hiding this comment.
I completely forgot about set_param, so, I think if you disable gensim_model in set_param, you can check model is None (it does not cover all cases, but covers the most obvious)
There was a problem hiding this comment.
Could you elaborate the meaning of "disabling" gensim_model param from the function set_params?
Actually, gensim_model is a public attribute of the model so it can be set like ldaseq_wrapper.gensim_model = some_model, which is almost the same as using set_params function to set this value. So, checking whether self.gensim_model is None should be enough, right?
This would be like :
def transform(self, docs):
"""
Return the topic proportions for the documents passed.
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")
# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
..........................................................................
..........................................................................
..........................................................................
..........................................................................
There was a problem hiding this comment.
Ok, as a temporary option.
| return np.reshape(np.array(X), (len(docs), self.num_topics)) | ||
|
|
||
| def partial_fit(self, X): | ||
| raise NotImplementedError("'partial_fit' has not been implemented for the LDA Seq model") |
There was a problem hiding this comment.
LDA Seq model -> SklLdaSeqModel
| for key in param_dict.keys(): | ||
| self.assertEqual(model_params[key], param_dict[key]) | ||
|
|
||
|
|
There was a problem hiding this comment.
Add persistence test with pickle
There was a problem hiding this comment.
And add test with pipeline
| score = text_ldaseq.score(corpus, test_target) | ||
| self.assertGreater(score, 0.50) | ||
|
|
||
| def testPersistence(self): |
There was a problem hiding this comment.
It's sanity check only.
For persistence, you need to compare current and loaded models. For this purpose, you need to compare current and loaded inner matrices OR get corpus, transform it with both variant and compare results
There was a problem hiding this comment.
Thanks. I have now added code for comparing the vectors transformed from original and loaded models, in addition to this sanity check. :)
| text_ldaseq = Pipeline((('features', model,), ('classifier', clf))) | ||
| text_ldaseq.fit(corpus, test_target) | ||
| score = text_ldaseq.score(corpus, test_target) | ||
| self.assertGreater(score, 0.50) |
There was a problem hiding this comment.
It's will be correct every time? No needed to fix seeds for reproducibility?
There was a problem hiding this comment.
We now have a fixed seed which is set before the test testPipeline to ensure that we get similar values.
|
Thank you @chinmayapancholi13 👍 |
This PR adds a scikit-learn wrapper for Gensim's LDASeq model.