From 29c2fda94e04097fe1e678add6c969a817e1c4fe Mon Sep 17 00:00:00 2001 From: akash-vijay-kv Date: Wed, 20 May 2026 14:14:10 +0530 Subject: [PATCH 01/13] [NET-920] feat: Add utility to explicitly set exceptions on a span (#289) --- CHANGELOG.md | 2 +- netra/__init__.py | 41 ++++++++++++++++++++++++++++++++++++++++ netra/session_manager.py | 29 ++++++++++++++++++++++++++++ netra/span_wrapper.py | 29 +++++++++++++++++++++++++--- 4 files changed, 97 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13a342c5..865c9e2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -289,4 +289,4 @@ Users can be now overwrite the input and ouput attributes of spans created by in - Added utility to set input and output data for any active span in a trace -[0.1.86]: https://github.com/KeyValueSoftwareSystems/netra-sdk-py/tree/main +[0.1.87]: https://github.com/KeyValueSoftwareSystems/netra-sdk-py/tree/main diff --git a/netra/__init__.py b/netra/__init__.py index bf4efa91..5022d24d 100644 --- a/netra/__init__.py +++ b/netra/__init__.py @@ -409,6 +409,47 @@ def set_custom_event(cls, event_name: str, attributes: Any) -> None: else: logger.warning("Both event_name and attributes must be provided for custom events.") + @classmethod + def record_exception( + cls, + exception: BaseException, + attributes: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a caught exception on the currently active span. + + Use this inside ``except`` blocks to attach exception details to the + current span when the exception is being handled and will not propagate + to the SDK's automatic capture logic. + + The method adds a standard OpenTelemetry exception event (with type, + message, and stacktrace), sets the span status to ERROR, and records + the ``netra.error_message`` attribute for consistency with the rest of + the SDK. + + Args: + exception: The exception instance to record. + attributes: Optional extra attributes to attach to the exception + event. + + Example:: + + @workflow + def process_order(order_id: str) -> str: + try: + result = call_payment_api(order_id) + except PaymentError as exc: + Netra.record_exception(exc) + return "fallback_result" + return result + """ + if not isinstance(exception, BaseException): + logger.error( + "record_exception: exception must be a BaseException instance, got %s", + type(exception), + ) + return + SessionManager.record_exception(exception, attributes=attributes) + @classmethod def add_conversation(cls, conversation_type: ConversationType, role: str, content: Any) -> None: """ diff --git a/netra/session_manager.py b/netra/session_manager.py index 7eee85e2..930d07ca 100644 --- a/netra/session_manager.py +++ b/netra/session_manager.py @@ -492,6 +492,35 @@ def set_attribute_on_root_span(cls, attr_key: str, attr_value: Any) -> None: except Exception: logger.exception("Failed to set attribute '%s' on root span", attr_key) + @staticmethod + def record_exception( + exception: BaseException, + attributes: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a caught exception on the currently active span. + + Adds a standard OTel exception event to the span and marks its status + as ERROR. Intended to be called from within user exception-handling + blocks where the exception would otherwise not propagate to the SDK's + automatic capture logic. + + Args: + exception: The exception instance to record. + attributes: Optional extra attributes to attach to the exception + event. + """ + try: + span = trace.get_current_span() + if not (span and getattr(span, "is_recording", lambda: False)()): + logger.warning("record_exception: no active recording span to record exception on") + return + + span.record_exception(exception, attributes=attributes) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(exception))) + span.set_attribute(f"{Config.LIBRARY_NAME}.error_message", str(exception)) + except Exception: + logger.exception("Failed to record exception on active span") + @staticmethod def set_attribute_on_active_span(attr_key: str, attr_value: Any) -> None: """ diff --git a/netra/span_wrapper.py b/netra/span_wrapper.py index edb7986b..77843a44 100644 --- a/netra/span_wrapper.py +++ b/netra/span_wrapper.py @@ -160,9 +160,12 @@ def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_t # Handle status and errors if exc_type is None and self.status == "pending": - self.status = "success" - if self.span: - self.span.set_status(Status(StatusCode.OK)) + if self._span_has_error_status(): + self.status = "error" + else: + self.status = "success" + if self.span: + self.span.set_status(Status(StatusCode.OK)) elif exc_type is not None: self.status = "error" self.error_message = str(exc_val) @@ -210,6 +213,26 @@ def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_t # Don't suppress exceptions return False + def _span_has_error_status(self) -> Any: + """Check whether the underlying OTel span has already been set to ERROR. + + This handles the case where ``Netra.record_exception()`` (or any direct + OTel API call) marked the span as errored while the exception was caught + by user code, so ``__exit__`` receives ``exc_type is None``. + + Returns: + True if the span's status code is ERROR, False otherwise. + """ + if self.span is None: + return False + try: + status = getattr(self.span, "status", None) + if status is not None: + return status.status_code == StatusCode.ERROR + except Exception: + logger.exception("Failed to check span status on span '%s'", self.name) + return False + def set_attribute(self, key: str, value: str) -> "SpanWrapper": """ Set a single attribute and return self for method chaining. From 58a93c64ad8bd546c7ab49353cb690da993e52bf Mon Sep 17 00:00:00 2001 From: akash-vijay-kv Date: Mon, 8 Jun 2026 16:42:22 +0530 Subject: [PATCH 02/13] [NET-856] feat: Add file handling support in the simulation workflow (#290) --- netra/__init__.py | 6 + netra/simulation/__init__.py | 4 + netra/simulation/api.py | 113 +++-- netra/simulation/client.py | 158 ++++--- netra/simulation/constants.py | 38 ++ netra/simulation/models.py | 46 +- netra/simulation/task.py | 49 ++- netra/simulation/utils.py | 121 +++++- poetry.lock | 83 +++- pyproject.toml | 1 + tests/test_simulation.py | 783 ++++++++++++++++++++++++++++++++++ 11 files changed, 1273 insertions(+), 129 deletions(-) create mode 100644 netra/simulation/constants.py create mode 100644 tests/test_simulation.py diff --git a/netra/__init__.py b/netra/__init__.py index 5022d24d..cb544fb0 100644 --- a/netra/__init__.py +++ b/netra/__init__.py @@ -291,6 +291,12 @@ def shutdown(cls) -> None: meter_provider.shutdown() except Exception: pass + # Close simulation HTTP client + if hasattr(cls, "simulation") and cls.simulation is not None: + try: + cls.simulation.close() + except Exception: + pass @classmethod def get_meter(cls, name: str = "netra", version: Optional[str] = None) -> otel_metrics.Meter: diff --git a/netra/simulation/__init__.py b/netra/simulation/__init__.py index 79c7cfd8..efcb9d02 100644 --- a/netra/simulation/__init__.py +++ b/netra/simulation/__init__.py @@ -2,6 +2,8 @@ from netra.simulation.models import ( ConversationResponse, ConversationStatus, + FileData, + ProcessedFile, SimulationItem, TaskResult, ) @@ -12,6 +14,8 @@ "BaseTask", "ConversationResponse", "ConversationStatus", + "FileData", + "ProcessedFile", "SimulationItem", "TaskResult", ] diff --git a/netra/simulation/api.py b/netra/simulation/api.py index dfea287e..3092fd4e 100644 --- a/netra/simulation/api.py +++ b/netra/simulation/api.py @@ -1,5 +1,3 @@ -"""Public API for running multi-turn conversation simulations.""" - import asyncio import concurrent.futures import logging @@ -8,7 +6,8 @@ from netra.config import Config from netra.simulation.client import SimulationHttpClient -from netra.simulation.models import SimulationItem +from netra.simulation.constants import DEFAULT_MAX_TURNS, LOG_PREFIX, SPAN_NAME +from netra.simulation.models import ConversationStatus, FileData, SimulationItem from netra.simulation.task import BaseTask from netra.simulation.utils import ( execute_task, @@ -20,9 +19,6 @@ logger = logging.getLogger(__name__) -_LOG_PREFIX = "netra.simulation" -_SPAN_NAME = "Netra.Simulation.TestRun" - class Simulation: """Public API for running multi-turn conversation simulations. @@ -43,6 +39,10 @@ def __init__(self, config: Config) -> None: self._config = config self._client = SimulationHttpClient(config) + def close(self) -> None: + """Release resources held by the simulation client.""" + self._client.close() + def run_simulation( self, name: str, @@ -50,16 +50,18 @@ def run_simulation( task: BaseTask, context: Optional[dict[str, Any]] = None, max_concurrency: int = 5, + max_turns: int = DEFAULT_MAX_TURNS, ) -> Optional[dict[str, Any]]: """Run a multi-turn conversation simulation. Args: name: Name of the simulation run. dataset_id: Identifier of the dataset to simulate. - task: A BaseTask instance whose run() method receives (message, session_id) + task: A BaseTask instance whose run() method receives (message, session_id, files) and returns TaskResult. Can be sync or async. context: Optional context data for the simulation. max_concurrency: Maximum parallel executions (default: 5). + max_turns: Maximum conversation turns per item before aborting (default: 50). Returns: Dictionary with simulation results, or None on failure. @@ -77,51 +79,60 @@ def run_simulation( return None run_id = run_result.get("run_id") - run_items = run_result.get("simulation_items") - if not run_items: - logger.error("%s: No items returned from create_run", _LOG_PREFIX) + simulation_items = run_result.get("simulation_items") + if not simulation_items: + logger.error("%s: No items returned from create_run", LOG_PREFIX) return None - logger.info("%s: Starting simulation with %d items", _LOG_PREFIX, len(run_items)) + logger.info("%s: Starting simulation with %d items", LOG_PREFIX, len(simulation_items)) try: result = run_async_safely( - self._run_simulation_async(run_id, run_items, task, max_concurrency) # type:ignore[arg-type] + self._run_simulation_async( + run_id, simulation_items, task, max_concurrency, max_turns # type:ignore[arg-type] + ) ) elapsed_time = time.time() - start_time - logger.info("%s: Simulation completed in %.2f seconds", _LOG_PREFIX, elapsed_time) + logger.info("%s: Simulation completed in %.2f seconds", LOG_PREFIX, elapsed_time) self._client.post_run_status(run_id, "completed") # type:ignore[arg-type] return result - except BaseException: - logger.error("%s: Run simulation failed", _LOG_PREFIX) + except Exception: + logger.error("%s: Run simulation failed", LOG_PREFIX, exc_info=True) self._client.post_run_status(run_id, "failed") # type:ignore[arg-type] return None async def _run_simulation_async( self, run_id: str, - run_items: list[SimulationItem], + simulation_items: list[SimulationItem], task: BaseTask, max_concurrency: int, + max_turns: int, ) -> dict[str, Any]: - """Async implementation of run_simulation with semaphore-based concurrency. + """Orchestrate concurrent simulation execution. + + Each simulation item is dispatched to a thread via ``run_in_executor``. + Inside each thread, ``run_async_safely`` creates a **new** event loop + so that async user tasks (``BaseTask.run``) work correctly without + nesting into the orchestrator's loop. This two-level design lets us + honour ``max_concurrency`` while supporting both sync and async tasks + transparently. Args: run_id: The simulation run identifier. - run_items: List of simulation items to process. + simulation_items: List of simulation items to process. task: The BaseTask instance to execute (sync or async). max_concurrency: Maximum concurrent executions. + max_turns: Maximum conversation turns per item. Returns: Dictionary with simulation results. """ - - max_workers = min(5, max_concurrency) results: dict[str, Any] = { "success": True, "completed": [], "failed": [], - "total_items": len(run_items), + "total_items": len(simulation_items), } processed_count = 0 lock = asyncio.Lock() @@ -129,8 +140,7 @@ async def _run_simulation_async( loop = asyncio.get_running_loop() def run_item_in_thread(run_item: SimulationItem) -> dict[str, Any]: - """ - Run a single simulation item in a thread. + """Run a single simulation item in a dedicated thread/event-loop. Args: run_item: The simulation item to run. @@ -138,11 +148,10 @@ def run_item_in_thread(run_item: SimulationItem) -> dict[str, Any]: Returns: Dictionary with simulation result. """ - return run_async_safely(self._execute_conversation(run_id, run_item, task)) + return run_async_safely(self._execute_conversation(run_id, run_item, task, max_turns)) async def process_item(run_item: SimulationItem) -> None: - """ - Process a single simulation item and handle its completion. + """Process a single simulation item and record its outcome. Args: run_item: The simulation item to process. @@ -155,14 +164,14 @@ async def process_item(run_item: SimulationItem) -> None: processed_count += 1 logger.info( "%s: %d/%d processed (run_item_id=%s)", - _LOG_PREFIX, + LOG_PREFIX, processed_count, - len(run_items), + len(simulation_items), run_item.run_item_id, ) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - tasks = [asyncio.create_task(process_item(run_item)) for run_item in run_items] + with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor: + tasks = [asyncio.create_task(process_item(item)) for item in simulation_items] try: await asyncio.gather(*tasks) except (asyncio.CancelledError, KeyboardInterrupt): @@ -172,7 +181,7 @@ async def process_item(run_item: SimulationItem) -> None: executor.shutdown(wait=False, cancel_futures=True) logger.info( "%s: Completed=%d, Failed=%d", - _LOG_PREFIX, + LOG_PREFIX, len(results["completed"]), len(results["failed"]), ) @@ -183,13 +192,15 @@ async def _execute_conversation( run_id: str, run_item: SimulationItem, task: BaseTask, - ) -> Any: + max_turns: int, + ) -> dict[str, Any]: """Execute a multi-turn conversation for a single simulation item. Args: run_id: The simulation run identifier. run_item: The simulation item to process. task: The BaseTask instance to execute (sync or async). + max_turns: Safety limit on the number of conversation turns. Returns: Dictionary with execution result including success status. @@ -197,27 +208,30 @@ async def _execute_conversation( run_item_id = run_item.run_item_id message = run_item.message turn_id = run_item.turn_id + raw_files: list[FileData] = run_item.files session_id: Optional[str] = None - while True: + for turn_number in range(1, max_turns + 1): try: - with SpanWrapper(_SPAN_NAME, module_name=_LOG_PREFIX) as span: + with SpanWrapper(SPAN_NAME, module_name=LOG_PREFIX) as span: trace_id = "" otel_span = span.get_current_span() if otel_span: span_context = otel_span.get_span_context() trace_id = format_trace_id(span_context.trace_id) - response_message, task_session_id = await execute_task(task, message, session_id) + response_message, task_session_id = await execute_task( + task, message, session_id, raw_files=raw_files + ) if task_session_id: session_id = task_session_id - response = self._client.trigger_conversation( - message=response_message, - turn_id=turn_id, - session_id=session_id or "", - trace_id=trace_id, - ) + response = self._client.trigger_conversation( + message=response_message, + turn_id=turn_id, + session_id=session_id or "", + trace_id=trace_id, + ) if response is None: error_msg = "Failed to get conversation response" @@ -228,10 +242,10 @@ async def _execute_conversation( "turn_id": turn_id, } - if response.decision == "stop": + if response.decision == ConversationStatus.STOP: logger.info( "%s: Completed run_item_id=%s reason=%s", - _LOG_PREFIX, + LOG_PREFIX, run_item_id, response.reason, ) @@ -243,12 +257,13 @@ async def _execute_conversation( message = response.next_user_message # type:ignore[assignment] turn_id = response.next_turn_id # type:ignore[assignment] + raw_files = response.next_files except Exception as exc: error_msg = str(exc) logger.error( "%s: Task failed run_item_id=%s, turn_id=%s: %s", - _LOG_PREFIX, + LOG_PREFIX, run_item_id, turn_id, error_msg, @@ -260,3 +275,13 @@ async def _execute_conversation( "error": error_msg, "turn_id": turn_id, } + + error_msg = f"Exceeded maximum turns ({max_turns}) for run_item_id={run_item_id}" + logger.error("%s: %s", LOG_PREFIX, error_msg) + self._client.report_failure(run_id=run_id, run_item_id=run_item_id, error=error_msg) + return { + "run_item_id": run_item_id, + "success": False, + "error": error_msg, + "turn_id": turn_id, + } diff --git a/netra/simulation/client.py b/netra/simulation/client.py index d4951854..2a2cf17e 100644 --- a/netra/simulation/client.py +++ b/netra/simulation/client.py @@ -1,19 +1,24 @@ -"""HTTP client for simulation API endpoints.""" - import logging -import os from typing import Any, Optional import httpx from netra.config import Config -from netra.simulation.models import ConversationResponse, SimulationItem +from netra.simulation.constants import ( + DEFAULT_TIMEOUT, + ENV_TIMEOUT, + LOG_PREFIX, + TELEMETRY_SUFFIX, + URL_AGENT_RESPONSE, + URL_CREATE_RUN, + URL_RUN_ITEM_STATUS, + URL_RUN_STATUS, +) +from netra.simulation.models import ConversationResponse, ConversationStatus, FileData, SimulationItem +from netra.simulation.utils import parse_env_float logger = logging.getLogger(__name__) -_DEFAULT_TIMEOUT = 10.0 -_LOG_PREFIX = "netra.simulation" - class SimulationHttpClient: """Internal HTTP client for simulation API endpoints. @@ -32,6 +37,26 @@ def __init__(self, config: Config) -> None: """ self._client: Optional[httpx.Client] = self._create_client(config) + def close(self) -> None: + """Close the underlying HTTP client and release connection resources.""" + if self._client: + try: + self._client.close() + except Exception: + logger.debug("%s: Error closing HTTP client", LOG_PREFIX, exc_info=True) + finally: + self._client = None + + def _ensure_client(self) -> Optional[httpx.Client]: + """Return the underlying client, logging an error if it is not initialized. + + Returns: + The httpx client, or None if not available. + """ + if not self._client: + logger.error("%s: Client not initialized", LOG_PREFIX) + return self._client + def _create_client(self, config: Config) -> Optional[httpx.Client]: """Create and configure the HTTP client. @@ -43,17 +68,17 @@ def _create_client(self, config: Config) -> Optional[httpx.Client]: """ endpoint = (config.otlp_endpoint or "").strip() if not endpoint: - logger.error("%s: NETRA_OTLP_ENDPOINT is required", _LOG_PREFIX) + logger.error("%s: NETRA_OTLP_ENDPOINT is required", LOG_PREFIX) return None base_url = self._resolve_base_url(endpoint) headers = self._build_headers(config) - timeout = self._get_timeout() + timeout = parse_env_float(ENV_TIMEOUT, DEFAULT_TIMEOUT) try: return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) except Exception as exc: - logger.error("%s: Failed to create HTTP client: %s", _LOG_PREFIX, exc) + logger.error("%s: Failed to create HTTP client: %s", LOG_PREFIX, exc) return None def _resolve_base_url(self, endpoint: str) -> str: @@ -66,8 +91,8 @@ def _resolve_base_url(self, endpoint: str) -> str: The cleaned base URL. """ base_url = endpoint.rstrip("/") - if base_url.endswith("/telemetry"): - base_url = base_url[: -len("/telemetry")] + if base_url.endswith(TELEMETRY_SUFFIX): + base_url = base_url[: -len(TELEMETRY_SUFFIX)] return base_url def _build_headers(self, config: Config) -> dict[str, str]: @@ -84,26 +109,6 @@ def _build_headers(self, config: Config) -> dict[str, str]: headers["x-api-key"] = config.api_key return headers - def _get_timeout(self) -> float: - """Get timeout from environment or use default. - - Returns: - The timeout value in seconds. - """ - timeout_str = os.getenv("NETRA_SIMULATION_TIMEOUT") - if not timeout_str: - return _DEFAULT_TIMEOUT - try: - return float(timeout_str) - except ValueError: - logger.warning( - "%s: Invalid timeout '%s', using default %.1f", - _LOG_PREFIX, - timeout_str, - _DEFAULT_TIMEOUT, - ) - return _DEFAULT_TIMEOUT - def create_run( self, name: str, @@ -120,26 +125,25 @@ def create_run( Returns: Dictionary containing run_id and simulation_items, or None on failure. """ - if not self._client: - logger.error("%s: Client not initialized", _LOG_PREFIX) + if not self._ensure_client(): return None response: Optional[httpx.Response] = None try: - url = "/evaluations/test_run/multi-turn" + url = URL_CREATE_RUN payload: dict[str, Any] = { "name": name, "datasetId": dataset_id, "context": context or {}, } - response = self._client.post(url, json=payload, timeout=500) + response = self._client.post(url, json=payload) # type:ignore[union-attr] response.raise_for_status() data = response.json() response_data = data.get("data", {}) user_messages = response_data.get("userMessages", []) if not user_messages: - logger.warning("%s: No user messages returned from create_run", _LOG_PREFIX) + logger.warning("%s: No user messages returned from create_run", LOG_PREFIX) return None run_id = response_data.get("id", "") @@ -148,6 +152,7 @@ def create_run( run_item_id=msg.get("testRunItemId", ""), message=msg.get("userMessage", ""), turn_id=msg.get("turnId", ""), + files=self._parse_files(msg.get("attachments")), ) for msg in user_messages ] @@ -158,7 +163,7 @@ def create_run( except Exception as exc: error_msg = self._extract_error_message(response, exc) - logger.error("%s: Failed to create simulation run: %s", _LOG_PREFIX, error_msg) + logger.error("%s: Failed to create simulation run: %s", LOG_PREFIX, error_msg) return None def trigger_conversation( @@ -179,13 +184,12 @@ def trigger_conversation( Returns: ConversationResponse with next turn info, or None on failure. """ - if not self._client: - logger.error("%s: Client not initialized", _LOG_PREFIX) + if not self._ensure_client(): return None response: Optional[httpx.Response] = None try: - url = "/evaluations/turn/agent-response" + url = URL_AGENT_RESPONSE payload: dict[str, Any] = { "turnId": turn_id, "agentResponse": {"message": message}, @@ -193,14 +197,15 @@ def trigger_conversation( "traceId": trace_id, } - response = self._client.post(url, json=payload, timeout=500) + response = self._client.post(url, json=payload) # type:ignore[union-attr] response.raise_for_status() data = response.json() response_data = data.get("data", {}) - decision = response_data.get("decision", "continue") + raw_decision = response_data.get("decision", "continue") + decision = ConversationStatus(raw_decision) - if decision == "stop": + if decision == ConversationStatus.STOP: return ConversationResponse( decision=decision, reason=response_data.get("reason", ""), @@ -208,7 +213,7 @@ def trigger_conversation( user_messages = response_data.get("userMessages", []) if not user_messages: - logger.warning("%s: No user messages in continue response", _LOG_PREFIX) + logger.warning("%s: No user messages in continue response", LOG_PREFIX) return None next_msg = next(iter(user_messages)) @@ -216,12 +221,12 @@ def trigger_conversation( decision=decision, next_turn_id=next_msg.get("turnId", ""), next_user_message=next_msg.get("userMessage", ""), - next_run_item_id=next_msg.get("testRunItemId", ""), + next_files=self._parse_files(next_msg.get("attachments")), ) except Exception as exc: error_msg = self._extract_error_message(response, exc) - logger.error("%s: Failed to trigger conversation: %s", _LOG_PREFIX, error_msg) + logger.error("%s: Failed to trigger conversation: %s", LOG_PREFIX, error_msg) raise def report_failure(self, run_id: str, run_item_id: str, error: str) -> None: @@ -232,20 +237,19 @@ def report_failure(self, run_id: str, run_item_id: str, error: str) -> None: run_item_id: Identifier of the run item. error: Error message describing the failure. """ - if not self._client: - logger.error("%s: Client not initialized", _LOG_PREFIX) + if not self._ensure_client(): return response: Optional[httpx.Response] = None try: - url = f"/evaluations/run/{run_id}/item/{run_item_id}/status" + url = URL_RUN_ITEM_STATUS.format(run_id=run_id, run_item_id=run_item_id) payload: dict[str, Any] = {"status": "failed", "failureReason": error} - response = self._client.patch(url, json=payload) + response = self._client.patch(url, json=payload) # type:ignore[union-attr] response.raise_for_status() - logger.info("%s: Reported failure - %s", _LOG_PREFIX, error) + logger.info("%s: Reported failure - %s", LOG_PREFIX, error) except Exception as exc: error_msg = self._extract_error_message(response, exc) - logger.error("%s: Failed to report failure: %s", _LOG_PREFIX, error_msg) + logger.error("%s: Failed to report failure: %s", LOG_PREFIX, error_msg) def post_run_status(self, run_id: str, status: str) -> Any: """Submit the run status. @@ -257,26 +261,60 @@ def post_run_status(self, run_id: str, status: str) -> Any: Returns: Backend JSON response containing confirmation, or error dict. """ - if not self._client: - logger.error("%s: Client not initialized; cannot post run status", _LOG_PREFIX) + if not self._ensure_client(): return {"success": False} response: Optional[httpx.Response] = None try: - url = f"/evaluations/run/{run_id}/status" + url = URL_RUN_STATUS.format(run_id=run_id) payload: dict[str, Any] = {"status": status} - response = self._client.post(url, json=payload) + response = self._client.post(url, json=payload) # type:ignore[union-attr] response.raise_for_status() data = response.json() if isinstance(data, dict) and "data" in data: - logger.info("%s: Test run status %s", _LOG_PREFIX, status) + logger.info("%s: Test run status %s", LOG_PREFIX, status) return data.get("data", {}) return data except Exception as exc: error_msg = self._extract_error_message(response, exc) - logger.error("%s: Failed to post run status for run '%s': %s", _LOG_PREFIX, run_id, error_msg) + logger.error("%s: Failed to post run status for run '%s': %s", LOG_PREFIX, run_id, error_msg) return {"success": False} + @staticmethod + def _parse_files(raw_files: list[dict[str, str]] | None) -> list[FileData]: + """Parse raw file entries from the backend response into FileData objects. + + Args: + raw_files: List of file dictionaries from the JSON response, or None. + + Returns: + List of FileData objects. Malformed entries are skipped. + """ + if not raw_files or not isinstance(raw_files, list): + return [] + + parsed: list[FileData] = [] + for entry in raw_files: + if not isinstance(entry, dict): + continue + file_name = entry.get("fileName", "") + download_url = entry.get("downloadUrl", "") + if not file_name or not download_url: + logger.warning( + "%s: Skipping file entry with missing fileName or downloadUrl", + LOG_PREFIX, + ) + continue + parsed.append( + FileData( + file_name=file_name, + content_type=entry.get("contentType", ""), + description=entry.get("description"), + download_url=download_url, + ) + ) + return parsed + def _extract_error_message( self, response: Optional[httpx.Response], @@ -298,5 +336,5 @@ def _extract_error_message( if isinstance(error_data, dict): return error_data.get("message", str(exc)) except Exception: - pass + logger.debug("%s: Could not parse error from response body", LOG_PREFIX, exc_info=True) return str(exc) diff --git a/netra/simulation/constants.py b/netra/simulation/constants.py new file mode 100644 index 00000000..20579473 --- /dev/null +++ b/netra/simulation/constants.py @@ -0,0 +1,38 @@ +"""Shared constants for the simulation module.""" + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- +LOG_PREFIX = "netra.simulation" + +# --------------------------------------------------------------------------- +# Span / tracing +# --------------------------------------------------------------------------- +SPAN_NAME = "Netra.Simulation.TestRun" + +# --------------------------------------------------------------------------- +# Conversation limits +# --------------------------------------------------------------------------- +DEFAULT_MAX_TURNS = 50 + +# --------------------------------------------------------------------------- +# API endpoints +# --------------------------------------------------------------------------- +URL_CREATE_RUN = "/evaluations/test_run/multi-turn" +URL_AGENT_RESPONSE = "/evaluations/turn/agent-response" +URL_RUN_ITEM_STATUS = "/evaluations/run/{run_id}/item/{run_item_id}/status" +URL_RUN_STATUS = "/evaluations/run/{run_id}/status" +TELEMETRY_SUFFIX = "/telemetry" + +# --------------------------------------------------------------------------- +# HTTP client timeouts +# --------------------------------------------------------------------------- +DEFAULT_TIMEOUT = 500.0 +ENV_TIMEOUT = "NETRA_SIMULATION_TIMEOUT" + +# --------------------------------------------------------------------------- +# File download +# --------------------------------------------------------------------------- +DEFAULT_FILE_DOWNLOAD_TIMEOUT = 30.0 +ENV_FILE_DOWNLOAD_TIMEOUT = "NETRA_SIMULATION_FILE_DOWNLOAD_TIMEOUT" +MAX_FILE_DOWNLOAD_WORKERS = 8 diff --git a/netra/simulation/models.py b/netra/simulation/models.py index 258f6554..26b68e90 100644 --- a/netra/simulation/models.py +++ b/netra/simulation/models.py @@ -1,6 +1,6 @@ """Data models for the simulation module.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Optional @@ -12,6 +12,40 @@ class ConversationStatus(Enum): STOP = "stop" +@dataclass(slots=True, frozen=True) +class FileData: + """Raw file metadata received from the backend. + + Attributes: + file_name: Name of the file. + content_type: MIME type of the file content. + description: Optional description of the file. + download_url: Pre-signed URL to download the file. + """ + + file_name: str + content_type: str + description: Optional[str] + download_url: str + + +@dataclass(slots=True, frozen=True) +class ProcessedFile: + """File after download and base64 encoding, delivered to the user task. + + Attributes: + file_name: Name of the file. + content_type: MIME type of the file content. + description: Optional description of the file. + data: Base64-encoded file content. + """ + + file_name: str + content_type: str + description: Optional[str] + data: str + + @dataclass(slots=True, frozen=True) class SimulationItem: """Represents a single item in a simulation run. @@ -20,11 +54,13 @@ class SimulationItem: run_item_id: Unique identifier for the run item. message: The user message content. turn_id: Identifier for the conversation turn. + files: File metadata attached to this item. """ run_item_id: str message: str turn_id: str + files: list[FileData] = field(default_factory=list) @dataclass(slots=True) @@ -32,18 +68,18 @@ class ConversationResponse: """Response from the conversation trigger API. Attributes: - decision: The decision to continue or stop the conversation. + decision: Whether to continue or stop the conversation. reason: Optional reason for stopping the conversation. next_turn_id: Identifier for the next turn if continuing. next_user_message: The next user message if continuing. - next_run_item_id: Identifier for the next run item if continuing. + next_files: File metadata for the next turn if continuing. """ - decision: str + decision: ConversationStatus reason: Optional[str] = None next_turn_id: Optional[str] = None next_user_message: Optional[str] = None - next_run_item_id: Optional[str] = None + next_files: list[FileData] = field(default_factory=list) @dataclass(slots=True, frozen=True) diff --git a/netra/simulation/task.py b/netra/simulation/task.py index bdfc39b7..8ecb40f4 100644 --- a/netra/simulation/task.py +++ b/netra/simulation/task.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from typing import Awaitable, Optional -from netra.simulation.models import TaskResult +from netra.simulation.models import ProcessedFile, TaskResult class BaseTask(ABC): @@ -18,13 +18,18 @@ class BaseTask(ABC): Subclasses must: - Implement run(): Executes the task logic and returns a TaskResult. - The run method receives a message and optional session_id, and must return - a TaskResult containing the response message and session_id. + The framework always passes ``message``, ``session_id``, and ``files`` + to ``run()``. Tasks that don't need file attachments can simply ignore + the ``files`` parameter. Example: class MyTask(BaseTask): - def run(self, message: str, session_id: Optional[str] = None) -> TaskResult: - # Call your LLM or agent here + def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult: response = my_agent.chat(message, session_id=session_id) return TaskResult( message=response.text, @@ -37,10 +42,31 @@ def run(self, message: str, session_id: Optional[str] = None) -> TaskResult: task=MyTask(), ) + Example with file uploads: + class MyFileTask(BaseTask): + def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult: + if files: + for f in files: + print(f.file_name, f.content_type, len(f.data)) + response = my_agent.chat(message, session_id=session_id, files=files) + return TaskResult( + message=response.text, + session_id=response.session_id or session_id or "default", + ) + Async Example: class MyAsyncTask(BaseTask): - async def run(self, message: str, session_id: Optional[str] = None) -> TaskResult: - # Call your async LLM or agent here + async def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult: response = await my_async_agent.chat(message, session_id=session_id) return TaskResult( message=response.text, @@ -49,7 +75,12 @@ async def run(self, message: str, session_id: Optional[str] = None) -> TaskResul """ @abstractmethod - def run(self, message: str, session_id: Optional[str] = None) -> TaskResult | Awaitable[TaskResult]: + def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult | Awaitable[TaskResult]: """ Execute the task logic. @@ -60,6 +91,8 @@ def run(self, message: str, session_id: Optional[str] = None) -> TaskResult | Aw message: The input message from the simulation. session_id: Optional session identifier for conversation continuity. Will be None for the first turn of a conversation. + files: Optional list of base64-encoded file attachments from the + dataset item. Will be None when no files are attached. Returns: TaskResult: The task result containing: diff --git a/netra/simulation/utils.py b/netra/simulation/utils.py index 341cd26a..24c078bb 100644 --- a/netra/simulation/utils.py +++ b/netra/simulation/utils.py @@ -1,16 +1,53 @@ """Utility functions for the simulation module.""" import asyncio +import base64 +import concurrent.futures import logging +import os import threading -from typing import Awaitable, Optional, Tuple, TypeVar +from typing import Awaitable, Optional, TypeVar -from netra.simulation.models import TaskResult +import httpx + +from netra.simulation.constants import ( + DEFAULT_FILE_DOWNLOAD_TIMEOUT, + ENV_FILE_DOWNLOAD_TIMEOUT, + LOG_PREFIX, + MAX_FILE_DOWNLOAD_WORKERS, +) +from netra.simulation.models import FileData, ProcessedFile, TaskResult from netra.simulation.task import BaseTask logger = logging.getLogger(__name__) -T = TypeVar("T") +_T = TypeVar("_T") + + +def parse_env_float(env_var: str, default: float) -> float: + """Read an environment variable and parse it as a float. + + Args: + env_var: Name of the environment variable. + default: Value to return when the variable is unset or invalid. + + Returns: + The parsed float, or *default* on failure. + """ + raw = os.getenv(env_var) + if not raw: + return default + try: + return float(raw) + except ValueError: + logger.warning( + "%s: Invalid value '%s' for %s, using default %.1f", + LOG_PREFIX, + raw, + env_var, + default, + ) + return default def format_trace_id(trace_id: int) -> str: @@ -47,11 +84,14 @@ def validate_simulation_inputs( return True -def run_async_safely(coro: Awaitable[T]) -> T: - """Run an async coroutine from sync code. +def run_async_safely(coro: Awaitable[_T]) -> _T: + """Run an async coroutine from synchronous code. - If an event loop is already running, executes in a dedicated thread - to avoid 'asyncio.run() cannot be called from a running event loop'. + When called from a context that already has a running event loop (e.g. a + Jupyter notebook, or an async framework like FastAPI), ``asyncio.run()`` + would raise. In that case we spin up a **new daemon thread** with its own + event loop via ``asyncio.run()`` so the caller's loop is never blocked or + re-entered. Args: coro: The coroutine to execute. @@ -68,7 +108,7 @@ def run_async_safely(coro: Awaitable[T]) -> T: loop = None if loop and loop.is_running(): - result_holder: dict[str, T] = {} + result_holder: dict[str, _T] = {} error_holder: dict[str, BaseException] = {} def runner() -> None: @@ -88,17 +128,76 @@ def runner() -> None: return asyncio.run(coro) # type: ignore[arg-type] +def _download_single_file(file_data: FileData, timeout: float) -> ProcessedFile: + """Download a single file and base64-encode its content. + + Args: + file_data: Metadata for the file to download. + timeout: HTTP request timeout in seconds. + + Returns: + A ProcessedFile with the base64-encoded content. + + Raises: + RuntimeError: If the download or encoding fails. + """ + try: + response = httpx.get(file_data.download_url, timeout=timeout) + response.raise_for_status() + encoded = base64.b64encode(response.content).decode("ascii") + return ProcessedFile( + file_name=file_data.file_name, + content_type=file_data.content_type, + description=file_data.description, + data=encoded, + ) + except Exception as exc: + raise RuntimeError(f"Failed to download file '{file_data.file_name}': {exc}") from exc + + +def process_files(files: list[FileData]) -> list[ProcessedFile]: + """Download files from pre-signed URLs and base64-encode their content. + + Downloads run concurrently via a thread pool. If any file fails, the + entire batch is aborted with a ``RuntimeError`` so that file-aware tasks + never receive a partial file list. + + Args: + files: List of FileData objects containing download URLs. + + Returns: + List of ProcessedFile objects with base64-encoded data. + + Raises: + RuntimeError: If a file download or encoding fails. + """ + if not files: + return [] + + timeout = parse_env_float(ENV_FILE_DOWNLOAD_TIMEOUT, DEFAULT_FILE_DOWNLOAD_TIMEOUT) + + if len(files) == 1: + return [_download_single_file(files[0], timeout)] + + max_workers = min(MAX_FILE_DOWNLOAD_WORKERS, len(files)) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [pool.submit(_download_single_file, fd, timeout) for fd in files] + return [f.result() for f in futures] + + async def execute_task( task: BaseTask, message: str, session_id: Optional[str], -) -> Tuple[str, Optional[str]]: + raw_files: Optional[list[FileData]] = None, +) -> tuple[str, Optional[str]]: """Execute a task's run method (sync or async) and extract message and session_id. Args: task: The BaseTask instance to execute. message: The input message to pass to the task. session_id: The current session identifier. + raw_files: Raw file metadata from the backend. Returns: A tuple of (response_message, session_id). @@ -106,7 +205,9 @@ async def execute_task( Raises: ValueError: If the task returns an unsupported type. """ - result = task.run(message=message, session_id=session_id) + processed_files = process_files(raw_files) if raw_files else None + + result = task.run(message=message, session_id=session_id, files=processed_files) if asyncio.iscoroutine(result): result = await result diff --git a/poetry.lock b/poetry.lock index c0d758fb..b32eacaf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -149,6 +149,26 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] +[[package]] +name = "anyio" +version = "4.13.0" +description = "High-level concurrency and networking framework on top of asyncio or Trio" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708"}, + {file = "anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} + +[package.extras] +trio = ["trio (>=0.32.0)"] + [[package]] name = "argcomplete" version = "3.6.2" @@ -841,7 +861,7 @@ version = "1.3.0" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main", "dev"] markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"}, @@ -1126,6 +1146,18 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.73.1)"] +[[package]] +name = "h11" +version = "0.16.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86"}, + {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"}, +] + [[package]] name = "hf-xet" version = "1.1.5" @@ -1148,6 +1180,53 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "httpcore" +version = "1.0.9" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"}, + {file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.16" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + +[[package]] +name = "httpx" +version = "0.28.1" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, + {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" + +[package.extras] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "huggingface-hub" version = "0.33.4" @@ -5950,4 +6029,4 @@ presidio = ["presidio-analyzer", "presidio-anonymizer", "stanza", "transformers" [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "8712f9e919a32ee9fc083e2ec7cf9721815a7e3108df884af52db399d06b3f07" +content-hash = "e3e8712b3e7fc95f4cfe59b27fca6ce7ac4e84e2d04ea419ffbce8540d0b48bd" diff --git a/pyproject.toml b/pyproject.toml index d906f769..3f3616eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ dependencies = [ "opentelemetry-instrumentation-urllib>=0.55b1,<=0.62b1", "opentelemetry-instrumentation-urllib3>=0.55b1,<=0.62b1", "json-repair==0.44.1", + "httpx>=0.27.0,<1.0.0", ] [project.urls] diff --git a/tests/test_simulation.py b/tests/test_simulation.py new file mode 100644 index 00000000..069e907b --- /dev/null +++ b/tests/test_simulation.py @@ -0,0 +1,783 @@ +""" +Unit tests for the netra/simulation/ module. + +Covers models, utils, client, api, and task layers with mocked +HTTP interactions and async helpers. +""" + +import asyncio +import base64 +from typing import Any, Optional +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from netra.simulation.models import ( + ConversationResponse, + ConversationStatus, + FileData, + ProcessedFile, + SimulationItem, + TaskResult, +) +from netra.simulation.task import BaseTask +from netra.simulation.utils import ( + execute_task, + format_trace_id, + parse_env_float, + process_files, + run_async_safely, + validate_simulation_inputs, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class SyncTask(BaseTask): + """Synchronous task that echoes the message.""" + + def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult: + return TaskResult(message=f"echo: {message}", session_id=session_id or "sid-1") + + +class AsyncTask(BaseTask): + """Asynchronous task that echoes the message.""" + + async def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult: + return TaskResult(message=f"async-echo: {message}", session_id=session_id or "sid-async") + + +class FileAwareTask(BaseTask): + """Task that uses the files parameter.""" + + def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> TaskResult: + count = len(files) if files else 0 + return TaskResult(message=f"files={count}", session_id=session_id or "sid-files") + + +# --------------------------------------------------------------------------- +# Section 1: Models +# --------------------------------------------------------------------------- + + +class TestConversationStatus: + """Tests for ConversationStatus enum.""" + + def test_continue_value(self) -> None: + assert ConversationStatus.CONTINUE.value == "continue" + + def test_stop_value(self) -> None: + assert ConversationStatus.STOP.value == "stop" + + def test_from_string(self) -> None: + assert ConversationStatus("continue") == ConversationStatus.CONTINUE + assert ConversationStatus("stop") == ConversationStatus.STOP + + +class TestFileData: + """Tests for the FileData frozen dataclass.""" + + def test_creation(self) -> None: + fd = FileData(file_name="a.txt", content_type="text/plain", description="desc", download_url="https://x") + assert fd.file_name == "a.txt" + assert fd.content_type == "text/plain" + assert fd.description == "desc" + assert fd.download_url == "https://x" + + def test_frozen(self) -> None: + fd = FileData(file_name="a.txt", content_type="text/plain", description=None, download_url="https://x") + with pytest.raises(AttributeError): + fd.file_name = "b.txt" # type: ignore[misc] + + +class TestProcessedFile: + """Tests for the ProcessedFile frozen dataclass.""" + + def test_creation(self) -> None: + pf = ProcessedFile(file_name="a.txt", content_type="text/plain", description=None, data="AAAA") + assert pf.data == "AAAA" + + def test_frozen(self) -> None: + pf = ProcessedFile(file_name="a.txt", content_type="text/plain", description=None, data="AAAA") + with pytest.raises(AttributeError): + pf.data = "BBBB" # type: ignore[misc] + + +class TestSimulationItem: + """Tests for the SimulationItem frozen dataclass.""" + + def test_defaults(self) -> None: + item = SimulationItem(run_item_id="r1", message="hi", turn_id="t1") + assert item.files == [] + + def test_with_files(self) -> None: + fd = FileData(file_name="a.txt", content_type="text/plain", description=None, download_url="https://x") + item = SimulationItem(run_item_id="r1", message="hi", turn_id="t1", files=[fd]) + assert len(item.files) == 1 + + +class TestConversationResponse: + """Tests for the ConversationResponse dataclass.""" + + def test_stop_decision(self) -> None: + resp = ConversationResponse(decision=ConversationStatus.STOP, reason="done") + assert resp.decision == ConversationStatus.STOP + assert resp.reason == "done" + + def test_continue_decision_defaults(self) -> None: + resp = ConversationResponse(decision=ConversationStatus.CONTINUE) + assert resp.next_turn_id is None + assert resp.next_user_message is None + assert resp.next_files == [] + + +class TestTaskResult: + """Tests for the TaskResult frozen dataclass.""" + + def test_creation(self) -> None: + tr = TaskResult(message="hello", session_id="s1") + assert tr.message == "hello" + assert tr.session_id == "s1" + + def test_frozen(self) -> None: + tr = TaskResult(message="hello", session_id="s1") + with pytest.raises(AttributeError): + tr.message = "bye" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# Section 2: Utils +# --------------------------------------------------------------------------- + + +class TestParseEnvFloat: + """Tests for parse_env_float.""" + + def test_returns_default_when_unset(self) -> None: + assert parse_env_float("_NETRA_TEST_NONEXISTENT_VAR_", 42.0) == 42.0 + + def test_parses_valid_value(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("_NETRA_TEST_FLOAT_", "3.14") + assert parse_env_float("_NETRA_TEST_FLOAT_", 1.0) == pytest.approx(3.14) + + def test_returns_default_on_invalid(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("_NETRA_TEST_FLOAT_", "not-a-number") + assert parse_env_float("_NETRA_TEST_FLOAT_", 7.0) == 7.0 + + +class TestFormatTraceId: + """Tests for format_trace_id.""" + + def test_zero(self) -> None: + assert format_trace_id(0) == "0" * 32 + + def test_known_value(self) -> None: + result = format_trace_id(255) + assert result == "0" * 30 + "ff" + assert len(result) == 32 + + +class TestValidateSimulationInputs: + """Tests for validate_simulation_inputs.""" + + def test_valid(self) -> None: + assert validate_simulation_inputs("ds-1", SyncTask()) is True + + def test_empty_dataset_id(self) -> None: + assert validate_simulation_inputs("", SyncTask()) is False + + def test_wrong_task_type(self) -> None: + assert validate_simulation_inputs("ds-1", "not a task") is False # type: ignore[arg-type] + + +class TestRunAsyncSafely: + """Tests for run_async_safely.""" + + def test_runs_coroutine(self) -> None: + async def coro() -> int: + return 42 + + assert run_async_safely(coro()) == 42 + + def test_propagates_exception(self) -> None: + async def coro() -> None: + raise ValueError("boom") + + with pytest.raises(ValueError, match="boom"): + run_async_safely(coro()) + + +class TestProcessFiles: + """Tests for process_files.""" + + def test_empty_list(self) -> None: + assert process_files([]) == [] + + @patch("netra.simulation.utils.httpx.get") + def test_downloads_and_encodes(self, mock_get: MagicMock) -> None: + raw_content = b"hello world" + mock_response = MagicMock(spec=httpx.Response) + mock_response.content = raw_content + mock_response.raise_for_status = MagicMock() + mock_get.return_value = mock_response + + fd = FileData(file_name="a.txt", content_type="text/plain", description=None, download_url="https://x/a.txt") + result = process_files([fd]) + + assert len(result) == 1 + assert result[0].file_name == "a.txt" + assert result[0].data == base64.b64encode(raw_content).decode("ascii") + + @patch("netra.simulation.utils.httpx.get") + def test_raises_on_download_failure(self, mock_get: MagicMock) -> None: + mock_get.side_effect = httpx.ConnectError("connection refused") + + fd = FileData(file_name="a.txt", content_type="text/plain", description=None, download_url="https://x/a.txt") + with pytest.raises(RuntimeError, match="Failed to download file 'a.txt'"): + process_files([fd]) + + @patch("netra.simulation.utils.httpx.get") + def test_concurrent_downloads(self, mock_get: MagicMock) -> None: + mock_response = MagicMock(spec=httpx.Response) + mock_response.content = b"data" + mock_response.raise_for_status = MagicMock() + mock_get.return_value = mock_response + + files = [ + FileData(file_name=f"f{i}.txt", content_type="text/plain", description=None, download_url=f"https://x/{i}") + for i in range(3) + ] + result = process_files(files) + assert len(result) == 3 + assert mock_get.call_count == 3 + + +class TestExecuteTaskFiles: + """Tests for file handling in execute_task.""" + + @patch("netra.simulation.utils.process_files", return_value=[]) + def test_files_downloaded_when_raw_files_present(self, mock_pf: MagicMock) -> None: + """Files are always downloaded and passed to the task.""" + fd = FileData(file_name="a.txt", content_type="text/plain", description=None, download_url="https://x/a.txt") + result = asyncio.run(execute_task(FileAwareTask(), "hi", None, raw_files=[fd])) + mock_pf.assert_called_once_with([fd]) + assert result[0] == "files=0" + + def test_no_files_passed_as_none(self) -> None: + """When no raw_files are provided, files=None is passed to the task.""" + msg, sid = asyncio.run(execute_task(SyncTask(), "hi", None, raw_files=None)) + assert msg == "echo: hi" + + @patch("netra.simulation.utils.process_files") + def test_empty_raw_files_skips_download(self, mock_pf: MagicMock) -> None: + """An empty raw_files list should not trigger downloads.""" + asyncio.run(execute_task(SyncTask(), "hi", None, raw_files=[])) + mock_pf.assert_not_called() + + +class TestExecuteTask: + """Tests for execute_task.""" + + def test_sync_task(self) -> None: + msg, sid = asyncio.run(execute_task(SyncTask(), "hello", None)) + assert msg == "echo: hello" + assert sid == "sid-1" + + def test_async_task(self) -> None: + msg, sid = asyncio.run(execute_task(AsyncTask(), "hello", None)) + assert msg == "async-echo: hello" + assert sid == "sid-async" + + def test_raises_on_bad_return_type(self) -> None: + class BadTask(BaseTask): + def run( + self, + message: str, + session_id: Optional[str] = None, + files: Optional[list[ProcessedFile]] = None, + ) -> Any: + return "not a TaskResult" + + with pytest.raises(ValueError, match="Task must return TaskResult"): + asyncio.run(execute_task(BadTask(), "x", None)) + + +# --------------------------------------------------------------------------- +# Section 3: Client +# --------------------------------------------------------------------------- + + +class TestSimulationHttpClient: + """Tests for SimulationHttpClient.""" + + def _make_config(self, endpoint: str = "https://api.getnetra.ai/telemetry", api_key: str = "key-1") -> MagicMock: + """Create a mock Config.""" + cfg = MagicMock() + cfg.otlp_endpoint = endpoint + cfg.api_key = api_key + cfg.headers = {} + return cfg + + def test_create_client_with_valid_config(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config()) + assert client._client is not None + client.close() + + def test_create_client_strips_telemetry_suffix(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config(endpoint="https://api.getnetra.ai/telemetry")) + assert client._client is not None + assert "/telemetry" not in str(client._client.base_url) + client.close() + + def test_create_client_returns_none_on_empty_endpoint(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config(endpoint="")) + assert client._client is None + + def test_close_sets_client_to_none(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config()) + assert client._client is not None + client.close() + assert client._client is None + + def test_close_idempotent(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config()) + client.close() + client.close() + assert client._client is None + + def test_ensure_client_returns_none_when_not_initialized(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config(endpoint="")) + assert client._ensure_client() is None + + def test_create_run_returns_none_without_client(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config(endpoint="")) + assert client.create_run(name="test", dataset_id="ds-1") is None + + @patch("netra.simulation.client.httpx.Client") + def test_create_run_success(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.client import SimulationHttpClient + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = { + "data": { + "id": "run-1", + "userMessages": [ + { + "testRunItemId": "item-1", + "userMessage": "hello", + "turnId": "turn-1", + "attachments": None, + } + ], + } + } + mock_response.raise_for_status = MagicMock() + + mock_instance = MagicMock() + mock_instance.post.return_value = mock_response + mock_client_cls.return_value = mock_instance + + client = SimulationHttpClient(self._make_config()) + result = client.create_run(name="test", dataset_id="ds-1") + + assert result is not None + assert result["run_id"] == "run-1" + assert len(result["simulation_items"]) == 1 + assert result["simulation_items"][0].message == "hello" + + @patch("netra.simulation.client.httpx.Client") + def test_create_run_returns_none_on_http_error(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.client import SimulationHttpClient + + mock_instance = MagicMock() + mock_instance.post.side_effect = httpx.HTTPStatusError( + "Server Error", request=MagicMock(), response=MagicMock() + ) + mock_client_cls.return_value = mock_instance + + client = SimulationHttpClient(self._make_config()) + result = client.create_run(name="test", dataset_id="ds-1") + assert result is None + + @patch("netra.simulation.client.httpx.Client") + def test_trigger_conversation_stop(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.client import SimulationHttpClient + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = { + "data": { + "decision": "stop", + "reason": "all done", + } + } + mock_response.raise_for_status = MagicMock() + + mock_instance = MagicMock() + mock_instance.post.return_value = mock_response + mock_client_cls.return_value = mock_instance + + client = SimulationHttpClient(self._make_config()) + resp = client.trigger_conversation(message="hi", turn_id="t1", session_id="s1", trace_id="trace") + + assert resp is not None + assert resp.decision == ConversationStatus.STOP + assert resp.reason == "all done" + + @patch("netra.simulation.client.httpx.Client") + def test_trigger_conversation_continue(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.client import SimulationHttpClient + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = { + "data": { + "decision": "continue", + "userMessages": [ + { + "turnId": "turn-2", + "userMessage": "follow-up", + "testRunItemId": "item-2", + "attachments": None, + } + ], + } + } + mock_response.raise_for_status = MagicMock() + + mock_instance = MagicMock() + mock_instance.post.return_value = mock_response + mock_client_cls.return_value = mock_instance + + client = SimulationHttpClient(self._make_config()) + resp = client.trigger_conversation(message="hi", turn_id="t1", session_id="s1", trace_id="trace") + + assert resp is not None + assert resp.decision == ConversationStatus.CONTINUE + assert resp.next_turn_id == "turn-2" + assert resp.next_user_message == "follow-up" + + def test_trigger_conversation_returns_none_without_client(self) -> None: + from netra.simulation.client import SimulationHttpClient + + client = SimulationHttpClient(self._make_config(endpoint="")) + resp = client.trigger_conversation(message="hi", turn_id="t1", session_id="s1", trace_id="trace") + assert resp is None + + @patch("netra.simulation.client.httpx.Client") + def test_report_failure(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.client import SimulationHttpClient + + mock_response = MagicMock(spec=httpx.Response) + mock_response.raise_for_status = MagicMock() + + mock_instance = MagicMock() + mock_instance.patch.return_value = mock_response + mock_client_cls.return_value = mock_instance + + client = SimulationHttpClient(self._make_config()) + client.report_failure(run_id="run-1", run_item_id="item-1", error="boom") + mock_instance.patch.assert_called_once() + + @patch("netra.simulation.client.httpx.Client") + def test_post_run_status_success(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.client import SimulationHttpClient + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {"status": "completed"}} + mock_response.raise_for_status = MagicMock() + + mock_instance = MagicMock() + mock_instance.post.return_value = mock_response + mock_client_cls.return_value = mock_instance + + client = SimulationHttpClient(self._make_config()) + result = client.post_run_status(run_id="run-1", status="completed") + assert result == {"status": "completed"} + + @patch("netra.simulation.client.httpx.Client") + def test_post_run_status_returns_error_on_failure(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.client import SimulationHttpClient + + mock_instance = MagicMock() + mock_instance.post.side_effect = httpx.ConnectError("timeout") + mock_client_cls.return_value = mock_instance + + client = SimulationHttpClient(self._make_config()) + result = client.post_run_status(run_id="run-1", status="completed") + assert result == {"success": False} + + def test_parse_files_none(self) -> None: + from netra.simulation.client import SimulationHttpClient + + assert SimulationHttpClient._parse_files(None) == [] + + def test_parse_files_valid(self) -> None: + from netra.simulation.client import SimulationHttpClient + + raw = [{"fileName": "a.txt", "downloadUrl": "https://x/a", "contentType": "text/plain"}] + result = SimulationHttpClient._parse_files(raw) + assert len(result) == 1 + assert result[0].file_name == "a.txt" + + def test_parse_files_skips_malformed(self) -> None: + from netra.simulation.client import SimulationHttpClient + + raw = [{"fileName": "", "downloadUrl": "https://x/a"}] + result = SimulationHttpClient._parse_files(raw) + assert result == [] + + def test_extract_error_message_from_response(self) -> None: + from netra.simulation.client import SimulationHttpClient + + cfg = MagicMock() + cfg.otlp_endpoint = "" + cfg.api_key = "" + cfg.headers = {} + client = SimulationHttpClient(cfg) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"error": {"message": "custom error"}} + result = client._extract_error_message(mock_response, ValueError("fallback")) + assert result == "custom error" + + def test_extract_error_message_fallback(self) -> None: + from netra.simulation.client import SimulationHttpClient + + cfg = MagicMock() + cfg.otlp_endpoint = "" + cfg.api_key = "" + cfg.headers = {} + client = SimulationHttpClient(cfg) + + result = client._extract_error_message(None, ValueError("fallback")) + assert result == "fallback" + + +# --------------------------------------------------------------------------- +# Section 4: API (Simulation class) +# --------------------------------------------------------------------------- + + +class TestSimulation: + """Tests for the Simulation public API.""" + + def _make_config(self) -> MagicMock: + cfg = MagicMock() + cfg.otlp_endpoint = "https://api.getnetra.ai/telemetry" + cfg.api_key = "key-1" + cfg.headers = {} + return cfg + + @patch("netra.simulation.api.SimulationHttpClient") + def test_run_simulation_returns_none_on_invalid_inputs(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.api import Simulation + + sim = Simulation(self._make_config()) + result = sim.run_simulation(name="test", dataset_id="", task=SyncTask()) + assert result is None + + @patch("netra.simulation.api.SimulationHttpClient") + def test_run_simulation_returns_none_when_create_run_fails(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.api import Simulation + + mock_client_cls.return_value.create_run.return_value = None + sim = Simulation(self._make_config()) + result = sim.run_simulation(name="test", dataset_id="ds-1", task=SyncTask()) + assert result is None + + @patch("netra.simulation.api.SpanWrapper") + @patch("netra.simulation.api.SimulationHttpClient") + def test_run_simulation_success(self, mock_client_cls: MagicMock, mock_span_wrapper: MagicMock) -> None: + from netra.simulation.api import Simulation + + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_span.get_current_span.return_value = None + mock_span_wrapper.return_value = mock_span + + stop_response = ConversationResponse( + decision=ConversationStatus.STOP, + reason="done", + ) + + mock_client = MagicMock() + mock_client.create_run.return_value = { + "run_id": "run-1", + "simulation_items": [ + SimulationItem(run_item_id="item-1", message="hello", turn_id="turn-1"), + ], + } + mock_client.trigger_conversation.return_value = stop_response + mock_client.post_run_status.return_value = {"status": "completed"} + mock_client_cls.return_value = mock_client + + sim = Simulation(self._make_config()) + result = sim.run_simulation(name="test", dataset_id="ds-1", task=SyncTask()) + + assert result is not None + assert result["total_items"] == 1 + assert len(result["completed"]) == 1 + assert len(result["failed"]) == 0 + + @patch("netra.simulation.api.SpanWrapper") + @patch("netra.simulation.api.SimulationHttpClient") + def test_run_simulation_marks_failed_on_exception( + self, mock_client_cls: MagicMock, mock_span_wrapper: MagicMock + ) -> None: + from netra.simulation.api import Simulation + + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_span.get_current_span.return_value = None + mock_span_wrapper.return_value = mock_span + + mock_client = MagicMock() + mock_client.create_run.return_value = { + "run_id": "run-1", + "simulation_items": [ + SimulationItem(run_item_id="item-1", message="hello", turn_id="turn-1"), + ], + } + mock_client.trigger_conversation.side_effect = RuntimeError("backend down") + mock_client.post_run_status.return_value = {} + mock_client_cls.return_value = mock_client + + sim = Simulation(self._make_config()) + result = sim.run_simulation(name="test", dataset_id="ds-1", task=SyncTask()) + + assert result is not None + assert len(result["failed"]) == 1 + assert result["failed"][0]["error"] == "backend down" + + @patch("netra.simulation.api.SpanWrapper") + @patch("netra.simulation.api.SimulationHttpClient") + def test_max_turns_guard(self, mock_client_cls: MagicMock, mock_span_wrapper: MagicMock) -> None: + from netra.simulation.api import Simulation + + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_span.get_current_span.return_value = None + mock_span_wrapper.return_value = mock_span + + continue_response = ConversationResponse( + decision=ConversationStatus.CONTINUE, + next_turn_id="turn-next", + next_user_message="keep going", + ) + + mock_client = MagicMock() + mock_client.create_run.return_value = { + "run_id": "run-1", + "simulation_items": [ + SimulationItem(run_item_id="item-1", message="hello", turn_id="turn-1"), + ], + } + mock_client.trigger_conversation.return_value = continue_response + mock_client.post_run_status.return_value = {} + mock_client_cls.return_value = mock_client + + sim = Simulation(self._make_config()) + result = sim.run_simulation(name="test", dataset_id="ds-1", task=SyncTask(), max_turns=3) + + assert result is not None + assert len(result["failed"]) == 1 + assert "Exceeded maximum turns (3)" in result["failed"][0]["error"] + + @patch("netra.simulation.api.SimulationHttpClient") + def test_close_delegates_to_client(self, mock_client_cls: MagicMock) -> None: + from netra.simulation.api import Simulation + + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + + sim = Simulation(self._make_config()) + sim.close() + mock_client.close.assert_called_once() + + @patch("netra.simulation.api.SpanWrapper") + @patch("netra.simulation.api.SimulationHttpClient") + def test_trigger_conversation_none_response(self, mock_client_cls: MagicMock, mock_span_wrapper: MagicMock) -> None: + from netra.simulation.api import Simulation + + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_span.get_current_span.return_value = None + mock_span_wrapper.return_value = mock_span + + mock_client = MagicMock() + mock_client.create_run.return_value = { + "run_id": "run-1", + "simulation_items": [ + SimulationItem(run_item_id="item-1", message="hello", turn_id="turn-1"), + ], + } + mock_client.trigger_conversation.return_value = None + mock_client.post_run_status.return_value = {} + mock_client_cls.return_value = mock_client + + sim = Simulation(self._make_config()) + result = sim.run_simulation(name="test", dataset_id="ds-1", task=SyncTask()) + + assert result is not None + assert len(result["failed"]) == 1 + assert "Failed to get conversation response" in result["failed"][0]["error"] + + +# --------------------------------------------------------------------------- +# Section 5: BaseTask +# --------------------------------------------------------------------------- + + +class TestBaseTask: + """Tests for BaseTask abstract class.""" + + def test_cannot_instantiate_directly(self) -> None: + with pytest.raises(TypeError): + BaseTask() # type: ignore[abstract] + + def test_sync_subclass(self) -> None: + task = SyncTask() + result = task.run(message="hi") + assert isinstance(result, TaskResult) + assert result.message == "echo: hi" + + def test_async_subclass(self) -> None: + task = AsyncTask() + result = asyncio.run(task.run(message="hi")) # type: ignore[arg-type] + assert isinstance(result, TaskResult) + assert result.message == "async-echo: hi" From e81dacfaad07315ca09c1c15042eaa3f3581fb65 Mon Sep 17 00:00:00 2001 From: akash-vijay-kv Date: Mon, 1 Jun 2026 09:57:44 +0530 Subject: [PATCH 03/13] [NET-995] feat: Add utility for adding streaming output on root span (#303) --- CHANGELOG.md | 2 +- netra/__init__.py | 22 ++ netra/instrumentation/agno/utils.py | 6 +- netra/instrumentation/agno/wrappers.py | 14 +- netra/instrumentation/cerebras/wrappers.py | 24 +++ .../instrumentation/google_genai/wrappers.py | 6 + netra/instrumentation/groq/wrappers.py | 24 +++ netra/instrumentation/litellm/wrappers.py | 24 +++ netra/instrumentation/openai/wrappers.py | 24 +++ netra/instrumentation/stream_utils.py | 201 ++++++++++++++++++ netra/session_manager.py | 68 ++++-- netra/utils.py | 15 ++ 12 files changed, 403 insertions(+), 27 deletions(-) create mode 100644 netra/instrumentation/stream_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 865c9e2a..bcfb8489 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -289,4 +289,4 @@ Users can be now overwrite the input and ouput attributes of spans created by in - Added utility to set input and output data for any active span in a trace -[0.1.87]: https://github.com/KeyValueSoftwareSystems/netra-sdk-py/tree/main +[0.1.89]: https://github.com/KeyValueSoftwareSystems/netra-sdk-py/tree/main diff --git a/netra/__init__.py b/netra/__init__.py index cb544fb0..5ced2017 100644 --- a/netra/__init__.py +++ b/netra/__init__.py @@ -508,6 +508,28 @@ def set_root_output(cls, value: Any) -> None: """ SessionManager.set_root_output(value) + @classmethod + def set_root_output_stream(cls, value: Any) -> Any: + """ + Wrap a stream so the accumulated output is set on the root span when iteration ends. + + The returned object is a transparent proxy — iterate over it instead of the original:: + + stream = Netra.set_root_output_stream(stream) + for chunk in stream: + ... + + Supports both sync and async iterables. Returns *value* unchanged if no active trace + context exists or if *value* is not iterable. + + Args: + value: The stream to wrap (Netra-instrumented or any generic iterable). + + Returns: + A wrapped stream proxy, or *value* unchanged if wrapping is not possible. + """ + return SessionManager.set_root_output_stream(value) + @classmethod def start_span( cls, diff --git a/netra/instrumentation/agno/utils.py b/netra/instrumentation/agno/utils.py index 2a3829da..26be21aa 100644 --- a/netra/instrumentation/agno/utils.py +++ b/netra/instrumentation/agno/utils.py @@ -941,7 +941,7 @@ def set_request_attributes( span.set_attribute("input", input_content) -def set_response_attributes(span: Span, response: Any) -> None: +def set_response_attributes(span: Span, response: Any) -> Optional[str]: """Set response-side span attributes from an Agno response object. Writes token usage, output content, response ID, and output type. @@ -951,7 +951,7 @@ def set_response_attributes(span: Span, response: Any) -> None: response: The Agno response object (RunResponse, TeamRunOutput, etc.). """ if not span.is_recording(): - return + return None usage = extract_token_usage(response) if usage: @@ -965,6 +965,8 @@ def set_response_attributes(span: Span, response: Any) -> None: if response_id: span.set_attribute(ATTR_RESPONSE_ID, response_id) + return output + def sanitize_headers(raw_headers: List[Tuple[bytes, bytes]]) -> Dict[str, str]: """Convert ASGI raw header pairs to a dict with sensitive values redacted. diff --git a/netra/instrumentation/agno/wrappers.py b/netra/instrumentation/agno/wrappers.py index 8aab90d9..91dacbb9 100644 --- a/netra/instrumentation/agno/wrappers.py +++ b/netra/instrumentation/agno/wrappers.py @@ -157,6 +157,8 @@ def _set_common_span_attributes(span: Span, entity_type: str) -> None: class _BaseStreamWrapper: """Shared base for all span streaming wrappers.""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: Any, ctx_token: Any = None) -> None: """Initialise the streaming wrapper. @@ -222,10 +224,14 @@ class _AgentStreamOutputMixin: def _set_output_on_success(self) -> None: """Write accumulated run content and token usage to the span before it closes.""" + self._netra_output = "" if self._last_response is not None: - set_response_attributes(self._span, self._last_response) + output = set_response_attributes(self._span, self._last_response) + self._netra_output = output if output else "" if self._content_chunks: - self._span.set_attribute("output", "".join(self._content_chunks)) + output = "".join(self._content_chunks) + self._span.set_attribute("output", output) + self._netra_output = output class _LlmStreamOutputMixin: @@ -239,9 +245,11 @@ class _LlmStreamOutputMixin: def _set_output_on_success(self) -> None: """Write accumulated LLM content, token usage, and timing metrics to the span.""" output_str = None + self._netra_output = "" if self._content_chunks: content = "".join(self._content_chunks) output_str = json.dumps([{"role": "assistant", "content": content}]) + self._netra_output = content elif self._tool_calls: try: tc_serialized = serialize_value(self._tool_calls, clean=True) @@ -251,10 +259,12 @@ def _set_output_on_success(self) -> None: except (json.JSONDecodeError, ValueError): tc_data = tc_serialized output_str = json.dumps([{"role": "assistant", "tool_calls": tc_data}]) + self._netra_output = tc_serialized except Exception as e: logger.debug("netra.instrumentation.agno: failed to serialize tool_calls for LLM output: %s", e) elif self._last_response is not None: output_str = format_response_as_output(self._last_response) + self._netra_output = output_str if output_str else "" if output_str: self._span.set_attribute("output", output_str) set_llm_completion_attributes(self._span, output_str) diff --git a/netra/instrumentation/cerebras/wrappers.py b/netra/instrumentation/cerebras/wrappers.py index e0c2f065..2ebd2863 100644 --- a/netra/instrumentation/cerebras/wrappers.py +++ b/netra/instrumentation/cerebras/wrappers.py @@ -40,6 +40,8 @@ def _detect_streaming(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> bool: class StreamingWrapper(ObjectProxy): # type: ignore[misc] """Wrapper for streaming responses""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: Iterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -59,6 +61,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + def __iter__(self) -> Iterator[Any]: return self @@ -129,6 +140,7 @@ def _finalize_span(self) -> None: """Finalize span when streaming is complete""" record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() @@ -136,6 +148,8 @@ def _finalize_span(self) -> None: class AsyncStreamingWrapper(ObjectProxy): # type: ignore[misc] """Async wrapper for streaming responses""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: AsyncIterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -155,6 +169,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + def __aiter__(self) -> AsyncIterator[Any]: return self @@ -227,6 +250,7 @@ def _finalize_span(self) -> None: """Finalize span when streaming is complete""" record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() diff --git a/netra/instrumentation/google_genai/wrappers.py b/netra/instrumentation/google_genai/wrappers.py index 21472d9d..934b9e91 100644 --- a/netra/instrumentation/google_genai/wrappers.py +++ b/netra/instrumentation/google_genai/wrappers.py @@ -235,6 +235,8 @@ async def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, . class StreamingWrapper: + _netra_stream_wrapper = True + def __init__(self, span: Span, response: Iterator[Any]) -> None: self._span = span self._buffer: dict[Any, Any] = {"chunk": None, "content": ""} @@ -272,11 +274,14 @@ def _process_chunk(self, chunk: Any) -> None: def _finalize_span(self) -> None: record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._buffer) + self._netra_output = self._buffer.get("content", "") if isinstance(self._buffer, dict) else "" self._span.set_status(Status(StatusCode.OK)) self._span.end() class AsyncStreamingWrapper: + _netra_stream_wrapper = True + def __init__(self, span: Span, response: AsyncIterator[Any]) -> None: self._span = span self._buffer: dict[Any, Any] = {"chunk": None, "content": ""} @@ -313,5 +318,6 @@ def _process_chunk(self, chunk: Any) -> None: def _finalize_span(self) -> None: record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._buffer) + self._netra_output = self._buffer.get("content", "") if isinstance(self._buffer, dict) else "" self._span.set_status(Status(StatusCode.OK)) self._span.end() diff --git a/netra/instrumentation/groq/wrappers.py b/netra/instrumentation/groq/wrappers.py index e64241cb..872e7f42 100644 --- a/netra/instrumentation/groq/wrappers.py +++ b/netra/instrumentation/groq/wrappers.py @@ -26,6 +26,8 @@ class StreamingWrapper(ObjectProxy): # type: ignore[misc] """Wrapper for streaming responses (OpenAI-style).""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: Iterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -43,6 +45,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + def __iter__(self) -> Iterator[Any]: return self @@ -98,6 +109,7 @@ def _process_chunk(self, chunk: Any) -> None: def _finalize_span(self) -> None: record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() @@ -105,6 +117,8 @@ def _finalize_span(self) -> None: class AsyncStreamingWrapper(ObjectProxy): # type: ignore[misc] """Async wrapper for streaming responses (OpenAI-style).""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: AsyncIterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -122,6 +136,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + def __aiter__(self) -> AsyncIterator[Any]: return self @@ -177,6 +200,7 @@ def _process_chunk(self, chunk: Any) -> None: def _finalize_span(self) -> None: record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() diff --git a/netra/instrumentation/litellm/wrappers.py b/netra/instrumentation/litellm/wrappers.py index 659acfd9..7ff657fc 100644 --- a/netra/instrumentation/litellm/wrappers.py +++ b/netra/instrumentation/litellm/wrappers.py @@ -335,6 +335,8 @@ async def wrapper( class StreamingWrapper(ObjectProxy): # type: ignore[misc] """Wrapper for streaming responses""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: Iterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -354,6 +356,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + def __enter__(self) -> "StreamingWrapper": if hasattr(self.__wrapped__, "__enter__"): self.__wrapped__.__enter__() @@ -444,6 +455,7 @@ def _finalize_span(self) -> None: """Finalize span when streaming is complete""" record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() @@ -451,6 +463,8 @@ def _finalize_span(self) -> None: class AsyncStreamingWrapper(ObjectProxy): # type: ignore[misc] """Async wrapper for streaming responses""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: AsyncIterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -470,6 +484,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + async def __aenter__(self) -> "AsyncStreamingWrapper": if hasattr(self.__wrapped__, "__aenter__"): await self.__wrapped__.__aenter__() @@ -560,5 +583,6 @@ def _finalize_span(self) -> None: """Finalize span when streaming is complete""" record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() diff --git a/netra/instrumentation/openai/wrappers.py b/netra/instrumentation/openai/wrappers.py index e32e82cf..3d013f5f 100644 --- a/netra/instrumentation/openai/wrappers.py +++ b/netra/instrumentation/openai/wrappers.py @@ -281,6 +281,8 @@ async def wrapper(wrapped: Callable[..., Awaitable[Any]], instance: Any, args: A class StreamingWrapper(ObjectProxy): # type: ignore[misc] """Wrapper for streaming responses""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: Iterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -300,6 +302,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + def __enter__(self) -> "StreamingWrapper": if hasattr(self.__wrapped__, "__enter__"): self.__wrapped__.__enter__() @@ -412,6 +423,7 @@ def _finalize_span(self) -> None: msg["tool_calls"] = [msg["tool_calls"][i] for i in sorted(msg["tool_calls"].keys())] record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() @@ -419,6 +431,8 @@ def _finalize_span(self) -> None: class AsyncStreamingWrapper(ObjectProxy): # type: ignore[misc] """Async wrapper for streaming responses""" + _netra_stream_wrapper = True + def __init__(self, span: Span, response: AsyncIterator[Any], request_kwargs: Dict[str, Any]) -> None: super().__init__(response) self._span = span @@ -438,6 +452,15 @@ def _ensure_choice(self, index: int) -> None: else: self._complete_response["choices"].append({"text": ""}) + def _extract_content_text(self) -> str: + """Extract the plain text content from the accumulated response.""" + parts = [] + for choice in self._complete_response.get("choices", []): + msg = choice.get("message", {}) + if content := msg.get("content"): + parts.append(content) + return "".join(parts) + async def __aenter__(self) -> "AsyncStreamingWrapper": if hasattr(self.__wrapped__, "__aenter__"): await self.__wrapped__.__aenter__() @@ -550,5 +573,6 @@ def _finalize_span(self) -> None: msg["tool_calls"] = [msg["tool_calls"][i] for i in sorted(msg["tool_calls"].keys())] record_span_timing(self._span, LLM_RESPONSE_DURATION) set_response_attributes(self._span, self._complete_response) + self._netra_output = self._extract_content_text() self._span.set_status(Status(StatusCode.OK)) self._span.end() diff --git a/netra/instrumentation/stream_utils.py b/netra/instrumentation/stream_utils.py new file mode 100644 index 00000000..0a81440b --- /dev/null +++ b/netra/instrumentation/stream_utils.py @@ -0,0 +1,201 @@ +""" +Utilities for wrapping stream objects so that when iteration completes, the +accumulated output is automatically set on the root span of the current trace. + +Three flows are supported: + + 1. Netra-wrapped stream (``_netra_stream_wrapper = True``) + The inner instrumentation wrapper has already accumulated the output + in ``_netra_output``. The outer tap simply delegates iteration and + reads that attribute once the inner wrapper signals exhaustion. + + 2. Generic / unknown stream + Any iterable whose type Netra does not know about. Chunks are + converted to strings via ``str(chunk)`` and concatenated. + + 3. Objects that carry no iterator protocol are returned unchanged with a + warning log. +""" + +import logging +from typing import Any, Callable, List, Union + +from opentelemetry.trace import Span + +from netra.session_manager import NETRA_USER_OUTPUT +from netra.utils import serialize_value + +logger = logging.getLogger(__name__) + + +def _set_output_on_root(root_span: Span, output: Any) -> None: + """Write serialized *output* to *root_span* as ``NETRA_USER_OUTPUT``.""" + try: + serialized = serialize_value(output) + if serialized: + root_span.set_attribute(NETRA_USER_OUTPUT, serialized) + except Exception: + logger.warning("root_output_stream: failed to set output on root span", exc_info=True) + + +# Extractors — injected at construction time, kept stateless +def _netra_extractor(wrapper: Union["RootOutputSyncStreamWrapper", "RootOutputAsyncStreamWrapper"]) -> Any: + """Read accumulated output from the inner Netra wrapper.""" + inner = wrapper._stream + output = getattr(inner, "_netra_output", None) + if output is not None: + return output + # Nested wrapping: inner is another RootOutput* wrapper with _chunks + chunks = getattr(inner, "_chunks", None) + if chunks is not None: + return "".join(chunks) + return None + + +def _generic_extractor(wrapper: Union["RootOutputSyncStreamWrapper", "RootOutputAsyncStreamWrapper"]) -> Any: + """Return the concatenated stringified chunks.""" + return "".join(wrapper._chunks) + + +# Sync wrapper +class RootOutputSyncStreamWrapper: + """Wraps a sync iterable; on exhaustion sets the output on the root span.""" + + _netra_stream_wrapper = True + + def __init__(self, stream: Any, root_span: Span, extractor: Callable[[Any], Any]) -> None: + self._stream = stream + self._root_span = root_span + self._extractor = extractor + self._chunks: List[str] = [] + self._track_chunks: bool = extractor is _generic_extractor + self._committed = False + + def __iter__(self) -> "RootOutputSyncStreamWrapper": + return self + + def __next__(self) -> Any: + try: + chunk = next(self._stream) + if self._track_chunks: + self._chunks.append(str(chunk)) + return chunk + except StopIteration: + self._commit() + raise + + def __enter__(self) -> "RootOutputSyncStreamWrapper": + if hasattr(self._stream, "__enter__"): + self._stream.__enter__() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if hasattr(self._stream, "__exit__"): + self._stream.__exit__(exc_type, exc_val, exc_tb) + if exc_type is None: + self._commit() + + def __getattr__(self, name: str) -> Any: + return getattr(self._stream, name) + + def __del__(self) -> None: + if not self._committed: + self._commit() + + def _commit(self) -> None: + if self._committed: + return + self._committed = True + try: + _set_output_on_root(self._root_span, self._extractor(self)) + except Exception: + logger.debug("RootOutputSyncWrapper: failed to commit output to root span", exc_info=True) + + +# Async wrapper +class RootOutputAsyncStreamWrapper: + """Wraps an async iterable; on exhaustion sets the output on the root span.""" + + _netra_stream_wrapper = True + + def __init__(self, stream: Any, root_span: Span, extractor: Callable[[Any], Any]) -> None: + self._stream = stream + self._root_span = root_span + self._extractor = extractor + self._chunks: List[str] = [] + self._track_chunks: bool = extractor is _generic_extractor + self._committed = False + + def __aiter__(self) -> "RootOutputAsyncStreamWrapper": + return self + + async def __anext__(self) -> Any: + try: + chunk = await self._stream.__anext__() + if self._track_chunks: + self._chunks.append(str(chunk)) + return chunk + except StopAsyncIteration: + self._commit() + raise + + async def __aenter__(self) -> "RootOutputAsyncStreamWrapper": + if hasattr(self._stream, "__aenter__"): + await self._stream.__aenter__() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if hasattr(self._stream, "__aexit__"): + await self._stream.__aexit__(exc_type, exc_val, exc_tb) + if exc_type is None: + self._commit() + + def __getattr__(self, name: str) -> Any: + return getattr(self._stream, name) + + def __del__(self) -> None: + if not self._committed: + self._commit() + + def _commit(self) -> None: + if self._committed: + return + self._committed = True + try: + _set_output_on_root(self._root_span, self._extractor(self)) + except Exception: + logger.debug("RootOutputAsyncWrapper: failed to commit output to root span", exc_info=True) + + +def wrap_stream_for_root_output(stream: Any, root_span: Span) -> Any: + """Wrap *stream* so the accumulated output is set on *root_span* when iteration ends. + + Detection order: + 1. ``_netra_stream_wrapper`` attribute present (Netra-wrapped) + 2. Has ``__aiter__`` or ``__iter__`` (generic) + 3. Not iterable (return unchanged) + + Args: + stream: The stream to wrap. May be sync or async. + root_span: The root OTel span that will receive the ``NETRA_USER_OUTPUT`` attribute. + + Returns: + A :class:`RootOutputSyncWrapper`, :class:`RootOutputAsyncWrapper`, or the + original *stream* unchanged if it is not iterable. + """ + is_netra = getattr(stream, "_netra_stream_wrapper", False) + extractor: Callable[[Union["RootOutputSyncStreamWrapper", "RootOutputAsyncStreamWrapper"]], Any] = ( + _netra_extractor if is_netra else _generic_extractor + ) + + if hasattr(stream, "__aiter__"): + return RootOutputAsyncStreamWrapper(stream, root_span, extractor) + + if hasattr(stream, "__iter__"): + return RootOutputSyncStreamWrapper(stream, root_span, extractor) + + logger.warning( + "set_root_output_stream: passed object of type %s is not iterable; returning unchanged", + type(stream).__name__, + ) + return stream diff --git a/netra/session_manager.py b/netra/session_manager.py index 930d07ca..62543279 100644 --- a/netra/session_manager.py +++ b/netra/session_manager.py @@ -1,4 +1,3 @@ -import json import logging from datetime import datetime from enum import Enum @@ -9,11 +8,15 @@ from opentelemetry import trace from netra.config import Config -from netra.utils import process_content_for_max_len +from netra.utils import process_content_for_max_len, serialize_value logger = logging.getLogger(__name__) +NETRA_USER_INPUT = "netra.user.input" +NETRA_USER_OUTPUT = "netra.user.output" + + class ConversationType(str, Enum): INPUT = "input" OUTPUT = "output" @@ -399,11 +402,8 @@ def set_input(cls, value: Any) -> None: value: The input value to record. """ try: - if isinstance(value, (dict, list)): - serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] - else: - serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] - cls.set_attribute_on_active_span("netra.user.input", serialized) + serialized = serialize_value(value) + cls.set_attribute_on_active_span(NETRA_USER_INPUT, serialized) except Exception: logger.exception("SessionManager.set_input: failed to set input attribute") @@ -419,11 +419,8 @@ def set_output(cls, value: Any) -> None: value: The output value to record. """ try: - if isinstance(value, (dict, list)): - serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] - else: - serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] - cls.set_attribute_on_active_span("netra.user.output", serialized) + serialized = serialize_value(value) + cls.set_attribute_on_active_span(NETRA_USER_OUTPUT, serialized) except Exception: logger.exception("SessionManager.set_output: failed to set output attribute") @@ -438,11 +435,8 @@ def set_root_input(cls, value: Any) -> None: value: The input value to record. """ try: - if isinstance(value, (dict, list)): - serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] - else: - serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] - cls.set_attribute_on_root_span("netra.user.input", serialized) + serialized = serialize_value(value) + cls.set_attribute_on_root_span(NETRA_USER_INPUT, serialized) except Exception: logger.exception("SessionManager.set_root_input: failed to set input attribute") @@ -457,14 +451,44 @@ def set_root_output(cls, value: Any) -> None: value: The output value to record. """ try: - if isinstance(value, (dict, list)): - serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] - else: - serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] - cls.set_attribute_on_root_span("netra.user.output", serialized) + serialized = serialize_value(value) + cls.set_attribute_on_root_span(NETRA_USER_OUTPUT, serialized) except Exception: logger.exception("SessionManager.set_root_output: failed to set output attribute") + @classmethod + def set_root_output_stream(cls, value: Any) -> Any: + """Wrap a stream so that the accumulated output is set on the root span when iteration ends. + + The stream is wrapped transparently — the user should iterate over the returned object + instead of the original stream. On exhaustion (or garbage collection), the output is + automatically written to the ``netra.user.output`` attribute of the root span for the + current trace, which is then promoted to ``output`` by the export pipeline. + + Supports both sync iterables and async iterables. + + Args: + value: The stream to wrap. May be a Netra-instrumented wrapper or any generic iterable. + + Returns: + A wrapped stream proxy. Returns *value* unchanged if no active trace context + exists or if *value* is not iterable, so callers can always reassign safely:: + + stream = Netra.set_root_output_stream(stream) + """ + try: + from netra.instrumentation.stream_utils import wrap_stream_for_root_output + from netra.processors.root_span_processor import RootSpanProcessor + + root_span = RootSpanProcessor.get_root_span(trace.get_current_span()) + if not root_span: + logger.warning("SessionManager.set_root_output_stream: no root span found for current trace") + return value + return wrap_stream_for_root_output(value, root_span) + except Exception: + logger.exception("SessionManager.set_root_output_stream: failed to wrap stream") + return value + @classmethod def set_attribute_on_root_span(cls, attr_key: str, attr_value: Any) -> None: """Set an attribute on the root span of the current trace. diff --git a/netra/utils.py b/netra/utils.py index 7d6d69b8..79265bf9 100644 --- a/netra/utils.py +++ b/netra/utils.py @@ -8,6 +8,7 @@ import logging from typing import AbstractSet, Any, Optional, Set +from netra.config import Config from netra.instrumentation.instruments import ( DEFAULT_INSTRUMENTS_FOR_ROOT, InstrumentSet, @@ -89,6 +90,20 @@ def process_content_for_max_len(content: Any, max_len: int) -> Any: return content +def serialize_value(value: Any) -> str: + """Serialize *value* to a string capped at ``Config.ATTRIBUTE_MAX_LEN``.""" + if value is None: + return "" + try: + import json + + serialized = json.dumps(value) if isinstance(value, (dict, list)) else str(value) + return truncate_string(serialized, Config.ATTRIBUTE_MAX_LEN) + except Exception: + logger.debug("utils: failed to serialize value", exc_info=True) + return "" + + def resolve_root_instruments( root_instruments: Optional[AbstractSet[NetraInstruments]], block_instruments: Optional[AbstractSet[NetraInstruments]], From 57aa2aec3a4b628c2169e23b4a5029e99bf8a9f1 Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Wed, 10 Jun 2026 16:16:48 +0530 Subject: [PATCH 04/13] test: Write test for SpanIOProcessor (PRD 4.3) --- tests/test_SpanIOProcessor.py | 39 +++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/test_SpanIOProcessor.py diff --git a/tests/test_SpanIOProcessor.py b/tests/test_SpanIOProcessor.py new file mode 100644 index 00000000..097935b9 --- /dev/null +++ b/tests/test_SpanIOProcessor.py @@ -0,0 +1,39 @@ +import unittest +from unittest.mock import MagicMock + +from netra.processors.span_io_processor import SpanIOProcessor + + +class TestSpanIOProcessor(unittest.TestCase): + """Test cases for `SpanIOProcessor` as per `4.3` in PRD""" + + _USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + _USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + _USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens" + _USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens" + + def __get_mocks(self, original_value: str, desired_value: str): + mock_span = MagicMock() + mock_original_set_attribute = MagicMock() + mock_span.set_attribute = mock_original_set_attribute + # creating fake span + + SpanIOProcessor._wrap_set_attribute(mock_span) + # adding set_attribute method to span + + mock_span.set_attribute(original_value, 100) + # setting value + + mock_original_set_attribute.assert_called_with(desired_value, 100) + + def test_alias_to_prompt_tokens(self): + self.__get_mocks(self._USAGE_INPUT_TOKENS, self._USAGE_PROMPT_TOKENS) + + def test_alias_to_completion_tokens(self): + self.__get_mocks(self._USAGE_OUTPUT_TOKENS, self._USAGE_COMPLETION_TOKENS) + + def test_pass_through_prompt_tokens(self): + self.__get_mocks(self._USAGE_PROMPT_TOKENS, self._USAGE_PROMPT_TOKENS) + + def test_pass_through_completion_tokens(self): + self.__get_mocks(self._USAGE_COMPLETION_TOKENS, self._USAGE_COMPLETION_TOKENS) From bba20c4378815697bae5fa69778d81f01d2b2cd9 Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Wed, 10 Jun 2026 17:33:41 +0530 Subject: [PATCH 05/13] test: Write test for groq provider utils --- tests/test_groq_provider_utils.py | 77 +++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 tests/test_groq_provider_utils.py diff --git a/tests/test_groq_provider_utils.py b/tests/test_groq_provider_utils.py new file mode 100644 index 00000000..564be287 --- /dev/null +++ b/tests/test_groq_provider_utils.py @@ -0,0 +1,77 @@ +import itertools +import unittest +from unittest.mock import MagicMock + +from opentelemetry.semconv_ai import SpanAttributes + +from netra.instrumentation.groq.utils import _set_usage_attributes + + +class TestGroqProviderUtils(unittest.TestCase): + """Tests `_set_usage_attributes` from `groq.utils`""" + + ALIASES = { + "prompt_tokens": ["prompt_tokens", "input_tokens"], + "completion_tokens": ["completion_tokens", "output_tokens"], + "prompt_tokens_details": ["prompt_tokens_details", "input_tokens_details"], + } + + P_TOKEN = 100 + C_TOKEN = 50 + P_DETAIL = 10 + T_TOKEN = 160 + + def __build_input_data(self): + keys_groups = [ + self.ALIASES["prompt_tokens"], + self.ALIASES["completion_tokens"], + self.ALIASES["prompt_tokens_details"], + ] + + for p_token, c_token, p_detail in itertools.product(*keys_groups): + data = dict() + + data[p_token] = self.P_TOKEN + data[c_token] = self.C_TOKEN + data[p_detail] = {"cached_tokens": self.P_DETAIL} + data["total_tokens"] = self.T_TOKEN + + yield data + + def __build_no_details_data(self): + keys_group = [self.ALIASES["prompt_tokens"], self.ALIASES["completion_tokens"]] + + for p_token, c_token in itertools.product(*keys_group): + data = dict() + + data[p_token] = self.P_TOKEN + data[c_token] = self.C_TOKEN + data["total_tokens"] = self.T_TOKEN + + yield data + + def test_set_usage_attributes(self): + for dummy_data in self.__build_input_data(): + mock_span = MagicMock() + _set_usage_attributes(mock_span, dummy_data) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS, self.P_DETAIL) + + def test_set_usage_attributes_no_prompt_tokens_details(self): + for dummy_data in self.__build_no_details_data(): + mock_span = MagicMock() + _set_usage_attributes(mock_span, dummy_data) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) + + def test_empty_dict(self): + mock_span = MagicMock() + _set_usage_attributes(mock_span, dict()) + + called_keys = [call[0][0] for call in mock_span.set_attribute.call_args_list] + self.assertEqual(len(called_keys), 0) From 8e8bc796266cfbb9a9d84fcfe5a1a5a0ec543f3e Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Thu, 11 Jun 2026 10:53:49 +0530 Subject: [PATCH 06/13] test: Write more tests for groq provider covering every methods --- tests/test_groq_provider_utils.py | 226 +++++++++++++++++++++++++++++- 1 file changed, 224 insertions(+), 2 deletions(-) diff --git a/tests/test_groq_provider_utils.py b/tests/test_groq_provider_utils.py index 564be287..3626b485 100644 --- a/tests/test_groq_provider_utils.py +++ b/tests/test_groq_provider_utils.py @@ -1,14 +1,29 @@ import itertools +import random import unittest +from collections.abc import Iterable +from typing import Any from unittest.mock import MagicMock from opentelemetry.semconv_ai import SpanAttributes -from netra.instrumentation.groq.utils import _set_usage_attributes +from netra.instrumentation.groq.utils import ( + _set_chat_input, + _set_response_message_attributes, + _set_usage_attributes, + set_request_attributes, + set_response_attributes, +) + + +class MockMessageObject: + def __init__(self, role: str, content: any): + self.role = role + self.content = content class TestGroqProviderUtils(unittest.TestCase): - """Tests `_set_usage_attributes` from `groq.utils`""" + """Tests `_set_usage_attributes`, `set_request_attributes`, `_set_chat_input`, `_set_response_message_attributes`, `set_response_attributes` from `groq.utils`""" ALIASES = { "prompt_tokens": ["prompt_tokens", "input_tokens"], @@ -21,6 +36,20 @@ class TestGroqProviderUtils(unittest.TestCase): P_DETAIL = 10 T_TOKEN = 160 + ATTRIBUTE_MAPPINGS = { + "model": SpanAttributes.LLM_REQUEST_MODEL, + "temperature": SpanAttributes.LLM_REQUEST_TEMPERATURE, + "max_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, + "max_completion_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, + "max_tokens_to_sample": SpanAttributes.LLM_REQUEST_MAX_TOKENS, + "reasoning_effort": SpanAttributes.LLM_REQUEST_REASONING_EFFORT, + "frequency_penalty": SpanAttributes.LLM_FREQUENCY_PENALTY, + "presence_penalty": SpanAttributes.LLM_PRESENCE_PENALTY, + "stop": SpanAttributes.LLM_CHAT_STOP_SEQUENCES, + "stream": SpanAttributes.LLM_IS_STREAMING, + "top_p": SpanAttributes.LLM_REQUEST_TOP_P, + } + def __build_input_data(self): keys_groups = [ self.ALIASES["prompt_tokens"], @@ -75,3 +104,196 @@ def test_empty_dict(self): called_keys = [call[0][0] for call in mock_span.set_attribute.call_args_list] self.assertEqual(len(called_keys), 0) + + def test_set_request_attributes(self): + OP_TYPE = "OP_TYPE" + mock_span = MagicMock() + samples = random.sample(list(self.ATTRIBUTE_MAPPINGS.keys()), k=random.randint(1, len(self.ATTRIBUTE_MAPPINGS))) + kwargs = {sample: "mock" for sample in samples} + # picking a random sample from kwargs + + kwargs["messages"] = [{"role": "system", "content": "Test"}, {"role": "user", "content": "Test"}] + + kwargs["prompt"] = "Test Prompt" + + set_request_attributes(mock_span, kwargs, OP_TYPE) + + for key, value in kwargs.items(): + if key in self.ATTRIBUTE_MAPPINGS: + mock_span.set_attribute.assert_any_call(self.ATTRIBUTE_MAPPINGS[key], value) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_REQUEST_TYPE, OP_TYPE) + + self.__set_chat_input_check(kwargs["messages"], kwargs["prompt"]) + + def __set_chat_input_check(self, messages: list[str], prompt: str): + mock_span = MagicMock() + + _set_chat_input(mock_span, messages, prompt) + + for i, message in enumerate(messages): + if isinstance(message, MockMessageObject): + mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message.role) + mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.content", message.content) + else: + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message["role"] if "role" in message else "user" + ) + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_PROMPTS}.{i}.content", str(message["content"]) + ) + + def test_set_chat_input_check(self): + messages_object = [ + MockMessageObject(role="system", content="Initialize core instructions."), + MockMessageObject(role="user", content="Explain quantum computing simply."), + MockMessageObject(role="assistant", content="Quantum computing uses qubits..."), + ] + prompt_dummy = "Test message" + self.__set_chat_input_check(messages_object, prompt_dummy) + + messages_dict = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + { + # Missing 'role' key completely -> tests the fallback to "user" + "content": "This should default to a user role." + }, + { + "role": "assistant", + # Non-string content -> tests the str(content) conversion block + "content": ["Nested list content", 12345], + }, + ] + prompt_dummy = None + self.__set_chat_input_check(messages_dict, prompt_dummy) + + def __set_response_message_attributes_check(self, response_dict: dict[str, Any]): + mock_span = MagicMock() + _set_response_message_attributes(mock_span, response_dict) + + if choices := response_dict.get("choices"): + self.assertTrue(isinstance(choices, Iterable)) + + message_index = 0 + for choice in choices: + message = None + if _message := choice.get("message"): + message = _message + elif delta := choice.get("delta"): + message = delta + + if message is not None: + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.role", message.get("role", "assistant") + ) + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.content", message.get("content", "") + ) + + message_index += 1 + + if finish_reason := choice.get("finish_reason"): + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.finish_reason", finish_reason + ) + + def test_set_response_message_attributes(self): + # Test Case 1: Standard Complete Response (Unary Block) + unary_success_data = { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "The capital of France is Paris."}, + "finish_reason": "stop", + } + ] + } + + # Test Case 2: Streaming Chunk Response (Delta Block) + streaming_success_data = { + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "Par"}, + "finish_reason": None, # Streams often pass null finish reasons mid-flight + } + ] + } + + # Test Case 3: Multiple Choices Response (n > 1) + # Tests that message_index tracks and increases across separate array objects + multiple_choices_data = { + "choices": [ + {"index": 0, "message": {"role": "assistant", "content": "Option A text"}, "finish_reason": "stop"}, + { + "index": 1, + "message": {"role": "assistant", "content": "Option B alternative text"}, + "finish_reason": "length", + }, + ] + } + + max_length_response = { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "This sentence is cut off mid-way because the"}, + "finish_reason": "length", + } + ] + } + + tool_call_response = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "Kochi"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + + multi_stream_response = { + "choices": [ + {"index": 0, "delta": {"role": "assistant", "content": "Running option one"}, "finish_reason": None}, + { + "index": 1, + "delta": {"role": "assistant", "content": "Alternative route processing"}, + "finish_reason": "stop", + }, + ] + } + + cases = [ + unary_success_data, + streaming_success_data, + multiple_choices_data, + max_length_response, + tool_call_response, + multi_stream_response, + ] + + for case in cases: + self.__set_response_message_attributes_check(case) + + def test_set_response_attributes(self): + mock_span_1 = MagicMock() + mock_span_1.is_recording = lambda: False + set_response_attributes(mock_span_1, dict()) + self.assertEqual(0, mock_span_1.set_attribute.call_count) + + mock_span_2 = MagicMock() + mock_span_1.is_recording = lambda: True + set_response_attributes(mock_span_2, {"model": "test_model_name"}) + mock_span_2.set_attribute.assert_any_call(SpanAttributes.LLM_RESPONSE_MODEL, "test_model_name") From 54d851b4f92a2293255406a195601e1779aa035b Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Thu, 11 Jun 2026 12:16:13 +0530 Subject: [PATCH 07/13] refactor: Create a base provider class for all utils.py related test cases --- tests/fixtures/__init__.py | 1 + tests/fixtures/base_provider_utils.py | 296 ++++++++++++++++++++++++++ tests/test_groq_provider_utils.py | 294 +------------------------ 3 files changed, 304 insertions(+), 287 deletions(-) create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/base_provider_utils.py diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..25b98afe --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ +from .base_provider_utils import BaseProviderUtils, MockMessageObject diff --git a/tests/fixtures/base_provider_utils.py b/tests/fixtures/base_provider_utils.py new file mode 100644 index 00000000..c89fd873 --- /dev/null +++ b/tests/fixtures/base_provider_utils.py @@ -0,0 +1,296 @@ +import itertools +import random +from collections.abc import Iterable +from typing import Any, Callable +from unittest.mock import MagicMock + +from opentelemetry.semconv_ai import SpanAttributes + + +class MockMessageObject: + def __init__(self, role: str, content: any): + self.role = role + self.content = content + + +class BaseProviderUtils: + """Base class to be inherited by test cases that are going to test `utils.py` from `netra/instrumentation/openai/`, `netra/instrumentation/groq/` etc.""" + + ALIASES = { + "prompt_tokens": ["prompt_tokens", "input_tokens"], + "completion_tokens": ["completion_tokens", "output_tokens"], + "prompt_tokens_details": ["prompt_tokens_details", "input_tokens_details"], + } + + P_TOKEN = 100 + C_TOKEN = 50 + P_DETAIL = 10 + T_TOKEN = 160 + + ATTRIBUTE_MAPPINGS = { + "model": SpanAttributes.LLM_REQUEST_MODEL, + "temperature": SpanAttributes.LLM_REQUEST_TEMPERATURE, + "max_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, + "max_completion_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, + "max_tokens_to_sample": SpanAttributes.LLM_REQUEST_MAX_TOKENS, + "reasoning_effort": SpanAttributes.LLM_REQUEST_REASONING_EFFORT, + "frequency_penalty": SpanAttributes.LLM_FREQUENCY_PENALTY, + "presence_penalty": SpanAttributes.LLM_PRESENCE_PENALTY, + "stop": SpanAttributes.LLM_CHAT_STOP_SEQUENCES, + "stream": SpanAttributes.LLM_IS_STREAMING, + "top_p": SpanAttributes.LLM_REQUEST_TOP_P, + } + + set_request_attributes_method: Callable = None + set_response_attributes_method: Callable = None + _set_usage_attributes_method: Callable = None + _set_chat_input_method: Callable = None + _set_response_message_attributes_method: Callable = None + + def __build_input_data(self): + keys_groups = [ + self.ALIASES["prompt_tokens"], + self.ALIASES["completion_tokens"], + self.ALIASES["prompt_tokens_details"], + ] + + for p_token, c_token, p_detail in itertools.product(*keys_groups): + data = dict() + + data[p_token] = self.P_TOKEN + data[c_token] = self.C_TOKEN + data[p_detail] = {"cached_tokens": self.P_DETAIL} + data["total_tokens"] = self.T_TOKEN + + yield data + + def __build_no_details_data(self): + keys_group = [self.ALIASES["prompt_tokens"], self.ALIASES["completion_tokens"]] + + for p_token, c_token in itertools.product(*keys_group): + data = dict() + + data[p_token] = self.P_TOKEN + data[c_token] = self.C_TOKEN + data["total_tokens"] = self.T_TOKEN + + yield data + + def test_set_usage_attributes(self): + for dummy_data in self.__build_input_data(): + mock_span = MagicMock() + self._set_usage_attributes_method(mock_span, dummy_data) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS, self.P_DETAIL) + + def test_set_usage_attributes_no_prompt_tokens_details(self): + for dummy_data in self.__build_no_details_data(): + mock_span = MagicMock() + self._set_usage_attributes_method(mock_span, dummy_data) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) + + def test_empty_dict(self): + mock_span = MagicMock() + self._set_usage_attributes_method(mock_span, dict()) + + called_keys = [call[0][0] for call in mock_span.set_attribute.call_args_list] + self.assertEqual(len(called_keys), 0) + + def test_set_request_attributes(self): + OP_TYPE = "OP_TYPE" + mock_span = MagicMock() + samples = random.sample(list(self.ATTRIBUTE_MAPPINGS.keys()), k=random.randint(1, len(self.ATTRIBUTE_MAPPINGS))) + kwargs = {sample: "mock" for sample in samples} + # picking a random sample from kwargs + + kwargs["messages"] = [{"role": "system", "content": "Test"}, {"role": "user", "content": "Test"}] + + kwargs["prompt"] = "Test Prompt" + + self.set_request_attributes_method(mock_span, kwargs, OP_TYPE) + + for key, value in kwargs.items(): + if key in self.ATTRIBUTE_MAPPINGS: + mock_span.set_attribute.assert_any_call(self.ATTRIBUTE_MAPPINGS[key], value) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_REQUEST_TYPE, OP_TYPE) + + self.__set_chat_input_check(kwargs["messages"], kwargs["prompt"]) + + def __set_chat_input_check(self, messages: list[str], prompt: str): + mock_span = MagicMock() + + self._set_chat_input_method(mock_span, messages, prompt) + + for i, message in enumerate(messages): + if isinstance(message, MockMessageObject): + mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message.role) + mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.content", message.content) + else: + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message["role"] if "role" in message else "user" + ) + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_PROMPTS}.{i}.content", str(message["content"]) + ) + + def test_set_chat_input_check(self): + messages_object = [ + MockMessageObject(role="system", content="Initialize core instructions."), + MockMessageObject(role="user", content="Explain quantum computing simply."), + MockMessageObject(role="assistant", content="Quantum computing uses qubits..."), + ] + prompt_dummy = "Test message" + self.__set_chat_input_check(messages_object, prompt_dummy) + + messages_dict = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + { + # Missing 'role' key completely -> tests the fallback to "user" + "content": "This should default to a user role." + }, + { + "role": "assistant", + # Non-string content -> tests the str(content) conversion block + "content": ["Nested list content", 12345], + }, + ] + prompt_dummy = None + self.__set_chat_input_check(messages_dict, prompt_dummy) + + def __set_response_message_attributes_check(self, response_dict: dict[str, Any]): + mock_span = MagicMock() + self._set_response_message_attributes_method(mock_span, response_dict) + + if choices := response_dict.get("choices"): + self.assertTrue(isinstance(choices, Iterable)) + + message_index = 0 + for choice in choices: + message = None + if _message := choice.get("message"): + message = _message + elif delta := choice.get("delta"): + message = delta + + if message is not None: + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.role", message.get("role", "assistant") + ) + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.content", message.get("content", "") + ) + + message_index += 1 + + if finish_reason := choice.get("finish_reason"): + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.finish_reason", finish_reason + ) + + def test_set_response_message_attributes(self): + # Test Case 1: Standard Complete Response (Unary Block) + unary_success_data = { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "The capital of France is Paris."}, + "finish_reason": "stop", + } + ] + } + + # Test Case 2: Streaming Chunk Response (Delta Block) + streaming_success_data = { + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "Par"}, + "finish_reason": None, # Streams often pass null finish reasons mid-flight + } + ] + } + + # Test Case 3: Multiple Choices Response (n > 1) + # Tests that message_index tracks and increases across separate array objects + multiple_choices_data = { + "choices": [ + {"index": 0, "message": {"role": "assistant", "content": "Option A text"}, "finish_reason": "stop"}, + { + "index": 1, + "message": {"role": "assistant", "content": "Option B alternative text"}, + "finish_reason": "length", + }, + ] + } + + max_length_response = { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "This sentence is cut off mid-way because the"}, + "finish_reason": "length", + } + ] + } + + tool_call_response = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "Kochi"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + + multi_stream_response = { + "choices": [ + {"index": 0, "delta": {"role": "assistant", "content": "Running option one"}, "finish_reason": None}, + { + "index": 1, + "delta": {"role": "assistant", "content": "Alternative route processing"}, + "finish_reason": "stop", + }, + ] + } + + cases = [ + unary_success_data, + streaming_success_data, + multiple_choices_data, + max_length_response, + tool_call_response, + multi_stream_response, + ] + + for case in cases: + self.__set_response_message_attributes_check(case) + + def test_set_response_attributes(self): + mock_span_1 = MagicMock() + mock_span_1.is_recording = lambda: False + self.set_response_attributes_method(mock_span_1, dict()) + self.assertEqual(0, mock_span_1.set_attribute.call_count) + + mock_span_2 = MagicMock() + mock_span_1.is_recording = lambda: True + self.set_response_attributes_method(mock_span_2, {"model": "test_model_name"}) + mock_span_2.set_attribute.assert_any_call(SpanAttributes.LLM_RESPONSE_MODEL, "test_model_name") diff --git a/tests/test_groq_provider_utils.py b/tests/test_groq_provider_utils.py index 3626b485..2e48cec5 100644 --- a/tests/test_groq_provider_utils.py +++ b/tests/test_groq_provider_utils.py @@ -1,11 +1,4 @@ -import itertools -import random import unittest -from collections.abc import Iterable -from typing import Any -from unittest.mock import MagicMock - -from opentelemetry.semconv_ai import SpanAttributes from netra.instrumentation.groq.utils import ( _set_chat_input, @@ -15,285 +8,12 @@ set_response_attributes, ) +from .fixtures.base_provider_utils import BaseProviderUtils -class MockMessageObject: - def __init__(self, role: str, content: any): - self.role = role - self.content = content - - -class TestGroqProviderUtils(unittest.TestCase): - """Tests `_set_usage_attributes`, `set_request_attributes`, `_set_chat_input`, `_set_response_message_attributes`, `set_response_attributes` from `groq.utils`""" - - ALIASES = { - "prompt_tokens": ["prompt_tokens", "input_tokens"], - "completion_tokens": ["completion_tokens", "output_tokens"], - "prompt_tokens_details": ["prompt_tokens_details", "input_tokens_details"], - } - - P_TOKEN = 100 - C_TOKEN = 50 - P_DETAIL = 10 - T_TOKEN = 160 - - ATTRIBUTE_MAPPINGS = { - "model": SpanAttributes.LLM_REQUEST_MODEL, - "temperature": SpanAttributes.LLM_REQUEST_TEMPERATURE, - "max_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, - "max_completion_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, - "max_tokens_to_sample": SpanAttributes.LLM_REQUEST_MAX_TOKENS, - "reasoning_effort": SpanAttributes.LLM_REQUEST_REASONING_EFFORT, - "frequency_penalty": SpanAttributes.LLM_FREQUENCY_PENALTY, - "presence_penalty": SpanAttributes.LLM_PRESENCE_PENALTY, - "stop": SpanAttributes.LLM_CHAT_STOP_SEQUENCES, - "stream": SpanAttributes.LLM_IS_STREAMING, - "top_p": SpanAttributes.LLM_REQUEST_TOP_P, - } - - def __build_input_data(self): - keys_groups = [ - self.ALIASES["prompt_tokens"], - self.ALIASES["completion_tokens"], - self.ALIASES["prompt_tokens_details"], - ] - - for p_token, c_token, p_detail in itertools.product(*keys_groups): - data = dict() - - data[p_token] = self.P_TOKEN - data[c_token] = self.C_TOKEN - data[p_detail] = {"cached_tokens": self.P_DETAIL} - data["total_tokens"] = self.T_TOKEN - - yield data - - def __build_no_details_data(self): - keys_group = [self.ALIASES["prompt_tokens"], self.ALIASES["completion_tokens"]] - - for p_token, c_token in itertools.product(*keys_group): - data = dict() - - data[p_token] = self.P_TOKEN - data[c_token] = self.C_TOKEN - data["total_tokens"] = self.T_TOKEN - - yield data - - def test_set_usage_attributes(self): - for dummy_data in self.__build_input_data(): - mock_span = MagicMock() - _set_usage_attributes(mock_span, dummy_data) - - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS, self.P_DETAIL) - - def test_set_usage_attributes_no_prompt_tokens_details(self): - for dummy_data in self.__build_no_details_data(): - mock_span = MagicMock() - _set_usage_attributes(mock_span, dummy_data) - - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) - - def test_empty_dict(self): - mock_span = MagicMock() - _set_usage_attributes(mock_span, dict()) - - called_keys = [call[0][0] for call in mock_span.set_attribute.call_args_list] - self.assertEqual(len(called_keys), 0) - - def test_set_request_attributes(self): - OP_TYPE = "OP_TYPE" - mock_span = MagicMock() - samples = random.sample(list(self.ATTRIBUTE_MAPPINGS.keys()), k=random.randint(1, len(self.ATTRIBUTE_MAPPINGS))) - kwargs = {sample: "mock" for sample in samples} - # picking a random sample from kwargs - - kwargs["messages"] = [{"role": "system", "content": "Test"}, {"role": "user", "content": "Test"}] - - kwargs["prompt"] = "Test Prompt" - - set_request_attributes(mock_span, kwargs, OP_TYPE) - - for key, value in kwargs.items(): - if key in self.ATTRIBUTE_MAPPINGS: - mock_span.set_attribute.assert_any_call(self.ATTRIBUTE_MAPPINGS[key], value) - - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_REQUEST_TYPE, OP_TYPE) - - self.__set_chat_input_check(kwargs["messages"], kwargs["prompt"]) - - def __set_chat_input_check(self, messages: list[str], prompt: str): - mock_span = MagicMock() - - _set_chat_input(mock_span, messages, prompt) - - for i, message in enumerate(messages): - if isinstance(message, MockMessageObject): - mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message.role) - mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.content", message.content) - else: - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message["role"] if "role" in message else "user" - ) - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_PROMPTS}.{i}.content", str(message["content"]) - ) - - def test_set_chat_input_check(self): - messages_object = [ - MockMessageObject(role="system", content="Initialize core instructions."), - MockMessageObject(role="user", content="Explain quantum computing simply."), - MockMessageObject(role="assistant", content="Quantum computing uses qubits..."), - ] - prompt_dummy = "Test message" - self.__set_chat_input_check(messages_object, prompt_dummy) - - messages_dict = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - { - # Missing 'role' key completely -> tests the fallback to "user" - "content": "This should default to a user role." - }, - { - "role": "assistant", - # Non-string content -> tests the str(content) conversion block - "content": ["Nested list content", 12345], - }, - ] - prompt_dummy = None - self.__set_chat_input_check(messages_dict, prompt_dummy) - - def __set_response_message_attributes_check(self, response_dict: dict[str, Any]): - mock_span = MagicMock() - _set_response_message_attributes(mock_span, response_dict) - - if choices := response_dict.get("choices"): - self.assertTrue(isinstance(choices, Iterable)) - - message_index = 0 - for choice in choices: - message = None - if _message := choice.get("message"): - message = _message - elif delta := choice.get("delta"): - message = delta - - if message is not None: - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.role", message.get("role", "assistant") - ) - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.content", message.get("content", "") - ) - - message_index += 1 - - if finish_reason := choice.get("finish_reason"): - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.finish_reason", finish_reason - ) - - def test_set_response_message_attributes(self): - # Test Case 1: Standard Complete Response (Unary Block) - unary_success_data = { - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "The capital of France is Paris."}, - "finish_reason": "stop", - } - ] - } - - # Test Case 2: Streaming Chunk Response (Delta Block) - streaming_success_data = { - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": "Par"}, - "finish_reason": None, # Streams often pass null finish reasons mid-flight - } - ] - } - - # Test Case 3: Multiple Choices Response (n > 1) - # Tests that message_index tracks and increases across separate array objects - multiple_choices_data = { - "choices": [ - {"index": 0, "message": {"role": "assistant", "content": "Option A text"}, "finish_reason": "stop"}, - { - "index": 1, - "message": {"role": "assistant", "content": "Option B alternative text"}, - "finish_reason": "length", - }, - ] - } - - max_length_response = { - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "This sentence is cut off mid-way because the"}, - "finish_reason": "length", - } - ] - } - - tool_call_response = { - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_abc123", - "type": "function", - "function": {"name": "get_weather", "arguments": '{"location": "Kochi"}'}, - } - ], - }, - "finish_reason": "tool_calls", - } - ] - } - - multi_stream_response = { - "choices": [ - {"index": 0, "delta": {"role": "assistant", "content": "Running option one"}, "finish_reason": None}, - { - "index": 1, - "delta": {"role": "assistant", "content": "Alternative route processing"}, - "finish_reason": "stop", - }, - ] - } - - cases = [ - unary_success_data, - streaming_success_data, - multiple_choices_data, - max_length_response, - tool_call_response, - multi_stream_response, - ] - - for case in cases: - self.__set_response_message_attributes_check(case) - - def test_set_response_attributes(self): - mock_span_1 = MagicMock() - mock_span_1.is_recording = lambda: False - set_response_attributes(mock_span_1, dict()) - self.assertEqual(0, mock_span_1.set_attribute.call_count) - mock_span_2 = MagicMock() - mock_span_1.is_recording = lambda: True - set_response_attributes(mock_span_2, {"model": "test_model_name"}) - mock_span_2.set_attribute.assert_any_call(SpanAttributes.LLM_RESPONSE_MODEL, "test_model_name") +class TestGroqProviderUtils(unittest.TestCase, BaseProviderUtils): + set_request_attributes_method = staticmethod(set_request_attributes) + set_response_attributes_method = staticmethod(set_response_attributes) + _set_chat_input_method = staticmethod(_set_chat_input) + _set_response_message_attributes_method = staticmethod(_set_response_message_attributes) + _set_usage_attributes_method = staticmethod(_set_usage_attributes) From 4a80f5996afc64c480b77c8eadfb49170d07ccf0 Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Thu, 11 Jun 2026 12:32:37 +0530 Subject: [PATCH 08/13] refactor: Move some responsibility back to TestGroqProviderUtils class --- tests/fixtures/base_provider_utils.py | 140 +------------------------ tests/test_groq_provider_utils.py | 141 +++++++++++++++++++++++++- 2 files changed, 142 insertions(+), 139 deletions(-) diff --git a/tests/fixtures/base_provider_utils.py b/tests/fixtures/base_provider_utils.py index c89fd873..59e0ce1a 100644 --- a/tests/fixtures/base_provider_utils.py +++ b/tests/fixtures/base_provider_utils.py @@ -1,5 +1,4 @@ import itertools -import random from collections.abc import Iterable from typing import Any, Callable from unittest.mock import MagicMock @@ -102,28 +101,7 @@ def test_empty_dict(self): called_keys = [call[0][0] for call in mock_span.set_attribute.call_args_list] self.assertEqual(len(called_keys), 0) - def test_set_request_attributes(self): - OP_TYPE = "OP_TYPE" - mock_span = MagicMock() - samples = random.sample(list(self.ATTRIBUTE_MAPPINGS.keys()), k=random.randint(1, len(self.ATTRIBUTE_MAPPINGS))) - kwargs = {sample: "mock" for sample in samples} - # picking a random sample from kwargs - - kwargs["messages"] = [{"role": "system", "content": "Test"}, {"role": "user", "content": "Test"}] - - kwargs["prompt"] = "Test Prompt" - - self.set_request_attributes_method(mock_span, kwargs, OP_TYPE) - - for key, value in kwargs.items(): - if key in self.ATTRIBUTE_MAPPINGS: - mock_span.set_attribute.assert_any_call(self.ATTRIBUTE_MAPPINGS[key], value) - - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_REQUEST_TYPE, OP_TYPE) - - self.__set_chat_input_check(kwargs["messages"], kwargs["prompt"]) - - def __set_chat_input_check(self, messages: list[str], prompt: str): + def _set_chat_input_check(self, messages: list[str], prompt: str): mock_span = MagicMock() self._set_chat_input_method(mock_span, messages, prompt) @@ -140,32 +118,7 @@ def __set_chat_input_check(self, messages: list[str], prompt: str): f"{SpanAttributes.LLM_PROMPTS}.{i}.content", str(message["content"]) ) - def test_set_chat_input_check(self): - messages_object = [ - MockMessageObject(role="system", content="Initialize core instructions."), - MockMessageObject(role="user", content="Explain quantum computing simply."), - MockMessageObject(role="assistant", content="Quantum computing uses qubits..."), - ] - prompt_dummy = "Test message" - self.__set_chat_input_check(messages_object, prompt_dummy) - - messages_dict = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - { - # Missing 'role' key completely -> tests the fallback to "user" - "content": "This should default to a user role." - }, - { - "role": "assistant", - # Non-string content -> tests the str(content) conversion block - "content": ["Nested list content", 12345], - }, - ] - prompt_dummy = None - self.__set_chat_input_check(messages_dict, prompt_dummy) - - def __set_response_message_attributes_check(self, response_dict: dict[str, Any]): + def _set_response_message_attributes_check(self, response_dict: dict[str, Any]): mock_span = MagicMock() self._set_response_message_attributes_method(mock_span, response_dict) @@ -195,95 +148,6 @@ def __set_response_message_attributes_check(self, response_dict: dict[str, Any]) f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.finish_reason", finish_reason ) - def test_set_response_message_attributes(self): - # Test Case 1: Standard Complete Response (Unary Block) - unary_success_data = { - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "The capital of France is Paris."}, - "finish_reason": "stop", - } - ] - } - - # Test Case 2: Streaming Chunk Response (Delta Block) - streaming_success_data = { - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": "Par"}, - "finish_reason": None, # Streams often pass null finish reasons mid-flight - } - ] - } - - # Test Case 3: Multiple Choices Response (n > 1) - # Tests that message_index tracks and increases across separate array objects - multiple_choices_data = { - "choices": [ - {"index": 0, "message": {"role": "assistant", "content": "Option A text"}, "finish_reason": "stop"}, - { - "index": 1, - "message": {"role": "assistant", "content": "Option B alternative text"}, - "finish_reason": "length", - }, - ] - } - - max_length_response = { - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "This sentence is cut off mid-way because the"}, - "finish_reason": "length", - } - ] - } - - tool_call_response = { - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_abc123", - "type": "function", - "function": {"name": "get_weather", "arguments": '{"location": "Kochi"}'}, - } - ], - }, - "finish_reason": "tool_calls", - } - ] - } - - multi_stream_response = { - "choices": [ - {"index": 0, "delta": {"role": "assistant", "content": "Running option one"}, "finish_reason": None}, - { - "index": 1, - "delta": {"role": "assistant", "content": "Alternative route processing"}, - "finish_reason": "stop", - }, - ] - } - - cases = [ - unary_success_data, - streaming_success_data, - multiple_choices_data, - max_length_response, - tool_call_response, - multi_stream_response, - ] - - for case in cases: - self.__set_response_message_attributes_check(case) - def test_set_response_attributes(self): mock_span_1 = MagicMock() mock_span_1.is_recording = lambda: False diff --git a/tests/test_groq_provider_utils.py b/tests/test_groq_provider_utils.py index 2e48cec5..55260608 100644 --- a/tests/test_groq_provider_utils.py +++ b/tests/test_groq_provider_utils.py @@ -1,4 +1,8 @@ +import random import unittest +from unittest.mock import MagicMock + +from opentelemetry.semconv_ai import SpanAttributes from netra.instrumentation.groq.utils import ( _set_chat_input, @@ -8,7 +12,7 @@ set_response_attributes, ) -from .fixtures.base_provider_utils import BaseProviderUtils +from .fixtures.base_provider_utils import BaseProviderUtils, MockMessageObject class TestGroqProviderUtils(unittest.TestCase, BaseProviderUtils): @@ -17,3 +21,138 @@ class TestGroqProviderUtils(unittest.TestCase, BaseProviderUtils): _set_chat_input_method = staticmethod(_set_chat_input) _set_response_message_attributes_method = staticmethod(_set_response_message_attributes) _set_usage_attributes_method = staticmethod(_set_usage_attributes) + + def test_set_chat_input_check(self): + messages_object = [ + MockMessageObject(role="system", content="Initialize core instructions."), + MockMessageObject(role="user", content="Explain quantum computing simply."), + MockMessageObject(role="assistant", content="Quantum computing uses qubits..."), + ] + prompt_dummy = "Test message" + self._set_chat_input_check(messages_object, prompt_dummy) + + messages_dict = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + { + # Missing 'role' key completely -> tests the fallback to "user" + "content": "This should default to a user role." + }, + { + "role": "assistant", + # Non-string content -> tests the str(content) conversion block + "content": ["Nested list content", 12345], + }, + ] + prompt_dummy = None + self._set_chat_input_check(messages_dict, prompt_dummy) + + def test_set_request_attributes(self): + OP_TYPE = "OP_TYPE" + mock_span = MagicMock() + samples = random.sample(list(self.ATTRIBUTE_MAPPINGS.keys()), k=random.randint(1, len(self.ATTRIBUTE_MAPPINGS))) + kwargs = {sample: "mock" for sample in samples} + # picking a random sample from kwargs + + kwargs["messages"] = [{"role": "system", "content": "Test"}, {"role": "user", "content": "Test"}] + + kwargs["prompt"] = "Test Prompt" + + self.set_request_attributes_method(mock_span, kwargs, OP_TYPE) + + for key, value in kwargs.items(): + if key in self.ATTRIBUTE_MAPPINGS: + mock_span.set_attribute.assert_any_call(self.ATTRIBUTE_MAPPINGS[key], value) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_REQUEST_TYPE, OP_TYPE) + + self._set_chat_input_check(kwargs["messages"], kwargs["prompt"]) + + def test_set_response_message_attributes(self): + # Test Case 1: Standard Complete Response (Unary Block) + unary_success_data = { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "The capital of France is Paris."}, + "finish_reason": "stop", + } + ] + } + + # Test Case 2: Streaming Chunk Response (Delta Block) + streaming_success_data = { + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "Par"}, + "finish_reason": None, # Streams often pass null finish reasons mid-flight + } + ] + } + + # Test Case 3: Multiple Choices Response (n > 1) + # Tests that message_index tracks and increases across separate array objects + multiple_choices_data = { + "choices": [ + {"index": 0, "message": {"role": "assistant", "content": "Option A text"}, "finish_reason": "stop"}, + { + "index": 1, + "message": {"role": "assistant", "content": "Option B alternative text"}, + "finish_reason": "length", + }, + ] + } + + max_length_response = { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "This sentence is cut off mid-way because the"}, + "finish_reason": "length", + } + ] + } + + tool_call_response = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "Kochi"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + + multi_stream_response = { + "choices": [ + {"index": 0, "delta": {"role": "assistant", "content": "Running option one"}, "finish_reason": None}, + { + "index": 1, + "delta": {"role": "assistant", "content": "Alternative route processing"}, + "finish_reason": "stop", + }, + ] + } + + cases = [ + unary_success_data, + streaming_success_data, + multiple_choices_data, + max_length_response, + tool_call_response, + multi_stream_response, + ] + + for case in cases: + self._set_response_message_attributes_check(case) From c3d0043067746ff8926c53abddb6f6cb5c25890c Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Thu, 11 Jun 2026 14:33:44 +0530 Subject: [PATCH 09/13] test: Refactor provider utility tests and fix inheritance discovery --- tests/fixtures/base_provider_utils.py | 46 +++++++++++++++++++-------- tests/test_groq_provider_utils.py | 7 ++++ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/tests/fixtures/base_provider_utils.py b/tests/fixtures/base_provider_utils.py index 59e0ce1a..ff593319 100644 --- a/tests/fixtures/base_provider_utils.py +++ b/tests/fixtures/base_provider_utils.py @@ -76,6 +76,7 @@ def __build_no_details_data(self): yield data def test_set_usage_attributes(self): + """Tests _set_usage_attributes""" for dummy_data in self.__build_input_data(): mock_span = MagicMock() self._set_usage_attributes_method(mock_span, dummy_data) @@ -86,6 +87,7 @@ def test_set_usage_attributes(self): mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS, self.P_DETAIL) def test_set_usage_attributes_no_prompt_tokens_details(self): + """Tests _set_usage_attributes without prompt token details""" for dummy_data in self.__build_no_details_data(): mock_span = MagicMock() self._set_usage_attributes_method(mock_span, dummy_data) @@ -95,6 +97,7 @@ def test_set_usage_attributes_no_prompt_tokens_details(self): mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) def test_empty_dict(self): + """Tests _set_usage_attributes with empty dictionary""" mock_span = MagicMock() self._set_usage_attributes_method(mock_span, dict()) @@ -106,17 +109,23 @@ def _set_chat_input_check(self, messages: list[str], prompt: str): self._set_chat_input_method(mock_span, messages, prompt) - for i, message in enumerate(messages): - if isinstance(message, MockMessageObject): - mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message.role) - mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.content", message.content) - else: - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message["role"] if "role" in message else "user" - ) - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_PROMPTS}.{i}.content", str(message["content"]) - ) + if messages: + for i, message in enumerate(messages): + if isinstance(message, MockMessageObject): + mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message.role) + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_PROMPTS}.{i}.content", message.content + ) + else: + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message["role"] if "role" in message else "user" + ) + mock_span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_PROMPTS}.{i}.content", str(message["content"]) + ) + elif prompt: + mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.0.role", "user") + mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.0.content", prompt) def _set_response_message_attributes_check(self, response_dict: dict[str, Any]): mock_span = MagicMock() @@ -149,12 +158,21 @@ def _set_response_message_attributes_check(self, response_dict: dict[str, Any]): ) def test_set_response_attributes(self): + """Tests set_response_attributes""" mock_span_1 = MagicMock() - mock_span_1.is_recording = lambda: False + mock_span_1.is_recording.return_value = False self.set_response_attributes_method(mock_span_1, dict()) self.assertEqual(0, mock_span_1.set_attribute.call_count) mock_span_2 = MagicMock() - mock_span_1.is_recording = lambda: True - self.set_response_attributes_method(mock_span_2, {"model": "test_model_name"}) + mock_span_2.is_recording.return_value = True + full_data = { + "model": "test_model_name", + "usage": {"total_tokens": 100}, + "choices": [{"message": {"role": "assistant", "content": "test"}}], + } + self.set_response_attributes_method(mock_span_2, full_data) + mock_span_2.set_attribute.assert_any_call(SpanAttributes.LLM_RESPONSE_MODEL, "test_model_name") + mock_span_2.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, 100) + mock_span_2.set_attribute.assert_any_call(f"{SpanAttributes.LLM_COMPLETIONS}.0.role", "assistant") diff --git a/tests/test_groq_provider_utils.py b/tests/test_groq_provider_utils.py index 55260608..c6505975 100644 --- a/tests/test_groq_provider_utils.py +++ b/tests/test_groq_provider_utils.py @@ -23,6 +23,7 @@ class TestGroqProviderUtils(unittest.TestCase, BaseProviderUtils): _set_usage_attributes_method = staticmethod(_set_usage_attributes) def test_set_chat_input_check(self): + """Tests _set_chat_input""" messages_object = [ MockMessageObject(role="system", content="Initialize core instructions."), MockMessageObject(role="user", content="Explain quantum computing simply."), @@ -47,7 +48,12 @@ def test_set_chat_input_check(self): prompt_dummy = None self._set_chat_input_check(messages_dict, prompt_dummy) + # Test branch where messages is empty/None and prompt is provided + prompt_only = "Legacy completion prompt" + self._set_chat_input_check(messages=[], prompt=prompt_only) + def test_set_request_attributes(self): + """Tests set_request_attributes""" OP_TYPE = "OP_TYPE" mock_span = MagicMock() samples = random.sample(list(self.ATTRIBUTE_MAPPINGS.keys()), k=random.randint(1, len(self.ATTRIBUTE_MAPPINGS))) @@ -69,6 +75,7 @@ def test_set_request_attributes(self): self._set_chat_input_check(kwargs["messages"], kwargs["prompt"]) def test_set_response_message_attributes(self): + """Tests _set_response_message_attributes""" # Test Case 1: Standard Complete Response (Unary Block) unary_success_data = { "choices": [ From c709203cce45bc7e093d07d24463fe80ec215741 Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Fri, 12 Jun 2026 10:15:32 +0530 Subject: [PATCH 10/13] fix: Naming issues as per PR#311 review --- tests/fixtures/base_provider_utils.py | 8 ++++---- ...st_SpanIOProcessor.py => test_span_io_processor.py} | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) rename tests/{test_SpanIOProcessor.py => test_span_io_processor.py} (74%) diff --git a/tests/fixtures/base_provider_utils.py b/tests/fixtures/base_provider_utils.py index ff593319..9ffa2123 100644 --- a/tests/fixtures/base_provider_utils.py +++ b/tests/fixtures/base_provider_utils.py @@ -46,7 +46,7 @@ class BaseProviderUtils: _set_chat_input_method: Callable = None _set_response_message_attributes_method: Callable = None - def __build_input_data(self): + def _build_input_data(self): keys_groups = [ self.ALIASES["prompt_tokens"], self.ALIASES["completion_tokens"], @@ -63,7 +63,7 @@ def __build_input_data(self): yield data - def __build_no_details_data(self): + def _build_no_details_data(self): keys_group = [self.ALIASES["prompt_tokens"], self.ALIASES["completion_tokens"]] for p_token, c_token in itertools.product(*keys_group): @@ -77,7 +77,7 @@ def __build_no_details_data(self): def test_set_usage_attributes(self): """Tests _set_usage_attributes""" - for dummy_data in self.__build_input_data(): + for dummy_data in self._build_input_data(): mock_span = MagicMock() self._set_usage_attributes_method(mock_span, dummy_data) @@ -88,7 +88,7 @@ def test_set_usage_attributes(self): def test_set_usage_attributes_no_prompt_tokens_details(self): """Tests _set_usage_attributes without prompt token details""" - for dummy_data in self.__build_no_details_data(): + for dummy_data in self._build_no_details_data(): mock_span = MagicMock() self._set_usage_attributes_method(mock_span, dummy_data) diff --git a/tests/test_SpanIOProcessor.py b/tests/test_span_io_processor.py similarity index 74% rename from tests/test_SpanIOProcessor.py rename to tests/test_span_io_processor.py index 097935b9..5cce86a1 100644 --- a/tests/test_SpanIOProcessor.py +++ b/tests/test_span_io_processor.py @@ -12,7 +12,7 @@ class TestSpanIOProcessor(unittest.TestCase): _USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens" _USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens" - def __get_mocks(self, original_value: str, desired_value: str): + def _get_mocks(self, original_value: str, desired_value: str): mock_span = MagicMock() mock_original_set_attribute = MagicMock() mock_span.set_attribute = mock_original_set_attribute @@ -27,13 +27,13 @@ def __get_mocks(self, original_value: str, desired_value: str): mock_original_set_attribute.assert_called_with(desired_value, 100) def test_alias_to_prompt_tokens(self): - self.__get_mocks(self._USAGE_INPUT_TOKENS, self._USAGE_PROMPT_TOKENS) + self._get_mocks(self._USAGE_INPUT_TOKENS, self._USAGE_PROMPT_TOKENS) def test_alias_to_completion_tokens(self): - self.__get_mocks(self._USAGE_OUTPUT_TOKENS, self._USAGE_COMPLETION_TOKENS) + self._get_mocks(self._USAGE_OUTPUT_TOKENS, self._USAGE_COMPLETION_TOKENS) def test_pass_through_prompt_tokens(self): - self.__get_mocks(self._USAGE_PROMPT_TOKENS, self._USAGE_PROMPT_TOKENS) + self._get_mocks(self._USAGE_PROMPT_TOKENS, self._USAGE_PROMPT_TOKENS) def test_pass_through_completion_tokens(self): - self.__get_mocks(self._USAGE_COMPLETION_TOKENS, self._USAGE_COMPLETION_TOKENS) + self._get_mocks(self._USAGE_COMPLETION_TOKENS, self._USAGE_COMPLETION_TOKENS) From 375a2876951a0a1d2793815d7ee69e9bd5375f42 Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Fri, 12 Jun 2026 11:06:44 +0530 Subject: [PATCH 11/13] refactor: Remove unnecessary test cases and helper functions --- tests/fixtures/base_provider_utils.py | 76 +------------ tests/test_groq_provider_utils.py | 148 +------------------------- 2 files changed, 2 insertions(+), 222 deletions(-) diff --git a/tests/fixtures/base_provider_utils.py b/tests/fixtures/base_provider_utils.py index 9ffa2123..de3b429d 100644 --- a/tests/fixtures/base_provider_utils.py +++ b/tests/fixtures/base_provider_utils.py @@ -1,6 +1,5 @@ import itertools -from collections.abc import Iterable -from typing import Any, Callable +from typing import Callable from unittest.mock import MagicMock from opentelemetry.semconv_ai import SpanAttributes @@ -103,76 +102,3 @@ def test_empty_dict(self): called_keys = [call[0][0] for call in mock_span.set_attribute.call_args_list] self.assertEqual(len(called_keys), 0) - - def _set_chat_input_check(self, messages: list[str], prompt: str): - mock_span = MagicMock() - - self._set_chat_input_method(mock_span, messages, prompt) - - if messages: - for i, message in enumerate(messages): - if isinstance(message, MockMessageObject): - mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message.role) - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_PROMPTS}.{i}.content", message.content - ) - else: - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_PROMPTS}.{i}.role", message["role"] if "role" in message else "user" - ) - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_PROMPTS}.{i}.content", str(message["content"]) - ) - elif prompt: - mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.0.role", "user") - mock_span.set_attribute.assert_any_call(f"{SpanAttributes.LLM_PROMPTS}.0.content", prompt) - - def _set_response_message_attributes_check(self, response_dict: dict[str, Any]): - mock_span = MagicMock() - self._set_response_message_attributes_method(mock_span, response_dict) - - if choices := response_dict.get("choices"): - self.assertTrue(isinstance(choices, Iterable)) - - message_index = 0 - for choice in choices: - message = None - if _message := choice.get("message"): - message = _message - elif delta := choice.get("delta"): - message = delta - - if message is not None: - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.role", message.get("role", "assistant") - ) - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.content", message.get("content", "") - ) - - message_index += 1 - - if finish_reason := choice.get("finish_reason"): - mock_span.set_attribute.assert_any_call( - f"{SpanAttributes.LLM_COMPLETIONS}.{message_index}.finish_reason", finish_reason - ) - - def test_set_response_attributes(self): - """Tests set_response_attributes""" - mock_span_1 = MagicMock() - mock_span_1.is_recording.return_value = False - self.set_response_attributes_method(mock_span_1, dict()) - self.assertEqual(0, mock_span_1.set_attribute.call_count) - - mock_span_2 = MagicMock() - mock_span_2.is_recording.return_value = True - full_data = { - "model": "test_model_name", - "usage": {"total_tokens": 100}, - "choices": [{"message": {"role": "assistant", "content": "test"}}], - } - self.set_response_attributes_method(mock_span_2, full_data) - - mock_span_2.set_attribute.assert_any_call(SpanAttributes.LLM_RESPONSE_MODEL, "test_model_name") - mock_span_2.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, 100) - mock_span_2.set_attribute.assert_any_call(f"{SpanAttributes.LLM_COMPLETIONS}.0.role", "assistant") diff --git a/tests/test_groq_provider_utils.py b/tests/test_groq_provider_utils.py index c6505975..2e48cec5 100644 --- a/tests/test_groq_provider_utils.py +++ b/tests/test_groq_provider_utils.py @@ -1,8 +1,4 @@ -import random import unittest -from unittest.mock import MagicMock - -from opentelemetry.semconv_ai import SpanAttributes from netra.instrumentation.groq.utils import ( _set_chat_input, @@ -12,7 +8,7 @@ set_response_attributes, ) -from .fixtures.base_provider_utils import BaseProviderUtils, MockMessageObject +from .fixtures.base_provider_utils import BaseProviderUtils class TestGroqProviderUtils(unittest.TestCase, BaseProviderUtils): @@ -21,145 +17,3 @@ class TestGroqProviderUtils(unittest.TestCase, BaseProviderUtils): _set_chat_input_method = staticmethod(_set_chat_input) _set_response_message_attributes_method = staticmethod(_set_response_message_attributes) _set_usage_attributes_method = staticmethod(_set_usage_attributes) - - def test_set_chat_input_check(self): - """Tests _set_chat_input""" - messages_object = [ - MockMessageObject(role="system", content="Initialize core instructions."), - MockMessageObject(role="user", content="Explain quantum computing simply."), - MockMessageObject(role="assistant", content="Quantum computing uses qubits..."), - ] - prompt_dummy = "Test message" - self._set_chat_input_check(messages_object, prompt_dummy) - - messages_dict = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - { - # Missing 'role' key completely -> tests the fallback to "user" - "content": "This should default to a user role." - }, - { - "role": "assistant", - # Non-string content -> tests the str(content) conversion block - "content": ["Nested list content", 12345], - }, - ] - prompt_dummy = None - self._set_chat_input_check(messages_dict, prompt_dummy) - - # Test branch where messages is empty/None and prompt is provided - prompt_only = "Legacy completion prompt" - self._set_chat_input_check(messages=[], prompt=prompt_only) - - def test_set_request_attributes(self): - """Tests set_request_attributes""" - OP_TYPE = "OP_TYPE" - mock_span = MagicMock() - samples = random.sample(list(self.ATTRIBUTE_MAPPINGS.keys()), k=random.randint(1, len(self.ATTRIBUTE_MAPPINGS))) - kwargs = {sample: "mock" for sample in samples} - # picking a random sample from kwargs - - kwargs["messages"] = [{"role": "system", "content": "Test"}, {"role": "user", "content": "Test"}] - - kwargs["prompt"] = "Test Prompt" - - self.set_request_attributes_method(mock_span, kwargs, OP_TYPE) - - for key, value in kwargs.items(): - if key in self.ATTRIBUTE_MAPPINGS: - mock_span.set_attribute.assert_any_call(self.ATTRIBUTE_MAPPINGS[key], value) - - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_REQUEST_TYPE, OP_TYPE) - - self._set_chat_input_check(kwargs["messages"], kwargs["prompt"]) - - def test_set_response_message_attributes(self): - """Tests _set_response_message_attributes""" - # Test Case 1: Standard Complete Response (Unary Block) - unary_success_data = { - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "The capital of France is Paris."}, - "finish_reason": "stop", - } - ] - } - - # Test Case 2: Streaming Chunk Response (Delta Block) - streaming_success_data = { - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": "Par"}, - "finish_reason": None, # Streams often pass null finish reasons mid-flight - } - ] - } - - # Test Case 3: Multiple Choices Response (n > 1) - # Tests that message_index tracks and increases across separate array objects - multiple_choices_data = { - "choices": [ - {"index": 0, "message": {"role": "assistant", "content": "Option A text"}, "finish_reason": "stop"}, - { - "index": 1, - "message": {"role": "assistant", "content": "Option B alternative text"}, - "finish_reason": "length", - }, - ] - } - - max_length_response = { - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": "This sentence is cut off mid-way because the"}, - "finish_reason": "length", - } - ] - } - - tool_call_response = { - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_abc123", - "type": "function", - "function": {"name": "get_weather", "arguments": '{"location": "Kochi"}'}, - } - ], - }, - "finish_reason": "tool_calls", - } - ] - } - - multi_stream_response = { - "choices": [ - {"index": 0, "delta": {"role": "assistant", "content": "Running option one"}, "finish_reason": None}, - { - "index": 1, - "delta": {"role": "assistant", "content": "Alternative route processing"}, - "finish_reason": "stop", - }, - ] - } - - cases = [ - unary_success_data, - streaming_success_data, - multiple_choices_data, - max_length_response, - tool_call_response, - multi_stream_response, - ] - - for case in cases: - self._set_response_message_attributes_check(case) From b3ed8f9cd0561e87da4318f6c6ec4bceca781d5d Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Fri, 12 Jun 2026 11:21:12 +0530 Subject: [PATCH 12/13] refactor: Remove unnecessary variables --- tests/fixtures/base_provider_utils.py | 54 ++++++++------------------- tests/test_groq_provider_utils.py | 12 +----- 2 files changed, 17 insertions(+), 49 deletions(-) diff --git a/tests/fixtures/base_provider_utils.py b/tests/fixtures/base_provider_utils.py index de3b429d..6a54f3d8 100644 --- a/tests/fixtures/base_provider_utils.py +++ b/tests/fixtures/base_provider_utils.py @@ -5,12 +5,6 @@ from opentelemetry.semconv_ai import SpanAttributes -class MockMessageObject: - def __init__(self, role: str, content: any): - self.role = role - self.content = content - - class BaseProviderUtils: """Base class to be inherited by test cases that are going to test `utils.py` from `netra/instrumentation/openai/`, `netra/instrumentation/groq/` etc.""" @@ -25,25 +19,7 @@ class BaseProviderUtils: P_DETAIL = 10 T_TOKEN = 160 - ATTRIBUTE_MAPPINGS = { - "model": SpanAttributes.LLM_REQUEST_MODEL, - "temperature": SpanAttributes.LLM_REQUEST_TEMPERATURE, - "max_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, - "max_completion_tokens": SpanAttributes.LLM_REQUEST_MAX_TOKENS, - "max_tokens_to_sample": SpanAttributes.LLM_REQUEST_MAX_TOKENS, - "reasoning_effort": SpanAttributes.LLM_REQUEST_REASONING_EFFORT, - "frequency_penalty": SpanAttributes.LLM_FREQUENCY_PENALTY, - "presence_penalty": SpanAttributes.LLM_PRESENCE_PENALTY, - "stop": SpanAttributes.LLM_CHAT_STOP_SEQUENCES, - "stream": SpanAttributes.LLM_IS_STREAMING, - "top_p": SpanAttributes.LLM_REQUEST_TOP_P, - } - - set_request_attributes_method: Callable = None - set_response_attributes_method: Callable = None _set_usage_attributes_method: Callable = None - _set_chat_input_method: Callable = None - _set_response_message_attributes_method: Callable = None def _build_input_data(self): keys_groups = [ @@ -76,24 +52,26 @@ def _build_no_details_data(self): def test_set_usage_attributes(self): """Tests _set_usage_attributes""" - for dummy_data in self._build_input_data(): - mock_span = MagicMock() - self._set_usage_attributes_method(mock_span, dummy_data) + for i, dummy_data in enumerate(self._build_input_data()): + with self.subTest(scenario=f"Combination {i}", payload=dummy_data): + mock_span = MagicMock() + self._set_usage_attributes_method(mock_span, dummy_data) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS, self.P_DETAIL) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS, self.P_DETAIL) def test_set_usage_attributes_no_prompt_tokens_details(self): """Tests _set_usage_attributes without prompt token details""" - for dummy_data in self._build_no_details_data(): - mock_span = MagicMock() - self._set_usage_attributes_method(mock_span, dummy_data) - - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) - mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) + for i, dummy_data in enumerate(self._build_no_details_data()): + with self.subTest(scenario=f"Combination {i}", payload=dummy_data): + mock_span = MagicMock() + self._set_usage_attributes_method(mock_span, dummy_data) + + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, self.P_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, self.C_TOKEN) + mock_span.set_attribute.assert_any_call(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, self.T_TOKEN) def test_empty_dict(self): """Tests _set_usage_attributes with empty dictionary""" diff --git a/tests/test_groq_provider_utils.py b/tests/test_groq_provider_utils.py index 2e48cec5..729c2f10 100644 --- a/tests/test_groq_provider_utils.py +++ b/tests/test_groq_provider_utils.py @@ -1,19 +1,9 @@ import unittest -from netra.instrumentation.groq.utils import ( - _set_chat_input, - _set_response_message_attributes, - _set_usage_attributes, - set_request_attributes, - set_response_attributes, -) +from netra.instrumentation.groq.utils import _set_usage_attributes from .fixtures.base_provider_utils import BaseProviderUtils class TestGroqProviderUtils(unittest.TestCase, BaseProviderUtils): - set_request_attributes_method = staticmethod(set_request_attributes) - set_response_attributes_method = staticmethod(set_response_attributes) - _set_chat_input_method = staticmethod(_set_chat_input) - _set_response_message_attributes_method = staticmethod(_set_response_message_attributes) _set_usage_attributes_method = staticmethod(_set_usage_attributes) From 33378823c61453eaa233850dc6bf14c1d3b4c87f Mon Sep 17 00:00:00 2001 From: ivanrj7j Date: Fri, 12 Jun 2026 11:30:52 +0530 Subject: [PATCH 13/13] fix: Switch to using plain assert statement instead of assertEqual --- tests/fixtures/__init__.py | 2 +- tests/fixtures/base_provider_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 25b98afe..cb01520a 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1 +1 @@ -from .base_provider_utils import BaseProviderUtils, MockMessageObject +from .base_provider_utils import BaseProviderUtils diff --git a/tests/fixtures/base_provider_utils.py b/tests/fixtures/base_provider_utils.py index 6a54f3d8..abf6906b 100644 --- a/tests/fixtures/base_provider_utils.py +++ b/tests/fixtures/base_provider_utils.py @@ -79,4 +79,4 @@ def test_empty_dict(self): self._set_usage_attributes_method(mock_span, dict()) called_keys = [call[0][0] for call in mock_span.set_attribute.call_args_list] - self.assertEqual(len(called_keys), 0) + assert len(called_keys) == 0