Skip to content

Commit 0171696

Browse files
committed
Minor changes
1 parent b795329 commit 0171696

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

oml/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from oml.models.audio.ecapa_tdnn.extractor import ECAPATDNNExtractor
12
from oml.models.meta.projection import ExtractorWithMLP
23
from oml.models.meta.siamese import (
34
ConcatSiamese,

oml/models/audio/ecapa_tdnn/extractor.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@ def feat_dim(self) -> int:
8282
return self.model.fc6.weight.shape[0]
8383

8484
def forward(self, x: torch.Tensor) -> torch.Tensor:
85-
assert x.ndim == 2, "The model expects input audio to have shape (batch_size, n_samples)"
85+
assert x.ndim == 2 or (
86+
x.ndim == 3 and x.shape[1] == 1
87+
), "The model expects input audio to have shape (batch_size, n_samples) or (batch_size, 1, n_samples)"
88+
89+
if x.ndim == 3:
90+
x = x.squeeze(1)
8691

8792
x = self.model.forward(x, aug=False)
8893
if self.normalise_features:

oml/registry/models.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch import nn
55

66
from oml.interfaces.models import IExtractor, IPairwiseModel
7+
from oml.models.audio.ecapa_tdnn.extractor import ECAPATDNNExtractor
78
from oml.models.meta.projection import ExtractorWithMLP
89
from oml.models.meta.siamese import (
910
ConcatSiamese,
@@ -22,6 +23,7 @@
2223
"vit_clip": ViTCLIPExtractor,
2324
"vit_unicom": ViTUnicomExtractor,
2425
"extractor_with_mlp": ExtractorWithMLP,
26+
"ecapa_tdnn": ECAPATDNNExtractor,
2527
}
2628

2729
PAIRWISE_MODELS_REGISTRY = {

0 commit comments

Comments
 (0)