66
77import platform
88import struct
9- from typing import Tuple , Dict , Optional , Union
9+ from typing import Tuple , Dict , Optional , List
1010from mssql_python .constants import AuthType
1111
12+
1213class AADAuth :
1314 """Handles Azure Active Directory authentication"""
14-
15+
1516 @staticmethod
1617 def get_token_struct (token : str ) -> bytes :
1718 """Convert token to SQL Server compatible format"""
@@ -21,22 +22,30 @@ def get_token_struct(token: str) -> bytes:
2122 @staticmethod
2223 def get_token (auth_type : str ) -> bytes :
2324 """Get token using the specified authentication type"""
24- from azure .identity import (
25- DefaultAzureCredential ,
26- DeviceCodeCredential ,
27- InteractiveBrowserCredential
28- )
29- from azure .core .exceptions import ClientAuthenticationError
30-
25+ # Import Azure libraries inside method to support test mocking
26+ # pylint: disable=import-outside-toplevel
27+ try :
28+ from azure .identity import (
29+ DefaultAzureCredential ,
30+ DeviceCodeCredential ,
31+ InteractiveBrowserCredential ,
32+ )
33+ from azure .core .exceptions import ClientAuthenticationError
34+ except ImportError as e :
35+ raise RuntimeError (
36+ "Azure authentication libraries are not installed. "
37+ "Please install with: pip install azure-identity azure-core"
38+ ) from e
39+
3140 # Mapping of auth types to credential classes
3241 credential_map = {
3342 "default" : DefaultAzureCredential ,
3443 "devicecode" : DeviceCodeCredential ,
3544 "interactive" : InteractiveBrowserCredential ,
3645 }
37-
46+
3847 credential_class = credential_map [auth_type ]
39-
48+
4049 try :
4150 credential = credential_class ()
4251 token = credential .get_token ("https://database.windows.net/.default" ).token
@@ -50,18 +59,21 @@ def get_token(auth_type: str) -> bytes:
5059 ) from e
5160 except Exception as e :
5261 # Catch any other unexpected exceptions
53- raise RuntimeError (f"Failed to create { credential_class .__name__ } : { e } " ) from e
62+ raise RuntimeError (
63+ f"Failed to create { credential_class .__name__ } : { e } "
64+ ) from e
65+
5466
55- def process_auth_parameters (parameters : list ) -> Tuple [list , Optional [str ]]:
67+ def process_auth_parameters (parameters : List [ str ] ) -> Tuple [List [ str ] , Optional [str ]]:
5668 """
5769 Process connection parameters and extract authentication type.
58-
70+
5971 Args:
6072 parameters: List of connection string parameters
61-
73+
6274 Returns:
6375 Tuple[list, Optional[str]]: Modified parameters and authentication type
64-
76+
6577 Raises:
6678 ValueError: If an invalid authentication type is provided
6779 """
@@ -88,7 +100,7 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
88100 # Interactive authentication (browser-based); only append parameter for non-Windows
89101 if platform .system ().lower () == "windows" :
90102 auth_type = None # Let Windows handle AADInteractive natively
91-
103+
92104 elif value_lower == AuthType .DEVICE_CODE .value :
93105 # Device code authentication (for devices without browser)
94106 auth_type = "devicecode"
@@ -99,40 +111,50 @@ def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
99111
100112 return modified_parameters , auth_type
101113
102- def remove_sensitive_params (parameters : list ) -> list :
114+
115+ def remove_sensitive_params (parameters : List [str ]) -> List [str ]:
103116 """Remove sensitive parameters from connection string"""
104117 exclude_keys = [
105- "uid=" , "pwd=" , "encrypt=" , "trustservercertificate=" , "authentication="
118+ "uid=" ,
119+ "pwd=" ,
120+ "encrypt=" ,
121+ "trustservercertificate=" ,
122+ "authentication=" ,
106123 ]
107124 return [
108- param for param in parameters
125+ param
126+ for param in parameters
109127 if not any (param .lower ().startswith (exclude ) for exclude in exclude_keys )
110128 ]
111129
130+
112131def get_auth_token (auth_type : str ) -> Optional [bytes ]:
113132 """Get authentication token based on auth type"""
114133 if not auth_type :
115134 return None
116-
135+
117136 # Handle platform-specific logic for interactive auth
118137 if auth_type == "interactive" and platform .system ().lower () == "windows" :
119138 return None # Let Windows handle AADInteractive natively
120-
139+
121140 try :
122141 return AADAuth .get_token (auth_type )
123142 except (ValueError , RuntimeError ):
124143 return None
125144
126- def process_connection_string (connection_string : str ) -> Tuple [str , Optional [Dict ]]:
145+
146+ def process_connection_string (
147+ connection_string : str ,
148+ ) -> Tuple [str , Optional [Dict [int , bytes ]]]:
127149 """
128150 Process connection string and handle authentication.
129-
151+
130152 Args:
131153 connection_string: The connection string to process
132-
154+
133155 Returns:
134156 Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed
135-
157+
136158 Raises:
137159 ValueError: If the connection string is invalid or empty
138160 """
@@ -145,9 +167,9 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
145167 raise ValueError ("Connection string cannot be empty" )
146168
147169 parameters = connection_string .split (";" )
148-
170+
149171 # Validate that there's at least one valid parameter
150- if not any ('=' in param for param in parameters ):
172+ if not any ("=" in param for param in parameters ):
151173 raise ValueError ("Invalid connection string format" )
152174
153175 modified_parameters , auth_type = process_auth_parameters (parameters )
@@ -158,4 +180,4 @@ def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dic
158180 if token_struct :
159181 return ";" .join (modified_parameters ) + ";" , {1256 : token_struct }
160182
161- return ";" .join (modified_parameters ) + ";" , None
183+ return ";" .join (modified_parameters ) + ";" , None
0 commit comments