Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 95 additions & 9 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import contextlib
import json
import logging
import re
import uuid
import warnings
from typing import Any
from urllib.parse import quote

import google.oauth2.credentials
Expand All @@ -14,6 +16,7 @@
from google.api_core.client_info import ClientInfo
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel, ValidationError
from sqlalchemy.engine import URL, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

Expand All @@ -33,6 +36,18 @@
from deepnote_toolkit.sql.sql_utils import is_single_select_query
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url

logger = logging.getLogger(__name__)


class IntegrationFederatedAuthParams(BaseModel):
integrationId: str
authContextToken: str


class FederatedAuthResponseData(BaseModel):
integrationType: str
accessToken: str

Comment on lines +42 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider snake_case attributes with camelCase aliases.

The model fields use camelCase, which matches the API but violates Python conventions. For better ergonomics, use snake_case attributes with Pydantic aliases:

from pydantic import BaseModel, Field

class IntegrationFederatedAuthParams(BaseModel):
    integration_id: str = Field(alias="integrationId")
    auth_context_token: str = Field(alias="authContextToken")

class FederatedAuthResponseData(BaseModel):
    integration_type: str = Field(alias="integrationType")
    access_token: str = Field(alias="accessToken")

This allows idiomatic Python access (federated_auth.access_token) while preserving API compatibility.

🤖 Prompt for AI Agents
In deepnote_toolkit/sql/sql_execution.py around lines 42 to 50, the Pydantic
models use camelCase field names which break Python naming conventions; change
fields to snake_case attributes and add Pydantic aliases that map to the
existing camelCase keys (use Field(alias="...") for each field) so external API
compatibility remains while callers can use idiomatic snake_case attribute
access; ensure models still accept population by field name and by alias if
needed (set Config allow_population_by_field_name=True or equivalent) and update
any downstream code that references the old attribute names.


