Skip to content

Commit fe334a7

Browse files
xingyousongcopybara-github
authored andcommitted
small fix
PiperOrigin-RevId: 691269779
1 parent 09c2052 commit fe334a7

File tree

4 files changed

+28
-21
lines changed

4 files changed

+28
-21
lines changed

optformer/embed_then_regress/checkpointing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def get_checkpoint_manager(
24-
workdir: epath.PathLike,
24+
workdir: epath.PathLike, **options_kwargs
2525
) -> orbax_checkpoint.CheckpointManager:
2626
"""Sets up Orbax checkpointing."""
2727
# The keys in this dict should match the keys in `checkpointed_state`.
@@ -32,7 +32,9 @@ def get_checkpoint_manager(
3232
return orbax_checkpoint.CheckpointManager(
3333
checkpoint_dir,
3434
checkpointers=checkpointers,
35-
options=orbax_checkpoint.CheckpointManagerOptions(create=True),
35+
options=orbax_checkpoint.CheckpointManagerOptions(
36+
create=True, **options_kwargs
37+
),
3638
)
3739

3840

optformer/embed_then_regress/configs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Callable
2121
from flax import linen as nn
2222
import jax
23+
import jax.numpy as jnp
2324
import jaxtyping as jt
2425
import optax
2526
from optformer.embed_then_regress import icl_transformer
@@ -84,12 +85,11 @@ def create_model(
8485
) -> icl_transformer.ICLTransformer:
8586

8687
kwargs = dataclasses.asdict(self)
87-
embedder_factory = embedder_config.create_embedder_factory()
88-
std_transform_fn = self.create_std_transform_fn()
88+
kwargs.pop('std_transform')
8989

9090
return icl_transformer.ICLTransformer(
91-
std_transform_fn=std_transform_fn,
92-
embedder_factory=embedder_factory,
91+
std_transform_fn=self.create_std_transform_fn(),
92+
embedder_factory=embedder_config.create_embedder_factory(),
9393
**kwargs,
9494
)
9595

@@ -98,17 +98,17 @@ def create_std_transform_fn(
9898
) -> Callable[[jt.Float[jax.Array, '*A']], jt.Float[jax.Array, '*A']]:
9999
"""Creates std transform function."""
100100
if self.std_transform == 'exp':
101-
return jax.numpy.exp
101+
return jnp.exp
102102
elif self.std_transform == 'exp10':
103-
return lambda x: jax.numpy.exp(10.0 * x)
103+
return lambda x: jnp.exp(10.0 * x)
104104
elif self.std_transform == 'softplus':
105105
return jax.nn.softplus
106106
elif self.std_transform == 'softplus10':
107107
return lambda x: jax.nn.softplus(10.0 * x)
108108
elif self.std_transform == 'abs':
109-
return jax.numpy.abs
109+
return jnp.abs
110110
elif self.std_transform == 'abs10':
111-
return lambda x: jax.numpy.abs(10.0 * x)
111+
return lambda x: jnp.abs(10.0 * x)
112112
elif self.std_transform == 'shifted_relu':
113113
return lambda x: jax.nn.relu(x + 1.0)
114114
elif self.std_transform == 'shifted_relu10':
@@ -131,7 +131,7 @@ class TrainingConfig:
131131
seed: int = 42
132132

133133
validation_interval: int = 100
134-
checkpoint_interval: int = 100
134+
max_to_keep_ckpts: int = 5
135135
workdir = '../checkpoints'
136136

137137
def create_optimizer(self) -> optax.GradientTransformation:

optformer/embed_then_regress/icl_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __call__(
169169
out = layer(out, mask, deterministic, rng)
170170

171171
mean, std = jnp.split(self.mean_logstd_head(out), 2, axis=-1) # [B L 1]
172-
std = self.std_transform_fn(self.std_transform)(std) + EPS
172+
std = self.std_transform_fn(std) + EPS
173173

174174
mean = jnp.squeeze(mean, axis=-1)
175175
std = jnp.squeeze(std, axis=-1)

optformer/embed_then_regress/train.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,12 @@ def train(
194194
)
195195

196196
# Set up checkpointing
197-
checkpoint_manager = ckpt_lib.get_checkpoint_manager(train_config.workdir)
197+
checkpoint_manager = ckpt_lib.get_checkpoint_manager(
198+
train_config.workdir,
199+
max_to_keep=train_config.max_to_keep_ckpts,
200+
best_fn=lambda metrics: metrics['eval_loss'],
201+
best_mode='min',
202+
)
198203
# Restore if available.
199204
train_state = ckpt_lib.restore_train_state(
200205
train_config.workdir, init_train_state
@@ -206,20 +211,20 @@ def train(
206211
eff_step = int(unreplicate(train_state.step)) // grad_accum_steps
207212

208213
while eff_step < train_config.max_steps:
209-
if eff_step % train_config.checkpoint_interval == 0:
214+
if eff_step % train_config.validation_interval == 0:
215+
valid_agg_metrics = aggregate_metrics([
216+
p_eval_step(train_state, next(valid_it))
217+
for _ in range(grad_accum_steps)
218+
])
219+
writer.write_scalars(eff_step, valid_agg_metrics)
220+
210221
ckpt_train_state = unreplicate(train_state)
211222
checkpoint_manager.save(
212223
eff_step,
213224
items=dict(train_state=jax.tree.map(np.array, ckpt_train_state)),
225+
metrics=valid_agg_metrics,
214226
)
215227

216-
if eff_step % train_config.validation_interval == 0:
217-
all_valid_metrics = [
218-
p_eval_step(train_state, next(valid_it))
219-
for _ in range(grad_accum_steps)
220-
]
221-
writer.write_scalars(eff_step, aggregate_metrics(all_valid_metrics))
222-
223228
all_train_metrics = []
224229
for _ in range(grad_accum_steps):
225230
train_state, train_metrics = p_train_step(train_state, next(train_it))

0 commit comments

Comments
 (0)