-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
49 lines (40 loc) · 1.43 KB
/
utils.py
File metadata and controls
49 lines (40 loc) · 1.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""This module implements functions used in multiple scripts."""
import torch
from model import Decoder, Encoder
from survae.distributions import (ConditionalMeanStdNormal, ConditionalNormal,
StandardNormal)
from survae.flows import Flow
from survae.transforms import VAE
def load_model(input_dim: tuple, latent_dim: int, checkpoint: str = None,
device: str = "cpu") -> Flow:
"""Load the model.
This method initializes the Encoder and Decoder networks and creates the
final Flow model. If `checkpoint` is specified, it loads the weights from
the `checkpoint` path.
Args:
- input_dim (tuple) : size of inputs (CxHxW)
- latent_dim (int) : dimension of the latent space
- checkpoint (str) : path to checkpoint
- device (str) : device to be used
"""
encoder = ConditionalNormal(
Encoder(
input_dim=input_dim[0],
latent_dim=latent_dim
)
)
decoder = ConditionalMeanStdNormal(
Decoder(
output_dim=input_dim[0],
latent_dim=latent_dim
),
input_dim
)
model = Flow(base_dist=StandardNormal((latent_dim,)),
transforms=[
VAE(encoder=encoder, decoder=decoder)
]).to(device)
if checkpoint is not None:
model.load_state_dict(
torch.load(checkpoint, map_location=torch.device(device)))
return model