Skip to content

Commit e462037

Browse files
committed
fix(bigquery): use pyarrow fallback in schema autodetect
1 parent 7c9c0cb commit e462037

File tree

3 files changed

+126
-9
lines changed

3 files changed

+126
-9
lines changed

bigquery/google/cloud/bigquery/_pandas_helpers.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ def pyarrow_timestamp():
110110
"TIME": pyarrow_time,
111111
"TIMESTAMP": pyarrow_timestamp,
112112
}
113+
ARROW_SCALARS_TO_BQ = {
114+
arrow_type(): bq_type # TODO: explain wht calling arrow_type()
115+
for bq_type, arrow_type in BQ_TO_ARROW_SCALARS.items()
116+
}
113117
else: # pragma: NO COVER
114118
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
119+
ARROW_SCALARS_TO_BQ = {} # pragma: NO_COVER
115120

116121

117122
def bq_to_arrow_struct_data_type(field):
@@ -140,10 +145,11 @@ def bq_to_arrow_data_type(field):
140145
return pyarrow.list_(inner_type)
141146
return None
142147

143-
if field.field_type.upper() in schema._STRUCT_TYPES:
148+
field_type_upper = field.field_type.upper() if field.field_type else ""
149+
if field_type_upper in schema._STRUCT_TYPES:
144150
return bq_to_arrow_struct_data_type(field)
145151

146-
data_type_constructor = BQ_TO_ARROW_SCALARS.get(field.field_type.upper())
152+
data_type_constructor = BQ_TO_ARROW_SCALARS.get(field_type_upper)
147153
if data_type_constructor is None:
148154
return None
149155
return data_type_constructor()
@@ -180,9 +186,12 @@ def bq_to_arrow_schema(bq_schema):
180186

181187
def bq_to_arrow_array(series, bq_field):
182188
arrow_type = bq_to_arrow_data_type(bq_field)
189+
190+
field_type_upper = bq_field.field_type.upper() if bq_field.field_type else ""
191+
183192
if bq_field.mode.upper() == "REPEATED":
184193
return pyarrow.ListArray.from_pandas(series, type=arrow_type)
185-
if bq_field.field_type.upper() in schema._STRUCT_TYPES:
194+
if field_type_upper in schema._STRUCT_TYPES:
186195
return pyarrow.StructArray.from_pandas(series, type=arrow_type)
187196
return pyarrow.array(series, type=arrow_type)
188197

@@ -273,7 +282,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
273282
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
274283
if not bq_type:
275284
warnings.warn(u"Unable to determine type of column '{}'.".format(column))
276-
return None
285+
277286
bq_field = schema.SchemaField(column, bq_type)
278287
bq_schema_out.append(bq_field)
279288

@@ -285,6 +294,44 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
285294
bq_schema_unused
286295
)
287296
)
297+
298+
# If schema detection was not successful for all columns, also try with
299+
# pyarrow, if available.
300+
if any(field.field_type is None for field in bq_schema_out):
301+
if not pyarrow:
302+
return None # We cannot detect the schema in full.
303+
304+
arrow_table = dataframe_to_arrow(dataframe, bq_schema_out)
305+
arrow_schema_index = {field.name: field.type for field in arrow_table}
306+
307+
currated_schema = []
308+
for schema_field in bq_schema_out:
309+
if schema_field.field_type is not None:
310+
currated_schema.append(schema_field)
311+
continue
312+
313+
detected_type = ARROW_SCALARS_TO_BQ.get(
314+
arrow_schema_index.get(schema_field.name)
315+
)
316+
if detected_type is None:
317+
warnings.warn(
318+
u"Pyarrow could not determine the type of column '{}'.".format(
319+
schema_field.name
320+
)
321+
)
322+
return None
323+
324+
new_field = schema.SchemaField(
325+
name=schema_field.name,
326+
field_type=detected_type,
327+
mode=schema_field.mode,
328+
description=schema_field.description,
329+
fields=schema_field.fields,
330+
)
331+
currated_schema.append(new_field)
332+
333+
bq_schema_out = currated_schema
334+
288335
return tuple(bq_schema_out)
289336

290337

