2929try :
3030 from pyemd import emd # noqa:F401
3131 PYEMD_EXT = True
32- except ImportError :
32+ except ( ImportError , ValueError ) :
3333 PYEMD_EXT = False
3434
3535sentences = [doc2vec .TaggedDocument (words , [i ]) for i , words in enumerate (texts )]
@@ -78,9 +78,8 @@ def testFull(self, num_best=None, shardsize=100):
7878 index .destroy ()
7979
8080 def testNumBest (self ):
81-
8281 if self .cls == similarities .WmdSimilarity and not PYEMD_EXT :
83- return
82+ self . skipTest ( "pyemd not installed or have some issues" )
8483
8584 for num_best in [None , 0 , 1 , 9 , 1000 ]:
8685 self .testFull (num_best = num_best )
@@ -110,6 +109,9 @@ def test_scipy2scipy_clipped(self):
110109
111110 def testEmptyQuery (self ):
112111 index = self .factoryMethod ()
112+ if isinstance (index , similarities .WmdSimilarity ) and not PYEMD_EXT :
113+ self .skipTest ("pyemd not installed or have some issues" )
114+
113115 query = []
114116 try :
115117 sims = index [query ]
@@ -166,7 +168,7 @@ def testIter(self):
166168
167169 def testPersistency (self ):
168170 if self .cls == similarities .WmdSimilarity and not PYEMD_EXT :
169- return
171+ self . skipTest ( "pyemd not installed or have some issues" )
170172
171173 fname = get_tmpfile ('gensim_similarities.tst.pkl' )
172174 index = self .factoryMethod ()
@@ -186,7 +188,7 @@ def testPersistency(self):
186188
187189 def testPersistencyCompressed (self ):
188190 if self .cls == similarities .WmdSimilarity and not PYEMD_EXT :
189- return
191+ self . skipTest ( "pyemd not installed or have some issues" )
190192
191193 fname = get_tmpfile ('gensim_similarities.tst.pkl.gz' )
192194 index = self .factoryMethod ()
@@ -206,7 +208,7 @@ def testPersistencyCompressed(self):
206208
207209 def testLarge (self ):
208210 if self .cls == similarities .WmdSimilarity and not PYEMD_EXT :
209- return
211+ self . skipTest ( "pyemd not installed or have some issues" )
210212
211213 fname = get_tmpfile ('gensim_similarities.tst.pkl' )
212214 index = self .factoryMethod ()
@@ -228,7 +230,7 @@ def testLarge(self):
228230
229231 def testLargeCompressed (self ):
230232 if self .cls == similarities .WmdSimilarity and not PYEMD_EXT :
231- return
233+ self . skipTest ( "pyemd not installed or have some issues" )
232234
233235 fname = get_tmpfile ('gensim_similarities.tst.pkl.gz' )
234236 index = self .factoryMethod ()
@@ -250,7 +252,7 @@ def testLargeCompressed(self):
250252
251253 def testMmap (self ):
252254 if self .cls == similarities .WmdSimilarity and not PYEMD_EXT :
253- return
255+ self . skipTest ( "pyemd not installed or have some issues" )
254256
255257 fname = get_tmpfile ('gensim_similarities.tst.pkl' )
256258 index = self .factoryMethod ()
@@ -273,7 +275,7 @@ def testMmap(self):
273275
274276 def testMmapCompressed (self ):
275277 if self .cls == similarities .WmdSimilarity and not PYEMD_EXT :
276- return
278+ self . skipTest ( "pyemd not installed or have some issues" )
277279
278280 fname = get_tmpfile ('gensim_similarities.tst.pkl.gz' )
279281 index = self .factoryMethod ()
@@ -298,12 +300,10 @@ def factoryMethod(self):
298300 # Override factoryMethod.
299301 return self .cls (texts , self .w2v_model )
300302
303+ @unittest .skipIf (PYEMD_EXT is False , "pyemd not installed or have some issues" )
301304 def testFull (self , num_best = None ):
302305 # Override testFull.
303306
304- if not PYEMD_EXT :
305- return
306-
307307 index = self .cls (texts , self .w2v_model )
308308 index .num_best = num_best
309309 query = texts [0 ]
@@ -319,15 +319,13 @@ def testFull(self, num_best=None):
319319 self .assertTrue (numpy .alltrue (sims [1 :] > 0.0 ))
320320 self .assertTrue (numpy .alltrue (sims [1 :] < 1.0 ))
321321
322+ @unittest .skipIf (PYEMD_EXT is False , "pyemd not installed or have some issues" )
322323 def testNonIncreasing (self ):
323324 ''' Check that similarities are non-increasing when `num_best` is not
324325 `None`.'''
325326 # NOTE: this could be implemented for other similarities as well (i.e.
326327 # in _TestSimilarityABC).
327328
328- if not PYEMD_EXT :
329- return
330-
331329 index = self .cls (texts , self .w2v_model , num_best = 3 )
332330 query = texts [0 ]
333331 sims = index [query ]
@@ -337,12 +335,10 @@ def testNonIncreasing(self):
337335 cond = sum (numpy .diff (sims2 ) < 0 ) == len (sims2 ) - 1
338336 self .assertTrue (cond )
339337
338+ @unittest .skipIf (PYEMD_EXT is False , "pyemd not installed or have some issues" )
340339 def testChunking (self ):
341340 # Override testChunking.
342341
343- if not PYEMD_EXT :
344- return
345-
346342 index = self .cls (texts , self .w2v_model )
347343 query = texts [:3 ]
348344 sims = index [query ]
@@ -358,12 +354,10 @@ def testChunking(self):
358354 self .assertTrue (numpy .alltrue (sim > 0.0 ))
359355 self .assertTrue (numpy .alltrue (sim <= 1.0 ))
360356
357+ @unittest .skipIf (PYEMD_EXT is False , "pyemd not installed or have some issues" )
361358 def testIter (self ):
362359 # Override testIter.
363360
364- if not PYEMD_EXT :
365- return
366-
367361 index = self .cls (texts , self .w2v_model )
368362 for sims in index :
369363 self .assertTrue (numpy .alltrue (sims >= 0.0 ))
0 commit comments