|
14 | 14 |
|
15 | 15 | """Transformer model for ICL regression."""
|
16 | 16 |
|
| 17 | +import dataclasses |
17 | 18 | import functools
|
18 | 19 | from typing import Callable
|
19 | 20 | from flax import linen as nn
|
| 21 | +from flax import struct |
20 | 22 | import jax
|
21 | 23 | import jax.numpy as jnp
|
22 | 24 | import jaxtyping as jt
|
@@ -86,14 +88,12 @@ def __call__(
|
86 | 88 | return x
|
87 | 89 |
|
88 | 90 |
|
89 |
| -class EmbeddingCache(dict[str, AnyTensor]): |
| 91 | +@struct.dataclass |
| 92 | +class EmbeddingCache: |
| 93 | + """Cache for storing previously computed embeddings.""" |
90 | 94 |
|
91 |
| - def get_or_set(self, key: str, fn: Callable[[], AnyTensor]): |
92 |
| - value = self.get(key) |
93 |
| - if value is None: |
94 |
| - value = fn() |
95 |
| - self.update({key: value}) |
96 |
| - return value |
| 95 | + x_emb: jt.Float[jax.Array, 'L E'] | None = None |
| 96 | + metadata_emb: jt.Float[jax.Array, 'E'] | None = None |
97 | 97 |
|
98 | 98 |
|
99 | 99 | class ICLTransformer(nn.Module):
|
@@ -206,36 +206,33 @@ def infer(
|
206 | 206 | x_targ: jt.Int[jax.Array, 'Q T'], # Q is fixed to avoid re-jitting.
|
207 | 207 | metadata: jt.Int[jax.Array, 'T'],
|
208 | 208 | mask: jt.Bool[jax.Array, 'L'],
|
209 |
| - cache: EmbeddingCache | None = None, # For caching embeddings. |
| 209 | + cache: EmbeddingCache, # For caching embeddings. |
210 | 210 | ) -> tuple[
|
211 | 211 | jt.Float[jax.Array, 'L'],
|
212 | 212 | jt.Float[jax.Array, 'L'],
|
213 |
| - dict[str, jax.Array], |
| 213 | + EmbeddingCache, |
214 | 214 | ]:
|
215 | 215 | """Friendly for inference, no batch dimension."""
|
216 |
| - if cache is None: |
217 |
| - cache = EmbeddingCache() |
218 |
| - |
219 |
| - # [L, E] |
220 |
| - x_pad_emb = cache.get_or_set('x_pad_emb', lambda: self.embed(x_padded)) |
| 216 | + if cache.x_emb is None: |
| 217 | + cache = dataclasses.replace(cache, x_emb=self.embed(x_padded)) |
| 218 | + x_pad_emb = cache.x_emb # [L, E] |
221 | 219 | x_targ_emb = self.embed(x_targ) # [Q, E]
|
222 |
| - L, E = x_pad_emb.shape # pylint: disable=invalid-name |
223 | 220 |
|
224 | 221 | # Combine target and historical (padded) embeddings.
|
225 | 222 | target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
|
226 |
| - padded_target_emb = jnp.zeros((L, E), dtype=x_targ_emb.dtype) |
| 223 | + padded_target_emb = jnp.zeros_like(x_pad_emb) |
227 | 224 | padded_target_emb = jax.lax.dynamic_update_slice_in_dim(
|
228 | 225 | padded_target_emb, x_targ_emb, start_index=target_index, axis=0
|
229 | 226 | )
|
230 | 227 | w_mask = jnp.expand_dims(mask, axis=-1) # [L, 1]
|
231 |
| - x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask) |
| 228 | + x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask) # [L, E] |
232 | 229 |
|
233 | 230 | if self.use_metadata: # Attach metadata embeddings too.
|
234 |
| - metadata_emb = cache.get_or_set( |
235 |
| - 'metadata_emb', lambda: self.embed(metadata) |
236 |
| - ) |
| 231 | + if cache.metadata_emb is None: |
| 232 | + cache = dataclasses.replace(cache, metadata_emb=self.embed(metadata)) |
| 233 | + metadata_emb = cache.metadata_emb # [E] |
237 | 234 | metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
|
238 |
| - metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E] |
| 235 | + metadata_emb = jnp.repeat(metadata_emb, x_emb.shape[0], axis=0) # [L, E] |
239 | 236 | x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]
|
240 | 237 |
|
241 | 238 | mean, std = self.__call__(
|
|
0 commit comments