Skip to content

Commit aa3ef62

Browse files
committed
Merge branch 'dev'
2 parents 8bcebd6 + 6be02ad commit aa3ef62

3 files changed

Lines changed: 11 additions & 8 deletions

File tree

docs/history.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22
2025.12.23 Add code of PS-Seg
33

44
To do:
5-
(2025.8.25) data augmentation in Crop4vox2vec should be implemented in external transforms
5+
(2025.8.25) data augmentation in Crop4vox2vec should be implemented in external transforms
6+
2026.4.28 In windows, the multiple process has some problem --> use a new pytorch version
7+
dependency should be added in requirement: imops, monai, ml_collections, einops, timm

pymic/net_run/agent_abstract.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def seed_torch(seed=1):
2525
torch.backends.cudnn.benchmark = False
2626
torch.backends.cudnn.deterministic = True
2727

28+
def worker_init_fn(worker_id):
29+
# workder_seed = self.random_seed+worker_id
30+
workder_seed = torch.initial_seed() % 2 ** 32
31+
np.random.seed(workder_seed)
32+
random.seed(workder_seed)
33+
2834
class NetRunAgent(object):
2935
"""
3036
The abstract class for medical image segmentation.
@@ -273,12 +279,7 @@ def create_dataset(self):
273279
self.valid_set = self.get_stage_dataset_from_config('valid')
274280
else:
275281
logging.warning("Dataset for validation is not created, as valid_dir is not provided.")
276-
if(self.deterministic):
277-
def worker_init_fn(worker_id):
278-
# workder_seed = self.random_seed+worker_id
279-
workder_seed = torch.initial_seed() % 2 ** 32
280-
np.random.seed(workder_seed)
281-
random.seed(workder_seed)
282+
if(self.deterministic):
282283
worker_init = worker_init_fn
283284
else:
284285
worker_init = None

pymic/net_run/agent_seg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def test_time_dropout(m):
481481

482482
# load network parameters and set the network as evaluation mode
483483
print("ckpt name", ckpt_name)
484-
checkpoint = torch.load(ckpt_name, map_location = device)
484+
checkpoint = torch.load(ckpt_name, map_location = device, weights_only = False)
485485
self.net.load_state_dict(checkpoint['model_state_dict'])
486486

487487
if(self.inferer is None):

0 commit comments

Comments
 (0)