-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
19 lines (13 loc) · 790 Bytes
/
utils.py
File metadata and controls
19 lines (13 loc) · 790 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def get_loaders(train_img_dir, train_mak_dir, val_img_dir, val_mask_dir,
train_transform, val_transform,
batch_size, num_workers, pin_memory=True):
train_ds = CarvanaDataset(image_dir=train_img_dir, mask_dir=train_mak_dir, transform=train_transform)
val_ds = CarvanaDataset(image_dir=val_img_dir, mask_dir=val_mask_dir,transform=val_transform)
train_loader = DataLoader(train_ds, batch_size = batch_size, num_workers = num_workers,pin_memory=pin_memory)
val_loader = DataLoader(val_ds, batch_size = batch_size, num_workers = num_workers,pin_memory = pin_memory)
return train_loader, val_loader