-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathload.py
More file actions
99 lines (92 loc) · 4.32 KB
/
Copy pathload.py
File metadata and controls
99 lines (92 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import time
import os
from os import listdir
from os.path import join, isfile, isdir, expanduser
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import functional as F
from redunet import *
def load_architecture(data, arch, seed=0):
if data == 'mnist2d':
if arch == 'lift2d_channels35_layers5':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=35, layers=5, num_classes=10, seed=seed)
if arch == 'lift2d_channels35_layers10':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=35, layers=5, num_classes=10, seed=seed)
if arch == 'lift2d_channels35_layers20':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=35, layers=20, num_classes=10, seed=seed)
if arch == 'lift2d_channels55_layers5':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=55, layers=5, num_classes=10, seed=seed)
if arch == 'lift2d_channels55_layers10':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=55, layers=5, num_classes=10, seed=seed)
if arch == 'lift2d_channels55_layers20':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=55, layers=20, num_classes=10, seed=seed)
if data == 'mnist2d+2class':
if arch == 'lift2d_channels35_layers5':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=35, layers=5, num_classes=2, seed=seed)
if arch == 'lift2d_channels35_layers10':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=35, layers=5, num_classes=2, seed=seed)
if arch == 'lift2d_channels35_layers20':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=35, layers=20, num_classes=2, seed=seed)
if arch == 'lift2d_channels55_layers5':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=55, layers=5, num_classes=2, seed=seed)
if arch == 'lift2d_channels55_layers10':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=55, layers=5, num_classes=2, seed=seed)
if arch == 'lift2d_channels55_layers20':
from architectures.mnist.lift2d import lift2d
return lift2d(channels=55, layers=20, num_classes=2, seed=seed)
if data == 'mnistvector':
if arch == 'layers50':
from architectures.mnist.flatten import flatten
return flatten(layers=50, num_classes=10)
if arch == 'layers20':
from architectures.mnist.flatten import flatten
return flatten(layers=20, num_classes=10)
if arch == 'layers10':
from architectures.mnist.flatten import flatten
return flatten(layers=10, num_classes=10)
if arch == 'layers5':
from architectures.mnist.flatten import flatten
return flatten(layers=5, num_classes=10)
if data == 'mnistvector_2class':
if arch == 'layers50':
from architectures.mnist.flatten import flatten
return flatten(layers=50, num_classes=2)
if arch == 'layers20':
from architectures.mnist.flatten import flatten
return flatten(layers=20, num_classes=2)
if arch == 'layers10':
from architectures.mnist.flatten import flatten
return flatten(layers=10, num_classes=2)
if arch == 'layers5':
from architectures.mnist.flatten import flatten
return flatten(layers=5, num_classes=2)
raise NameError('Cannot find architecture: {}.')
def load_dataset(choice, data_dir='./data/'):
if choice == 'mnist2d':
from datasets.mnist import mnist2d_10class
return mnist2d_10class(data_dir)
if choice == 'mnist2d_2class':
from datasets.mnist import mnist2d_2class
return mnist2d_2class(data_dir)
if choice =='mnistvector':
from datasets.mnist import mnistvector_10class
return mnistvector_10class(data_dir)
raise NameError(f'Dataset {choice} not found.')