bigquery/tests/unit/test__pandas_helpers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,3 +905,69 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
905905
call_args = fake_write_table.call_args
906906
assert call_args is not None
907907
assert call_args.kwargs.get("compression") == "ZSTD"
908+
909+
910+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
911+
def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test):
912+
dataframe = pandas.DataFrame(
913+
data=[
914+
{"id": 10, "status": "FOO", "execution_date": datetime.date(2019, 5, 10)},
915+
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
916+
]
917+
)
918+
919+
no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None)
920+
921+
with no_pyarrow_patch:
922+
detected_schema = module_under_test.dataframe_to_bq_schema(
923+
dataframe, bq_schema=[]
924+
)
925+
926+
assert detected_schema is None
927+
928+
929+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
930+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
931+
def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test):
932+
dataframe = pandas.DataFrame(
933+
data=[
934+
{"id": 10, "status": "FOO", "created_at": datetime.date(2019, 5, 10)},
935+
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
936+
]
937+
)
938+
939+
detected_schema = module_under_test.dataframe_to_bq_schema(dataframe, bq_schema=[])
940+
expected_schema = (
941+
schema.SchemaField("id", "INTEGER", mode="NULLABLE"),
942+
schema.SchemaField("status", "STRING", mode="NULLABLE"),
943+
schema.SchemaField("created_at", "DATE", mode="NULLABLE"),
944+
)
945+
assert detected_schema == expected_schema
946+
947+
948+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
949+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
950+
def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
951+
dataframe = pandas.DataFrame(
952+
data=[
953+
{"id": 10, "status": "FOO", "all_items": [10.1, 10.2]},
954+
{"id": 20, "status": "BAR", "all_items": [20.1, 20.2]},
955+
]
956+
)
957+
958+
with warnings.catch_warnings(record=True) as warned:
959+
detected_schema = module_under_test.dataframe_to_bq_schema(
960+
dataframe, bq_schema=[]
961+
)
962+
963+
assert detected_schema is None
964+
965+
expected_warnings = []
966+
for warning in warned:
967+
if "Pyarrow could not" in str(warning):
968+
expected_warnings.append(warning)
969+
970+
assert len(expected_warnings) == 1
971+
warning_msg = str(expected_warnings[0])
972+
assert "all_items" in warning_msg
973+
assert "could not determine the type" in warning_msg

bigquery/tests/unit/test_client.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5704,8 +5704,7 @@ def test_load_table_from_dataframe_unknown_table(self):
57045704
)
57055705

57065706
@unittest.skipIf(pandas is None, "Requires `pandas`")
5707-
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
5708-
def test_load_table_from_dataframe_no_schema_warning(self):
5707+
def test_load_table_from_dataframe_no_schema_warning_wo_pyarrow(self):
57095708
client = self._make_client()
57105709

57115710
# Pick at least one column type that translates to Pandas dtype
@@ -5722,9 +5721,12 @@ def test_load_table_from_dataframe_no_schema_warning(self):
57225721
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
57235722
)
57245723
pyarrow_patch = mock.patch("google.cloud.bigquery.client.pyarrow", None)
5724+
pyarrow_patch_helpers = mock.patch(
5725+
"google.cloud.bigquery._pandas_helpers.pyarrow", None
5726+
)
57255727
catch_warnings = warnings.catch_warnings(record=True)
57265728

5727-
with get_table_patch, load_patch, pyarrow_patch, catch_warnings as warned:
5729+
with get_table_patch, load_patch, pyarrow_patch, pyarrow_patch_helpers, catch_warnings as warned:
57285730
client.load_table_from_dataframe(
57295731
dataframe, self.TABLE_REF, location=self.LOCATION
57305732
)
@@ -5892,7 +5894,6 @@ def test_load_table_from_dataframe_w_partial_schema_extra_types(self):
58925894
assert "unknown_col" in message
58935895

58945896
@unittest.skipIf(pandas is None, "Requires `pandas`")
5895-
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
58965897
def test_load_table_from_dataframe_w_partial_schema_missing_types(self):
58975898
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
58985899
from google.cloud.bigquery import job
@@ -5909,10 +5910,13 @@ def test_load_table_from_dataframe_w_partial_schema_missing_types(self):
59095910
load_patch = mock.patch(
59105911
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
59115912
)
5913+
pyarrow_patch = mock.patch(
5914+
"google.cloud.bigquery._pandas_helpers.pyarrow", None
5915+
)
59125916

59135917
schema = (SchemaField("string_col", "STRING"),)
59145918
job_config = job.LoadJobConfig(schema=schema)
5915-
with load_patch as load_table_from_file, warnings.catch_warnings(
5919+
with pyarrow_patch, load_patch as load_table_from_file, warnings.catch_warnings(
59165920
record=True
59175921
) as warned:
59185922
client.load_table_from_dataframe(

0 commit comments

Comments
 (0)