def compile_sql_query(
skip_jinja_template_render,
Expand Down Expand Up @@ -242,11 +257,89 @@ def _generate_temporary_credentials(integration_id):

response = requests.post(url, timeout=10, headers=headers)

response.raise_for_status()

data = response.json()

return quote(data["username"]), quote(data["password"])


def _get_federated_auth_credentials(
integration_id: str, user_pod_auth_context_token: str
) -> FederatedAuthResponseData:
"""Get federated auth credentials for the given integration ID and user pod auth context token."""

url = get_absolute_userpod_api_url(
f"integrations/federated-auth-token/{integration_id}"
)

# Add project credentials in detached mode
headers = get_project_auth_headers()
headers["UserPodAuthContextToken"] = user_pod_auth_context_token

response = requests.post(url, timeout=10, headers=headers)

response.raise_for_status()

data = FederatedAuthResponseData.model_validate(response.json())

return data

Comment on lines +267 to +287
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error handling for JSON decode and validation.

response.json() and model_validate() can raise exceptions that should be caught and logged with context.

🔎 Proposed fix
 def _get_federated_auth_credentials(
     integration_id: str, user_pod_auth_context_token: str
 ) -> FederatedAuthResponseData:
     """Get federated auth credentials for the given integration ID and user pod auth context token."""
 
     url = get_absolute_userpod_api_url(
         f"integrations/federated-auth-token/{integration_id}"
     )
 
     # Add project credentials in detached mode
     headers = get_project_auth_headers()
     headers["UserPodAuthContextToken"] = user_pod_auth_context_token
 
     response = requests.post(url, timeout=10, headers=headers)
 
     response.raise_for_status()
 
+    try:
-        data = FederatedAuthResponseData.model_validate(response.json())
+        data = FederatedAuthResponseData.model_validate(response.json())
+    except (ValueError, ValidationError) as e:
+        logger.error(
+            "Failed to parse federated auth response from %s: %s",
+            url,
+            response.text,
+            exc_info=e,
+        )
+        raise
 
     return data

Committable suggestion skipped: line range outside the PR's diff.


def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Apply IAM credentials to the connection URL in-place."""

if "iamParams" not in sql_alchemy_dict:
return

integration_id = sql_alchemy_dict["iamParams"]["integrationId"]

temporary_username, temporary_password = _generate_temporary_credentials(
integration_id
)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporary_username, temporary_password
)


def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None:
"""Fetch and apply federated auth credentials to connection params in-place."""

if "federatedAuthParams" not in sql_alchemy_dict:
return

try:
federated_auth_params = IntegrationFederatedAuthParams.model_validate(
sql_alchemy_dict["federatedAuthParams"]
)
except ValidationError as e:
logger.error(
"Invalid federated auth params, try updating toolkit version:", exc_info=e
)
return

federated_auth = _get_federated_auth_credentials(
federated_auth_params.integrationId, federated_auth_params.authContextToken
)

if federated_auth.integrationType == "trino":
sql_alchemy_dict["params"]["connect_args"]["http_headers"][
"Authorization"
] = f"Bearer {federated_auth.accessToken}"
elif federated_auth.integrationType == "big-query":
sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken
Comment on lines +326 to +331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing KeyError handling for nested dict access.

Lines 319-321 assume sql_alchemy_dict["params"]["connect_args"]["http_headers"] exists. Will raise KeyError if structure is incomplete.

🔎 Proposed fix
     if federated_auth.integrationType == "trino":
-        sql_alchemy_dict["params"]["connect_args"]["http_headers"][
-            "Authorization"
-        ] = f"Bearer {federated_auth.accessToken}"
+        try:
+            sql_alchemy_dict["params"]["connect_args"]["http_headers"][
+                "Authorization"
+            ] = f"Bearer {federated_auth.accessToken}"
+        except KeyError:
+            logger.error(
+                "Missing required connection structure for Trino federated auth"
+            )
+            return
     elif federated_auth.integrationType == "big-query":
         sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken
🤖 Prompt for AI Agents
In deepnote_toolkit/sql/sql_execution.py around lines 318 to 323 the code
assumes nested keys sql_alchemy_dict["params"]["connect_args"]["http_headers"]
exist and will raise KeyError if the dict structure is incomplete; update the
code to defensively ensure each nested mapping exists (e.g., use dict.setdefault
or check and assign empty dicts for "params", "connect_args" and "http_headers")
before assigning the Authorization header, and likewise ensure "params" exists
before setting "access_token" for big-query so the assignment never raises
KeyError.

elif federated_auth.integrationType == "snowflake":
logger.warning(
"Snowflake federated auth is not supported yet, using the original connection URL"
)
else:
logger.error(
"Unsupported integration type: %s, try updating toolkit version",
federated_auth.integrationType,
)


@contextlib.contextmanager
def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict):
server = None
Expand Down Expand Up @@ -346,16 +439,9 @@ def _query_data_source(
):
sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False)

if "iamParams" in sql_alchemy_dict:
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]
_handle_iam_params(sql_alchemy_dict)

temporaryUsername, temporaryPassword = _generate_temporary_credentials(
integration_id
)

sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
sql_alchemy_dict["url"], temporaryUsername, temporaryPassword
)
_handle_federated_auth_params(sql_alchemy_dict)

with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url:
if url is None:
Expand Down
219 changes: 219 additions & 0 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,222 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
df_cleaned.to_parquet(in_memory_file)
except: # noqa: E722
self.fail(f"serializing to parquet failed for {key}")


class TestFederatedAuth(unittest.TestCase):
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_trino(self, mock_get_credentials):
"""Test that Trino federated auth updates the Authorization header with Bearer token."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return Trino credentials
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="trino",
accessToken="test-trino-access-token",
)

# Create a sql_alchemy_dict with federatedAuthParams and the expected structure
sql_alchemy_dict = {
"url": "trino://user@localhost:8080/catalog",
"params": {
"connect_args": {
"http_headers": {
"Authorization": "Bearer old-token",
}
}
},
"federatedAuthParams": {
"integrationId": "test-integration-id",
"authContextToken": "test-auth-context-token",
},
}

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the API was called with correct params
mock_get_credentials.assert_called_once_with(
"test-integration-id", "test-auth-context-token"
)

# Verify the Authorization header was updated with the new token
self.assertEqual(
sql_alchemy_dict["params"]["connect_args"]["http_headers"]["Authorization"],
"Bearer test-trino-access-token",
)

