Skip to content

Commit 09c2052

Browse files
xingyousongcopybara-github
authored andcommitted
Add different nonlinearities on std head.
PiperOrigin-RevId: 691152711
1 parent b19cba7 commit 09c2052

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

optformer/embed_then_regress/configs.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Callable
2121
from flax import linen as nn
2222
import jax
23+
import jaxtyping as jt
2324
import optax
2425
from optformer.embed_then_regress import icl_transformer
2526
from optformer.t5x import embedders
@@ -76,18 +77,45 @@ class ModelConfig:
7677
nhead: int = 16
7778
dropout: float = 0.1
7879
num_layers: int = 8
80+
std_transform: str = 'exp'
7981

8082
def create_model(
8183
self, embedder_config: T5EmbedderConfig
8284
) -> icl_transformer.ICLTransformer:
8385

8486
kwargs = dataclasses.asdict(self)
8587
embedder_factory = embedder_config.create_embedder_factory()
88+
std_transform_fn = self.create_std_transform_fn()
8689

8790
return icl_transformer.ICLTransformer(
88-
embedder_factory=embedder_factory, **kwargs
91+
std_transform_fn=std_transform_fn,
92+
embedder_factory=embedder_factory,
93+
**kwargs,
8994
)
9095

96+
def create_std_transform_fn(
97+
self,
98+
) -> Callable[[jt.Float[jax.Array, '*A']], jt.Float[jax.Array, '*A']]:
99+
"""Creates std transform function."""
100+
if self.std_transform == 'exp':
101+
return jax.numpy.exp
102+
elif self.std_transform == 'exp10':
103+
return lambda x: jax.numpy.exp(10.0 * x)
104+
elif self.std_transform == 'softplus':
105+
return jax.nn.softplus
106+
elif self.std_transform == 'softplus10':
107+
return lambda x: jax.nn.softplus(10.0 * x)
108+
elif self.std_transform == 'abs':
109+
return jax.numpy.abs
110+
elif self.std_transform == 'abs10':
111+
return lambda x: jax.numpy.abs(10.0 * x)
112+
elif self.std_transform == 'shifted_relu':
113+
return lambda x: jax.nn.relu(x + 1.0)
114+
elif self.std_transform == 'shifted_relu10':
115+
return lambda x: jax.nn.relu(10.0 * x + 1.0)
116+
else:
117+
raise ValueError(f'Unknown std_transform: {self.std_transform}')
118+
91119

92120
@dataclasses.dataclass
93121
class TrainingConfig:

optformer/embed_then_regress/icl_transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
# order to use the same learning rate.
3131
default_kernel_init = nn.initializers.truncated_normal(stddev=0.02)
3232
Dense = functools.partial(nn.Dense, kernel_init=default_kernel_init)
33+
EPS = 1e-7
34+
AnyTensor = jt.Float[jax.Array, '*A']
3335

3436

3537
class Block(nn.Module):
@@ -92,7 +94,7 @@ class ICLTransformer(nn.Module):
9294
nhead: int # H
9395
dropout: float
9496
num_layers: int
95-
97+
std_transform_fn: Callable[[AnyTensor], AnyTensor]
9698
embedder_factory: Callable[[], nn.Module] # __call__: [B, T] -> [B, D]
9799

98100
def setup(self):
@@ -166,8 +168,8 @@ def __call__(
166168
for layer in self.encoder_layers:
167169
out = layer(out, mask, deterministic, rng)
168170

169-
mean, log_std = jnp.split(self.mean_logstd_head(out), 2, axis=-1) # [B L 1]
170-
std = jnp.exp(log_std)
171+
mean, std = jnp.split(self.mean_logstd_head(out), 2, axis=-1) # [B L 1]
172+
std = self.std_transform_fn(self.std_transform)(std) + EPS
171173

172174
mean = jnp.squeeze(mean, axis=-1)
173175
std = jnp.squeeze(std, axis=-1)

optformer/embed_then_regress/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232

3333
Scalar = jnp.ndarray | np.ndarray | float
34-
EPS = 1e-7
3534

3635

3736
def multi_gpu() -> bool:
@@ -98,7 +97,7 @@ def loss_fn(
9897
"""Loss function with metrics."""
9998
# pylint: disable=invalid-name
10099
mean, std = model.apply(params, deterministic=not training, rng=rng, **batch)
101-
nlogprob = -jax.scipy.stats.norm.logpdf(batch['y'], mean, std + EPS) # [B, L]
100+
nlogprob = -jax.scipy.stats.norm.logpdf(batch['y'], mean, std) # [B, L]
102101

103102
# Only compute loss over target ys. Mask is BxL where True denotes context
104103
# token and False otherwise.

0 commit comments

Comments
 (0)