@@ -656,7 +656,7 @@ def test_bind(self):
656656 for session in SESSIONS :
657657 session .create .assert_not_called ()
658658 txn = session ._transaction
659- self . assertTrue ( txn ._begun )
659+ txn .begin . assert_called_once_with ( )
660660
661661 self .assertTrue (pool ._pending_sessions .empty ())
662662
@@ -685,7 +685,7 @@ def test_bind_w_timestamp_race(self):
685685 for session in SESSIONS :
686686 session .create .assert_not_called ()
687687 txn = session ._transaction
688- self . assertTrue ( txn ._begun )
688+ txn .begin . assert_called_once_with ( )
689689
690690 self .assertTrue (pool ._pending_sessions .empty ())
691691
@@ -718,7 +718,7 @@ def test_put_non_full_w_active_txn(self):
718718 self .assertIs (queued , session )
719719
720720 self .assertEqual (len (pending ._items ), 0 )
721- self . assertFalse ( txn ._begun )
721+ txn .begin . assert_not_called ( )
722722
723723 def test_put_non_full_w_committed_txn (self ):
724724 pool = self ._make_one (size = 1 )
@@ -727,7 +727,7 @@ def test_put_non_full_w_committed_txn(self):
727727 database = _Database ("name" )
728728 session = _Session (database )
729729 committed = session .transaction ()
730- committed ._committed = True
730+ committed .committed = True
731731
732732 pool .put (session )
733733
@@ -736,7 +736,7 @@ def test_put_non_full_w_committed_txn(self):
736736 self .assertEqual (len (pending ._items ), 1 )
737737 self .assertIs (pending ._items [0 ], session )
738738 self .assertIsNot (session ._transaction , committed )
739- self . assertFalse ( session ._transaction ._begun )
739+ session ._transaction .begin . assert_not_called ( )
740740
741741 def test_put_non_full (self ):
742742 pool = self ._make_one (size = 1 )
@@ -762,7 +762,7 @@ def test_begin_pending_transactions_non_empty(self):
762762 pool ._sessions = _Queue ()
763763
764764 database = _Database ("name" )
765- TRANSACTIONS = [_Transaction ( )]
765+ TRANSACTIONS = [_make_transaction ( object () )]
766766 PENDING_SESSIONS = [_Session (database , transaction = txn ) for txn in TRANSACTIONS ]
767767
768768 pending = pool ._pending_sessions = _Queue (* PENDING_SESSIONS )
@@ -771,7 +771,7 @@ def test_begin_pending_transactions_non_empty(self):
771771 pool .begin_pending_transactions () # no raise
772772
773773 for txn in TRANSACTIONS :
774- self . assertTrue ( txn ._begun )
774+ txn .begin . assert_called_once_with ( )
775775
776776 self .assertTrue (pending .empty ())
777777
@@ -832,17 +832,13 @@ def test_context_manager_w_kwargs(self):
832832 self .assertEqual (pool ._got , {"foo" : "bar" })
833833
834834
835- class _Transaction (object ):
835+ def _make_transaction (* args , ** kw ):
836+ from google .cloud .spanner_v1 .transaction import Transaction
836837
837- _begun = False
838- _committed = False
839- _rolled_back = False
840-
841- def begin (self ):
842- self ._begun = True
843-
844- def committed (self ):
845- return self ._committed
838+ txn = mock .create_autospec (Transaction )(* args , ** kw )
839+ txn .committed = None
840+ txn ._rolled_back = False
841+ return txn
846842
847843
848844@total_ordering
@@ -873,7 +869,7 @@ def delete(self):
873869 raise NotFound ("unknown session" )
874870
875871 def transaction (self ):
876- txn = self ._transaction = _Transaction ( )
872+ txn = self ._transaction = _make_transaction ( self )
877873 return txn
878874
879875
0 commit comments