@@ -628,91 +628,234 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
628628
629629
630630class TransformerEmbedding (nn .Module ):
631- def __init__ (self , config ):
631+ r"""
632+ Transformer-based embedding network for **time series** and **image** data.
633+
634+ This module provides a flexible embedding architecture that supports both
635+ (1) 1D / multivariate time series (e.g., experimental trials, temporal signals),
636+ and
637+ (2) image inputs via a lightweight Vision Transformer (ViT)-style patch embedding.
638+
639+ It is designed for simulation-based inference (SBI) workflows where raw
640+ observations must be encoded into fixed-dimensional embeddings before passing
641+ them to a neural density estimator.
642+
643+ Parameters
644+ ----------
645+ pos_emb :
646+ Positional embedding type. One of ``{"rotary", "positional", "none"}``.
647+ pos_emb_base :
648+ Base frequency for rotary positional embeddings.
649+ rms_norm_eps :
650+ Epsilon for RMSNorm layers.
651+ router_jitter_noise :
652+ Noise added when routing tokens to MoE experts.
653+ vit_dropout :
654+ Dropout applied inside ViT patch embedding layers.
655+ mlp_activation :
656+ Activation used inside the feedforward blocks.
657+ is_causal :
658+ If ``True``, applies a causal mask during attention (useful for time-series).
659+ vit :
660+ If ``True``, enables Vision Transformer mode for 2D image inputs.
661+ num_hidden_layers :
662+ Number of transformer encoder blocks.
663+ num_attention_heads :
664+ Number of self-attention heads.
665+ num_key_value_heads :
666+ Number of KV heads (for multi-query attention).
667+ intermediate_size :
668+ Hidden dimension of feedforward network (or MoE experts).
669+ ffn :
670+ Feedforward type. One of ``{"mlp", "moe"}``.
671+ head_dim :
672+ Per-head embedding dimension. If ``None``, inferred as
673+ ``feature_space_dim // num_attention_heads``.
674+ attention_dropout :
675+ Dropout used inside the attention mechanism.
676+ feature_space_dim :
677+ Dimensionality of the token embeddings flowing through the transformer.
678+ - For time-series, this is the model dimension.
679+ - For images (``vit=True``), this is the post-patch-projection embedding size.
680+ final_emb_dimension :
681+ Output embedding dimension. Defaults to ``feature_space_dim // 2``.
682+ image_size :
683+ Input image height/width (only if ``vit=True``).
684+ patch_size :
685+ ViT patch size (only if ``vit=True``).
686+ num_channels :
687+ Number of image channels for ViT mode.
688+ num_local_experts :
689+ Number of MoE experts (only relevant when ``ffn="moe"``).
690+ num_experts_per_tok :
691+ How many experts each token is routed to in MoE mode.
692+
693+ Notes
694+ -----
695+ **Time-series mode (``vit=False``)**
696+ - Inputs of shape ``(batch, seq_len)`` (scalar series) are automatically
697+ projected to ``(batch, seq_len, feature_space_dim)``.
698+ - Inputs of shape ``(batch, seq_len, features)`` are used as-is.
699+ - Causal masking is applied if ``is_causal=True`` (default).
700+ - Suitable for experimental trials, temporal dynamics, or sets of sequential
701+ observations.
702+
703+ **Image mode (``vit=True``)**
704+ - Inputs must have shape ``(batch, channels, height, width)``.
705+ - Images are patchified, linearly projected, and fed to the transformer.
706+ - Causal masking is disabled in this mode.
707+
708+ **Output**
709+ The embedding is obtained by selecting the final token and applying a linear
710+ projection, resulting in a tensor of shape:
711+
712+ ``(batch, final_emb_dimension)``
713+
714+ Example
715+ -------
716+ **1D time-series (default mode)**::
717+
718+ from sbi.neural_nets.embedding_nets import TransformerEmbedding
719+ import torch
720+
721+ x = torch.randn(16, 100) # (batch, seq_len)
722+ emb = TransformerEmbedding(feature_space_dim=64)
723+ z = emb(x)
724+
725+ **Image input (ViT-style)**::
726+
727+ from sbi.neural_nets.embedding_nets import TransformerEmbedding
728+ import torch
729+
730+ x = torch.randn(8, 3, 64, 64) # (batch, C, H, W)
731+ emb = TransformerEmbedding(
732+ vit=True,
733+ image_size=64,
734+ patch_size=8,
735+ num_channels=3,
736+ feature_space_dim=128,
737+ )
738+ z = emb(x)
739+ """
740+
741+ def __init__ (
742+ self ,
743+ * ,
744+ pos_emb : str = "rotary" ,
745+ pos_emb_base : float = 10e4 ,
746+ rms_norm_eps : float = 1e-05 ,
747+ router_jitter_noise : float = 0.0 ,
748+ vit_dropout : float = 0.5 ,
749+ mlp_activation : str = "gelu" ,
750+ is_causal : bool = True ,
751+ vit : bool = False ,
752+ num_hidden_layers : int = 4 ,
753+ num_attention_heads : int = 12 ,
754+ num_key_value_heads : int = 12 ,
755+ intermediate_size : int = 256 ,
756+ ffn : str = "mlp" ,
757+ head_dim : Optional [int ] = None ,
758+ attention_dropout : float = 0.5 ,
759+ feature_space_dim : int ,
760+ final_emb_dimension : Optional [int ] = None ,
761+ image_size : Optional [int ] = None ,
762+ patch_size : Optional [int ] = None ,
763+ num_channels : Optional [int ] = None ,
764+ num_local_experts : Optional [int ] = None ,
765+ num_experts_per_tok : Optional [int ] = None ,
766+ ):
632767 super ().__init__ ()
633768 """
634769 Main class for constructing a transformer embedding
635- Basic configuration parameters:
636- pos_emb (string): position encoding to be used, currently available:
637- {"rotary", "positional", "none"}
638- pos_emb_base (float): base used to construct the positinal encoding
639- rms_norm_eps (float): noise added to the rms variance computation
640- ffn (string): feedforward layer after used after computing the attention:
641- {"mlp", "moe"}
642- mlp_activation (string): activation function to be used within the ffn
643- layer
644- is_causal (bool): specifies whether causal mask should be created
645- vit (bool): specifies the whether a convolutional layer should be used for
646- processing images, inspired by the vision transformer
647- num_hidden_layer (int): number of transformer blocks
648- num_attention_heads (int): number of attention heads
649- num_key_value_heads (int): number of key/value heads
650- feature_space_dim (int): dimension of the feature vectors
651- intermediate_size (int): hidden size of the feedforward layer
652- head_dim (int): dimension key/query vectors
653- attention_dropout (float): value for the dropout of the attention layer
770+
771+ Args:
772+ pos_emb: position encoding to be used, currently available:
773+ {"rotary", "positional", "none"}
774+ pos_emb_base: base used to construct the positinal encoding
775+ rms_norm_eps: noise added to the rms variance computation
776+ ffn: feedforward layer after used after computing the attention:
777+ {"mlp", "moe"}
778+ mlp_activation: activation function to be used within the ffn
779+ layer
780+ is_causal: specifies whether causal mask should be created
781+ vit: specifies the whether a convolutional layer should be used for
782+ processing images, inspired by the vision transformer
783+ num_hidden_layers: number of transformer blocks
784+ num_attention_heads: number of attention heads
785+ num_key_value_heads: number of key/value heads
786+ feature_space_dim: dimension of the feature vectors
787+ intermediate_size: hidden size of the feedforward layer
788+ head_dim: dimension key/query vectors
789+ attention_dropout: value for the dropout of the attention layer
654790
655791 MoE:
656- router_jitter_noise (float) : noise added before routing the input vectors
657- to the experts
658- num_local_experts (int) : total number of experts
659- num_experts_per_tok (int) : number of experts each token is assigned to
792+ router_jitter_noise: noise added before routing the input vectors
793+ to the experts
794+ num_local_experts: total number of experts
795+ num_experts_per_tok: number of experts each token is assigned to
660796
661797 ViT
662- feature_space_dim (int): dimension of the feature vectors after
663- preprocessing the images
664- image_size (int): dimension of the squared image used to created
665- the positional encoders
666- a rectagular image can be used at training/inference time by
667- resampling the encoders
668- patch_size (int): size of the square patches used to create the
669- positional encoders
670- num_channels (int): number of channels of the input image
671- vit_dropout (float): value for the dropout of the attention layer
672- """
673- self .config = {
674- "pos_emb" : "rotary" ,
675- "pos_emb_base" : 10e4 ,
676- "rms_norm_eps" : 1e-05 ,
677- "router_jitter_noise" : 0.0 ,
678- "vit_dropout" : 0.5 ,
679- "mlp_activation" : "gelu" ,
680- "is_causal" : True ,
681- "vit" : False ,
682- "num_hidden_layers" : 4 ,
683- "num_attention_heads" : 12 ,
684- "num_key_value_heads" : 12 ,
685- "intermediate_size" : 256 ,
686- "ffn" : "mlp" ,
687- "head_dim" : None ,
688- "attention_dropout" : 0.5 ,
689- }
798+ feature_space_dim: dimension of the feature vectors after
799+ preprocessing the images
800+ image_size: dimension of the squared image used to created
801+ the positional encoders
802+ a rectagular image can be used at training/inference time by
803+ resampling the encoders
804+ patch_size: size of the square patches used to create the
805+ positional encoders
806+ num_channels: number of channels of the input image
807+ vit_dropout: value for the dropout of the attention layer
808+ """
809+
810+ self .config = dict (
811+ pos_emb = pos_emb ,
812+ pos_emb_base = pos_emb_base ,
813+ rms_norm_eps = rms_norm_eps ,
814+ router_jitter_noise = router_jitter_noise ,
815+ vit_dropout = vit_dropout ,
816+ mlp_activation = mlp_activation ,
817+ is_causal = is_causal ,
818+ vit = vit ,
819+ num_hidden_layers = num_hidden_layers ,
820+ num_attention_heads = num_attention_heads ,
821+ num_key_value_heads = num_key_value_heads ,
822+ intermediate_size = intermediate_size ,
823+ ffn = ffn ,
824+ head_dim = head_dim ,
825+ attention_dropout = attention_dropout ,
826+ feature_space_dim = feature_space_dim ,
827+ image_size = image_size ,
828+ patch_size = patch_size ,
829+ num_channels = num_channels ,
830+ num_local_experts = num_local_experts ,
831+ num_experts_per_tok = num_experts_per_tok ,
832+ )
690833
691- self .config . update ( config )
834+ self .preprocess = ViTEmbeddings ( self . config ) if vit else IdentityEncoder ( )
692835
693- self .preprocess = (
694- ViTEmbeddings (self .config ) if self .config ["vit" ] else IdentityEncoder ()
695- )
836+ self ._supports_scalar_series = not vit
837+ if self ._supports_scalar_series :
838+ self .scalar_projection = nn .Linear (
839+ 1 , feature_space_dim
840+ ) # proj 1D → model dim
696841
697842 self .layers = nn .ModuleList ([
698- TransformerBlock (self .config )
699- for _ in range (self .config ["num_hidden_layers" ])
843+ TransformerBlock (self .config ) for _ in range (num_hidden_layers )
700844 ])
701- self .is_causal = self . config [ " is_causal" ] and not self . config [ " vit" ]
845+ self .is_causal = is_causal and not vit
702846
703- self .norm = RMSNorm (
704- self .config ["feature_space_dim" ], eps = self .config ["rms_norm_eps" ]
705- )
706- final_emb_dimension = self .config .get (
707- "final_emb_dimension" , self .config ["feature_space_dim" ] // 2
708- )
709- if not config ["vit" ] and final_emb_dimension > self .config ["feature_space_dim" ]:
847+ self .norm = RMSNorm (feature_space_dim , eps = rms_norm_eps )
848+
849+ if final_emb_dimension is None :
850+ final_emb_dimension = feature_space_dim // 2
851+
852+ if not vit and final_emb_dimension > feature_space_dim :
710853 raise ValueError (
711- "The final embedding dimension should be equal or smaller than "
712- "the input dimension"
854+ "The final embedding dimension should be "
855+ "equal or smaller than the input dimension"
713856 )
714857 self .aggregator = nn .Linear (
715- self . config [ " feature_space_dim" ] ,
858+ feature_space_dim ,
716859 final_emb_dimension ,
717860 )
718861 self .causal_mask_cache_ = (None , None , None )
@@ -764,21 +907,26 @@ def forward(
764907 ) -> Tuple [torch .Tensor , Optional [Tuple [torch .Tensor , torch .Tensor ]]]:
765908 """
766909 Args:
767- input (`torch.Tensor`): input of shape `(batch, seq_len,
768- feature_space_dim)`
769- or `(batch, num_channels, height, width)` if using ViT
770- attention_mask (`torch.Tensor`, *optional*) :
910+ input:
911+ input of shape `(batch, seq_len, feature_space_dim)`
912+ or `(batch, num_channels, height, width)` if using ViT
913+ attention_mask:
771914 attention mask of size `(batch_size, sequence_length)`
772- output_attentions (`bool`, *optional*) :
915+ output_attentions:
773916 Whether or not to return the attention tensors
774- cache_attention_mask (`bool`, *optional*) :
917+ cache_attention_mask:
775918 Whether or not to cache the expanded attention mask, useful if using
776919 multiple batched with identical input shapes
777- kwargs (`dict`, *optional*) :
920+ kwargs:
778921 Arbitrary kwargs
779922 """
780923
781924 input = self .preprocess (input )
925+
926+ if self ._supports_scalar_series and input .ndim == 2 :
927+ input = input .unsqueeze (- 1 ) # (B, T, 1)
928+ input = self .scalar_projection (input ) # (B, T, feature_space_dim)
929+
782930 if self .is_causal :
783931 dtype , device = input .dtype , input .device
784932
0 commit comments