11from __future__ import annotations
2+
23import contextlib
34import re
4- from typing import Any , TypeVar , Callable , Awaitable , Iterator
5+ from typing import Any , Awaitable , Callable , Iterator , TypeVar , Union
56
67import sentry_sdk
78from sentry_sdk .consts import OP , SPANDATA
8- from sentry_sdk .integrations import _check_minimum_version , Integration , DidNotEnable
9+ from sentry_sdk .integrations import DidNotEnable , Integration , _check_minimum_version
10+ from sentry_sdk .traces import StreamedSpan
911from sentry_sdk .tracing import Span
10- from sentry_sdk .tracing_utils import add_query_source , record_sql_queries
12+ from sentry_sdk .tracing_utils import (
13+ add_query_source ,
14+ has_span_streaming_enabled ,
15+ record_sql_queries_supporting_streaming ,
16+ )
1117from sentry_sdk .utils import (
18+ capture_internal_exceptions ,
1219 ensure_integration_enabled ,
1320 parse_version ,
14- capture_internal_exceptions ,
1521)
1622
1723try :
@@ -62,7 +68,8 @@ def _normalize_query(query: str) -> str:
6268
6369def _wrap_execute (f : "Callable[..., Awaitable[T]]" ) -> "Callable[..., Awaitable[T]]" :
6470 async def _inner (* args : "Any" , ** kwargs : "Any" ) -> "T" :
65- if sentry_sdk .get_client ().get_integration (AsyncPGIntegration ) is None :
71+ client = sentry_sdk .get_client ()
72+ if client .get_integration (AsyncPGIntegration ) is None :
6673 return await f (* args , ** kwargs )
6774
6875 # Avoid recording calls to _execute twice.
@@ -73,7 +80,7 @@ async def _inner(*args: "Any", **kwargs: "Any") -> "T":
7380 return await f (* args , ** kwargs )
7481
7582 query = _normalize_query (args [1 ])
76- with record_sql_queries (
83+ with record_sql_queries_supporting_streaming (
7784 cursor = None ,
7885 query = query ,
7986 params_list = None ,
@@ -82,9 +89,13 @@ async def _inner(*args: "Any", **kwargs: "Any") -> "T":
8289 span_origin = AsyncPGIntegration .origin ,
8390 ) as span :
8491 res = await f (* args , ** kwargs )
92+ if isinstance (span , StreamedSpan ):
93+ with capture_internal_exceptions ():
94+ add_query_source (span )
8595
86- with capture_internal_exceptions ():
87- add_query_source (span )
96+ if not isinstance (span , StreamedSpan ):
97+ with capture_internal_exceptions ():
98+ add_query_source (span )
8899
89100 return res
90101
@@ -101,15 +112,16 @@ def _record(
101112 params_list : "tuple[Any, ...] | None" ,
102113 * ,
103114 executemany : bool = False ,
104- ) -> "Iterator[Span]" :
105- integration = sentry_sdk .get_client ().get_integration (AsyncPGIntegration )
115+ ) -> "Iterator[Union[Span, StreamedSpan]]" :
116+ client = sentry_sdk .get_client ()
117+ integration = client .get_integration (AsyncPGIntegration )
106118 if integration is not None and not integration ._record_params :
107119 params_list = None
108120
109121 param_style = "pyformat" if params_list else None
110122
111123 query = _normalize_query (query )
112- with record_sql_queries (
124+ with record_sql_queries_supporting_streaming (
113125 cursor = cursor ,
114126 query = query ,
115127 params_list = params_list ,
@@ -152,7 +164,6 @@ def _inner(*args: "Any", **kwargs: "Any") -> "T": # noqa: N807
152164 ) as span :
153165 _set_db_data (span , args [0 ])
154166 res = f (* args , ** kwargs )
155- span .set_data ("db.cursor" , res )
156167
157168 return res
158169
@@ -163,56 +174,85 @@ def _wrap_connect_addr(
163174 f : "Callable[..., Awaitable[T]]" ,
164175) -> "Callable[..., Awaitable[T]]" :
165176 async def _inner (* args : "Any" , ** kwargs : "Any" ) -> "T" :
166- if sentry_sdk .get_client ().get_integration (AsyncPGIntegration ) is None :
177+ client = sentry_sdk .get_client ()
178+ if client .get_integration (AsyncPGIntegration ) is None :
167179 return await f (* args , ** kwargs )
168180
169181 user = kwargs ["params" ].user
170182 database = kwargs ["params" ].database
171-
172- with sentry_sdk .start_span (
173- op = OP .DB ,
174- name = "connect" ,
175- origin = AsyncPGIntegration .origin ,
176- ) as span :
177- span .set_data (SPANDATA .DB_SYSTEM , "postgresql" )
178- addr = kwargs .get ("addr" )
183+ addr = kwargs .get ("addr" )
184+
185+ if has_span_streaming_enabled (client .options ):
186+ span_attributes = {
187+ "sentry.op" : OP .DB ,
188+ "sentry.origin" : AsyncPGIntegration .origin ,
189+ SPANDATA .DB_SYSTEM : "postgresql" ,
190+ SPANDATA .DB_USER : user ,
191+ SPANDATA .DB_NAME : database ,
192+ SPANDATA .DB_DRIVER_NAME : "asyncpg" ,
193+ }
179194 if addr :
180195 try :
181- span . set_data ( SPANDATA .SERVER_ADDRESS , addr [0 ])
182- span . set_data ( SPANDATA .SERVER_PORT , addr [1 ])
196+ span_attributes [ SPANDATA .SERVER_ADDRESS ] = addr [0 ]
197+ span_attributes [ SPANDATA .SERVER_PORT ] = addr [1 ]
183198 except IndexError :
184199 pass
185- span .set_data (SPANDATA .DB_NAME , database )
186- span .set_data (SPANDATA .DB_USER , user )
187- span .set_data (SPANDATA .DB_DRIVER_NAME , "asyncpg" )
188200
189- with capture_internal_exceptions ():
190- sentry_sdk .add_breadcrumb (
191- message = "connect" , category = "query" , data = span ._data
192- )
193- res = await f (* args , ** kwargs )
201+ with sentry_sdk .traces .start_span (name = "connect" ) as span :
202+ span .set_attributes (span_attributes )
203+
204+ with capture_internal_exceptions ():
205+ sentry_sdk .add_breadcrumb (
206+ message = "connect" , category = "query" , data = span_attributes
207+ )
208+ res = await f (* args , ** kwargs )
209+
210+ else :
211+ with sentry_sdk .start_span (
212+ op = OP .DB ,
213+ name = "connect" ,
214+ origin = AsyncPGIntegration .origin ,
215+ ) as span :
216+ span .set_data (SPANDATA .DB_SYSTEM , "postgresql" )
217+ if addr :
218+ try :
219+ span .set_data (SPANDATA .SERVER_ADDRESS , addr [0 ])
220+ span .set_data (SPANDATA .SERVER_PORT , addr [1 ])
221+ except IndexError :
222+ pass
223+ span .set_data (SPANDATA .DB_NAME , database )
224+ span .set_data (SPANDATA .DB_USER , user )
225+ span .set_data (SPANDATA .DB_DRIVER_NAME , "asyncpg" )
226+
227+ with capture_internal_exceptions ():
228+ sentry_sdk .add_breadcrumb (
229+ message = "connect" , category = "query" , data = span ._data
230+ )
231+ res = await f (* args , ** kwargs )
194232
195233 return res
196234
197235 return _inner
198236
199237
200- def _set_db_data (span : "Span" , conn : "Any" ) -> None :
201- span .set_data (SPANDATA .DB_SYSTEM , "postgresql" )
202- span .set_data (SPANDATA .DB_DRIVER_NAME , "asyncpg" )
238+ def _set_db_data (span : "Union[Span, StreamedSpan]" , conn : "Any" ) -> None :
239+ set_value = span .set_attribute if isinstance (span , StreamedSpan ) else span .set_data
240+
241+ set_value (SPANDATA .DB_SYSTEM , "postgresql" )
242+ set_value (SPANDATA .DB_DRIVER_NAME , "asyncpg" )
203243
204244 addr = conn ._addr
205245 if addr :
206246 try :
207- span . set_data (SPANDATA .SERVER_ADDRESS , addr [0 ])
208- span . set_data (SPANDATA .SERVER_PORT , addr [1 ])
247+ set_value (SPANDATA .SERVER_ADDRESS , addr [0 ])
248+ set_value (SPANDATA .SERVER_PORT , addr [1 ])
209249 except IndexError :
210250 pass
211251
212252 database = conn ._params .database
213253 if database :
214- span . set_data (SPANDATA .DB_NAME , database )
254+ set_value (SPANDATA .DB_NAME , database )
215255
216256 user = conn ._params .user
217257 if user :
218- span . set_data (SPANDATA .DB_USER , user )
258+ set_value (SPANDATA .DB_USER , user )
0 commit comments