-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathweight_methods.py
119 lines (103 loc) · 3.73 KB
/
weight_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from abc import abstractmethod
from typing import Dict, List, Tuple, Union
import torch
class WeightMethod:
def __init__(self, n_tasks: int, device: torch.device):
super().__init__()
self.n_tasks = n_tasks
self.device = device
@abstractmethod
def get_weighted_loss(
self,
losses: torch.Tensor,
shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
task_specific_parameters: Union[
List[torch.nn.parameter.Parameter], torch.Tensor
],
last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor],
representation: Union[torch.nn.parameter.Parameter, torch.Tensor],
**kwargs,
):
pass
def backward(
self,
losses: torch.Tensor,
shared_parameters: Union[
List[torch.nn.parameter.Parameter], torch.Tensor
] = None,
task_specific_parameters: Union[
List[torch.nn.parameter.Parameter], torch.Tensor
] = None,
last_shared_parameters: Union[
List[torch.nn.parameter.Parameter], torch.Tensor
] = None,
representation: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
**kwargs,
) -> Tuple[Union[torch.Tensor, None], Union[dict, None]]:
loss, extra_outputs = self.get_weighted_loss(
losses=losses,
shared_parameters=shared_parameters,
task_specific_parameters=task_specific_parameters,
last_shared_parameters=last_shared_parameters,
representation=representation,
**kwargs,
)
loss.backward()
return loss, extra_outputs
def __call__(
self,
losses: torch.Tensor,
shared_parameters: Union[
List[torch.nn.parameter.Parameter], torch.Tensor
] = None,
task_specific_parameters: Union[
List[torch.nn.parameter.Parameter], torch.Tensor
] = None,
**kwargs,
):
return self.backward(
losses=losses,
shared_parameters=shared_parameters,
task_specific_parameters=task_specific_parameters,
**kwargs,
)
def parameters(self) -> List[torch.Tensor]:
"""return learnable parameters"""
return []
class EqualWeight(WeightMethod):
def __init__(
self,
n_tasks: int,
device: torch.device,
task_weights: Union[List[float], torch.Tensor] = None,
):
super().__init__(n_tasks, device=device)
if task_weights is None:
task_weights = torch.ones((n_tasks,))
if not isinstance(task_weights, torch.Tensor):
task_weights = torch.tensor(task_weights)
assert len(task_weights) == n_tasks
self.task_weights = task_weights.to(device)
def get_weighted_loss(self, losses, **kwargs):
loss = torch.sum(losses * self.task_weights)
return loss, dict(weights=self.task_weights)
class WeightMethods:
def __init__(self, method: str, n_tasks: int, device: torch.device, **kwargs):
"""
:param method:
"""
assert method in list(METHODS.keys()), f"unknown method {method}."
self.method = METHODS[method](n_tasks=n_tasks, device=device, **kwargs)
def get_weighted_loss(self, losses, **kwargs):
return self.method.get_weighted_loss(losses, **kwargs)
def backward(
self, losses, **kwargs
) -> Tuple[Union[torch.Tensor, None], Union[Dict, None]]:
return self.method.backward(losses, **kwargs)
def __ceil__(self, losses, **kwargs):
return self.backward(losses, **kwargs)
def parameters(self):
return self.method.parameters()
METHODS = dict(
equalweight=EqualWeight
)