-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Add Pivot Normalization for gensim.models.TfidfModel. Fix #220
#1780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
efb7e3c
pivot normalization
markroxor b7d07d4
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor e8a3f16
verify weights
markroxor 648bf21
verify weights
markroxor a6f1afb
smartirs ready
markroxor d091138
change old tests
markroxor 951c549
remove lambdas
markroxor 40c0558
address suggestions
markroxor b35344c
minor fix
markroxor 634d595
pep8 fix
markroxor 0917e75
pep8 fix
markroxor bef79cc
numpy style doc strings
markroxor d3d431c
fix pickle problem
menshikh-iv 0e6f21e
flake8 fix
markroxor 7ee7560
fix bug in docstring
menshikh-iv b2def84
added few tests
markroxor 5b2d37a
fix normalize issue for pickling
markroxor ac4b154
fix normalize issue for pickling
markroxor 0bacc08
test without sklearn api
markroxor 51e0eb9
Merge branch 'smartirs' of github.com:markroxor/gensim into smartirs
markroxor 3039732
hanging idents and new tests
markroxor 99e6a6f
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor 7d63d9c
Merge branch 'smartirs' of github.com:markroxor/gensim into smartirs
markroxor e5140f8
add docstring
markroxor 4afbadd
add docstring
markroxor d2fe235
Merge branch 'smartirs' of github.com:markroxor/gensim into smartirs
markroxor 5565c78
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor 099dbdf
merge conflicts fix
markroxor ef67f63
pivotized normalization
markroxor 52ee3c4
better way cmparing floats
markroxor 3087030
pass tests
markroxor 62bba1b
pass tests
markroxor 0a9f816
merging
markroxor dc63ab9
Merge branch 'pivot_norm' of github.com:markroxor/gensim into pivot_norm
markroxor 035c8c5
merge develop
markroxor dc4ca52
added benchmarks
markroxor 1ee449d
address comments
markroxor 4ea6caa
benchmarking
markroxor b3cead6
testing pipeline
markroxor 044332b
pivoted normalisation
markroxor 1c2196c
taking overall norm
markroxor 309b4e8
Update tfidfmodel.py
markroxor 3866a9c
Update sklearn_api.ipynb
markroxor 12b42e6
tests for pivoted normalization
markroxor 0ff6ad7
results
markroxor 65c651b
adding visualizations
markroxor 4a947ba
minor nb changes
markroxor 619bb33
minor nb changes
markroxor f105190
removed self.pivoted_normalisation
markroxor 6410f21
Update test_tfidfmodel.py
markroxor 2eb6fc2
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor a65dccf
minor suggestions
markroxor 8717350
added description
markroxor 95cb630
added description
markroxor 5f46d2f
Merge branch 'pivot_norm' of github.com:markroxor/gensim into pivot_norm
markroxor 2c7115d
last commit
markroxor 1fe46f8
Merge remote-tracking branch 'upstream/develop' into pivot_norm
menshikh-iv 63c8385
cleanup
menshikh-iv 5e87229
cosmetic fixes
menshikh-iv 9f2b02c
changed pivot
markroxor fc701a1
changed pivot
markroxor 1868da5
fixed comments
markroxor File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
422 changes: 422 additions & 0 deletions
422
docs/notebooks/pivoted_document_length_normalisation.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,8 @@ def resolve_weights(smartirs): | |
| Information Retrieval System, a mnemonic scheme for denoting tf-idf weighting | ||
| variants in the vector space model. The mnemonic for representing a combination | ||
| of weights takes the form ddd, where the letters represents the term weighting of the document vector. | ||
| for more information visit [1]_. | ||
| for more information visit `SMART Information Retrieval System | ||
| <https://en.wikipedia.org/wiki/SMART_Information_Retrieval_System>`_. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -54,10 +55,6 @@ def resolve_weights(smartirs): | |
| If `smartirs` is not a string of length 3 or one of the decomposed value | ||
| doesn't fit the list of permissible values | ||
|
|
||
| References | ||
| ---------- | ||
| .. [1] https://en.wikipedia.org/wiki/SMART_Information_Retrieval_System | ||
|
|
||
| """ | ||
| if not isinstance(smartirs, str) or len(smartirs) != 3: | ||
| raise ValueError("Expected a string of length 3 except got " + smartirs) | ||
|
|
@@ -70,7 +67,7 @@ def resolve_weights(smartirs): | |
| if w_df not in 'ntp': | ||
| raise ValueError("Expected inverse document frequency weight to be one of 'ntp', except got {}".format(w_df)) | ||
|
|
||
| if w_n not in 'ncb': | ||
| if w_n not in 'nc': | ||
| raise ValueError("Expected normalization weight to be one of 'ncb', except got {}".format(w_n)) | ||
|
|
||
| return w_tf, w_df, w_n | ||
|
|
@@ -177,7 +174,7 @@ def updated_wglobal(docfreq, totaldocs, n_df): | |
| return np.log((1.0 * totaldocs - docfreq) / docfreq) / np.log(2) | ||
|
|
||
|
|
||
| def updated_normalize(x, n_n): | ||
| def updated_normalize(x, n_n, return_norm=False): | ||
|
mpenkov marked this conversation as resolved.
|
||
| """Normalizes the final tf-idf value according to the value of `n_n`. | ||
|
|
||
| Parameters | ||
|
|
@@ -186,17 +183,24 @@ def updated_normalize(x, n_n): | |
| Input array | ||
| n_n : {'n', 'c'} | ||
| Parameter that decides the normalizing function to be used. | ||
| return_norm : bool, optional | ||
| If True - returns the length of vector `x`. | ||
|
|
||
| Returns | ||
| ------- | ||
| numpy.ndarray | ||
| Normalized array. | ||
| float | ||
| Vector length. | ||
|
|
||
| """ | ||
| if n_n == "n": | ||
| return x | ||
| if return_norm: | ||
| return x, 1. | ||
| else: | ||
| return x | ||
| elif n_n == "c": | ||
| return matutils.unitvec(x) | ||
| return matutils.unitvec(x, return_norm=return_norm) | ||
|
|
||
|
|
||
| class TfidfModel(interfaces.TransformationABC): | ||
|
|
@@ -219,7 +223,7 @@ class TfidfModel(interfaces.TransformationABC): | |
| """ | ||
|
|
||
| def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.identity, | ||
| wglobal=df2idf, normalize=True, smartirs=None): | ||
| wglobal=df2idf, normalize=True, smartirs=None, pivot=None, slope=0.65): | ||
| """Compute tf-idf by multiplying a local component (term frequency) with a global component | ||
| (inverse document frequency), and normalizing the resulting documents to unit length. | ||
| Formula for non-normalized weight of term :math:`i` in document :math:`j` in a corpus of :math:`D` documents | ||
|
|
@@ -272,22 +276,41 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden | |
| * `n` - none, | ||
| * `c` - cosine. | ||
|
|
||
| For more information visit [1]_. | ||
|
|
||
| For more information visit `SMART Information Retrieval System | ||
| <https://en.wikipedia.org/wiki/SMART_Information_Retrieval_System>`_. | ||
| pivot : float, optional | ||
| It is the point around which the regular normalization curve is `tilted` to get the new pivoted | ||
| normalization curve. In the paper `Amit Singhal, Chris Buckley, Mandar Mitra: | ||
| "Pivoted Document Length Normalization" <http://singhal.info/pivoted-dln.pdf>`_ it is the point where the | ||
| retrieval and relevance curves intersect. | ||
| This parameter along with slope is used for pivoted document length normalization. | ||
| Only when `pivot` is not None pivoted document length normalization will be applied else regular TfIdf | ||
| is used. | ||
| slope : float, optional | ||
| It is the parameter required by pivoted document length normalization which determines the slope to which | ||
| the `old normalization` can be tilted. This parameter only works when pivot is defined by user and is not | ||
| None. | ||
| """ | ||
|
|
||
| self.id2word = id2word | ||
| self.wlocal, self.wglobal, self.normalize = wlocal, wglobal, normalize | ||
| self.num_docs, self.num_nnz, self.idfs = None, None, None | ||
| self.smartirs = smartirs | ||
| self.slope = slope | ||
| self.pivot = pivot | ||
| self.eps = 1e-12 | ||
|
|
||
| # If smartirs is not None, override wlocal, wglobal and normalize | ||
| if smartirs is not None: | ||
| n_tf, n_df, n_n = resolve_weights(smartirs) | ||
|
|
||
| self.wlocal = partial(updated_wlocal, n_tf=n_tf) | ||
| self.wglobal = partial(updated_wglobal, n_df=n_df) | ||
| self.normalize = partial(updated_normalize, n_n=n_n) | ||
| # also return norm factor if pivot is not none | ||
| if self.pivot is None: | ||
| self.normalize = partial(updated_normalize, n_n=n_n) | ||
| else: | ||
| self.normalize = partial(updated_normalize, n_n=n_n, return_norm=True) | ||
|
|
||
| if dictionary is not None: | ||
| # user supplied a Dictionary object, which already contains all the | ||
|
|
@@ -309,6 +332,23 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden | |
| # be initialized in some other way | ||
| pass | ||
|
|
||
| @classmethod | ||
| def load(cls, *args, **kwargs): | ||
| """ | ||
| Load a previously saved TfidfModel class. Handles backwards compatibility from | ||
| older TfidfModel versions which did not use pivoted document normalization. | ||
| """ | ||
| model = super(TfidfModel, cls).load(*args, **kwargs) | ||
| if not hasattr(model, 'pivot'): | ||
| logger.info('older version of %s loaded without pivot arg', cls.__name__) | ||
| logger.info('Setting pivot to None.') | ||
| model.pivot = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forgot, what's about |
||
| if not hasattr(model, 'slope'): | ||
| logger.info('older version of %s loaded without slope arg', cls.__name__) | ||
| logger.info('Setting slope to 0.65.') | ||
| model.slope = 0.65 | ||
| return model | ||
|
|
||
| def __str__(self): | ||
| return "TfidfModel(num_docs=%s, num_nnz=%s)" % (self.num_docs, self.num_nnz) | ||
|
|
||
|
|
@@ -360,6 +400,7 @@ def __getitem__(self, bow, eps=1e-12): | |
| TfIdf corpus, if `bow` is corpus. | ||
|
|
||
| """ | ||
| self.eps = eps | ||
| # if the input vector is in fact a corpus, return a transformed corpus as a result | ||
| is_corpus, bow = utils.is_corpus(bow) | ||
| if is_corpus: | ||
|
|
@@ -377,7 +418,7 @@ def __getitem__(self, bow, eps=1e-12): | |
|
|
||
| vector = [ | ||
| (termid, tf * self.idfs.get(termid)) | ||
| for termid, tf in zip(termid_array, tf_array) if abs(self.idfs.get(termid, 0.0)) > eps | ||
| for termid, tf in zip(termid_array, tf_array) if abs(self.idfs.get(termid, 0.0)) > self.eps | ||
| ] | ||
|
|
||
| if self.normalize is True: | ||
|
|
@@ -387,8 +428,15 @@ def __getitem__(self, bow, eps=1e-12): | |
|
|
||
| # and finally, normalize the vector either to unit length, or use a | ||
| # user-defined normalization function | ||
| vector = self.normalize(vector) | ||
|
|
||
| # make sure there are no explicit zeroes in the vector (must be sparse) | ||
| vector = [(termid, weight) for termid, weight in vector if abs(weight) > eps] | ||
| return vector | ||
| if self.pivot is None: | ||
| norm_vector = self.normalize(vector) | ||
| norm_vector = [(termid, weight) for termid, weight in norm_vector if abs(weight) > self.eps] | ||
| else: | ||
| _, old_norm = self.normalize(vector, return_norm=True) | ||
| pivoted_norm = (1 - self.slope) * self.pivot + self.slope * old_norm | ||
| norm_vector = [ | ||
| (termid, weight / float(pivoted_norm)) | ||
| for termid, weight in vector | ||
| if abs(weight / float(pivoted_norm)) > self.eps | ||
| ] | ||
| return norm_vector | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.