diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 02892dc..6e65589 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -1,9 +1,11 @@ import base64 import contextlib import json +import logging import re import uuid import warnings +from typing import Any from urllib.parse import quote import google.oauth2.credentials @@ -14,6 +16,7 @@ from google.api_core.client_info import ClientInfo from google.cloud import bigquery from packaging.version import parse as parse_version +from pydantic import BaseModel, ValidationError from sqlalchemy.engine import URL, create_engine, make_url from sqlalchemy.exc import ResourceClosedError @@ -33,6 +36,18 @@ from deepnote_toolkit.sql.sql_utils import is_single_select_query from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url +logger = logging.getLogger(__name__) + + +class IntegrationFederatedAuthParams(BaseModel): + integrationId: str + authContextToken: str + + +class FederatedAuthResponseData(BaseModel): + integrationType: str + accessToken: str + def compile_sql_query( skip_jinja_template_render, @@ -242,11 +257,89 @@ def _generate_temporary_credentials(integration_id): response = requests.post(url, timeout=10, headers=headers) + response.raise_for_status() + data = response.json() return quote(data["username"]), quote(data["password"]) +def _get_federated_auth_credentials( + integration_id: str, user_pod_auth_context_token: str +) -> FederatedAuthResponseData: + """Get federated auth credentials for the given integration ID and user pod auth context token.""" + + url = get_absolute_userpod_api_url( + f"integrations/federated-auth-token/{integration_id}" + ) + + # Add project credentials in detached mode + headers = get_project_auth_headers() + headers["UserPodAuthContextToken"] = user_pod_auth_context_token + + response = requests.post(url, timeout=10, headers=headers) + + response.raise_for_status() + + data = FederatedAuthResponseData.model_validate(response.json()) + + return data + + +def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None: + """Apply IAM credentials to the connection URL in-place.""" + + if "iamParams" not in sql_alchemy_dict: + return + + integration_id = sql_alchemy_dict["iamParams"]["integrationId"] + + temporary_username, temporary_password = _generate_temporary_credentials( + integration_id + ) + + sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( + sql_alchemy_dict["url"], temporary_username, temporary_password + ) + + +def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: + """Fetch and apply federated auth credentials to connection params in-place.""" + + if "federatedAuthParams" not in sql_alchemy_dict: + return + + try: + federated_auth_params = IntegrationFederatedAuthParams.model_validate( + sql_alchemy_dict["federatedAuthParams"] + ) + except ValidationError as e: + logger.error( + "Invalid federated auth params, try updating toolkit version:", exc_info=e + ) + return + + federated_auth = _get_federated_auth_credentials( + federated_auth_params.integrationId, federated_auth_params.authContextToken + ) + + if federated_auth.integrationType == "trino": + sql_alchemy_dict["params"]["connect_args"]["http_headers"][ + "Authorization" + ] = f"Bearer {federated_auth.accessToken}" + elif federated_auth.integrationType == "big-query": + sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken + elif federated_auth.integrationType == "snowflake": + logger.warning( + "Snowflake federated auth is not supported yet, using the original connection URL" + ) + else: + logger.error( + "Unsupported integration type: %s, try updating toolkit version", + federated_auth.integrationType, + ) + + @contextlib.contextmanager def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict): server = None @@ -346,16 +439,9 @@ def _query_data_source( ): sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) - if "iamParams" in sql_alchemy_dict: - integration_id = sql_alchemy_dict["iamParams"]["integrationId"] + _handle_iam_params(sql_alchemy_dict) - temporaryUsername, temporaryPassword = _generate_temporary_credentials( - integration_id - ) - - sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( - sql_alchemy_dict["url"], temporaryUsername, temporaryPassword - ) + _handle_federated_auth_params(sql_alchemy_dict) with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url: if url is None: diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index 5a2c41f..e516a50 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -585,3 +585,222 @@ def test_all_dataframes_serialize_to_parquet(self, key, df): df_cleaned.to_parquet(in_memory_file) except: # noqa: E722 self.fail(f"serializing to parquet failed for {key}") + + +class TestFederatedAuth(unittest.TestCase): + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_trino(self, mock_get_credentials): + """Test that Trino federated auth updates the Authorization header with Bearer token.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return Trino credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="trino", + accessToken="test-trino-access-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams and the expected structure + sql_alchemy_dict = { + "url": "trino://user@localhost:8080/catalog", + "params": { + "connect_args": { + "http_headers": { + "Authorization": "Bearer old-token", + } + } + }, + "federatedAuthParams": { + "integrationId": "test-integration-id", + "authContextToken": "test-auth-context-token", + }, + } + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-integration-id", "test-auth-context-token" + ) + + # Verify the Authorization header was updated with the new token + self.assertEqual( + sql_alchemy_dict["params"]["connect_args"]["http_headers"]["Authorization"], + "Bearer test-trino-access-token", + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_bigquery(self, mock_get_credentials): + """Test that BigQuery federated auth updates the access_token in params.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return BigQuery credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="big-query", + accessToken="test-bigquery-access-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams + sql_alchemy_dict = { + "url": "bigquery://?user_supplied_client=true", + "params": { + "access_token": "old-access-token", + "project": "test-project", + }, + "federatedAuthParams": { + "integrationId": "test-bigquery-integration-id", + "authContextToken": "test-bigquery-auth-context-token", + }, + } + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-bigquery-integration-id", "test-bigquery-auth-context-token" + ) + + # Verify the access_token was updated with the new token + self.assertEqual( + sql_alchemy_dict["params"]["access_token"], + "test-bigquery-access-token", + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_snowflake(self, mock_get_credentials, mock_logger): + """Test that Snowflake federated auth logs a warning since it's not supported yet.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return Snowflake credentials + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="snowflake", + accessToken="test-snowflake-access-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams + sql_alchemy_dict = { + "url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces", + "params": {}, + "federatedAuthParams": { + "integrationId": "test-snowflake-integration-id", + "authContextToken": "test-snowflake-auth-context-token", + }, + } + + # Store original params to verify they remain unchanged + original_params = sql_alchemy_dict["params"].copy() + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the API was called with correct params + mock_get_credentials.assert_called_once_with( + "test-snowflake-integration-id", "test-snowflake-auth-context-token" + ) + + # Verify a warning was logged + mock_logger.warning.assert_called_once_with( + "Snowflake federated auth is not supported yet, using the original connection URL" + ) + + # Verify params were NOT modified (snowflake is not supported yet) + self.assertEqual(sql_alchemy_dict["params"], original_params) + + def test_federated_auth_params_not_present(self): + """Test that no action is taken when federatedAuthParams is not present.""" + from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params + + # Create a sql_alchemy_dict without federatedAuthParams + sql_alchemy_dict = { + "url": "trino://user@localhost:8080/catalog", + "params": { + "connect_args": { + "http_headers": {"Authorization": "Bearer original-token"} + } + }, + } + + original_dict = json.loads(json.dumps(sql_alchemy_dict)) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify the dict was not modified + self.assertEqual(sql_alchemy_dict, original_dict) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + def test_federated_auth_params_invalid_params(self, mock_logger): + """Test that invalid federated auth params logs an error and returns early.""" + from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params + + # Create a sql_alchemy_dict with invalid federatedAuthParams (missing required fields) + sql_alchemy_dict = { + "url": "trino://user@localhost:8080/catalog", + "params": {}, + "federatedAuthParams": { + "invalidField": "value", + }, + } + + original_dict = json.loads(json.dumps(sql_alchemy_dict)) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify an error was logged + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args + self.assertIn("Invalid federated auth params", call_args[0][0]) + + self.assertEqual(sql_alchemy_dict, original_dict) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + @mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials") + def test_federated_auth_params_unsupported_integration_type( + self, mock_get_credentials, mock_logger + ): + """Test that unsupported integration type logs an error.""" + from deepnote_toolkit.sql.sql_execution import ( + FederatedAuthResponseData, + _handle_federated_auth_params, + ) + + # Setup mock to return unknown integration type + mock_get_credentials.return_value = FederatedAuthResponseData( + integrationType="unknown-database", + accessToken="test-token", + ) + + # Create a sql_alchemy_dict with federatedAuthParams + sql_alchemy_dict = { + "url": "unknown://host/db", + "params": {}, + "federatedAuthParams": { + "integrationId": "test-integration-id", + "authContextToken": "test-auth-context-token", + }, + } + + original_dict = json.loads(json.dumps(sql_alchemy_dict)) + + # Call the function + _handle_federated_auth_params(sql_alchemy_dict) + + # Verify an error was logged for unsupported integration type + mock_logger.error.assert_called_once_with( + "Unsupported integration type: %s, try updating toolkit version", + "unknown-database", + ) + + self.assertEqual(sql_alchemy_dict, original_dict)