Skip to content

Commit 197c477

Browse files
committed
more examples of sum. Allow sum method to receive no argument, in this case it will sum over all variables
1 parent a4bcb2e commit 197c477

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

brml/prob_tables/array.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import numpy as np
33
from brml.prob_tables.potential import Potential
4+
from brml.prob_tables.const import Const
45

56

67
class Array(Potential):
@@ -22,9 +23,9 @@ class is not implemented.
2223
:return:
2324
"""
2425
# TODO: See if LogArray class is needed or this is enough.
25-
return Array(variables=self.variables, table=np.log(self.table))
26+
return self.__class__(self.variables, np.log(self.table))
2627

27-
def sum(self, variables):
28+
def sum(self, variables=None):
2829
"""
2930
This method sums the probability table over provided variables. It is
3031
the marginalization operation.
@@ -35,24 +36,31 @@ def sum(self, variables):
3536
:param variables:
3637
:return:
3738
"""
39+
# Allow for default "None" to sum over all variables
40+
if variables is None:
41+
variables = self.variables
3842
# New Array will have variables in self that are not in variables
3943
# because those will be summed over
40-
newpot = Array(variables=np.setdiff1d(self.variables, variables))
44+
remaining_vars = np.setdiff1d(self.variables, variables)
45+
if remaining_vars.size == 0:
46+
newpot = Const()
47+
else:
48+
newpot = Array(variables=remaining_vars)
4149
# Find the indexes of the variables that we will sum over. These are
42-
# all those in self that are in variables
50+
# all those in self that are in variables. Notice that nonzero()
51+
# returns a tuple. Each array in the tuple contains indexes of non-zero
52+
# elements for that dimension. That is there is an array for each
53+
# dimension. But since variables are going to be flat, should be first
54+
# element.
4355
to_sumover = np.isin(self.variables, variables).nonzero()[0]
44-
# FIXME: Notice that nonzero() returns a tuple. Each array in the tuple
45-
# FIXME: contains indexes of non-zero elements for that dimension. That
46-
# FIXME: is there is an array for each dimension. But since variables
47-
# FIXME: are going to be flat, should be first element.
4856
# Find the table
4957
t = copy.deepcopy(self.table)
50-
for variable_index in to_sumover:
58+
for variable_index in to_sumover[::-1]:
59+
# FIXME: Notice that at the moment I am summing over in opposite
60+
# FIXME: order, this is so indexes don't disappear before they are
61+
# FIXME: Used
5162
t = t.sum(axis=variable_index)
5263
# Add table to Array
53-
# TODO: When we sum over all variables, Array will be initialized with
54-
# TODO: no variables, thus we need to create Const potential class,
55-
# TODO: having no variables, but just a table with one element.
5664
newpot.set_table(table=t)
5765
return newpot
5866

@@ -73,3 +81,9 @@ def sumpot(self, axis):
7381
print("b: ", b)
7482
print("Number of states of a: ", a.size())
7583
print("Number of states of b: ", b.size())
84+
# Sum
85+
pot = Array([1, 2], np.array([[0.4, 0.6], [0.3, 0.7]]))
86+
print("Potential Array: \n", pot)
87+
print("Marginalize var 1: \n", pot.sum(1))
88+
print("Marginalize var 2: \n", pot.sum(2))
89+
print("Marginalize var [1,2]: \n", pot.sum([1, 2]))

0 commit comments

Comments
 (0)