@@ -86,6 +86,16 @@ def __call__(
86
86
return x
87
87
88
88
89
+ class EmbeddingCache (dict [str , AnyTensor ]):
90
+
91
+ def get_or_set (self , key : str , fn : Callable [[], AnyTensor ]):
92
+ value = self .get (key )
93
+ if value is None :
94
+ value = fn ()
95
+ self .update ({key : value })
96
+ return value
97
+
98
+
89
99
class ICLTransformer (nn .Module ):
90
100
"""ICL Transformer model for regression."""
91
101
@@ -94,6 +104,7 @@ class ICLTransformer(nn.Module):
94
104
nhead : int # H
95
105
dropout : float
96
106
num_layers : int
107
+ use_metadata : bool
97
108
std_transform_fn : Callable [[AnyTensor ], AnyTensor ]
98
109
embedder_factory : Callable [[], nn .Module ] # __call__: [B, T] -> [B, D]
99
110
@@ -102,7 +113,7 @@ def setup(self):
102
113
self .embedder = self .embedder_factory ()
103
114
104
115
# X, Y, and concatenated X,Y embedders.
105
- self .xm_proj = nn .Sequential (
116
+ self .x_proj = nn .Sequential (
106
117
[Dense (self .d_model ), nn .relu , Dense (self .d_model )]
107
118
)
108
119
self .y_proj = nn .Sequential (
@@ -132,7 +143,6 @@ def __call__(
132
143
self ,
133
144
x_emb : jt .Float [jax .Array , 'B L E' ],
134
145
y : jt .Float [jax .Array , 'B L' ],
135
- metadata_emb : jt .Float [jax .Array , 'B E' ],
136
146
mask : jt .Bool [jax .Array , 'B L' ],
137
147
deterministic : bool | None = None ,
138
148
rng : jax .Array | None = None ,
@@ -178,16 +188,16 @@ def fit(
178
188
rng : jax .Array | None = None ,
179
189
) -> tuple [jt .Float [jax .Array , 'B L' ], jt .Float [jax .Array , 'B L' ]]:
180
190
"""For training / eval loss metrics only."""
181
- # pylint: disable=invalid-name
182
- L = x .shape [1 ]
183
-
184
191
x_emb = self .embed (x ) # [B, L, E]
185
- metadata_emb = self .embed (metadata ) # [B, E]
186
192
187
- metadata_emb = jnp .expand_dims (metadata_emb , axis = 1 ) # [B, 1, E]
188
- metadata_emb = jnp .repeat (metadata_emb , L , axis = 1 ) # [B, L, E]
189
- xm_emb = jnp .concatenate ((x_emb , metadata_emb ), axis = - 1 ) # [B, L, 2E]
190
- return self .__call__ (xm_emb , y , metadata_emb , mask , deterministic , rng )
193
+ if self .use_metadata :
194
+ L = x_emb .shape [1 ] # pylint: disable=invalid-name
195
+ metadata_emb = self .embed (metadata ) # [B, E]
196
+ metadata_emb = jnp .expand_dims (metadata_emb , axis = 1 ) # [B, 1, E]
197
+ metadata_emb = jnp .repeat (metadata_emb , L , axis = 1 ) # [B, L, E]
198
+ x_emb = jnp .concatenate ((x_emb , metadata_emb ), axis = - 1 ) # [B, L, 2E]
199
+
200
+ return self .__call__ (x_emb , y , mask , deterministic , rng )
191
201
192
202
def infer (
193
203
self ,
@@ -196,49 +206,44 @@ def infer(
196
206
x_targ : jt .Int [jax .Array , 'Q T' ], # Q is fixed to avoid re-jitting.
197
207
metadata : jt .Int [jax .Array , 'T' ],
198
208
mask : jt .Bool [jax .Array , 'L' ],
199
- cache : dict [ str , jax . Array ] | None = None , # For caching embeddings.
209
+ cache : EmbeddingCache | None = None , # For caching embeddings.
200
210
) -> tuple [
201
211
jt .Float [jax .Array , 'L' ],
202
212
jt .Float [jax .Array , 'L' ],
203
213
dict [str , jax .Array ],
204
214
]:
205
215
"""Friendly for inference, no batch dimension."""
206
216
if cache is None :
207
- x_padded_emb = self .embed (x_padded ) # [L, E]
208
- metadata_emb = self .embed (metadata ) # [E]
209
- cache = {'x_padded_emb' : x_padded_emb , 'metadata_emb' : metadata_emb }
210
- else :
211
- x_padded_emb = cache ['x_padded_emb' ] # [L, E]
212
- metadata_emb = cache ['metadata_emb' ] # [E]
217
+ cache = EmbeddingCache ()
213
218
219
+ # [L, E]
220
+ x_pad_emb = cache .get_or_set ('x_pad_emb' , lambda : self .embed (x_padded ))
214
221
x_targ_emb = self .embed (x_targ ) # [Q, E]
215
-
216
- L , E = x_padded_emb .shape # pylint: disable=invalid-name
217
-
218
- target_index = jnp .sum (mask , dtype = jnp .int32 ) # [1]
222
+ L , E = x_pad_emb .shape # pylint: disable=invalid-name
219
223
220
224
# Combine target and historical (padded) embeddings.
225
+ target_index = jnp .sum (mask , dtype = jnp .int32 ) # [1]
221
226
padded_target_emb = jnp .zeros ((L , E ), dtype = x_targ_emb .dtype )
222
227
padded_target_emb = jax .lax .dynamic_update_slice_in_dim (
223
228
padded_target_emb , x_targ_emb , start_index = target_index , axis = 0
224
229
)
225
230
w_mask = jnp .expand_dims (mask , axis = - 1 ) # [L, 1]
226
- x_emb = x_padded_emb * w_mask + padded_target_emb * (1 - w_mask )
231
+ x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask )
227
232
228
- # Attach metadata embeddings too.
229
- metadata_emb = jnp .expand_dims (metadata_emb , axis = 0 ) # [1, E]
230
- metadata_emb = jnp .repeat (metadata_emb , L , axis = 0 ) # [L, E]
231
- xm_emb = jnp .concatenate ((x_emb , metadata_emb ), axis = - 1 ) # [L, 2E]
233
+ if self .use_metadata : # Attach metadata embeddings too.
234
+ metadata_emb = cache .get_or_set (
235
+ 'metadata_emb' , lambda : self .embed (metadata )
236
+ )
237
+ metadata_emb = jnp .expand_dims (metadata_emb , axis = 0 ) # [1, E]
238
+ metadata_emb = jnp .repeat (metadata_emb , L , axis = 0 ) # [L, E]
239
+ x_emb = jnp .concatenate ((x_emb , metadata_emb ), axis = - 1 ) # [L, 2E]
232
240
233
- # TODO: Are these batch=1 expands necessary?
234
241
mean , std = self .__call__ (
235
- x_emb = jnp .expand_dims (xm_emb , axis = 0 ),
242
+ x_emb = jnp .expand_dims (x_emb , axis = 0 ),
236
243
y = jnp .expand_dims (y_padded , axis = 0 ),
237
- metadata_emb = jnp .expand_dims (metadata_emb , axis = 0 ),
238
244
mask = jnp .expand_dims (mask , axis = 0 ),
239
245
deterministic = True ,
240
246
)
241
-
242
247
return jnp .squeeze (mean , axis = 0 ), jnp .squeeze (std , axis = 0 ), cache
243
248
244
249
@nn .remat # Reduce memory consumption during backward pass.
0 commit comments