@@ -321,11 +321,11 @@ class _EmformerLayer(torch.nn.Module):
321321 input_dim (int): input dimension.
322322 num_heads (int): number of attention heads.
323323 ffn_dim: (int): hidden layer dimension of feedforward network.
324+ segment_length (int): length of each input segment.
324325 dropout (float, optional): dropout probability. (Default: 0.0)
325326 activation (str, optional): activation function to use in feedforward network.
326327 Must be one of ("relu", "gelu", "silu"). (Default: "relu")
327328 left_context_length (int, optional): length of left context. (Default: 0)
328- segment_length (int, optional): length of each input segment. (Default: 128)
329329 max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
330330 weight_init_gain (float or None, optional): scale factor to apply when initializing
331331 attention module parameters. (Default: ``None``)
@@ -338,10 +338,10 @@ def __init__(
338338 input_dim : int ,
339339 num_heads : int ,
340340 ffn_dim : int ,
341+ segment_length : int ,
341342 dropout : float = 0.0 ,
342343 activation : str = "relu" ,
343344 left_context_length : int = 0 ,
344- segment_length : int = 128 ,
345345 max_memory_size : int = 0 ,
346346 weight_init_gain : Optional [float ] = None ,
347347 tanh_on_mem : bool = False ,
@@ -386,9 +386,7 @@ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[t
386386 past_length = torch .zeros (1 , batch_size , dtype = torch .int32 , device = device )
387387 return [empty_memory , left_context_key , left_context_val , past_length ]
388388
389- def _unpack_state (
390- self , utterance : torch .Tensor , mems : torch .Tensor , state : List [torch .Tensor ]
391- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
389+ def _unpack_state (self , state : List [torch .Tensor ]) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
392390 past_length = state [3 ][0 ][0 ].item ()
393391 past_left_context_length = min (self .left_context_length , past_length )
394392 past_mem_length = min (self .max_memory_size , math .ceil (past_length / self .segment_length ))
@@ -474,7 +472,7 @@ def _apply_attention_infer(
474472 ) -> Tuple [torch .Tensor , torch .Tensor , List [torch .Tensor ]]:
475473 if state is None :
476474 state = self ._init_state (utterance .size (1 ), device = utterance .device )
477- pre_mems , lc_key , lc_val = self ._unpack_state (utterance , mems , state )
475+ pre_mems , lc_key , lc_val = self ._unpack_state (state )
478476 if self .use_mem :
479477 summary = self .memory_op (utterance .permute (1 , 2 , 0 )).permute (2 , 0 , 1 )
480478 summary = summary [:1 ]
@@ -652,10 +650,10 @@ def __init__(
652650 input_dim ,
653651 num_heads ,
654652 ffn_dim ,
653+ segment_length ,
655654 dropout = dropout ,
656655 activation = activation ,
657656 left_context_length = left_context_length ,
658- segment_length = segment_length ,
659657 max_memory_size = max_memory_size ,
660658 weight_init_gain = weight_init_gains [layer_idx ],
661659 tanh_on_mem = tanh_on_mem ,
@@ -718,7 +716,7 @@ def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) ->
718716 return col_widths
719717
720718 def _gen_attention_mask (self , input : torch .Tensor ) -> torch .Tensor :
721- utterance_length , batch_size , _ = input .shape
719+ utterance_length = input .size ( 0 )
722720 num_segs = math .ceil (utterance_length / self .segment_length )
723721
724722 rc_mask = []
0 commit comments