|
| 1 | +import argparse |
| 2 | +import itertools |
| 3 | +import random |
| 4 | +import time |
| 5 | +from typing import List, Tuple |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from loguru import logger |
| 9 | +from numba import jit |
| 10 | + |
| 11 | + |
| 12 | +@jit(nopython=True) |
| 13 | +def check_success(prob: np.ndarray) -> np.ndarray: |
| 14 | + """Check the success of an event based on a given probability. |
| 15 | +
|
| 16 | + Args: |
| 17 | + prob (float): The probability of success. |
| 18 | +
|
| 19 | + Returns: |
| 20 | + int: 1 if the event is successful, 0 otherwise. |
| 21 | + """ |
| 22 | + return np.random.binomial(1, prob, 1) |
| 23 | + |
| 24 | + |
| 25 | +@jit(nopython=True) |
| 26 | +def determine_action(state: np.ndarray, states: int) -> int: |
| 27 | + """Determine the action based on the state. |
| 28 | +
|
| 29 | + Args: |
| 30 | + state (int): The current state. |
| 31 | + states (int): The number of states. |
| 32 | +
|
| 33 | + Returns: |
| 34 | + int: 0 if state is less than or equal to states, else 1. |
| 35 | + """ |
| 36 | + return 0 if state <= states else 1 |
| 37 | + |
| 38 | + |
| 39 | +@jit(nopython=True) |
| 40 | +def perform_tmc(n: int, epochs: int, examples: np.ndarray, labels: np.ndarray, clauses: int, |
| 41 | + states: int, probabilities: np.ndarray) -> np.ndarray: |
| 42 | + """Perform TMC and return the state of Tsetlin Automata. |
| 43 | +
|
| 44 | + Args: |
| 45 | + n (int): Number of features. |
| 46 | + epochs (int): Number of training epochs. |
| 47 | + examples (np.ndarray): Array of examples. |
| 48 | + labels (np.ndarray): Array of labels. |
| 49 | + clauses (int): Number of clauses. |
| 50 | + states (int): Number of states. |
| 51 | + probabilities (np.ndarray): Array of probabilities for each clause. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + np.ndarray: The state of the Tsetlin Automata. |
| 55 | + """ |
| 56 | + ta_state = np.random.choice(np.array([states, states + 1]), size=(clauses, n, 2)).astype(np.int32) |
| 57 | + |
| 58 | + for i in range(epochs): |
| 59 | + # print('') |
| 60 | + # print("###################################Epoch", i, "###################################") |
| 61 | + # print('') |
| 62 | + for e in range(len(examples)): |
| 63 | + # print("------------------------",e,"------------------") |
| 64 | + # FeedBack on Positives |
| 65 | + if labels[e] == 1: |
| 66 | + for f in range(n): |
| 67 | + for j in range(clauses): |
| 68 | + p = probabilities[j] |
| 69 | + # Include with some probability |
| 70 | + if determine_action(ta_state[j, f, examples[e][f]], states) == 1: |
| 71 | + if check_success(p): |
| 72 | + if ta_state[j, f, examples[e][f]] < 2 * states: |
| 73 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1 |
| 74 | + else: |
| 75 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1 |
| 76 | + |
| 77 | + # Exclude Negations |
| 78 | + if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 1: |
| 79 | + ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1 |
| 80 | + |
| 81 | + # Include with some probability |
| 82 | + if determine_action(ta_state[j, f, examples[e][f]], states) == 0: |
| 83 | + if check_success(1 - p): |
| 84 | + if ta_state[j, f, examples[e][f]] > 1: |
| 85 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1 |
| 86 | + else: |
| 87 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1 |
| 88 | + |
| 89 | + # Exclude Negations |
| 90 | + if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 0: |
| 91 | + if ta_state[j, f, int(not examples[e][f])] > 1: |
| 92 | + ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1 |
| 93 | + |
| 94 | + if labels[e] == 0: |
| 95 | + for f in range(n): |
| 96 | + for j in range(clauses): |
| 97 | + p = probabilities[j] |
| 98 | + if determine_action(ta_state[j, f, examples[e][f]], states) == 1: |
| 99 | + if check_success(p): |
| 100 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1 |
| 101 | + else: |
| 102 | + if ta_state[j, f, examples[e][f]] < 2 * states: |
| 103 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1 |
| 104 | + |
| 105 | + if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 1: |
| 106 | + if check_success(p): |
| 107 | + if ta_state[j, f, int(not examples[e][f])] < 2 * states + 1: |
| 108 | + ta_state[j, f, int(not examples[e][f])] = ta_state[ |
| 109 | + j, f, int(not examples[e][f])] + 1 |
| 110 | + else: |
| 111 | + ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] - 1 |
| 112 | + |
| 113 | + if determine_action(ta_state[j, f, examples[e][f]], states) == 0: |
| 114 | + if check_success(1 - p): |
| 115 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] + 1 |
| 116 | + |
| 117 | + else: |
| 118 | + if ta_state[j, f, examples[e][f]] > 1: |
| 119 | + ta_state[j, f, examples[e][f]] = ta_state[j, f, examples[e][f]] - 1 |
| 120 | + |
| 121 | + if determine_action(ta_state[j, f, int(not examples[e][f])], states) == 0: |
| 122 | + if check_success(1 - p): |
| 123 | + if ta_state[j, f, int(not examples[e][f])] > 1: |
| 124 | + ta_state[j, f, int(not examples[e][f])] = ta_state[ |
| 125 | + j, f, int(not examples[e][f])] - 1 |
| 126 | + else: |
| 127 | + ta_state[j, f, int(not examples[e][f])] = ta_state[j, f, int(not examples[e][f])] + 1 |
| 128 | + |
| 129 | + return ta_state |
| 130 | + |
| 131 | + |
| 132 | +# @jit(nopython=True) |
| 133 | +# Function to calculate accuracy |
| 134 | +def calculate_accuracy(examples: List[List[int]], formulas: List[List[int]], n: int, labels: List[int]) -> float: |
| 135 | + """Calculate the accuracy of the learned formulas. |
| 136 | +
|
| 137 | + Args: |
| 138 | + examples (List[List[int]]): List of examples. |
| 139 | + formulas (List[List[int]]): List of learned formulas. |
| 140 | + n (int): Number of features. |
| 141 | + labels (List[int]): List of correct labels. |
| 142 | +
|
| 143 | + Returns: |
| 144 | + float: The accuracy of the learned formulas. |
| 145 | + """ |
| 146 | + accur = 0 |
| 147 | + |
| 148 | + for e in range(len(examples)): |
| 149 | + allLabels = [] |
| 150 | + # print("------------", e,"---------------") |
| 151 | + predicted = 0 |
| 152 | + for c in formulas: |
| 153 | + label = 1 |
| 154 | + for f in range(n): |
| 155 | + if c.__contains__((f + 1)) and examples[e][f] == 0: |
| 156 | + label = 0 |
| 157 | + continue |
| 158 | + if c.__contains__(-1 * (f + 1)) and examples[e][f] == 1: |
| 159 | + label = 0 |
| 160 | + continue |
| 161 | + allLabels.append(label) |
| 162 | + if allLabels.__contains__(1): |
| 163 | + predicted = 1 |
| 164 | + # if sum(allLabels) > len(formulas) / 2: |
| 165 | + # predicted = 1 |
| 166 | + if predicted == labels[e]: |
| 167 | + accur = accur + 1 |
| 168 | + return accur / len(examples) |
| 169 | + |
| 170 | + |
| 171 | +def main(args): |
| 172 | + start_all = time.process_time() |
| 173 | + num_features = args.num_features |
| 174 | + examples = list(map(list, itertools.product([0, 1], repeat=num_features))) |
| 175 | + target = args.target |
| 176 | + labels = [1 if all(x == target_i for x, target_i in zip(example, target)) else 0 for example in examples] |
| 177 | + |
| 178 | + accuracies = [] |
| 179 | + for i in range(args.runs): |
| 180 | + logger.info(f"----- Run {i} --------") |
| 181 | + X_training, y_training = np.array(examples), np.array(labels) |
| 182 | + X_test, y_test = X_training, y_training |
| 183 | + logger.info(f"Labels Counts: {np.unique(y_test, return_counts=True)}") |
| 184 | + |
| 185 | + probabilities = np.array([random.uniform(args.lower_prob, args.upper_prob) for _ in range(args.clauses)]) |
| 186 | + logger.info(f"Probabilities: {probabilities}") |
| 187 | + |
| 188 | + start = time.process_time() |
| 189 | + result = perform_tmc(num_features, args.epochs, X_training, y_training, args.clauses, args.states, probabilities) |
| 190 | + |
| 191 | + formulas = [[(j + 1) if result[i, j, 1] > args.states else -(j + 1) for j in range(num_features)] for i in |
| 192 | + range(args.clauses)] |
| 193 | + |
| 194 | + logger.info(f"Learned Formulas: {formulas}") |
| 195 | + logger.info(f"Time Taken: {time.process_time() - start} seconds") |
| 196 | + |
| 197 | + accuracy = calculate_accuracy(X_test, formulas, num_features, y_test) |
| 198 | + accuracies.append(accuracy) |
| 199 | + logger.info(f"Accuracy: {accuracy}") |
| 200 | + |
| 201 | + logger.info(f"Accuracies: {accuracies}") |
| 202 | + logger.info(f"MIN Accuracy: {min(accuracies)}") |
| 203 | + logger.info(f"MAX Accuracy: {max(accuracies)}") |
| 204 | + logger.info(f"AVG Accuracy: {sum(accuracies) / len(accuracies)}") |
| 205 | + logger.info(f"STD Accuracy: {np.std(accuracies)}") |
| 206 | + logger.info(f"Total Time Taken: {time.process_time() - start_all} seconds") |
| 207 | + |
| 208 | + |
| 209 | +if __name__ == '__main__': |
| 210 | + parser = argparse.ArgumentParser(description='Perform TMC.') |
| 211 | + parser.add_argument('--num_features', type=int, default=10, help='Number of features') |
| 212 | + parser.add_argument('--target', type=int, nargs='+', default=[0, 1, 2, 0, 0, 1, 2, 2, 0, 1], help='Target list') |
| 213 | + parser.add_argument('--states', type=int, default=10000, help='Number of states') |
| 214 | + parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs') |
| 215 | + parser.add_argument('--clauses', type=int, default=1, help='Number of clauses') |
| 216 | + parser.add_argument('--runs', type=int, default=1, help='Number of runs') |
| 217 | + parser.add_argument('--lower_prob', type=float, default=0.6, help='Lower bound for probability') |
| 218 | + parser.add_argument('--upper_prob', type=float, default=0.8, help='Upper bound for probability') |
| 219 | + |
| 220 | + args = parser.parse_args() |
| 221 | + main(args) |
0 commit comments