Skip to content

Commit d79a98c

Browse files
mfeurerPGijsbers
authored andcommitted
add better error message for too-long URI (#881)
* add better error message for too-long URI * improve error handling * improve data download function, fix bugs * stricter API, more private methods * incorporate Pieter's feedback
1 parent 69d443f commit d79a98c

File tree

8 files changed

+127
-97
lines changed

8 files changed

+127
-97
lines changed

openml/_api_calls.py

Lines changed: 105 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# License: BSD 3-Clause
22

33
import time
4+
import hashlib
45
import logging
56
import requests
6-
import warnings
77
import xmltodict
8-
from typing import Dict
8+
from typing import Dict, Optional
99

1010
from . import config
1111
from .exceptions import (OpenMLServerError, OpenMLServerException,
12-
OpenMLServerNoResult)
12+
OpenMLServerNoResult, OpenMLHashException)
1313

1414

1515
def _perform_api_call(call, request_method, data=None, file_elements=None):
@@ -47,20 +47,105 @@ def _perform_api_call(call, request_method, data=None, file_elements=None):
4747
url = url.replace('=', '%3d')
4848
logging.info('Starting [%s] request for the URL %s', request_method, url)
4949
start = time.time()
50+
5051
if file_elements is not None:
5152
if request_method != 'post':
52-
raise ValueError('request method must be post when file elements '
53-
'are present')
54-
response = _read_url_files(url, data=data, file_elements=file_elements)
53+
raise ValueError('request method must be post when file elements are present')
54+
response = __read_url_files(url, data=data, file_elements=file_elements)
5555
else:
56-
response = _read_url(url, request_method, data)
56+
response = __read_url(url, request_method, data)
57+
58+
__check_response(response, url, file_elements)
59+
5760
logging.info(
5861
'%.7fs taken for [%s] request for the URL %s',
5962
time.time() - start,
6063
request_method,
6164
url,
6265
)
63-
return response
66+
return response.text
67+
68+
69+
def _download_text_file(source: str,
70+
output_path: Optional[str] = None,
71+
md5_checksum: str = None,
72+
exists_ok: bool = True,
73+
encoding: str = 'utf8',
74+
) -> Optional[str]:
75+
""" Download the text file at `source` and store it in `output_path`.
76+
77+
By default, do nothing if a file already exists in `output_path`.
78+
The downloaded file can be checked against an expected md5 checksum.
79+
80+
Parameters
81+
----------
82+
source : str
83+
url of the file to be downloaded
84+
output_path : str, (optional)
85+
full path, including filename, of where the file should be stored. If ``None``,
86+
this function returns the downloaded file as string.
87+
md5_checksum : str, optional (default=None)
88+
If not None, should be a string of hexidecimal digits of the expected digest value.
89+
exists_ok : bool, optional (default=True)
90+
If False, raise an FileExistsError if there already exists a file at `output_path`.
91+
encoding : str, optional (default='utf8')
92+
The encoding with which the file should be stored.
93+
"""
94+
if output_path is not None:
95+
try:
96+
with open(output_path, encoding=encoding):
97+
if exists_ok:
98+
return None
99+
else:
100+
raise FileExistsError
101+
except FileNotFoundError:
102+
pass
103+
104+
logging.info('Starting [%s] request for the URL %s', 'get', source)
105+
start = time.time()
106+
response = __read_url(source, request_method='get')
107+
__check_response(response, source, None)
108+
downloaded_file = response.text
109+
110+
if md5_checksum is not None:
111+
md5 = hashlib.md5()
112+
md5.update(downloaded_file.encode('utf-8'))
113+
md5_checksum_download = md5.hexdigest()
114+
if md5_checksum != md5_checksum_download:
115+
raise OpenMLHashException(
116+
'Checksum {} of downloaded file is unequal to the expected checksum {}.'
117+
.format(md5_checksum_download, md5_checksum))
118+
119+
if output_path is None:
120+
logging.info(
121+
'%.7fs taken for [%s] request for the URL %s',
122+
time.time() - start,
123+
'get',
124+
source,
125+
)
126+
return downloaded_file
127+
128+
else:
129+
with open(output_path, "w", encoding=encoding) as fh:
130+
fh.write(downloaded_file)
131+
132+
logging.info(
133+
'%.7fs taken for [%s] request for the URL %s',
134+
time.time() - start,
135+
'get',
136+
source,
137+
)
138+
139+
del downloaded_file
140+
return None
141+
142+
143+
def __check_response(response, url, file_elements):
144+
if response.status_code != 200:
145+
raise __parse_server_exception(response, url, file_elements=file_elements)
146+
elif 'Content-Encoding' not in response.headers or \
147+
response.headers['Content-Encoding'] != 'gzip':
148+
logging.warning('Received uncompressed content from OpenML for {}.'.format(url))
64149

