Skip to content

Commit 59b3d86

Browse files
authored
Merge pull request #2124 from fzyzcjy/patch-1
Fix super tiny type error
2 parents 67b0b3d + b44e4e4 commit 59b3d86

File tree

7 files changed

+15
-9
lines changed

7 files changed

+15
-9
lines changed

timm/scheduler/cosine_lr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import math
99
import numpy as np
1010
import torch
11+
from typing import List
1112

1213
from .scheduler import Scheduler
1314

@@ -77,7 +78,7 @@ def __init__(
7778
else:
7879
self.warmup_steps = [1 for _ in self.base_values]
7980

80-
def _get_lr(self, t):
81+
def _get_lr(self, t: int) -> List[float]:
8182
if t < self.warmup_t:
8283
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
8384
else:

timm/scheduler/multistep_lr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_curr_decay_steps(self, t):
5353
# assumes self.decay_t is sorted
5454
return bisect.bisect_right(self.decay_t, t + 1)
5555

56-
def _get_lr(self, t):
56+
def _get_lr(self, t: int) -> List[float]:
5757
if t < self.warmup_t:
5858
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
5959
else:

timm/scheduler/plateau_lr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Hacked together by / Copyright 2020 Ross Wightman
66
"""
77
import torch
8+
from typing import List
89

910
from .scheduler import Scheduler
1011

@@ -106,5 +107,5 @@ def _apply_noise(self, epoch):
106107
param_group['lr'] = new_lr
107108
self.restore_lr = restore_lr
108109

109-
def _get_lr(self, t: int) -> float:
110+
def _get_lr(self, t: int) -> List[float]:
110111
assert False, 'should not be called as step is overridden'

timm/scheduler/poly_lr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import math
88
import logging
9+
from typing import List
910

1011
import torch
1112

@@ -73,7 +74,7 @@ def __init__(
7374
else:
7475
self.warmup_steps = [1 for _ in self.base_values]
7576

76-
def _get_lr(self, t):
77+
def _get_lr(self, t: int) -> List[float]:
7778
if t < self.warmup_t:
7879
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
7980
else:

timm/scheduler/scheduler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
from abc import ABC
3-
from typing import Any, Dict, Optional
3+
from typing import Any, Dict, List, Optional
44

55
import torch
66

@@ -65,10 +65,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
6565
self.__dict__.update(state_dict)
6666

6767
@abc.abstractmethod
68-
def _get_lr(self, t: int) -> float:
68+
def _get_lr(self, t: int) -> List[float]:
6969
pass
7070

71-
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
71+
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
7272
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
7373
if not proceed:
7474
return None

timm/scheduler/step_lr.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
"""
77
import math
88
import torch
9+
from typing import List
10+
911

1012
from .scheduler import Scheduler
1113

@@ -51,7 +53,7 @@ def __init__(
5153
else:
5254
self.warmup_steps = [1 for _ in self.base_values]
5355

54-
def _get_lr(self, t):
56+
def _get_lr(self, t: int) -> List[float]:
5557
if t < self.warmup_t:
5658
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
5759
else:

timm/scheduler/tanh_lr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import math
99
import numpy as np
1010
import torch
11+
from typing import List
1112

1213
from .scheduler import Scheduler
1314

@@ -75,7 +76,7 @@ def __init__(
7576
else:
7677
self.warmup_steps = [1 for _ in self.base_values]
7778

78-
def _get_lr(self, t):
79+
def _get_lr(self, t: int) -> List[float]:
7980
if t < self.warmup_t:
8081
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
8182
else:

0 commit comments

Comments
 (0)