-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgpu_manager.py
More file actions
69 lines (59 loc) · 2.22 KB
/
gpu_manager.py
File metadata and controls
69 lines (59 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gc
import os
import threading
import time
from typing import Callable, Optional
import torch
from loguru import logger
class GPUManager:
def __init__(self, idle_timeout: int = 60):
self.idle_timeout = idle_timeout
self.last_used = time.time()
self.model = None
self.lock = threading.Lock()
self.timer = None
self.load_func = None
def get_model(self, load_func: Callable):
with self.lock:
self.last_used = time.time()
if self.model is None:
logger.info("Loading model to GPU...")
self.model = load_func()
self.load_func = load_func
self._reset_timer()
return self.model
def _reset_timer(self):
if self.timer:
self.timer.cancel()
self.timer = threading.Timer(self.idle_timeout, self._offload)
self.timer.daemon = True
self.timer.start()
def _offload(self):
with self.lock:
if time.time() - self.last_used >= self.idle_timeout:
logger.info("Offloading model from GPU...")
self.model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def force_offload(self):
with self.lock:
if self.model is not None:
logger.info("Force offloading model...")
self.model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def get_status(self):
with self.lock:
if torch.cuda.is_available():
mem_used = torch.cuda.memory_allocated() / 1024**3
mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
return {
"loaded": self.model is not None,
"gpu_memory_used": f"{mem_used:.2f}GB",
"gpu_memory_total": f"{mem_total:.2f}GB",
"last_used": time.time() - self.last_used,
}
return {"loaded": self.model is not None, "device": "cpu"}
gpu_manager = GPUManager(idle_timeout=int(os.getenv("GPU_IDLE_TIMEOUT", "60")))