Skip to content

Commit dec0e19

Browse files
committed
Merge #1130: PSv2 NATS connection pooling & retries
2 parents a8c5e47 + fa0f84b commit dec0e19

File tree

10 files changed

+929
-284
lines changed

10 files changed

+929
-284
lines changed

.agents/AGENTS.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,29 +107,34 @@ docker compose restart django celeryworker
107107

108108
### Backend (Django)
109109

110-
Run tests:
110+
Run tests (use `docker-compose.ci.yml` to avoid conflicts with the local dev stack):
111+
```bash
112+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test
113+
```
114+
115+
Run a specific test module:
111116
```bash
112-
docker compose run --rm django python manage.py test
117+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.ml.orchestration.tests.test_nats_connection
113118
```
114119

115120
Run specific test pattern:
116121
```bash
117-
docker compose run --rm django python manage.py test -k pattern
122+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern
118123
```
119124

120125
Run tests with debugger on failure:
121126
```bash
122-
docker compose run --rm django python manage.py test -k pattern --failfast --pdb
127+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern --failfast --pdb
123128
```
124129

125130
Speed up test development (reuse database):
126131
```bash
127-
docker compose run --rm django python manage.py test --keepdb
132+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test --keepdb
128133
```
129134

130135
Run pytest (alternative test runner):
131136
```bash
132-
docker compose run --rm django pytest --ds=config.settings.test --reuse-db
137+
docker compose -f docker-compose.ci.yml run --rm django pytest --ds=config.settings.test --reuse-db
133138
```
134139

135140
Django shell:
@@ -654,13 +659,13 @@ images = SourceImage.objects.annotate(det_count=Count('detections'))
654659

655660
```bash
656661
# Run specific test class
657-
docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase
662+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase
658663

659664
# Run specific test method
660-
docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation
665+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation
661666

662667
# Run with pattern matching
663-
docker compose run --rm django python manage.py test -k test_detection
668+
docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k test_detection
664669
```
665670

666671
### Pre-commit Hooks

ami/jobs/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None:
135135
try:
136136

137137
async def ack_task():
138-
async with TaskQueueManager() as manager:
139-
return await manager.acknowledge_task(reply_subject)
138+
manager = TaskQueueManager()
139+
return await manager.acknowledge_task(reply_subject)
140140

141141
ack_success = async_to_sync(ack_task)()
142142

ami/jobs/test_tasks.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,8 @@ def tearDown(self):
7373

7474
def _setup_mock_nats(self, mock_manager_class):
7575
"""Helper to setup mock NATS manager."""
76-
mock_manager = AsyncMock()
76+
mock_manager = mock_manager_class.return_value
7777
mock_manager.acknowledge_task = AsyncMock(return_value=True)
78-
mock_manager_class.return_value.__aenter__.return_value = mock_manager
79-
mock_manager_class.return_value.__aexit__.return_value = AsyncMock()
8078
return mock_manager
8179

8280
def _create_error_result(self, image_id: str | None = None, error_msg: str = "Processing failed") -> dict:

ami/jobs/views.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,11 @@ def tasks(self, request, pk=None):
242242

243243
async def get_tasks():
244244
tasks = []
245-
async with TaskQueueManager() as manager:
246-
for _ in range(batch):
247-
task = await manager.reserve_task(job.pk, timeout=0.1)
248-
if task:
249-
tasks.append(task.dict())
245+
manager = TaskQueueManager()
246+
for _ in range(batch):
247+
task = await manager.reserve_task(job.pk, timeout=0.1)
248+
if task:
249+
tasks.append(task.dict())
250250
return tasks
251251

252252
# Use async_to_sync to properly handle the async call

ami/ml/orchestration/jobs.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def cleanup_async_job_resources(job: "Job") -> bool:
4040

4141
# Cleanup NATS resources
4242
async def cleanup():
43-
async with TaskQueueManager() as manager:
44-
return await manager.cleanup_job_resources(job.pk)
43+
manager = TaskQueueManager()
44+
return await manager.cleanup_job_resources(job.pk)
4545

4646
try:
4747
nats_success = async_to_sync(cleanup)()
@@ -96,22 +96,22 @@ async def queue_all_images():
9696
successful_queues = 0
9797
failed_queues = 0
9898

