diff --git a/openml/_api_calls.py b/openml/_api_calls.py index 57599b912..67e57d60a 100644 --- a/openml/_api_calls.py +++ b/openml/_api_calls.py @@ -55,7 +55,7 @@ def _perform_api_call(call, request_method, data=None, file_elements=None): if file_elements is not None: if request_method != "post": raise ValueError("request method must be post when file elements are present") - response = __read_url_files(url, data=data, file_elements=file_elements) + response = _read_url_files(url, data=data, file_elements=file_elements) else: response = __read_url(url, request_method, data) @@ -106,7 +106,6 @@ def _download_text_file( logging.info("Starting [%s] request for the URL %s", "get", source) start = time.time() response = __read_url(source, request_method="get") - __check_response(response, source, None) downloaded_file = response.text if md5_checksum is not None: @@ -138,15 +137,6 @@ def _download_text_file( return None -def __check_response(response, url, file_elements): - if response.status_code != 200: - raise __parse_server_exception(response, url, file_elements=file_elements) - elif ( - "Content-Encoding" not in response.headers or response.headers["Content-Encoding"] != "gzip" - ): - logging.warning("Received uncompressed content from OpenML for {}.".format(url)) - - def _file_id_to_url(file_id, filename=None): """ Presents the URL how to download a given file id @@ -159,7 +149,7 @@ def _file_id_to_url(file_id, filename=None): return url -def __read_url_files(url, data=None, file_elements=None): +def _read_url_files(url, data=None, file_elements=None): """do a post request to url with data and sending file_elements as files""" @@ -169,7 +159,7 @@ def __read_url_files(url, data=None, file_elements=None): file_elements = {} # Using requests.post sets header 'Accept-encoding' automatically to # 'gzip,deflate' - response = __send_request(request_method="post", url=url, data=data, files=file_elements,) + response = _send_request(request_method="post", url=url, data=data, files=file_elements,) return response @@ -178,10 +168,10 @@ def __read_url(url, request_method, data=None): if config.apikey is not None: data["api_key"] = config.apikey - return __send_request(request_method=request_method, url=url, data=data) + return _send_request(request_method=request_method, url=url, data=data) -def __send_request( +def _send_request( request_method, url, data, files=None, ): n_retries = config.connection_n_retries @@ -198,17 +188,40 @@ def __send_request( response = session.post(url, data=data, files=files) else: raise NotImplementedError() + __check_response(response=response, url=url, file_elements=files) break - except (requests.exceptions.ConnectionError, requests.exceptions.SSLError,) as e: + except ( + requests.exceptions.ConnectionError, + requests.exceptions.SSLError, + OpenMLServerException, + ) as e: + if isinstance(e, OpenMLServerException): + if e.code != 107: + # 107 is a database connection error - only then do retries + raise + else: + wait_time = 0.3 + else: + wait_time = 0.1 if i == n_retries: raise e else: - time.sleep(0.1 * i) + time.sleep(wait_time * i) + continue if response is None: raise ValueError("This should never happen!") return response +def __check_response(response, url, file_elements): + if response.status_code != 200: + raise __parse_server_exception(response, url, file_elements=file_elements) + elif ( + "Content-Encoding" not in response.headers or response.headers["Content-Encoding"] != "gzip" + ): + logging.warning("Received uncompressed content from OpenML for {}.".format(url)) + + def __parse_server_exception( response: requests.Response, url: str, file_elements: Dict, ) -> OpenMLServerError: diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index 26c705eca..1ddf94796 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -183,7 +183,7 @@ def list_datasets( status: Optional[str] = None, tag: Optional[str] = None, output_format: str = "dict", - **kwargs + **kwargs, ) -> Union[Dict, pd.DataFrame]: """ @@ -251,7 +251,7 @@ def list_datasets( size=size, status=status, tag=tag, - **kwargs + **kwargs, ) @@ -357,8 +357,7 @@ def _validated_data_attributes( def check_datasets_active( - dataset_ids: List[int], - raise_error_if_not_exist: bool = True, + dataset_ids: List[int], raise_error_if_not_exist: bool = True, ) -> Dict[int, bool]: """ Check if the dataset ids provided are active. @@ -386,7 +385,7 @@ def check_datasets_active( dataset = dataset_list.get(did, None) if dataset is None: if raise_error_if_not_exist: - raise ValueError(f'Could not find dataset {did} in OpenML dataset list.') + raise ValueError(f"Could not find dataset {did} in OpenML dataset list.") else: active[did] = dataset["status"] == "active" diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index 38b035fcf..9a87b96b4 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -227,10 +227,7 @@ def test_list_datasets_empty(self): def test_check_datasets_active(self): # Have to test on live because there is no deactivated dataset on the test server. openml.config.server = self.production_server - active = openml.datasets.check_datasets_active( - [2, 17, 79], - raise_error_if_not_exist=False, - ) + active = openml.datasets.check_datasets_active([2, 17, 79], raise_error_if_not_exist=False,) self.assertTrue(active[2]) self.assertFalse(active[17]) self.assertIsNone(active.get(79)) diff --git a/tests/test_openml/test_api_calls.py b/tests/test_openml/test_api_calls.py index 8b470a45b..459a0cdf5 100644 --- a/tests/test_openml/test_api_calls.py +++ b/tests/test_openml/test_api_calls.py @@ -1,3 +1,5 @@ +import unittest.mock + import openml import openml.testing @@ -8,3 +10,23 @@ def test_too_long_uri(self): openml.exceptions.OpenMLServerError, "URI too long!", ): openml.datasets.list_datasets(data_id=list(range(10000))) + + @unittest.mock.patch("time.sleep") + @unittest.mock.patch("requests.Session") + def test_retry_on_database_error(self, Session_class_mock, _): + response_mock = unittest.mock.Mock() + response_mock.text = ( + "\n" + "107" + "Database connection error. " + "Usually due to high server load. " + "Please wait for N seconds and try again.\n" + "" + ) + Session_class_mock.return_value.__enter__.return_value.get.return_value = response_mock + with self.assertRaisesRegex( + openml.exceptions.OpenMLServerException, "/abc returned code 107" + ): + openml._api_calls._send_request("get", "/abc", {}) + + self.assertEqual(Session_class_mock.return_value.__enter__.return_value.get.call_count, 10)