diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 4d8990c44..e4affbde6 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -22,7 +22,7 @@ from dbt.adapters.base import AdapterConfig, PythonJobHelper from dbt.adapters.base.impl import catch_as_completed from dbt.adapters.base.meta import available -from dbt.adapters.base.relation import BaseRelation +from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.capability import CapabilityDict, CapabilitySupport, Support, Capability from dbt.adapters.spark.impl import ( SparkAdapter, @@ -59,7 +59,7 @@ CURRENT_CATALOG_MACRO_NAME = "current_catalog" USE_CATALOG_MACRO_NAME = "use_catalog" - +GET_CATALOG_MACRO_NAME = "get_catalog" SHOW_TABLE_EXTENDED_MACRO_NAME = "show_table_extended" SHOW_TABLES_MACRO_NAME = "show_tables" SHOW_VIEWS_MACRO_NAME = "show_views" @@ -449,16 +449,54 @@ def parse_columns_from_information( # type: ignore[override] columns.append(column) return columns - def get_catalog( - self, manifest: Manifest, selected_nodes: Optional[Set[Any]] = None - ) -> Tuple[Table, List[Exception]]: - if selected_nodes: - relations: Set[BaseRelation] = { - self.Relation.create_from(self.config, n) for n in selected_nodes - } - else: - relations = set(self._get_catalog_relations(manifest)) - return self.get_catalog_by_relations(manifest, relations) + def get_catalog(self, manifest: Manifest) -> Tuple[Table, List[Exception]]: # type: ignore + schema_map = self._get_catalog_schemas(manifest) + + with executor(self.config) as tpe: + futures: List[Future[Table]] = [] + for info, schemas in schema_map.items(): + if is_hive_metastore(info.database): + for schema in schemas: + futures.append( + tpe.submit_connected( + self, + "hive_metastore", + self._get_hive_catalog, + schema, + "*", + ) + ) + else: + name = ".".join([str(info.database), "information_schema"]) + fut = tpe.submit_connected( + self, + name, + self._get_one_unity_catalog, + info, + schemas, + manifest, + ) + futures.append(fut) + catalogs, exceptions = catch_as_completed(futures) + return catalogs, exceptions + + def _get_one_unity_catalog( + self, info: InformationSchema, schemas: Set[str], manifest: Manifest + ) -> Table: + kwargs = { + "information_schema": info, + "schemas": schemas, + } + table = self.execute_macro( + GET_CATALOG_MACRO_NAME, + kwargs=kwargs, + # pass in the full manifest, so we get any local project + # overrides + manifest=manifest, + ) + + results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type] + return results def get_catalog_by_relations( self, manifest: Manifest, relations: Set[BaseRelation] @@ -475,13 +513,14 @@ def get_catalog_by_relations( schema_map[relation.schema].append(relation) for schema, schema_relations in schema_map.items(): + table_names = extract_identifiers(schema_relations) futures.append( tpe.submit_connected( self, "hive_metastore", self._get_hive_catalog, schema, - schema_relations, + get_identifier_list_string(table_names), ) ) else: @@ -502,16 +541,15 @@ def get_catalog_by_relations( def _get_hive_catalog( self, schema: str, - relations: Set[BaseRelation], + identifier: str, ) -> Table: - table_names = extract_identifiers(relations) columns: List[Dict[str, Any]] = [] - if len(table_names) > 0: + if identifier: schema_relation = self.Relation.create( database="hive_metastore", schema=schema, - identifier=get_identifier_list_string(table_names), + identifier=identifier, quote_policy=self.config.quoting, ) for relation, information in self._list_relations_with_information(schema_relation): diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 29dd60634..07046f98b 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Set, Type +from typing import Any, Dict, Iterable, Optional, Set, Type from dbt.contracts.relation import ( ComponentName, ) @@ -141,5 +141,5 @@ def is_hive_metastore(database: Optional[str]) -> bool: return database is None or database.lower() == "hive_metastore" -def extract_identifiers(relations: Set[BaseRelation]) -> Set[str]: +def extract_identifiers(relations: Iterable[BaseRelation]) -> Set[str]: return {r.identifier for r in relations if r.identifier is not None} diff --git a/dbt/include/databricks/macros/catalog.sql b/dbt/include/databricks/macros/catalog.sql index 0749bca95..144122466 100644 --- a/dbt/include/databricks/macros/catalog.sql +++ b/dbt/include/databricks/macros/catalog.sql @@ -19,6 +19,27 @@ {% endcall %} {% endmacro %} +{% macro get_catalog(information_schema, schemas) -%} + {{ return(adapter.dispatch('get_catalog', 'dbt')(information_schema, schemas)) }} +{% endmacro %} + +{% macro databricks__get_catalog(information_schema, schemas) -%} + + {% set query %} + with tables as ( + {{ databricks__get_catalog_tables_sql(information_schema) }} + {{ databricks__get_catalog_schemas_where_clause_sql(schemas) }} + ), + columns as ( + {{ databricks__get_catalog_columns_sql(information_schema) }} + {{ databricks__get_catalog_schemas_where_clause_sql(schemas) }} + ) + {{ databricks__get_catalog_results_sql() }} + {%- endset -%} + + {{ return(run_query(query)) }} +{%- endmacro %} + {% macro databricks__get_catalog_relations(information_schema, relations) -%} {% set query %} @@ -72,7 +93,7 @@ {% macro databricks__get_catalog_schemas_where_clause_sql(schemas) -%} where ({%- for schema in schemas -%} - upper(table_schema) = upper('{{ schema }}'){%- if not loop.last %} or {% endif -%} + table_schema = lower('{{ schema }}'){%- if not loop.last %} or {% endif -%} {%- endfor -%}) {%- endmacro %} @@ -82,12 +103,12 @@ {%- for relation in relations -%} {% if relation.schema and relation.identifier %} ( - upper(table_schema) = upper('{{ relation.schema }}') - and upper(table_name) = upper('{{ relation.identifier }}') + table_schema = lower('{{ relation.schema }}') + and table_name = lower('{{ relation.identifier }}') ) {% elif relation.schema %} ( - upper(table_schema) = upper('{{ relation.schema }}') + table_schema = lower('{{ relation.schema }}') ) {% else %} {% do exceptions.raise_compiler_error(