88
99@pytest .fixture
1010def ex ():
11- return Experiment (' tensorflow_tests' )
11+ return Experiment (" tensorflow_tests" )
1212
1313
1414@pytest .fixture ()
@@ -18,6 +18,7 @@ def tf():
1818 so `tensorflow` is not required during the tests.
1919 """
2020 from sacred .optional import has_tensorflow
21+
2122 if has_tensorflow :
2223 return opt .get_tensorflow ()
2324 else :
@@ -28,7 +29,10 @@ class FileWriter:
2829 def __init__ (self , logdir , graph ):
2930 self .logdir = logdir
3031 self .graph = graph
31- print ("Mocked FileWriter got logdir=%s, graph=%s" % (logdir , graph ))
32+ print (
33+ "Mocked FileWriter got logdir=%s, graph=%s"
34+ % (logdir , graph )
35+ )
3236
3337 class Session :
3438 def __init__ (self ):
@@ -42,6 +46,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
4246
4347 # Set stflow to use the mock as the test
4448 import sacred .stflow .method_interception
49+
4550 sacred .stflow .method_interception .tf = tensorflow
4651 return tensorflow
4752
@@ -90,7 +95,10 @@ def run_experiment(_run):
9095 assert swr is not None
9196 assert _run .info ["tensorflow" ]["logdirs" ] == [TEST_LOG_DIR ]
9297 tf .summary .FileWriter (TEST_LOG_DIR2 , s .graph )
93- assert _run .info ["tensorflow" ]["logdirs" ] == [TEST_LOG_DIR , TEST_LOG_DIR2 ]
98+ assert _run .info ["tensorflow" ]["logdirs" ] == [
99+ TEST_LOG_DIR ,
100+ TEST_LOG_DIR2 ,
101+ ]
94102
95103 # This should not be captured:
96104 tf .summary .FileWriter ("/tmp/whatever" , s .graph )
@@ -133,7 +141,7 @@ def test_log_summary_writer_class(ex, tf):
133141 TEST_LOG_DIR = "/dev/null"
134142 TEST_LOG_DIR2 = "/tmp/sacred_test"
135143
136- class FooClass () :
144+ class FooClass :
137145 def __init__ (self ):
138146 pass
139147
0 commit comments