diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index 29785c197..f3c24e532 100755 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -39,6 +39,10 @@ def __getitem__(self, index): params = get_params(self.opt, A.size) if self.opt.label_nc == 0: transform_A = get_transform(self.opt, params) + if self.opt.input_nc == 3: + A = A.convert('RGB') + elif self.opt.input_nc == 1: + A = A.convert('L') A_tensor = transform_A(A.convert('RGB')) else: transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) @@ -47,8 +51,12 @@ def __getitem__(self, index): B_tensor = inst_tensor = feat_tensor = 0 ### input B (real images) if self.opt.isTrain or self.opt.use_encoded_image: - B_path = self.B_paths[index] - B = Image.open(B_path).convert('RGB') + B_path = self.B_paths[index] + B = Image.open(B_path) + if self.opt.output_nc == 3: + B = B.convert('RGB') + elif self.opt.output_nc == 1: + B = B.convert('L') transform_B = get_transform(self.opt, params) B_tensor = transform_B(B) @@ -73,4 +81,4 @@ def __len__(self): return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize def name(self): - return 'AlignedDataset' \ No newline at end of file + return 'AlignedDataset' diff --git a/data/base_dataset.py b/data/base_dataset.py index ece8813db..89e39a2c7 100755 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -34,7 +34,7 @@ def get_transform(opt, params, method=Image.BICUBIC, normalize=True): transform_list = [] if 'resize' in opt.resize_or_crop: osize = [opt.loadSize, opt.loadSize] - transform_list.append(transforms.Scale(osize, method)) + transform_list.append(transforms.Resize(osize, method)) elif 'scale_width' in opt.resize_or_crop: transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) @@ -53,8 +53,8 @@ def get_transform(opt, params, method=Image.BICUBIC, normalize=True): transform_list += [transforms.ToTensor()] if normalize: - transform_list += [transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] + transform_list += [transforms.Normalize((0.5,) * opt.output_nc, + (0.5,) * opt.output_nc)] return transforms.Compose(transform_list) def normalize(): diff --git a/models/networks.py b/models/networks.py index ee05d85d8..2e7ff8af4 100755 --- a/models/networks.py +++ b/models/networks.py @@ -316,7 +316,7 @@ def singleD_forward(self, model, input): else: return [model(input)] - def forward(self, input): + def forward(self, input): num_D = self.num_D result = [] input_downsampled = input @@ -407,6 +407,9 @@ def __init__(self, requires_grad=False): param.requires_grad = False def forward(self, X): + # vgg19 assumes 3 input channels. + if X.shape[1] == 1: + X = X.expand(-1, 3, -1, -1) h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index fafdec0b7..5de641718 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -151,7 +151,7 @@ def discriminate(self, input_label, test_image, use_pool=False): def forward(self, label, inst, image, feat, infer=False): # Encode Inputs - input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) + input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) # Fake Generation if self.use_features: @@ -166,7 +166,7 @@ def forward(self, label, inst, image, feat, infer=False): pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) - # Real Detection and Loss + # Real Detection and Loss pred_real = self.discriminate(input_label, real_image) loss_D_real = self.criterionGAN(pred_real, True) diff --git a/train.py b/train.py index acedac25b..8d2655b22 100755 --- a/train.py +++ b/train.py @@ -5,8 +5,8 @@ from torch.autograd import Variable from collections import OrderedDict from subprocess import call -import fractions -def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0 +import math +def lcm(a,b): return abs(a * b)/math.gcd(a,b) if a and b else 0 from options.train_options import TrainOptions from data.data_loader import CreateDataLoader