Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
## dbt-databricks 1.1.1 (Release TBD)

### Features
- Support for Databricks CATALOG as a DATABASE in DBT compilations ([#95](https://github.com/databricks/dbt-databricks/issues/95), [#89](https://github.com/databricks/dbt-databricks/pull/89), [#94](https://github.com/databricks/dbt-databricks/pull/94))
- Support for Databricks CATALOG as a DATABASE in DBT compilations ([#95](https://github.com/databricks/dbt-databricks/issues/95), [#89](https://github.com/databricks/dbt-databricks/pull/89), [#94](https://github.com/databricks/dbt-databricks/pull/94), [#105](https://github.com/databricks/dbt-databricks/pull/105))
- Setting an initial catalog with `session_properties` is deprecated and will not work in the future release. Please use `catalog` or `database` to set the initial catalog.
- When using catalog, `spark_build_snapshot_staging_table` macro will not be used. If trying to override the macro, `databricks_build_snapshot_staging_table` should be overridden instead.

### Fixes
Expand Down
51 changes: 49 additions & 2 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from concurrent.futures import Future
from contextlib import contextmanager
from dataclasses import dataclass
import re
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union

from agate import Row, Table

from dbt.adapters.base import AdapterConfig
from dbt.adapters.base.impl import catch_as_completed
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.spark.impl import (
SparkAdapter,
Expand All @@ -14,9 +17,11 @@
LIST_SCHEMAS_MACRO_NAME,
)
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.relation import RelationType
import dbt.exceptions
from dbt.events import AdapterLogger
from dbt.utils import executor

from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import DatabricksConnectionManager
Expand All @@ -26,6 +31,9 @@

logger = AdapterLogger("Databricks")

CURRENT_CATALOG_MACRO_NAME = "current_catalog"
USE_CATALOG_MACRO_NAME = "use_catalog"


@dataclass
class DatabricksConfig(AdapterConfig):
Expand Down Expand Up @@ -86,7 +94,9 @@ def list_relations_without_caching(
) -> List[DatabricksRelation]:
kwargs = {"schema_relation": schema_relation}
try:
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
# The catalog for `show table extended` needs to match the current catalog.
with self._catalog(schema_relation.database):
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
except dbt.exceptions.RuntimeException as e:
errmsg = getattr(e, "msg", "")
if f"Database '{schema_relation}' not found" in errmsg:
Expand Down Expand Up @@ -176,6 +186,21 @@ def parse_columns_from_information(
columns.append(column)
return columns

def get_catalog(self, manifest: Manifest) -> Tuple[Table, List[Exception]]:
schema_map = self._get_catalog_schemas(manifest)

with executor(self.config) as tpe:
futures: List[Future[Table]] = []
for info, schemas in schema_map.items():
for schema in schemas:
futures.append(
tpe.submit_connected(
self, schema, self._get_one_catalog, info, [schema], manifest
)
)
catalogs, exceptions = catch_as_completed(futures)
return catalogs, exceptions

def _get_columns_for_catalog(self, relation: DatabricksRelation) -> Iterable[Dict[str, Any]]:
columns = self.parse_columns_from_information(relation)

Expand All @@ -185,3 +210,25 @@ def _get_columns_for_catalog(self, relation: DatabricksRelation) -> Iterable[Dic
as_dict["column_name"] = as_dict.pop("column", None)
as_dict["column_type"] = as_dict.pop("dtype")
yield as_dict

@contextmanager
def _catalog(self, catalog: Optional[str]) -> Iterator[None]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can add a comment here?

"""
A context manager to make the operation work in the specified catalog,
and move back to the current catalog after the operation.

If `catalog` is None, the operation works in the current catalog.
"""
current_catalog: Optional[str] = None
try:
if catalog is not None:
current_catalog = self.execute_macro(CURRENT_CATALOG_MACRO_NAME)[0][0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious is the result of current_catalog cached somewhere?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not cached.

if current_catalog is not None:
if current_catalog != catalog:
self.execute_macro(USE_CATALOG_MACRO_NAME, kwargs=dict(catalog=catalog))
else:
current_catalog = None
yield
finally:
if current_catalog is not None:
self.execute_macro(USE_CATALOG_MACRO_NAME, kwargs=dict(catalog=current_catalog))
20 changes: 20 additions & 0 deletions dbt/include/databricks/macros/catalog.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{% macro current_catalog() -%}
{{ return(adapter.dispatch('current_catalog', 'dbt')()) }}
{% endmacro %}

{% macro databricks__current_catalog() -%}
{% call statement('current_catalog', fetch_result=True) %}
select current_catalog()
{% endcall %}
{% do return(load_result('current_catalog').table) %}
{% endmacro %}

{% macro use_catalog(catalog) -%}
{{ return(adapter.dispatch('use_catalog', 'dbt')(catalog)) }}
{% endmacro %}

{% macro databricks__use_catalog(catalog) -%}
{% call statement() %}
use catalog {{ adapter.quote(catalog) }}
{% endcall %}
{% endmacro %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{{ config(
catalog = env_var('DBT_DATABRICKS_UC_ALTERNATIVE_CATALOG', 'alternative')
) }}

select * from {{ ref('seed') }}
7 changes: 7 additions & 0 deletions tests/integration/multi_catalog/models/cross_catalog.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
select
{{ ref('alternative_catalog') }}.id,
{{ ref('alternative_catalog') }}.name,
{{ ref('alternative_catalog') }}.date
from
{{ ref('alternative_catalog') }}
inner join {{ ref('refer_alternative_catalog')}} using (id)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select * from {{ ref('alternative_catalog') }}
3 changes: 3 additions & 0 deletions tests/integration/multi_catalog/seeds/seed.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
id,name,date
1,Alice,2022-01-01
2,Bob,2022-02-01
15 changes: 15 additions & 0 deletions tests/integration/multi_catalog/snapshots/snapshot.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{% snapshot my_snapshot %}

{{
config(
check_cols=["name", "date"],
unique_key="id",
strategy="check",
target_schema=schema,
target_database=env_var('DBT_DATABRICKS_UC_ALTERNATIVE_CATALOG', 'alternative'),
)
}}

select * from {{ ref('seed') }}

{% endsnapshot %}
100 changes: 100 additions & 0 deletions tests/integration/multi_catalog/test_multi_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os

from tests.integration.base import DBTIntegrationTest, use_profile


class TestMultiCatalog(DBTIntegrationTest):
setup_alternate_db = True

@property
def schema(self):
return "multi_catalog"

@property
def models(self):
return "models"

@property
def alternative_database(self):
return os.getenv("DBT_DATABRICKS_UC_ALTERNATIVE_CATALOG", "alternative")

@property
def project_config(self):
return {
"config-version": 2,
"models": {"materialized": "table"},
}

def test_multi_catalog_run(self, seed_catalog):
self.run_dbt(["seed"])

self.assertEqual(len(self.run_dbt(["run"])), 3)
self.assertEqual(len(self.run_dbt(["run"])), 3)

self.assertManyRelationsEqual(
[
("seed", self.unique_schema(), seed_catalog),
("alternative_catalog", self.unique_schema(), self.alternative_database),
("refer_alternative_catalog", self.unique_schema(), self.default_database),
("cross_catalog", self.unique_schema(), self.default_database),
]
)

self.run_dbt(["snapshot"])
self.run_dbt(["snapshot"])

results = self.run_sql(
"select * from {database_schema}.my_snapshot",
fetch="all",
kwargs=dict(database=self.alternative_database),
)
self.assertEqual(len(results), 2)

catalog = self.run_dbt(["docs", "generate"])
assert len(catalog.nodes) == 5


class TestMultiCatalogTableModels(TestMultiCatalog):
@use_profile("databricks_uc_cluster")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would the test behave if we run it against a non-uc cluster?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will just be skipped.

def test_multi_catalog_run_databricks_uc_cluster(self):
self.test_multi_catalog_run(self.default_database)

@use_profile("databricks_uc_sql_endpoint")
def test_multi_catalog_run_databricks_uc_sql_endpoint(self):
self.test_multi_catalog_run(self.default_database)


class TestMultiCatalogViewModels(TestMultiCatalog):
@property
def project_config(self):
return {
"config-version": 2,
"models": {"materialized": "view"},
}

@use_profile("databricks_uc_cluster")
def test_multi_catalog_run_databricks_uc_cluster(self):
self.test_multi_catalog_run(self.default_database)

@use_profile("databricks_uc_sql_endpoint")
def test_multi_catalog_run_databricks_uc_sql_endpoint(self):
self.test_multi_catalog_run(self.default_database)


class TestMultiCatalogSeedsInAlternativeCatalog(TestMultiCatalog):
@property
def project_config(self):
return {
"config-version": 2,
"seeds": {
"catalog": self.alternative_database,
},
}

@use_profile("databricks_uc_cluster")
def test_multi_catalog_run_databricks_uc_cluster(self):
self.test_multi_catalog_run(self.alternative_database)

@use_profile("databricks_uc_sql_endpoint")
def test_multi_catalog_run_databricks_uc_sql_endpoint(self):
self.test_multi_catalog_run(self.alternative_database)