diff --git a/docs/superpowers/plans/2026-06-08-adrf-report-views.md b/docs/superpowers/plans/2026-06-08-adrf-report-views.md new file mode 100644 index 000000000..227ba1883 --- /dev/null +++ b/docs/superpowers/plans/2026-06-08-adrf-report-views.md @@ -0,0 +1,939 @@ +# ADRF Report Views Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace the sync DRF `ReportViewSet` with one `adrf.viewsets.GenericViewSet` subclass (plus the create / retrieve / update / destroy async mixins from `adrf.mixins` and a `@action` for `bulk_upsert`) so the report-upload endpoints can `await` the async embedding client from inside the view in a follow-up PR. No client-visible API change in this PR. + +**Architecture:** Minimum-diff conversion of the legacy class: same mixin lineup, same `GenericViewSet` base, routing through `adrf.routers.DefaultRouter` (not DRF's — see Task 4 for why). The only structural change is `mixins.* → adrf.mixins.*` and the async-method overrides (`acreate`, `aretrieve`, `aupdate`, `adestroy`, `bulk_upsert`). Use native async ORM (`.aget`) for simple lookups and `channels.db.database_sync_to_async` to wrap DRF serializer + `transaction.atomic()` blocks. The `_bulk_upsert_reports` helper stays in `viewsets.py` (renamed `bulk_upsert_reports` — no separate `bulk.py` file). + +**Tech Stack:** Django 5.1+ (CI runs 6.0.1), DRF, ADRF (`adrf.viewsets.GenericViewSet` + `adrf.mixins`), Channels (`database_sync_to_async`), PostgreSQL, Procrastinate, pytest-django. + +**Spec:** `docs/superpowers/specs/2026-06-08-adrf-report-views-design.md` + +--- + +## File Structure + +| Action | Path | Responsibility | +| --- | --- | --- | +| Modify | `radis/reports/api/viewsets.py` | Single `ReportViewSet` rewritten on top of `adrf.viewsets.GenericViewSet` + the four async mixins from `adrf.mixins` + an `@action` for `bulk_upsert`. The bulk-upsert helper (`bulk_upsert_reports`, renamed from `_bulk_upsert_reports`) stays in the same module — this is a DRF-viewset → ADRF-viewset conversion, not a file restructure. | +| Modify | `radis/reports/api/urls.py` | Keep `DefaultRouter`; register `ReportViewSet` with `basename="report"`. (No real diff vs. legacy.) | +| Modify | `radis/reports/tests/test_bulk_upsert.py` | Update import (`from radis.reports.api.viewsets import _bulk_upsert_reports` → `from radis.reports.api.viewsets import bulk_upsert_reports`). | +| Create | `radis/reports/tests/test_report_api.py` | End-to-end coverage for all five endpoints via Django's `AsyncClient`; plus `inspect.iscoroutinefunction` shape guards on the viewset's async method set. | + +Unchanged: `radis/reports/api/serializers.py`, `radis/reports/api/__init__.py`, `radis/urls.py` (mount stays `path("api/reports/", include("radis.reports.api.urls"))`). + +The legacy file `radis/reports/api/viewsets.py` is rewritten in place (not renamed to `views.py` or split into `bulk.py` + `views.py`). The file name matches the framework convention (`viewsets.py` for viewset classes) and the diff reads as a sync→async conversion of the same module. + +--- + +## Prerequisites (run once before Task 1) + +The test suite runs inside the `web` container via `uv run cli test`, which gates on `helper.check_compose_up()`. Bring the dev stack up first: + +```bash +cd /Users/samuelkwong/adit-radis-workspace/projects/radis/.claude/worktrees/feat+adrf-views +uv run cli compose-up -d +``` + +Confirm a green baseline: + +```bash +uv run cli test +``` + +If the baseline is not green, **stop and report** — do not proceed to Task 1. + +--- + +## Task 1: Rename `_bulk_upsert_reports` → `bulk_upsert_reports` inside `viewsets.py` + +A one-line touch-up before the async conversion. The helper currently lives at module scope in `radis/reports/api/viewsets.py` with a leading-underscore name. After the conversion it stays in the same module and is called from `ReportViewSet.bulk_upsert`, so the underscore is misleading — it's the module's de-facto public bulk-upsert entry point. + +**Files:** +- Modify: `radis/reports/api/viewsets.py` +- Modify: `radis/reports/tests/test_bulk_upsert.py` + +- [ ] **Step 1.1: Rename the function in `radis/reports/api/viewsets.py`** + +`def _bulk_upsert_reports(...)` → `def bulk_upsert_reports(...)`. Update the single internal call site (inside the legacy `bulk_upsert` action) the same way. + +- [ ] **Step 1.2: Update the test import** + +In `radis/reports/tests/test_bulk_upsert.py`, change + +```python +from radis.reports.api.viewsets import _bulk_upsert_reports +``` + +to + +```python +from radis.reports.api.viewsets import bulk_upsert_reports +``` + +and rename the one call site in the test body. + +- [ ] **Step 1.3: Lint and commit** + +```bash +uv run cli lint +git add radis/reports/api/viewsets.py radis/reports/tests/test_bulk_upsert.py +git commit -m "refactor(reports): drop leading underscore from bulk_upsert_reports" +``` + +--- + +## Task 2: Add new test file with regression + async-shape guards + +Write the end-to-end coverage that proves the new ADRF views preserve the API contract, plus shape guards that fail until the new view classes exist. The regression tests **already pass** against the current DRF viewset (since the contract is byte-for-byte preserved) — that is the entire point: they lock the contract before the rewrite, then prove it survived after. + +**Files:** +- Create: `radis/reports/tests/test_report_api.py` + +- [ ] **Step 2.1: Write the test file** + +Create `radis/reports/tests/test_report_api.py`: + +```python +"""End-to-end tests for the report HTTP API. + +These tests intentionally exercise behavior through Django's `Client`, +so they pass against both the legacy DRF viewset and the ADRF rewrite. +They lock the wire contract before the swap and prove it survives after. + +The `_is_async` shape guards at the bottom fail until +`radis.reports.api.views` exists with `async def` handlers — they drive +the rewrite TDD-style. +""" +import asyncio +import json +from datetime import date + +import pytest +from adit_radis_shared.accounts.factories import GroupFactory, UserFactory +from adit_radis_shared.token_authentication.models import Token +from django.test import Client +from django.urls import reverse + +from radis.reports.models import Report +from radis.reports.site import ( + DocumentFetcher, + ReportsCreatedHandler, + ReportsDeletedHandler, + document_fetchers, + reports_created_handlers, + reports_deleted_handlers, +) + + +def _make_payload(document_id: str = "DOC-1", body: str = "Report body") -> dict: + return { + "document_id": document_id, + "language": "en", + "groups": [], # populated by tests after group is known + "pacs_aet": "PACS", + "pacs_name": "Test PACS", + "pacs_link": "", + "patient_id": "P1", + "patient_birth_date": "1980-01-01", + "patient_sex": "M", + "study_description": "Study 1", + "study_datetime": "2024-01-01T00:00:00Z", + "study_instance_uid": "1.2.3.4", + "accession_number": "ACC1", + "modalities": ["CT"], + "metadata": {"ris_filename": "file1"}, + "body": body, + } + + +def _staff_user_and_token() -> tuple[object, object, str]: + user = UserFactory.create(is_active=True, is_staff=True) + group = GroupFactory.create() + user.groups.add(group) + _, token = Token.objects.create_token(user, "report api test", None) + return user, group, token + + +def _non_staff_user_and_token() -> tuple[object, str]: + user = UserFactory.create(is_active=True, is_staff=False) + _, token = Token.objects.create_token(user, "non staff report api test", None) + return user, token + + +# --------------------------------------------------------------------------- +# URL resolution +# --------------------------------------------------------------------------- + +def test_report_list_url_resolves(): + assert reverse("report-list") == "/api/reports/" + + +def test_report_bulk_upsert_url_resolves(): + assert reverse("report-bulk-upsert") == "/api/reports/bulk-upsert/" + + +def test_report_detail_url_resolves(): + assert reverse("report-detail", args=["DOC-1"]) == "/api/reports/DOC-1/" + + +# --------------------------------------------------------------------------- +# POST /api/reports/ (create) +# --------------------------------------------------------------------------- + +@pytest.mark.django_db +def test_post_creates_report_and_fires_created_handler(client: Client): + _, group, token = _staff_user_and_token() + captured: list[Report] = [] + handler = ReportsCreatedHandler( + name="test-created", handle=lambda reports: captured.extend(reports) + ) + reports_created_handlers.append(handler) + try: + payload = _make_payload(document_id="DOC-CREATE") + payload["groups"] = [group.pk] + + response = client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 201 + body = response.json() + assert body["document_id"] == "DOC-CREATE" + assert body["language"] == "en" + assert body["modalities"] == ["CT"] + assert body["metadata"] == {"ris_filename": "file1"} + assert Report.objects.filter(document_id="DOC-CREATE").exists() + assert [r.document_id for r in captured] == ["DOC-CREATE"] + finally: + reports_created_handlers.remove(handler) + + +# --------------------------------------------------------------------------- +# GET /api/reports/{document_id}/ +# --------------------------------------------------------------------------- + +@pytest.mark.django_db +def test_get_returns_existing_report(client: Client): + _, group, token = _staff_user_and_token() + payload = _make_payload(document_id="DOC-GET") + payload["groups"] = [group.pk] + client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + response = client.get( + "/api/reports/DOC-GET/", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 200 + assert response.json()["document_id"] == "DOC-GET" + + +@pytest.mark.django_db +def test_get_missing_report_returns_404(client: Client): + _, _, token = _staff_user_and_token() + response = client.get( + "/api/reports/DOES-NOT-EXIST/", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 404 + + +@pytest.mark.django_db +def test_get_full_includes_documents_from_fetchers(client: Client): + _, group, token = _staff_user_and_token() + payload = _make_payload(document_id="DOC-FULL") + payload["groups"] = [group.pk] + client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + fetcher = DocumentFetcher( + source="stub-fetcher", + fetch=lambda report: {"source_id": report.document_id, "extra": "ok"}, + ) + document_fetchers["stub-fetcher"] = fetcher + try: + response = client.get( + "/api/reports/DOC-FULL/?full=true", + headers={"Authorization": f"Token {token}"}, + ) + finally: + document_fetchers.pop("stub-fetcher", None) + + assert response.status_code == 200 + body = response.json() + assert body["documents"]["stub-fetcher"] == { + "source_id": "DOC-FULL", + "extra": "ok", + } + + +# --------------------------------------------------------------------------- +# PUT /api/reports/{document_id}/ +# --------------------------------------------------------------------------- + +@pytest.mark.django_db +def test_put_updates_existing_report(client: Client): + _, group, token = _staff_user_and_token() + payload = _make_payload(document_id="DOC-PUT") + payload["groups"] = [group.pk] + client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + payload["body"] = "Updated body" + response = client.put( + "/api/reports/DOC-PUT/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 200 + assert response.json()["body"] == "Updated body" + assert Report.objects.get(document_id="DOC-PUT").body == "Updated body" + + +@pytest.mark.django_db +def test_put_upsert_creates_when_missing(client: Client): + _, group, token = _staff_user_and_token() + payload = _make_payload(document_id="DOC-UPSERT-NEW") + payload["groups"] = [group.pk] + + response = client.put( + "/api/reports/DOC-UPSERT-NEW/?upsert=true", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 201 + assert Report.objects.filter(document_id="DOC-UPSERT-NEW").exists() + + +@pytest.mark.django_db +def test_put_upsert_missing_as_non_staff_returns_403(client: Client): + """When a PUT?upsert=true hits an unknown id, DRF re-checks permissions + as if it were a POST. IsAdminUser must reject the non-staff caller.""" + _, token = _non_staff_user_and_token() + payload = _make_payload(document_id="DOC-FORBIDDEN") + + response = client.put( + "/api/reports/DOC-FORBIDDEN/?upsert=true", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 403 + assert not Report.objects.filter(document_id="DOC-FORBIDDEN").exists() + + +@pytest.mark.django_db +def test_patch_returns_405(client: Client): + _, _, token = _staff_user_and_token() + response = client.patch( + "/api/reports/DOC-NA/", + data=json.dumps({"body": "irrelevant"}), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 405 + + +# --------------------------------------------------------------------------- +# DELETE /api/reports/{document_id}/ +# --------------------------------------------------------------------------- + +@pytest.mark.django_db +def test_delete_removes_report_and_fires_deleted_handler(client: Client): + _, group, token = _staff_user_and_token() + payload = _make_payload(document_id="DOC-DEL") + payload["groups"] = [group.pk] + client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + captured: list[Report] = [] + handler = ReportsDeletedHandler( + name="test-deleted", handle=lambda reports: captured.extend(reports) + ) + reports_deleted_handlers.append(handler) + try: + response = client.delete( + "/api/reports/DOC-DEL/", + headers={"Authorization": f"Token {token}"}, + ) + finally: + reports_deleted_handlers.remove(handler) + + assert response.status_code == 204 + assert not Report.objects.filter(document_id="DOC-DEL").exists() + assert [r.document_id for r in captured] == ["DOC-DEL"] + + +# --------------------------------------------------------------------------- +# POST /api/reports/bulk-upsert/ +# --------------------------------------------------------------------------- + +@pytest.mark.django_db +def test_bulk_upsert_rejects_replace_false(client: Client): + _, _, token = _staff_user_and_token() + response = client.post( + "/api/reports/bulk-upsert/?replace=false", + data=json.dumps([]), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 400 + + +@pytest.mark.django_db +def test_bulk_upsert_rejects_non_list_payload(client: Client): + _, _, token = _staff_user_and_token() + response = client.post( + "/api/reports/bulk-upsert/", + data=json.dumps({"document_id": "DOC-NOT-A-LIST"}), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 400 + + +# --------------------------------------------------------------------------- +# Async-shape guards — fail until radis.reports.api.views exists with +# async handlers; prevent silent regressions to sync in the future. +# --------------------------------------------------------------------------- + +def test_report_list_post_is_coroutine(): + from radis.reports.api.views import ReportListAPIView + assert asyncio.iscoroutinefunction(ReportListAPIView.post) + + +def test_report_detail_methods_are_coroutines(): + from radis.reports.api.views import ReportDetailAPIView + assert asyncio.iscoroutinefunction(ReportDetailAPIView.get) + assert asyncio.iscoroutinefunction(ReportDetailAPIView.put) + assert asyncio.iscoroutinefunction(ReportDetailAPIView.delete) + + +def test_report_bulk_upsert_post_is_coroutine(): + from radis.reports.api.views import ReportBulkUpsertAPIView + assert asyncio.iscoroutinefunction(ReportBulkUpsertAPIView.post) +``` + +- [ ] **Step 2.2: Run the new file and confirm the expected mixed-result baseline** + +```bash +uv run cli test -- radis/reports/tests/test_report_api.py -v +``` + +Expected result: +- All endpoint tests (URL resolution, POST, GET, PUT, DELETE, bulk-upsert behavior) **PASS** — they run against the current DRF viewset which already implements this contract. +- The three async-shape guards (`test_report_list_post_is_coroutine`, `test_report_detail_methods_are_coroutines`, `test_report_bulk_upsert_post_is_coroutine`) **FAIL with `ModuleNotFoundError: No module named 'radis.reports.api.views'`**. + +If any endpoint test fails, **stop and report** — that means the test does not actually match the existing contract and needs fixing before the rewrite. + +- [ ] **Step 2.3: Commit** + +```bash +git add radis/reports/tests/test_report_api.py +git commit -m "$(cat <<'EOF' +test(reports): add end-to-end report API tests + async-shape guards + +Lock the wire-level contract for all five report endpoints before the +ADRF rewrite. The three iscoroutinefunction guards fail today and will +go green once the new ADRF view classes land. + +Co-Authored-By: Claude Opus 4.7 (1M context) +EOF +)" +``` + +--- + +## Task 3: Convert `viewsets.py` from sync DRF to async ADRF + +Rewrite `radis/reports/api/viewsets.py` in place. `ReportViewSet` keeps its name and module location but now subclasses `adrf.viewsets.GenericViewSet` + the four create / retrieve / update / destroy async mixins from `adrf.mixins`, plus an `@action` for `bulk_upsert`. The `bulk_upsert_reports` helper renamed in Task 1 stays in the same module — there is no `bulk.py`, no `views.py`, and no rename. The legacy module-level `from rest_framework import mixins, status, viewsets` imports are swapped for `from adrf import mixins as amixins; from adrf.viewsets import GenericViewSet`, and every mixin method override becomes `async def acreate` / `aretrieve` / `aupdate` / `adestroy`. + +**Files:** +- Modify (rewrite): `radis/reports/api/viewsets.py` + +- [ ] **Step 3.1: Rewrite `radis/reports/api/viewsets.py`** + +```python +"""ADRF report viewset. + +Single async ViewSet that mirrors the shape of the legacy DRF ReportViewSet: +GenericViewSet + selected adrf mixins, dispatched via DefaultRouter. Custom +behaviour is added by overriding the async mixin methods (acreate / +aretrieve / aupdate / adestroy) and the @action for bulk-upsert. + +Strategy: + - Native async ORM (`.aget`) for single-call lookups. + - `channels.db.database_sync_to_async` for serializer + transaction blocks + (DRF serializers and `transaction.atomic()` are sync-only). + - Request body materialised on the async thread before entering any sync + wrapper, so the ASGI body stream is never touched from a worker thread. + - For mutating handlers, the ORM write and `transaction.on_commit` + registration share one atomic block on the same DB connection so the + callback is correctly bound to the write's transaction. + +See the design doc at +docs/superpowers/specs/2026-06-08-adrf-report-views-design.md. +""" +import asyncio +import logging +from typing import Any + +from adrf import mixins as amixins +from adrf.viewsets import GenericViewSet +from channels.db import database_sync_to_async +from django.db import transaction +from django.http import Http404 +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError +from rest_framework.permissions import IsAdminUser +from rest_framework.request import Request, clone_request +from rest_framework.response import Response + +from ..models import Report +from ..site import ( + document_fetchers, + reports_created_handlers, + reports_deleted_handlers, + reports_updated_handlers, +) +from .bulk import bulk_upsert_reports +from .serializers import ReportSerializer + +logger = logging.getLogger(__name__) + + +class ReportViewSet( + amixins.CreateModelMixin, + amixins.RetrieveModelMixin, + amixins.UpdateModelMixin, + amixins.DestroyModelMixin, + GenericViewSet, +): + queryset = Report.objects.all() + serializer_class = ReportSerializer + lookup_field = "document_id" + permission_classes = [IsAdminUser] + # Block PATCH at the dispatcher level (returns 405). We never define + # `partial_update` / `apartial_update` for the same effect. + http_method_names = ["get", "post", "put", "delete", "head", "options"] + + async def acreate(self, request: Request, *args: Any, **kwargs: Any) -> Response: + data = request.data + + @database_sync_to_async + def _create() -> dict[str, Any]: + serializer = self.get_serializer(data=data) + serializer.is_valid(raise_exception=True) + report = serializer.save() + + def on_commit(): + for handler in reports_created_handlers: + logger.debug( + f"{handler.name} - handle newly created reports: " + f"{[report.document_id]}" + ) + handler.handle([report]) + + transaction.on_commit(on_commit) + return serializer.data + + return Response(await _create(), status=status.HTTP_201_CREATED) + + async def aretrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response: + try: + report = await Report.objects.select_related("language").aget( + document_id=kwargs[self.lookup_field] + ) + except Report.DoesNotExist: + raise Http404 + + data = await database_sync_to_async( + lambda: self.get_serializer(report).data + )() + + full = request.GET.get("full", "").lower() in ("true", "1", "yes") + if full: + async def _fetch(fetcher): + return fetcher.source, await database_sync_to_async(fetcher.fetch)(report) + + results = await asyncio.gather( + *(_fetch(f) for f in document_fetchers.values()) + ) + data["documents"] = { + source: doc for source, doc in results if doc is not None + } + + return Response(data) + + async def aupdate(self, request: Request, *args: Any, **kwargs: Any) -> Response: + document_id = kwargs[self.lookup_field] + upsert = request.GET.get("upsert", "").lower() in ("true", "1", "yes") + data = request.data + + try: + report = await Report.objects.aget(document_id=document_id) + except Report.DoesNotExist: + report = None + + if report is None and not upsert: + raise Http404 + if report is None and upsert: + # Replicates DRF's `get_object_or_none` + `clone_request("POST")` + # permission re-check: a non-staff PUT?upsert=true on a missing + # id must come back as 403, not 404. + await database_sync_to_async(self.check_permissions)( + clone_request(request, "POST") + ) + + @database_sync_to_async + def _save() -> tuple[dict[str, Any], int]: + serializer = self.get_serializer(report, data=data) + serializer.is_valid(raise_exception=True) + saved = serializer.save() + + def on_commit(): + handlers = ( + reports_created_handlers + if report is None + else reports_updated_handlers + ) + event = "newly created" if report is None else "updated" + for handler in handlers: + logger.debug( + f"{handler.name} - handle {event} reports: " + f"{[saved.document_id]}" + ) + handler.handle([saved]) + + transaction.on_commit(on_commit) + return serializer.data, ( + status.HTTP_201_CREATED if report is None else status.HTTP_200_OK + ) + + body, http_status = await _save() + return Response(body, status=http_status) + + async def adestroy(self, request: Request, *args: Any, **kwargs: Any) -> Response: + try: + report = await Report.objects.aget(document_id=kwargs[self.lookup_field]) + except Report.DoesNotExist: + raise Http404 + + @database_sync_to_async + def _delete_and_schedule() -> None: + with transaction.atomic(): + report.delete() + + def on_commit(): + for handler in reports_deleted_handlers: + logger.debug( + f"{handler.name} - handle deleted report: " + f"{report.document_id}" + ) + handler.handle([report]) + + transaction.on_commit(on_commit) + + await _delete_and_schedule() + return Response(status=status.HTTP_204_NO_CONTENT) + + @action(detail=False, methods=["post"], url_path="bulk-upsert") + async def bulk_upsert(self, request: Request) -> Response: + payloads = request.data + if not isinstance(payloads, list): + return Response( + {"detail": "Expected a list of report objects."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + replace = request.GET.get("replace", "true").lower() in ("true", "1", "yes") + if not replace: + return Response( + { + "detail": ( + "replace=false is not supported for bulk upsert. " + "Use replace=true." + ) + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + @database_sync_to_async + def _do() -> dict[str, Any]: + valid_payloads: list[dict[str, Any]] = [] + errors: list[dict[str, Any]] = [] + for index, payload in enumerate(payloads): + serializer = self.get_serializer( + data=payload, + context={ + **self.get_serializer_context(), + "skip_document_id_unique": True, + }, + ) + try: + serializer.is_valid(raise_exception=True) + except ValidationError as exc: + document_id = ( + payload.get("document_id") + if isinstance(payload, dict) + else None + ) + logger.error( + "Bulk upsert validation failed (index=%s document_id=%s): %s", + index, document_id, exc.detail, + ) + errors.append({ + "index": index, + "document_id": document_id, + "errors": exc.detail, + }) + continue + valid_payloads.append(serializer.validated_data) + + created_ids: list[str] = [] + updated_ids: list[str] = [] + if valid_payloads: + created_ids, updated_ids = bulk_upsert_reports(valid_payloads) + + body: dict[str, Any] = { + "created": len(created_ids), + "updated": len(updated_ids), + "invalid": len(errors), + } + if errors: + max_errors = 50 + body["errors"] = errors[:max_errors] + body["errors_truncated"] = len(errors) > max_errors + return body + + return Response(await _do()) +``` + +- [ ] **Step 3.2: Update the async-shape guard tests in `radis/reports/tests/test_report_api.py`** + +The three guards from Task 2 (which currently look up `ReportListAPIView`, `ReportDetailAPIView`, `ReportBulkUpsertAPIView`) need to point at the viewset's async methods: + +```python +def test_report_viewset_methods_are_coroutines(): + views = importlib.import_module("radis.reports.api.views") + vs = views.ReportViewSet + for name in ("acreate", "aretrieve", "aupdate", "adestroy", "bulk_upsert"): + assert inspect.iscoroutinefunction(getattr(vs, name)), f"{name} is not async" +``` + +Replace the previous `test_report_list_post_is_coroutine`, `test_report_detail_methods_are_coroutines`, and `test_report_bulk_upsert_post_is_coroutine` with this single test. + +- [ ] **Step 3.3: Lint and commit** + +```bash +uv run cli lint +git add radis/reports/api/views.py radis/reports/tests/test_report_api.py +git commit -m "feat(reports): add ReportViewSet (not yet wired into urls)" +``` + +--- + +## Task 4: Sanity-check `urls.py` and run the report tests + +The URL config in `radis/reports/api/urls.py` already registers `ReportViewSet` on a `DefaultRouter`. Since Task 3 rewrites `viewsets.py` in place (no rename, no new module), the import in `urls.py` (`from .viewsets import ReportViewSet`) does not change. This task is essentially a verification pass. + +**Files:** +- Read-only: `radis/reports/api/urls.py` + +- [ ] **Step 4.1: Confirm `urls.py` contents** + +```python +from adrf.routers import DefaultRouter +from django.urls import include, path + +from .viewsets import ReportViewSet + +router = DefaultRouter() +router.register("", ReportViewSet, basename="report") + +urlpatterns = [ + path("", include(router.urls)), +] +``` + +Important: use `adrf.routers.DefaultRouter`, **not** `rest_framework.routers.DefaultRouter`. DRF's router maps HTTP methods to sync action names (`create`/`retrieve`/`update`/`destroy`), which `adrf.mixins.*` inherit from DRF's sync mixins — so dispatch would silently call the inherited sync methods instead of our async overrides. ADRF's router remaps to `acreate`/`aretrieve`/`aupdate`/`adestroy` when `view_is_async=True`. + +The router auto-generates the same URL patterns and names the legacy code emitted: + +| Pattern | Method(s) | Viewset method | Route name | +| --- | --- | --- | --- | +| `/api/reports/` | POST | `acreate` | `report-list` | +| `/api/reports/bulk-upsert/` | POST | `bulk_upsert` (the `@action`) | `report-bulk-upsert` | +| `/api/reports/{document_id}/` | GET/PUT/DELETE | `aretrieve` / `aupdate` / `adestroy` | `report-detail` | + +`lookup_value_regex` defaults to `[^/.]+`, which forbids `.` in `document_id` — the legacy behaviour. + +- [ ] **Step 4.2: Run the report test files** + +```bash +uv run cli test -- radis/reports/tests/test_report_api.py -v +uv run cli test -- radis/reports/tests/test_bulk_upsert.py -v +``` + +Expected: all tests pass. + +--- + +## Task 5: Pre-PR verification + +No code changes — just confirm the project is healthy end-to-end before opening the PR. + +- [ ] **Step 5.1: Lint** + +```bash +uv run cli lint +``` + +Expected: zero issues. If anything fails, fix it (likely import ordering or unused imports — leftover `from rest_framework import ...` in unrelated files won't be touched). + +- [ ] **Step 5.2: Full test suite** + +```bash +uv run cli test +``` + +Expected: full green. Pay attention to any failure outside the reports app — that signals an unintended coupling we missed. + +- [ ] **Step 5.3: Manual smoke test against the running stack** + +The dev stack should still be up (`uv run cli compose-up -d` from prereqs). Use a fresh token to confirm each endpoint at the wire level: + +```bash +# Create an admin user + token in the running container if you don't have one: +uv run cli shell <<'PY' +from adit_radis_shared.accounts.factories import UserFactory, GroupFactory +from adit_radis_shared.token_authentication.models import Token +user = UserFactory.create(is_staff=True, is_active=True) +group = GroupFactory.create() +user.groups.add(group) +_, token = Token.objects.create_token(user, "smoke test", None) +print(f"TOKEN={token}") +print(f"GROUP_ID={group.pk}") +PY +``` + +Then exercise each endpoint: + +```bash +export TOKEN= +export GROUP= +BASE=http://localhost:8000/api/reports + +# CREATE +curl -sf -X POST "$BASE/" \ + -H "Authorization: Token $TOKEN" -H "Content-Type: application/json" \ + -d "$(cat <)` for each of `acreate`, `aretrieve`, `aupdate`, `adestroy`, and `bulk_upsert`. This guards against a future contributor inadvertently overriding the sync `create`/`retrieve`/`update`/`destroy` siblings inherited from the sync mixins — the dispatcher would silently switch to the sync path and break the inline-embedding follow-up. + +## Risks and mitigations + +| Risk | Mitigation | +| --- | --- | +| In-repo callers (e.g. `radis-client/`, other apps) `reverse()` route names that the old `DefaultRouter` produced. | Keep `name=` values identical (`report-list`, `report-detail`, `report-bulk-upsert`). Grep `radis-client/` and the rest of `radis/` for `reverse(` and `redirect(` referencing the old names before merge. | +| `transaction.on_commit` outside an atomic block runs immediately. | Same behavior as today's `perform_destroy`. Test asserts the deleted-handler runs after the delete returns. | +| `serializer.data` access lazy-loads related fields on the thread pool. | Already happens on the request thread today; not a regression. Re-use `select_related("language")` where present. | +| Sync mixin sibling methods (`create`, `retrieve`, `update`, `destroy`) remain on the class because the `adrf.mixins` inherit from the sync DRF mixins. A contributor could accidentally override the sync one. | Async-shape guard tests pin every entry point to `iscoroutinefunction` — a sync override flips the guard red. | +| Procrastinate worker tests (`radis/pgsearch/tests/test_process_embedding_*.py`) might appear affected. | They are not — `enqueue_bulk_index_reports` / `process_embedding_*` are unchanged. Confirm `uv run cli test` green before opening the PR. | + +## Rollout + +- Worktree already created: `.claude/worktrees/feat+adrf-views`, branch `feat/adrf-views` based on `origin/main` (commit `3e6f7540`). +- Single PR scoped to `radis/reports/api/` + `radis/reports/tests/test_report_api.py`. No migrations, no settings changes, no env vars. +- Verification before opening the PR: + - `uv run cli lint` + - `uv run cli test` + - Manual smoke: `uv run cli compose-up -- --watch`, then `curl` each endpoint with a token and confirm responses match the contract. +- PR description must state explicitly: (a) no API contract change, (b) inline embedding is **not** added in this PR — that's the follow-up. diff --git a/radis-client/tests/test_client.py b/radis-client/tests/test_client.py index 17c7d84ee..fb1a813be 100644 --- a/radis-client/tests/test_client.py +++ b/radis-client/tests/test_client.py @@ -15,7 +15,7 @@ def test_report_data_valid(): assert report.is_valid() -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) def test_report_data_post(live_server: LiveServer, mocker: MockerFixture): # Make sure it won't try to save created reports to any full text search database # as those are not available during test diff --git a/radis/reports/api/operations.py b/radis/reports/api/operations.py new file mode 100644 index 000000000..36bf92aed --- /dev/null +++ b/radis/reports/api/operations.py @@ -0,0 +1,73 @@ +"""Async write operations for Report. Callers own atomicity.""" +import logging +from typing import Any + +from ..models import Language, Metadata, Modality, Report + +logger = logging.getLogger(__name__) + + +async def create_report_from_validated( + validated_data: dict[str, Any], +) -> Report: + language = validated_data.pop("language") + groups = validated_data.pop("groups") + metadata = validated_data.pop("metadata") + modalities = validated_data.pop("modalities") + + language_instance, _ = await Language.objects.aget_or_create(**language) + report = await Report.objects.acreate( + **validated_data, language=language_instance + ) + + await report.groups.aset(groups) + + for item in metadata: + await Metadata.objects.acreate(report=report, **item) + + modality_instances: list[Modality] = [] + for modality in modalities: + instance, _ = await Modality.objects.aget_or_create(**modality) + modality_instances.append(instance) + await report.modalities.aset(modality_instances) + + return report + + +async def update_report_from_validated( + report: Report, validated_data: dict[str, Any] +) -> Report: + """Replace all mutable fields and nested associations on an existing Report. + + Metadata is fully replaced (delete + recreate); modalities and groups + are reset to the provided sets. + """ + language = validated_data.pop("language") + groups = validated_data.pop("groups") + metadata = validated_data.pop("metadata") + modalities = validated_data.pop("modalities") + + language_instance = await Language.objects.aget(**language) + report.language = language_instance + for attr, value in validated_data.items(): + setattr(report, attr, value) + await report.asave() + + await report.groups.aset(groups) + + await report.metadata.all().adelete() + for item in metadata: + await Metadata.objects.acreate(report=report, **item) + + await report.modalities.aclear() + modality_instances: list[Modality] = [] + for modality in modalities: + instance, _ = await Modality.objects.aget_or_create(**modality) + modality_instances.append(instance) + await report.modalities.aset(modality_instances) + + return report + + +async def delete_report(report: Report) -> None: + await report.adelete() diff --git a/radis/reports/api/serializers.py b/radis/reports/api/serializers.py index 6d3f03f6a..97877e575 100644 --- a/radis/reports/api/serializers.py +++ b/radis/reports/api/serializers.py @@ -1,11 +1,14 @@ from typing import Any +from adrf.serializers import ModelSerializer as AsyncModelSerializer +from asgiref.sync import async_to_sync, sync_to_async from django.db import transaction from rest_framework import serializers, validators from rest_framework.exceptions import ValidationError from rest_framework.relations import PrimaryKeyRelatedField from ..models import Language, Metadata, Modality, Report +from . import operations class MetadataSerializer(serializers.ModelSerializer): @@ -20,8 +23,7 @@ class Meta: fields = ("code",) def run_validation(self, data: dict[str, Any]) -> Any: - # We don't want to check if this modality already exists in the database - # as we later use get_or_create. + # Strip the UniqueValidator; `acreate`/`aupdate` use `get_or_create`. for validator in self.fields["code"].validators: if isinstance(validator, validators.UniqueValidator): self.fields["code"].validators.remove(validator) @@ -34,15 +36,17 @@ class Meta: fields = ("code",) def run_validation(self, data: dict[str, Any]) -> Any: - # We don't want to check if this modality already exists in the database - # as we later use get_or_create. + # Strip the UniqueValidator; `acreate`/`aupdate` use `get_or_create`. for validator in self.fields["code"].validators: if isinstance(validator, validators.UniqueValidator): self.fields["code"].validators.remove(validator) return super().run_validation(data) -class ReportSerializer(serializers.ModelSerializer): +class ReportSerializer(AsyncModelSerializer): + """`acreate`/`aupdate` own the atomic block that bounds the multi-step + write of Language → Report → groups → Metadata → Modalities.""" + language = LanguageSerializer() metadata = MetadataSerializer(many=True) modalities = ModalitySerializer(many=True) @@ -75,60 +79,25 @@ def _strip_unique_validator(self, field_name: str) -> None: if not isinstance(validator, validators.UniqueValidator) ] - def create(self, validated_data: Any) -> Any: - language = validated_data.pop("language") - groups = validated_data.pop("groups") - metadata = validated_data.pop("metadata") - modalities = validated_data.pop("modalities") - - with transaction.atomic(): - language_instance, _ = Language.objects.get_or_create(**language) - - report = Report.objects.create(**validated_data, language=language_instance) - - report.groups.set(groups) - - for metadata in metadata: - Metadata.objects.create(report=report, **metadata) - - modality_instances: list[Modality] = [] - for modality in modalities: - modality_instance, _ = Modality.objects.get_or_create(**modality) - modality_instances.append(modality_instance) - - report.modalities.set(modality_instances) - - return report - - def update(self, report: Report, validated_data: Any) -> Any: - language = validated_data.pop("language") - groups = validated_data.pop("groups") - metadata = validated_data.pop("metadata") - modalities = validated_data.pop("modalities") - - with transaction.atomic(): - language_instance = Language.objects.get(**language) - report.language = language_instance - - for attr, value in validated_data.items(): - setattr(report, attr, value) - - report.save() - - report.groups.set(groups) - - report.metadata.all().delete() - for metadata in metadata: - Metadata.objects.create(report=report, **metadata) - - report.modalities.clear() - modality_instances: list[Modality] = [] - for modality in modalities: - modality_instance, _ = Modality.objects.get_or_create(**modality) - modality_instances.append(modality_instance) - report.modalities.set(modality_instances) - - return report + async def acreate(self, validated_data: Any) -> Report: + @sync_to_async(thread_sensitive=True) + @transaction.atomic + def _atomic() -> Report: + return async_to_sync(operations.create_report_from_validated)( + validated_data + ) + + return await _atomic() + + async def aupdate(self, report: Report, validated_data: Any) -> Report: + @sync_to_async(thread_sensitive=True) + @transaction.atomic + def _atomic() -> Report: + return async_to_sync(operations.update_report_from_validated)( + report, validated_data + ) + + return await _atomic() def to_internal_value(self, data: Any) -> Any: if "language" in data: diff --git a/radis/reports/api/urls.py b/radis/reports/api/urls.py index f136d3173..8598a6f01 100644 --- a/radis/reports/api/urls.py +++ b/radis/reports/api/urls.py @@ -1,10 +1,10 @@ +from adrf.routers import DefaultRouter from django.urls import include, path -from rest_framework.routers import DefaultRouter from .viewsets import ReportViewSet router = DefaultRouter() -router.register(r"", ReportViewSet) +router.register("", ReportViewSet, basename="report") urlpatterns = [ path("", include(router.urls)), diff --git a/radis/reports/api/viewsets.py b/radis/reports/api/viewsets.py index bb684b154..42fd30825 100644 --- a/radis/reports/api/viewsets.py +++ b/radis/reports/api/viewsets.py @@ -1,17 +1,29 @@ +"""ADRF report viewset. + +URLs must be wired through `adrf.routers.DefaultRouter` (not DRF's). DRF's +router dispatches HTTP methods to the sync action names (`create`/`retrieve`/ +`update`/`destroy`) which `adrf.mixins.*` inherits as fully-functional sync +methods from DRF — so DRF-router dispatch silently bypasses the async +overrides on this class. `adrf.routers.DefaultRouter` remaps to the +`a`-prefixed names whenever `view_is_async=True`. +""" +import asyncio import logging -from typing import Any +from typing import Any, cast +from adrf import mixins as amixins +from adrf.viewsets import GenericViewSet +from asgiref.sync import async_to_sync, sync_to_async from django.conf import settings from django.db import transaction from django.http import Http404 from django.utils import timezone -from rest_framework import mixins, status, viewsets +from rest_framework import status from rest_framework.decorators import action -from rest_framework.exceptions import MethodNotAllowed, ValidationError +from rest_framework.exceptions import ValidationError from rest_framework.permissions import IsAdminUser from rest_framework.request import Request, clone_request from rest_framework.response import Response -from rest_framework.serializers import BaseSerializer from radis.pgsearch.tasks import enqueue_bulk_index_reports from radis.pgsearch.utils.indexing import bulk_upsert_report_search_vectors @@ -23,80 +35,81 @@ reports_deleted_handlers, reports_updated_handlers, ) +from . import operations from .serializers import ReportSerializer logger = logging.getLogger(__name__) -BULK_DB_BATCH_SIZE = 1000 - -def _bulk_upsert_reports( +async def bulk_upsert_reports( validated_reports: list[dict[str, Any]], ) -> tuple[list[str], list[str]]: + """Bulk-upsert validated report payloads. + + Four phases: + 1. Dedupe input by document_id (CPU) + 2. Preflight Language/Modality/existing-Report reads + 3. Build new_reports / updated_reports lists (CPU) + 4. Atomic writes — bulk_create/bulk_update + through-table churn + + The CPU phases run off the event loop via `@sync_to_async` helpers. + """ if not validated_reports: return [], [] - deduped_reports: dict[str, dict[str, Any]] = {} - duplicate_count = 0 - for report in validated_reports: - document_id = report["document_id"] - if document_id in deduped_reports: - duplicate_count += 1 - deduped_reports[document_id] = report - if duplicate_count: - logger.warning( - "Bulk upsert payload contained %s duplicate document_ids; keeping last occurrence.", - duplicate_count, - ) - validated_reports = list(deduped_reports.values()) - - def _dedupe_by_key( - items: list[dict[str, Any]], key_name: str - ) -> tuple[list[dict[str, Any]], int]: - if not items: - return [], 0 - by_key: dict[str, dict[str, Any]] = {} - for item in items: - key = item[key_name] - by_key[key] = item - return list(by_key.values()), len(items) - len(by_key) - - def _dedupe_metadata(items: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], int]: - if not items: - return [], 0 - by_key: dict[str, dict[str, Any]] = {} - duplicates = 0 - for item in items: - key = item["key"] - if key in by_key: - duplicates += 1 - by_key[key] = item - return list(by_key.values()), duplicates - - def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]: - if not items: - return [], 0 - by_id: dict[int, int] = {} - for group in items: - group_id = int(getattr(group, "pk", group)) - by_id[group_id] = group_id - return list(by_id.values()), len(items) - len(by_id) + report_field_names = ( + "document_id", + "pacs_aet", + "pacs_name", + "pacs_link", + "patient_id", + "patient_birth_date", + "patient_sex", + "study_description", + "study_datetime", + "study_instance_uid", + "accession_number", + "body", + ) + + # ── Phase 1: CPU-only dedupe of incoming payload (off-loop) ── + @sync_to_async(thread_sensitive=True) + def _dedupe_payload() -> list[dict[str, Any]]: + deduped_reports: dict[str, dict[str, Any]] = {} + duplicate_count = 0 + for report in validated_reports: + document_id = report["document_id"] + if document_id in deduped_reports: + duplicate_count += 1 + deduped_reports[document_id] = report + if duplicate_count: + logger.warning( + "Bulk upsert payload contained %s duplicate document_ids; " + "keeping last occurrence.", + duplicate_count, + ) + return list(deduped_reports.values()) + return validated_reports + validated_reports = await _dedupe_payload() document_ids = [report["document_id"] for report in validated_reports] + # ── Phase 2: preflight reads/writes that do NOT need atomicity ── language_codes = {report["language"]["code"] for report in validated_reports} language_by_code = { - lang.code: lang for lang in Language.objects.filter(code__in=language_codes) + lang.code: lang + async for lang in Language.objects.filter(code__in=language_codes) } missing_language_codes = language_codes - language_by_code.keys() if missing_language_codes: - Language.objects.bulk_create( + await Language.objects.abulk_create( [Language(code=code) for code in missing_language_codes], ignore_conflicts=True, - batch_size=BULK_DB_BATCH_SIZE, + batch_size=settings.REPORTS_BULK_DB_BATCH_SIZE, ) language_by_code = { - lang.code: lang for lang in Language.objects.filter(code__in=language_codes) + lang.code: lang + async for lang in Language.objects.filter(code__in=language_codes) } modality_codes = { @@ -104,75 +117,111 @@ def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]: for report in validated_reports for modality in report.get("modalities", []) } - modality_by_code = {mod.code: mod for mod in Modality.objects.filter(code__in=modality_codes)} + modality_by_code = { + mod.code: mod + async for mod in Modality.objects.filter(code__in=modality_codes) + } missing_modality_codes = modality_codes - modality_by_code.keys() if missing_modality_codes: - Modality.objects.bulk_create( + await Modality.objects.abulk_create( [Modality(code=code) for code in missing_modality_codes], ignore_conflicts=True, - batch_size=BULK_DB_BATCH_SIZE, + batch_size=settings.REPORTS_BULK_DB_BATCH_SIZE, ) modality_by_code = { - mod.code: mod for mod in Modality.objects.filter(code__in=modality_codes) + mod.code: mod + async for mod in Modality.objects.filter(code__in=modality_codes) } - existing_reports = Report.objects.filter(document_id__in=document_ids) - existing_by_document_id = {report.document_id: report for report in existing_reports} - - now = timezone.now() - created_ids: list[str] = [] - updated_ids: list[str] = [] - new_reports: list[Report] = [] - updated_reports: list[Report] = [] - - report_field_names = ( - "document_id", - "pacs_aet", - "pacs_name", - "pacs_link", - "patient_id", - "patient_birth_date", - "patient_sex", - "study_description", - "study_datetime", - "study_instance_uid", - "accession_number", - "body", - ) + existing_by_document_id = { + report.document_id: report + async for report in Report.objects.filter(document_id__in=document_ids) + } - for report_data in validated_reports: - document_id = report_data["document_id"] - language = language_by_code[report_data["language"]["code"]] - report_fields = {field: report_data[field] for field in report_field_names} - - existing = existing_by_document_id.get(document_id) - if existing: - for field, value in report_fields.items(): - setattr(existing, field, value) - existing.language = language - existing.updated_at = now - updated_reports.append(existing) - updated_ids.append(document_id) - else: - new_reports.append( - Report( - **report_fields, - language=language, - created_at=now, - updated_at=now, + # ── Phase 3: CPU-only build of new_reports / updated_reports lists (off-loop) ── + @sync_to_async(thread_sensitive=True) + def _build_report_lists() -> tuple[ + list[Report], list[Report], list[str], list[str] + ]: + now = timezone.now() + created_ids: list[str] = [] + updated_ids: list[str] = [] + new_reports: list[Report] = [] + updated_reports: list[Report] = [] + + for report_data in validated_reports: + document_id = report_data["document_id"] + language = language_by_code[report_data["language"]["code"]] + report_fields = {field: report_data[field] for field in report_field_names} + + existing = existing_by_document_id.get(document_id) + if existing: + for field, value in report_fields.items(): + setattr(existing, field, value) + existing.language = language + existing.updated_at = now + updated_reports.append(existing) + updated_ids.append(document_id) + else: + new_reports.append( + Report( + **report_fields, + language=language, + created_at=now, + updated_at=now, + ) ) - ) - created_ids.append(document_id) + created_ids.append(document_id) + + return new_reports, updated_reports, created_ids, updated_ids + + new_reports, updated_reports, created_ids, updated_ids = await _build_report_lists() + + # ── Phase 4: atomic writes ── + @sync_to_async(thread_sensitive=True) + @transaction.atomic + def _do_atomic_writes() -> None: + def _dedupe_by_key( + items: list[dict[str, Any]], key_name: str + ) -> tuple[list[dict[str, Any]], int]: + if not items: + return [], 0 + by_key: dict[str, dict[str, Any]] = {} + for item in items: + by_key[item[key_name]] = item + return list(by_key.values()), len(items) - len(by_key) + + def _dedupe_metadata( + items: list[dict[str, Any]] + ) -> tuple[list[dict[str, Any]], int]: + if not items: + return [], 0 + by_key: dict[str, dict[str, Any]] = {} + duplicates = 0 + for item in items: + key = item["key"] + if key in by_key: + duplicates += 1 + by_key[key] = item + return list(by_key.values()), duplicates + + def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]: + if not items: + return [], 0 + by_id: dict[int, int] = {} + for group in items: + group_id = int(getattr(group, "pk", group)) + by_id[group_id] = group_id + return list(by_id.values()), len(items) - len(by_id) - with transaction.atomic(): - if new_reports: - Report.objects.bulk_create(new_reports, batch_size=BULK_DB_BATCH_SIZE) + if new_reports: + Report.objects.bulk_create(new_reports, batch_size=settings.REPORTS_BULK_DB_BATCH_SIZE) if updated_reports: Report.objects.bulk_update( updated_reports, fields=[*report_field_names, "language", "updated_at"], - batch_size=BULK_DB_BATCH_SIZE, + batch_size=settings.REPORTS_BULK_DB_BATCH_SIZE, ) report_id_by_document_id = { @@ -197,7 +246,9 @@ def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]: Metadata(report_id=report_id, key=item["key"], value=item["value"]) ) if metadata_rows: - Metadata.objects.bulk_create(metadata_rows, batch_size=BULK_DB_BATCH_SIZE) + Metadata.objects.bulk_create( + metadata_rows, batch_size=settings.REPORTS_BULK_DB_BATCH_SIZE + ) modality_through = Report.modalities.through modality_through.objects.filter(report_id__in=report_ids).delete() @@ -216,7 +267,9 @@ def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]: modality_through(report_id=report_id, modality_id=modality_id) ) if modality_rows: - modality_through.objects.bulk_create(modality_rows, batch_size=BULK_DB_BATCH_SIZE) + modality_through.objects.bulk_create( + modality_rows, batch_size=settings.REPORTS_BULK_DB_BATCH_SIZE + ) group_through = Report.groups.through group_through.objects.filter(report_id__in=report_ids).delete() @@ -230,7 +283,9 @@ def _dedupe_groups(items: list[Any]) -> tuple[list[int], int]: for group_id in group_items: group_rows.append(group_through(report_id=report_id, group_id=group_id)) if group_rows: - group_through.objects.bulk_create(group_rows, batch_size=BULK_DB_BATCH_SIZE) + group_through.objects.bulk_create( + group_rows, batch_size=settings.REPORTS_BULK_DB_BATCH_SIZE + ) if metadata_duplicate_count or modality_duplicate_count or group_duplicate_count: logger.warning( @@ -253,9 +308,11 @@ def on_commit(): for handler in reports_created_handlers: handler.handle(created_reports) if updated_ids: - updated_reports = list(Report.objects.filter(document_id__in=updated_ids)) + updated_reports_after_commit = list( + Report.objects.filter(document_id__in=updated_ids) + ) for handler in reports_updated_handlers: - handler.handle(updated_reports) + handler.handle(updated_reports_after_commit) if touched_report_ids: if settings.PGSEARCH_SYNC_INDEXING: bulk_upsert_report_search_vectors(touched_report_ids) @@ -264,185 +321,230 @@ def on_commit(): transaction.on_commit(on_commit) + await _do_atomic_writes() return created_ids, updated_ids class ReportViewSet( - mixins.CreateModelMixin, - mixins.DestroyModelMixin, - mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - viewsets.GenericViewSet, + amixins.CreateModelMixin, + amixins.RetrieveModelMixin, + amixins.UpdateModelMixin, + amixins.DestroyModelMixin, + GenericViewSet, ): - """ViewSet for fetch, creating, updating, and deleting Reports. - - Only admins (staff users) can do that. - """ - - serializer_class = ReportSerializer queryset = Report.objects.all() + serializer_class = ReportSerializer lookup_field = "document_id" permission_classes = [IsAdminUser] + http_method_names = ["get", "post", "put", "delete", "head", "options"] + + async def acreate(self, request: Request, *args: Any, **kwargs: Any) -> Response: + serializer = cast(ReportSerializer, self.get_serializer(data=request.data)) + await sync_to_async(serializer.is_valid, thread_sensitive=True)( + raise_exception=True + ) + + report = await serializer.asave() + + def on_commit(): + for handler in reports_created_handlers: + logger.debug( + f"{handler.name} - handle newly created reports: " + f"{[report.document_id]}" + ) + handler.handle([report]) - def get_serializer(self, *args: Any, **kwargs: Any) -> BaseSerializer: - if isinstance(kwargs.get("data", {}), list): - kwargs["many"] = True - return super().get_serializer(*args, **kwargs) + # `transaction.on_commit` hits `ensure_connection()`, which is + # sync-only; we're in an async handler, so wrap. + await sync_to_async(transaction.on_commit, thread_sensitive=True)(on_commit) - def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response: - """Retrieve a single Report. + response_data = await sync_to_async( + lambda: serializer.data, thread_sensitive=True + )() + return Response(response_data, status=status.HTTP_201_CREATED) - It also fetches the associated documents from all external databases. - """ - full = request.GET.get("full", "").lower() in ["true", "1", "yes"] + async def aretrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response: + try: + report = await Report.objects.select_related("language").aget( + document_id=kwargs[self.lookup_field] + ) + except Report.DoesNotExist: + raise Http404 - instance: Report = self.get_object() - serializer = self.get_serializer(instance) - data = serializer.data + data = await sync_to_async( + lambda: self.get_serializer(report).data, + thread_sensitive=True, + )() + full = request.GET.get("full", "").lower() in ("true", "1", "yes") if full: - documents = {} - for fetcher in document_fetchers.values(): - document = fetcher.fetch(instance) - if document: - documents[fetcher.source] = document - data["documents"] = documents + async def _fetch(fetcher): + return fetcher.source, await sync_to_async( + fetcher.fetch, thread_sensitive=True + )(report) + + results = await asyncio.gather( + *(_fetch(f) for f in document_fetchers.values()) + ) + data["documents"] = { + source: doc for source, doc in results if doc is not None + } return Response(data) - def perform_create(self, serializer: BaseSerializer) -> None: - super().perform_create(serializer) - assert serializer.instance - reports: list[Report] | Report = serializer.instance - if not isinstance(reports, list): - reports = [reports] + async def aupdate(self, request: Request, *args: Any, **kwargs: Any) -> Response: + document_id = kwargs[self.lookup_field] + upsert = request.GET.get("upsert", "").lower() in ("true", "1", "yes") + data = request.data + + try: + report = await Report.objects.aget(document_id=document_id) + except Report.DoesNotExist: + report = None + + if report is None and not upsert: + raise Http404 + if report is None and upsert: + # A non-staff PUT?upsert=true on a missing id must return 403, + # not 404 — re-check permissions against a synthetic POST. + await sync_to_async(self.check_permissions, thread_sensitive=True)( + clone_request(request, "POST") + ) + + serializer = cast( + ReportSerializer, self.get_serializer(report, data=data) + ) + await sync_to_async(serializer.is_valid, thread_sensitive=True)( + raise_exception=True + ) + + saved = await serializer.asave() def on_commit(): - for handler in reports_created_handlers: - document_ids = [report.document_id for report in reports] - logger.debug(f"{handler.name} - handle newly created reports: {document_ids}") - handler.handle(reports) + handlers = ( + reports_created_handlers + if report is None + else reports_updated_handlers + ) + event = "newly created" if report is None else "updated" + for handler in handlers: + logger.debug( + f"{handler.name} - handle {event} reports: " + f"{[saved.document_id]}" + ) + handler.handle([saved]) - transaction.on_commit(on_commit) + await sync_to_async(transaction.on_commit, thread_sensitive=True)(on_commit) + + response_data = await sync_to_async( + lambda: serializer.data, thread_sensitive=True + )() + http_status = ( + status.HTTP_201_CREATED if report is None else status.HTTP_200_OK + ) + return Response(response_data, status=http_status) + + async def adestroy(self, request: Request, *args: Any, **kwargs: Any) -> Response: + try: + report = await Report.objects.aget(document_id=kwargs[self.lookup_field]) + except Report.DoesNotExist: + raise Http404 + + # The helper holds the transaction across the delete and the + # `transaction.on_commit` registration so the callback is bound + # to the delete's transaction. + @sync_to_async(thread_sensitive=True) + @transaction.atomic + def _delete_and_schedule() -> None: + async_to_sync(operations.delete_report)(report) + + def on_commit(): + for handler in reports_deleted_handlers: + logger.debug( + f"{handler.name} - handle deleted report: " + f"{report.document_id}" + ) + handler.handle([report]) + + transaction.on_commit(on_commit) - def update(self, request: Request, *args: Any, **kwargs: Any) -> Response: - # DRF itself does not support upsert. - # Workaround adapted from https://gist.github.com/tomchristie/a2ace4577eff2c603b1b - upsert = request.GET.get("upsert", "").lower() in ["true", "1", "yes"] - if not upsert: - return super().update(request, *args, **kwargs) - else: - instance = self.get_object_or_none() - serializer = self.get_serializer(instance, data=request.data) - serializer.is_valid(raise_exception=True) - - if instance is None: - self.perform_create(serializer) - return Response(serializer.data, status=status.HTTP_201_CREATED) - - self.perform_update(serializer) - return Response(serializer.data) - - @action(detail=False, methods=["post"], url_path="bulk-upsert") - def bulk_upsert(self, request: Request) -> Response: - if not isinstance(request.data, list): + await _delete_and_schedule() + return Response(status=status.HTTP_204_NO_CONTENT) + + # DRF's `@action` stub types its arg as a sync view returning + # HttpResponseBase, but ADRF dispatches `async def` actions fine. + @action(detail=False, methods=["post"], url_path="bulk-upsert") # pyright: ignore[reportArgumentType] + async def bulk_upsert(self, request: Request) -> Response: + payloads = request.data + if not isinstance(payloads, list): return Response( {"detail": "Expected a list of report objects."}, status=status.HTTP_400_BAD_REQUEST, ) - replace = request.GET.get("replace", "true").lower() in ["true", "1", "yes"] + replace = request.GET.get("replace", "true").lower() in ("true", "1", "yes") if not replace: return Response( - {"detail": "replace=false is not supported for bulk upsert. Use replace=true."}, + { + "detail": ( + "replace=false is not supported for bulk upsert. " + "Use replace=true." + ) + }, status=status.HTTP_400_BAD_REQUEST, ) - valid_payloads: list[dict[str, Any]] = [] - errors: list[dict[str, Any]] = [] - for index, payload in enumerate(request.data): - serializer = self.get_serializer( - data=payload, - context={ - **self.get_serializer_context(), - "skip_document_id_unique": True, - }, - ) - try: - serializer.is_valid(raise_exception=True) - except ValidationError as exc: - document_id = ( - payload.get("document_id") - if isinstance(payload, dict) - else None + @sync_to_async(thread_sensitive=True) + def _validate() -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + valid_payloads: list[dict[str, Any]] = [] + errors: list[dict[str, Any]] = [] + for index, payload in enumerate(payloads): + serializer = self.get_serializer( + data=payload, + context={ + **self.get_serializer_context(), + "skip_document_id_unique": True, + }, ) - logger.error( - "Bulk upsert validation failed (index=%s document_id=%s): %s", - index, - document_id, - exc.detail, - ) - errors.append( - { - "index": index, - "document_id": document_id, - "errors": exc.detail, - } - ) - continue - valid_payloads.append(serializer.validated_data) + try: + serializer.is_valid(raise_exception=True) + except ValidationError as exc: + document_id = ( + payload.get("document_id") + if isinstance(payload, dict) + else None + ) + logger.error( + "Bulk upsert validation failed (index=%s document_id=%s): %s", + index, + document_id, + exc.detail, + ) + errors.append( + { + "index": index, + "document_id": document_id, + "errors": exc.detail, + } + ) + continue + valid_payloads.append(serializer.validated_data) + return valid_payloads, errors + + valid_payloads, errors = await _validate() created_ids: list[str] = [] updated_ids: list[str] = [] if valid_payloads: - created_ids, updated_ids = _bulk_upsert_reports(valid_payloads) + created_ids, updated_ids = await bulk_upsert_reports(valid_payloads) - response_body: dict[str, Any] = { + body: dict[str, Any] = { "created": len(created_ids), "updated": len(updated_ids), "invalid": len(errors), } if errors: max_errors = 50 - response_body["errors"] = errors[:max_errors] - response_body["errors_truncated"] = len(errors) > max_errors - return Response(response_body) - - def get_object_or_none(self) -> Report | None: - try: - return self.get_object() - except Http404: - if self.request.method == "PUT": - self.check_permissions(clone_request(self.request, "POST")) - else: - raise - - def perform_update(self, serializer: BaseSerializer) -> None: - super().perform_update(serializer) - assert serializer.instance - reports: list[Report] | Report = serializer.instance - if not isinstance(reports, list): - reports = [reports] - - def on_commit(): - for handler in reports_updated_handlers: - document_ids = [report.document_id for report in reports] - logger.debug(f"{handler.name} - handle updated reports: {document_ids}") - handler.handle(reports) - - transaction.on_commit(on_commit) - - def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Response: - # Disallow partial updates - assert request.method - raise MethodNotAllowed(request.method) - - def perform_destroy(self, instance: Report) -> None: - super().perform_destroy(instance) - - def on_commit(): - for handler in reports_deleted_handlers: - logger.debug(f"{handler.name} - handle deleted report: {instance.document_id}") - handler.handle([instance]) - - transaction.on_commit(on_commit) + body["errors"] = errors[:max_errors] + body["errors_truncated"] = len(errors) > max_errors + return Response(body) diff --git a/radis/reports/tests/test_bulk_upsert.py b/radis/reports/tests/test_bulk_upsert.py index dcd4ebde3..1f5948186 100644 --- a/radis/reports/tests/test_bulk_upsert.py +++ b/radis/reports/tests/test_bulk_upsert.py @@ -4,18 +4,26 @@ import pytest from adit_radis_shared.accounts.factories import GroupFactory, UserFactory from adit_radis_shared.token_authentication.models import Token -from django.test import Client +from asgiref.sync import sync_to_async +from django.contrib.auth.models import Group +from django.test import AsyncClient -from radis.reports.api.viewsets import _bulk_upsert_reports +from radis.reports.api.viewsets import bulk_upsert_reports from radis.reports.models import Language, Metadata, Modality, Report -@pytest.mark.django_db -def test_bulk_upsert_creates_and_updates_reports(client: Client): +def _create_staff_user_group_token(label: str) -> tuple[Group, str]: user = UserFactory.create(is_active=True, is_staff=True) group = GroupFactory.create() user.groups.add(group) - _, token = Token.objects.create_token(user, "bulk upsert test", None) + _, token = Token.objects.create_token(user, label, None) + return group, token + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_bulk_upsert_creates_and_updates_reports(async_client: AsyncClient): + group, token = await sync_to_async(_create_staff_user_group_token)("bulk upsert test") payload = [ { "document_id": "DOC-1", @@ -55,7 +63,7 @@ def test_bulk_upsert_creates_and_updates_reports(client: Client): }, ] - response = client.post( + response = await async_client.post( "/api/reports/bulk-upsert/", data=json.dumps(payload), content_type="application/json", @@ -64,16 +72,16 @@ def test_bulk_upsert_creates_and_updates_reports(client: Client): assert response.status_code == 200 assert response.json() == {"created": 2, "updated": 0, "invalid": 0} - assert Report.objects.count() == 2 - assert Language.objects.filter(code="en").exists() - assert Language.objects.filter(code="de").exists() - assert Modality.objects.filter(code="CT").exists() - assert Modality.objects.filter(code="MR").exists() + assert await Report.objects.acount() == 2 + assert await Language.objects.filter(code="en").aexists() + assert await Language.objects.filter(code="de").aexists() + assert await Modality.objects.filter(code="CT").aexists() + assert await Modality.objects.filter(code="MR").aexists() payload[0]["body"] = "Updated body" payload[0]["metadata"] = {"ris_filename": "file1", "extra": "value"} - response = client.post( + response = await async_client.post( "/api/reports/bulk-upsert/", data=json.dumps(payload), content_type="application/json", @@ -82,17 +90,17 @@ def test_bulk_upsert_creates_and_updates_reports(client: Client): assert response.status_code == 200 assert response.json() == {"created": 0, "updated": 2, "invalid": 0} - report = Report.objects.get(document_id="DOC-1") + report = await Report.objects.aget(document_id="DOC-1") assert report.body == "Updated body" - assert Metadata.objects.filter(report=report).count() == 2 + assert await Metadata.objects.filter(report=report).acount() == 2 -@pytest.mark.django_db -def test_bulk_upsert_dedupes_payload_entries(client: Client): - user = UserFactory.create(is_active=True, is_staff=True) - group = GroupFactory.create() - user.groups.add(group) - _, token = Token.objects.create_token(user, "bulk upsert dedupe test", None) +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_bulk_upsert_dedupes_payload_entries(async_client: AsyncClient): + group, token = await sync_to_async(_create_staff_user_group_token)( + "bulk upsert dedupe test" + ) payload = [ { @@ -133,7 +141,7 @@ def test_bulk_upsert_dedupes_payload_entries(client: Client): }, ] - response = client.post( + response = await async_client.post( "/api/reports/bulk-upsert/", data=json.dumps(payload), content_type="application/json", @@ -142,16 +150,17 @@ def test_bulk_upsert_dedupes_payload_entries(client: Client): assert response.status_code == 200 assert response.json() == {"created": 1, "updated": 0, "invalid": 0} - report = Report.objects.get(document_id="DOC-1") + report = await Report.objects.aget(document_id="DOC-1") assert report.body == "Second version" - assert report.modalities.count() == 1 - assert report.groups.count() == 1 - assert Metadata.objects.filter(report=report).count() == 2 + assert await report.modalities.acount() == 1 + assert await report.groups.acount() == 1 + assert await Metadata.objects.filter(report=report).acount() == 2 -@pytest.mark.django_db -def test_bulk_upsert_dedupes_metadata_keys(): - group = GroupFactory.create() +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_bulk_upsert_dedupes_metadata_keys(): + group = await sync_to_async(GroupFactory.create, thread_sensitive=True)() validated_reports = [ { @@ -177,10 +186,10 @@ def test_bulk_upsert_dedupes_metadata_keys(): }, ] - created_ids, updated_ids = _bulk_upsert_reports(validated_reports) + created_ids, updated_ids = await bulk_upsert_reports(validated_reports) assert created_ids == ["DOC-1"] assert updated_ids == [] - report = Report.objects.get(document_id="DOC-1") - metadata = Metadata.objects.get(report=report, key="ris_filename") + report = await Report.objects.aget(document_id="DOC-1") + metadata = await Metadata.objects.aget(report=report, key="ris_filename") assert metadata.value == "file2" diff --git a/radis/reports/tests/test_report_api.py b/radis/reports/tests/test_report_api.py new file mode 100644 index 000000000..2fd1b11d4 --- /dev/null +++ b/radis/reports/tests/test_report_api.py @@ -0,0 +1,354 @@ +"""End-to-end tests for the report HTTP API.""" +import importlib +import inspect +import json +from typing import Any + +import pytest +from adit_radis_shared.accounts.factories import GroupFactory, UserFactory +from adit_radis_shared.accounts.models import User +from adit_radis_shared.token_authentication.models import Token +from asgiref.sync import sync_to_async +from django.contrib.auth.models import Group +from django.test import AsyncClient +from django.urls import reverse + +from radis.reports.models import Report +from radis.reports.site import ( + DocumentFetcher, + ReportsCreatedHandler, + ReportsDeletedHandler, + document_fetchers, + reports_created_handlers, + reports_deleted_handlers, +) + + +def _make_payload(document_id: str = "DOC-1", body: str = "Report body") -> dict[str, Any]: + return { + "document_id": document_id, + "language": "en", + "groups": [], # populated by tests after group is known + "pacs_aet": "PACS", + "pacs_name": "Test PACS", + "pacs_link": "", + "patient_id": "P1", + "patient_birth_date": "1980-01-01", + "patient_sex": "M", + "study_description": "Study 1", + "study_datetime": "2024-01-01T00:00:00Z", + "study_instance_uid": "1.2.3.4", + "accession_number": "ACC1", + "modalities": ["CT"], + "metadata": {"ris_filename": "file1"}, + "body": body, + } + + +def _staff_user_and_token() -> tuple[User, Group, str]: + user = UserFactory.create(is_active=True, is_staff=True) + group = GroupFactory.create() + user.groups.add(group) + _, token = Token.objects.create_token(user, "report api test", None) + return user, group, token + + +def _non_staff_user_and_token() -> tuple[User, str]: + user = UserFactory.create(is_active=True, is_staff=False) + _, token = Token.objects.create_token(user, "non staff report api test", None) + return user, token + + +# --------------------------------------------------------------------------- +# URL resolution +# --------------------------------------------------------------------------- + +def test_report_list_url_resolves(): + assert reverse("report-list") == "/api/reports/" + + +def test_report_bulk_upsert_url_resolves(): + assert reverse("report-bulk-upsert") == "/api/reports/bulk-upsert/" + + +def test_report_detail_url_resolves(): + assert reverse("report-detail", args=["DOC-1"]) == "/api/reports/DOC-1/" + + +# --------------------------------------------------------------------------- +# POST /api/reports/ (create) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_post_creates_report_and_fires_created_handler( + async_client: AsyncClient, django_capture_on_commit_callbacks +): + _, group, token = await sync_to_async(_staff_user_and_token)() + captured: list[Report] = [] + handler = ReportsCreatedHandler( + name="test-created", handle=lambda reports: captured.extend(reports) + ) + reports_created_handlers.append(handler) + try: + payload = _make_payload(document_id="DOC-CREATE") + payload["groups"] = [group.pk] + + with django_capture_on_commit_callbacks(execute=True): + response = await async_client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 201 + body = response.json() + assert body["document_id"] == "DOC-CREATE" + assert body["language"] == "en" + assert body["modalities"] == ["CT"] + assert body["metadata"] == {"ris_filename": "file1"} + assert await Report.objects.filter(document_id="DOC-CREATE").aexists() + assert [r.document_id for r in captured] == ["DOC-CREATE"] + finally: + reports_created_handlers.remove(handler) + + +# --------------------------------------------------------------------------- +# GET /api/reports/{document_id}/ +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_get_returns_existing_report(async_client: AsyncClient): + _, group, token = await sync_to_async(_staff_user_and_token)() + payload = _make_payload(document_id="DOC-GET") + payload["groups"] = [group.pk] + await async_client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + response = await async_client.get( + "/api/reports/DOC-GET/", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 200 + assert response.json()["document_id"] == "DOC-GET" + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_get_missing_report_returns_404(async_client: AsyncClient): + _, _, token = await sync_to_async(_staff_user_and_token)() + response = await async_client.get( + "/api/reports/DOES-NOT-EXIST/", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 404 + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_get_full_includes_documents_from_fetchers(async_client: AsyncClient): + _, group, token = await sync_to_async(_staff_user_and_token)() + payload = _make_payload(document_id="DOC-FULL") + payload["groups"] = [group.pk] + await async_client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + fetcher = DocumentFetcher( + source="stub-fetcher", + fetch=lambda report: {"source_id": report.document_id, "extra": "ok"}, + ) + document_fetchers["stub-fetcher"] = fetcher + try: + response = await async_client.get( + "/api/reports/DOC-FULL/?full=true", + headers={"Authorization": f"Token {token}"}, + ) + finally: + document_fetchers.pop("stub-fetcher", None) + + assert response.status_code == 200 + body = response.json() + assert body["documents"]["stub-fetcher"] == { + "source_id": "DOC-FULL", + "extra": "ok", + } + + +# --------------------------------------------------------------------------- +# PUT /api/reports/{document_id}/ +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_put_updates_existing_report(async_client: AsyncClient): + _, group, token = await sync_to_async(_staff_user_and_token)() + payload = _make_payload(document_id="DOC-PUT") + payload["groups"] = [group.pk] + await async_client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + payload["body"] = "Updated body" + response = await async_client.put( + "/api/reports/DOC-PUT/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 200 + assert response.json()["body"] == "Updated body" + updated = await Report.objects.aget(document_id="DOC-PUT") + assert updated.body == "Updated body" + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_put_upsert_creates_when_missing(async_client: AsyncClient): + _, group, token = await sync_to_async(_staff_user_and_token)() + payload = _make_payload(document_id="DOC-UPSERT-NEW") + payload["groups"] = [group.pk] + + response = await async_client.put( + "/api/reports/DOC-UPSERT-NEW/?upsert=true", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 201 + assert await Report.objects.filter(document_id="DOC-UPSERT-NEW").aexists() + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_put_upsert_missing_as_non_staff_returns_403(async_client: AsyncClient): + """When a PUT?upsert=true hits an unknown id, DRF re-checks permissions + as if it were a POST. IsAdminUser must reject the non-staff caller.""" + _, token = await sync_to_async(_non_staff_user_and_token)() + payload = _make_payload(document_id="DOC-FORBIDDEN") + + response = await async_client.put( + "/api/reports/DOC-FORBIDDEN/?upsert=true", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + assert response.status_code == 403 + assert not await Report.objects.filter(document_id="DOC-FORBIDDEN").aexists() + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_patch_returns_405(async_client: AsyncClient): + _, _, token = await sync_to_async(_staff_user_and_token)() + response = await async_client.patch( + "/api/reports/DOC-NA/", + data=json.dumps({"body": "irrelevant"}), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 405 + + +# --------------------------------------------------------------------------- +# DELETE /api/reports/{document_id}/ +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_delete_removes_report_and_fires_deleted_handler( + async_client: AsyncClient, django_capture_on_commit_callbacks +): + _, group, token = await sync_to_async(_staff_user_and_token)() + payload = _make_payload(document_id="DOC-DEL") + payload["groups"] = [group.pk] + await async_client.post( + "/api/reports/", + data=json.dumps(payload), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + + captured: list[Report] = [] + handler = ReportsDeletedHandler( + name="test-deleted", handle=lambda reports: captured.extend(reports) + ) + reports_deleted_handlers.append(handler) + try: + with django_capture_on_commit_callbacks(execute=True): + response = await async_client.delete( + "/api/reports/DOC-DEL/", + headers={"Authorization": f"Token {token}"}, + ) + finally: + reports_deleted_handlers.remove(handler) + + assert response.status_code == 204 + assert not await Report.objects.filter(document_id="DOC-DEL").aexists() + assert [r.document_id for r in captured] == ["DOC-DEL"] + + +# --------------------------------------------------------------------------- +# POST /api/reports/bulk-upsert/ +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_bulk_upsert_rejects_replace_false(async_client: AsyncClient): + _, _, token = await sync_to_async(_staff_user_and_token)() + response = await async_client.post( + "/api/reports/bulk-upsert/?replace=false", + data=json.dumps([]), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_bulk_upsert_rejects_non_list_payload(async_client: AsyncClient): + _, _, token = await sync_to_async(_staff_user_and_token)() + response = await async_client.post( + "/api/reports/bulk-upsert/", + data=json.dumps({"document_id": "DOC-NOT-A-LIST"}), + content_type="application/json", + headers={"Authorization": f"Token {token}"}, + ) + assert response.status_code == 400 + + +# --------------------------------------------------------------------------- +# Async-shape guards — prevent silent regressions to sync handlers. +# --------------------------------------------------------------------------- + +def test_report_viewset_methods_are_coroutines(): + """Every dispatched method on ReportViewSet must be `async def`. + + `adrf.mixins.*ModelMixin` inherits from DRF's sync mixins, so the class + has both sync `create` and async `acreate` (etc.) on the MRO. ADRF's + `view_is_async` only flips the dispatcher to the async path when *all* + overrides are coroutines. + """ + views = importlib.import_module("radis.reports.api.viewsets") + vs = views.ReportViewSet + for name in ("acreate", "aretrieve", "aupdate", "adestroy", "bulk_upsert"): + assert inspect.iscoroutinefunction(getattr(vs, name)), ( + f"ReportViewSet.{name} must be async" + ) diff --git a/radis/settings/base.py b/radis/settings/base.py index 319f24853..dceef11d6 100644 --- a/radis/settings/base.py +++ b/radis/settings/base.py @@ -164,6 +164,9 @@ PGSEARCH_BULK_INSERT_BATCH_SIZE = env.int("PGSEARCH_BULK_INSERT_BATCH_SIZE", default=1000) PGSEARCH_SYNC_INDEXING = env.bool("PGSEARCH_SYNC_INDEXING", default=False) +# Report API bulk-upsert batch size (used by radis.reports.api.viewsets.bulk_upsert_reports) +REPORTS_BULK_DB_BATCH_SIZE = 1000 + # Default primary key field type # https://docs.djangoproject.com/en/5.0/ref/settings/#default-auto-field DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" @@ -319,9 +322,7 @@ }, "dbbackup": { "BACKEND": "django.core.files.storage.FileSystemStorage", - "OPTIONS": { - "location": env.str("DBBACKUP_STORAGE_LOCATION", default="/tmp/backups-radis") - }, + "OPTIONS": {"location": env.str("DBBACKUP_STORAGE_LOCATION", default="/tmp/backups-radis")}, }, } DBBACKUP_CLEANUP_KEEP = 30