@@ -321,11 +321,11 @@ class _EmformerLayer(torch.nn.Module):
321
321
input_dim (int): input dimension.
322
322
num_heads (int): number of attention heads.
323
323
ffn_dim: (int): hidden layer dimension of feedforward network.
324
+ segment_length (int): length of each input segment.
324
325
dropout (float, optional): dropout probability. (Default: 0.0)
325
326
activation (str, optional): activation function to use in feedforward network.
326
327
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
327
328
left_context_length (int, optional): length of left context. (Default: 0)
328
- segment_length (int, optional): length of each input segment. (Default: 128)
329
329
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
330
330
weight_init_gain (float or None, optional): scale factor to apply when initializing
331
331
attention module parameters. (Default: ``None``)
@@ -338,10 +338,10 @@ def __init__(
338
338
input_dim : int ,
339
339
num_heads : int ,
340
340
ffn_dim : int ,
341
+ segment_length : int ,
341
342
dropout : float = 0.0 ,
342
343
activation : str = "relu" ,
343
344
left_context_length : int = 0 ,
344
- segment_length : int = 128 ,
345
345
max_memory_size : int = 0 ,
346
346
weight_init_gain : Optional [float ] = None ,
347
347
tanh_on_mem : bool = False ,
@@ -386,9 +386,7 @@ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[t
386
386
past_length = torch .zeros (1 , batch_size , dtype = torch .int32 , device = device )
387
387
return [empty_memory , left_context_key , left_context_val , past_length ]
388
388
389
- def _unpack_state (
390
- self , utterance : torch .Tensor , mems : torch .Tensor , state : List [torch .Tensor ]
391
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
389
+ def _unpack_state (self , state : List [torch .Tensor ]) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
392
390
past_length = state [3 ][0 ][0 ].item ()
393
391
past_left_context_length = min (self .left_context_length , past_length )
394
392
past_mem_length = min (self .max_memory_size , math .ceil (past_length / self .segment_length ))
@@ -474,7 +472,7 @@ def _apply_attention_infer(
474
472
) -> Tuple [torch .Tensor , torch .Tensor , List [torch .Tensor ]]:
475
473
if state is None :
476
474
state = self ._init_state (utterance .size (1 ), device = utterance .device )
477
- pre_mems , lc_key , lc_val = self ._unpack_state (utterance , mems , state )
475
+ pre_mems , lc_key , lc_val = self ._unpack_state (state )
478
476
if self .use_mem :
479
477
summary = self .memory_op (utterance .permute (1 , 2 , 0 )).permute (2 , 0 , 1 )
480
478
summary = summary [:1 ]
@@ -652,10 +650,10 @@ def __init__(
652
650
input_dim ,
653
651
num_heads ,
654
652
ffn_dim ,
653
+ segment_length ,
655
654
dropout = dropout ,
656
655
activation = activation ,
657
656
left_context_length = left_context_length ,
658
- segment_length = segment_length ,
659
657
max_memory_size = max_memory_size ,
660
658
weight_init_gain = weight_init_gains [layer_idx ],
661
659
tanh_on_mem = tanh_on_mem ,
@@ -718,7 +716,7 @@ def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) ->
718
716
return col_widths
719
717
720
718
def _gen_attention_mask (self , input : torch .Tensor ) -> torch .Tensor :
721
- utterance_length , batch_size , _ = input .shape
719
+ utterance_length = input .size ( 0 )
722
720
num_segs = math .ceil (utterance_length / self .segment_length )
723
721
724
722
rc_mask = []
0 commit comments