Skip to content

Commit 9e965d0

Browse files
authored
Merge pull request #110 from ECP-CANDLE/RylieWeaver9-HPO
learning rate base 10 functionality in HPO
2 parents 07711eb + 1427b69 commit 9e965d0

File tree

1 file changed

+185
-122
lines changed

1 file changed

+185
-122
lines changed

workflows/common/python/ga_utils.py

+185-122
Original file line numberDiff line numberDiff line change
@@ -5,98 +5,190 @@
55
import random
66
import sys
77

8+
"""
9+
This script contains the hyperparameter parsing, mutation, and random draw logic for the genetic algorithm
10+
(GA) hyperparameter optimization using deap. The params list are created, then the Hyperparemeters are
11+
parsed from a JSON file based on their class. Default sigma values for mutation are given but can be
12+
provided in the JSON file. The mutation function is defined specially for each parameter type to not corrupt
13+
data types. The float parameter also offers a log 10 random draw and mutation functionality.
814
15+
Note that there are both parameter types and element types, which are not always the same. For example,There
16+
could be floats in a categorical parameter.
17+
"""
18+
19+
20+
"""Setup:"""
21+
22+
# import logging
23+
# logging.basicConfig()
24+
# log = logging.getLogger("a")
25+
# global log
26+
27+
# Functionality for boolean hyperparameters
28+
def str_to_bool(s):
29+
if s.lower() == "true":
30+
return True
31+
else:
32+
return False
33+
34+
# Parse function to determine if something is a number
935
def is_number(s):
1036
try:
1137
float(s)
1238
return True
1339
except ValueError:
1440
return False
41+
42+
# Create parameters from JSON file (main functionality)
43+
def create_parameters(param_file):
44+
with open(param_file) as json_file:
45+
data = json.load(json_file)
1546

47+
params = []
48+
for item in data:
49+
name = item["name"]
50+
type = item["type"]
51+
52+
if type == "int":
53+
lower = int(item["lower"])
54+
upper = int(item["upper"])
55+
# Allow for optional sigma
56+
sigma = None if "sigma" not in item else int(item["sigma"])
57+
params.append(IntParameter(name, lower, upper, sigma))
58+
59+
elif type == "float":
60+
lower = float(item["lower"])
61+
upper = float(item["upper"])
62+
# Allow for optional sigma
63+
sigma = None if "sigma" not in item else float(item["sigma"])
64+
params.append(FloatParameter(name, lower, upper, sigma))
65+
66+
elif type == "categorical":
67+
values = item["values"]
68+
element_type = item["element_type"]
69+
params.append(CategoricalParameter(name, values, element_type))
1670

17-
class ConstantParameter(object):
71+
elif type == "ordered":
72+
values = item["values"]
73+
element_type = item["element_type"]
74+
# Allow for optional sigma
75+
sigma = None if "sigma" not in item else item["sigma"]
76+
params.append(OrderedParameter(name, values, sigma, element_type))
1877

19-
def __init__(self, name, value):
20-
self.name = name
21-
self.value = value
78+
elif type == "logical":
79+
params.append(LogicalParameter(name))
2280

23-
def randomDraw(self):
24-
return self.value
81+
elif type == "constant":
82+
values = item["value"]
83+
params.append(ConstantParameter(name, values))
2584

26-
def mutate(self, x, mu, indpb):
27-
return self.value
85+
return params
2886

29-
def parse(self, s):
30-
if is_number(s):
31-
if "." in s or "e" in s:
32-
return float(s)
33-
return int(s)
34-
return s
3587

88+
"""Numeric Parameters:"""
3689

90+
# Numeric parameter superclass (int or float)
3791
class NumericParameter(object):
3892

