Skip to content

Commit 3316888

Browse files
authored
feat: 1D time-series support in TransformerEmbedding (#1703)
* ENH: add support for scalar time series to TransformerEmbedding and add corresponding tests * support explicit kwargs in TransformerEmbedding * fix docstring indentation and add class level docstring * Removed redundant line-breaks
1 parent e186adb commit 3316888

File tree

3 files changed

+251
-78
lines changed

3 files changed

+251
-78
lines changed

docs/sbi.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Embedding nets
4646
sbi.neural_nets.embedding_nets.PermutationInvariantEmbedding
4747
sbi.neural_nets.embedding_nets.ResNetEmbedding1D
4848
sbi.neural_nets.embedding_nets.ResNetEmbedding2D
49+
sbi.neural_nets.embedding_nets.TransformerEmbedding
4950

5051

5152
Training

sbi/neural_nets/embedding_nets/transformer.py

Lines changed: 224 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -628,91 +628,234 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
628628

629629

630630
class TransformerEmbedding(nn.Module):
631-
def __init__(self, config):
631+
r"""
632+
Transformer-based embedding network for **time series** and **image** data.
633+
634+
This module provides a flexible embedding architecture that supports both
635+
(1) 1D / multivariate time series (e.g., experimental trials, temporal signals),
636+
and
637+
(2) image inputs via a lightweight Vision Transformer (ViT)-style patch embedding.
638+
639+
It is designed for simulation-based inference (SBI) workflows where raw
640+
observations must be encoded into fixed-dimensional embeddings before passing
641+
them to a neural density estimator.
642+
643+
Parameters
644+
----------
645+
pos_emb :
646+
Positional embedding type. One of ``{"rotary", "positional", "none"}``.
647+
pos_emb_base :
648+
Base frequency for rotary positional embeddings.
649+
rms_norm_eps :
650+
Epsilon for RMSNorm layers.
651+
router_jitter_noise :
652+
Noise added when routing tokens to MoE experts.
653+
vit_dropout :
654+
Dropout applied inside ViT patch embedding layers.
655+
mlp_activation :
656+
Activation used inside the feedforward blocks.
657+
is_causal :
658+
If ``True``, applies a causal mask during attention (useful for time-series).
659+
vit :
660+
If ``True``, enables Vision Transformer mode for 2D image inputs.
661+
num_hidden_layers :
662+
Number of transformer encoder blocks.
663+
num_attention_heads :
664+
Number of self-attention heads.
665+
num_key_value_heads :
666+
Number of KV heads (for multi-query attention).
667+
intermediate_size :
668+
Hidden dimension of feedforward network (or MoE experts).
669+
ffn :
670+
Feedforward type. One of ``{"mlp", "moe"}``.
671+
head_dim :
672+
Per-head embedding dimension. If ``None``, inferred as
673+
``feature_space_dim // num_attention_heads``.
674+
attention_dropout :
675+
Dropout used inside the attention mechanism.
676+
feature_space_dim :
677+
Dimensionality of the token embeddings flowing through the transformer.
678+
- For time-series, this is the model dimension.
679+
- For images (``vit=True``), this is the post-patch-projection embedding size.
680+
final_emb_dimension :
681+
Output embedding dimension. Defaults to ``feature_space_dim // 2``.
682+
image_size :
683+
Input image height/width (only if ``vit=True``).
684+
patch_size :
685+
ViT patch size (only if ``vit=True``).
686+
num_channels :
687+
Number of image channels for ViT mode.
688+
num_local_experts :
689+
Number of MoE experts (only relevant when ``ffn="moe"``).
690+
num_experts_per_tok :
691+
How many experts each token is routed to in MoE mode.
692+
693+
Notes
694+
-----
695+
**Time-series mode (``vit=False``)**
696+
- Inputs of shape ``(batch, seq_len)`` (scalar series) are automatically
697+
projected to ``(batch, seq_len, feature_space_dim)``.
698+
- Inputs of shape ``(batch, seq_len, features)`` are used as-is.
699+
- Causal masking is applied if ``is_causal=True`` (default).
700+
- Suitable for experimental trials, temporal dynamics, or sets of sequential
701+
observations.
702+
703+
**Image mode (``vit=True``)**
704+
- Inputs must have shape ``(batch, channels, height, width)``.
705+
- Images are patchified, linearly projected, and fed to the transformer.
706+
- Causal masking is disabled in this mode.
707+
708+
**Output**
709+
The embedding is obtained by selecting the final token and applying a linear
710+
projection, resulting in a tensor of shape:
711+
712+
``(batch, final_emb_dimension)``
713+
714+
Example
715+
-------
716+
**1D time-series (default mode)**::
717+
718+
from sbi.neural_nets.embedding_nets import TransformerEmbedding
719+
import torch
720+
721+
x = torch.randn(16, 100) # (batch, seq_len)
722+
emb = TransformerEmbedding(feature_space_dim=64)
723+
z = emb(x)
724+
725+
**Image input (ViT-style)**::
726+
727+
from sbi.neural_nets.embedding_nets import TransformerEmbedding
728+
import torch
729+
730+
x = torch.randn(8, 3, 64, 64) # (batch, C, H, W)
731+
emb = TransformerEmbedding(
732+
vit=True,
733+
image_size=64,
734+
patch_size=8,
735+
num_channels=3,
736+
feature_space_dim=128,
737+
)
738+
z = emb(x)
739+
"""
740+
741+
def __init__(
742+
self,
743+
*,
744+
pos_emb: str = "rotary",
745+
pos_emb_base: float = 10e4,
746+
rms_norm_eps: float = 1e-05,
747+
router_jitter_noise: float = 0.0,
748+
vit_dropout: float = 0.5,
749+
mlp_activation: str = "gelu",
750+
is_causal: bool = True,
751+
vit: bool = False,
752+
num_hidden_layers: int = 4,
753+
num_attention_heads: int = 12,
754+
num_key_value_heads: int = 12,
755+
intermediate_size: int = 256,
756+
ffn: str = "mlp",
757+
head_dim: Optional[int] = None,
758+
attention_dropout: float = 0.5,
759+
feature_space_dim: int,
760+
final_emb_dimension: Optional[int] = None,
761+
image_size: Optional[int] = None,
762+
patch_size: Optional[int] = None,
763+
num_channels: Optional[int] = None,
764+
num_local_experts: Optional[int] = None,
765+
num_experts_per_tok: Optional[int] = None,
766+
):
632767
super().__init__()
633768
"""
634769
Main class for constructing a transformer embedding
635-
Basic configuration parameters:
636-
pos_emb (string): position encoding to be used, currently available:
637-
{"rotary", "positional", "none"}
638-
pos_emb_base (float): base used to construct the positinal encoding
639-
rms_norm_eps (float): noise added to the rms variance computation
640-
ffn (string): feedforward layer after used after computing the attention:
641-
{"mlp", "moe"}
642-
mlp_activation (string): activation function to be used within the ffn
643-
layer
644-
is_causal (bool): specifies whether causal mask should be created
645-
vit (bool): specifies the whether a convolutional layer should be used for
646-
processing images, inspired by the vision transformer
647-
num_hidden_layer (int): number of transformer blocks
648-
num_attention_heads (int): number of attention heads
649-
num_key_value_heads (int): number of key/value heads
650-
feature_space_dim (int): dimension of the feature vectors
651-
intermediate_size (int): hidden size of the feedforward layer
652-
head_dim (int): dimension key/query vectors
653-
attention_dropout (float): value for the dropout of the attention layer
770+
771+
Args:
772+
pos_emb: position encoding to be used, currently available:
773+
{"rotary", "positional", "none"}
774+
pos_emb_base: base used to construct the positinal encoding
775+
rms_norm_eps: noise added to the rms variance computation
776+
ffn: feedforward layer after used after computing the attention:
777+
{"mlp", "moe"}
778+
mlp_activation: activation function to be used within the ffn
779+
layer
780+
is_causal: specifies whether causal mask should be created
781+
vit: specifies the whether a convolutional layer should be used for
782+
processing images, inspired by the vision transformer
783+
num_hidden_layers: number of transformer blocks
784+
num_attention_heads: number of attention heads
785+
num_key_value_heads: number of key/value heads
786+
feature_space_dim: dimension of the feature vectors
787+
intermediate_size: hidden size of the feedforward layer
788+
head_dim: dimension key/query vectors
789+
attention_dropout: value for the dropout of the attention layer
654790
655791
MoE:
656-
router_jitter_noise (float): noise added before routing the input vectors
657-
to the experts
658-
num_local_experts (int): total number of experts
659-
num_experts_per_tok (int): number of experts each token is assigned to
792+
router_jitter_noise: noise added before routing the input vectors
793+
to the experts
794+
num_local_experts: total number of experts
795+
num_experts_per_tok: number of experts each token is assigned to
660796
661797
ViT
662-
feature_space_dim (int): dimension of the feature vectors after
663-
preprocessing the images
664-
image_size (int): dimension of the squared image used to created
665-
the positional encoders
666-
a rectagular image can be used at training/inference time by
667-
resampling the encoders
668-
patch_size (int): size of the square patches used to create the
669-
positional encoders
670-
num_channels (int): number of channels of the input image
671-
vit_dropout (float): value for the dropout of the attention layer
672-
"""
673-
self.config = {
674-
"pos_emb": "rotary",
675-
"pos_emb_base": 10e4,
676-
"rms_norm_eps": 1e-05,
677-
"router_jitter_noise": 0.0,
678-
"vit_dropout": 0.5,
679-
"mlp_activation": "gelu",
680-
"is_causal": True,
681-
"vit": False,
682-
"num_hidden_layers": 4,
683-
"num_attention_heads": 12,
684-
"num_key_value_heads": 12,
685-
"intermediate_size": 256,
686-
"ffn": "mlp",
687-
"head_dim": None,
688-
"attention_dropout": 0.5,
689-
}
798+
feature_space_dim: dimension of the feature vectors after
799+
preprocessing the images
800+
image_size: dimension of the squared image used to created
801+
the positional encoders
802+
a rectagular image can be used at training/inference time by
803+
resampling the encoders
804+
patch_size: size of the square patches used to create the
805+
positional encoders
806+
num_channels: number of channels of the input image
807+
vit_dropout: value for the dropout of the attention layer
808+
"""
809+
810+
self.config = dict(
811+
pos_emb=pos_emb,
812+
pos_emb_base=pos_emb_base,
813+
rms_norm_eps=rms_norm_eps,
814+
router_jitter_noise=router_jitter_noise,
815+
vit_dropout=vit_dropout,
816+
mlp_activation=mlp_activation,
817+
is_causal=is_causal,
818+
vit=vit,
819+
num_hidden_layers=num_hidden_layers,
820+
num_attention_heads=num_attention_heads,
821+
num_key_value_heads=num_key_value_heads,
822+
intermediate_size=intermediate_size,
823+
ffn=ffn,
824+
head_dim=head_dim,
825+
attention_dropout=attention_dropout,
826+
feature_space_dim=feature_space_dim,
827+
image_size=image_size,
828+
patch_size=patch_size,
829+
num_channels=num_channels,
830+
num_local_experts=num_local_experts,
831+
num_experts_per_tok=num_experts_per_tok,
832+
)
690833

691-
self.config.update(config)
834+
self.preprocess = ViTEmbeddings(self.config) if vit else IdentityEncoder()
692835

693-
self.preprocess = (
694-
ViTEmbeddings(self.config) if self.config["vit"] else IdentityEncoder()
695-
)
836+
self._supports_scalar_series = not vit
837+
if self._supports_scalar_series:
838+
self.scalar_projection = nn.Linear(
839+
1, feature_space_dim
840+
) # proj 1D → model dim
696841

697842
self.layers = nn.ModuleList([
698-
TransformerBlock(self.config)
699-
for _ in range(self.config["num_hidden_layers"])
843+
TransformerBlock(self.config) for _ in range(num_hidden_layers)
700844
])
701-
self.is_causal = self.config["is_causal"] and not self.config["vit"]
845+
self.is_causal = is_causal and not vit
702846

703-
self.norm = RMSNorm(
704-
self.config["feature_space_dim"], eps=self.config["rms_norm_eps"]
705-
)
706-
final_emb_dimension = self.config.get(
707-
"final_emb_dimension", self.config["feature_space_dim"] // 2
708-
)
709-
if not config["vit"] and final_emb_dimension > self.config["feature_space_dim"]:
847+
self.norm = RMSNorm(feature_space_dim, eps=rms_norm_eps)
848+
849+
if final_emb_dimension is None:
850+
final_emb_dimension = feature_space_dim // 2
851+
852+
if not vit and final_emb_dimension > feature_space_dim:
710853
raise ValueError(
711-
"The final embedding dimension should be equal or smaller than "
712-
"the input dimension"
854+
"The final embedding dimension should be "
855+
"equal or smaller than the input dimension"
713856
)
714857
self.aggregator = nn.Linear(
715-
self.config["feature_space_dim"],
858+
feature_space_dim,
716859
final_emb_dimension,
717860
)
718861
self.causal_mask_cache_ = (None, None, None)
@@ -764,21 +907,26 @@ def forward(
764907
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
765908
"""
766909
Args:
767-
input (`torch.Tensor`): input of shape `(batch, seq_len,
768-
feature_space_dim)`
769-
or `(batch, num_channels, height, width)` if using ViT
770-
attention_mask (`torch.Tensor`, *optional*):
910+
input:
911+
input of shape `(batch, seq_len, feature_space_dim)`
912+
or `(batch, num_channels, height, width)` if using ViT
913+
attention_mask:
771914
attention mask of size `(batch_size, sequence_length)`
772-
output_attentions (`bool`, *optional*):
915+
output_attentions:
773916
Whether or not to return the attention tensors
774-
cache_attention_mask (`bool`, *optional*):
917+
cache_attention_mask:
775918
Whether or not to cache the expanded attention mask, useful if using
776919
multiple batched with identical input shapes
777-
kwargs (`dict`, *optional*):
920+
kwargs:
778921
Arbitrary kwargs
779922
"""
780923

781924
input = self.preprocess(input)
925+
926+
if self._supports_scalar_series and input.ndim == 2:
927+
input = input.unsqueeze(-1) # (B, T, 1)
928+
input = self.scalar_projection(input) # (B, T, feature_space_dim)
929+
782930
if self.is_causal:
783931
dtype, device = input.dtype, input.device
784932

tests/embedding_net_test.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def repeat_to_match_shape(x, input_shape):
260260
)
261261
@pytest.mark.parametrize("seq_length", (24, 13, 5))
262262
def test_transformer_embedding(config, seq_length):
263-
net = TransformerEmbedding(config=config)
263+
net = TransformerEmbedding(**config)
264264

265265
def simulator(theta):
266266
x = MultivariateNormal(
@@ -291,7 +291,7 @@ def simulator(theta):
291291
)
292292
@pytest.mark.parametrize("img_shape", ((3, 32, 24), (3, 64, 64)))
293293
def test_transformer_vitembedding(config, img_shape):
294-
net = TransformerEmbedding(config=config)
294+
net = TransformerEmbedding(**config)
295295

296296
def simulator(theta):
297297
x = MultivariateNormal(
@@ -311,6 +311,30 @@ def simulator(theta):
311311
_test_helper_embedding_net(prior, xo, simulator, net)
312312

313313

314+
@pytest.mark.parametrize("seq_length", (10, 20))
315+
def test_transformer_embedding_scalar_timeseries(seq_length):
316+
net = TransformerEmbedding(
317+
pos_emb="rotary",
318+
feature_space_dim=32,
319+
num_attention_heads=4,
320+
num_key_value_heads=4,
321+
vit=False,
322+
head_dim=None,
323+
intermediate_size=64,
324+
num_hidden_layers=2,
325+
attention_dropout=0.1,
326+
)
327+
328+
def simulator(theta):
329+
batch = theta.shape[0]
330+
return torch.randn(batch, seq_length) + theta[:, 0:1]
331+
332+
xo = torch.randn(1, seq_length) # shape: (1, T)
333+
prior = MultivariateNormal(torch.zeros(1), torch.eye(1))
334+
335+
_test_helper_embedding_net(prior, xo, simulator, net)
336+
337+
314338
def _test_helper_embedding_net(prior, xo, simulator, net):
315339
estimator_provider = posterior_nn(
316340
"mdn",

0 commit comments

Comments
 (0)