-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutilize_prototypemodel.py
71 lines (57 loc) · 2.06 KB
/
utilize_prototypemodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
class QModel2:
def __init__(self, func):
self.q_model = func
pass
def get_q_prediction(self, state, action):
return self.q_model(state, action)
class QModel:
def __init__(self, func, seed = 1):
self.q_model = func
self.seed = seed
pass
def get_q_prediction(self, state, action, mediator):
np.random.seed(self.seed)
return self.q_model(state, action, mediator)
class PAModel:
def __init__(self, func, seed = 1):
self.pa_model = func
self.seed = seed
pass
def get_pa_prediction(self, state, action):
np.random.seed(self.seed)
return self.pa_model(state, action)
def get_pa_ratio(self, state, policy_action, action):
pa1 = self.get_pa_prediction(state, policy_action)
pa2 = self.get_pa_prediction(state, action)
return pa1 / pa2
class PMModel:
def __init__(self, func, seed = 1, ratio_noise=False, false_func=None):
self.pm_model = func
self.ratio_noise = ratio_noise
self.false_model = false_func
self.seed = seed
pass
def get_pm_prediction(self, state, action, mediator):
np.random.seed(self.seed)
if self.false_model is None:
pm_prediction = self.pm_model(state, action, mediator)
else:
pm_prediction = self.false_model(state, action, mediator)
return pm_prediction
def get_pm_ratio(self, state, policy_action, action, mediator):
pm1 = self.get_pm_prediction(state, policy_action, mediator)
pm2 = self.get_pm_prediction(state, action, mediator)
pm_ratio = pm1 / pm2
if self.ratio_noise:
pm_ratio += np.random.normal(scale=0.05, size=pm_ratio.shape[0])
pm_ratio = np.clip(pm_ratio, a_min=0.01, a_max=100)
return pm_ratio
class RatioModel:
def __init__(self, func, seed=1):
self.r_model = func
self.seed = seed
pass
def get_r_prediction(self, state):
np.random.seed(self.seed)
return self.r_model(state)