Skip to content

Commit ea03966

Browse files
committed
feat: Add OpenLineage support for CloudSQLExecuteQueryOperator
Signed-off-by: Kacper Muda <mudakacper@gmail.com>
1 parent b32fd1a commit ea03966

File tree

6 files changed

+219
-10
lines changed

6 files changed

+219
-10
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import logging
21+
from typing import TYPE_CHECKING
22+
23+
log = logging.getLogger(__name__)
24+
25+
if TYPE_CHECKING:
26+
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql
27+
28+
else:
29+
try:
30+
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql
31+
except ImportError:
32+
33+
def get_openlineage_facets_with_sql(
34+
hook,
35+
sql: str | list[str],
36+
conn_id: str,
37+
database: str | None,
38+
):
39+
try:
40+
from airflow.providers.openlineage.sqlparser import SQLParser
41+
except ImportError:
42+
log.debug("SQLParser could not be imported from OpenLineage provider.")
43+
return None
44+
45+
try:
46+
from airflow.providers.openlineage.utils.utils import should_use_external_connection
47+
48+
use_external_connection = should_use_external_connection(hook)
49+
except ImportError:
50+
# OpenLineage provider release < 1.8.0 - we always use connection
51+
use_external_connection = True
52+
53+
connection = hook.get_connection(conn_id)
54+
try:
55+
database_info = hook.get_openlineage_database_info(connection)
56+
except AttributeError:
57+
log.debug("%s has no database info provided", hook)
58+
database_info = None
59+
60+
if database_info is None:
61+
return None
62+
63+
try:
64+
sql_parser = SQLParser(
65+
dialect=hook.get_openlineage_database_dialect(connection),
66+
default_schema=hook.get_openlineage_default_schema(),
67+
)
68+
except AttributeError:
69+
log.debug("%s failed to get database dialect", hook)
70+
return None
71+
72+
operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
73+
sql=sql,
74+
hook=hook,
75+
database_info=database_info,
76+
database=database,
77+
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
78+
use_connection=use_external_connection,
79+
)
80+
81+
return operator_lineage
82+
83+
84+
__all__ = ["get_openlineage_facets_with_sql"]

providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
from google.cloud.secretmanager_v1 import AccessSecretVersionResponse
6868
from requests import Session
6969

70+
from airflow.providers.common.sql.hooks.sql import DbApiHook
71+
7072
UNIX_PATH_MAX = 108
7173

7274
# Time to sleep between active checks of the operation results
@@ -1146,7 +1148,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner:
11461148
gcp_conn_id=self.gcp_conn_id,
11471149
)
11481150