99-
async with TaskQueueManager() as manager:
100-
for image_pk, task in tasks:
101-
try:
102-
logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}")
103-
success = await manager.publish_task(
104-
job_id=job.pk,
105-
data=task,
106-
)
107-
except Exception as e:
108-
logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}")
109-
success = False
110-
111-
if success:
112-
successful_queues += 1
113-
else:
114-
failed_queues += 1
99+
manager = TaskQueueManager()
100+
for image_pk, task in tasks:
101+
try:
102+
logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}")
103+
success = await manager.publish_task(
104+
job_id=job.pk,
105+
data=task,
106+
)
107+
except Exception as e:
108+
logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}")
109+
success = False
110+
111+
if success:
112+
successful_queues += 1
113+
else:
114+
failed_queues += 1
115115

116116
return successful_queues, failed_queues
117117

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
NATS connection management for both Celery workers and Django processes.
3+
4+
Provides a ConnectionPool keyed by event loop. The pool reuses a single NATS
5+
connection for all async operations *within* one async_to_sync() boundary.
6+
It does NOT provide reuse across separate async_to_sync() calls — each call
7+
creates a new event loop, so a new connection is established.
8+
9+
Where the pool helps:
10+
The main beneficiary is queue_images_to_nats() in jobs.py, which wraps
11+
1000+ publish_task() awaits in a single async_to_sync() call. All of those
12+
awaits share one event loop and therefore one NATS connection. Without the
13+
pool, each publish would open its own TCP connection (~1500 per job).
14+
Similarly, JobViewSet.tasks() batches multiple reserve_task() calls in one
15+
async_to_sync() boundary.
16+
17+
Where it doesn't help:
18+
Single-operation boundaries like _ack_task_via_nats() (one ACK per call)
19+
get no reuse — the pool is effectively single-use there. The overhead is
20+
negligible (one dict lookup), and the retry_on_connection_error decorator
21+
provides resilience regardless.
22+
23+
Why keyed by event loop:
24+
asyncio.Lock and nats.Client are bound to the loop they were created on.
25+
Sharing them across loops causes "attached to a different loop" errors.
26+
Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops
27+
are garbage collected, so short-lived loops don't leak.
28+
29+
Archived alternative:
30+
ContextManagerConnection preserves the original pre-pool implementation
31+
(one connection per `async with` block) as a drop-in fallback.
32+
"""
33+
34+
import asyncio
35+
import logging
36+
import threading
37+
from typing import TYPE_CHECKING
38+
from weakref import WeakKeyDictionary
39+
40+
import nats
41+
from django.conf import settings
42+
from nats.js import JetStreamContext
43+
44+
if TYPE_CHECKING:
45+
from nats.aio.client import Client as NATSClient
46+
47+
logger = logging.getLogger(__name__)
48+
49+
50+
class ConnectionPool:
51+
"""
52+
Manages a single persistent NATS connection per event loop.
53+
54+
This is safe because:
55+
- asyncio.Lock and NATS Client are bound to the event loop they were created on
56+
- Each event loop gets its own isolated connection and lock
57+
- Works correctly with async_to_sync() which creates per-thread event loops
58+
- Prevents "attached to a different loop" errors in Celery tasks and Django views
59+
60+
Instantiating TaskQueueManager() is cheap — multiple instances share the same
61+
underlying connection via this pool.
62+
"""
63+
64+
def __init__(self):
65+
self._nc: "NATSClient | None" = None
66+
self._js: JetStreamContext | None = None
67+
self._lock: asyncio.Lock | None = None # Lazy-initialized when needed
68+
69+
def _ensure_lock(self) -> asyncio.Lock:
70+
"""Lazily create lock bound to current event loop."""
71+
if self._lock is None:
72+
self._lock = asyncio.Lock()
73+
return self._lock
74+
75+
async def get_connection(self) -> tuple["NATSClient", JetStreamContext]:
76+
"""
77+
Get or create the event loop's NATS connection. Checks connection health
78+
and recreates if stale.
79+
80+
Returns:
81+
Tuple of (NATS connection, JetStream context)
82+
Raises:
83+
RuntimeError: If connection cannot be established
84+
"""
85+
# Fast path (no lock needed): connection exists, is open, and is connected.
86+
# This is the hot path — most calls hit this and return immediately.
87+
if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected:
88+
return self._nc, self._js
89+
90+
# Connection is stale or doesn't exist — clear references before reconnecting
91+
if self._nc is not None:
92+
logger.warning("NATS connection is closed or disconnected, will reconnect")
93+
self._nc = None
94+
self._js = None
95+
96+
# Slow path: acquire lock to prevent concurrent reconnection attempts
97+
lock = self._ensure_lock()
98+
async with lock:
99+
# Double-check after acquiring lock (another coroutine may have reconnected)
100+
if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected:
101+
return self._nc, self._js
102+
103+
nats_url = settings.NATS_URL
104+
try:
105+
logger.info(f"Creating NATS connection to {nats_url}")
106+
self._nc = await nats.connect(nats_url)
107+
self._js = self._nc.jetstream()
108+
logger.info(f"Successfully connected to NATS at {nats_url}")
109+
return self._nc, self._js
110+
except Exception as e:
111+
logger.error(f"Failed to connect to NATS: {e}")
112+
raise RuntimeError(f"Could not establish NATS connection: {e}") from e
113+
114+
async def close(self):
115+
"""Close the NATS connection if it exists."""
116+
if self._nc is not None and not self._nc.is_closed:
117+
logger.info("Closing NATS connection")
118+
await self._nc.close()
119+
self._nc = None
120+
self._js = None
121+
122+
async def reset(self):
123+
"""
124+
Close the current connection and clear all state so the next call to
125+
get_connection() creates a fresh one.
126+
127+
Called by retry_on_connection_error when an operation hits a connection
128+
error (e.g. network blip, NATS restart). The lock is also cleared so it
129+
gets recreated bound to the current event loop.
130+
"""
131+
logger.warning("Resetting NATS connection pool due to connection error")
132+
if self._nc is not None:
133+
try:
134+
if not self._nc.is_closed:
135+
await self._nc.close()
136+
logger.debug("Successfully closed existing NATS connection during reset")
137+
except Exception as e:
138+
# Swallow errors - connection may already be broken
139+
logger.debug(f"Error closing connection during reset (expected): {e}")
140+
self._nc = None
141+
self._js = None
142+
self._lock = None # Clear lock so new one is created for fresh connection
143+
144+
145+
class ContextManagerConnection:
146+
"""
147+
Archived pre-pool implementation: one NATS connection per `async with` block.
148+
149+
This was the original approach before the connection pool was added. It creates
150+
a fresh connection on get_connection() and expects the caller to close it when
151+
done. There is no connection reuse and no retry logic at this layer.
152+
153+
Trade-offs vs ConnectionPool:
154+
- Simpler: no shared state, no locking, no event-loop keying
155+
- Expensive: ~1500 TCP connections per 1000-image job vs 1 with the pool
156+
- No automatic reconnection — caller must handle connection failures
157+
158+
Kept as a drop-in fallback. To switch, change the class used in
159+
_create_pool() below from ConnectionPool to ContextManagerConnection.
160+
"""
161+
162+
async def get_connection(self) -> tuple["NATSClient", JetStreamContext]:
163+
"""Create a fresh NATS connection."""
164+
nats_url = settings.NATS_URL
165+
try:
166+
logger.debug(f"Creating per-operation NATS connection to {nats_url}")
167+
nc = await nats.connect(nats_url)
168+
js = nc.jetstream()
169+
return nc, js
170+
except Exception as e:
171+
logger.error(f"Failed to connect to NATS: {e}")
172+
raise RuntimeError(f"Could not establish NATS connection: {e}") from e
173+
174+
async def close(self):
175+
"""No-op — connections are not tracked."""
176+
pass
177+
178+
async def reset(self):
179+
"""No-op — connections are not tracked."""
180+
pass
181+
182+
183+
# Event-loop-keyed pools: one ConnectionPool per event loop.
184+
# WeakKeyDictionary automatically cleans up when event loops are garbage collected.
185+
_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary()
186+
_pools_lock = threading.Lock()
187+
188+
189+
def _get_pool() -> ConnectionPool:
190+
"""Get or create the ConnectionPool for the current event loop."""
191+
try:
192+
loop = asyncio.get_running_loop()
193+
except RuntimeError:
194+
raise RuntimeError(
195+
"get_connection() must be called from an async context with a running event loop. "
196+
"If calling from sync code, use async_to_sync() to wrap the async function."
197+
) from None
198+
199+
with _pools_lock:
200+
if loop not in _pools:
201+
_pools[loop] = ConnectionPool()
202+
logger.debug(f"Created NATS connection pool for event loop {id(loop)}")
203+
return _pools[loop]
204+
205+
206+
async def get_connection() -> tuple["NATSClient", JetStreamContext]:
207+
"""
208+
Get or create a NATS connection for the current event loop.
209+
210+
Returns:
211+
Tuple of (NATS connection, JetStream context)
212+
Raises:
213+
RuntimeError: If called outside of an async context (no running event loop)
214+
"""
215+
pool = _get_pool()
216+
return await pool.get_connection()
217+
218+
219+
async def reset_connection() -> None:
220+
"""
221+
Reset the NATS connection for the current event loop.
222+
223+
Closes the current connection and clears all state so the next call to
224+
get_connection() creates a fresh one.
225+
"""
226+
pool = _get_pool()
227+
await pool.reset()

0 commit comments

Comments
 (0)