@@ -2114,6 +2114,7 @@ def test_generate_read_batches_w_max_partitions(self):
21142114 "columns" : self .COLUMNS ,
21152115 "keyset" : {"all" : True },
21162116 "index" : "" ,
2117+ "data_boost_enabled" : False ,
21172118 }
21182119 self .assertEqual (len (batches ), len (self .TOKENS ))
21192120 for batch , token in zip (batches , self .TOKENS ):
@@ -2155,6 +2156,7 @@ def test_generate_read_batches_w_retry_and_timeout_params(self):
21552156 "columns" : self .COLUMNS ,
21562157 "keyset" : {"all" : True },
21572158 "index" : "" ,
2159+ "data_boost_enabled" : False ,
21582160 }
21592161 self .assertEqual (len (batches ), len (self .TOKENS ))
21602162 for batch , token in zip (batches , self .TOKENS ):
@@ -2195,6 +2197,7 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self):
21952197 "columns" : self .COLUMNS ,
21962198 "keyset" : {"all" : True },
21972199 "index" : self .INDEX ,
2200+ "data_boost_enabled" : False ,
21982201 }
21992202 self .assertEqual (len (batches ), len (self .TOKENS ))
22002203 for batch , token in zip (batches , self .TOKENS ):
@@ -2212,6 +2215,47 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self):
22122215 timeout = gapic_v1 .method .DEFAULT ,
22132216 )
22142217
2218+ def test_generate_read_batches_w_data_boost_enabled (self ):
2219+ data_boost_enabled = True
2220+ keyset = self ._make_keyset ()
2221+ database = self ._make_database ()
2222+ batch_txn = self ._make_one (database )
2223+ snapshot = batch_txn ._snapshot = self ._make_snapshot ()
2224+ snapshot .partition_read .return_value = self .TOKENS
2225+
2226+ batches = list (
2227+ batch_txn .generate_read_batches (
2228+ self .TABLE ,
2229+ self .COLUMNS ,
2230+ keyset ,
2231+ index = self .INDEX ,
2232+ data_boost_enabled = data_boost_enabled ,
2233+ )
2234+ )
2235+
2236+ expected_read = {
2237+ "table" : self .TABLE ,
2238+ "columns" : self .COLUMNS ,
2239+ "keyset" : {"all" : True },
2240+ "index" : self .INDEX ,
2241+ "data_boost_enabled" : True ,
2242+ }
2243+ self .assertEqual (len (batches ), len (self .TOKENS ))
2244+ for batch , token in zip (batches , self .TOKENS ):
2245+ self .assertEqual (batch ["partition" ], token )
2246+ self .assertEqual (batch ["read" ], expected_read )
2247+
2248+ snapshot .partition_read .assert_called_once_with (
2249+ table = self .TABLE ,
2250+ columns = self .COLUMNS ,
2251+ keyset = keyset ,
2252+ index = self .INDEX ,
2253+ partition_size_bytes = None ,
2254+ max_partitions = None ,
2255+ retry = gapic_v1 .method .DEFAULT ,
2256+ timeout = gapic_v1 .method .DEFAULT ,
2257+ )
2258+
22152259 def test_process_read_batch (self ):
22162260 keyset = self ._make_keyset ()
22172261 token = b"TOKEN"
@@ -2288,7 +2332,11 @@ def test_generate_query_batches_w_max_partitions(self):
22882332 batch_txn .generate_query_batches (sql , max_partitions = max_partitions )
22892333 )
22902334
2291- expected_query = {"sql" : sql , "query_options" : client ._query_options }
2335+ expected_query = {
2336+ "sql" : sql ,
2337+ "data_boost_enabled" : False ,
2338+ "query_options" : client ._query_options ,
2339+ }
22922340 self .assertEqual (len (batches ), len (self .TOKENS ))
22932341 for batch , token in zip (batches , self .TOKENS ):
22942342 self .assertEqual (batch ["partition" ], token )
@@ -2326,6 +2374,7 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self):
23262374
23272375 expected_query = {
23282376 "sql" : sql ,
2377+ "data_boost_enabled" : False ,
23292378 "params" : params ,
23302379 "param_types" : param_types ,
23312380 "query_options" : client ._query_options ,
@@ -2372,6 +2421,7 @@ def test_generate_query_batches_w_retry_and_timeout_params(self):
23722421
23732422 expected_query = {
23742423 "sql" : sql ,
2424+ "data_boost_enabled" : False ,
23752425 "params" : params ,
23762426 "param_types" : param_types ,
23772427 "query_options" : client ._query_options ,
@@ -2391,6 +2441,37 @@ def test_generate_query_batches_w_retry_and_timeout_params(self):
23912441 timeout = 2.0 ,
23922442 )
23932443
2444+ def test_generate_query_batches_w_data_boost_enabled (self ):
2445+ sql = "SELECT COUNT(*) FROM table_name"
2446+ client = _Client (self .PROJECT_ID )
2447+ instance = _Instance (self .INSTANCE_NAME , client = client )
2448+ database = _Database (self .DATABASE_NAME , instance = instance )
2449+ batch_txn = self ._make_one (database )
2450+ snapshot = batch_txn ._snapshot = self ._make_snapshot ()
2451+ snapshot .partition_query .return_value = self .TOKENS
2452+
2453+ batches = list (batch_txn .generate_query_batches (sql , data_boost_enabled = True ))
2454+
2455+ expected_query = {
2456+ "sql" : sql ,
2457+ "data_boost_enabled" : True ,
2458+ "query_options" : client ._query_options ,
2459+ }
2460+ self .assertEqual (len (batches ), len (self .TOKENS ))
2461+ for batch , token in zip (batches , self .TOKENS ):
2462+ self .assertEqual (batch ["partition" ], token )
2463+ self .assertEqual (batch ["query" ], expected_query )
2464+
2465+ snapshot .partition_query .assert_called_once_with (
2466+ sql = sql ,
2467+ params = None ,
2468+ param_types = None ,
2469+ partition_size_bytes = None ,
2470+ max_partitions = None ,
2471+ retry = gapic_v1 .method .DEFAULT ,
2472+ timeout = gapic_v1 .method .DEFAULT ,
2473+ )
2474+
23942475 def test_process_query_batch (self ):
23952476 sql = (
23962477 "SELECT first_name, last_name, email FROM citizens " "WHERE age <= @max_age"
0 commit comments