1149-
def get_database_hook(self, connection: Connection) -> BaseHook:
1151+
def get_database_hook(self, connection: Connection) -> DbApiHook:
11501152
"""
11511153
Retrieve database hook.
11521154
@@ -1156,7 +1158,7 @@ def get_database_hook(self, connection: Connection) -> BaseHook:
11561158
if self.database_type == "postgres":
11571159
from airflow.providers.postgres.hooks.postgres import PostgresHook
11581160

1159-
db_hook: BaseHook = PostgresHook(connection=connection, database=self.database)
1161+
db_hook: DbApiHook = PostgresHook(connection=connection, database=self.database)
11601162
else:
11611163
from airflow.providers.mysql.hooks.mysql import MySqlHook
11621164

providers/src/airflow/providers/google/cloud/operators/cloud_sql.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from __future__ import annotations
2121

2222
from collections.abc import Iterable, Mapping, Sequence
23+
from contextlib import contextmanager
2324
from functools import cached_property
2425
from typing import TYPE_CHECKING, Any
2526

@@ -38,8 +39,7 @@
3839

3940
if TYPE_CHECKING:
4041
from airflow.models import Connection
41-
from airflow.providers.mysql.hooks.mysql import MySqlHook
42-
from airflow.providers.postgres.hooks.postgres import PostgresHook
42+
from airflow.providers.openlineage.extractors import OperatorLineage
4343
from airflow.utils.context import Context
4444

4545

@@ -1256,7 +1256,8 @@ def __init__(
12561256
self.ssl_client_key = ssl_client_key
12571257
self.ssl_secret_id = ssl_secret_id
12581258

1259-
def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None:
1259+
@contextmanager
1260+
def cloud_sql_proxy_context(self, hook: CloudSQLDatabaseHook):
12601261
cloud_sql_proxy_runner = None
12611262
try:
12621263
if hook.use_proxy:
@@ -1266,27 +1267,27 @@ def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook
12661267
# be taken over here by another bind(0).
12671268
# It's quite unlikely to happen though!
12681269
cloud_sql_proxy_runner.start_proxy()
1269-
self.log.info('Executing: "%s"', self.sql)
1270-
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
1270+
yield
12711271
finally:
12721272
if cloud_sql_proxy_runner:
12731273
cloud_sql_proxy_runner.stop_proxy()
12741274

12751275
def execute(self, context: Context):
1276-
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
1277-
12781276
hook = self.hook
12791277
hook.validate_ssl_certs()
12801278
connection = hook.create_connection()
12811279
hook.validate_socket_path_length()
12821280
database_hook = hook.get_database_hook(connection=connection)
12831281
try:
1284-
self._execute_query(hook, database_hook)
1282+
with self.cloud_sql_proxy_context(hook):
1283+
self.log.info('Executing: "%s"', self.sql)
1284+
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
12851285
finally:
12861286
hook.cleanup_database_hook()
12871287

12881288
@cached_property
12891289
def hook(self):
1290+
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
12901291
return CloudSQLDatabaseHook(
12911292
gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
12921293
gcp_conn_id=self.gcp_conn_id,
@@ -1297,3 +1298,14 @@ def hook(self):
12971298
ssl_key=self.ssl_client_key,
12981299
ssl_secret_id=self.ssl_secret_id,
12991300
)
1301+
1302+
def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None:
1303+
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
1304+
1305+
with self.cloud_sql_proxy_context(self.hook):
1306+
return get_openlineage_facets_with_sql(
1307+
hook=self.hook.db_hook,
1308+
sql=self.sql, # type:ignore[arg-type] # Iterable[str] instead of list[str]
1309+
conn_id=self.gcp_cloudsql_conn_id,
1310+
database=self.hook.database,
1311+
)

providers/src/airflow/providers/google/provider.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ versions:
101101

102102
dependencies:
103103
- apache-airflow>=2.9.0
104+
# - apache-airflow-providers-common-compat>=1.4.0
104105
- apache-airflow-providers-common-compat>=1.3.0
105106
- apache-airflow-providers-common-sql>=1.20.0
106107
- asgiref>=3.5.2

providers/src/airflow/providers/openlineage/sqlparser.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import logging
1920
from typing import TYPE_CHECKING, Callable
2021

2122
import sqlparse
@@ -30,6 +31,7 @@
3031
create_information_schema_query,
3132
get_table_schemas,
3233
)
34+
from airflow.providers.openlineage.utils.utils import should_use_external_connection
3335
from airflow.typing_compat import TypedDict
3436
from airflow.utils.log.logging_mixin import LoggingMixin
3537

@@ -38,6 +40,9 @@
3840
from sqlalchemy.engine import Engine
3941

4042
from airflow.hooks.base import BaseHook
43+
from airflow.providers.common.sql.hooks.sql import DbApiHook
44+
45+
log = logging.getLogger(__name__)
4146

4247
DEFAULT_NAMESPACE = "default"
4348
DEFAULT_INFORMATION_SCHEMA_COLUMNS = [
@@ -397,3 +402,37 @@ def _get_tables_hierarchy(
397402
tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, [])
398403
tables.append(table.name)
399404
return hierarchy
405+
406+
407+
def get_openlineage_facets_with_sql(
408+
hook: DbApiHook, sql: str | list[str], conn_id: str, database: str | None
409+
) -> OperatorLineage | None:
410+
connection = hook.get_connection(conn_id)
411+
try:
412+
database_info = hook.get_openlineage_database_info(connection)
413+
except AttributeError:
414+
database_info = None
415+
416+
if database_info is None:
417+
log.debug("%s has no database info provided", hook)
418+
return None
419+
420+
try:
421+
sql_parser = SQLParser(
422+
dialect=hook.get_openlineage_database_dialect(connection),
423+
default_schema=hook.get_openlineage_default_schema(),
424+
)
425+
except AttributeError:
426+
log.debug("%s failed to get database dialect", hook)
427+
return None
428+
429+
operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
430+
sql=sql,
431+
hook=hook,
432+
database_info=database_info,
433+
database=database,
434+
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
435+
use_connection=should_use_external_connection(hook),
436+
)
437+
438+
return operator_lineage

providers/tests/google/cloud/operators/test_cloud_sql.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,19 @@
1919

2020
import os
2121
from unittest import mock
22+
from unittest.mock import MagicMock
2223

2324
import pytest
2425

2526
from airflow.exceptions import AirflowException, TaskDeferred
2627
from airflow.models import Connection
28+
from airflow.providers.common.compat.openlineage.facet import (
29+
Dataset,
30+
SchemaDatasetFacet,
31+
SchemaDatasetFacetFields,
32+
SQLJobFacet,
33+
)
34+
from airflow.providers.common.sql.hooks.sql import DbApiHook
2735
from airflow.providers.google.cloud.operators.cloud_sql import (
2836
CloudSQLCloneInstanceOperator,
2937
CloudSQLCreateInstanceDatabaseOperator,
@@ -822,3 +830,66 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connection):
822830
operator.execute(None)
823831
err = ctx.value
824832
assert "The UNIX socket path length cannot exceed" in str(err)
833+
834+
@pytest.mark.parametrize(
835+
"connection_port, default_port, expected_port",
836+
[(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
837+
)
838+
def test_execute_openlineage_events(self, connection_port, default_port, expected_port):
839+
class DBApiHookForTests(DbApiHook):
840+
conn_name_attr = "sql_default"
841+
get_conn = MagicMock(name="conn")
842+
get_connection = MagicMock()
843+
844+
def get_openlineage_database_info(self, connection):
845+
from airflow.providers.openlineage.sqlparser import DatabaseInfo
846+
847+
return DatabaseInfo(
848+
scheme="sqlscheme",
849+
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port),
850+
)
851+
852+
dbapi_hook = DBApiHookForTests()
853+
854+
class CloudSQLExecuteQueryOperatorForTest(CloudSQLExecuteQueryOperator):
855+
@property
856+
def hook(self):
857+
return MagicMock(db_hook=dbapi_hook, database="")
858+
859+
sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
860+
order_day_of_week VARCHAR(64) NOT NULL,
861+
order_placed_on TIMESTAMP NOT NULL,
862+
orders_placed INTEGER NOT NULL
863+
);
864+
FORGOT TO COMMENT"""
865+
op = CloudSQLExecuteQueryOperatorForTest(task_id="task_id", sql=sql)
866+
DB_SCHEMA_NAME = "PUBLIC"
867+
rows = [
868+
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"),
869+
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"),
870+
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"),
871+
]
872+
dbapi_hook.get_connection.return_value = Connection(
873+
conn_id="sql_default", conn_type="postgresql", host="host", port=connection_port
874+
)
875+
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]
876+
877+
lineage = op.get_openlineage_facets_on_complete(None)
878+
assert len(lineage.inputs) == 0
879+
assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)}
880+
assert lineage.run_facets["extractionError"].failedTasks == 1
881+
assert lineage.outputs == [
882+
Dataset(
883+
namespace=f"sqlscheme://host:{expected_port}",
884+
name="PUBLIC.popular_orders_day_of_week",
885+
facets={
886+
"schema": SchemaDatasetFacet(
887+
fields=[
888+
SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
889+
SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
890+
SchemaDatasetFacetFields(name="orders_placed", type="int4"),
891+
]
892+
)
893+
},
894+
)
895+
]

0 commit comments

Comments
 (0)