From abdff7e7184cf7279d390ca572e1171018bd0fc5 Mon Sep 17 00:00:00 2001 From: Andre Furlan Date: Fri, 28 Apr 2023 15:18:10 -0700 Subject: [PATCH 1/2] OAuth support (#307) * oauth support --------- Signed-off-by: Andre Furlan --- .gitignore | 3 +- dbt/adapters/databricks/auth.py | 78 ++++++++++ dbt/adapters/databricks/connections.py | 133 +++++++++++++++++- dbt/adapters/databricks/python_submissions.py | 7 +- dbt/include/databricks/profile_template.yml | 8 +- dev-requirements.txt | 4 +- docs/oauth.md | 26 ++++ requirements.txt | 4 +- setup.py | 2 + tests/profiles.py | 3 + tests/unit/test_adapter.py | 81 ++++++++++- tests/unit/test_auth.py | 73 ++++++++++ 12 files changed, 408 insertions(+), 14 deletions(-) create mode 100644 dbt/adapters/databricks/auth.py create mode 100644 docs/oauth.md create mode 100644 tests/unit/test_auth.py diff --git a/.gitignore b/.gitignore index dc3c25ac5..8c1b1b2bc 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ test.env .vscode *.log logs/ -.venv \ No newline at end of file +.venv +.venv2 \ No newline at end of file diff --git a/dbt/adapters/databricks/auth.py b/dbt/adapters/databricks/auth.py new file mode 100644 index 000000000..48dce0e83 --- /dev/null +++ b/dbt/adapters/databricks/auth.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, Optional +from databricks.sdk.oauth import ClientCredentials, Token, TokenSource +from databricks.sdk.core import CredentialsProvider, HeaderFactory, Config, credentials_provider + + +class token_auth(CredentialsProvider): + _token: str + + def __init__(self, token: str) -> None: + self._token = token + + def auth_type(self) -> str: + return "token" + + def as_dict(self) -> dict: + return {"token": self._token} + + @staticmethod + def from_dict(raw: Optional[dict]) -> CredentialsProvider: + if not raw: + return None + return token_auth(raw["token"]) + + def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory: + static_credentials = {"Authorization": f"Bearer {self._token}"} + + def inner() -> Dict[str, str]: + return static_credentials + + return inner + + +class m2m_auth(CredentialsProvider): + _token_source: TokenSource = None + + def __init__(self, host: str, client_id: str, client_secret: str) -> None: + @credentials_provider("noop", []) + def noop_credentials(_: Any): # type: ignore + return lambda: {} + + config = Config(host=host, credentials_provider=noop_credentials) + oidc = config.oidc_endpoints + scopes = ["offline_access", "all-apis"] + if not oidc: + raise ValueError(f"{host} does not support OAuth") + if config.is_azure: + # Azure AD only supports full access to Azure Databricks. + scopes = [f"{config.effective_azure_login_app_id}/.default", "offline_access"] + self._token_source = ClientCredentials( + client_id=client_id, + client_secret=client_secret, + token_url=oidc.token_endpoint, + scopes=scopes, + use_header="microsoft" not in oidc.token_endpoint, + use_params="microsoft" in oidc.token_endpoint, + ) + + def auth_type(self) -> str: + return "oauth" + + def as_dict(self) -> dict: + if self._token_source: + return {"token": self._token_source.token().as_dict()} + else: + return {"token": {}} + + @staticmethod + def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> CredentialsProvider: + c = m2m_auth(host=host, client_id=client_id, client_secret=client_secret) + c._token_source._token = Token.from_dict(raw["token"]) + return c + + def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory: + def inner() -> Dict[str, str]: + token = self._token_source.token() + return {"Authorization": f"{token.token_type} {token.access_token}"} + + return inner diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 76a62aaff..ba936e694 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -6,6 +6,7 @@ import os import re import sys +import threading import time from typing import ( Any, @@ -49,6 +50,12 @@ from dbt.adapters.databricks.__version__ import version as __version__ from dbt.adapters.databricks.utils import redact_credentials +from databricks.sdk.core import CredentialsProvider +from databricks.sdk.oauth import OAuthClient, RefreshableCredentials +from dbt.adapters.databricks.auth import token_auth, m2m_auth + +import keyring + logger = AdapterLogger("Databricks") CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" @@ -58,6 +65,10 @@ EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" +REDIRECT_URL = "http://localhost:8020" +CLIENT_ID = "dbt-databricks" +SCOPES = ["all-apis", "offline_access"] + @dataclass class DatabricksCredentials(Credentials): @@ -65,13 +76,19 @@ class DatabricksCredentials(Credentials): host: Optional[str] = None http_path: Optional[str] = None token: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None session_properties: Optional[Dict[str, Any]] = None connection_parameters: Optional[Dict[str, Any]] = None + auth_type: Optional[str] = None connect_retries: int = 1 connect_timeout: Optional[int] = None retry_all: bool = False + _credentials_provider: Optional[Dict[str, Any]] = None + _lock = threading.Lock() # to avoid concurrent auth + _ALIASES = { "catalog": "database", "target_catalog": "target_database", @@ -116,6 +133,8 @@ def __post_init__(self) -> None: "server_hostname", "http_path", "access_token", + "client_id", + "client_secret", "session_configuration", "catalog", "schema", @@ -138,11 +157,23 @@ def __post_init__(self) -> None: self.connection_parameters = connection_parameters def validate_creds(self) -> None: - for key in ["host", "http_path", "token"]: + for key in ["host", "http_path"]: if not getattr(self, key): raise dbt.exceptions.DbtProfileError( "The config '{}' is required to connect to Databricks".format(key) ) + if not self.token and self.auth_type != "oauth": + raise dbt.exceptions.DbtProfileError( + ("The config `auth_type: oauth` is required when not using access token") + ) + + if not self.client_id and self.client_secret: + raise dbt.exceptions.DbtProfileError( + ( + "The config 'client_id' is required to connect " + "to Databricks when 'client_secret' is present" + ) + ) @classmethod def get_invocation_env(cls) -> Optional[str]: @@ -232,6 +263,100 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: def cluster_id(self) -> Optional[str]: return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] + def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider: + self.validate_creds() + host: str = self.host or "" + if self._credentials_provider: + return self._provider_from_dict() + if in_provider: + self._credentials_provider = in_provider.as_dict() + return in_provider + + # dbt will spin up multiple threads. This has to be sync. So lock here + self._lock.acquire() + try: + if self.token: + provider = token_auth(self.token) + self._credentials_provider = provider.as_dict() + return provider + + if self.client_id and self.client_secret: + provider = m2m_auth( + host=host, + client_id=self.client_id or "", + client_secret=self.client_secret or "", + ) + self._credentials_provider = provider.as_dict() + return provider + + oauth_client = OAuthClient( + host=host, + client_id=self.client_id if self.client_id else CLIENT_ID, + client_secret=None, + redirect_url=REDIRECT_URL, + scopes=SCOPES, + ) + # optional branch. Try and keep going if it does not work + try: + # try to get cached credentials + credsdict = keyring.get_password("dbt-databricks", host) + + if credsdict: + provider = RefreshableCredentials.from_dict(oauth_client, json.loads(credsdict)) + # if refresh token is expired, this will throw + try: + if provider.token().valid: + return provider + except Exception as e: + logger.debug(e) + # whatever it is, get rid of the cache + keyring.delete_password("dbt-databricks", host) + + # error with keyring. Maybe machine has no password persistency + except Exception as e: + logger.debug(e) + logger.info("could not retrieved saved token") + + # no token, go fetch one + consent = oauth_client.initiate_consent() + + provider = consent.launch_external_browser() + # save for later + self._credentials_provider = provider.as_dict() + try: + keyring.set_password("dbt-databricks", host, json.dumps(self._credentials_provider)) + # error with keyring. Maybe machine has no password persistency + except Exception as e: + logger.debug(e) + logger.info("could not save token") + + return provider + + finally: + self._lock.release() + + def _provider_from_dict(self) -> CredentialsProvider: + if self.token: + return token_auth.from_dict(self._credentials_provider) + + if self.client_id and self.client_secret: + return m2m_auth.from_dict( + host=self.host or "", + client_id=self.client_id or "", + client_secret=self.client_secret or "", + raw=self._credentials_provider or {"token": {}}, + ) + + oauth_client = OAuthClient( + host=self.host, + client_id=CLIENT_ID, + client_secret=None, + redirect_url=REDIRECT_URL, + scopes=SCOPES, + ) + + return RefreshableCredentials.from_dict(client=oauth_client, raw=self._credentials_provider) + class DatabricksSQLConnectionWrapper: """Wrap a Databricks SQL connector in a way that no-ops transactions""" @@ -404,6 +529,7 @@ def _get_comment_macro(self) -> Optional[str]: class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" + credentials_provider: CredentialsProvider = None def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) @@ -549,7 +675,8 @@ def open(cls, connection: Connection) -> Connection: creds: DatabricksCredentials = connection.credentials timeout = creds.connect_timeout - creds.validate_creds() + # gotta keep this so we don't prompt users many times + cls.credentials_provider = creds.authenticate(cls.credentials_provider) user_agent_entry = f"dbt-databricks/{__version__}" @@ -569,7 +696,7 @@ def connect() -> DatabricksSQLConnectionWrapper: conn: DatabricksSQLConnection = dbsql.connect( server_hostname=creds.host, http_path=creds.http_path, - access_token=creds.token, + credentials_provider=cls.credentials_provider, http_headers=http_headers if http_headers else None, session_configuration=creds.session_properties, catalog=creds.database, diff --git a/dbt/adapters/databricks/python_submissions.py b/dbt/adapters/databricks/python_submissions.py index 75dc603de..45aa088d3 100644 --- a/dbt/adapters/databricks/python_submissions.py +++ b/dbt/adapters/databricks/python_submissions.py @@ -12,6 +12,7 @@ import dbt.exceptions from dbt.adapters.base import PythonJobHelper from dbt.adapters.spark import __version__ +from databricks.sdk.core import CredentialsProvider logger = AdapterLogger("Databricks") @@ -381,6 +382,7 @@ def submit(self, compiled_code: str) -> None: class DbtDatabricksBasePythonJobHelper(BaseDatabricksHelper): credentials: DatabricksCredentials # type: ignore[assignment] + _credentials_provider: CredentialsProvider = None def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: super().__init__( @@ -400,8 +402,11 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No http_headers: Dict[str, str] = credentials.get_all_http_headers( connection_parameters.pop("http_headers", {}) ) + self._credentials_provider = credentials.authenticate(self._credentials_provider) + header_factory = self._credentials_provider() + headers = header_factory() - self.auth_header.update({"User-Agent": user_agent, **http_headers}) + self.auth_header.update({"User-Agent": user_agent, **http_headers, **headers}) @property def cluster_id(self) -> Optional[str]: # type: ignore[override] diff --git a/dbt/include/databricks/profile_template.yml b/dbt/include/databricks/profile_template.yml index db8d8bb85..6ce610541 100644 --- a/dbt/include/databricks/profile_template.yml +++ b/dbt/include/databricks/profile_template.yml @@ -5,9 +5,11 @@ prompts: hint: yourorg.databricks.com http_path: hint: 'HTTP Path' - token: - hint: 'dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX' - hide_input: true + _choose_access_token: + 'use access token': + token: + hint: 'dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX' + hide_input: true _choose_unity_catalog: 'use Unity Catalog': catalog: diff --git a/dev-requirements.txt b/dev-requirements.txt index 6567167a0..b13d124fb 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -23,5 +23,5 @@ tox>=3.2.0 types-requests dbt-spark==1.4.* -dbt-core==1.4.* -dbt-tests-adapter==1.4.* \ No newline at end of file +# dbt-core==1.4.* +dbt-tests-adapter>=1.4.0 \ No newline at end of file diff --git a/docs/oauth.md b/docs/oauth.md new file mode 100644 index 000000000..f56379573 --- /dev/null +++ b/docs/oauth.md @@ -0,0 +1,26 @@ +# Configure OAuth for DBT Databricks + +This feature is in [Public Preview](https://docs.databricks.com/release-notes/release-types.html). + +Databricks DBT adapter now supports authentication via OAuth in AWS and Azure. This is a much safer method as it enables you to generate short-lived (one hour) OAuth access tokens, which eliminates the risk of accidentally exposing longer-lived tokens such as Databricks personal access tokens through version control checkins or other means. OAuth also enables better server-side session invalidation and scoping. + +Once an admin correctly configured OAuth in Databricks, you can simply add the config `auth_type` and set it to `oauth`. Config `token` is no longer necessary. + +For Azure, you admin needs to create a Public AD application for dbt and provide you with its client_id. + +``` YAML +jaffle_shop: + outputs: + dev: + host: + http_path: + catalog: + schema: + auth_type: oauth # new + client_id: # only necessary for Azure + type: databricks + target: dev +``` + + + diff --git a/requirements.txt b/requirements.txt index 726694017..478bd90b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ databricks-sql-connector>=2.5.0 -dbt-spark==1.4.* +dbt-spark>=1.4.0 +databricks-sdk>=0.1.1 +keyring>=23.13.* \ No newline at end of file diff --git a/setup.py b/setup.py index e7288e9aa..e4fa51dc9 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,8 @@ def _get_plugin_version(): install_requires=[ "dbt-spark~={}".format(dbt_spark_version), "databricks-sql-connector>=2.5.0", + "databricks-sdk>=0.1.1", + "keyring>=23.13.0" ], zip_safe=False, classifiers=[ diff --git a/tests/profiles.py b/tests/profiles.py index 14a0e7be7..3cfdac9cf 100644 --- a/tests/profiles.py +++ b/tests/profiles.py @@ -25,9 +25,12 @@ def _build_databricks_cluster_target( "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), "http_path": http_path, "token": os.getenv("DBT_DATABRICKS_TOKEN"), + "client_id": os.getenv("DBT_DATABRICKS_CLIENT_ID"), + "client_secret": os.getenv("DBT_DATABRICKS_CLIENT_SECRET"), "connect_retries": 3, "connect_timeout": 5, "retry_all": True, + "auth_type": "oauth", } if catalog is not None: profile["catalog"] = catalog diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index f4abe0a7f..b02520233 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -50,6 +50,42 @@ def _get_target_databricks_sql_connector(self, project): }, ) + def _get_target_databricks_sql_connector_no_token(self, project): + return config_from_parts_or_dicts( + project, + { + "outputs": { + "test": { + "type": "databricks", + "schema": "analytics", + "host": "yourorg.databricks.com", + "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", + "session_properties": {"spark.sql.ansi.enabled": "true"}, + } + }, + "target": "test", + }, + ) + + def _get_target_databricks_sql_connector_client_creds(self, project): + return config_from_parts_or_dicts( + project, + { + "outputs": { + "test": { + "type": "databricks", + "schema": "analytics", + "host": "yourorg.databricks.com", + "http_path": "sql/protocolv1/o/1234567890123456/1234-567890-test123", + "client_id": "foo", + "client_secret": "bar", + "session_properties": {"spark.sql.ansi.enabled": "true"}, + } + }, + "target": "test", + }, + ) + def _get_target_databricks_sql_connector_catalog(self, project): return config_from_parts_or_dicts( project, @@ -264,21 +300,60 @@ def _test_environment_http_headers( connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load + @unittest.skip("not ready") + def test_oauth_settings(self): + config = self._get_target_databricks_sql_connector_no_token(self.project_cfg) + + adapter = DatabricksAdapter(config) + + with mock.patch( + "dbt.adapters.databricks.connections.dbsql.connect", + new=self._connect_func(expected_no_token=True), + ): + connection = adapter.acquire_connection("dummy") + connection.handle # trigger lazy-load + + @unittest.skip("not ready") + def test_client_creds_settings(self): + config = self._get_target_databricks_sql_connector_client_creds(self.project_cfg) + + adapter = DatabricksAdapter(config) + + with mock.patch( + "dbt.adapters.databricks.connections.dbsql.connect", + new=self._connect_func(expected_client_creds=True), + ): + connection = adapter.acquire_connection("dummy") + connection.handle # trigger lazy-load + def _connect_func( - self, *, expected_catalog=None, expected_invocation_env=None, expected_http_headers=None + self, + *, + expected_catalog=None, + expected_invocation_env=None, + expected_http_headers=None, + expected_no_token=None, + expected_client_creds=None, ): def connect( server_hostname, http_path, - access_token, + credentials_provider, http_headers, session_configuration, catalog, _user_agent_entry, + **kwargs, ): self.assertEqual(server_hostname, "yourorg.databricks.com") self.assertEqual(http_path, "sql/protocolv1/o/1234567890123456/1234-567890-test123") - self.assertEqual(access_token, "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX") + if not (expected_no_token or expected_client_creds): + self.assertEqual( + credentials_provider._token, "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + ) + if expected_client_creds: + self.assertEqual(kwargs.get("client_id"), "foo") + self.assertEqual(kwargs.get("client_secret"), "bar") self.assertEqual(session_configuration["spark.sql.ansi.enabled"], "true") if expected_catalog is None: self.assertIsNone(catalog) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 000000000..c887b547c --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,73 @@ +import unittest +from dbt.adapters.databricks.connections import DatabricksCredentials +import pytest + + +@pytest.mark.skip(reason="Need to mock requests to OIDC") +class TestM2MAuth(unittest.TestCase): + def test_m2m(self): + host = "my.cloud.databricks.com" + creds = DatabricksCredentials( + host=host, + http_path="http://foo", + client_id="my-client-id", + client_secret="my-client-secret", + database="andre", + schema="dbt", + ) + provider = creds.authenticate(None) + self.assertIsNotNone(provider) + headers_fn = provider() + headers = headers_fn() + self.assertIsNotNone(headers) + + raw = provider.as_dict() + self.assertIsNotNone(raw) + + provider_b = creds._provider_from_dict() + headers_fn2 = provider_b() + headers2 = headers_fn2() + self.assertEqual(headers, headers2) + + +@pytest.mark.skip(reason="Need to mock requests to OIDC and mock opening browser") +class TestU2MAuth(unittest.TestCase): + def test_u2m(self): + host = "my.cloud.databricks.com" + creds = DatabricksCredentials( + host=host, database="andre", http_path="http://foo", schema="dbt" + ) + provider = creds.authenticate(None) + self.assertIsNotNone(provider) + headers_fn = provider() + headers = headers_fn() + self.assertIsNotNone(headers) + + raw = provider.as_dict() + self.assertIsNotNone(raw) + + provider_b = creds._provider_from_dict() + headers_fn2 = provider_b() + headers2 = headers_fn2() + self.assertEqual(headers, headers2) + + +class TestTokenAuth(unittest.TestCase): + def test_token(self): + host = "my.cloud.databricks.com" + creds = DatabricksCredentials( + host=host, token="foo", database="andre", http_path="http://foo", schema="dbt" + ) + provider = creds.authenticate(None) + self.assertIsNotNone(provider) + headers_fn = provider() + headers = headers_fn() + self.assertIsNotNone(headers) + + raw = provider.as_dict() + self.assertIsNotNone(raw) + + provider_b = creds._provider_from_dict() + headers_fn2 = provider_b() + headers2 = headers_fn2() + self.assertEqual(headers, headers2) From d435c1c54614aea361d1bc9a1121ee255f94230f Mon Sep 17 00:00:00 2001 From: Jesse Date: Mon, 1 May 2023 13:59:12 -0500 Subject: [PATCH 2/2] Update .gitignore Signed-off-by: Jesse Whitehouse --- .gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 8c1b1b2bc..de15912ce 100644 --- a/.gitignore +++ b/.gitignore @@ -16,5 +16,4 @@ test.env .vscode *.log logs/ -.venv -.venv2 \ No newline at end of file +.venv* \ No newline at end of file