diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d35c6fd..6091418 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,7 @@ jobs: run: black --check src/ - name: Type check - run: mypy src/ || true + run: mypy src/ - name: Run tests run: pytest diff --git a/pyproject.toml b/pyproject.toml index 7601b78..d268df1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,13 @@ python_version = "3.11" warn_return_any = true warn_unused_configs = true +[[tool.mypy.overrides]] +module = [ + "boto3", + "botocore.exceptions", + "torch", + "httpx", + "ollama", +] +ignore_missing_imports = true + diff --git a/src/deepiri_modelkit/contracts/models.py b/src/deepiri_modelkit/contracts/models.py index bdc6568..be774ed 100644 --- a/src/deepiri_modelkit/contracts/models.py +++ b/src/deepiri_modelkit/contracts/models.py @@ -112,7 +112,7 @@ def validate_aimodel(value: Any) -> Any: return value - def serialize_aimodel(value: Any) -> Dict[str, Any]: + def serialize_aimodel(value: Any) -> Optional[Dict[str, Any]]: """Serialize AIModel instance to dict""" if value is None: return None diff --git a/src/deepiri_modelkit/contracts/services.py b/src/deepiri_modelkit/contracts/services.py index b844a46..56ce934 100644 --- a/src/deepiri_modelkit/contracts/services.py +++ b/src/deepiri_modelkit/contracts/services.py @@ -2,7 +2,7 @@ Service contracts and interfaces """ -from typing import Protocol, Dict, Any, Optional +from typing import Any, Callable, Dict, Optional, Protocol from pydantic import BaseModel @@ -38,7 +38,10 @@ def publish(self, topic: str, event: Dict[str, Any]) -> bool: ... def subscribe( - self, topic: str, callback: callable, consumer_group: Optional[str] = None + self, + topic: str, + callback: Callable[[Dict[str, Any]], Any], + consumer_group: Optional[str] = None, ) -> None: """Subscribe to topic with callback""" ... diff --git a/src/deepiri_modelkit/data/monitoring.py b/src/deepiri_modelkit/data/monitoring.py index 81b7a91..229ab57 100644 --- a/src/deepiri_modelkit/data/monitoring.py +++ b/src/deepiri_modelkit/data/monitoring.py @@ -35,7 +35,7 @@ def __init__(self, log_dir: str = "./logs/dataset_monitoring"): self.alerts_file = self.log_dir / "alerts.jsonl" # In-memory metrics for quick access - self.current_metrics = { + self.current_metrics: Dict[str, Any] = { "total_versions_created": 0, "total_datasets_tracked": 0, "average_version_creation_time": 0, @@ -159,7 +159,7 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: """Get usage analytics for the specified period.""" cutoff_date = datetime.utcnow() - timedelta(days=days) - analytics = { + analytics: Dict[str, Any] = { "period_days": days, "version_creations": [], "training_runs": [], diff --git a/src/deepiri_modelkit/data/validation.py b/src/deepiri_modelkit/data/validation.py index 60ef3c7..4b87378 100644 --- a/src/deepiri_modelkit/data/validation.py +++ b/src/deepiri_modelkit/data/validation.py @@ -93,7 +93,7 @@ def validate_dataset(self, data_path: Path) -> Dict[str, Any]: "Starting dataset validation", path=str(data_path), type=self.dataset_type ) - results = { + results: Dict[str, Any] = { "is_valid": True, "errors": [], "warnings": [], diff --git a/src/deepiri_modelkit/ml/semantic.py b/src/deepiri_modelkit/ml/semantic.py index 366af91..95b38a3 100644 --- a/src/deepiri_modelkit/ml/semantic.py +++ b/src/deepiri_modelkit/ml/semantic.py @@ -48,7 +48,7 @@ def __init__( ): self.ollama_base_url = ollama_base_url self.model = model - self._cache = {} # Cache for semantic analysis results + self._cache: dict[str, list[str]] = {} # Cache for semantic analysis results def _call_ollama( self, prompt: str, timeout: float = 15.0 @@ -66,7 +66,7 @@ def _call_ollama( "num_predict": 100, # Reduced from 200 for faster responses }, ) - return response.get("response", "").strip() + return str(response.get("response", "")).strip() except Exception: # Fall back to HTTP pass @@ -93,7 +93,7 @@ def _call_ollama( if response.status_code == 200: result = response.json() logger.debug("Ollama HTTP call succeeded") - return result.get("response", "").strip() + return str(result.get("response", "")).strip() else: logger.debug( f"Ollama HTTP call failed: HTTP {response.status_code}" @@ -118,7 +118,7 @@ def _call_ollama( if response.status_code == 200: result = response.json() logger.debug("Ollama HTTP call succeeded") - return result.get("response", "").strip() + return str(result.get("response", "")).strip() else: logger.debug( f"Ollama HTTP call failed: HTTP {response.status_code}" @@ -305,7 +305,9 @@ def analyze_semantic_structure(self, text: str) -> Dict: try: json_match = re.search(r"\{.*?\}", response, re.DOTALL) if json_match: - return json.loads(json_match.group()) + parsed = json.loads(json_match.group()) + if isinstance(parsed, dict): + return parsed except Exception: pass @@ -333,10 +335,10 @@ def check_ollama_available(self) -> bool: try: if HAS_HTTPX: response = httpx.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) - return response.status_code == 200 + return bool(response.status_code == 200) elif HAS_REQUESTS: response = requests.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) - return response.status_code == 200 + return bool(response.status_code == 200) except Exception: pass @@ -349,8 +351,10 @@ def get_semantic_analyzer( """ Factory function to get semantic analyzer """ - base_url = ollama_base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") - model_name = model or os.getenv("OLLAMA_MODEL", "llama3:8b") + base_url = ( + ollama_base_url or os.getenv("OLLAMA_BASE_URL") or "http://localhost:11434" + ) + model_name = model or os.getenv("OLLAMA_MODEL") or "llama3:8b" analyzer = SemanticAnalyzer(ollama_base_url=base_url, model=model_name) diff --git a/src/deepiri_modelkit/rag/__init__.py b/src/deepiri_modelkit/rag/__init__.py index 746f60f..7b3cab4 100644 --- a/src/deepiri_modelkit/rag/__init__.py +++ b/src/deepiri_modelkit/rag/__init__.py @@ -1,166 +1,97 @@ -""" -Universal RAG Module for Deepiri Platform -Reusable across all industry niches: Insurance, Manufacturing, Property Management, Healthcare, etc. -""" - -from .base import ( - UniversalRAGEngine, - Document, - DocumentType, - IndustryNiche, - RAGConfig, - RAGQuery, - RetrievalResult, -) -from .processors import ( - DocumentProcessor, - RegulationProcessor, - HistoricalDataProcessor, - KnowledgeBaseProcessor, - ManualProcessor, - get_processor, -) -from .retrievers import ( - MultiModalRetriever, - HybridRetriever, - ContextualRetriever, - get_retriever, -) - -# Advanced features (optional imports) -try: - from .advanced_retrieval import ( - AdvancedRetrievalPipeline, - QueryExpander, - SynonymQueryExpander, - RephraseQueryExpander, - MultiQueryRetriever, - QueryCache, - ) - - HAS_ADVANCED_RETRIEVAL = True -except ImportError: - HAS_ADVANCED_RETRIEVAL = False - AdvancedRetrievalPipeline = None - QueryExpander = None - SynonymQueryExpander = None - RephraseQueryExpander = None - MultiQueryRetriever = None - QueryCache = None - -try: - from .caching import ( - AdvancedCacheManager, - EmbeddingCache, - QueryResultCache, - ) - - HAS_CACHING = True -except ImportError: - HAS_CACHING = False - AdvancedCacheManager = None - EmbeddingCache = None - QueryResultCache = None - -try: - from .monitoring import ( - RAGMonitor, - RetrievalMetrics, - IndexingMetrics, - SystemMetrics, - PerformanceTimer, - ) - - HAS_MONITORING = True -except ImportError: - HAS_MONITORING = False - RAGMonitor = None - RetrievalMetrics = None - IndexingMetrics = None - SystemMetrics = None - PerformanceTimer = None - -try: - from .async_processing import ( - AsyncBatchProcessor, - AsyncDocumentIndexer, - AsyncDocumentProcessor, - BatchProcessingConfig, - BatchProcessingResult, - ) - - HAS_ASYNC = True -except ImportError: - HAS_ASYNC = False - AsyncBatchProcessor = None - AsyncDocumentIndexer = None - AsyncDocumentProcessor = None - BatchProcessingConfig = None - BatchProcessingResult = None - -__all__ = [ - # Core - "UniversalRAGEngine", - "Document", - "DocumentType", - "IndustryNiche", - "RAGConfig", - "RAGQuery", - "RetrievalResult", - # Processors - "DocumentProcessor", - "RegulationProcessor", - "HistoricalDataProcessor", - "KnowledgeBaseProcessor", - "ManualProcessor", - "get_processor", - # Retrievers - "MultiModalRetriever", - "HybridRetriever", - "ContextualRetriever", - "get_retriever", -] - -# Conditionally add advanced features -if HAS_ADVANCED_RETRIEVAL: - __all__.extend( - [ - "AdvancedRetrievalPipeline", - "QueryExpander", - "SynonymQueryExpander", - "RephraseQueryExpander", - "MultiQueryRetriever", - "QueryCache", - ] - ) - -if HAS_CACHING: - __all__.extend( - [ - "AdvancedCacheManager", - "EmbeddingCache", - "QueryResultCache", - ] - ) - -if HAS_MONITORING: - __all__.extend( - [ - "RAGMonitor", - "RetrievalMetrics", - "IndexingMetrics", - "SystemMetrics", - "PerformanceTimer", - ] - ) - -if HAS_ASYNC: - __all__.extend( - [ - "AsyncBatchProcessor", - "AsyncDocumentIndexer", - "AsyncDocumentProcessor", - "BatchProcessingConfig", - "BatchProcessingResult", - ] - ) +""" +Universal RAG Module for Deepiri Platform +Reusable across all industry niches: Insurance, Manufacturing, Property Management, Healthcare, etc. +""" + +from .advanced_retrieval import ( + AdvancedRetrievalPipeline, + MultiQueryRetriever, + QueryCache, + QueryExpander, + RephraseQueryExpander, + SynonymQueryExpander, +) +from .async_processing import ( + AsyncBatchProcessor, + AsyncDocumentIndexer, + AsyncDocumentProcessor, + BatchProcessingConfig, + BatchProcessingResult, +) +from .base import ( + Document, + DocumentType, + IndustryNiche, + RAGConfig, + RAGQuery, + RetrievalResult, + UniversalRAGEngine, +) +from .caching import AdvancedCacheManager, EmbeddingCache, QueryResultCache +from .monitoring import ( + IndexingMetrics, + PerformanceTimer, + RAGMonitor, + RetrievalMetrics, + SystemMetrics, +) +from .processors import ( + DocumentProcessor, + HistoricalDataProcessor, + KnowledgeBaseProcessor, + ManualProcessor, + RegulationProcessor, + get_processor, +) +from .retrievers import ( + ContextualRetriever, + HybridRetriever, + MultiModalRetriever, + get_retriever, +) + +__all__ = [ + # Core + "UniversalRAGEngine", + "Document", + "DocumentType", + "IndustryNiche", + "RAGConfig", + "RAGQuery", + "RetrievalResult", + # Processors + "DocumentProcessor", + "RegulationProcessor", + "HistoricalDataProcessor", + "KnowledgeBaseProcessor", + "ManualProcessor", + "get_processor", + # Retrievers + "MultiModalRetriever", + "HybridRetriever", + "ContextualRetriever", + "get_retriever", + # Advanced retrieval + "AdvancedRetrievalPipeline", + "QueryExpander", + "SynonymQueryExpander", + "RephraseQueryExpander", + "MultiQueryRetriever", + "QueryCache", + # Caching + "AdvancedCacheManager", + "EmbeddingCache", + "QueryResultCache", + # Monitoring + "RAGMonitor", + "RetrievalMetrics", + "IndexingMetrics", + "SystemMetrics", + "PerformanceTimer", + # Async processing + "AsyncBatchProcessor", + "AsyncDocumentIndexer", + "AsyncDocumentProcessor", + "BatchProcessingConfig", + "BatchProcessingResult", +] diff --git a/src/deepiri_modelkit/rag/advanced_retrieval.py b/src/deepiri_modelkit/rag/advanced_retrieval.py index 4de0fbc..db74e32 100644 --- a/src/deepiri_modelkit/rag/advanced_retrieval.py +++ b/src/deepiri_modelkit/rag/advanced_retrieval.py @@ -394,12 +394,11 @@ def __init__( self.use_cache = use_cache self.query_cache = QueryCache(cache_manager) if use_cache else None + self.multi_query_retriever: Optional[MultiQueryRetriever] = None if use_multi_query: self.multi_query_retriever = MultiQueryRetriever( base_retriever=base_retriever, query_expander=query_expander ) - else: - self.multi_query_retriever = None def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """Retrieve with advanced strategies""" diff --git a/src/deepiri_modelkit/rag/async_processing.py b/src/deepiri_modelkit/rag/async_processing.py index 87a0b0f..c564c98 100644 --- a/src/deepiri_modelkit/rag/async_processing.py +++ b/src/deepiri_modelkit/rag/async_processing.py @@ -41,7 +41,7 @@ class BatchProcessingResult: successful_items: int failed_items: int processing_time_seconds: float - errors: List[Dict[str, Any]] = None + errors: Optional[List[Dict[str, Any]]] = None def __post_init__(self): if self.errors is None: @@ -118,7 +118,7 @@ async def process_batch( # Aggregate results for result in batch_results: - if isinstance(result, Exception): + if isinstance(result, BaseException): failed_items += self.config.batch_size errors.append({"error": str(result), "type": type(result).__name__}) else: @@ -253,7 +253,7 @@ async def index_documents_streaming( Returns: BatchProcessingResult with statistics """ - batch = [] + batch: list[Document] = [] total_processed = 0 successful = 0 failed = 0 diff --git a/src/deepiri_modelkit/rag/caching.py b/src/deepiri_modelkit/rag/caching.py index 6d25942..0b6985b 100644 --- a/src/deepiri_modelkit/rag/caching.py +++ b/src/deepiri_modelkit/rag/caching.py @@ -23,7 +23,7 @@ class CacheEntry: expires_at: Optional[datetime] access_count: int = 0 last_accessed: Optional[datetime] = None - tags: List[str] = None + tags: Optional[List[str]] = None def __post_init__(self): if self.tags is None: @@ -109,7 +109,7 @@ def _serialize_value(self, value: Any) -> str: return json.dumps(value) return str(value) - def _deserialize_value(self, value: str, value_type: type = None) -> Any: + def _deserialize_value(self, value: str, value_type: Optional[type] = None) -> Any: """Deserialize value from storage""" try: if value_type == list or (isinstance(value, str) and value.startswith("[")): @@ -248,7 +248,7 @@ def delete(self, key: str, namespace: str = "rag") -> bool: if key in self.memory_cache: entry = self.memory_cache[key] # Remove from tag indexes - for tag in entry.tags: + for tag in entry.tags or []: if tag in self.tag_index and key in self.tag_index[tag]: self.tag_index[tag].remove(key) del self.memory_cache[key] @@ -340,7 +340,7 @@ def _update_redis_entry(self, key: str, entry: CacheEntry): def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" - stats = { + stats: Dict[str, Any] = { "memory_cache_size": len(self.memory_cache), "max_size": self.max_size, "tag_index_size": len(self.tag_index), diff --git a/src/deepiri_modelkit/rag/processors.py b/src/deepiri_modelkit/rag/processors.py index 5585cde..78970ed 100644 --- a/src/deepiri_modelkit/rag/processors.py +++ b/src/deepiri_modelkit/rag/processors.py @@ -4,7 +4,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional +from typing import Any, Callable, Dict, List, Optional import re from datetime import datetime @@ -116,7 +116,7 @@ def _extract_sections(self, content: str) -> List[Dict[str, Any]]: section_pattern = r"(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)" sections = [] - current_section = {"section": None, "content": ""} + current_section: Dict[str, str] = {"section": "", "content": ""} lines = content.split("\n") for line in lines: @@ -379,7 +379,7 @@ def _extract_sections(self, content: str) -> List[Dict[str, Any]]: section_pattern = r"(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)" sections = [] - current_section = {"chapter": None, "section": None, "content": ""} + current_section: Dict[str, str] = {"chapter": "", "section": "", "content": ""} lines = content.split("\n") for line in lines: @@ -426,7 +426,7 @@ def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: Returns: Configured document processor """ - processor_map = { + processor_map: Dict[DocumentType, Callable[..., DocumentProcessor]] = { DocumentType.REGULATION: RegulationProcessor, DocumentType.POLICY: RegulationProcessor, # Similar processing DocumentType.WORK_ORDER: HistoricalDataProcessor, @@ -439,5 +439,7 @@ def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: DocumentType.PROCEDURE: ManualProcessor, # Similar processing } - processor_class = processor_map.get(doc_type, DocumentProcessor) - return processor_class(**kwargs) + factory = processor_map.get(doc_type) + if factory is None: + raise ValueError(f"No processor registered for document type: {doc_type}") + return factory(**kwargs) diff --git a/src/deepiri_modelkit/rag/retrievers.py b/src/deepiri_modelkit/rag/retrievers.py index bf40f86..08c7299 100644 --- a/src/deepiri_modelkit/rag/retrievers.py +++ b/src/deepiri_modelkit/rag/retrievers.py @@ -4,7 +4,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional +from typing import Any, Callable, Dict, List, Optional from dataclasses import dataclass from .base import Document, RetrievalResult, RAGQuery @@ -223,10 +223,9 @@ def _apply_temporal_boost( current_time = datetime.now().timestamp() for result in results: - if result.document.updated_at or result.document.created_at: - doc_time = ( - result.document.updated_at or result.document.created_at - ).timestamp() + ts = result.document.updated_at or result.document.created_at + if ts is not None: + doc_time = ts.timestamp() # Calculate age in days age_days = (current_time - doc_time) / 86400 @@ -277,11 +276,9 @@ def get_retriever(retriever_type: str, **kwargs) -> BaseRetriever: Returns: Configured retriever """ - retriever_map = { + retriever_map: Dict[str, Callable[..., BaseRetriever]] = { "hybrid": HybridRetriever, "multimodal": MultiModalRetriever, "contextual": ContextualRetriever, } - - retriever_class = retriever_map.get(retriever_type, HybridRetriever) - return retriever_class(**kwargs) + return retriever_map.get(retriever_type, HybridRetriever)(**kwargs) diff --git a/src/deepiri_modelkit/rag/testing.py b/src/deepiri_modelkit/rag/testing.py index f118dc2..a42ffd5 100644 --- a/src/deepiri_modelkit/rag/testing.py +++ b/src/deepiri_modelkit/rag/testing.py @@ -20,7 +20,7 @@ class TestCase: expected_doc_types: Optional[List[DocumentType]] = None min_score: float = 0.7 # Minimum similarity score top_k: int = 5 - metadata: Dict[str, Any] = None + metadata: Optional[Dict[str, Any]] = None def __post_init__(self): if self.metadata is None: diff --git a/src/deepiri_modelkit/registry/model_registry.py b/src/deepiri_modelkit/registry/model_registry.py index 1d2fe3c..4df3b63 100644 --- a/src/deepiri_modelkit/registry/model_registry.py +++ b/src/deepiri_modelkit/registry/model_registry.py @@ -123,6 +123,8 @@ def register_model( return True + raise ValueError(f"Unknown registry_type: {self.registry_type}") + except Exception as e: print(f"Error registering model: {e}") return False @@ -208,6 +210,8 @@ def get_model( "type": "local", } + raise ValueError(f"Unknown registry_type: {self.registry_type}") + except Exception as e: print(f"Error getting model: {e}") raise @@ -264,9 +268,9 @@ def download_model(self, model_name: str, version: str, destination: str) -> str else: # MLflow handles loading directly - return model_info["uri"] + return str(model_info["uri"]) - def list_models(self, model_name: Optional[str] = None) -> list: + def list_models(self, model_name: Optional[str] = None) -> list[Dict[str, Any]]: """ List available models @@ -305,28 +309,30 @@ def list_models(self, model_name: Optional[str] = None) -> list: Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" ) - models = [] + s3_models: list[Dict[str, Any]] = [] for prefix_obj in response.get("CommonPrefixes", []): model_path = prefix_obj["Prefix"] parts = model_path.strip("/").split("/") if len(parts) >= 2: - models.append({"name": parts[1], "path": model_path}) + s3_models.append({"name": parts[1], "path": model_path}) - return models + return s3_models elif self.registry_type == "local": - models = [] + local_models: list[Dict[str, Any]] = [] for model_dir in self.local_path.iterdir(): if model_dir.is_dir(): versions = [d.name for d in model_dir.iterdir() if d.is_dir()] - models.append( + local_models.append( { "name": model_dir.name, "versions": versions, "latest_version": max(versions) if versions else None, } ) - return models + return local_models + + raise ValueError(f"Unknown registry_type: {self.registry_type}") except Exception as e: print(f"Error listing models: {e}") diff --git a/src/deepiri_modelkit/streaming/event_stream.py b/src/deepiri_modelkit/streaming/event_stream.py index 904fe0b..e3e655e 100644 --- a/src/deepiri_modelkit/streaming/event_stream.py +++ b/src/deepiri_modelkit/streaming/event_stream.py @@ -43,7 +43,6 @@ def __init__( decode_responses=True, ) self._running = False - self._subscriptions = {} async def connect(self): """Connect to Redis""" @@ -72,16 +71,17 @@ async def publish( event["timestamp"] = datetime.utcnow().isoformat() # Publish to stream + # redis-py's xadd stub narrows the fields type more than the runtime accepts. message_id = await self.redis.xadd( - topic, event, maxlen=max_length, approximate=True + topic, event, maxlen=max_length, approximate=True # type: ignore[arg-type] ) - return message_id + return str(message_id) async def subscribe( self, topic: str, - callback: Callable[[Dict[str, Any]], None], + callback: Optional[Callable[[Dict[str, Any]], None]] = None, consumer_group: Optional[str] = None, consumer_name: Optional[str] = None, last_id: str = "0", @@ -191,4 +191,4 @@ async def get_stream_info(self, topic: str) -> Dict[str, Any]: async def get_stream_length(self, topic: str) -> int: """Get number of messages in stream""" - return await self.redis.xlen(topic) + return int(await self.redis.xlen(topic)) diff --git a/src/deepiri_modelkit/streaming/schemas.py b/src/deepiri_modelkit/streaming/schemas.py index 830cbae..d8a9f1f 100644 --- a/src/deepiri_modelkit/streaming/schemas.py +++ b/src/deepiri_modelkit/streaming/schemas.py @@ -2,6 +2,8 @@ Streaming event schemas and validation """ +from typing import Dict, List, Type + from .topics import StreamTopics from ..contracts.events import ( BaseEvent, @@ -14,7 +16,7 @@ ) # Map topics to event schemas -TOPIC_EVENT_SCHEMAS = { +TOPIC_EVENT_SCHEMAS: Dict[StreamTopics, List[Type[BaseEvent]]] = { StreamTopics.MODEL_EVENTS: [ModelReadyEvent, ModelLoadedEvent], StreamTopics.INFERENCE_EVENTS: [InferenceEvent], StreamTopics.PLATFORM_EVENTS: [PlatformEvent], @@ -37,12 +39,13 @@ def validate_event(topic: str, event_data: dict) -> BaseEvent: Raises: ValueError: If event doesn't match schema """ - if topic not in TOPIC_EVENT_SCHEMAS: + try: + topic_enum = StreamTopics(topic) + except ValueError: # Unknown topic, return base event return BaseEvent(**event_data) - schemas = TOPIC_EVENT_SCHEMAS[topic] - event_type = event_data.get("event") + schemas = TOPIC_EVENT_SCHEMAS[topic_enum] # Try to match event type to schema for schema in schemas: diff --git a/src/deepiri_modelkit/streaming/sidecar_utils.py b/src/deepiri_modelkit/streaming/sidecar_utils.py index 1ec9faf..4e41636 100644 --- a/src/deepiri_modelkit/streaming/sidecar_utils.py +++ b/src/deepiri_modelkit/streaming/sidecar_utils.py @@ -66,7 +66,7 @@ def sidecar_payload_from_fields(fields: Dict[str, Any]) -> Dict[str, Any]: payload = json.loads(payload) except ValueError: payload = {} - elif not isinstance(payload, dict): + if not isinstance(payload, dict): payload = {} if "event" not in payload and fields.get("event_type"):