Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions example.env
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ BACKUP_ENABLED=true
# Cron schedule for the shared backup_db periodic task.
BACKUP_CRON="0 3 * * *"

# Enable the daily database backup periodic task.
# Set to "false" to no-op the shared backup_db task (e.g. in test environments).
BACKUP_ENABLED=true

# Cron schedule for the shared backup_db periodic task.
BACKUP_CRON="0 3 * * *"

# Site information that is synced to the database and used by the sites framework.
SITE_NAME="RADIS"
SITE_DOMAIN=localhost
Expand Down
7 changes: 6 additions & 1 deletion radis-client/radis_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,17 @@ def update_report(
return response.json()

def update_reports_bulk(
self, reports: list[ReportData], upsert: bool = True
self,
reports: list[ReportData],
upsert: bool = True,
timeout: float | tuple[float, float] | None = None,
) -> dict[str, Any]:
"""Bulk upsert reports using a single request.

Args:
reports: The report payloads to upsert.
upsert: Whether to perform upsert behavior when a report is missing.
timeout: Optional requests timeout (seconds).

Returns:
The response as JSON.
Expand All @@ -119,6 +123,7 @@ def update_reports_bulk(
json=payload,
headers=self._headers,
params={"upsert": upsert},
timeout=timeout,
)
response.raise_for_status()
return response.json()
Expand Down
30 changes: 30 additions & 0 deletions radis/pgsearch/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

from procrastinate.contrib.django import app
from procrastinate.types import JSONValue

from .utils.indexing import bulk_upsert_report_search_vectors

logger = logging.getLogger(__name__)


@app.task
def bulk_index_reports(report_ids: list[int]) -> None:
if not report_ids:
return
logger.info("Indexing %s reports in bulk.", len(report_ids))
bulk_upsert_report_search_vectors(report_ids)


def enqueue_bulk_index_reports(report_ids: list[int]) -> int | None:
if not report_ids:
return None
try:
payload: list[JSONValue] = [int(report_id) for report_id in report_ids]
except (TypeError, ValueError) as exc:
logger.error("Invalid report_id in bulk index request: %s", exc)
return None
Comment on lines +19 to +26

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t drop valid IDs when one entry is invalid.
Right now a single bad report_id prevents the entire batch from enqueuing, which can leave valid reports unindexed. Consider per-item validation (skip/log invalids) or explicitly raise so the caller can retry.

✅ Suggested fix (skip invalid IDs, keep valid ones)
 def enqueue_bulk_index_reports(report_ids: list[int]) -> int | None:
     if not report_ids:
         return None
-    try:
-        payload: list[int] = [int(report_id) for report_id in report_ids]
-    except (TypeError, ValueError) as exc:
-        logger.error("Invalid report_id in bulk index request: %s", exc)
-        return None
+    payload: list[int] = []
+    for report_id in report_ids:
+        try:
+            payload.append(int(report_id))
+        except (TypeError, ValueError):
+            logger.exception(
+                "Invalid report_id in bulk index request: %r",
+                report_id,
+            )
+    if not payload:
+        return None
     return app.configure_task(
         "radis.pgsearch.tasks.bulk_index_reports",
         allow_unknown=False,
     ).defer(report_ids=payload)
🧰 Tools
🪛 Ruff (0.14.14)

24-24: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

🤖 Prompt for AI Agents
In `@radis/pgsearch/tasks.py` around lines 18 - 25, The current
enqueue_bulk_index_reports function aborts the whole batch if any report_id
fails conversion; change it to validate per-item instead: iterate over
report_ids, attempt int(report_id) for each inside a try/except, append
successfully converted ids to payload (or valid_ids) and logger.warning/error
the specific invalid value on exception, and after the loop return None if
payload is empty else continue with the original enqueue logic; update
references to payload (or rename to valid_ids) and remove the single try/except
around the whole list so valid IDs are not dropped.

return app.configure_task(
"radis.pgsearch.tasks.bulk_index_reports",
allow_unknown=False,
).defer(report_ids=payload)
33 changes: 33 additions & 0 deletions radis/pgsearch/tests/test_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from radis.pgsearch.models import ReportSearchVector
from radis.pgsearch.utils.indexing import bulk_upsert_report_search_vectors
from radis.reports.models import Language, Report


@pytest.mark.django_db
def test_bulk_index_matches_signal_vector() -> None:
language = Language.objects.create(code="en")
report = Report.objects.create(
document_id="DOC-INDEX",
pacs_aet="PACS",
pacs_name="PACS",
pacs_link="",
patient_id="P1",
patient_birth_date="1980-01-01",
patient_sex="M",
study_description="Study",
study_datetime="2024-01-01T00:00:00Z",
study_instance_uid="1.2.3.4",
accession_number="ACC1",
body="Findings: No acute abnormality.",
language=language,
)

signal_vector = ReportSearchVector.objects.get(report=report).search_vector
ReportSearchVector.objects.filter(report=report).delete()

bulk_upsert_report_search_vectors([report.pk])
bulk_vector = ReportSearchVector.objects.get(report=report).search_vector

assert signal_vector == bulk_vector
74 changes: 74 additions & 0 deletions radis/pgsearch/utils/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

import logging
from collections.abc import Iterable

from django.conf import settings
from django.db import connection

from radis.reports.models import Report

from ..models import ReportSearchVector
from .language_utils import code_to_language

logger = logging.getLogger(__name__)


def _chunked(items: list[int], size: int) -> Iterable[list[int]]:
for index in range(0, len(items), size):
yield items[index : index + size]


def bulk_upsert_report_search_vectors(
report_ids: Iterable[int],
chunk_size: int | None = None,
) -> None:
ids = sorted({int(report_id) for report_id in report_ids if report_id is not None})
if not ids:
return
resolved_chunk_size = (
settings.PGSEARCH_BULK_INDEX_CHUNK_SIZE if chunk_size is None else chunk_size
)

for chunk in _chunked(ids, resolved_chunk_size):
reports = (
Report.objects.filter(id__in=chunk)
.select_related("language")
.only("id", "language__code")
)
report_ids_found: set[int] = set()
config_to_ids: dict[str, list[int]] = {}
config_cache: dict[str, str] = {}
for report in reports:
report_ids_found.add(report.pk)
language_code = report.language.code
config = config_cache.get(language_code)
if config is None:
config = code_to_language(language_code)
config_cache[language_code] = config
config_to_ids.setdefault(config, []).append(report.pk)
missing_ids = set(chunk) - report_ids_found
if missing_ids:
logger.warning(
"Skipping %s missing reports during bulk index (ids=%s).",
len(missing_ids),
sorted(missing_ids)[:10],
)

for config, config_ids in config_to_ids.items():
ReportSearchVector.objects.bulk_create(
[ReportSearchVector(report_id=report_id) for report_id in config_ids],
ignore_conflicts=True,
batch_size=settings.PGSEARCH_BULK_INSERT_BATCH_SIZE,
)

with connection.cursor() as cursor:
cursor.execute(
"""
UPDATE pgsearch_reportsearchvector v
SET search_vector = to_tsvector(%s::regconfig, r.body)
FROM reports_report r
WHERE v.report_id = r.id AND r.id = ANY(%s)
""",
[config, config_ids],
)
6 changes: 5 additions & 1 deletion radis/pgsearch/utils/language_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def get_available_search_configs() -> set[str]:
try:
return _get_available_search_configs_cached()
except DatabaseError as exc:
logger.warning("Failed to read pg_ts_config; falling back to simple. %s", exc)
logger.error(
"Failed to read pg_ts_config; falling back to simple. %s",
exc,
exc_info=True,
)
return set()


Expand Down
10 changes: 10 additions & 0 deletions radis/reports/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import transaction
from rest_framework import serializers, validators
from rest_framework.exceptions import ValidationError
from rest_framework.relations import PrimaryKeyRelatedField

from ..models import Language, Metadata, Modality, Report

Expand Down Expand Up @@ -50,6 +51,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if self.context.get("skip_document_id_unique"):
self._strip_unique_validator("document_id")
request = self.context.get("request")
if request is not None and "groups" in self.fields:
groups_field = self.fields["groups"]
if isinstance(groups_field, PrimaryKeyRelatedField):
if groups_field.queryset is not None:
if request.user.is_superuser:
groups_field.queryset = groups_field.queryset.all()
else:
groups_field.queryset = request.user.groups.all()

class Meta:
model = Report
Expand Down
100 changes: 95 additions & 5 deletions radis/reports/api/viewsets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any

from django.conf import settings
from django.db import transaction
from django.http import Http404
from django.utils import timezone
Expand All @@ -12,6 +13,9 @@
from rest_framework.response import Response
from rest_framework.serializers import BaseSerializer

from radis.pgsearch.tasks import enqueue_bulk_index_reports
from radis.pgsearch.utils.indexing import bulk_upsert_report_search_vectors

from ..models import Language, Metadata, Modality, Report
from ..site import (
document_fetchers,
Expand All @@ -26,10 +30,58 @@
BULK_DB_BATCH_SIZE = 1000


def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[str], list[str]]:
def _bulk_upsert_reports(
validated_reports: list[dict[str, Any]],
) -> tuple[list[str], list[str]]:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if not validated_reports:
return [], []

deduped_reports: dict[str, dict[str, Any]] = {}
duplicate_count = 0
for report in validated_reports:
document_id = report["document_id"]
if document_id in deduped_reports:
duplicate_count += 1
deduped_reports[document_id] = report
if duplicate_count:
logger.warning(
"Bulk upsert payload contained %s duplicate document_ids; keeping last occurrence.",
duplicate_count,
)
validated_reports = list(deduped_reports.values())

def _dedupe_by_key(
items: list[dict[str, Any]], key_name: str
) -> tuple[list[dict[str, Any]], int]:
if not items:
return [], 0
by_key: dict[str, dict[str, Any]] = {}
for item in items:
key = item[key_name]
by_key[key] = item
return list(by_key.values()), len(items) - len(by_key)

def _dedupe_metadata(items: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], int]:
if not items:
return [], 0
by_key: dict[str, dict[str, Any]] = {}
duplicates = 0
for item in items:
key = item["key"]
if key in by_key:
duplicates += 1
by_key[key] = item
return list(by_key.values()), duplicates

def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]:
if not items:
return [], 0
by_id: dict[int, int] = {}
for group in items:
group_id = int(getattr(group, "pk", group))
by_id[group_id] = group_id
return list(by_id.values()), len(items) - len(by_id)

document_ids = [report["document_id"] for report in validated_reports]

language_codes = {report["language"]["code"] for report in validated_reports}
Expand Down Expand Up @@ -135,9 +187,12 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[
Metadata.objects.filter(report_id__in=report_ids).delete()

metadata_rows: list[Metadata] = []
metadata_duplicate_count = 0
for report_data in validated_reports:
report_id = report_id_by_document_id[report_data["document_id"]]
for item in report_data.get("metadata", []):
metadata_items, duplicates = _dedupe_metadata(report_data.get("metadata", []))
metadata_duplicate_count += duplicates
for item in metadata_items:
metadata_rows.append(
Metadata(report_id=report_id, key=item["key"], value=item["value"])
)
Expand All @@ -148,9 +203,14 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[
modality_through.objects.filter(report_id__in=report_ids).delete()

modality_rows = []
modality_duplicate_count = 0
for report_data in validated_reports:
report_id = report_id_by_document_id[report_data["document_id"]]
for modality in report_data.get("modalities", []):
modality_items, duplicates = _dedupe_by_key(
report_data.get("modalities", []), "code"
)
modality_duplicate_count += duplicates
for modality in modality_items:
modality_id = modality_by_code[modality["code"]].pk
modality_rows.append(
modality_through(report_id=report_id, modality_id=modality_id)
Expand All @@ -162,13 +222,31 @@ def _bulk_upsert_reports(validated_reports: list[dict[str, Any]]) -> tuple[list[
group_through.objects.filter(report_id__in=report_ids).delete()

group_rows = []
group_duplicate_count = 0
for report_data in validated_reports:
report_id = report_id_by_document_id[report_data["document_id"]]
for group in report_data.get("groups", []):
group_rows.append(group_through(report_id=report_id, group_id=group.pk))
group_items, duplicates = _dedupe_groups(report_data.get("groups", []))
group_duplicate_count += duplicates
for group_id in group_items:
group_rows.append(group_through(report_id=report_id, group_id=group_id))
if group_rows:
group_through.objects.bulk_create(group_rows, batch_size=BULK_DB_BATCH_SIZE)

if metadata_duplicate_count or modality_duplicate_count or group_duplicate_count:
logger.warning(
"Bulk upsert payload contained duplicate metadata/modality/group entries "
"(metadata=%s modalities=%s groups=%s); duplicates were dropped.",
metadata_duplicate_count,
modality_duplicate_count,
group_duplicate_count,
)

touched_report_ids = [
report_id_by_document_id[document_id]
for document_id in [*created_ids, *updated_ids]
if document_id in report_id_by_document_id
]

def on_commit():
if created_ids:
created_reports = list(Report.objects.filter(document_id__in=created_ids))
Expand All @@ -178,6 +256,11 @@ def on_commit():
updated_reports = list(Report.objects.filter(document_id__in=updated_ids))
for handler in reports_updated_handlers:
handler.handle(updated_reports)
if touched_report_ids:
if settings.PGSEARCH_SYNC_INDEXING:
bulk_upsert_report_search_vectors(touched_report_ids)
else:
enqueue_bulk_index_reports(touched_report_ids)

transaction.on_commit(on_commit)

Expand Down Expand Up @@ -268,6 +351,13 @@ def bulk_upsert(self, request: Request) -> Response:
status=status.HTTP_400_BAD_REQUEST,
)

replace = request.GET.get("replace", "true").lower() in ["true", "1", "yes"]
if not replace:
return Response(
{"detail": "replace=false is not supported for bulk upsert. Use replace=true."},
status=status.HTTP_400_BAD_REQUEST,
)

valid_payloads: list[dict[str, Any]] = []
errors: list[dict[str, Any]] = []
for index, payload in enumerate(request.data):
Expand Down
1 change: 1 addition & 0 deletions radis/reports/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading