|
| 1 | +""" |
| 2 | +NATS connection management for both Celery workers and Django processes. |
| 3 | +
|
| 4 | +Provides a ConnectionPool keyed by event loop. The pool reuses a single NATS |
| 5 | +connection for all async operations *within* one async_to_sync() boundary. |
| 6 | +It does NOT provide reuse across separate async_to_sync() calls — each call |
| 7 | +creates a new event loop, so a new connection is established. |
| 8 | +
|
| 9 | +Where the pool helps: |
| 10 | + The main beneficiary is queue_images_to_nats() in jobs.py, which wraps |
| 11 | + 1000+ publish_task() awaits in a single async_to_sync() call. All of those |
| 12 | + awaits share one event loop and therefore one NATS connection. Without the |
| 13 | + pool, each publish would open its own TCP connection (~1500 per job). |
| 14 | + Similarly, JobViewSet.tasks() batches multiple reserve_task() calls in one |
| 15 | + async_to_sync() boundary. |
| 16 | +
|
| 17 | +Where it doesn't help: |
| 18 | + Single-operation boundaries like _ack_task_via_nats() (one ACK per call) |
| 19 | + get no reuse — the pool is effectively single-use there. The overhead is |
| 20 | + negligible (one dict lookup), and the retry_on_connection_error decorator |
| 21 | + provides resilience regardless. |
| 22 | +
|
| 23 | +Why keyed by event loop: |
| 24 | + asyncio.Lock and nats.Client are bound to the loop they were created on. |
| 25 | + Sharing them across loops causes "attached to a different loop" errors. |
| 26 | + Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops |
| 27 | + are garbage collected, so short-lived loops don't leak. |
| 28 | +
|
| 29 | +Archived alternative: |
| 30 | + ContextManagerConnection preserves the original pre-pool implementation |
| 31 | + (one connection per `async with` block) as a drop-in fallback. |
| 32 | +""" |
| 33 | + |
| 34 | +import asyncio |
| 35 | +import logging |
| 36 | +import threading |
| 37 | +from typing import TYPE_CHECKING |
| 38 | +from weakref import WeakKeyDictionary |
| 39 | + |
| 40 | +import nats |
| 41 | +from django.conf import settings |
| 42 | +from nats.js import JetStreamContext |
| 43 | + |
| 44 | +if TYPE_CHECKING: |
| 45 | + from nats.aio.client import Client as NATSClient |
| 46 | + |
| 47 | +logger = logging.getLogger(__name__) |
| 48 | + |
| 49 | + |
| 50 | +class ConnectionPool: |
| 51 | + """ |
| 52 | + Manages a single persistent NATS connection per event loop. |
| 53 | +
|
| 54 | + This is safe because: |
| 55 | + - asyncio.Lock and NATS Client are bound to the event loop they were created on |
| 56 | + - Each event loop gets its own isolated connection and lock |
| 57 | + - Works correctly with async_to_sync() which creates per-thread event loops |
| 58 | + - Prevents "attached to a different loop" errors in Celery tasks and Django views |
| 59 | +
|
| 60 | + Instantiating TaskQueueManager() is cheap — multiple instances share the same |
| 61 | + underlying connection via this pool. |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__(self): |
| 65 | + self._nc: "NATSClient | None" = None |
| 66 | + self._js: JetStreamContext | None = None |
| 67 | + self._lock: asyncio.Lock | None = None # Lazy-initialized when needed |
| 68 | + |
| 69 | + def _ensure_lock(self) -> asyncio.Lock: |
| 70 | + """Lazily create lock bound to current event loop.""" |
| 71 | + if self._lock is None: |
| 72 | + self._lock = asyncio.Lock() |
| 73 | + return self._lock |
| 74 | + |
| 75 | + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: |
| 76 | + """ |
| 77 | + Get or create the event loop's NATS connection. Checks connection health |
| 78 | + and recreates if stale. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + Tuple of (NATS connection, JetStream context) |
| 82 | + Raises: |
| 83 | + RuntimeError: If connection cannot be established |
| 84 | + """ |
| 85 | + # Fast path (no lock needed): connection exists, is open, and is connected. |
| 86 | + # This is the hot path — most calls hit this and return immediately. |
| 87 | + if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: |
| 88 | + return self._nc, self._js |
| 89 | + |
| 90 | + # Connection is stale or doesn't exist — clear references before reconnecting |
| 91 | + if self._nc is not None: |
| 92 | + logger.warning("NATS connection is closed or disconnected, will reconnect") |
| 93 | + self._nc = None |
| 94 | + self._js = None |
| 95 | + |
| 96 | + # Slow path: acquire lock to prevent concurrent reconnection attempts |
| 97 | + lock = self._ensure_lock() |
| 98 | + async with lock: |
| 99 | + # Double-check after acquiring lock (another coroutine may have reconnected) |
| 100 | + if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: |
| 101 | + return self._nc, self._js |
| 102 | + |
| 103 | + nats_url = settings.NATS_URL |
| 104 | + try: |
| 105 | + logger.info(f"Creating NATS connection to {nats_url}") |
| 106 | + self._nc = await nats.connect(nats_url) |
| 107 | + self._js = self._nc.jetstream() |
| 108 | + logger.info(f"Successfully connected to NATS at {nats_url}") |
| 109 | + return self._nc, self._js |
| 110 | + except Exception as e: |
| 111 | + logger.error(f"Failed to connect to NATS: {e}") |
| 112 | + raise RuntimeError(f"Could not establish NATS connection: {e}") from e |
| 113 | + |
| 114 | + async def close(self): |
| 115 | + """Close the NATS connection if it exists.""" |
| 116 | + if self._nc is not None and not self._nc.is_closed: |
| 117 | + logger.info("Closing NATS connection") |
| 118 | + await self._nc.close() |
| 119 | + self._nc = None |
| 120 | + self._js = None |
| 121 | + |
| 122 | + async def reset(self): |
| 123 | + """ |
| 124 | + Close the current connection and clear all state so the next call to |
| 125 | + get_connection() creates a fresh one. |
| 126 | +
|
| 127 | + Called by retry_on_connection_error when an operation hits a connection |
| 128 | + error (e.g. network blip, NATS restart). The lock is also cleared so it |
| 129 | + gets recreated bound to the current event loop. |
| 130 | + """ |
| 131 | + logger.warning("Resetting NATS connection pool due to connection error") |
| 132 | + if self._nc is not None: |
| 133 | + try: |
| 134 | + if not self._nc.is_closed: |
| 135 | + await self._nc.close() |
| 136 | + logger.debug("Successfully closed existing NATS connection during reset") |
| 137 | + except Exception as e: |
| 138 | + # Swallow errors - connection may already be broken |
| 139 | + logger.debug(f"Error closing connection during reset (expected): {e}") |
| 140 | + self._nc = None |
| 141 | + self._js = None |
| 142 | + self._lock = None # Clear lock so new one is created for fresh connection |
| 143 | + |
| 144 | + |
| 145 | +class ContextManagerConnection: |
| 146 | + """ |
| 147 | + Archived pre-pool implementation: one NATS connection per `async with` block. |
| 148 | +
|
| 149 | + This was the original approach before the connection pool was added. It creates |
| 150 | + a fresh connection on get_connection() and expects the caller to close it when |
| 151 | + done. There is no connection reuse and no retry logic at this layer. |
| 152 | +
|
| 153 | + Trade-offs vs ConnectionPool: |
| 154 | + - Simpler: no shared state, no locking, no event-loop keying |
| 155 | + - Expensive: ~1500 TCP connections per 1000-image job vs 1 with the pool |
| 156 | + - No automatic reconnection — caller must handle connection failures |
| 157 | +
|
| 158 | + Kept as a drop-in fallback. To switch, change the class used in |
| 159 | + _create_pool() below from ConnectionPool to ContextManagerConnection. |
| 160 | + """ |
| 161 | + |
| 162 | + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: |
| 163 | + """Create a fresh NATS connection.""" |
| 164 | + nats_url = settings.NATS_URL |
| 165 | + try: |
| 166 | + logger.debug(f"Creating per-operation NATS connection to {nats_url}") |
| 167 | + nc = await nats.connect(nats_url) |
| 168 | + js = nc.jetstream() |
| 169 | + return nc, js |
| 170 | + except Exception as e: |
| 171 | + logger.error(f"Failed to connect to NATS: {e}") |
| 172 | + raise RuntimeError(f"Could not establish NATS connection: {e}") from e |
| 173 | + |
| 174 | + async def close(self): |
| 175 | + """No-op — connections are not tracked.""" |
| 176 | + pass |
| 177 | + |
| 178 | + async def reset(self): |
| 179 | + """No-op — connections are not tracked.""" |
| 180 | + pass |
| 181 | + |
| 182 | + |
| 183 | +# Event-loop-keyed pools: one ConnectionPool per event loop. |
| 184 | +# WeakKeyDictionary automatically cleans up when event loops are garbage collected. |
| 185 | +_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary() |
| 186 | +_pools_lock = threading.Lock() |
| 187 | + |
| 188 | + |
| 189 | +def _get_pool() -> ConnectionPool: |
| 190 | + """Get or create the ConnectionPool for the current event loop.""" |
| 191 | + try: |
| 192 | + loop = asyncio.get_running_loop() |
| 193 | + except RuntimeError: |
| 194 | + raise RuntimeError( |
| 195 | + "get_connection() must be called from an async context with a running event loop. " |
| 196 | + "If calling from sync code, use async_to_sync() to wrap the async function." |
| 197 | + ) from None |
| 198 | + |
| 199 | + with _pools_lock: |
| 200 | + if loop not in _pools: |
| 201 | + _pools[loop] = ConnectionPool() |
| 202 | + logger.debug(f"Created NATS connection pool for event loop {id(loop)}") |
| 203 | + return _pools[loop] |
| 204 | + |
| 205 | + |
| 206 | +async def get_connection() -> tuple["NATSClient", JetStreamContext]: |
| 207 | + """ |
| 208 | + Get or create a NATS connection for the current event loop. |
| 209 | +
|
| 210 | + Returns: |
| 211 | + Tuple of (NATS connection, JetStream context) |
| 212 | + Raises: |
| 213 | + RuntimeError: If called outside of an async context (no running event loop) |
| 214 | + """ |
| 215 | + pool = _get_pool() |
| 216 | + return await pool.get_connection() |
| 217 | + |
| 218 | + |
| 219 | +async def reset_connection() -> None: |
| 220 | + """ |
| 221 | + Reset the NATS connection for the current event loop. |
| 222 | +
|
| 223 | + Closes the current connection and clears all state so the next call to |
| 224 | + get_connection() creates a fresh one. |
| 225 | + """ |
| 226 | + pool = _get_pool() |
| 227 | + await pool.reset() |
0 commit comments