diff --git a/optformer/embed_then_regress/icl_transformer.py b/optformer/embed_then_regress/icl_transformer.py index f4fe3a1..17897c1 100644 --- a/optformer/embed_then_regress/icl_transformer.py +++ b/optformer/embed_then_regress/icl_transformer.py @@ -102,7 +102,7 @@ def setup(self): self.embedder = self.embedder_factory() # X, Y, and concatenated X,Y embedders. - self.x_proj = nn.Sequential( + self.xm_proj = nn.Sequential( [Dense(self.d_model), nn.relu, Dense(self.d_model)] ) self.y_proj = nn.Sequential( @@ -136,20 +136,43 @@ def __call__( mask: jt.Bool[jax.Array, 'B L'], deterministic: bool | None = None, rng: jax.Array | None = None, - ) -> tuple[jt.Float[jax.Array, 'B L'], jt.Float[jax.Array, 'B L']]: + embedding_cache: dict[str, jax.Array] | None = None, + ) -> tuple[ + jt.Float[jax.Array, 'B L'], + jt.Float[jax.Array, 'B L'], + dict[str, jax.Array], + ]: # pylint: disable=invalid-name - - B, L, T = x.shape - x = jnp.reshape(x, (B * L, T)) - x = self.embed(x) # [B*L, E] - x = jnp.reshape(x, (B, L, -1)) # [B, L, E] - - metadata = self.embed(metadata) # [B, E] - metadata = jnp.expand_dims(metadata, axis=1) # [B, 1, E] - metadata = jnp.repeat(metadata, L, axis=1) # [B, L, E] - x = jnp.concatenate((x, metadata), axis=-1) # [B, L, 2E] - - xt_emb = self.x_proj(x) # [B, L, D] + L = x.shape[1] + + if embedding_cache is None: + x_emb = self.embed(x) # [B, L, E] + metadata_emb = self.embed(metadata) # [B, E] + embedding_cache = {'x': x_emb, 'metadata': metadata_emb} + else: + # Find starting index of target. Raise value error if masks are not all + # same, since dynamic_update_slice wouldn't work. + target_indices = jnp.sum(mask, axis=-1, dtype=jnp.int32) + if not jnp.all(target_indices == target_indices[0]): + raise ValueError('At inference, all masks must be the same in batch.') + target_index = target_indices[0] + + # Embed only the new tokens. + target_x = x[:, target_index:, :] # [B=1, target_index, T] + target_x_emb = self.embed(target_x) # [B=1, target_index, E] + + x_emb = jax.lax.dynamic_update_slice( + embedding_cache['x'], + target_x_emb, + start_indices=(0, target_index, 0), + ) + metadata_emb = embedding_cache['metadata'] + + metadata_emb = jnp.expand_dims(metadata_emb, axis=1) # [B, 1, E] + metadata_emb = jnp.repeat(metadata_emb, L, axis=1) # [B, L, E] + xm_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [B, L, 2E] + + xt_emb = self.xm_proj(xm_emb) # [B, L, D] # Force 0.0 values for target points using the mask. y = y * mask # [B, L], element-wise multiplication @@ -173,8 +196,12 @@ def __call__( mean = jnp.squeeze(mean, axis=-1) std = jnp.squeeze(std, axis=-1) - return mean, std + return mean, std, embedding_cache @nn.remat # Reduce memory consumption during backward pass. - def embed(self, tokens: jt.Int[jax.Array, 'X T']): - return self.embedder(tokens) + def embed( + self, tokens: jt.Int[jax.Array, '*X T'] + ) -> jt.Float[jax.Array, '*X E']: + reshaped_tokens = jnp.reshape(tokens, (-1, tokens.shape[-1])) + embeddings = self.embedder(reshaped_tokens) # [-1, E] + return jnp.reshape(embeddings, tokens.shape[:-1] + (embeddings.shape[-1],)) diff --git a/optformer/embed_then_regress/regressor.py b/optformer/embed_then_regress/regressor.py index eb3abc0..b272149 100644 --- a/optformer/embed_then_regress/regressor.py +++ b/optformer/embed_then_regress/regressor.py @@ -40,18 +40,21 @@ class StatefulICLRegressor: params: flax_typing.FrozenVariableDict = attrs.field() vocab: seqio.Vocabulary = attrs.field() - max_trial_length: int = attrs.field(default=300, kw_only=True) # L + max_memory_length: int = attrs.field(default=10000, kw_only=True) # M >> L max_token_length: int = attrs.field(default=256, kw_only=True) # T warper: normalization.StatefulWarper = attrs.field( factory=normalization.default_warper, kw_only=True ) - # Internal state containing tokens. - _all_xt: jt.Int[np.ndarray, 'L T'] = attrs.field(init=False) - _all_yt: jt.Float[np.ndarray, 'L'] = attrs.field(init=False) + # Internal state containing history. + _all_xt: jt.Int[np.ndarray, 'M T'] = attrs.field(init=False) + _all_yt: jt.Float[np.ndarray, 'M'] = attrs.field(init=False) _mt: jt.Int[np.ndarray, 'T'] = attrs.field(init=False) _num_prev: int = attrs.field(init=False) + _embedding_cache: dict[str, jax.Array] | None = attrs.field(init=False) + + # Jitted function. _jit_apply: Callable[..., Any] = attrs.field(init=False) def __attrs_post_init__(self): @@ -64,25 +67,29 @@ def predict(self, xs: Sequence[str]) -> tfd.Distribution: """Returns prediction in normalized/warped space.""" num_query = len(xs) - temp_xt = np.copy(self._all_xt) - temp_xt[self._num_prev : self._num_prev + num_query] = self._tokenize(xs) + # Use instead of max_memory_length to reduce embedding costs. + max_trial_length = self._num_prev + num_query # L + + temp_xt = np.copy(self._all_xt)[:max_trial_length] + temp_xt[self._num_prev :] = self._tokenize(xs) - temp_yt = np.copy(self._all_yt) + temp_yt = np.copy(self._all_yt)[:max_trial_length] temp_yt = self.warper.warp(temp_yt) temp_mt = np.copy(self._mt) - mask = np.ones(self.max_trial_length, dtype=bool) + mask = np.ones(max_trial_length, dtype=bool) mask[self._num_prev :] = False # Need to add batch dimension to all inputs. - mean, std = self._jit_apply( + mean, std, self._embedding_cache = self._jit_apply( self.params, x=np.expand_dims(temp_xt, axis=0), # [B=1, L, T], y=np.expand_dims(temp_yt, axis=0), # [B=1, L], metadata=np.expand_dims(temp_mt, axis=0), # [B=1, T], mask=np.expand_dims(mask, axis=0), # [B=1, L], deterministic=True, + embedding_cache=self._embedding_cache, ) mean, std = np.squeeze(mean, axis=0), np.squeeze(std, axis=0) @@ -97,6 +104,7 @@ def absorb(self, xs: Sequence[str], ys: Sequence[float]): self._all_xt[self._num_prev : self._num_prev + num_pts] = self._tokenize(xs) self._all_yt[self._num_prev : self._num_prev + num_pts] = np.array(ys) self._num_prev += num_pts + self._embedding_cache = None # Need to recompute historical embeddings. self.warper.train(self._all_yt[: self._num_prev]) @@ -105,11 +113,12 @@ def set_metadata(self, metadata: str) -> None: def reset(self) -> None: self._all_xt = np.zeros( - (self.max_trial_length, self.max_token_length), dtype=np.int32 + (self.max_memory_length, self.max_token_length), dtype=np.int32 ) - self._all_yt = np.zeros(self.max_trial_length, dtype=np.float32) + self._all_yt = np.zeros(self.max_memory_length, dtype=np.float32) self._mt = np.zeros(self.max_token_length, dtype=np.int32) self._num_prev = 0 + self._embedding_cache = None def _tokenize(self, ss: Sequence[str]) -> jt.Int[np.ndarray, 'S T']: """Converts ss (strings) to tokens.""" diff --git a/optformer/embed_then_regress/train.py b/optformer/embed_then_regress/train.py index 5a97b80..fe96266 100644 --- a/optformer/embed_then_regress/train.py +++ b/optformer/embed_then_regress/train.py @@ -97,7 +97,9 @@ def loss_fn( ) -> tuple[jax.Array, Mapping[str, Scalar]]: """Loss function with metrics.""" # pylint: disable=invalid-name - mean, std = model.apply(params, deterministic=not training, rng=rng, **batch) + mean, std, _ = model.apply( + params, deterministic=not training, rng=rng, **batch + ) nlogprob = -jax.scipy.stats.norm.logpdf(batch['y'], mean, std) # [B, L] # Only compute loss over target ys. Mask is BxL where True denotes context diff --git a/optformer/embed_then_regress/vizier/designer.py b/optformer/embed_then_regress/vizier/designer.py index e51c868..a701d2f 100644 --- a/optformer/embed_then_regress/vizier/designer.py +++ b/optformer/embed_then_regress/vizier/designer.py @@ -40,10 +40,6 @@ use_fori=False, ) -default_scoring_function_factory = acq_lib.bayesian_scoring_function_factory( - lambda _: acq_lib.UCB() -) - @attrs.define class TransformerICLOptDesigner(vza.Designer): @@ -57,7 +53,7 @@ class TransformerICLOptDesigner(vza.Designer): default=default_optimizer_factory, kw_only=True ) _acq_fn: acq_lib.AcquisitionFunction = attrs.field( - default=acq_lib.UCB(), kw_only=True + factory=acq_lib.UCB, kw_only=True ) _rng: jax.Array = attrs.field(