65150

66151
def _file_id_to_url(file_id, filename=None):
@@ -75,7 +160,7 @@ def _file_id_to_url(file_id, filename=None):
75160
return url
76161

77162

78-
def _read_url_files(url, data=None, file_elements=None):
163+
def __read_url_files(url, data=None, file_elements=None):
79164
"""do a post request to url with data
80165
and sending file_elements as files"""
81166

@@ -85,37 +170,24 @@ def _read_url_files(url, data=None, file_elements=None):
85170
file_elements = {}
86171
# Using requests.post sets header 'Accept-encoding' automatically to
87172
# 'gzip,deflate'
88-
response = send_request(
173+
response = __send_request(
89174
request_method='post',
90175
url=url,
91176
data=data,
92177
files=file_elements,
93178
)
94-
if response.status_code != 200:
95-
raise _parse_server_exception(response, url, file_elements=file_elements)
96-
if 'Content-Encoding' not in response.headers or \
97-
response.headers['Content-Encoding'] != 'gzip':
98-
warnings.warn('Received uncompressed content from OpenML for {}.'
99-
.format(url))
100-
return response.text
179+
return response
101180

102181

103-
def _read_url(url, request_method, data=None):
182+
def __read_url(url, request_method, data=None):
104183
data = {} if data is None else data
105184
if config.apikey is not None:
106185
data['api_key'] = config.apikey
107186

108-
response = send_request(request_method=request_method, url=url, data=data)
109-
if response.status_code != 200:
110-
raise _parse_server_exception(response, url, file_elements=None)
111-
if 'Content-Encoding' not in response.headers or \
112-
response.headers['Content-Encoding'] != 'gzip':
113-
warnings.warn('Received uncompressed content from OpenML for {}.'
114-
.format(url))
115-
return response.text
187+
return __send_request(request_method=request_method, url=url, data=data)
116188

117189

118-
def send_request(
190+
def __send_request(
119191
request_method,
120192
url,
121193
data,
@@ -149,16 +221,19 @@ def send_request(
149221
return response
150222

151223

152-
def _parse_server_exception(
224+
def __parse_server_exception(
153225
response: requests.Response,
154226
url: str,
155227
file_elements: Dict,
156228
) -> OpenMLServerError:
157-
# OpenML has a sophisticated error system
158-
# where information about failures is provided. try to parse this
229+
230+
if response.status_code == 414:
231+
raise OpenMLServerError('URI too long! ({})'.format(url))
159232
try:
160233
server_exception = xmltodict.parse(response.text)
161234
except Exception:
235+
# OpenML has a sophisticated error system
236+
# where information about failures is provided. try to parse this
162237
raise OpenMLServerError(
163238
'Unexpected server error when calling {}. Please contact the developers!\n'
164239
'Status code: {}\n{}'.format(url, response.status_code, response.text))

openml/datasets/functions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,7 @@ def _get_dataset_arff(description: Union[Dict, OpenMLDataset],
886886
output_file_path = os.path.join(cache_directory, "dataset.arff")
887887

888888
try:
889-
openml.utils._download_text_file(
889+
openml._api_calls._download_text_file(
890890
source=url,
891891
output_path=output_file_path,
892892
md5_checksum=md5_checksum_fixture
@@ -1038,13 +1038,11 @@ def _get_online_dataset_arff(dataset_id):
10381038
str
10391039
A string representation of an ARFF file.
10401040
"""
1041-
dataset_xml = openml._api_calls._perform_api_call("data/%d" % dataset_id,
1042-
'get')
1041+
dataset_xml = openml._api_calls._perform_api_call("data/%d" % dataset_id, 'get')
10431042
# build a dict from the xml.
10441043
# use the url from the dataset description and return the ARFF string
1045-
return openml._api_calls._read_url(
1044+
return openml._api_calls._download_text_file(
10461045
xmltodict.parse(dataset_xml)['oml:data_set_description']['oml:url'],
1047-
request_method='get'
10481046
)
10491047

10501048

openml/runs/run.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ def get_metric_fn(self, sklearn_fn, kwargs=None):
327327
predictions_file_url = openml._api_calls._file_id_to_url(
328328
self.output_files['predictions'], 'predictions.arff',
329329
)
330-
response = openml._api_calls._read_url(predictions_file_url,
331-
request_method='get')
330+
response = openml._api_calls._download_text_file(predictions_file_url)
332331
predictions_arff = arff.loads(response)
333332
# TODO: make this a stream reader
334333
else:

openml/tasks/task.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,10 @@ def _download_split(self, cache_file: str):
116116
pass
117117
except (OSError, IOError):
118118
split_url = self.estimation_procedure["data_splits_url"]
119-
split_arff = openml._api_calls._read_url(split_url,
120-
request_method='get')
121-
122-
with io.open(cache_file, "w", encoding='utf8') as fh:
123-
fh.write(split_arff)
124-
del split_arff
119+
openml._api_calls._download_text_file(
120+
source=str(split_url),
121+
output_path=cache_file,
122+
)
125123

126124
def download_split(self) -> OpenMLSplit:
127125
"""Download the OpenML split for a given task.

openml/utils.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# License: BSD 3-Clause
22

33
import os
4-
import hashlib
54
import xmltodict
65
import shutil
76
from typing import TYPE_CHECKING, List, Tuple, Union, Type
@@ -366,53 +365,3 @@ def _create_lockfiles_dir():
366365
except OSError:
367366
pass
368367
return dir
369-
370-
371-
def _download_text_file(source: str,
372-
output_path: str,
373-
md5_checksum: str = None,
374-
exists_ok: bool = True,
375-
encoding: str = 'utf8',
376-
) -> None:
377-
""" Download the text file at `source` and store it in `output_path`.
378-
379-
By default, do nothing if a file already exists in `output_path`.
380-
The downloaded file can be checked against an expected md5 checksum.
381-
382-
Parameters
383-
----------
384-
source : str
385-
url of the file to be downloaded
386-
output_path : str
387-
full path, including filename, of where the file should be stored.
388-
md5_checksum : str, optional (default=None)
389-
If not None, should be a string of hexidecimal digits of the expected digest value.
390-
exists_ok : bool, optional (default=True)
391-
If False, raise an FileExistsError if there already exists a file at `output_path`.
392-
encoding : str, optional (default='utf8')
393-
The encoding with which the file should be stored.
394-
"""
395-
try:
396-
with open(output_path, encoding=encoding):
397-
if exists_ok:
398-
return
399-
else:
400-
raise FileExistsError
401-
except FileNotFoundError:
402-
pass
403-
404-
downloaded_file = openml._api_calls._read_url(source, request_method='get')
405-
406-
if md5_checksum is not None:
407-
md5 = hashlib.md5()
408-
md5.update(downloaded_file.encode('utf-8'))
409-
md5_checksum_download = md5.hexdigest()
410-
if md5_checksum != md5_checksum_download:
411-
raise openml.exceptions.OpenMLHashException(
412-
'Checksum {} of downloaded file is unequal to the expected checksum {}.'
413-
.format(md5_checksum_download, md5_checksum))
414-
415-
with open(output_path, "w", encoding=encoding) as fh:
416-
fh.write(downloaded_file)
417-
418-
del downloaded_file
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import openml
2+
import openml.testing
3+
4+
5+
class TestConfig(openml.testing.TestBase):
6+
7+
def test_too_long_uri(self):
8+
with self.assertRaisesRegex(
9+
openml.exceptions.OpenMLServerError,
10+
'URI too long!',
11+
):
12+
openml.datasets.list_datasets(data_id=list(range(10000)))

tests/test_runs/test_run_functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed):
119119
# downloads the predictions of the old task
120120
file_id = run.output_files['predictions']
121121
predictions_url = openml._api_calls._file_id_to_url(file_id)
122-
response = openml._api_calls._read_url(predictions_url,
123-
request_method='get')
122+
response = openml._api_calls._download_text_file(predictions_url)
124123
predictions = arff.loads(response)
125124
run_prime = openml.runs.run_model_on_task(
126125
model=model_prime,

tests/test_utils/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class OpenMLTaskTest(TestBase):
1616
def mocked_perform_api_call(call, request_method):
1717
# TODO: JvR: Why is this not a staticmethod?
1818
url = openml.config.server + '/' + call
19-
return openml._api_calls._read_url(url, request_method=request_method)
19+
return openml._api_calls._download_text_file(url)
2020

2121
def test_list_all(self):
2222
openml.utils._list_all(listing_call=openml.tasks.functions._list_tasks)

0 commit comments

Comments
 (0)