Skip to content

Commit 281e931

Browse files
mihowclaude
andcommitted
refactor: extract check_stale_jobs() for reuse by periodic task
Move core stale-job logic from management command into check_stale_jobs() in tasks.py. The management command is now a thin wrapper. Add tests for the extracted function. This prepares for #1025 which will call check_stale_jobs() from a Celery Beat periodic task. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 703400b commit 281e931

3 files changed

Lines changed: 127 additions & 37 deletions

File tree

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
1-
from celery import states
2-
from celery.result import AsyncResult
31
from django.core.management.base import BaseCommand
4-
from django.utils import timezone
52

6-
from ami.jobs.models import Job, JobState
7-
from ami.jobs.tasks import cleanup_async_job_if_needed
8-
9-
# Celery returns PENDING for tasks it has no record of.
10-
# These are the states that indicate a real, known task status.
11-
KNOWN_CELERY_STATES = frozenset(states.ALL_STATES) - {states.PENDING}
3+
from ami.jobs.models import Job
4+
from ami.jobs.tasks import check_stale_jobs
125

136

147
class Command(BaseCommand):
@@ -28,36 +21,17 @@ def add_arguments(self, parser):
2821
)
2922

3023
def handle(self, *args, **options):
31-
cutoff = timezone.now() - timezone.timedelta(hours=options["hours"])
32-
stale_jobs = Job.objects.filter(
33-
status__in=JobState.running_states(),
34-
updated_at__lt=cutoff,
35-
)
24+
results = check_stale_jobs(hours=options["hours"], dry_run=options["dry_run"])
3625

37-
if not stale_jobs.exists():
26+
if not results:
3827
self.stdout.write("No stale jobs found.")
3928
return
4029

41-
for job in stale_jobs:
42-
celery_state = None
43-
if job.task_id:
44-
celery_state = AsyncResult(job.task_id).state
45-
46-
if celery_state in KNOWN_CELERY_STATES:
47-
# Celery has a real status for this task — use it
48-
if options["dry_run"]:
49-
self.stdout.write(f" [dry-run] Job {job.pk}: would update to {celery_state} (from Celery)")
50-
continue
51-
job.update_status(celery_state, save=False)
52-
job.save()
53-
self.stdout.write(self.style.SUCCESS(f"Job {job.pk}: updated to {celery_state} (from Celery)"))
30+
prefix = "[dry-run] " if options["dry_run"] else ""
31+
for r in results:
32+
if r["action"] == "updated":
33+
self.stdout.write(
34+
self.style.SUCCESS(f"{prefix}Job {r['job_id']}: updated to {r['state']} (from Celery)")
35+
)
5436
else:
55-
# No task_id, or Celery has no record (returns PENDING) — revoke
56-
if options["dry_run"]:
57-
self.stdout.write(f" [dry-run] Job {job.pk} ({job.status}): would revoke and clean up")
58-
continue
59-
job.update_status(JobState.REVOKED, save=False)
60-
job.finished_at = timezone.now()
61-
job.save()
62-
cleanup_async_job_if_needed(job)
63-
self.stdout.write(self.style.WARNING(f"Job {job.pk}: revoked (no known Celery state)"))
37+
self.stdout.write(self.style.WARNING(f"{prefix}Job {r['job_id']}: revoked (no known Celery state)"))

