-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrun_linen_experiments.py
More file actions
175 lines (150 loc) · 5.63 KB
/
run_linen_experiments.py
File metadata and controls
175 lines (150 loc) · 5.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""This script runs experiments training various recurrent memory models
on different datasets using Flax Linen. It serves as a reference implementation
for training and evaluating memax modules."""
import argparse
from functools import partial
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import tqdm
import wandb
from memax.datasets.mnist_math import get_dataset as get_mnist_math
from memax.datasets.sequential_mnist import get_dataset as get_sequential_mnist
from memax.linen.train_utils import (
get_residual_memory_models,
loss_classify_terminal_output,
update_model,
)
def parse_args():
parser = argparse.ArgumentParser(description="Train recurrent memory models.")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--num-epochs", type=int, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, help="Batch size")
parser.add_argument(
"--recurrent-size", type=int, help="Recurrent size of the model"
)
parser.add_argument("--num-layers", type=int, help="Number of layers in the model")
parser.add_argument("--lr", type=float, help="Learning rate")
parser.add_argument(
"--use-wandb",
action="store_true",
default=False,
help="Use Weights & Biases for logging",
)
parser.add_argument(
"--project-name",
type=str,
default="memax-debug",
help="Weights & Biases project name",
)
parser.add_argument(
"--dataset_name",
type=str,
default="sequential_mnist",
help="Dataset name (e.g., mnist_math, other_dataset)",
)
parser.add_argument(
"--loss-function",
type=str,
default="loss_classify_terminal_output",
help="Loss function to use (e.g., loss_classify_terminal_output, other_loss_fn)",
)
parser.add_argument("--models", type=str, nargs="+")
return parser.parse_args()
def get_default_hyperparameters(dataset_name):
defaults = {
"mnist_math_5": {
"num_epochs": 5,
"batch_size": 16,
"recurrent_size": 256,
"num_layers": 2,
"lr": 0.0001,
},
"sequential_mnist": {
"num_epochs": 5,
"batch_size": 16,
"recurrent_size": 256,
"num_layers": 2,
"lr": 0.0001,
},
# Add more datasets and their default hyperparameters here
}
if dataset_name in defaults:
return defaults[dataset_name]
else:
raise ValueError(
f"No default hyperparameters defined for dataset: {dataset_name}"
)
def update_config_with_defaults(args):
defaults = get_default_hyperparameters(args.dataset_name)
for key, value in defaults.items():
if getattr(args, key) is None:
setattr(args, key, value)
def run_test(config, name, model, dataset, loss_fn):
if config.use_wandb:
wandb.init(project=config.project_name, name=name)
lr_schedule = optax.constant_schedule(config.lr)
opt = optax.chain(
optax.zero_nans(),
optax.adamw(lr_schedule),
)
key = jax.random.PRNGKey(config.seed)
dummy_x = dataset["x_train"][0]
dummy_starts = jnp.zeros(dummy_x.shape[0], dtype=bool)
dummy_h = model.zero_carry()
params = model.init(key, dummy_h, (dummy_x, dummy_starts))
opt_state = opt.init(params)
initialise_carry_fn = partial(model.apply, method="initialize_carry")
model_apply_fn = model.apply
loss_fn = partial(
loss_classify_terminal_output,
init_carry_fn=initialise_carry_fn,
model_apply_fn=model_apply_fn,
)
for epoch in range(config.num_epochs):
key, shuffle_key = jax.random.split(key)
shuffle_idx = jax.random.permutation(shuffle_key, dataset["size"])
x = dataset["x_train"][shuffle_idx]
y = dataset["y_train"][shuffle_idx]
pbar = tqdm.tqdm(range(x.shape[0] // config.batch_size))
for update in pbar:
key, subkey = jax.random.split(key)
x_batch = x[update * config.batch_size : (update + 1) * config.batch_size]
y_batch = y[update * config.batch_size : (update + 1) * config.batch_size]
params, opt_state, metrics = jax.jit(
update_model, static_argnames=("loss_fn", "opt")
)(params, loss_fn, opt, opt_state, x_batch, y_batch, key=subkey)
mean_metrics = {k: jnp.mean(v).item() for k, v in metrics.items()}
pbar.set_description(
f"{name} epoch: {epoch}, "
+ ", ".join(f"{k}: {v:.4f}" for k, v in mean_metrics.items())
)
if config.use_wandb:
wandb.log({**mean_metrics, "epoch": epoch})
if config.use_wandb:
wandb.finish()
def main():
args = parse_args()
update_config_with_defaults(args)
# Dynamically load dataset
if args.dataset_name == "mnist_math_5":
dataset = get_mnist_math(5)
elif args.dataset_name == "sequential_mnist":
dataset = get_sequential_mnist()
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
# Dynamically select loss function
if args.loss_function == "loss_classify_terminal_output":
loss_fn = loss_classify_terminal_output
else:
raise ValueError(f"Unknown loss function: {args.loss_function}")
models = get_residual_memory_models(
hidden=args.recurrent_size,
output=dataset["num_labels"],
num_layers=args.num_layers,
)
for name, model in models.items():
run_test(args, name, model, dataset, loss_fn)
if __name__ == "__main__":
main()