Skip to content

Commit ee59ac5

Browse files
xingyousongcopybara-github
authored andcommitted
Allow metadata embedding to be optional.
PiperOrigin-RevId: 696662925
1 parent e76d7d4 commit ee59ac5

File tree

3 files changed

+41
-33
lines changed

3 files changed

+41
-33
lines changed

optformer/embed_then_regress/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class ModelConfig:
7878
nhead: int = 16
7979
dropout: float = 0.1
8080
num_layers: int = 8
81+
use_metadata: bool = True
8182
std_transform: str = 'exp'
8283

8384
def create_model(

optformer/embed_then_regress/icl_transformer.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ def __call__(
8686
return x
8787

8888

89+
class EmbeddingCache(dict[str, AnyTensor]):
90+
91+
def get_or_set(self, key: str, fn: Callable[[], AnyTensor]):
92+
value = self.get(key)
93+
if value is None:
94+
value = fn()
95+
self.update({key: value})
96+
return value
97+
98+
8999
class ICLTransformer(nn.Module):
90100
"""ICL Transformer model for regression."""
91101

@@ -94,6 +104,7 @@ class ICLTransformer(nn.Module):
94104
nhead: int # H
95105
dropout: float
96106
num_layers: int
107+
use_metadata: bool
97108
std_transform_fn: Callable[[AnyTensor], AnyTensor]
98109
embedder_factory: Callable[[], nn.Module] # __call__: [B, T] -> [B, D]
99110

@@ -102,7 +113,7 @@ def setup(self):
102113
self.embedder = self.embedder_factory()
103114

104115
# X, Y, and concatenated X,Y embedders.
105-
self.xm_proj = nn.Sequential(
116+
self.x_proj = nn.Sequential(
106117
[Dense(self.d_model), nn.relu, Dense(self.d_model)]
107118
)
108119
self.y_proj = nn.Sequential(
@@ -132,7 +143,6 @@ def __call__(
132143
self,
133144
x_emb: jt.Float[jax.Array, 'B L E'],
134145
y: jt.Float[jax.Array, 'B L'],
135-
metadata_emb: jt.Float[jax.Array, 'B E'],
136146
mask: jt.Bool[jax.Array, 'B L'],
137147
deterministic: bool | None = None,
138148
rng: jax.Array | None = None,
@@ -178,16 +188,16 @@ def fit(
178188
rng: jax.Array | None = None,
179189
) -> tuple[jt.Float[jax.Array, 'B L'], jt.Float[jax.Array, 'B L']]:
180190
"""For training / eval loss metrics only."""
181-
# pylint: disable=invalid-name
182-
L = x.shape[1]
183-
184191
x_emb = self.embed(x) # [B, L, E]
185-
metadata_emb = self.embed(metadata) # [B, E]
186192

187-
metadata_emb = jnp.expand_dims(metadata_emb, axis=1) # [B, 1, E]
188-
metadata_emb = jnp.repeat(metadata_emb, L, axis=1) # [B, L, E]
189-
xm_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [B, L, 2E]
190-
return self.__call__(xm_emb, y, metadata_emb, mask, deterministic, rng)
193+
if self.use_metadata:
194+
L = x_emb.shape[1] # pylint: disable=invalid-name
195+
metadata_emb = self.embed(metadata) # [B, E]
196+
metadata_emb = jnp.expand_dims(metadata_emb, axis=1) # [B, 1, E]
197+
metadata_emb = jnp.repeat(metadata_emb, L, axis=1) # [B, L, E]
198+
x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [B, L, 2E]
199+
200+
return self.__call__(x_emb, y, mask, deterministic, rng)
191201

192202
def infer(
193203
self,
@@ -196,49 +206,44 @@ def infer(
196206
x_targ: jt.Int[jax.Array, 'Q T'], # Q is fixed to avoid re-jitting.
197207
metadata: jt.Int[jax.Array, 'T'],
198208
mask: jt.Bool[jax.Array, 'L'],
199-
cache: dict[str, jax.Array] | None = None, # For caching embeddings.
209+
cache: EmbeddingCache | None = None, # For caching embeddings.
200210
) -> tuple[
201211
jt.Float[jax.Array, 'L'],
202212
jt.Float[jax.Array, 'L'],
203213
dict[str, jax.Array],
204214
]:
205215
"""Friendly for inference, no batch dimension."""
206216
if cache is None:
207-
x_padded_emb = self.embed(x_padded) # [L, E]
208-
metadata_emb = self.embed(metadata) # [E]
209-
cache = {'x_padded_emb': x_padded_emb, 'metadata_emb': metadata_emb}
210-
else:
211-
x_padded_emb = cache['x_padded_emb'] # [L, E]
212-
metadata_emb = cache['metadata_emb'] # [E]
217+
cache = EmbeddingCache()
213218

219+
# [L, E]
220+
x_pad_emb = cache.get_or_set('x_pad_emb', lambda: self.embed(x_padded))
214221
x_targ_emb = self.embed(x_targ) # [Q, E]
215-
216-
L, E = x_padded_emb.shape # pylint: disable=invalid-name
217-
218-
target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
222+
L, E = x_pad_emb.shape # pylint: disable=invalid-name
219223

220224
# Combine target and historical (padded) embeddings.
225+
target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
221226
padded_target_emb = jnp.zeros((L, E), dtype=x_targ_emb.dtype)
222227
padded_target_emb = jax.lax.dynamic_update_slice_in_dim(
223228
padded_target_emb, x_targ_emb, start_index=target_index, axis=0
224229
)
225230
w_mask = jnp.expand_dims(mask, axis=-1) # [L, 1]
226-
x_emb = x_padded_emb * w_mask + padded_target_emb * (1 - w_mask)
231+
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask)
227232

228-
# Attach metadata embeddings too.
229-
metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
230-
metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E]
231-
xm_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]
233+
if self.use_metadata: # Attach metadata embeddings too.
234+
metadata_emb = cache.get_or_set(
235+
'metadata_emb', lambda: self.embed(metadata)
236+
)
237+
metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
238+
metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E]
239+
x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]
232240

233-
# TODO: Are these batch=1 expands necessary?
234241
mean, std = self.__call__(
235-
x_emb=jnp.expand_dims(xm_emb, axis=0),
242+
x_emb=jnp.expand_dims(x_emb, axis=0),
236243
y=jnp.expand_dims(y_padded, axis=0),
237-
metadata_emb=jnp.expand_dims(metadata_emb, axis=0),
238244
mask=jnp.expand_dims(mask, axis=0),
239245
deterministic=True,
240246
)
241-
242247
return jnp.squeeze(mean, axis=0), jnp.squeeze(std, axis=0), cache
243248

244249
@nn.remat # Reduce memory consumption during backward pass.

optformer/embed_then_regress/regressor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
tfd = tfp.distributions
3333

34+
EmbeddingCache = icl_transformer.EmbeddingCache
35+
3436

3537
# TODO: Maybe refactor omnipred2 regressor base class.
3638
@attrs.define
@@ -54,7 +56,7 @@ class StatefulICLRegressor:
5456
_all_yt: jt.Float[np.ndarray, 'L'] = attrs.field(init=False)
5557
_mt: jt.Int[np.ndarray, 'T'] = attrs.field(init=False)
5658
_num_prev: int = attrs.field(init=False)
57-
_cache: dict[str, jax.Array] | None = attrs.field(init=False)
59+
_cache: EmbeddingCache = attrs.field(init=False)
5860

5961
# Jitted function.
6062
_jit_apply: Callable[..., Any] = attrs.field(init=False)
@@ -95,7 +97,7 @@ def absorb(self, xs: Sequence[str], ys: Sequence[float]):
9597
self._all_xt[self._num_prev : self._num_prev + num_pts] = self._tokenize(xs)
9698
self._all_yt[self._num_prev : self._num_prev + num_pts] = np.array(ys)
9799
self._num_prev += num_pts
98-
self._cache = None # Need to recompute historical embeddings.
100+
self._cache = EmbeddingCache() # Need to recompute historical embeddings.
99101

100102
self.warper.train(self._all_yt[: self._num_prev])
101103

@@ -109,7 +111,7 @@ def reset(self) -> None:
109111
self._all_yt = np.zeros(self.max_trial_length, dtype=np.float32)
110112
self._mt = np.zeros(self.max_token_length, dtype=np.int32)
111113
self._num_prev = 0
112-
self._cache = None
114+
self._cache = EmbeddingCache()
113115

114116
def _tokenize(self, ss: Sequence[str]) -> jt.Int[np.ndarray, 'S T']:
115117
"""Converts ss (strings) to tokens."""

0 commit comments

Comments
 (0)