Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import google.oauth2.credentials
import numpy as np
import requests
from requests.adapters import HTTPAdapter, Retry
import wrapt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand Down Expand Up @@ -263,13 +264,28 @@ class ExecuteSqlError(Exception):
)


def _create_retry_session() -> requests.Session:
"""Create a requests session with retry on 5xx for POST requests."""
session = requests.Session()
retries = Retry(
total=3,
backoff_factor=0.5,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["POST"],
)
Comment thread
tkislan marked this conversation as resolved.
session.mount("http://", HTTPAdapter(max_retries=retries))
session.mount("https://", HTTPAdapter(max_retries=retries))
return session
Comment thread
tkislan marked this conversation as resolved.


def _generate_temporary_credentials(integration_id):
Comment thread
tkislan marked this conversation as resolved.
Outdated
url = get_absolute_userpod_api_url(f"integrations/credentials/{integration_id}")

# Add project credentials in detached mode
headers = get_project_auth_headers()

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand All @@ -291,7 +307,8 @@ def _get_federated_auth_credentials(
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)
session = _create_retry_session()
response = session.post(url, timeout=10, headers=headers)

response.raise_for_status()

Expand Down
94 changes: 90 additions & 4 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
class TestFederatedAuth(unittest.TestCase):
@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
@mock.patch("deepnote_toolkit.sql.sql_execution.requests.post")
@mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session")
def test_get_federated_auth_credentials_returns_validated_response(
self, mock_post, mock_get_url, mock_get_headers
self, mock_create_session, mock_get_url, mock_get_headers
):
"""Test that _get_federated_auth_credentials properly validates and returns response data."""
from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials
Expand All @@ -603,12 +603,14 @@ def test_get_federated_auth_credentials_returns_validated_response(
mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-integration-id"
mock_get_headers.return_value = {"Authorization": "Bearer project-token"}

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"integrationType": "trino",
"accessToken": "test-access-token-123",
}
mock_post.return_value = mock_response
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session

# Call the function
result = _get_federated_auth_credentials(
Expand All @@ -621,7 +623,7 @@ def test_get_federated_auth_credentials_returns_validated_response(
)

# Verify headers include both project auth and user pod auth context token
mock_post.assert_called_once_with(
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/federated-auth-token/test-integration-id",
timeout=10,
headers={
Expand Down Expand Up @@ -1019,3 +1021,87 @@ def test_databricks_connector_dialect_alias_is_registered(self):

self.assertEqual(url.drivername, "databricks+connector")
self.assertIsNotNone(dialect_cls)


class TestCreateRetrySession(unittest.TestCase):
def test_retry_session_has_correct_config(self):
"""Test that _create_retry_session configures retries correctly."""
from deepnote_toolkit.sql.sql_execution import _create_retry_session

session = _create_retry_session()

# Check that both http and https adapters are mounted with retry config
for prefix in ("http://", "https://"):
adapter = session.get_adapter(prefix)
retries = adapter.max_retries
self.assertEqual(retries.total, 3)
self.assertEqual(retries.backoff_factor, 0.5)
self.assertEqual(list(retries.status_forcelist), [500, 502, 503, 504])
self.assertIn("POST", retries.allowed_methods)
Comment thread
tkislan marked this conversation as resolved.
Outdated

@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_generate_temporary_credentials_uses_retry_session(
self, mock_get_url, mock_get_headers
):
"""Test that _generate_temporary_credentials uses a retry session."""
from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials

mock_get_url.return_value = "https://api.example.com/integrations/credentials/test-id"
mock_get_headers.return_value = {"Authorization": "Bearer token"}

with mock.patch(
"deepnote_toolkit.sql.sql_execution._create_retry_session"
) as mock_create_session:
mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"username": "user",
"password": "pass",
}
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

_generate_temporary_credentials("test-id")

mock_create_session.assert_called_once()
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/credentials/test-id",
timeout=10,
headers={"Authorization": "Bearer token"},
)

@mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers")
@mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url")
def test_get_federated_auth_credentials_uses_retry_session(
self, mock_get_url, mock_get_headers
):
"""Test that _get_federated_auth_credentials uses a retry session."""
from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials

mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-id"
mock_get_headers.return_value = {"Authorization": "Bearer token"}

with mock.patch(
"deepnote_toolkit.sql.sql_execution._create_retry_session"
) as mock_create_session:
mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.json.return_value = {
"integrationType": "trino",
"accessToken": "test-token",
}
mock_session.post.return_value = mock_response
mock_create_session.return_value = mock_session

_get_federated_auth_credentials("test-id", "auth-context-token")

mock_create_session.assert_called_once()
mock_session.post.assert_called_once_with(
"https://api.example.com/integrations/federated-auth-token/test-id",
timeout=10,
headers={
"Authorization": "Bearer token",
"UserPodAuthContextToken": "auth-context-token",
},
)
Loading