@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_bigquery(self, mock_get_credentials):
"""Test that BigQuery federated auth updates the access_token in params."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return BigQuery credentials
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="big-query",
accessToken="test-bigquery-access-token",
)

# Create a sql_alchemy_dict with federatedAuthParams
sql_alchemy_dict = {
"url": "bigquery://?user_supplied_client=true",
"params": {
"access_token": "old-access-token",
"project": "test-project",
},
"federatedAuthParams": {
"integrationId": "test-bigquery-integration-id",
"authContextToken": "test-bigquery-auth-context-token",
},
}

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the API was called with correct params
mock_get_credentials.assert_called_once_with(
"test-bigquery-integration-id", "test-bigquery-auth-context-token"
)

# Verify the access_token was updated with the new token
self.assertEqual(
sql_alchemy_dict["params"]["access_token"],
"test-bigquery-access-token",
)

@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_snowflake(self, mock_get_credentials, mock_logger):
"""Test that Snowflake federated auth logs a warning since it's not supported yet."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return Snowflake credentials
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="snowflake",
accessToken="test-snowflake-access-token",
)

# Create a sql_alchemy_dict with federatedAuthParams
sql_alchemy_dict = {
"url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces",
"params": {},
"federatedAuthParams": {
"integrationId": "test-snowflake-integration-id",
"authContextToken": "test-snowflake-auth-context-token",
},
}

# Store original params to verify they remain unchanged
original_params = sql_alchemy_dict["params"].copy()

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the API was called with correct params
mock_get_credentials.assert_called_once_with(
"test-snowflake-integration-id", "test-snowflake-auth-context-token"
)

# Verify a warning was logged
mock_logger.warning.assert_called_once_with(
"Snowflake federated auth is not supported yet, using the original connection URL"
)

# Verify params were NOT modified (snowflake is not supported yet)
self.assertEqual(sql_alchemy_dict["params"], original_params)

def test_federated_auth_params_not_present(self):
"""Test that no action is taken when federatedAuthParams is not present."""
from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params

# Create a sql_alchemy_dict without federatedAuthParams
sql_alchemy_dict = {
"url": "trino://user@localhost:8080/catalog",
"params": {
"connect_args": {
"http_headers": {"Authorization": "Bearer original-token"}
}
},
}

original_dict = json.loads(json.dumps(sql_alchemy_dict))

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify the dict was not modified
self.assertEqual(sql_alchemy_dict, original_dict)

@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
def test_federated_auth_params_invalid_params(self, mock_logger):
"""Test that invalid federated auth params logs an error and returns early."""
from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params

# Create a sql_alchemy_dict with invalid federatedAuthParams (missing required fields)
sql_alchemy_dict = {
"url": "trino://user@localhost:8080/catalog",
"params": {},
"federatedAuthParams": {
"invalidField": "value",
},
}

original_dict = json.loads(json.dumps(sql_alchemy_dict))

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify an error was logged
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args
self.assertIn("Invalid federated auth params", call_args[0][0])

self.assertEqual(sql_alchemy_dict, original_dict)

@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
def test_federated_auth_params_unsupported_integration_type(
self, mock_get_credentials, mock_logger
):
"""Test that unsupported integration type logs an error."""
from deepnote_toolkit.sql.sql_execution import (
FederatedAuthResponseData,
_handle_federated_auth_params,
)

# Setup mock to return unknown integration type
mock_get_credentials.return_value = FederatedAuthResponseData(
integrationType="unknown-database",
accessToken="test-token",
)

# Create a sql_alchemy_dict with federatedAuthParams
sql_alchemy_dict = {
"url": "unknown://host/db",
"params": {},
"federatedAuthParams": {
"integrationId": "test-integration-id",
"authContextToken": "test-auth-context-token",
},
}

original_dict = json.loads(json.dumps(sql_alchemy_dict))

# Call the function
_handle_federated_auth_params(sql_alchemy_dict)

# Verify an error was logged for unsupported integration type
mock_logger.error.assert_called_once_with(
"Unsupported integration type: %s, try updating toolkit version",
"unknown-database",
)

self.assertEqual(sql_alchemy_dict, original_dict)
Loading