Skip to content

Commit d8cbeb6

Browse files
xingyousongcopybara-github
authored andcommitted
small fix
PiperOrigin-RevId: 691467279
1 parent 09c2052 commit d8cbeb6

File tree

4 files changed

+31
-28
lines changed

4 files changed

+31
-28
lines changed

optformer/embed_then_regress/checkpointing.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,16 @@
2121

2222

2323
def get_checkpoint_manager(
24-
workdir: epath.PathLike,
24+
workdir: epath.PathLike, **options_kwargs
2525
) -> orbax_checkpoint.CheckpointManager:
2626
"""Sets up Orbax checkpointing."""
27-
# The keys in this dict should match the keys in `checkpointed_state`.
28-
checkpointers = dict(
29-
train_state=orbax_checkpoint.PyTreeCheckpointer(),
30-
)
3127
checkpoint_dir = epath.Path(workdir) / 'checkpoints'
3228
return orbax_checkpoint.CheckpointManager(
3329
checkpoint_dir,
34-
checkpointers=checkpointers,
35-
options=orbax_checkpoint.CheckpointManagerOptions(create=True),
30+
checkpointers={'train_state': orbax_checkpoint.PyTreeCheckpointer()},
31+
options=orbax_checkpoint.CheckpointManagerOptions(
32+
create=True, **options_kwargs
33+
),
3634
)
3735

3836

optformer/embed_then_regress/configs.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Callable
2121
from flax import linen as nn
2222
import jax
23+
import jax.numpy as jnp
2324
import jaxtyping as jt
2425
import optax
2526
from optformer.embed_then_regress import icl_transformer
@@ -84,12 +85,11 @@ def create_model(
8485
) -> icl_transformer.ICLTransformer:
8586

8687
kwargs = dataclasses.asdict(self)
87-
embedder_factory = embedder_config.create_embedder_factory()
88-
std_transform_fn = self.create_std_transform_fn()
88+
kwargs.pop('std_transform')
8989

9090
return icl_transformer.ICLTransformer(
91-
std_transform_fn=std_transform_fn,
92-
embedder_factory=embedder_factory,
91+
std_transform_fn=self.create_std_transform_fn(),
92+
embedder_factory=embedder_config.create_embedder_factory(),
9393
**kwargs,
9494
)
9595

@@ -98,17 +98,17 @@ def create_std_transform_fn(
9898
) -> Callable[[jt.Float[jax.Array, '*A']], jt.Float[jax.Array, '*A']]:
9999
"""Creates std transform function."""
100100
if self.std_transform == 'exp':
101-
return jax.numpy.exp
101+
return jnp.exp
102102
elif self.std_transform == 'exp10':
103-
return lambda x: jax.numpy.exp(10.0 * x)
103+
return lambda x: jnp.exp(10.0 * x)
104104
elif self.std_transform == 'softplus':
105105
return jax.nn.softplus
106106
elif self.std_transform == 'softplus10':
107107
return lambda x: jax.nn.softplus(10.0 * x)
108108
elif self.std_transform == 'abs':
109-
return jax.numpy.abs
109+
return jnp.abs
110110
elif self.std_transform == 'abs10':
111-
return lambda x: jax.numpy.abs(10.0 * x)
111+
return lambda x: jnp.abs(10.0 * x)
112112
elif self.std_transform == 'shifted_relu':
113113
return lambda x: jax.nn.relu(x + 1.0)
114114
elif self.std_transform == 'shifted_relu10':
@@ -131,7 +131,7 @@ class TrainingConfig:
131131
seed: int = 42
132132

133133
validation_interval: int = 100
134-
checkpoint_interval: int = 100
134+
max_to_keep_ckpts: int = 5
135135
workdir = '../checkpoints'
136136

137137
def create_optimizer(self) -> optax.GradientTransformation:
@@ -186,7 +186,6 @@ def wrap_ds(
186186
) -> tf.data.Dataset:
187187
"""This should be used at the trainer level."""
188188
ds = self._tokenize_ds(ds)
189-
ds = ds.shard(jax.process_count(), jax.process_index())
190189
ds = ds.repeat()
191190
ds = ds.shuffle(buffer_size=self.buffer_size)
192191

optformer/embed_then_regress/icl_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __call__(
169169
out = layer(out, mask, deterministic, rng)
170170

171171
mean, std = jnp.split(self.mean_logstd_head(out), 2, axis=-1) # [B L 1]
172-
std = self.std_transform_fn(self.std_transform)(std) + EPS
172+
std = self.std_transform_fn(std) + EPS
173173

174174
mean = jnp.squeeze(mean, axis=-1)
175175
std = jnp.squeeze(std, axis=-1)

optformer/embed_then_regress/train.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def aggregate_metrics(
155155
"""Aggregates metrics (possibly from multiple gradient accumulation steps)."""
156156
if isinstance(metrics, list):
157157
metrics = jax.tree.map(lambda *args: jnp.stack(args), *metrics)
158-
return jax.tree.map(jnp.mean, metrics)
158+
metrics = jax.tree.map(jnp.mean, metrics)
159+
return {k: float(v) for k, v in metrics.items()}
159160

160161

161162
def train(
@@ -194,7 +195,12 @@ def train(
194195
)
195196

196197
# Set up checkpointing
197-
checkpoint_manager = ckpt_lib.get_checkpoint_manager(train_config.workdir)
198+
checkpoint_manager = ckpt_lib.get_checkpoint_manager(
199+
train_config.workdir,
200+
max_to_keep=train_config.max_to_keep_ckpts,
201+
best_fn=lambda metrics: metrics['eval_loss'],
202+
best_mode='min',
203+
)
198204
# Restore if available.
199205
train_state = ckpt_lib.restore_train_state(
200206
train_config.workdir, init_train_state
@@ -206,20 +212,20 @@ def train(
206212
eff_step = int(unreplicate(train_state.step)) // grad_accum_steps
207213

208214
while eff_step < train_config.max_steps:
209-
if eff_step % train_config.checkpoint_interval == 0:
215+
if eff_step % train_config.validation_interval == 0:
216+
valid_agg_metrics = aggregate_metrics([
217+
p_eval_step(train_state, next(valid_it))
218+
for _ in range(grad_accum_steps)
219+
])
220+
writer.write_scalars(eff_step, valid_agg_metrics)
221+
210222
ckpt_train_state = unreplicate(train_state)
211223
checkpoint_manager.save(
212224
eff_step,
213225
items=dict(train_state=jax.tree.map(np.array, ckpt_train_state)),
226+
metrics=valid_agg_metrics,
214227
)
215228

216-
if eff_step % train_config.validation_interval == 0:
217-
all_valid_metrics = [
218-
p_eval_step(train_state, next(valid_it))
219-
for _ in range(grad_accum_steps)
220-
]
221-
writer.write_scalars(eff_step, aggregate_metrics(all_valid_metrics))
222-
223229
all_train_metrics = []
224230
for _ in range(grad_accum_steps):
225231
train_state, train_metrics = p_train_step(train_state, next(train_it))

0 commit comments

Comments
 (0)