Skip to content

Commit de6d0d0

Browse files
committed
Fixed TypeError in inner class
1 parent cd635b9 commit de6d0d0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchtune/modules/position_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional
7+
from typing import Any, Optional, Union
88

99
import torch
1010
from torch import nn
@@ -349,7 +349,7 @@ def __init__(self, dim_model: int, kdim: int, hidden_size: int) -> None:
349349
)
350350

351351
# concave function to amplify differences among local positions
352-
def phi(self, c: nn.Parameter, x: int | torch.Tensor) -> torch.Tensor:
352+
def phi(self, c: nn.Parameter, x: Union[int, torch.Tensor]) -> torch.Tensor:
353353
return torch.log1p(c * x)
354354

355355
def forward(self, src: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)