-
Notifications
You must be signed in to change notification settings - Fork 140
feat: add JSON type, bindparam support #1147
Changes from 1 commit
140ca3a
d0fd734
65d676d
b168080
2a03b31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| import sqlalchemy | ||
|
|
||
|
|
||
| class JSON(sqlalchemy.sql.sqltypes.JSON): | ||
| def bind_expression(self, bindvalue): | ||
| # JSON query parameters have type STRING | ||
| # This hook ensures that the rendered expression has type JSON | ||
| return sqlalchemy.func.PARSE_JSON(bindvalue, type_=self) | ||
|
Comment on lines
+5
to
+8
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note An alternative here would be to just preserve the STRING type and let BigQuery handle the cast, but this seems hard to reason about. It feels right to have the expression type in BigQuery match the expression type here in SQLAlchemy.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at [1] and pondering this - it wouldn't be that surprising if a user wanted to submit invalid JSON as a query parameter at some point. This could be supported by parameterizing this behavior in the type, i.e: |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,7 +59,7 @@ | |
| import re | ||
|
|
||
| from .parse_url import parse_url | ||
| from . import _helpers, _struct, _types | ||
| from . import _helpers, _json, _struct, _types | ||
| import sqlalchemy_bigquery_vendored.sqlalchemy.postgresql.base as vendored_postgresql | ||
|
|
||
| # Illegal characters is intended to be all characters that are not explicitly | ||
|
|
@@ -547,6 +547,13 @@ def visit_bindparam( | |
| bq_type = self.dialect.type_compiler.process(type_) | ||
| bq_type = self.__remove_type_parameter(bq_type) | ||
|
|
||
| if bq_type == "JSON": | ||
| # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI | ||
| # For now, we hack around this by: | ||
| # - Rewriting the bindparam type to STRING | ||
| # - Applying a bind expression that converts the parameter back to JSON | ||
| bq_type = "STRING" | ||
|
r1b marked this conversation as resolved.
Outdated
|
||
|
|
||
| assert_(param != "%s", f"Unexpected param: {param}") | ||
|
|
||
| if bindparam.expanding: # pragma: NO COVER | ||
|
|
@@ -641,6 +648,9 @@ def visit_NUMERIC(self, type_, **kw): | |
|
|
||
| visit_DECIMAL = visit_NUMERIC | ||
|
|
||
| def visit_JSON(self, type_, **kw): | ||
| return "JSON" | ||
|
|
||
|
|
||
| class BigQueryDDLCompiler(DDLCompiler): | ||
| option_datatype_mapping = { | ||
|
|
@@ -1076,6 +1086,7 @@ class BigQueryDialect(DefaultDialect): | |
| sqlalchemy.sql.sqltypes.TIMESTAMP: BQTimestamp, | ||
| sqlalchemy.sql.sqltypes.ARRAY: BQArray, | ||
| sqlalchemy.sql.sqltypes.Enum: sqlalchemy.sql.sqltypes.Enum, | ||
| sqlalchemy.sql.sqltypes.JSON: _json.JSON, | ||
| } | ||
|
|
||
| def __init__( | ||
|
|
@@ -1086,6 +1097,8 @@ def __init__( | |
| credentials_info=None, | ||
| credentials_base64=None, | ||
| list_tables_page_size=1000, | ||
| json_serializer=None, | ||
| json_deserializer=None, | ||
|
Comment on lines
+1101
to
+1102
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note See https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.JSON section "Customizing the JSON Serializer" |
||
| *args, | ||
| **kwargs, | ||
| ): | ||
|
|
@@ -1098,6 +1111,8 @@ def __init__( | |
| self.identifier_preparer = self.preparer(self) | ||
| self.dataset_id = None | ||
| self.list_tables_page_size = list_tables_page_size | ||
| self._json_serializer = json_serializer | ||
| self._json_deserializer = json_deserializer | ||
|
|
||
| @classmethod | ||
| def dbapi(cls): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import json | ||
| from unittest import mock | ||
|
|
||
| import pytest | ||
| import sqlalchemy | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def json_table(metadata): | ||
| from sqlalchemy_bigquery import JSON | ||
|
|
||
| return sqlalchemy.Table("json_table", metadata, sqlalchemy.Column("json", JSON)) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def json_data(): | ||
| return {"foo": "bar"} | ||
|
|
||
|
|
||
| def test_set_json_serde(faux_conn, metadata, json_table, json_data): | ||
| from sqlalchemy_bigquery import JSON | ||
|
|
||
| json_serializer = mock.Mock(side_effect=json.dumps) | ||
| json_deserializer = mock.Mock(side_effect=json.loads) | ||
|
|
||
| engine = sqlalchemy.create_engine( | ||
| "bigquery://myproject/mydataset", | ||
| json_serializer=json_serializer, | ||
| json_deserializer=json_deserializer, | ||
| ) | ||
|
|
||
| json_column = json_table.c.json | ||
|
|
||
| process_bind = json_column.type.bind_processor(engine.dialect) | ||
| process_bind(json_data) | ||
| assert json_serializer.mock_calls == [mock.call(json_data)] | ||
|
|
||
| process_result = json_column.type.result_processor(engine.dialect, JSON) | ||
| process_result(json.dumps(json_data)) | ||
| assert json_deserializer.mock_calls == [mock.call(json.dumps(json_data))] | ||
|
|
||
|
|
||
| def test_json_create(faux_conn, metadata, json_table, json_data): | ||
| expr = sqlalchemy.schema.CreateTable(json_table) | ||
| sql = expr.compile(faux_conn.engine).string | ||
| assert sql == ("\nCREATE TABLE `json_table` (\n" "\t`json` JSON\n" ") \n\n") | ||
|
r1b marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| def test_json_insert(faux_conn, metadata, json_table, json_data): | ||
| expr = sqlalchemy.insert(json_table).values(json=json_data) | ||
| sql = expr.compile(faux_conn.engine).string | ||
| assert ( | ||
| sql == "INSERT INTO `json_table` (`json`) VALUES (PARSE_JSON(%(json:STRING)s))" | ||
| ) | ||
|
|
||
|
|
||
| def test_json_where(faux_conn, metadata, json_table, json_data): | ||
| expr = sqlalchemy.select(json_table.c.json).where(json_table.c.json == json_data) | ||
| sql = expr.compile(faux_conn.engine).string | ||
| assert sql == ( | ||
| "SELECT `json_table`.`json` \n" | ||
| "FROM `json_table` \n" | ||
| "WHERE `json_table`.`json` = PARSE_JSON(%(json_1:STRING)s)" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.