39-
def __init__(self, name, lower, upper, sigma):
93+
def __init__(self, name, lower, upper, sigma=None, use_log_scale=False):
94+
# Check for valid bounds
95+
if lower >= upper:
96+
raise ValueError("Lower bound must be less than upper bound.")
97+
if lower <= 0 and use_log_scale:
98+
raise ValueError("Lower bound must be positive for log scale.")
99+
40100
self.name = name
41101
self.lower = lower
42102
self.upper = upper
43-
self.sigma = sigma
44-
45-
def randomDraw(self):
46-
x = self.uni_rand_func(self.lower, self.upper)
103+
self.use_log_scale = use_log_scale
104+
# Calculate default sigma if not provided
105+
self.sigma = sigma if sigma is not None else self.calculate_default_sigma()
106+
107+
# Default sigma calculation
108+
def calculate_default_sigma(self):
109+
if self.use_log_scale:
110+
return self.default_log_sigma()
111+
else:
112+
return self.default_sigma()
113+
114+
def default_sigma(self):
115+
return (self.upper - self.lower) / 10
116+
117+
def default_log_sigma(self):
118+
log_lower = math.log10(self.lower)
119+
log_upper = math.log10(self.upper)
120+
return (log_upper - log_lower) / 10
121+
122+
# General random draw function (returns float)
123+
def draw_float(self):
124+
if self.use_log_scale:
125+
log_lower = math.log10(self.lower)
126+
log_upper = math.log10(self.upper)
127+
x_log = random.uniform(log_lower, log_upper)
128+
x = 10 ** x_log
129+
else:
130+
x = random.uniform(self.lower, self.upper)
47131
return x
132+
133+
# General mutation function (returns float)
134+
def mut_float(self, x, mu, indpb):
135+
if random.random() <= indpb:
136+
if self.use_log_scale:
137+
# Convert to log scale for mutation and then back
138+
x_log = math.log10(x)
139+
x_log += random.gauss(mu, self.sigma)
140+
x = 10 ** x_log
141+
x = max(self.lower, min(self.upper, x))
142+
else:
143+
x += random.gauss(mu, self.sigma)
144+
x = max(self.lower, min(self.upper, x))
145+
return x
146+
48147

49-
148+
# Integer parameter class
50149
class IntParameter(NumericParameter):
51150

52-
def __init__(self, name, lower, upper, sigma):
53-
super(IntParameter, self).__init__(name, lower, upper, sigma)
54-
self.uni_rand_func = random.randint
151+
def __init__(self, name, lower, upper, sigma=None, use_log_scale=False):
152+
super(IntParameter, self).__init__(name, lower, upper, sigma, use_log_scale)
153+
154+
# Round the float and explicitly set as int for random draw and mutation
55155

156+
def randomDraw(self):
157+
return int(round(self.draw_float()))
158+
56159
def mutate(self, x, mu, indpb):
57-
if random.random() <= indpb:
58-
x += random.gauss(mu, self.sigma)
59-
x = int(max(self.lower, min(self.upper, round(x))))
60-
return x
160+
return int(round(self.mut_float(x, mu, indpb)))
61161

62162
def parse(self, s):
63163
return int(s)
64164

65165

166+
# Float parameter class
66167
class FloatParameter(NumericParameter):
67168

68-
def __init__(self, name, lower, upper, sigma):
69-
super(FloatParameter, self).__init__(name, lower, upper, sigma)
70-
self.uni_rand_func = random.uniform
169+
def __init__(self, name, lower, upper, sigma=None, use_log_scale=False):
170+
super(FloatParameter, self).__init__(name, lower, upper, sigma, use_log_scale)
171+
172+
def randomDraw(self):
173+
return self.draw_float()
71174

72175
def mutate(self, x, mu, indpb):
73-
if random.random() <= indpb:
74-
x += random.gauss(mu, self.sigma)
75-
x = max(self.lower, min(self.upper, x))
76-
return x
176+
return self.mut_float(x, mu, indpb)
77177

78178
def parse(self, s):
79179
return float(s)
80180

81181

82-
# import logging
83-
# logging.basicConfig()
84-
# log = logging.getLogger("a")
85-
86-
87-
def str_to_bool(s):
88-
if s.lower() == "true":
89-
return True
90-
else:
91-
return False
92-
182+
"""List Parameters:"""
93183

184+
# List parameter superclass (categorical, ordered, or logical)
94185
class ListParameter(object):
95186

96-
def __init__(self, name, categories, element_type):
187+
def __init__(self, name, elements, element_type):
97188
self.name = name
98-
self.categories = categories
189+
self.elements = elements
99190

191+
# Determine element type within parameter type
100192
if element_type == "float":
101193
self.parse_func = float
102194
elif element_type == "int":
@@ -110,57 +202,72 @@ def __init__(self, name, categories, element_type):
110202
"Invalid type: {} - must be one of 'float', 'int', 'string', or 'logical'"
111203
)
112204

205+
def randomDraw(self):
206+
i = random.randint(0, len(self.elements) - 1)
207+
return self.elements[i]
208+
209+
def calculate_default_sigma(self):
210+
default_sigma = (len(self.elements)) / 10
211+
return default_sigma
212+
113213
def parse(self, s):
114214
return self.parse_func(s)
115215

116-
216+
# Categorical parameter class
117217
class CategoricalParameter(ListParameter):
118218

