Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: black --check src/

- name: Type check
run: mypy src/ || true
run: mypy src/

- name: Run tests
run: pytest
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,13 @@ python_version = "3.11"
warn_return_any = true
warn_unused_configs = true

[[tool.mypy.overrides]]
module = [
"boto3",
"botocore.exceptions",
"torch",
"httpx",
"ollama",
]
ignore_missing_imports = true

2 changes: 1 addition & 1 deletion src/deepiri_modelkit/contracts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def validate_aimodel(value: Any) -> Any:

return value

def serialize_aimodel(value: Any) -> Dict[str, Any]:
def serialize_aimodel(value: Any) -> Optional[Dict[str, Any]]:
"""Serialize AIModel instance to dict"""
if value is None:
return None
Expand Down
7 changes: 5 additions & 2 deletions src/deepiri_modelkit/contracts/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Service contracts and interfaces
"""

from typing import Protocol, Dict, Any, Optional
from typing import Any, Callable, Dict, Optional, Protocol
from pydantic import BaseModel


Expand Down Expand Up @@ -38,7 +38,10 @@ def publish(self, topic: str, event: Dict[str, Any]) -> bool:
...

def subscribe(
self, topic: str, callback: callable, consumer_group: Optional[str] = None
self,
topic: str,
callback: Callable[[Dict[str, Any]], Any],
consumer_group: Optional[str] = None,
) -> None:
"""Subscribe to topic with callback"""
...
4 changes: 2 additions & 2 deletions src/deepiri_modelkit/data/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, log_dir: str = "./logs/dataset_monitoring"):
self.alerts_file = self.log_dir / "alerts.jsonl"

# In-memory metrics for quick access
self.current_metrics = {
self.current_metrics: Dict[str, Any] = {
"total_versions_created": 0,
"total_datasets_tracked": 0,
"average_version_creation_time": 0,
Expand Down Expand Up @@ -159,7 +159,7 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]:
"""Get usage analytics for the specified period."""
cutoff_date = datetime.utcnow() - timedelta(days=days)

analytics = {
analytics: Dict[str, Any] = {
"period_days": days,
"version_creations": [],
"training_runs": [],
Expand Down
2 changes: 1 addition & 1 deletion src/deepiri_modelkit/data/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def validate_dataset(self, data_path: Path) -> Dict[str, Any]:
"Starting dataset validation", path=str(data_path), type=self.dataset_type
)

results = {
results: Dict[str, Any] = {
"is_valid": True,
"errors": [],
"warnings": [],
Expand Down
22 changes: 13 additions & 9 deletions src/deepiri_modelkit/ml/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
):
self.ollama_base_url = ollama_base_url
self.model = model
self._cache = {} # Cache for semantic analysis results
self._cache: dict[str, list[str]] = {} # Cache for semantic analysis results

def _call_ollama(
self, prompt: str, timeout: float = 15.0
Expand All @@ -66,7 +66,7 @@ def _call_ollama(
"num_predict": 100, # Reduced from 200 for faster responses
},
)
return response.get("response", "").strip()
return str(response.get("response", "")).strip()
except Exception:
# Fall back to HTTP
pass
Expand All @@ -93,7 +93,7 @@ def _call_ollama(
if response.status_code == 200:
result = response.json()
logger.debug("Ollama HTTP call succeeded")
return result.get("response", "").strip()
return str(result.get("response", "")).strip()
else:
logger.debug(
f"Ollama HTTP call failed: HTTP {response.status_code}"
Expand All @@ -118,7 +118,7 @@ def _call_ollama(
if response.status_code == 200:
result = response.json()
logger.debug("Ollama HTTP call succeeded")
return result.get("response", "").strip()
return str(result.get("response", "")).strip()
else:
logger.debug(
f"Ollama HTTP call failed: HTTP {response.status_code}"
Expand Down Expand Up @@ -305,7 +305,9 @@ def analyze_semantic_structure(self, text: str) -> Dict:
try:
json_match = re.search(r"\{.*?\}", response, re.DOTALL)
if json_match:
return json.loads(json_match.group())
parsed = json.loads(json_match.group())
if isinstance(parsed, dict):
return parsed
except Exception:
pass

Expand Down Expand Up @@ -333,10 +335,10 @@ def check_ollama_available(self) -> bool:
try:
if HAS_HTTPX:
response = httpx.get(f"{self.ollama_base_url}/api/tags", timeout=5.0)
return response.status_code == 200
return bool(response.status_code == 200)
elif HAS_REQUESTS:
response = requests.get(f"{self.ollama_base_url}/api/tags", timeout=5.0)
return response.status_code == 200
return bool(response.status_code == 200)
except Exception:
pass

Expand All @@ -349,8 +351,10 @@ def get_semantic_analyzer(
"""
Factory function to get semantic analyzer
"""
base_url = ollama_base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
model_name = model or os.getenv("OLLAMA_MODEL", "llama3:8b")
base_url = (
ollama_base_url or os.getenv("OLLAMA_BASE_URL") or "http://localhost:11434"
)
model_name = model or os.getenv("OLLAMA_MODEL") or "llama3:8b"

analyzer = SemanticAnalyzer(ollama_base_url=base_url, model=model_name)

Expand Down
Loading
Loading