Skip to content

Commit dc392e5

Browse files
committed
Apply minor fixes to Emformer implementation (#2252)
Summary: Noticed some items to clean up in `Emformer`. - Make `segment_length` a required argument in `_EmformerLayer`. - Remove unused variables from `_unpack_state` and `_gen_attention_mask`. These don't affect `Emformer`'s functionality or public API. Pull Request resolved: #2252 Reviewed By: carolineechen, mthrok Differential Revision: D34321430 Pulled By: hwangjeff fbshipit-source-id: 38a5046f633a3e625352c476ef71c78380ccc597
1 parent e9881b9 commit dc392e5

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

torchaudio/models/emformer.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,11 @@ class _EmformerLayer(torch.nn.Module):
321321
input_dim (int): input dimension.
322322
num_heads (int): number of attention heads.
323323
ffn_dim: (int): hidden layer dimension of feedforward network.
324+
segment_length (int): length of each input segment.
324325
dropout (float, optional): dropout probability. (Default: 0.0)
325326
activation (str, optional): activation function to use in feedforward network.
326327
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
327328
left_context_length (int, optional): length of left context. (Default: 0)
328-
segment_length (int, optional): length of each input segment. (Default: 128)
329329
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
330330
weight_init_gain (float or None, optional): scale factor to apply when initializing
331331
attention module parameters. (Default: ``None``)
@@ -338,10 +338,10 @@ def __init__(
338338
input_dim: int,
339339
num_heads: int,
340340
ffn_dim: int,
341+
segment_length: int,
341342
dropout: float = 0.0,
342343
activation: str = "relu",
343344
left_context_length: int = 0,
344-
segment_length: int = 128,
345345
max_memory_size: int = 0,
346346
weight_init_gain: Optional[float] = None,
347347
tanh_on_mem: bool = False,
@@ -386,9 +386,7 @@ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[t
386386
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
387387
return [empty_memory, left_context_key, left_context_val, past_length]
388388

389-
def _unpack_state(
390-
self, utterance: torch.Tensor, mems: torch.Tensor, state: List[torch.Tensor]
391-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
389+
def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
392390
past_length = state[3][0][0].item()
393391
past_left_context_length = min(self.left_context_length, past_length)
394392
past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
@@ -474,7 +472,7 @@ def _apply_attention_infer(
474472
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
475473
if state is None:
476474
state = self._init_state(utterance.size(1), device=utterance.device)
477-
pre_mems, lc_key, lc_val = self._unpack_state(utterance, mems, state)
475+
pre_mems, lc_key, lc_val = self._unpack_state(state)
478476
if self.use_mem:
479477
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
480478
summary = summary[:1]
@@ -652,10 +650,10 @@ def __init__(
652650
input_dim,
653651
num_heads,
654652
ffn_dim,
653+
segment_length,
655654
dropout=dropout,
656655
activation=activation,
657656
left_context_length=left_context_length,
658-
segment_length=segment_length,
659657
max_memory_size=max_memory_size,
660658
weight_init_gain=weight_init_gains[layer_idx],
661659
tanh_on_mem=tanh_on_mem,
@@ -718,7 +716,7 @@ def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) ->
718716
return col_widths
719717

720718
def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
721-
utterance_length, batch_size, _ = input.shape
719+
utterance_length = input.size(0)
722720
num_segs = math.ceil(utterance_length / self.segment_length)
723721

724722
rc_mask = []

0 commit comments

Comments
 (0)