Skip to content

Commit e115be2

Browse files
xingyousongcopybara-github
authored andcommitted
Fix caching
PiperOrigin-RevId: 697005664
1 parent 72ab8db commit e115be2

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

optformer/embed_then_regress/icl_transformer.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
"""Transformer model for ICL regression."""
1616

17+
import dataclasses
1718
import functools
1819
from typing import Callable
1920
from flax import linen as nn
21+
from flax import struct
2022
import jax
2123
import jax.numpy as jnp
2224
import jaxtyping as jt
@@ -86,14 +88,12 @@ def __call__(
8688
return x
8789

8890

89-
class EmbeddingCache(dict[str, AnyTensor]):
91+
@struct.dataclass
92+
class EmbeddingCache:
93+
"""Cache for storing previously computed embeddings."""
9094

91-
def get_or_set(self, key: str, fn: Callable[[], AnyTensor]):
92-
value = self.get(key)
93-
if value is None:
94-
value = fn()
95-
self.update({key: value})
96-
return value
95+
x_emb: jt.Float[jax.Array, 'L E'] | None = None
96+
metadata_emb: jt.Float[jax.Array, 'E'] | None = None
9797

9898

9999
class ICLTransformer(nn.Module):
@@ -206,36 +206,33 @@ def infer(
206206
x_targ: jt.Int[jax.Array, 'Q T'], # Q is fixed to avoid re-jitting.
207207
metadata: jt.Int[jax.Array, 'T'],
208208
mask: jt.Bool[jax.Array, 'L'],
209-
cache: EmbeddingCache | None = None, # For caching embeddings.
209+
cache: EmbeddingCache, # For caching embeddings.
210210
) -> tuple[
211211
jt.Float[jax.Array, 'L'],
212212
jt.Float[jax.Array, 'L'],
213-
dict[str, jax.Array],
213+
EmbeddingCache,
214214
]:
215215
"""Friendly for inference, no batch dimension."""
216-
if cache is None:
217-
cache = EmbeddingCache()
218-
219-
# [L, E]
220-
x_pad_emb = cache.get_or_set('x_pad_emb', lambda: self.embed(x_padded))
216+
if cache.x_emb is None:
217+
cache = dataclasses.replace(cache, x_emb=self.embed(x_padded))
218+
x_pad_emb = cache.x_emb # [L, E]
221219
x_targ_emb = self.embed(x_targ) # [Q, E]
222-
L, E = x_pad_emb.shape # pylint: disable=invalid-name
223220

224221
# Combine target and historical (padded) embeddings.
225222
target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
226-
padded_target_emb = jnp.zeros((L, E), dtype=x_targ_emb.dtype)
223+
padded_target_emb = jnp.zeros_like(x_pad_emb)
227224
padded_target_emb = jax.lax.dynamic_update_slice_in_dim(
228225
padded_target_emb, x_targ_emb, start_index=target_index, axis=0
229226
)
230227
w_mask = jnp.expand_dims(mask, axis=-1) # [L, 1]
231-
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask)
228+
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask) # [L, E]
232229

233230
if self.use_metadata: # Attach metadata embeddings too.
234-
metadata_emb = cache.get_or_set(
235-
'metadata_emb', lambda: self.embed(metadata)
236-
)
231+
if cache.metadata_emb is None:
232+
cache = dataclasses.replace(cache, metadata_emb=self.embed(metadata))
233+
metadata_emb = cache.metadata_emb # [E]
237234
metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
238-
metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E]
235+
metadata_emb = jnp.repeat(metadata_emb, x_emb.shape[0], axis=0) # [L, E]
239236
x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]
240237

241238
mean, std = self.__call__(

0 commit comments

Comments
 (0)