Skip to content

Commit 4f35247

Browse files
authored
Accept list of variables in probleminstance (#14)
* better assertion error in algorithms for variables given * handle names that do not have brackets * accept lists of variables in probleminstance
1 parent 472b0c3 commit 4f35247

File tree

8 files changed

+30
-12
lines changed

8 files changed

+30
-12
lines changed

pycona/active_algorithms/gquacq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
4343
"""
4444
if X is None:
4545
X = instance.X
46-
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
46+
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
47+
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
4748

4849
self.env.init_state(instance, oracle, verbose, metrics)
4950

pycona/active_algorithms/growacq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
4040
"""
4141
if X is None:
4242
X = instance.X
43-
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
43+
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
44+
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
4445

4546
self.env.init_state(instance, oracle, verbose, metrics)
4647

pycona/active_algorithms/mquacq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
3434
"""
3535
if X is None:
3636
X = instance.X
37-
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
37+
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
38+
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
3839

3940
self.env.init_state(instance, oracle, verbose, metrics)
4041

pycona/active_algorithms/mquacq2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
4444
"""
4545
if X is None:
4646
X = instance.X
47-
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
47+
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
48+
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
4849

4950
self.env.init_state(instance, oracle, verbose, metrics)
5051

pycona/active_algorithms/pquacq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
3838
"""
3939
if X is None:
4040
X = instance.X
41-
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
41+
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
42+
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
4243

4344
self.env.init_state(instance, oracle, verbose, metrics)
4445

pycona/active_algorithms/quacq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
3434
"""
3535
if X is None:
3636
X = instance.X
37-
assert isinstance(X, list) and set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a list of variables"
37+
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
38+
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
3839

3940
self.env.init_state(instance, oracle, verbose, metrics)
4041

pycona/problem_instance/problem_instance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ def variables(self, vars):
123123
"""
124124
self._variables = vars
125125
if vars is not None:
126-
self.X = list(self._variables.flatten())
126+
if isinstance(vars, NDVarArray):
127+
self.X = list(self._variables.flatten())
128+
else:
129+
self.X = vars
130+
self._variables = cp.cpm_array(vars)
127131

128132
@property
129133
def X(self):

pycona/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,9 @@ def get_var_name(var):
322322
:return: The name of the variable without its indices.
323323
"""
324324
name = re.findall(r"\[\d+[,\d+]*\]", var.name)
325-
name = var.name.replace(name[0], '')
326-
return name
325+
if name: # Check if we found any indices
326+
return var.name.replace(name[0], '')
327+
return var.name # Return original name if no indices found
327328

328329

329330
def get_var_ndims(var):
@@ -344,9 +345,11 @@ def get_var_dims(var):
344345
Get the dimensions of a variable.
345346
346347
:param var: The variable.
347-
:return: The dimensions of the variable.
348+
:return: The dimensions of the variable. Returns empty list if variable has no indices.
348349
"""
349350
dims = re.findall(r"\[\d+[,\d+]*\]", var.name)
351+
if not dims: # If no indices found
352+
return []
350353
dims_str = "".join(dims)
351354
dims = re.split(r"[\[\]]", dims_str)[1]
352355
dims = [int(dim) for dim in re.split(",", dims)]
@@ -406,13 +409,18 @@ def get_variables_from_constraints(constraints):
406409
:param constraints: List of constraints.
407410
:return: List of variables involved in the constraints.
408411
"""
409-
410412
# Create set to hold unique variables
411413
variable_set = set()
412414
for constraint in constraints:
413415
variable_set.update(get_variables(constraint))
414416

415-
extract_nums = lambda s: list(map(int, s.name[s.name.index("[") + 1:s.name.index("]")].split(',')))
417+
def extract_nums(s):
418+
dims = re.findall(r"\[\d+[,\d+]*\]", s.name)
419+
if not dims:
420+
return [0] # Default value for variables without indices
421+
dims_str = "".join(dims)
422+
dims = re.split(r"[\[\]]", dims_str)[1]
423+
return [int(dim) for dim in re.split(",", dims)]
416424

417425
variable_list = sorted(variable_set, key=extract_nums)
418426
return variable_list

0 commit comments

Comments
 (0)