This repository was archived by the owner on Jun 6, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
200 lines (170 loc) · 8.4 KB
/
model.py
File metadata and controls
200 lines (170 loc) · 8.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import torch # Only to disable autograd
from datetime import datetime
import time as time
import pathlib
import pickle
from .others.src.sgd import SGD
from .others.src.Layers.convolution import Conv2d
from .others.src.Layers.upsampling import Upsampling
from .others.src.Layers.relu import ReLU
from .others.src.Layers.sigmoid import Sigmoid
from .others.src.sequential import Sequential
from .others.src.Loss_functions.mse import MSE
from .others.src.Layers.nearest_neighbor_upsample import NNUpsampling
from .others.src.utils import waiting_bar
torch.set_grad_enabled(False)
class Model():
def __init__(self):
"""
Instantiates the model class.
:return: None
"""
# It avoids precision problems, as well as conversion
torch.set_default_dtype(torch.float64)
self.SGD = SGD(lr=3e-1)
self.Conv2d = Conv2d
self.ReLU = ReLU
self.Sigmoid = Sigmoid
self.NNUpsampling = NNUpsampling
self.Upsampling = Upsampling
#self.Sequential = Sequential(Conv2d(3,3,3,1,1,1,True), ReLU(), Conv2d(3,3,3,1,1,1,True), Sigmoid())
self.Sequential = Sequential(
Conv2d(in_channels=3, out_channels=6, stride=2, padding=1, dilation=1, kernel_size=3), ReLU(),
Conv2d(in_channels=6, out_channels=9, stride=2, padding=1, dilation=1, kernel_size=3), ReLU(),
Upsampling(scale_factor=2, in_channels=9, out_channels=6, kernel_size=3, transposeconvargs=False), ReLU(),
Upsampling(scale_factor=2, in_channels=6, out_channels=3, kernel_size=3, transposeconvargs=False) , Sigmoid())
self.MSE = MSE()
self.eval_step = 1
self.path = str(pathlib.Path(__file__).parent.resolve())
# To store the training logs
# First row: the epoch number
# Second row: the training error
# Third row: the validation error
self.logs = [[], [], []]
def load_pretrained_model(self):
"""
Loads best model from file bestmodel.pth
:return: None
"""
# The path needed when used in testing mode
#params_gradient = torch.load(self.path+"/bestmodel.pth")
with open(self.path+"/bestmodel.pth", "rb") as fp: # Unpickling
params_gradient = pickle.load(fp)
best_params = []
# params contain both the parameters and gradient, only extract the parameters of the best model
# Iterate on the layer's parameters
for layer_param in params_gradient:
# Check whether the list is empty
if layer_param:
intermediate_param = []
# Iterate on the parameters of the layer
for param, gradient in layer_param:
intermediate_param.append(param)
best_params.append(intermediate_param)
else:
# If no parameters, just return an empty list
best_params.append([])
# Assign the best parameters
self.Sequential.update_param(best_params)
def train(self, train_input, train_target, num_epochs=20, batch_size=32, validation=0.2):
"""
Trains the model.
:param train_input: Training data.
:param train_target: Train targets.
:return: None
"""
# Custom train/validation split - Start by shuffling
idx = torch.randperm(train_input.size()[0])
train_input = train_input[idx, :, :, :].to(torch.float64)
train_target = train_target[idx, :, :, :].to(torch.float64)
# Then take the last images as validation set (w.r.t. proportion)
split = int(validation * train_input.size(0))
# Training data is standardized by the DataLoader
val_input = (train_input[0:split] / 255)
val_target = (train_target[0:split] / 255)
train_input = (train_input[split:-1]) / 255
train_target = (train_target[split:-1] / 255)
nb_images_train = len(train_input)
nb_images_val = len(val_input)
# Monitor time taken
start = time.time()
print("Training started!")
# The loop on the epochs
for epoch in range(0, num_epochs):
idx = torch.randperm(nb_images_train)
for train_img, target_img in zip(torch.split(train_input[idx], batch_size),
torch.split(train_target[idx], batch_size)):
# Compute the predictions from the model
output = self.Sequential(train_img)
# Compute the loss from the predictions
loss = self.MSE.forward(output, target_img)
loss_grad = self.MSE.backward()
# Zero the gradient
self.Sequential.zero_grad()
# Compute the gradient
self.Sequential.backward(loss_grad)
# Compute the SGD and update the parameters
updated_params = self.SGD.step(self.Sequential.param())
# Assign the newly calculated parameters
self.Sequential.update_param(updated_params)
# Evaluate the model every eval_step
if (epoch + 1) % self.eval_step == 0:
eva_batch_size = 1000
train_error = 0.
val_error = 0.
# Computing the number of split to compute the mean of the error of each batch
if nb_images_train%eva_batch_size == 0:
nb_split_train = nb_images_train//eva_batch_size
else:
nb_split_train = nb_images_train // eva_batch_size + 1
if nb_images_val%eva_batch_size == 0:
nb_split_val = nb_images_val//eva_batch_size
else:
nb_split_val = nb_images_val // eva_batch_size + 1
train_zip = zip(torch.split(train_input, eva_batch_size),
torch.split(train_target, eva_batch_size))
val_zip = zip(torch.split(val_input, eva_batch_size), torch.split(val_target, eva_batch_size))
for train_img, target_img in train_zip:
train_error += self.MSE.forward(self.Sequential(train_img), target_img)
for val_img, val_img_target in val_zip:
val_error +=self.MSE.forward(self.Sequential(val_img), val_img_target)
train_error = train_error / nb_split_train
val_error = val_error / nb_split_val
self.logs[0].append(epoch+1)
self.logs[1].append(train_error)
self.logs[2].append(val_error)
waiting_bar(i=epoch+1, length=num_epochs, loss=(self.logs[1][-1], self.logs[2][-1]))
# Save the model - path name contains the parameters + date
date = datetime.now().strftime("%d%m%Y_%H%M%S")
path = str(self.SGD.lr) + "_" + str(batch_size) + "_" + date + ".pth"
with open(self.path +"/others/outputs/trained_models/"+ path, "wb") as fp:
pickle.dump(self.Sequential.param(), fp)
# Save the logs as well
self.logs = torch.tensor(self.logs)
with open(self.path + "/others/outputs/logs/" + path, "wb") as fp:
pickle.dump(self.logs, fp)
# Record and print time
end = time.time()
min = (end - start) // 60
sec = (end - start) % 60
print("\nTime taken for training: {:.0f} min {:.0f} s".format(min, sec))
del train_input, train_target
def predict(self, test_input):
"""
Predicts with the model on the provided input.
:param test_input: Test input.
:return: The prediction (torch.Tensor).
"""
out = self.Sequential(test_input.double()/255.0)
# Rescale the output between 0 and 255
return out.double()*255
def psnr(self, denoised, ground_truth):
"""
Computes the Peak Signal-to-Noise Ratio of a denoised image compared to the ground truth.
:param denoised: Denoised image. Must be in range [0, 1].
:param ground_truth: Ground truth image. Must be in range [0, 1].
:return: PSNR (0-dimensional torch.Tensor)
"""
assert denoised.shape == ground_truth.shape, "Denoised image and ground truth must have the same shape!"
assert denoised.shape == ground_truth.shape, "Denoised image and ground truth must have the same shape!"
return - 10 * torch.log10(((denoised-ground_truth) ** 2).mean((1, 2, 3))).mean()