ami/jobs/tasks.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,56 @@ def _update_job_progress(
316316
cleanup_async_job_if_needed(job)
317317

318318

319+
def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[dict]:
320+
"""
321+
Find jobs stuck in a running state past the cutoff and revoke them.
322+
323+
For each stale job, checks Celery for a real task status. If Celery has one
324+
(e.g. SUCCESS, FAILURE), uses that. Otherwise revokes the job and cleans up
325+
any async resources (NATS/Redis).
326+
327+
Returns a list of dicts describing what was done to each job.
328+
"""
329+
import datetime
330+
331+
from celery import states
332+
from celery.result import AsyncResult
333+
334+
from ami.jobs.models import Job, JobState
335+
336+
if hours is None:
337+
hours = Job.FAILED_CUTOFF_HOURS
338+
339+
known_celery_states = frozenset(states.ALL_STATES) - {states.PENDING}
340+
341+
cutoff = datetime.datetime.now() - datetime.timedelta(hours=hours)
342+
stale_jobs = Job.objects.filter(
343+
status__in=JobState.running_states(),
344+
updated_at__lt=cutoff,
345+
)
346+
347+
results = []
348+
for job in stale_jobs:
349+
celery_state = None
350+
if job.task_id:
351+
celery_state = AsyncResult(job.task_id).state
352+
353+
if celery_state in known_celery_states:
354+
if not dry_run:
355+
job.update_status(celery_state, save=False)
356+
job.save()
357+
results.append({"job_id": job.pk, "action": "updated", "state": celery_state})
358+
else:
359+
if not dry_run:
360+
job.update_status(JobState.REVOKED, save=False)
361+
job.finished_at = datetime.datetime.now()
362+
job.save()
363+
cleanup_async_job_if_needed(job)
364+
results.append({"job_id": job.pk, "action": "revoked", "previous_status": job.status})
365+
366+
return results
367+
368+
319369
def cleanup_async_job_if_needed(job) -> None:
320370
"""
321371
Clean up async resources (NATS/Redis) if this job uses them.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from datetime import timedelta
2+
from unittest.mock import patch
3+
4+
from django.test import TestCase
5+
from django.utils import timezone
6+
7+
from ami.jobs.models import Job, JobState
8+
from ami.jobs.tasks import check_stale_jobs
9+
from ami.main.models import Project
10+
11+
12+
class CheckStaleJobsTest(TestCase):
13+
def setUp(self):
14+
self.project = Project.objects.create(name="Stale jobs test project")
15+
16+
def _create_job(self, status=JobState.STARTED, hours_ago=100, task_id=None):
17+
job = Job.objects.create(
18+
project=self.project,
19+
name=f"Test job {status}",
20+
status=status,
21+
)
22+
Job.objects.filter(pk=job.pk).update(
23+
updated_at=timezone.now() - timedelta(hours=hours_ago),
24+
)
25+
if task_id is not None:
26+
Job.objects.filter(pk=job.pk).update(task_id=task_id)
27+
job.refresh_from_db()
28+
return job
29+
30+
@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
31+
def test_dry_run(self, mock_cleanup):
32+
"""Dry run returns results without modifying jobs."""
33+
job = self._create_job(status=JobState.STARTED)
34+
35+
results = check_stale_jobs(dry_run=True)
36+
37+
self.assertEqual(len(results), 1)
38+
self.assertEqual(results[0]["action"], "revoked")
39+
job.refresh_from_db()
40+
self.assertEqual(job.status, JobState.STARTED.value)
41+
mock_cleanup.assert_not_called()
42+
43+
@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
44+
def test_revokes_stale_job(self, mock_cleanup):
45+
"""Stale job without a known Celery state is revoked and cleaned up."""
46+
job = self._create_job(status=JobState.STARTED)
47+
48+
results = check_stale_jobs()
49+
50+
self.assertEqual(len(results), 1)
51+
self.assertEqual(results[0]["action"], "revoked")
52+
job.refresh_from_db()
53+
self.assertEqual(job.status, JobState.REVOKED.value)
54+
self.assertIsNotNone(job.finished_at)
55+
mock_cleanup.assert_called_once_with(job)
56+
57+
@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
58+
def test_skips_recent_and_final_state_jobs(self, mock_cleanup):
59+
"""Recent jobs and jobs in final states are not touched."""
60+
self._create_job(status=JobState.STARTED, hours_ago=1) # recent
61+
self._create_job(status=JobState.SUCCESS, hours_ago=200) # final state
62+
63+
results = check_stale_jobs()
64+
65+
self.assertEqual(results, [])
66+
mock_cleanup.assert_not_called()

0 commit comments

Comments
 (0)