diff --git a/CHANGELOG.md b/CHANGELOG.md index a8fce97e0..cbdff8010 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## dbt-databricks 1.0.4 (Release TBD) +### Fixes +- Block taking jinja2.runtime.Undefined into DatabricksAdapter ([#98](https://github.com/databricks/dbt-databricks/pull/98)) + ## dbt-databricks 1.0.3 (April 26, 2022) ### Fixes diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 7f273dceb..fd9c886af 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -203,8 +203,6 @@ def _execute_cursor( return self.get_result_from_cursor(cursor) def list_schemas(self, database: Optional[str], schema: Optional[str] = None) -> Table: - database = database if isinstance(database, str) else None - schema = schema if isinstance(schema, str) else None return self._execute_cursor( f"GetSchemas(database={database}, schema={schema})", lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema), diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index df8a2dff1..8b5fbee06 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -2,12 +2,13 @@ from typing import Optional, List, Dict, Union from dbt.adapters.base import AdapterConfig -from dbt.adapters.databricks import DatabricksConnectionManager -from dbt.adapters.databricks.relation import DatabricksRelation -from dbt.adapters.databricks.column import DatabricksColumn - from dbt.adapters.spark.impl import SparkAdapter +from dbt.adapters.databricks.column import DatabricksColumn +from dbt.adapters.databricks.connections import DatabricksConnectionManager +from dbt.adapters.databricks.relation import DatabricksRelation +from dbt.adapters.databricks.utils import undefined_proof + @dataclass class DatabricksConfig(AdapterConfig): @@ -21,6 +22,7 @@ class DatabricksConfig(AdapterConfig): tblproperties: Optional[Dict[str, str]] = None +@undefined_proof class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 5c4e6b961..118390f8e 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,12 +1,23 @@ from dataclasses import dataclass +from typing import Any, Dict +from dbt.adapters.spark.relation import SparkRelation from dbt.exceptions import RuntimeException -from dbt.adapters.spark.relation import SparkRelation +from dbt.adapters.databricks.utils import remove_undefined @dataclass(frozen=True, eq=False, repr=False) class DatabricksRelation(SparkRelation): + @classmethod + def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]: + data = super().__pre_deserialize__(data) + if "database" not in data["path"]: + data["path"]["database"] = None + else: + data["path"]["database"] = remove_undefined(data["path"]["database"]) + return data + def __post_init__(self) -> None: if self.database != self.schema and self.database: raise RuntimeException("Cannot set database in Databricks!") diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py new file mode 100644 index 000000000..e46397a63 --- /dev/null +++ b/dbt/adapters/databricks/utils.py @@ -0,0 +1,51 @@ +import functools +import inspect +from typing import Any, Callable, Type, TypeVar + +from dbt.adapters.base import BaseAdapter +from jinja2.runtime import Undefined + + +A = TypeVar("A", bound=BaseAdapter) + + +def remove_undefined(v: Any) -> Any: + return None if isinstance(v, Undefined) else v + + +def undefined_proof(cls: Type[A]) -> Type[A]: + for name in cls._available_: + func = getattr(cls, name) + if not callable(func): + continue + try: + static_attr = inspect.getattr_static(cls, name) + isstatic = isinstance(static_attr, staticmethod) + isclass = isinstance(static_attr, classmethod) + except AttributeError: + isstatic = False + isclass = False + wrapped_function = _wrap_function(func.__func__ if isclass else func) + setattr( + cls, + name, + ( + staticmethod(wrapped_function) + if isstatic + else classmethod(wrapped_function) + if isclass + else wrapped_function + ), + ) + + return cls + + +def _wrap_function(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + new_args = [remove_undefined(arg) for arg in args] + new_kwargs = {key: remove_undefined(value) for key, value in kwargs.items()} + return func(*new_args, **new_kwargs) + + return wrapper diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py new file mode 100644 index 000000000..ff7f380f9 --- /dev/null +++ b/tests/unit/test_relation.py @@ -0,0 +1,67 @@ +import unittest + +from jinja2.runtime import Undefined + +from dbt.adapters.databricks.relation import DatabricksRelation + + +class TestDatabricksRelation(unittest.TestCase): + def test_pre_deserialize(self): + data = { + "quote_policy": {"database": False, "schema": False, "identifier": False}, + "path": { + "database": "some_schema", + "schema": "some_schema", + "identifier": "some_table", + }, + "type": None, + } + + relation = DatabricksRelation.from_dict(data) + self.assertEqual(relation.database, "some_schema") + self.assertEqual(relation.schema, "some_schema") + self.assertEqual(relation.identifier, "some_table") + + data = { + "quote_policy": {"database": False, "schema": False, "identifier": False}, + "path": { + "database": None, + "schema": "some_schema", + "identifier": "some_table", + }, + "type": None, + } + + relation = DatabricksRelation.from_dict(data) + self.assertIsNone(relation.database) + self.assertEqual(relation.schema, "some_schema") + self.assertEqual(relation.identifier, "some_table") + + data = { + "quote_policy": {"database": False, "schema": False, "identifier": False}, + "path": { + "schema": "some_schema", + "identifier": "some_table", + }, + "type": None, + } + + relation = DatabricksRelation.from_dict(data) + self.assertIsNone(relation.database) + self.assertEqual(relation.schema, "some_schema") + self.assertEqual(relation.identifier, "some_table") + + data = { + "quote_policy": {"database": False, "schema": False, "identifier": False}, + "path": { + "database": Undefined(), + "schema": "some_schema", + "identifier": "some_table", + }, + "type": None, + } + + relation = DatabricksRelation.from_dict(data) + self.assertIsNone(relation.database) + self.assertEqual(relation.schema, "some_schema") + self.assertEqual(relation.identifier, "some_table")