Skip to content

Commit f9a2e31

Browse files
committed
Updates with convergences sim.
1 parent d3c02b0 commit f9a2e31

File tree

2 files changed

+392
-0
lines changed

2 files changed

+392
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import argparse
2+
import itertools
3+
import random
4+
import time
5+
from typing import List, Tuple
6+
7+
import numpy as np
8+
from loguru import logger
9+
from numba import jit
10+
11+
12+
@jit(nopython=True)
13+
def check_success(prob: np.ndarray) -> np.ndarray:
14+
"""Check the success of an event based on a given probability.
15+
16+
Args:
17+
prob (float): The probability of success.
18+
19+
Returns:
20+
int: 1 if the event is successful, 0 otherwise.
21+
"""
22+
return np.random.binomial(1, prob, 1)
23+
24+
25+
@jit(nopython=True)
26+
def determine_action(state: np.ndarray, states: int) -> int:
27+
"""Determine the action based on the state.
28+
29+
Args:
30+
state (int): The current state.
31+
states (int): The number of states.
32+
33+
Returns:
34+
int: 0 if state is less than or equal to states, else 1.
35+
"""
36+
return 0 if state <= states else 1
37+
38+
39+
@jit(nopython=True)
40+
def perform_tmc(n: int, epochs: int, examples: np.ndarray, labels: np.ndarray, clauses: int,
41+
states: int, probabilities: np.ndarray) -> np.ndarray:
42+
"""Perform TMC and return the state of Tsetlin Automata.
43+
44+
Args:
45+
n (int): Number of features.
46+
epochs (int): Number of training epochs.
47+
examples (np.ndarray): Array of examples.
48+
labels (np.ndarray): Array of labels.
49+
clauses (int): Number of clauses.
50+
states (int): Number of states.
51+
probabilities (np.ndarray): Array of probabilities for each clause.
52+
53+
Returns:
54+
np.ndarray: The state of the Tsetlin Automata.
55+
"""
56+
ta_state = np.random.choice(np.array([states, states + 1]), size=(clauses, n, 2)).astype(np.int32)
57+
58+
for i in range(epochs):
59+
# print('')
60+
# print("###################################Epoch", i, "###################################")
61+
# print('')
62+
for e in range(len(examples)):
63+
# print("------------------------",e,"------------------")
64+
# FeedBack on Positives
65+
if labels[e] == 1:
66+
for f in range(n):
67+
for j in range(clauses):
68+
p = probabilities[j]
69+
# Include with some probability
70+
if determine_action(ta_state[j, f, examples[e][f]], states) == 1:
71+
if check_success(p):
72+
if ta_state[j, f, examples[e][f]] < 2 * states:
73+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
74+
else:
75+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
76+
77+
# Exclude Negations
78+
if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 1:
79+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1
80+
81+
# Include with some probability
82+
if determine_action(ta_state[j, f, examples[e][f]], states) == 0:
83+
if check_success(1 - p):
84+
if ta_state[j, f, examples[e][f]] > 1:
85+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
86+
else:
87+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
88+
89+
# Exclude Negations
90+
if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 0:
91+
if ta_state[j, f, int(not examples[e][f])] > 1:
92+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1
93+
94+
if labels[e] == 0:
95+
for f in range(n):
96+
for j in range(clauses):
97+
p = probabilities[j]
98+
if determine_action(ta_state[j, f, examples[e][f]], states) == 1:
99+
if check_success(p):
100+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
101+
else:
102+
if ta_state[j, f, examples[e][f]] < 2 * states:
103+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
104+
105+
if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 1:
106+
if check_success(p):
107+
if ta_state[j, f, int(not examples[e][f])] < 2 * states + 1:
108+
ta_state[j, f, int(not examples[e][f])] = ta_state[
109+
j, f, int(not examples[e][f])] + 1
110+
else:
111+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1
112+
113+
if determine_action(ta_state[j, f, examples[e][f]], states) == 0:
114+
if check_success(1 - p):
115+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
116+
117+
else:
118+
if ta_state[j, f, examples[e][f]] > 1:
119+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
120+
121+
if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 0:
122+
if check_success(1 - p):
123+
if ta_state[j, f, int(not examples[e][f])] > 1:
124+
ta_state[j, f, int(not examples[e][f])] = ta_state[
125+
j, f, int(not examples[e][f])] - 1
126+
else:
127+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] + 1
128+
129+
return ta_state
130+
131+
132+
# @jit(nopython=True)
133+
# Function to calculate accuracy
134+
def calculate_accuracy(examples: List[List[int]], formulas: List[List[int]], n: int, labels: List[int]) -> float:
135+
"""Calculate the accuracy of the learned formulas.
136+
137+
Args:
138+
examples (List[List[int]]): List of examples.
139+
formulas (List[List[int]]): List of learned formulas.
140+
n (int): Number of features.
141+
labels (List[int]): List of correct labels.
142+
143+
Returns:
144+
float: The accuracy of the learned formulas.
145+
"""
146+
accur = 0
147+
148+
for e in range(len(examples)):
149+
allLabels = []
150+
# print("------------", e,"---------------")
151+
predicted = 0
152+
for c in formulas:
153+
label = 1
154+
for f in range(n):
155+
if c.__contains__((f + 1)) and examples[e][f] == 0:
156+
label = 0
157+
continue
158+
if c.__contains__(-1 * (f + 1)) and examples[e][f] == 1:
159+
label = 0
160+
continue
161+
allLabels.append(label)
162+
if allLabels.__contains__(1):
163+
predicted = 1
164+
# if sum(allLabels) > len(formulas) / 2:
165+
# predicted = 1
166+
if predicted == labels[e]:
167+
accur = accur + 1
168+
return accur / len(examples)
169+
170+
171+
def main(args):
172+
start_all = time.process_time()
173+
num_features = args.num_features
174+
examples = list(map(list, itertools.product([0, 1], repeat=num_features)))
175+
target = args.target
176+
labels = [1 if all(x == target_i for x, target_i in zip(example, target)) else 0 for example in examples]
177+
178+
accuracies = []
179+
for i in range(args.runs):
180+
logger.info(f"----- Run {i} --------")
181+
X_training, y_training = np.array(examples), np.array(labels)
182+
X_test, y_test = X_training, y_training
183+
logger.info(f"Labels Counts: {np.unique(y_test, return_counts=True)}")
184+
185+
probabilities = np.array([random.uniform(args.lower_prob, args.upper_prob) for _ in range(args.clauses)])
186+
logger.info(f"Probabilities: {probabilities}")
187+
188+
start = time.process_time()
189+
result = perform_tmc(num_features, args.epochs, X_training, y_training, args.clauses, args.states, probabilities)
190+
191+
formulas = [[(j + 1) if result[i, j, 1] > args.states else -(j + 1) for j in range(num_features)] for i in
192+
range(args.clauses)]
193+
194+
logger.info(f"Learned Formulas: {formulas}")
195+
logger.info(f"Time Taken: {time.process_time() - start} seconds")
196+
197+
accuracy = calculate_accuracy(X_test, formulas, num_features, y_test)
198+
accuracies.append(accuracy)
199+
logger.info(f"Accuracy: {accuracy}")
200+
201+
logger.info(f"Accuracies: {accuracies}")
202+
logger.info(f"MIN Accuracy: {min(accuracies)}")
203+
logger.info(f"MAX Accuracy: {max(accuracies)}")
204+
logger.info(f"AVG Accuracy: {sum(accuracies) / len(accuracies)}")
205+
logger.info(f"STD Accuracy: {np.std(accuracies)}")
206+
logger.info(f"Total Time Taken: {time.process_time() - start_all} seconds")
207+
208+
209+
if __name__ == '__main__':
210+
parser = argparse.ArgumentParser(description='Perform TMC.')
211+
parser.add_argument('--num_features', type=int, default=10, help='Number of features')
212+
parser.add_argument('--target', type=int, nargs='+', default=[0, 1, 2, 0, 0, 1, 2, 2, 0, 1], help='Target list')
213+
parser.add_argument('--states', type=int, default=10000, help='Number of states')
214+
parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs')
215+
parser.add_argument('--clauses', type=int, default=1, help='Number of clauses')
216+
parser.add_argument('--runs', type=int, default=1, help='Number of runs')
217+
parser.add_argument('--lower_prob', type=float, default=0.6, help='Lower bound for probability')
218+
parser.add_argument('--upper_prob', type=float, default=0.8, help='Upper bound for probability')
219+
220+
args = parser.parse_args()
221+
main(args)
+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# This is a sample Python script.
2+
import argparse
3+
# Press Shift+F10 to execute it or replace it with your code.
4+
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
5+
6+
import itertools
7+
import random
8+
import numpy as np
9+
from loguru import logger
10+
from scipy.stats import bernoulli
11+
12+
import time
13+
from numba import jit
14+
15+
@jit(nopython=True)
16+
def success(prob: float) -> int:
17+
"""Determine success based on a probability value."""
18+
return 1 if np.random.binomial(1, prob, 1) else 0
19+
20+
21+
@jit(nopython=True)
22+
def action(state: np.ndarray, states: int) -> int:
23+
"""Determine action based on state and states."""
24+
return 0 if state <= states else 1
25+
26+
27+
@jit(nopython=True)
28+
def TMC(n: int, epochs: int, examples: np.ndarray, labels: np.ndarray,
29+
clauses: int, states: int, p: float) -> np.ndarray:
30+
"""Perform the Tsetlin Machine Classifier algorithm."""
31+
ta_state = np.random.choice(np.array([states, states + 1]), size=(clauses, n, 2)).astype(np.int32)
32+
33+
for i in range(epochs):
34+
#print('')
35+
#print("###################################Epoch", i, "###################################")
36+
#print('')
37+
for e in range(len(examples)):
38+
# print("------------------------",e,"------------------")
39+
# FeedBack on Positives
40+
if labels[e] == 1:
41+
for f in range(n):
42+
for j in range(clauses):
43+
#p = probabilitie[j]
44+
# Include with some probability
45+
if action(ta_state[j, f, examples[e][f]], states) == 1:
46+
if success(p):
47+
if ta_state[j, f, examples[e][f]] < 2 * states:
48+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
49+
else:
50+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
51+
52+
# Exclude Negations
53+
if action(ta_state[j, f, int(not examples[e][f])], states) == 1:
54+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1
55+
56+
# Include with some probability
57+
if action(ta_state[j, f, examples[e][f]], states) == 0:
58+
if success(1 - p):
59+
if ta_state[j, f, examples[e][f]] > 1:
60+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
61+
else:
62+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
63+
64+
# Exclude Negations
65+
if action(ta_state[j, f, int(not examples[e][f])], states) == 0:
66+
if ta_state[j, f, int(not examples[e][f])] > 1:
67+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1
68+
69+
if labels[e] == 0:
70+
for f in range(n):
71+
for j in range(clauses):
72+
#p = probabilitie[j]
73+
if action(ta_state[j, f, examples[e][f]], states) == 1:
74+
if success(p):
75+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
76+
else:
77+
if ta_state[j, f, examples[e][f]] < 2 * states:
78+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
79+
80+
if action(ta_state[j, f, int(not examples[e][f])], states) == 1:
81+
if success(p):
82+
if ta_state[j, f, int(not examples[e][f])] < 2 * states + 1:
83+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] + 1
84+
else:
85+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1
86+
87+
if action(ta_state[j, f, examples[e][f]], states) == 0:
88+
if success(1 - p):
89+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1
90+
91+
else:
92+
if ta_state[j, f, examples[e][f]] > 1:
93+
ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1
94+
95+
if action(ta_state[j, f, int(not examples[e][f])], states) == 0:
96+
if success(1 - p):
97+
if ta_state[j, f, int(not examples[e][f])] > 1:
98+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1
99+
else:
100+
ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] + 1
101+
102+
return ta_state
103+
104+
105+
def main(n: int, states: int, epochs: int, clauses: int, runs: int, p: float) -> None:
106+
start_all = time.process_time()
107+
examples = list(map(list, itertools.product([0, 1], repeat=n)))
108+
suc, fails = 0, []
109+
110+
for run_number in range(runs):
111+
start = time.process_time()
112+
examples = list(map(list, itertools.product([0, 1], repeat=n)))
113+
logger.info("------------- Run {} -----------------------------", run_number)
114+
115+
target = np.random.choice(3, n, replace=True)
116+
labels = [1] * len(examples)
117+
118+
for X in examples:
119+
if any((x == 1 and t == 0) or (x == 0 and t == 1) for x, t in zip(X, target)):
120+
labels[examples.index(X)] = 0
121+
122+
s = ""
123+
t = []
124+
for i in range(len(target)):
125+
if target[i] == 1:
126+
s += f"X_{i+1} and "
127+
t.append(i+1)
128+
if target[i] == 0:
129+
s += f"not X_{i + 1} and "
130+
t.append(-i - 1)
131+
132+
logger.info("Target: {}", s[:-5]) # removed the last " and "
133+
logger.info("Target List: {}", t)
134+
135+
result = TMC(n, epochs, np.array(examples), np.array(labels), clauses, states, p)
136+
137+
formulas = []
138+
for i in range(clauses):
139+
c = [j + 1 if result[i, j, 1] > states else -j - 1 for j in range(n) if result[i, j, 0] > states or result[i, j, 1] > states]
140+
formulas.append(c)
141+
142+
logger.info("Learned: {}", formulas)
143+
144+
if t in formulas:
145+
suc += 1
146+
else:
147+
fails.append(t)
148+
149+
logger.info("Contains Target: {}", t in formulas)
150+
logger.info("Time for this run: {:.2f} seconds", time.process_time() - start)
151+
152+
logger.info("Successes: {}/{}", suc, runs)
153+
logger.info("Fails: {}", fails)
154+
logger.info("Total Time: {:.2f} seconds", time.process_time() - start_all)
155+
156+
157+
158+
159+
if __name__ == '__main__':
160+
parser = argparse.ArgumentParser(description='Execute Tsetlin Machine Classifier Algorithm.')
161+
parser.add_argument('--n', type=int, default=5, help='Number of elements.')
162+
parser.add_argument('--states', type=int, default=10000, help='Number of states.')
163+
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs.')
164+
parser.add_argument('--clauses', type=int, default=1, help='Number of clauses.')
165+
parser.add_argument('--runs', type=int, default=100, help='Number of runs.')
166+
parser.add_argument('--p', type=float, default=0.75, help='Probability value.')
167+
168+
args = parser.parse_args()
169+
170+
logger.info(f"Executing with arguments: {args}")
171+
main(args.n, args.states, args.epochs, args.clauses, args.runs, args.p)

0 commit comments

Comments
 (0)