Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,30 @@ def _cancel():
return timer, timeout_event


def _format_query_tag_value(value: str) -> str:
"""
Escape special characters and truncate a single query tag value.

Databricks ``QUERY_TAGS`` uses ``key:value`` pairs delimited by commas, so
backslash, comma and colon inside *values* must be escaped. Values are also
capped at 128 characters before escaping to keep the overall tag string
within reasonable bounds.
"""
value = str(value)[:128]
return value.replace("\\", "\\\\").replace(",", "\\,").replace(":", "\\:")


def _format_query_tags(tags: dict[str, str | None]) -> str:
"""
Serialize a query-tags dict to the ``key:value,key:value`` string expected by ``QUERY_TAGS``.

Entries whose value is ``None`` are omitted.
"""
return ",".join(
f"{key}:{_format_query_tag_value(value)}" for key, value in tags.items() if value is not None
)


class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""
Hook to interact with Databricks SQL.
Expand All @@ -88,6 +112,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
on every request
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
:param query_tags: An optional dict of query tags to attach to every SQL statement executed by
this hook. Tags are injected via the ``QUERY_TAGS`` Databricks session parameter so they
appear in ``system.query.history``. Any existing ``QUERY_TAGS`` already present in
*session_configuration* are preserved and the new tags are appended.
:param kwargs: Additional parameters internal to Databricks SQL Connector parameters
"""

Expand All @@ -103,6 +131,7 @@ def __init__(
http_headers: list[tuple[str, str]] | None = None,
catalog: str | None = None,
schema: str | None = None,
query_tags: dict[str, str | None] | None = None,
caller: str = "DatabricksSqlHook",
**kwargs,
) -> None:
Expand All @@ -118,6 +147,7 @@ def __init__(
self.schema = schema
self.additional_params = kwargs
self.query_ids: list[str] = []
self.query_tags = query_tags

def _get_extra_config(self) -> dict[str, Any | None]:
extra_params = copy(self.databricks_conn.extra_dejson)
Expand Down Expand Up @@ -172,17 +202,27 @@ def get_conn(self) -> AirflowConnection:
if not self._sql_conn or prev_token != new_token:
if self._sql_conn: # close already existing connection
self._sql_conn.close()
session_config: dict[str, str] = dict(self.session_config) if self.session_config else {}
if self.query_tags:
tags_str = _format_query_tags(self.query_tags)
existing = session_config.get("QUERY_TAGS", "")
session_config["QUERY_TAGS"] = f"{existing},{tags_str}" if existing else tags_str

connect_kwargs = {
"schema": self.schema,
"catalog": self.catalog,
"session_configuration": session_config or None,
"http_headers": self.http_headers,
"_user_agent_entry": self.user_agent_value,
**self._get_extra_config(),
**self.additional_params,
}

self._sql_conn = sql.connect(
self.host,
self._http_path,
self._token,
schema=self.schema,
catalog=self.catalog,
session_configuration=self.session_config,
http_headers=self.http_headers,
_user_agent_entry=self.user_agent_value,
**self._get_extra_config(),
**self.additional_params,
**connect_kwargs,
)

if self._sql_conn is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@
_DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/")


def _get_airflow_query_tags(context: Context) -> dict[str, str | None]:
"""Return Airflow context metadata as a query-tags dict."""
task_instance = context["ti"]

def _as_str(value: Any) -> str | None:
return None if value is None else str(value)

return {
"airflow_dag_id": _as_str(task_instance.dag_id),
"airflow_task_id": _as_str(task_instance.task_id),
"airflow_run_id": _as_str(task_instance.run_id),
"airflow_try_number": _as_str(task_instance.try_number),
"airflow_map_index": _as_str(task_instance.map_index),
}


class DatabricksSqlOperator(SQLExecuteQueryOperator):
"""
Executes SQL code in a Databricks SQL endpoint or a Databricks cluster.
Expand All @@ -68,6 +84,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
:param session_configuration: An optional dictionary of Spark session parameters. Defaults to None.
If not specified, it could be specified in the Databricks connection's extra parameters.
:param client_parameters: Additional parameters internal to Databricks SQL Connector parameters
:param query_tags: Optional dictionary of query tags to attach to Databricks SQL queries.
:param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags.
:param http_headers: An optional list of (k, v) pairs that will be set as HTTP headers on every request.
(templated)
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+ (templated)
Expand All @@ -93,6 +111,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
"http_headers",
"databricks_conn_id",
"_gcs_impersonation_chain",
"query_tags",
}
| set(SQLExecuteQueryOperator.template_fields)
)
Expand All @@ -115,6 +134,8 @@ def __init__(
output_format: str = "csv",
csv_params: dict[str, Any] | None = None,
client_parameters: dict[str, Any] | None = None,
query_tags: dict[str, str | None] | None = None,
include_airflow_query_tags: bool = True,
gcp_conn_id: str = "google_cloud_default",
gcs_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
Expand All @@ -132,6 +153,8 @@ def __init__(
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
self.query_tags = query_tags or {}
self.include_airflow_query_tags = include_airflow_query_tags
self._gcp_conn_id = gcp_conn_id
self._gcs_impersonation_chain = gcs_impersonation_chain

Expand Down Expand Up @@ -303,6 +326,20 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen

return list(zip(descriptions, results))

def _get_query_tags(self, context: Context) -> dict[str, str | None] | None:
query_tags: dict[str, str | None] = {}

if self.include_airflow_query_tags and context is not None:
query_tags.update(_get_airflow_query_tags(context))

query_tags.update(self.query_tags)

return query_tags or None

def execute(self, context: Context) -> Any:
self.get_db_hook().query_tags = self._get_query_tags(context)
return super().execute(context)


COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"]

Expand Down Expand Up @@ -335,6 +372,8 @@ class DatabricksCopyIntoOperator(BaseOperator):
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
:param client_parameters: Additional parameters internal to Databricks SQL Connector parameters
:param query_tags: Optional dictionary of query tags to attach to Databricks SQL queries.
:param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags.
:param files: optional list of files to import. Can't be specified together with ``pattern``. (templated)
:param pattern: optional regex string to match file names to import.
Can't be specified together with ``files``.
Expand All @@ -355,6 +394,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
"files",
"table_name",
"databricks_conn_id",
"query_tags",
)

def __init__(
Expand All @@ -381,9 +421,11 @@ def __init__(
force_copy: bool | None = None,
copy_options: dict[str, str] | None = None,
validate: bool | int | None = None,
query_tags: dict[str, str | None] | None = None,
include_airflow_query_tags: bool = True,
**kwargs,
) -> None:
"""Create a new ``DatabricksSqlOperator``."""
"""Create a new ``DatabricksCopyIntoOperator``."""
super().__init__(**kwargs)
if files is not None and pattern is not None:
raise AirflowException("Only one of 'pattern' or 'files' should be specified")
Expand Down Expand Up @@ -413,6 +455,8 @@ def __init__(
self._validate = validate
self._http_headers = http_headers
self._client_parameters = client_parameters or {}
self.query_tags = query_tags or {}
self.include_airflow_query_tags = include_airflow_query_tags
if force_copy is not None:
self._copy_options["force"] = "true" if force_copy else "false"
self._sql: str | None = None
Expand Down Expand Up @@ -514,10 +558,21 @@ def _create_sql_query(self) -> str:
"""
return sql.strip()

def _get_query_tags(self, context: Context) -> dict[str, str | None] | None:
query_tags: dict[str, str | None] = {}

if self.include_airflow_query_tags and context is not None:
query_tags.update(_get_airflow_query_tags(context))

query_tags.update(self.query_tags)

return query_tags or None

def execute(self, context: Context) -> Any:
self._sql = self._create_sql_query()
self.log.info("Executing: %s", self._sql)
hook = self._get_hook()
hook.query_tags = self._get_query_tags(context)
hook.run(self._sql)

def on_kill(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _get_hook(self) -> DatabricksSqlHook:
self.http_headers,
self.catalog,
self.schema,
self.caller,
caller=self.caller,
**self.client_parameters,
**self.hook_params,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def hook(self) -> DatabricksSqlHook:
self.http_headers,
self.catalog,
self.schema,
self.caller,
caller=self.caller,
**self.client_parameters,
**self.hook_params,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from airflow.models import Connection
from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook, create_timeout_thread
from airflow.providers.databricks.hooks.databricks_sql import (
DatabricksSqlHook,
_format_query_tag_value,
_format_query_tags,
create_timeout_thread,
)

TASK_ID = "databricks-sql-operator"
DEFAULT_CONN_ID = "databricks_default"
Expand Down Expand Up @@ -792,3 +797,96 @@ def test_resolve_warehouse_name_empty_response(self, mock_requests):
hook = DatabricksSqlHook(sql_endpoint_name="Test")
with pytest.raises(RuntimeError, match="Can't list Databricks SQL warehouses"):
hook._get_sql_endpoint_by_name("Test")


@mock.patch("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
def test_get_conn_passes_query_tags_via_session_configuration(mock_connect, mock_get_requests):
"""query_tags must be injected into session_configuration['QUERY_TAGS'], not sql.connect(query_tags=)."""
hook = DatabricksSqlHook(
databricks_conn_id=DEFAULT_CONN_ID,
http_path=HTTP_PATH,
query_tags={"airflow_dag_id": "dag_1", "airflow_task_id": "task_1"},
)

hook.get_conn()

mock_connect.assert_called_once()
session_cfg = mock_connect.call_args.kwargs["session_configuration"]
assert session_cfg is not None
assert "QUERY_TAGS" in session_cfg
query_tags_str = session_cfg["QUERY_TAGS"]
assert "airflow_dag_id:dag_1" in query_tags_str
assert "airflow_task_id:task_1" in query_tags_str
assert "query_tags" not in mock_connect.call_args.kwargs


@mock.patch("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
def test_get_conn_merges_query_tags_with_existing_session_configuration(mock_connect, mock_get_requests):
"""Existing QUERY_TAGS in session_configuration must be preserved and new tags appended."""
hook = DatabricksSqlHook(
databricks_conn_id=DEFAULT_CONN_ID,
http_path=HTTP_PATH,
session_configuration={"QUERY_TAGS": "existing_tag:existing_value"},
query_tags={"airflow_dag_id": "dag_1"},
)

hook.get_conn()

mock_connect.assert_called_once()
session_cfg = mock_connect.call_args.kwargs["session_configuration"]
query_tags_str = session_cfg["QUERY_TAGS"]
assert "existing_tag:existing_value" in query_tags_str
assert "airflow_dag_id:dag_1" in query_tags_str


@mock.patch("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
def test_get_conn_no_query_tags(mock_connect, mock_get_requests):
"""When no query_tags are provided, session_configuration should not gain a QUERY_TAGS key."""
hook = DatabricksSqlHook(
databricks_conn_id=DEFAULT_CONN_ID,
http_path=HTTP_PATH,
)

hook.get_conn()

mock_connect.assert_called_once()
session_cfg = mock_connect.call_args.kwargs.get("session_configuration")
assert session_cfg is None or "QUERY_TAGS" not in session_cfg


class TestFormatQueryTags:
def test_simple_values(self):
result = _format_query_tags({"dag_id": "my_dag", "task_id": "my_task"})
assert "dag_id:my_dag" in result
assert "task_id:my_task" in result

def test_none_values_omitted(self):
result = _format_query_tags({"dag_id": "my_dag", "map_index": None})
assert "dag_id:my_dag" in result
assert "map_index" not in result

def test_empty_dict_returns_empty_string(self):
assert _format_query_tags({}) == ""

def test_value_escaping_comma(self):
result = _format_query_tag_value("a,b")
assert result == "a\\,b"

def test_value_escaping_colon(self):
result = _format_query_tag_value("a:b")
assert result == "a\\:b"

def test_value_escaping_backslash(self):
result = _format_query_tag_value("a\\b")
assert result == "a\\\\b"

def test_value_truncated_at_128_chars(self):
long_value = "x" * 200
result = _format_query_tag_value(long_value)
assert len(result) == 128

def test_format_query_tags_roundtrip(self):
tags = {"airflow_dag_id": "dag:1", "airflow_run_id": "run,2"}
result = _format_query_tags(tags)
assert "airflow_dag_id:dag\\:1" in result
assert "airflow_run_id:run\\,2" in result
Loading
Loading