119-
def __init__(self, name, categories, element_type):
120-
super(CategoricalParameter, self).__init__(name, categories,
121-
element_type)
122-
123-
def randomDraw(self):
124-
i = random.randint(0, len(self.categories) - 1)
125-
return self.categories[i]
219+
def __init__(self, name, elements, element_type):
220+
super(CategoricalParameter, self).__init__(name, elements, element_type)
126221

222+
# Mutation picks randomly from the elements while avoiding the same value
127223
def mutate(self, x, mu, indpb):
128-
global log
129-
if random.random() <= indpb:
224+
if random.random() <= indpb and len(self.elements) > 1: # Avoid mutation forever loop if only one category
130225
a = self.randomDraw()
131226
while x == a:
132227
a = self.randomDraw()
133228
x = a
134229
return x
135230

136-
231+
# Ordered parameter class
137232
class OrderedParameter(ListParameter):
138233

139-
def __init__(self, name, categories, sigma, element_type):
140-
super(OrderedParameter, self).__init__(name, categories, element_type)
141-
self.sigma = sigma
142-
143-
def randomDraw(self):
144-
i = random.randint(0, len(self.categories) - 1)
145-
return self.categories[i]
146-
147-
def drawIndex(self, i):
148-
n = random.randint(1, self.sigma)
149-
n = i + (n if random.random() < 0.5 else -n)
150-
n = max(0, min(len(self.categories) - 1, n))
151-
return n
234+
def __init__(self, name, elements, sigma, element_type):
235+
super(OrderedParameter, self).__init__(name, elements, element_type)
236+
self.sigma = sigma if sigma is not None else self.calculate_default_sigma()
152237

238+
# Gaussian mutation is applied to the index and rounded/bounded
153239
def mutate(self, x, mu, indpb):
154240
if random.random() <= indpb:
155-
i = self.categories.index(x)
156-
n = self.drawIndex(i)
157-
while n == i:
158-
n = self.drawIndex(i)
159-
160-
x = self.categories[n]
241+
i = self.elements.index(x)
242+
i_new = i + random.gauss(mu, self.sigma)
243+
i_new = int(round(max(0, min(len(self.elements) - 1, i_new))))
244+
x = self.elements[i_new]
161245
return x
162246

163247

248+
"""Other Parameters:"""
249+
250+
# Constant parameter class (usually foe epochs or pathing parameters not related to the HPO process)
251+
class ConstantParameter(object):
252+
253+
def __init__(self, name, value):
254+
self.name = name
255+
self.value = value
256+
257+
def randomDraw(self):
258+
return self.value
259+
260+
def mutate(self, x, mu, indpb):
261+
return self.value
262+
263+
def parse(self, s):
264+
if is_number(s):
265+
if "." in s or "e" in s:
266+
return float(s)
267+
return int(s)
268+
return s
269+
270+
# Logical parameter class
164271
class LogicalParameter:
165272

166273
def __init__(self, name):
@@ -181,50 +288,6 @@ def parse(self, s):
181288
return False
182289

183290

184-
def create_parameters(param_file, ignore_sigma=False):
185-
with open(param_file) as json_file:
186-
data = json.load(json_file)
187-
188-
params = []
189-
for item in data:
190-
name = item["name"]
191-
t = item["type"]
192-
if ignore_sigma:
193-
sigma = float("nan")
194-
if t == "int" or t == "float":
195-
lower = item["lower"]
196-
upper = item["upper"]
197-
if not ignore_sigma:
198-
sigma = item["sigma"]
199-
200-
if t == "int":
201-
params.append(
202-
IntParameter(name, int(lower), int(upper), int(sigma)))
203-
else:
204-
params.append(
205-
FloatParameter(name, float(lower), float(upper),
206-
float(sigma)))
207-
208-
elif t == "categorical":
209-
vs = item["values"]
210-
element_type = item["element_type"]
211-
params.append(CategoricalParameter(name, vs, element_type))
212-
213-
elif t == "logical":
214-
params.append(LogicalParameter(name))
215-
216-
elif t == "ordered":
217-
vs = item["values"]
218-
if not ignore_sigma:
219-
sigma = item["sigma"]
220-
element_type = item["element_type"]
221-
params.append(OrderedParameter(name, vs, sigma, element_type))
222-
elif t == "constant":
223-
vs = item["value"]
224-
params.append(ConstantParameter(name, vs))
225-
226-
return params
227-
228-
291+
# Run main function
229292
if __name__ == "__main__":
230293
create_parameters(sys.argv[1])

0 commit comments

Comments
 (0)