diff --git a/deep_learning4e.py b/deep_learning4e.py index 734a9307c..0a0387afc 100644 --- a/deep_learning4e.py +++ b/deep_learning4e.py @@ -13,23 +13,6 @@ class Node: - """ - A node in a computational graph contains the pointer to all its parents. - :param val: value of current node - :param parents: a container of all parents of current node - """ - - def __init__(self, val=None, parents=None): - if parents is None: - parents = [] - self.val = val - self.parents = parents - - def __repr__(self): - return "".format(self.val) - - -class NNUnit(Node): """ A single unit of a layer in a neural network :param weights: weights between parent nodes and current node @@ -37,7 +20,7 @@ class NNUnit(Node): """ def __init__(self, weights=None, value=None): - super().__init__(value) + self.value = value self.weights = weights or [] @@ -47,8 +30,8 @@ class Layer: :param size: number of units in the current layer """ - def __init__(self, size=3): - self.nodes = [NNUnit() for _ in range(size)] + def __init__(self, size): + self.nodes = [Node() for _ in range(size)] def forward(self, inputs): """Define the operation to get the output of this layer""" @@ -65,7 +48,7 @@ def forward(self, inputs): """Take each value of the inputs to each unit in the layer.""" assert len(self.nodes) == len(inputs) for node, inp in zip(self.nodes, inputs): - node.val = inp + node.value = inp return inputs @@ -79,7 +62,7 @@ def forward(self, inputs): assert len(self.nodes) == len(inputs) res = softmax1D(inputs) for node, val in zip(self.nodes, res): - node.val = val + node.value = val return res @@ -91,11 +74,11 @@ class DenseLayer(Layer): :param activation: (Activation object) activation function """ - def __init__(self, in_size=3, out_size=3, activation=None): + def __init__(self, in_size=3, out_size=3, activation=Sigmoid): super().__init__(out_size) self.out_size = out_size self.inputs = None - self.activation = Sigmoid() if not activation else activation + self.activation = activation() # initialize weights for node in self.nodes: node.weights = random_weights(-0.5, 0.5, in_size) @@ -105,8 +88,8 @@ def forward(self, inputs): res = [] # get the output value of each unit for unit in self.nodes: - val = self.activation.f(dot_product(unit.weights, inputs)) - unit.val = val + val = self.activation.function(dot_product(unit.weights, inputs)) + unit.value = val res.append(val) return res @@ -131,7 +114,7 @@ def forward(self, features): for node, feature in zip(self.nodes, features): out = conv1D(feature, node.weights) res.append(out) - node.val = out + node.value = out return res @@ -157,7 +140,7 @@ def forward(self, features): out = [max(feature[i:i + self.kernel_size]) for i in range(len(feature) - self.kernel_size + 1)] res.append(out) - self.nodes[i].val = out + self.nodes[i].value = out return res @@ -181,7 +164,7 @@ def init_examples(examples, idx_i, idx_t, o_units): return inputs, targets -def gradient_descent(dataset, net, loss, epochs=1000, l_rate=0.01, batch_size=1, verbose=None): +def stochastic_gradient_descent(dataset, net, loss, epochs=1000, l_rate=0.01, batch_size=1, verbose=None): """ Gradient descent algorithm to update the learnable parameters of a network. :return: the updated network @@ -200,6 +183,7 @@ def gradient_descent(dataset, net, loss, epochs=1000, l_rate=0.01, batch_size=1, # update weights with gradient descent weights = vector_add(weights, scalar_vector_product(-l_rate, gs)) total_loss += batch_loss + # update the weights of network each batch for i in range(len(net)): if weights[i]: @@ -310,7 +294,7 @@ def BackPropagation(inputs, targets, theta, net, loss): # backward pass for i in range(h_layers, 0, -1): layer = net[i] - derivative = [layer.activation.derivative(node.val) for node in layer.nodes] + derivative = [layer.activation.derivative(node.value) for node in layer.nodes] delta[i] = element_wise_product(previous, derivative) # pass to layer i-1 in the next iteration previous = matrix_multiplication([delta[i]], theta[i])[0] @@ -344,7 +328,7 @@ def forward(self, inputs): for i in range(len(self.nodes)): val = [(inputs[i] - mu) * self.weights[0] / np.sqrt(self.eps + stderr ** 2) + self.weights[1]] res.append(val) - self.nodes[i].val = val + self.nodes[i].value = val return res @@ -354,15 +338,12 @@ def get_batch(examples, batch_size=1): yield examples[i: i + batch_size] -def NeuralNetLearner(dataset, hidden_layer_sizes=None, learning_rate=0.01, epochs=100, - optimizer=gradient_descent, batch_size=1, verbose=None): +def NeuralNetLearner(dataset, hidden_layer_sizes, l_rate=0.01, epochs=1000, batch_size=1, + optimizer=stochastic_gradient_descent, verbose=None): """ Simple dense multilayer neural network. :param hidden_layer_sizes: size of hidden layers in the form of a list """ - - if hidden_layer_sizes is None: - hidden_layer_sizes = [4] input_size = len(dataset.inputs) output_size = len(dataset.values[dataset.target]) @@ -376,7 +357,7 @@ def NeuralNetLearner(dataset, hidden_layer_sizes=None, learning_rate=0.01, epoch raw_net.append(DenseLayer(hidden_input_size, output_size)) # update parameters of the network - learned_net = optimizer(dataset, raw_net, mean_squared_error_loss, epochs, l_rate=learning_rate, + learned_net = optimizer(dataset, raw_net, mean_squared_error_loss, epochs, l_rate=l_rate, batch_size=batch_size, verbose=verbose) def predict(example): @@ -395,7 +376,8 @@ def predict(example): return predict -def PerceptronLearner(dataset, learning_rate=0.01, epochs=100, optimizer=gradient_descent, batch_size=1, verbose=None): +def PerceptronLearner(dataset, l_rate=0.01, epochs=1000, batch_size=1, + optimizer=stochastic_gradient_descent, verbose=None): """ Simple perceptron neural network. """ @@ -406,7 +388,7 @@ def PerceptronLearner(dataset, learning_rate=0.01, epochs=100, optimizer=gradien raw_net = [InputLayer(input_size), DenseLayer(input_size, output_size)] # update the network - learned_net = optimizer(dataset, raw_net, mean_squared_error_loss, epochs, l_rate=learning_rate, + learned_net = optimizer(dataset, raw_net, mean_squared_error_loss, epochs, l_rate=l_rate, batch_size=batch_size, verbose=verbose) def predict(example): diff --git a/gui/eight_puzzle.py b/gui/eight_puzzle.py index 82acced03..5733228d7 100644 --- a/gui/eight_puzzle.py +++ b/gui/eight_puzzle.py @@ -1,138 +1,151 @@ -# author ad71 -from tkinter import * +import os.path +import random +import time from functools import partial +from tkinter import * -import time -import random -import numpy as np +from search import astar_search, EightPuzzle -import sys -import os.path sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from search import astar_search, EightPuzzle -import utils - root = Tk() state = [1, 2, 3, 4, 5, 6, 7, 8, 0] puzzle = EightPuzzle(tuple(state)) solution = None -b = [None]*9 +b = [None] * 9 + # TODO: refactor into OOP, remove global variables def scramble(): - """ Scrambles the puzzle starting from the goal state """ + """Scrambles the puzzle starting from the goal state""" + + global state + global puzzle + possible_actions = ['UP', 'DOWN', 'LEFT', 'RIGHT'] + scramble = [] + for _ in range(60): + scramble.append(random.choice(possible_actions)) - global state - global puzzle - possible_actions = ['UP', 'DOWN', 'LEFT', 'RIGHT'] - scramble = [] - for _ in range(60): - scramble.append(random.choice(possible_actions)) + for move in scramble: + if move in puzzle.actions(state): + state = list(puzzle.result(state, move)) + puzzle = EightPuzzle(tuple(state)) + create_buttons() - for move in scramble: - if move in puzzle.actions(state): - state = list(puzzle.result(state, move)) - puzzle = EightPuzzle(tuple(state)) - create_buttons() def solve(): - """ Solves the puzzle using astar_search """ + """Solves the puzzle using astar_search""" + + return astar_search(puzzle).solution() - return astar_search(puzzle).solution() def solve_steps(): - """ Solves the puzzle step by step """ - - global puzzle - global solution - global state - solution = solve() - print(solution) - - for move in solution: - state = puzzle.result(state, move) - create_buttons() - root.update() - root.after(1, time.sleep(0.75)) + """Solves the puzzle step by step""" + + global puzzle + global solution + global state + solution = solve() + print(solution) + + for move in solution: + state = puzzle.result(state, move) + create_buttons() + root.update() + root.after(1, time.sleep(0.75)) + def exchange(index): - """ Interchanges the position of the selected tile with the zero tile under certain conditions """ - - global state - global solution - global puzzle - zero_ix = list(state).index(0) - actions = puzzle.actions(state) - current_action = '' - i_diff = index//3 - zero_ix//3 - j_diff = index%3 - zero_ix%3 - if i_diff == 1: - current_action += 'DOWN' - elif i_diff == -1: - current_action += 'UP' - - if j_diff == 1: - current_action += 'RIGHT' - elif j_diff == -1: - current_action += 'LEFT' - - if abs(i_diff) + abs(j_diff) != 1: - current_action = '' - - if current_action in actions: - b[zero_ix].grid_forget() - b[zero_ix] = Button(root, text=f'{state[index]}', width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, zero_ix)) - b[zero_ix].grid(row=zero_ix//3, column=zero_ix%3, ipady=40) - b[index].grid_forget() - b[index] = Button(root, text=None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, index)) - b[index].grid(row=index//3, column=index%3, ipady=40) - state[zero_ix], state[index] = state[index], state[zero_ix] - puzzle = EightPuzzle(tuple(state)) + """Interchanges the position of the selected tile with the zero tile under certain conditions""" + + global state + global solution + global puzzle + zero_ix = list(state).index(0) + actions = puzzle.actions(state) + current_action = '' + i_diff = index // 3 - zero_ix // 3 + j_diff = index % 3 - zero_ix % 3 + if i_diff == 1: + current_action += 'DOWN' + elif i_diff == -1: + current_action += 'UP' + + if j_diff == 1: + current_action += 'RIGHT' + elif j_diff == -1: + current_action += 'LEFT' + + if abs(i_diff) + abs(j_diff) != 1: + current_action = '' + + if current_action in actions: + b[zero_ix].grid_forget() + b[zero_ix] = Button(root, text=f'{state[index]}', width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, zero_ix)) + b[zero_ix].grid(row=zero_ix // 3, column=zero_ix % 3, ipady=40) + b[index].grid_forget() + b[index] = Button(root, text=None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, index)) + b[index].grid(row=index // 3, column=index % 3, ipady=40) + state[zero_ix], state[index] = state[index], state[zero_ix] + puzzle = EightPuzzle(tuple(state)) + def create_buttons(): - """ Creates dynamic buttons """ - - # TODO: Find a way to use grid_forget() with a for loop for initialization - b[0] = Button(root, text=f'{state[0]}' if state[0] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 0)) - b[0].grid(row=0, column=0, ipady=40) - b[1] = Button(root, text=f'{state[1]}' if state[1] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 1)) - b[1].grid(row=0, column=1, ipady=40) - b[2] = Button(root, text=f'{state[2]}' if state[2] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 2)) - b[2].grid(row=0, column=2, ipady=40) - b[3] = Button(root, text=f'{state[3]}' if state[3] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 3)) - b[3].grid(row=1, column=0, ipady=40) - b[4] = Button(root, text=f'{state[4]}' if state[4] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 4)) - b[4].grid(row=1, column=1, ipady=40) - b[5] = Button(root, text=f'{state[5]}' if state[5] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 5)) - b[5].grid(row=1, column=2, ipady=40) - b[6] = Button(root, text=f'{state[6]}' if state[6] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 6)) - b[6].grid(row=2, column=0, ipady=40) - b[7] = Button(root, text=f'{state[7]}' if state[7] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 7)) - b[7].grid(row=2, column=1, ipady=40) - b[8] = Button(root, text=f'{state[8]}' if state[8] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), command=partial(exchange, 8)) - b[8].grid(row=2, column=2, ipady=40) + """Creates dynamic buttons""" + + # TODO: Find a way to use grid_forget() with a for loop for initialization + b[0] = Button(root, text=f'{state[0]}' if state[0] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 0)) + b[0].grid(row=0, column=0, ipady=40) + b[1] = Button(root, text=f'{state[1]}' if state[1] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 1)) + b[1].grid(row=0, column=1, ipady=40) + b[2] = Button(root, text=f'{state[2]}' if state[2] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 2)) + b[2].grid(row=0, column=2, ipady=40) + b[3] = Button(root, text=f'{state[3]}' if state[3] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 3)) + b[3].grid(row=1, column=0, ipady=40) + b[4] = Button(root, text=f'{state[4]}' if state[4] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 4)) + b[4].grid(row=1, column=1, ipady=40) + b[5] = Button(root, text=f'{state[5]}' if state[5] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 5)) + b[5].grid(row=1, column=2, ipady=40) + b[6] = Button(root, text=f'{state[6]}' if state[6] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 6)) + b[6].grid(row=2, column=0, ipady=40) + b[7] = Button(root, text=f'{state[7]}' if state[7] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 7)) + b[7].grid(row=2, column=1, ipady=40) + b[8] = Button(root, text=f'{state[8]}' if state[8] != 0 else None, width=6, font=('Helvetica', 40, 'bold'), + command=partial(exchange, 8)) + b[8].grid(row=2, column=2, ipady=40) + def create_static_buttons(): - """ Creates scramble and solve buttons """ + """Creates scramble and solve buttons""" + + scramble_btn = Button(root, text='Scramble', font=('Helvetica', 30, 'bold'), width=8, command=partial(init)) + scramble_btn.grid(row=3, column=0, ipady=10) + solve_btn = Button(root, text='Solve', font=('Helvetica', 30, 'bold'), width=8, command=partial(solve_steps)) + solve_btn.grid(row=3, column=2, ipady=10) - scramble_btn = Button(root, text='Scramble', font=('Helvetica', 30, 'bold'), width=8, command=partial(init)) - scramble_btn.grid(row=3, column=0, ipady=10) - solve_btn = Button(root, text='Solve', font=('Helvetica', 30, 'bold'), width=8, command=partial(solve_steps)) - solve_btn.grid(row=3, column=2, ipady=10) def init(): - """ Calls necessary functions """ - - global state - global solution - state = [1, 2, 3, 4, 5, 6, 7, 8, 0] - scramble() - create_buttons() - create_static_buttons() + """Calls necessary functions""" + + global state + global solution + state = [1, 2, 3, 4, 5, 6, 7, 8, 0] + scramble() + create_buttons() + create_static_buttons() + init() root.mainloop() diff --git a/gui/genetic_algorithm_example.py b/gui/genetic_algorithm_example.py index 418da02e9..c987151c8 100644 --- a/gui/genetic_algorithm_example.py +++ b/gui/genetic_algorithm_example.py @@ -1,4 +1,3 @@ -# author: ad71 # A simple program that implements the solution to the phrase generation problem using # genetic algorithms as given in the search.ipynb notebook. # @@ -9,17 +8,13 @@ # Displays a progress bar that indicates the amount of completion of the algorithm # Displays the first few individuals of the current generation -import sys -import time -import random import os.path -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - from tkinter import * from tkinter import ttk import search -from utils import argmax + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) LARGE_FONT = ('Verdana', 12) EXTRA_LARGE_FONT = ('Consolas', 36, 'bold') @@ -34,20 +29,20 @@ # genetic algorithm variables # feel free to play around with these -target = 'Genetic Algorithm' # the phrase to be generated -max_population = 100 # number of samples in each population -mutation_rate = 0.1 # probability of mutation -f_thres = len(target) # fitness threshold -ngen = 1200 # max number of generations to run the genetic algorithm +target = 'Genetic Algorithm' # the phrase to be generated +max_population = 100 # number of samples in each population +mutation_rate = 0.1 # probability of mutation +f_thres = len(target) # fitness threshold +ngen = 1200 # max number of generations to run the genetic algorithm -generation = 0 # counter to keep track of generation number +generation = 0 # counter to keep track of generation number -u_case = [chr(x) for x in range(65, 91)] # list containing all uppercase characters -l_case = [chr(x) for x in range(97, 123)] # list containing all lowercase characters -punctuations1 = [chr(x) for x in range(33, 48)] # lists containing punctuation symbols +u_case = [chr(x) for x in range(65, 91)] # list containing all uppercase characters +l_case = [chr(x) for x in range(97, 123)] # list containing all lowercase characters +punctuations1 = [chr(x) for x in range(33, 48)] # lists containing punctuation symbols punctuations2 = [chr(x) for x in range(58, 65)] punctuations3 = [chr(x) for x in range(91, 97)] -numerals = [chr(x) for x in range(48, 58)] # list containing numbers +numerals = [chr(x) for x in range(48, 58)] # list containing numbers # extend the gene pool with the required lists and append the space character gene_pool = [] @@ -55,44 +50,51 @@ gene_pool.extend(l_case) gene_pool.append(' ') + # callbacks to update global variables from the slider values def update_max_population(slider_value): - global max_population - max_population = slider_value + global max_population + max_population = slider_value + def update_mutation_rate(slider_value): - global mutation_rate - mutation_rate = slider_value + global mutation_rate + mutation_rate = slider_value + def update_f_thres(slider_value): - global f_thres - f_thres = slider_value + global f_thres + f_thres = slider_value + def update_ngen(slider_value): - global ngen - ngen = slider_value + global ngen + ngen = slider_value + # fitness function def fitness_fn(_list): - fitness = 0 - # create string from list of characters - phrase = ''.join(_list) - # add 1 to fitness value for every matching character - for i in range(len(phrase)): - if target[i] == phrase[i]: - fitness += 1 - return fitness + fitness = 0 + # create string from list of characters + phrase = ''.join(_list) + # add 1 to fitness value for every matching character + for i in range(len(phrase)): + if target[i] == phrase[i]: + fitness += 1 + return fitness + # function to bring a new frame on top def raise_frame(frame, init=False, update_target=False, target_entry=None, f_thres_slider=None): - frame.tkraise() - global target - if update_target and target_entry is not None: - target = target_entry.get() - f_thres_slider.config(to=len(target)) - if init: - population = search.init_population(max_population, gene_pool, len(target)) - genetic_algorithm_stepwise(population) + frame.tkraise() + global target + if update_target and target_entry is not None: + target = target_entry.get() + f_thres_slider.config(to=len(target)) + if init: + population = search.init_population(max_population, gene_pool, len(target)) + genetic_algorithm_stepwise(population) + # defining root and child frames root = Tk() @@ -101,7 +103,7 @@ def raise_frame(frame, init=False, update_target=False, target_entry=None, f_thr # pack frames on top of one another for frame in (f1, f2): - frame.grid(row=0, column=0, sticky='news') + frame.grid(row=0, column=0, sticky='news') # Home Screen (f1) widgets target_entry = Entry(f1, font=('Consolas 46 bold'), exportselection=0, foreground=p_blue, justify=CENTER) @@ -109,64 +111,79 @@ def raise_frame(frame, init=False, update_target=False, target_entry=None, f_thr target_entry.pack(expand=YES, side=TOP, fill=X, padx=50) target_entry.focus_force() -max_population_slider = Scale(f1, from_=3, to=1000, orient=HORIZONTAL, label='Max population', command=lambda value: update_max_population(int(value))) +max_population_slider = Scale(f1, from_=3, to=1000, orient=HORIZONTAL, label='Max population', + command=lambda value: update_max_population(int(value))) max_population_slider.set(max_population) max_population_slider.pack(expand=YES, side=TOP, fill=X, padx=40) -mutation_rate_slider = Scale(f1, from_=0, to=1, orient=HORIZONTAL, label='Mutation rate', resolution=0.0001, command=lambda value: update_mutation_rate(float(value))) +mutation_rate_slider = Scale(f1, from_=0, to=1, orient=HORIZONTAL, label='Mutation rate', resolution=0.0001, + command=lambda value: update_mutation_rate(float(value))) mutation_rate_slider.set(mutation_rate) mutation_rate_slider.pack(expand=YES, side=TOP, fill=X, padx=40) -f_thres_slider = Scale(f1, from_=0, to=len(target), orient=HORIZONTAL, label='Fitness threshold', command=lambda value: update_f_thres(int(value))) +f_thres_slider = Scale(f1, from_=0, to=len(target), orient=HORIZONTAL, label='Fitness threshold', + command=lambda value: update_f_thres(int(value))) f_thres_slider.set(f_thres) f_thres_slider.pack(expand=YES, side=TOP, fill=X, padx=40) -ngen_slider = Scale(f1, from_=1, to=5000, orient=HORIZONTAL, label='Max number of generations', command=lambda value: update_ngen(int(value))) +ngen_slider = Scale(f1, from_=1, to=5000, orient=HORIZONTAL, label='Max number of generations', + command=lambda value: update_ngen(int(value))) ngen_slider.set(ngen) ngen_slider.pack(expand=YES, side=TOP, fill=X, padx=40) -button = ttk.Button(f1, text='RUN', command=lambda: raise_frame(f2, init=True, update_target=True, target_entry=target_entry, f_thres_slider=f_thres_slider)).pack(side=BOTTOM, pady=50) +button = ttk.Button(f1, text='RUN', + command=lambda: raise_frame(f2, init=True, update_target=True, target_entry=target_entry, + f_thres_slider=f_thres_slider)).pack(side=BOTTOM, pady=50) # f2 widgets canvas = Canvas(f2, width=canvas_width, height=canvas_height) canvas.pack(expand=YES, fill=BOTH, padx=20, pady=15) button = ttk.Button(f2, text='EXIT', command=lambda: raise_frame(f1)).pack(side=BOTTOM, pady=15) + # function to run the genetic algorithm and update text on the canvas def genetic_algorithm_stepwise(population): - root.title('Genetic Algorithm') - for generation in range(ngen): - # generating new population after selecting, recombining and mutating the existing population - population = [search.mutate(search.recombine(*search.select(2, population, fitness_fn)), gene_pool, mutation_rate) for i in range(len(population))] - # genome with the highest fitness in the current generation - current_best = ''.join(argmax(population, key=fitness_fn)) - # collecting first few examples from the current population - members = [''.join(x) for x in population][:48] - - # clear the canvas - canvas.delete('all') - # displays current best on top of the screen - canvas.create_text(canvas_width / 2, 40, fill=p_blue, font='Consolas 46 bold', text=current_best) - - # displaying a part of the population on the screen - for i in range(len(members) // 3): - canvas.create_text((canvas_width * .175), (canvas_height * .25 + (25 * i)), fill=lp_blue, font='Consolas 16', text=members[3 * i]) - canvas.create_text((canvas_width * .500), (canvas_height * .25 + (25 * i)), fill=lp_blue, font='Consolas 16', text=members[3 * i + 1]) - canvas.create_text((canvas_width * .825), (canvas_height * .25 + (25 * i)), fill=lp_blue, font='Consolas 16', text=members[3 * i + 2]) - - # displays current generation number - canvas.create_text((canvas_width * .5), (canvas_height * 0.95), fill=p_blue, font='Consolas 18 bold', text=f'Generation {generation}') - - # displays blue bar that indicates current maximum fitness compared to maximum possible fitness - scaling_factor = fitness_fn(current_best) / len(target) - canvas.create_rectangle(canvas_width * 0.1, 90, canvas_width * 0.9, 100, outline=p_blue) - canvas.create_rectangle(canvas_width * 0.1, 90, canvas_width * 0.1 + scaling_factor * canvas_width * 0.8, 100, fill=lp_blue) - canvas.update() - - # checks for completion - fittest_individual = search.fitness_threshold(fitness_fn, f_thres, population) - if fittest_individual: - break + root.title('Genetic Algorithm') + for generation in range(ngen): + # generating new population after selecting, recombining and mutating the existing population + population = [ + search.mutate(search.recombine(*search.select(2, population, fitness_fn)), gene_pool, mutation_rate) for i + in range(len(population))] + # genome with the highest fitness in the current generation + current_best = ''.join(max(population, key=fitness_fn)) + # collecting first few examples from the current population + members = [''.join(x) for x in population][:48] + + # clear the canvas + canvas.delete('all') + # displays current best on top of the screen + canvas.create_text(canvas_width / 2, 40, fill=p_blue, font='Consolas 46 bold', text=current_best) + + # displaying a part of the population on the screen + for i in range(len(members) // 3): + canvas.create_text((canvas_width * .175), (canvas_height * .25 + (25 * i)), fill=lp_blue, + font='Consolas 16', text=members[3 * i]) + canvas.create_text((canvas_width * .500), (canvas_height * .25 + (25 * i)), fill=lp_blue, + font='Consolas 16', text=members[3 * i + 1]) + canvas.create_text((canvas_width * .825), (canvas_height * .25 + (25 * i)), fill=lp_blue, + font='Consolas 16', text=members[3 * i + 2]) + + # displays current generation number + canvas.create_text((canvas_width * .5), (canvas_height * 0.95), fill=p_blue, font='Consolas 18 bold', + text=f'Generation {generation}') + + # displays blue bar that indicates current maximum fitness compared to maximum possible fitness + scaling_factor = fitness_fn(current_best) / len(target) + canvas.create_rectangle(canvas_width * 0.1, 90, canvas_width * 0.9, 100, outline=p_blue) + canvas.create_rectangle(canvas_width * 0.1, 90, canvas_width * 0.1 + scaling_factor * canvas_width * 0.8, 100, + fill=lp_blue) + canvas.update() + + # checks for completion + fittest_individual = search.fitness_threshold(fitness_fn, f_thres, population) + if fittest_individual: + break + raise_frame(f1) -root.mainloop() \ No newline at end of file +root.mainloop() diff --git a/gui/grid_mdp.py b/gui/grid_mdp.py index 540bc2611..cb04c54b9 100644 --- a/gui/grid_mdp.py +++ b/gui/grid_mdp.py @@ -1,26 +1,22 @@ -# author: ad71 +import os.path +import sys import tkinter as tk import tkinter.messagebox -from tkinter import ttk - from functools import partial - -import sys -import os.path -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - -from mdp import * -import utils -import numpy as np -import time +from tkinter import ttk import matplotlib import matplotlib.animation as animation +from matplotlib import pyplot as plt +from matplotlib import style from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg -from matplotlib.ticker import MaxNLocator from matplotlib.figure import Figure -from matplotlib import style -from matplotlib import pyplot as plt +from matplotlib.ticker import MaxNLocator + +from mdp import * + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + matplotlib.use('TkAgg') style.use('ggplot') @@ -41,617 +37,640 @@ green8 = '#008080' green4 = '#004040' -cell_window_mantainer=None +cell_window_mantainer = None + def extents(f): - ''' adjusts axis markers for heatmap ''' + """adjusts axis markers for heatmap""" + + delta = f[1] - f[0] + return [f[0] - delta / 2, f[-1] + delta / 2] - delta = f[1] - f[0] - return [f[0] - delta/2, f[-1] + delta/2] def display(gridmdp, _height, _width): - ''' displays matrix ''' + """displays matrix""" - dialog = tk.Toplevel() - dialog.wm_title('Values') + dialog = tk.Toplevel() + dialog.wm_title('Values') - container = tk.Frame(dialog) - container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + container = tk.Frame(dialog) + container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - for i in range(max(1, _height)): - for j in range(max(1, _width)): - label = ttk.Label(container, text=f'{gridmdp[_height - i - 1][j]:.3f}', font=('Helvetica', 12)) - label.grid(row=i + 1, column=j + 1, padx=3, pady=3) + for i in range(max(1, _height)): + for j in range(max(1, _width)): + label = ttk.Label(container, text=f'{gridmdp[_height - i - 1][j]:.3f}', font=('Helvetica', 12)) + label.grid(row=i + 1, column=j + 1, padx=3, pady=3) + + dialog.mainloop() - dialog.mainloop() def display_best_policy(_best_policy, _height, _width): - ''' displays best policy ''' + """displays best policy""" + dialog = tk.Toplevel() + dialog.wm_title('Best Policy') - dialog = tk.Toplevel() - dialog.wm_title('Best Policy') + container = tk.Frame(dialog) + container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - container = tk.Frame(dialog) - container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + for i in range(max(1, _height)): + for j in range(max(1, _width)): + label = ttk.Label(container, text=_best_policy[i][j], font=('Helvetica', 12, 'bold')) + label.grid(row=i + 1, column=j + 1, padx=3, pady=3) - for i in range(max(1, _height)): - for j in range(max(1, _width)): - label = ttk.Label(container, text=_best_policy[i][j], font=('Helvetica', 12, 'bold')) - label.grid(row=i + 1, column=j + 1, padx=3, pady=3) + dialog.mainloop() - dialog.mainloop() def initialize_dialogbox(_width, _height, gridmdp, terminals, buttons): - ''' creates dialogbox for initialization ''' - - dialog = tk.Toplevel() - dialog.wm_title('Initialize') - - container = tk.Frame(dialog) - container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - container.grid_rowconfigure(0, weight=1) - container.grid_columnconfigure(0, weight=1) - - wall = tk.IntVar() - wall.set(0) - term = tk.IntVar() - term.set(0) - reward = tk.DoubleVar() - reward.set(0.0) - - label = ttk.Label(container, text='Initialize', font=('Helvetica', 12), anchor=tk.N) - label.grid(row=0, column=0, columnspan=3, sticky='new', pady=15, padx=5) - label_reward = ttk.Label(container, text='Reward', font=('Helvetica', 10), anchor=tk.N) - label_reward.grid(row=1, column=0, columnspan=3, sticky='new', pady=1, padx=5) - entry_reward = ttk.Entry(container, font=('Helvetica', 10), justify=tk.CENTER, exportselection=0, textvariable=reward) - entry_reward.grid(row=2, column=0, columnspan=3, sticky='new', pady=5, padx=50) - - rbtn_term = ttk.Radiobutton(container, text='Terminal', variable=term, value=TERM_VALUE) - rbtn_term.grid(row=3, column=0, columnspan=3, sticky='nsew', padx=160, pady=5) - rbtn_wall = ttk.Radiobutton(container, text='Wall', variable=wall, value=WALL_VALUE) - rbtn_wall.grid(row=4, column=0, columnspan=3, sticky='nsew', padx=172, pady=5) - - initialize_widget_disability_checks(_width, _height, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, rbtn_term) - - btn_apply = ttk.Button(container, text='Apply', command=partial(initialize_update_table, _width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term, rbtn_wall)) - btn_apply.grid(row=5, column=0, sticky='nsew', pady=5, padx=5) - btn_reset = ttk.Button(container, text='Reset', command=partial(initialize_reset_all, _width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_wall, rbtn_term)) - btn_reset.grid(row=5, column=1, sticky='nsew', pady=5, padx=5) - btn_ok = ttk.Button(container, text='Ok', command=dialog.destroy) - btn_ok.grid(row=5, column=2, sticky='nsew', pady=5, padx=5) - - dialog.geometry('400x200') - dialog.mainloop() - -def update_table(i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term, rbtn_wall): - ''' functionality for 'apply' button ''' - - if wall.get() == WALL_VALUE: - buttons[i][j].configure(style='wall.TButton') - buttons[i][j].config(text='Wall') - label_reward.config(foreground='#999') - entry_reward.config(state=tk.DISABLED) - rbtn_term.state(['!focus', '!selected']) - rbtn_term.config(state=tk.DISABLED) - gridmdp[i][j] = WALL_VALUE - - elif wall.get() != WALL_VALUE: - if reward.get() != 0.0: - gridmdp[i][j] = reward.get() - buttons[i][j].configure(style='reward.TButton') - buttons[i][j].config(text=f'R = {reward.get()}') - - if term.get() == TERM_VALUE: - if (i, j) not in terminals: - terminals.append((i, j)) - rbtn_wall.state(['!focus', '!selected']) - rbtn_wall.config(state=tk.DISABLED) - - if gridmdp[i][j] < 0: - buttons[i][j].configure(style='-term.TButton') - - elif gridmdp[i][j] > 0: - buttons[i][j].configure(style='+term.TButton') - - elif gridmdp[i][j] == 0.0: - buttons[i][j].configure(style='=term.TButton') - -def initialize_update_table(_width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term, rbtn_wall): - ''' runs update_table for all cells ''' - - for i in range(max(1, _height)): - for j in range(max(1, _width)): - update_table(i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term, rbtn_wall) - -def reset_all(_height, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_wall, rbtn_term): - ''' functionality for reset button ''' - - reward.set(0.0) - term.set(0) - wall.set(0) - gridmdp[i][j] = 0.0 - buttons[i][j].configure(style='TButton') - buttons[i][j].config(text=f'({_height - i - 1}, {j})') - - if (i, j) in terminals: - terminals.remove((i, j)) - - label_reward.config(foreground='#000') - entry_reward.config(state=tk.NORMAL) - rbtn_term.config(state=tk.NORMAL) - rbtn_wall.config(state=tk.NORMAL) - rbtn_wall.state(['!focus', '!selected']) - rbtn_term.state(['!focus', '!selected']) - -def initialize_reset_all(_width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_wall, rbtn_term): - ''' runs reset_all for all cells ''' - - for i in range(max(1, _height)): - for j in range(max(1, _width)): - reset_all(_height, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_wall, rbtn_term) + """creates dialogbox for initialization""" + + dialog = tk.Toplevel() + dialog.wm_title('Initialize') + + container = tk.Frame(dialog) + container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + container.grid_rowconfigure(0, weight=1) + container.grid_columnconfigure(0, weight=1) + + wall = tk.IntVar() + wall.set(0) + term = tk.IntVar() + term.set(0) + reward = tk.DoubleVar() + reward.set(0.0) + + label = ttk.Label(container, text='Initialize', font=('Helvetica', 12), anchor=tk.N) + label.grid(row=0, column=0, columnspan=3, sticky='new', pady=15, padx=5) + label_reward = ttk.Label(container, text='Reward', font=('Helvetica', 10), anchor=tk.N) + label_reward.grid(row=1, column=0, columnspan=3, sticky='new', pady=1, padx=5) + entry_reward = ttk.Entry(container, font=('Helvetica', 10), justify=tk.CENTER, exportselection=0, + textvariable=reward) + entry_reward.grid(row=2, column=0, columnspan=3, sticky='new', pady=5, padx=50) + + rbtn_term = ttk.Radiobutton(container, text='Terminal', variable=term, value=TERM_VALUE) + rbtn_term.grid(row=3, column=0, columnspan=3, sticky='nsew', padx=160, pady=5) + rbtn_wall = ttk.Radiobutton(container, text='Wall', variable=wall, value=WALL_VALUE) + rbtn_wall.grid(row=4, column=0, columnspan=3, sticky='nsew', padx=172, pady=5) + + initialize_widget_disability_checks(_width, _height, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, + rbtn_term) + + btn_apply = ttk.Button(container, text='Apply', + command=partial(initialize_update_table, _width, _height, gridmdp, terminals, buttons, + reward, term, wall, label_reward, entry_reward, rbtn_term, rbtn_wall)) + btn_apply.grid(row=5, column=0, sticky='nsew', pady=5, padx=5) + btn_reset = ttk.Button(container, text='Reset', + command=partial(initialize_reset_all, _width, _height, gridmdp, terminals, buttons, reward, + term, wall, label_reward, entry_reward, rbtn_wall, rbtn_term)) + btn_reset.grid(row=5, column=1, sticky='nsew', pady=5, padx=5) + btn_ok = ttk.Button(container, text='Ok', command=dialog.destroy) + btn_ok.grid(row=5, column=2, sticky='nsew', pady=5, padx=5) + + dialog.geometry('400x200') + dialog.mainloop() + + +def update_table(i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term, + rbtn_wall): + """functionality for 'apply' button""" + if wall.get() == WALL_VALUE: + buttons[i][j].configure(style='wall.TButton') + buttons[i][j].config(text='Wall') + label_reward.config(foreground='#999') + entry_reward.config(state=tk.DISABLED) + rbtn_term.state(['!focus', '!selected']) + rbtn_term.config(state=tk.DISABLED) + gridmdp[i][j] = WALL_VALUE + + elif wall.get() != WALL_VALUE: + if reward.get() != 0.0: + gridmdp[i][j] = reward.get() + buttons[i][j].configure(style='reward.TButton') + buttons[i][j].config(text=f'R = {reward.get()}') + + if term.get() == TERM_VALUE: + if (i, j) not in terminals: + terminals.append((i, j)) + rbtn_wall.state(['!focus', '!selected']) + rbtn_wall.config(state=tk.DISABLED) + + if gridmdp[i][j] < 0: + buttons[i][j].configure(style='-term.TButton') + + elif gridmdp[i][j] > 0: + buttons[i][j].configure(style='+term.TButton') + + elif gridmdp[i][j] == 0.0: + buttons[i][j].configure(style='=term.TButton') + + +def initialize_update_table(_width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward, + entry_reward, rbtn_term, rbtn_wall): + """runs update_table for all cells""" + + for i in range(max(1, _height)): + for j in range(max(1, _width)): + update_table(i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term, + rbtn_wall) + + +def reset_all(_height, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_wall, + rbtn_term): + """functionality for reset button""" + reward.set(0.0) + term.set(0) + wall.set(0) + gridmdp[i][j] = 0.0 + buttons[i][j].configure(style='TButton') + buttons[i][j].config(text=f'({_height - i - 1}, {j})') + + if (i, j) in terminals: + terminals.remove((i, j)) + + label_reward.config(foreground='#000') + entry_reward.config(state=tk.NORMAL) + rbtn_term.config(state=tk.NORMAL) + rbtn_wall.config(state=tk.NORMAL) + rbtn_wall.state(['!focus', '!selected']) + rbtn_term.state(['!focus', '!selected']) + + +def initialize_reset_all(_width, _height, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, + rbtn_wall, rbtn_term): + """runs reset_all for all cells""" + + for i in range(max(1, _height)): + for j in range(max(1, _width)): + reset_all(_height, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, + rbtn_wall, rbtn_term) + def external_reset(_width, _height, gridmdp, terminals, buttons): - ''' reset from edit menu ''' + """reset from edit menu""" + for i in range(max(1, _height)): + for j in range(max(1, _width)): + gridmdp[i][j] = 0.0 + buttons[i][j].configure(style='TButton') + buttons[i][j].config(text=f'({_height - i - 1}, {j})') - terminals = [] - for i in range(max(1, _height)): - for j in range(max(1, _width)): - gridmdp[i][j] = 0.0 - buttons[i][j].configure(style='TButton') - buttons[i][j].config(text=f'({_height - i - 1}, {j})') def widget_disability_checks(i, j, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, rbtn_term): - ''' checks for required state of widgets in dialogboxes ''' + """checks for required state of widgets in dialog boxes""" - if gridmdp[i][j] == WALL_VALUE: - label_reward.config(foreground='#999') - entry_reward.config(state=tk.DISABLED) - rbtn_term.config(state=tk.DISABLED) - rbtn_wall.state(['!focus', 'selected']) - rbtn_term.state(['!focus', '!selected']) + if gridmdp[i][j] == WALL_VALUE: + label_reward.config(foreground='#999') + entry_reward.config(state=tk.DISABLED) + rbtn_term.config(state=tk.DISABLED) + rbtn_wall.state(['!focus', 'selected']) + rbtn_term.state(['!focus', '!selected']) - if (i, j) in terminals: - rbtn_wall.config(state=tk.DISABLED) - rbtn_wall.state(['!focus', '!selected']) + if (i, j) in terminals: + rbtn_wall.config(state=tk.DISABLED) + rbtn_wall.state(['!focus', '!selected']) -def flatten_list(_list): - ''' returns a flattened list ''' - - return sum(_list, []) - -def initialize_widget_disability_checks(_width, _height, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, rbtn_term): - ''' checks for required state of widgets when cells are initialized ''' - - bool_walls = [['False']*max(1, _width) for _ in range(max(1, _height))] - bool_terms = [['False']*max(1, _width) for _ in range(max(1, _height))] - - for i in range(max(1, _height)): - for j in range(max(1, _width)): - if gridmdp[i][j] == WALL_VALUE: - bool_walls[i][j] = 'True' - - if (i, j) in terminals: - bool_terms[i][j] = 'True' - - bool_walls_fl = flatten_list(bool_walls) - bool_terms_fl = flatten_list(bool_terms) - - if bool_walls_fl.count('True') == len(bool_walls_fl): - print('`') - label_reward.config(foreground='#999') - entry_reward.config(state=tk.DISABLED) - rbtn_term.config(state=tk.DISABLED) - rbtn_wall.state(['!focus', 'selected']) - rbtn_term.state(['!focus', '!selected']) - - if bool_terms_fl.count('True') == len(bool_terms_fl): - rbtn_wall.config(state=tk.DISABLED) - rbtn_wall.state(['!focus', '!selected']) - rbtn_term.state(['!focus', 'selected']) - -def dialogbox(i, j, gridmdp, terminals, buttons, _height): - ''' creates dialogbox for each cell ''' - - global cell_window_mantainer - if(cell_window_mantainer!=None): - cell_window_mantainer.destroy() - - dialog = tk.Toplevel() - cell_window_mantainer=dialog - dialog.wm_title(f'{_height - i - 1}, {j}') - - container = tk.Frame(dialog) - container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - container.grid_rowconfigure(0, weight=1) - container.grid_columnconfigure(0, weight=1) - - wall = tk.IntVar() - wall.set(gridmdp[i][j]) - term = tk.IntVar() - term.set(TERM_VALUE if (i, j) in terminals else 0.0) - reward = tk.DoubleVar() - reward.set(gridmdp[i][j] if gridmdp[i][j] != WALL_VALUE else 0.0) - - label = ttk.Label(container, text=f'Configure cell {_height - i - 1}, {j}', font=('Helvetica', 12), anchor=tk.N) - label.grid(row=0, column=0, columnspan=3, sticky='new', pady=15, padx=5) - label_reward = ttk.Label(container, text='Reward', font=('Helvetica', 10), anchor=tk.N) - label_reward.grid(row=1, column=0, columnspan=3, sticky='new', pady=1, padx=5) - entry_reward = ttk.Entry(container, font=('Helvetica', 10), justify=tk.CENTER, exportselection=0, textvariable=reward) - entry_reward.grid(row=2, column=0, columnspan=3, sticky='new', pady=5, padx=50) - - rbtn_term = ttk.Radiobutton(container, text='Terminal', variable=term, value=TERM_VALUE) - rbtn_term.grid(row=3, column=0, columnspan=3, sticky='nsew', padx=160, pady=5) - rbtn_wall = ttk.Radiobutton(container, text='Wall', variable=wall, value=WALL_VALUE) - rbtn_wall.grid(row=4, column=0, columnspan=3, sticky='nsew', padx=172, pady=5) - - widget_disability_checks(i, j, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, rbtn_term) - - btn_apply = ttk.Button(container, text='Apply', command=partial(update_table, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_term, rbtn_wall)) - btn_apply.grid(row=5, column=0, sticky='nsew', pady=5, padx=5) - btn_reset = ttk.Button(container, text='Reset', command=partial(reset_all, _height, i, j, gridmdp, terminals, buttons, reward, term, wall, label_reward, entry_reward, rbtn_wall, rbtn_term)) - btn_reset.grid(row=5, column=1, sticky='nsew', pady=5, padx=5) - btn_ok = ttk.Button(container, text='Ok', command=dialog.destroy) - btn_ok.grid(row=5, column=2, sticky='nsew', pady=5, padx=5) - - dialog.geometry('400x200') - dialog.mainloop() +def flatten_list(_list): + """returns a flattened list""" + return sum(_list, []) -class MDPapp(tk.Tk): - - def __init__(self, *args, **kwargs): - - tk.Tk.__init__(self, *args, **kwargs) - tk.Tk.wm_title(self, 'Grid MDP') - self.shared_data = { - 'height': tk.IntVar(), - 'width': tk.IntVar() - } - self.shared_data['height'].set(1) - self.shared_data['width'].set(1) - self.container = tk.Frame(self) - self.container.pack(side='top', fill='both', expand=True) - self.container.grid_rowconfigure(0, weight=1) - self.container.grid_columnconfigure(0, weight=1) - - self.frames = {} - - self.menu_bar = tk.Menu(self.container) - self.file_menu = tk.Menu(self.menu_bar, tearoff=0) - self.file_menu.add_command(label='Exit', command=self.exit) - self.menu_bar.add_cascade(label='File', menu=self.file_menu) - - self.edit_menu = tk.Menu(self.menu_bar, tearoff=1) - self.edit_menu.add_command(label='Reset', command=self.master_reset) - self.edit_menu.add_command(label='Initialize', command=self.initialize) - self.edit_menu.add_separator() - self.edit_menu.add_command(label='View matrix', command=self.view_matrix) - self.edit_menu.add_command(label='View terminals', command=self.view_terminals) - self.menu_bar.add_cascade(label='Edit', menu=self.edit_menu) - self.menu_bar.entryconfig('Edit', state=tk.DISABLED) - - self.build_menu = tk.Menu(self.menu_bar, tearoff=1) - self.build_menu.add_command(label='Build and Run', command=self.build) - self.menu_bar.add_cascade(label='Build', menu=self.build_menu) - self.menu_bar.entryconfig('Build', state=tk.DISABLED) - tk.Tk.config(self, menu=self.menu_bar) - - for F in (HomePage, BuildMDP, SolveMDP): - frame = F(self.container, self) - self.frames[F] = frame - frame.grid(row=0, column=0, sticky='nsew') - - self.show_frame(HomePage) - - def placeholder_function(self): - ''' placeholder function ''' - - print('Not supported yet!') - - def exit(self): - ''' function to exit ''' - - if tkinter.messagebox.askokcancel('Exit?', 'All changes will be lost'): - quit() - - def new(self): - ''' function to create new GridMDP ''' - - self.master_reset() - build_page = self.get_page(BuildMDP) - build_page.gridmdp = None - build_page.terminals = None - build_page.buttons = None - self.show_frame(HomePage) - - def get_page(self, page_class): - ''' returns pages from stored frames ''' - - return self.frames[page_class] - def view_matrix(self): - ''' prints current matrix to console ''' +def initialize_widget_disability_checks(_width, _height, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, + rbtn_term): + """checks for required state of widgets when cells are initialized""" - build_page = self.get_page(BuildMDP) - _height = self.shared_data['height'].get() - _width = self.shared_data['width'].get() - print(build_page.gridmdp) - display(build_page.gridmdp, _height, _width) + bool_walls = [['False'] * max(1, _width) for _ in range(max(1, _height))] + bool_terms = [['False'] * max(1, _width) for _ in range(max(1, _height))] - def view_terminals(self): - ''' prints current terminals to console ''' + for i in range(max(1, _height)): + for j in range(max(1, _width)): + if gridmdp[i][j] == WALL_VALUE: + bool_walls[i][j] = 'True' - build_page = self.get_page(BuildMDP) - print('Terminals', build_page.terminals) + if (i, j) in terminals: + bool_terms[i][j] = 'True' - def initialize(self): - ''' calls initialize from BuildMDP ''' + bool_walls_fl = flatten_list(bool_walls) + bool_terms_fl = flatten_list(bool_terms) - build_page = self.get_page(BuildMDP) - build_page.initialize() + if bool_walls_fl.count('True') == len(bool_walls_fl): + print('`') + label_reward.config(foreground='#999') + entry_reward.config(state=tk.DISABLED) + rbtn_term.config(state=tk.DISABLED) + rbtn_wall.state(['!focus', 'selected']) + rbtn_term.state(['!focus', '!selected']) - def master_reset(self): - ''' calls master_reset from BuildMDP ''' + if bool_terms_fl.count('True') == len(bool_terms_fl): + rbtn_wall.config(state=tk.DISABLED) + rbtn_wall.state(['!focus', '!selected']) + rbtn_term.state(['!focus', 'selected']) - build_page = self.get_page(BuildMDP) - build_page.master_reset() - def build(self): - ''' runs specified mdp solving algorithm ''' +def dialogbox(i, j, gridmdp, terminals, buttons, _height): + """creates dialogbox for each cell""" + global cell_window_mantainer + if (cell_window_mantainer != None): + cell_window_mantainer.destroy() + + dialog = tk.Toplevel() + cell_window_mantainer = dialog + dialog.wm_title(f'{_height - i - 1}, {j}') + + container = tk.Frame(dialog) + container.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + container.grid_rowconfigure(0, weight=1) + container.grid_columnconfigure(0, weight=1) + + wall = tk.IntVar() + wall.set(gridmdp[i][j]) + term = tk.IntVar() + term.set(TERM_VALUE if (i, j) in terminals else 0.0) + reward = tk.DoubleVar() + reward.set(gridmdp[i][j] if gridmdp[i][j] != WALL_VALUE else 0.0) + + label = ttk.Label(container, text=f'Configure cell {_height - i - 1}, {j}', font=('Helvetica', 12), anchor=tk.N) + label.grid(row=0, column=0, columnspan=3, sticky='new', pady=15, padx=5) + label_reward = ttk.Label(container, text='Reward', font=('Helvetica', 10), anchor=tk.N) + label_reward.grid(row=1, column=0, columnspan=3, sticky='new', pady=1, padx=5) + entry_reward = ttk.Entry(container, font=('Helvetica', 10), justify=tk.CENTER, exportselection=0, + textvariable=reward) + entry_reward.grid(row=2, column=0, columnspan=3, sticky='new', pady=5, padx=50) + + rbtn_term = ttk.Radiobutton(container, text='Terminal', variable=term, value=TERM_VALUE) + rbtn_term.grid(row=3, column=0, columnspan=3, sticky='nsew', padx=160, pady=5) + rbtn_wall = ttk.Radiobutton(container, text='Wall', variable=wall, value=WALL_VALUE) + rbtn_wall.grid(row=4, column=0, columnspan=3, sticky='nsew', padx=172, pady=5) + + widget_disability_checks(i, j, gridmdp, terminals, label_reward, entry_reward, rbtn_wall, rbtn_term) + + btn_apply = ttk.Button(container, text='Apply', + command=partial(update_table, i, j, gridmdp, terminals, buttons, reward, term, wall, + label_reward, entry_reward, rbtn_term, rbtn_wall)) + btn_apply.grid(row=5, column=0, sticky='nsew', pady=5, padx=5) + btn_reset = ttk.Button(container, text='Reset', + command=partial(reset_all, _height, i, j, gridmdp, terminals, buttons, reward, term, wall, + label_reward, entry_reward, rbtn_wall, rbtn_term)) + btn_reset.grid(row=5, column=1, sticky='nsew', pady=5, padx=5) + btn_ok = ttk.Button(container, text='Ok', command=dialog.destroy) + btn_ok.grid(row=5, column=2, sticky='nsew', pady=5, padx=5) + + dialog.geometry('400x200') + dialog.mainloop() - frame = SolveMDP(self.container, self) - self.frames[SolveMDP] = frame - frame.grid(row=0, column=0, sticky='nsew') - self.show_frame(SolveMDP) - build_page = self.get_page(BuildMDP) - gridmdp = build_page.gridmdp - terminals = build_page.terminals - solve_page = self.get_page(SolveMDP) - _height = self.shared_data['height'].get() - _width = self.shared_data['width'].get() - solve_page.create_graph(gridmdp, terminals, _height, _width) - def show_frame(self, controller, cb=False): - ''' shows specified frame and optionally runs create_buttons ''' +class MDPapp(tk.Tk): - if cb: - build_page = self.get_page(BuildMDP) - build_page.create_buttons() - frame = self.frames[controller] - frame.tkraise() + def __init__(self, *args, **kwargs): + + tk.Tk.__init__(self, *args, **kwargs) + tk.Tk.wm_title(self, 'Grid MDP') + self.shared_data = { + 'height': tk.IntVar(), + 'width': tk.IntVar()} + self.shared_data['height'].set(1) + self.shared_data['width'].set(1) + self.container = tk.Frame(self) + self.container.pack(side='top', fill='both', expand=True) + self.container.grid_rowconfigure(0, weight=1) + self.container.grid_columnconfigure(0, weight=1) + + self.frames = {} + + self.menu_bar = tk.Menu(self.container) + self.file_menu = tk.Menu(self.menu_bar, tearoff=0) + self.file_menu.add_command(label='Exit', command=self.exit) + self.menu_bar.add_cascade(label='File', menu=self.file_menu) + + self.edit_menu = tk.Menu(self.menu_bar, tearoff=1) + self.edit_menu.add_command(label='Reset', command=self.master_reset) + self.edit_menu.add_command(label='Initialize', command=self.initialize) + self.edit_menu.add_separator() + self.edit_menu.add_command(label='View matrix', command=self.view_matrix) + self.edit_menu.add_command(label='View terminals', command=self.view_terminals) + self.menu_bar.add_cascade(label='Edit', menu=self.edit_menu) + self.menu_bar.entryconfig('Edit', state=tk.DISABLED) + + self.build_menu = tk.Menu(self.menu_bar, tearoff=1) + self.build_menu.add_command(label='Build and Run', command=self.build) + self.menu_bar.add_cascade(label='Build', menu=self.build_menu) + self.menu_bar.entryconfig('Build', state=tk.DISABLED) + tk.Tk.config(self, menu=self.menu_bar) + + for F in (HomePage, BuildMDP, SolveMDP): + frame = F(self.container, self) + self.frames[F] = frame + frame.grid(row=0, column=0, sticky='nsew') + + self.show_frame(HomePage) + + def placeholder_function(self): + """placeholder function""" + + print('Not supported yet!') + + def exit(self): + """function to exit""" + if tkinter.messagebox.askokcancel('Exit?', 'All changes will be lost'): + quit() + + def new(self): + """function to create new GridMDP""" + + self.master_reset() + build_page = self.get_page(BuildMDP) + build_page.gridmdp = None + build_page.terminals = None + build_page.buttons = None + self.show_frame(HomePage) + + def get_page(self, page_class): + """returns pages from stored frames""" + return self.frames[page_class] + + def view_matrix(self): + """prints current matrix to console""" + + build_page = self.get_page(BuildMDP) + _height = self.shared_data['height'].get() + _width = self.shared_data['width'].get() + print(build_page.gridmdp) + display(build_page.gridmdp, _height, _width) + + def view_terminals(self): + """prints current terminals to console""" + build_page = self.get_page(BuildMDP) + print('Terminals', build_page.terminals) + + def initialize(self): + """calls initialize from BuildMDP""" + + build_page = self.get_page(BuildMDP) + build_page.initialize() + + def master_reset(self): + """calls master_reset from BuildMDP""" + build_page = self.get_page(BuildMDP) + build_page.master_reset() + + def build(self): + """runs specified mdp solving algorithm""" + + frame = SolveMDP(self.container, self) + self.frames[SolveMDP] = frame + frame.grid(row=0, column=0, sticky='nsew') + self.show_frame(SolveMDP) + build_page = self.get_page(BuildMDP) + gridmdp = build_page.gridmdp + terminals = build_page.terminals + solve_page = self.get_page(SolveMDP) + _height = self.shared_data['height'].get() + _width = self.shared_data['width'].get() + solve_page.create_graph(gridmdp, terminals, _height, _width) + + def show_frame(self, controller, cb=False): + """shows specified frame and optionally runs create_buttons""" + if cb: + build_page = self.get_page(BuildMDP) + build_page.create_buttons() + frame = self.frames[controller] + frame.tkraise() class HomePage(tk.Frame): - def __init__(self, parent, controller): - ''' HomePage constructor ''' - - tk.Frame.__init__(self, parent) - self.controller = controller - frame1 = tk.Frame(self) - frame1.pack(side=tk.TOP) - frame3 = tk.Frame(self) - frame3.pack(side=tk.TOP) - frame4 = tk.Frame(self) - frame4.pack(side=tk.TOP) - frame2 = tk.Frame(self) - frame2.pack(side=tk.TOP) - - s = ttk.Style() - s.theme_use('clam') - s.configure('TButton', background=grayd, padding=0) - s.configure('wall.TButton', background=gray2, foreground=white) - s.configure('reward.TButton', background=gray9) - s.configure('+term.TButton', background=green8) - s.configure('-term.TButton', background=pblue, foreground=white) - s.configure('=term.TButton', background=green4) - - label = ttk.Label(frame1, text='GridMDP builder', font=('Helvetica', 18, 'bold'), background=grayef) - label.pack(pady=75, padx=50, side=tk.TOP) - - ec_btn = ttk.Button(frame3, text='Empty cells', width=20) - ec_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) - ec_btn.configure(style='TButton') - - w_btn = ttk.Button(frame3, text='Walls', width=20) - w_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) - w_btn.configure(style='wall.TButton') - - r_btn = ttk.Button(frame3, text='Rewards', width=20) - r_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) - r_btn.configure(style='reward.TButton') - - term_p = ttk.Button(frame3, text='Positive terminals', width=20) - term_p.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) - term_p.configure(style='+term.TButton') - - term_z = ttk.Button(frame3, text='Neutral terminals', width=20) - term_z.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) - term_z.configure(style='=term.TButton') - - term_n = ttk.Button(frame3, text='Negative terminals', width=20) - term_n.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) - term_n.configure(style='-term.TButton') - - label = ttk.Label(frame4, text='Dimensions', font=('Verdana', 14), background=grayef) - label.pack(pady=15, padx=10, side=tk.TOP) - entry_h = tk.Entry(frame2, textvariable=self.controller.shared_data['height'], font=('Verdana', 10), width=3, justify=tk.CENTER) - entry_h.pack(pady=10, padx=10, side=tk.LEFT) - label_x = ttk.Label(frame2, text='X', font=('Verdana', 10), background=grayef) - label_x.pack(pady=10, padx=4, side=tk.LEFT) - entry_w = tk.Entry(frame2, textvariable=self.controller.shared_data['width'], font=('Verdana', 10), width=3, justify=tk.CENTER) - entry_w.pack(pady=10, padx=10, side=tk.LEFT) - button = ttk.Button(self, text='Build a GridMDP', command=lambda: controller.show_frame(BuildMDP, cb=True)) - button.pack(pady=10, padx=10, side=tk.TOP, ipadx=20, ipady=10) - button.configure(style='reward.TButton') + def __init__(self, parent, controller): + """HomePage constructor""" + + tk.Frame.__init__(self, parent) + self.controller = controller + frame1 = tk.Frame(self) + frame1.pack(side=tk.TOP) + frame3 = tk.Frame(self) + frame3.pack(side=tk.TOP) + frame4 = tk.Frame(self) + frame4.pack(side=tk.TOP) + frame2 = tk.Frame(self) + frame2.pack(side=tk.TOP) + + s = ttk.Style() + s.theme_use('clam') + s.configure('TButton', background=grayd, padding=0) + s.configure('wall.TButton', background=gray2, foreground=white) + s.configure('reward.TButton', background=gray9) + s.configure('+term.TButton', background=green8) + s.configure('-term.TButton', background=pblue, foreground=white) + s.configure('=term.TButton', background=green4) + + label = ttk.Label(frame1, text='GridMDP builder', font=('Helvetica', 18, 'bold'), background=grayef) + label.pack(pady=75, padx=50, side=tk.TOP) + + ec_btn = ttk.Button(frame3, text='Empty cells', width=20) + ec_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) + ec_btn.configure(style='TButton') + + w_btn = ttk.Button(frame3, text='Walls', width=20) + w_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) + w_btn.configure(style='wall.TButton') + + r_btn = ttk.Button(frame3, text='Rewards', width=20) + r_btn.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) + r_btn.configure(style='reward.TButton') + + term_p = ttk.Button(frame3, text='Positive terminals', width=20) + term_p.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) + term_p.configure(style='+term.TButton') + + term_z = ttk.Button(frame3, text='Neutral terminals', width=20) + term_z.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) + term_z.configure(style='=term.TButton') + + term_n = ttk.Button(frame3, text='Negative terminals', width=20) + term_n.pack(pady=0, padx=0, side=tk.LEFT, ipady=10) + term_n.configure(style='-term.TButton') + + label = ttk.Label(frame4, text='Dimensions', font=('Verdana', 14), background=grayef) + label.pack(pady=15, padx=10, side=tk.TOP) + entry_h = tk.Entry(frame2, textvariable=self.controller.shared_data['height'], font=('Verdana', 10), width=3, + justify=tk.CENTER) + entry_h.pack(pady=10, padx=10, side=tk.LEFT) + label_x = ttk.Label(frame2, text='X', font=('Verdana', 10), background=grayef) + label_x.pack(pady=10, padx=4, side=tk.LEFT) + entry_w = tk.Entry(frame2, textvariable=self.controller.shared_data['width'], font=('Verdana', 10), width=3, + justify=tk.CENTER) + entry_w.pack(pady=10, padx=10, side=tk.LEFT) + button = ttk.Button(self, text='Build a GridMDP', command=lambda: controller.show_frame(BuildMDP, cb=True)) + button.pack(pady=10, padx=10, side=tk.TOP, ipadx=20, ipady=10) + button.configure(style='reward.TButton') class BuildMDP(tk.Frame): - def __init__(self, parent, controller): - - tk.Frame.__init__(self, parent) - self.grid_rowconfigure(0, weight=1) - self.grid_columnconfigure(0, weight=1) - self.frame = tk.Frame(self) - self.frame.pack() - self.controller = controller - - def create_buttons(self): - ''' creates interactive cells to build MDP ''' - - _height = self.controller.shared_data['height'].get() - _width = self.controller.shared_data['width'].get() - self.controller.menu_bar.entryconfig('Edit', state=tk.NORMAL) - self.controller.menu_bar.entryconfig('Build', state=tk.NORMAL) - self.gridmdp = [[0.0]*max(1, _width) for _ in range(max(1, _height))] - self.buttons = [[None]*max(1, _width) for _ in range(max(1, _height))] - self.terminals = [] - - s = ttk.Style() - s.theme_use('clam') - s.configure('TButton', background=grayd, padding=0) - s.configure('wall.TButton', background=gray2, foreground=white) - s.configure('reward.TButton', background=gray9) - s.configure('+term.TButton', background=green8) - s.configure('-term.TButton', background=pblue, foreground=white) - s.configure('=term.TButton', background=green4) - - for i in range(max(1, _height)): - for j in range(max(1, _width)): - self.buttons[i][j] = ttk.Button(self.frame, text=f'({_height - i - 1}, {j})', width=int(196/max(1, _width)), command=partial(dialogbox, i, j, self.gridmdp, self.terminals, self.buttons, _height)) - self.buttons[i][j].grid(row=i, column=j, ipady=int(336/max(1, _height)) - 12) - - def initialize(self): - ''' runs initialize_dialogbox ''' - - _height = self.controller.shared_data['height'].get() - _width = self.controller.shared_data['width'].get() - initialize_dialogbox(_width, _height, self.gridmdp, self.terminals, self.buttons) - - def master_reset(self): - ''' runs external reset ''' - - _height = self.controller.shared_data['height'].get() - _width = self.controller.shared_data['width'].get() - if tkinter.messagebox.askokcancel('Reset', 'Are you sure you want to reset all cells?'): - external_reset(_width, _height, self.gridmdp, self.terminals, self.buttons) + def __init__(self, parent, controller): + + tk.Frame.__init__(self, parent) + self.grid_rowconfigure(0, weight=1) + self.grid_columnconfigure(0, weight=1) + self.frame = tk.Frame(self) + self.frame.pack() + self.controller = controller + + def create_buttons(self): + """creates interactive cells to build MDP""" + _height = self.controller.shared_data['height'].get() + _width = self.controller.shared_data['width'].get() + self.controller.menu_bar.entryconfig('Edit', state=tk.NORMAL) + self.controller.menu_bar.entryconfig('Build', state=tk.NORMAL) + self.gridmdp = [[0.0] * max(1, _width) for _ in range(max(1, _height))] + self.buttons = [[None] * max(1, _width) for _ in range(max(1, _height))] + self.terminals = [] + + s = ttk.Style() + s.theme_use('clam') + s.configure('TButton', background=grayd, padding=0) + s.configure('wall.TButton', background=gray2, foreground=white) + s.configure('reward.TButton', background=gray9) + s.configure('+term.TButton', background=green8) + s.configure('-term.TButton', background=pblue, foreground=white) + s.configure('=term.TButton', background=green4) + + for i in range(max(1, _height)): + for j in range(max(1, _width)): + self.buttons[i][j] = ttk.Button(self.frame, text=f'({_height - i - 1}, {j})', + width=int(196 / max(1, _width)), + command=partial(dialogbox, i, j, self.gridmdp, self.terminals, + self.buttons, _height)) + self.buttons[i][j].grid(row=i, column=j, ipady=int(336 / max(1, _height)) - 12) + + def initialize(self): + """runs initialize_dialogbox""" + + _height = self.controller.shared_data['height'].get() + _width = self.controller.shared_data['width'].get() + initialize_dialogbox(_width, _height, self.gridmdp, self.terminals, self.buttons) + + def master_reset(self): + """runs external reset""" + _height = self.controller.shared_data['height'].get() + _width = self.controller.shared_data['width'].get() + if tkinter.messagebox.askokcancel('Reset', 'Are you sure you want to reset all cells?'): + external_reset(_width, _height, self.gridmdp, self.terminals, self.buttons) class SolveMDP(tk.Frame): - def __init__(self, parent, controller): - - tk.Frame.__init__(self, parent) - self.grid_rowconfigure(0, weight=1) - self.grid_columnconfigure(0, weight=1) - self.frame = tk.Frame(self) - self.frame.pack() - self.controller = controller - self.terminated = False - self.iterations = 0 - self.epsilon = 0.001 - self.delta = 0 + def __init__(self, parent, controller): - def process_data(self, terminals, _height, _width, gridmdp): - ''' preprocess variables ''' + tk.Frame.__init__(self, parent) + self.grid_rowconfigure(0, weight=1) + self.grid_columnconfigure(0, weight=1) + self.frame = tk.Frame(self) + self.frame.pack() + self.controller = controller + self.terminated = False + self.iterations = 0 + self.epsilon = 0.001 + self.delta = 0 - flipped_terminals = [] + def process_data(self, terminals, _height, _width, gridmdp): + """preprocess variables""" - for terminal in terminals: - flipped_terminals.append((terminal[1], _height - terminal[0] - 1)) + flipped_terminals = [] - grid_to_solve = [[0.0]*max(1, _width) for _ in range(max(1, _height))] - grid_to_show = [[0.0]*max(1, _width) for _ in range(max(1, _height))] + for terminal in terminals: + flipped_terminals.append((terminal[1], _height - terminal[0] - 1)) - for i in range(max(1, _height)): - for j in range(max(1, _width)): - if gridmdp[i][j] == WALL_VALUE: - grid_to_show[i][j] = 0.0 - grid_to_solve[i][j] = None + grid_to_solve = [[0.0] * max(1, _width) for _ in range(max(1, _height))] + grid_to_show = [[0.0] * max(1, _width) for _ in range(max(1, _height))] - else: - grid_to_show[i][j] = grid_to_solve[i][j] = gridmdp[i][j] + for i in range(max(1, _height)): + for j in range(max(1, _width)): + if gridmdp[i][j] == WALL_VALUE: + grid_to_show[i][j] = 0.0 + grid_to_solve[i][j] = None - return flipped_terminals, grid_to_solve, np.flipud(grid_to_show) + else: + grid_to_show[i][j] = grid_to_solve[i][j] = gridmdp[i][j] - def create_graph(self, gridmdp, terminals, _height, _width): - ''' creates canvas and initializes value_iteration_paramteres ''' + return flipped_terminals, grid_to_solve, np.flipud(grid_to_show) - self._height = _height - self._width = _width - self.controller.menu_bar.entryconfig('Edit', state=tk.DISABLED) - self.controller.menu_bar.entryconfig('Build', state=tk.DISABLED) + def create_graph(self, gridmdp, terminals, _height, _width): + """creates canvas and initializes value_iteration_parameters""" + self._height = _height + self._width = _width + self.controller.menu_bar.entryconfig('Edit', state=tk.DISABLED) + self.controller.menu_bar.entryconfig('Build', state=tk.DISABLED) - self.terminals, self.gridmdp, self.grid_to_show = self.process_data(terminals, _height, _width, gridmdp) - self.sequential_decision_environment = GridMDP(self.gridmdp, terminals=self.terminals) + self.terminals, self.gridmdp, self.grid_to_show = self.process_data(terminals, _height, _width, gridmdp) + self.sequential_decision_environment = GridMDP(self.gridmdp, terminals=self.terminals) - self.initialize_value_iteration_parameters(self.sequential_decision_environment) + self.initialize_value_iteration_parameters(self.sequential_decision_environment) - self.canvas = FigureCanvasTkAgg(fig, self.frame) - self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True) - self.anim = animation.FuncAnimation(fig, self.animate_graph, interval=50) - self.canvas.show() + self.canvas = FigureCanvasTkAgg(fig, self.frame) + self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True) + self.anim = animation.FuncAnimation(fig, self.animate_graph, interval=50) + self.canvas.show() - def animate_graph(self, i): - ''' performs value iteration and animates graph ''' + def animate_graph(self, i): + """performs value iteration and animates graph""" - # cmaps to use: bone_r, Oranges, inferno, BrBG, copper - self.iterations += 1 - x_interval = max(2, len(self.gridmdp[0])) - y_interval = max(2, len(self.gridmdp)) - x = np.linspace(0, len(self.gridmdp[0]) - 1, x_interval) - y = np.linspace(0, len(self.gridmdp) - 1, y_interval) + # cmaps to use: bone_r, Oranges, inferno, BrBG, copper + self.iterations += 1 + x_interval = max(2, len(self.gridmdp[0])) + y_interval = max(2, len(self.gridmdp)) + x = np.linspace(0, len(self.gridmdp[0]) - 1, x_interval) + y = np.linspace(0, len(self.gridmdp) - 1, y_interval) - sub.clear() - sub.imshow(self.grid_to_show, cmap='BrBG', aspect='auto', interpolation='none', extent=extents(x) + extents(y), origin='lower') - fig.tight_layout() + sub.clear() + sub.imshow(self.grid_to_show, cmap='BrBG', aspect='auto', interpolation='none', extent=extents(x) + extents(y), + origin='lower') + fig.tight_layout() - U = self.U1.copy() + U = self.U1.copy() - for s in self.sequential_decision_environment.states: - self.U1[s] = self.R(s) + self.gamma * max([sum([p * U[s1] for (p, s1) in self.T(s, a)]) for a in self.sequential_decision_environment.actions(s)]) - self.delta = max(self.delta, abs(self.U1[s] - U[s])) + for s in self.sequential_decision_environment.states: + self.U1[s] = self.R(s) + self.gamma * max( + [sum([p * U[s1] for (p, s1) in self.T(s, a)]) for a in self.sequential_decision_environment.actions(s)]) + self.delta = max(self.delta, abs(self.U1[s] - U[s])) - self.grid_to_show = grid_to_show = [[0.0]*max(1, self._width) for _ in range(max(1, self._height))] - for k, v in U.items(): - self.grid_to_show[k[1]][k[0]] = v + self.grid_to_show = grid_to_show = [[0.0] * max(1, self._width) for _ in range(max(1, self._height))] + for k, v in U.items(): + self.grid_to_show[k[1]][k[0]] = v - if (self.delta < self.epsilon * (1 - self.gamma) / self.gamma) or (self.iterations > 60) and self.terminated == False: - self.terminated = True - display(self.grid_to_show, self._height, self._width) + if (self.delta < self.epsilon * (1 - self.gamma) / self.gamma) or ( + self.iterations > 60) and self.terminated == False: + self.terminated = True + display(self.grid_to_show, self._height, self._width) - pi = best_policy(self.sequential_decision_environment, value_iteration(self.sequential_decision_environment, .01)) - display_best_policy(self.sequential_decision_environment.to_arrows(pi), self._height, self._width) - - ax = fig.gca() - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) - ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + pi = best_policy(self.sequential_decision_environment, + value_iteration(self.sequential_decision_environment, .01)) + display_best_policy(self.sequential_decision_environment.to_arrows(pi), self._height, self._width) - def initialize_value_iteration_parameters(self, mdp): - ''' initializes value_iteration parameters ''' + ax = fig.gca() + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.yaxis.set_major_locator(MaxNLocator(integer=True)) - self.U1 = {s: 0 for s in mdp.states} - self.R, self.T, self.gamma = mdp.R, mdp.T, mdp.gamma + def initialize_value_iteration_parameters(self, mdp): + """initializes value_iteration parameters""" + self.U1 = {s: 0 for s in mdp.states} + self.R, self.T, self.gamma = mdp.R, mdp.T, mdp.gamma - def value_iteration_metastep(self, mdp, iterations=20): - ''' runs value_iteration ''' + def value_iteration_metastep(self, mdp, iterations=20): + """runs value_iteration""" - U_over_time = [] - U1 = {s: 0 for s in mdp.states} - R, T, gamma = mdp.R, mdp.T, mdp.gamma + U_over_time = [] + U1 = {s: 0 for s in mdp.states} + R, T, gamma = mdp.R, mdp.T, mdp.gamma - for _ in range(iterations): - U = U1.copy() + for _ in range(iterations): + U = U1.copy() - for s in mdp.states: - U1[s] = R(s) + gamma * max([sum([p * U[s1] for (p, s1) in T(s, a)]) for a in mdp.actions(s)]) + for s in mdp.states: + U1[s] = R(s) + gamma * max([sum([p * U[s1] for (p, s1) in T(s, a)]) for a in mdp.actions(s)]) - U_over_time.append(U) - return U_over_time + U_over_time.append(U) + return U_over_time if __name__ == '__main__': - app = MDPapp() - app.geometry('1280x720') - app.mainloop() \ No newline at end of file + app = MDPapp() + app.geometry('1280x720') + app.mainloop() diff --git a/gui/romania_problem.py b/gui/romania_problem.py index 08219bb55..9ec94099d 100644 --- a/gui/romania_problem.py +++ b/gui/romania_problem.py @@ -621,9 +621,7 @@ def reset_map(): # TODO: Add more search algorithms in the OptionMenu - - -def main(): +if __name__ == "__main__": global algo, start, goal, next_button root = Tk() root.title("Road Map of Romania") @@ -672,7 +670,3 @@ def main(): frame1.pack(side=BOTTOM) create_map(root) root.mainloop() - - -if __name__ == "__main__": - main() diff --git a/gui/tic-tac-toe.py b/gui/tic-tac-toe.py index 4f51425c1..66d9d6e75 100644 --- a/gui/tic-tac-toe.py +++ b/gui/tic-tac-toe.py @@ -1,11 +1,12 @@ -from tkinter import * -import sys import os.path -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from tkinter import * + from games import minmax_decision, alpha_beta_player, random_player, TicTacToe # "gen_state" can be used to generate a game state to apply the algorithm from tests.test_games import gen_state +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + ttt = TicTacToe() root = None buttons = [] @@ -152,8 +153,7 @@ def check_victory(button): return True # check if previous move was on the secondary diagonal and caused a win - if x + y \ - == 2 and buttons[0][2]['text'] == buttons[1][1]['text'] == buttons[2][0]['text'] != " ": + if x + y == 2 and buttons[0][2]['text'] == buttons[1][1]['text'] == buttons[2][0]['text'] != " ": buttons[0][2].config(text="/" + tt + "/") buttons[1][1].config(text="/" + tt + "/") buttons[2][0].config(text="/" + tt + "/") @@ -213,7 +213,7 @@ def exit_game(root): root.destroy() -def main(): +if __name__ == "__main__": global result, choices root = Tk() @@ -230,7 +230,3 @@ def main(): menu = OptionMenu(root, choices, "Vs Random", "Vs Pro", "Vs Legend") menu.pack() root.mainloop() - - -if __name__ == "__main__": - main() diff --git a/gui/tsp.py b/gui/tsp.py index 1830cba23..590fff354 100644 --- a/gui/tsp.py +++ b/gui/tsp.py @@ -1,21 +1,19 @@ from tkinter import * from tkinter import messagebox -import sys -import os.path -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from search import * + import utils -import numpy as np +from search import * -distances = {} +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +distances = {} -class TSP_problem(Problem): - """ subclass of Problem to define various functions """ +class TSProblem(Problem): + """subclass of Problem to define various functions""" def two_opt(self, state): - """ Neighbour generating function for Traveling Salesman Problem """ + """Neighbour generating function for Traveling Salesman Problem""" neighbour_state = state[:] left = random.randint(0, len(neighbour_state) - 1) right = random.randint(0, len(neighbour_state) - 1) @@ -25,15 +23,15 @@ def two_opt(self, state): return neighbour_state def actions(self, state): - """ action that can be excuted in given state """ + """action that can be executed in given state""" return [self.two_opt] def result(self, state, action): - """ result after applying the given action on the given state """ + """result after applying the given action on the given state""" return action(state) def path_cost(self, c, state1, action, state2): - """ total distance for the Traveling Salesman to be covered if in state2 """ + """total distance for the Traveling Salesman to be covered if in state2""" cost = 0 for i in range(len(state2) - 1): cost += distances[state2[i]][state2[i + 1]] @@ -41,12 +39,12 @@ def path_cost(self, c, state1, action, state2): return cost def value(self, state): - """ value of path cost given negative for the given state """ + """value of path cost given negative for the given state""" return -1 * self.path_cost(None, None, None, state) -class TSP_Gui(): - """ Class to create gui of Traveling Salesman using simulated annealing where one can +class TSPGui(): + """Class to create gui of Traveling Salesman using simulated annealing where one can select cities, change speed and temperature. Distances between cities are euclidean distances between them. """ @@ -67,7 +65,7 @@ def __init__(self, root, all_cities): Label(self.root, text="Map of Romania", font="Times 13 bold").grid(row=0, columnspan=10) def create_checkboxes(self, side=LEFT, anchor=W): - """ To select cities which are to be a part of Traveling Salesman Problem """ + """To select cities which are to be a part of Traveling Salesman Problem""" row_number = 0 column_number = 0 @@ -85,7 +83,7 @@ def create_checkboxes(self, side=LEFT, anchor=W): row_number += 1 def create_buttons(self): - """ Create start and quit button """ + """Create start and quit button""" Button(self.frame_select_cities, textvariable=self.button_text, command=self.run_traveling_salesman).grid(row=5, column=4, sticky=E + W) @@ -93,7 +91,7 @@ def create_buttons(self): row=5, column=5, sticky=E + W) def create_dropdown_menu(self): - """ Create dropdown menu for algorithm selection """ + """Create dropdown menu for algorithm selection""" choices = {'Simulated Annealing', 'Genetic Algorithm', 'Hill Climbing'} self.algo_var.set('Simulated Annealing') @@ -102,19 +100,19 @@ def create_dropdown_menu(self): dropdown_menu.config(width=19) def run_traveling_salesman(self): - """ Choose selected citites """ + """Choose selected cities""" cities = [] for i in range(len(self.vars)): if self.vars[i].get() == 1: cities.append(self.all_cities[i]) - tsp_problem = TSP_problem(cities) + tsp_problem = TSProblem(cities) self.button_text.set("Reset") self.create_canvas(tsp_problem) def calculate_canvas_size(self): - """ Width and height for canvas """ + """Width and height for canvas""" minx, maxx = sys.maxsize, -1 * sys.maxsize miny, maxy = sys.maxsize, -1 * sys.maxsize @@ -137,7 +135,7 @@ def calculate_canvas_size(self): self.canvas_height = canvas_height def create_canvas(self, problem): - """ creating map with cities """ + """creating map with cities""" map_canvas = Canvas(self.frame_canvas, width=self.canvas_width, height=self.canvas_height) map_canvas.grid(row=3, columnspan=10) @@ -163,18 +161,18 @@ def create_canvas(self, problem): variable=self.speed, label="Speed ----> ", showvalue=0, font="Times 11", relief="sunken", cursor="gumby") speed_scale.grid(row=1, columnspan=5, sticky=N + S + E + W) - + if self.algo_var.get() == 'Simulated Annealing': self.temperature = IntVar() temperature_scale = Scale(self.frame_canvas, from_=100, to=0, orient=HORIZONTAL, - length=200, variable=self.temperature, label="Temperature ---->", - font="Times 11", relief="sunken", showvalue=0, cursor="gumby") + length=200, variable=self.temperature, label="Temperature ---->", + font="Times 11", relief="sunken", showvalue=0, cursor="gumby") temperature_scale.grid(row=1, column=5, columnspan=5, sticky=N + S + E + W) self.simulated_annealing_with_tunable_T(problem, map_canvas) elif self.algo_var.get() == 'Genetic Algorithm': self.mutation_rate = DoubleVar() self.mutation_rate.set(0.05) - mutation_rate_scale = Scale(self.frame_canvas, from_=0, to=1, orient=HORIZONTAL, + mutation_rate_scale = Scale(self.frame_canvas, from_=0, to=1, orient=HORIZONTAL, length=200, variable=self.mutation_rate, label='Mutation Rate ---->', font='Times 11', relief='sunken', showvalue=0, cursor='gumby', resolution=0.001) mutation_rate_scale.grid(row=1, column=5, columnspan=5, sticky='nsew') @@ -182,23 +180,23 @@ def create_canvas(self, problem): elif self.algo_var.get() == 'Hill Climbing': self.no_of_neighbors = IntVar() self.no_of_neighbors.set(100) - no_of_neighbors_scale = Scale(self.frame_canvas, from_=10, to=1000, orient=HORIZONTAL, + no_of_neighbors_scale = Scale(self.frame_canvas, from_=10, to=1000, orient=HORIZONTAL, length=200, variable=self.no_of_neighbors, label='Number of neighbors ---->', - font='Times 11',relief='sunken', showvalue=0, cursor='gumby') + font='Times 11', relief='sunken', showvalue=0, cursor='gumby') no_of_neighbors_scale.grid(row=1, column=5, columnspan=5, sticky='nsew') self.hill_climbing(problem, map_canvas) def exp_schedule(k=100, lam=0.03, limit=1000): - """ One possible schedule function for simulated annealing """ + """One possible schedule function for simulated annealing""" - return lambda t: (k * math.exp(-lam * t) if t < limit else 0) + return lambda t: (k * np.exp(-lam * t) if t < limit else 0) def simulated_annealing_with_tunable_T(self, problem, map_canvas, schedule=exp_schedule()): - """ Simulated annealing where temperature is taken as user input """ + """Simulated annealing where temperature is taken as user input""" current = Node(problem.initial) - while(1): + while True: T = schedule(self.temperature.get()) if T == 0: return current.state @@ -207,7 +205,7 @@ def simulated_annealing_with_tunable_T(self, problem, map_canvas, schedule=exp_s return current.state next = random.choice(neighbors) delta_e = problem.value(next.state) - problem.value(current.state) - if delta_e > 0 or probability(math.exp(delta_e / T)): + if delta_e > 0 or probability(np.exp(delta_e / T)): map_canvas.delete("poly") current = next @@ -221,10 +219,10 @@ def simulated_annealing_with_tunable_T(self, problem, map_canvas, schedule=exp_s map_canvas.after(self.speed.get()) def genetic_algorithm(self, problem, map_canvas): - """ Genetic Algorithm modified for the given problem """ + """Genetic Algorithm modified for the given problem""" def init_population(pop_number, gene_pool, state_length): - """ initialize population """ + """initialize population""" population = [] for i in range(pop_number): @@ -232,7 +230,7 @@ def init_population(pop_number, gene_pool, state_length): return population def recombine(state_a, state_b): - """ recombine two problem states """ + """recombine two problem states""" start = random.randint(0, len(state_a) - 1) end = random.randint(start + 1, len(state_a)) @@ -243,7 +241,7 @@ def recombine(state_a, state_b): return new_state def mutate(state, mutation_rate): - """ mutate problem states """ + """mutate problem states""" if random.uniform(0, 1) < mutation_rate: sample = random.sample(range(len(state)), 2) @@ -251,17 +249,18 @@ def mutate(state, mutation_rate): return state def fitness_fn(state): - """ calculate fitness of a particular state """ - + """calculate fitness of a particular state""" + fitness = problem.value(state) return int((5600 + fitness) ** 2) current = Node(problem.initial) population = init_population(100, current.state, len(current.state)) all_time_best = current.state - while(1): - population = [mutate(recombine(*select(2, population, fitness_fn)), self.mutation_rate.get()) for i in range(len(population))] - current_best = utils.argmax(population, key=fitness_fn) + while True: + population = [mutate(recombine(*select(2, population, fitness_fn)), self.mutation_rate.get()) + for _ in range(len(population))] + current_best = np.argmax(population, key=fitness_fn) if fitness_fn(current_best) > fitness_fn(all_time_best): all_time_best = current_best self.cost.set("Cost = " + str('%0.3f' % (-1 * problem.value(all_time_best)))) @@ -280,10 +279,10 @@ def fitness_fn(state): map_canvas.after(self.speed.get()) def hill_climbing(self, problem, map_canvas): - """ hill climbing where number of neighbors is taken as user input """ + """hill climbing where number of neighbors is taken as user input""" def find_neighbors(state, number_of_neighbors=100): - """ finds neighbors using two_opt method """ + """finds neighbors using two_opt method""" neighbors = [] for i in range(number_of_neighbors): @@ -293,9 +292,9 @@ def find_neighbors(state, number_of_neighbors=100): return neighbors current = Node(problem.initial) - while(1): + while True: neighbors = find_neighbors(current.state, self.no_of_neighbors.get()) - neighbor = utils.argmax_random_tie(neighbors, key=lambda node: problem.value(node.state)) + neighbor = np.argmax_random_tie(neighbors, key=lambda node: problem.value(node.state)) map_canvas.delete('poly') points = [] for city in current.state: @@ -317,7 +316,8 @@ def on_closing(self): if messagebox.askokcancel('Quit', 'Do you want to quit?'): self.root.destroy() -def main(): + +if __name__ == '__main__': all_cities = [] for city in romania_map.locations.keys(): distances[city] = {} @@ -334,13 +334,9 @@ def main(): root = Tk() root.title("Traveling Salesman Problem") - cities_selection_panel = TSP_Gui(root, all_cities) + cities_selection_panel = TSPGui(root, all_cities) cities_selection_panel.create_checkboxes() cities_selection_panel.create_buttons() cities_selection_panel.create_dropdown_menu() root.protocol('WM_DELETE_WINDOW', cities_selection_panel.on_closing) root.mainloop() - - -if __name__ == '__main__': - main() diff --git a/gui/vacuum_agent.py b/gui/vacuum_agent.py index 23292efb3..b07dab282 100644 --- a/gui/vacuum_agent.py +++ b/gui/vacuum_agent.py @@ -1,15 +1,14 @@ -from tkinter import * -import random -import sys import os.path -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from tkinter import * + from agents import * +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + loc_A, loc_B = (0, 0), (1, 0) # The two locations for the Vacuum world class Gui(Environment): - """This GUI environment has two locations, A and B. Each can be Dirty or Clean. The agent perceives its location and the location's status.""" @@ -33,7 +32,7 @@ def thing_classes(self): def percept(self, agent): """Returns the agent's location, and the location status (Dirty/Clean).""" - return (agent.location, self.status[agent.location]) + return agent.location, self.status[agent.location] def execute_action(self, agent, action): """Change the location status (Dirty/Clean); track performance. @@ -137,8 +136,7 @@ def move_agent(env, agent, before_step): # TODO: Add more agents to the environment. # TODO: Expand the environment to XYEnvironment. -def main(): - """The main function of the program.""" +if __name__ == "__main__": root = Tk() root.title("Vacuum Environment") root.geometry("420x380") @@ -154,7 +152,3 @@ def main(): create_agent(env, agent) next_button.config(command=lambda: env.update_env(agent)) root.mainloop() - - -if __name__ == "__main__": - main() diff --git a/gui/xy_vacuum_environment.py b/gui/xy_vacuum_environment.py index 4ba4497ea..093abc6c3 100644 --- a/gui/xy_vacuum_environment.py +++ b/gui/xy_vacuum_environment.py @@ -1,10 +1,10 @@ -from tkinter import * -import random -import sys import os.path -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from tkinter import * + from agents import * +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + class Gui(VacuumEnvironment): """This is a two-dimensional GUI environment. Each location may be @@ -13,8 +13,10 @@ class Gui(VacuumEnvironment): xi, yi = (0, 0) perceptible_distance = 1 - def __init__(self, root, width=7, height=7, elements=['D', 'W']): + def __init__(self, root, width=7, height=7, elements=None): super().__init__(width, height) + if elements is None: + elements = ['D', 'W'] self.root = root self.create_frames() self.create_buttons() @@ -71,10 +73,10 @@ def display_element(self, button): def execute_action(self, agent, action): """Determines the action the agent performs.""" - xi, yi = ((self.xi, self.yi)) + xi, yi = (self.xi, self.yi) if action == 'Suck': dirt_list = self.list_things_at(agent.location, Dirt) - if dirt_list != []: + if dirt_list: dirt = dirt_list[0] agent.performance += 100 self.delete_thing(dirt) @@ -166,11 +168,9 @@ def __init__(self, program=None): self.direction = Direction("up") -# TODO: -# Check the coordinate system. -# Give manual choice for agent's location. -def main(): - """The main function.""" +# TODO: Check the coordinate system. +# TODO: Give manual choice for agent's location. +if __name__ == "__main__": root = Tk() root.title("Vacuum Environment") root.geometry("420x440") @@ -189,7 +189,3 @@ def main(): next_button.config(command=env.update_env) reset_button.config(command=lambda: env.reset_env(agt)) root.mainloop() - - -if __name__ == "__main__": - main() diff --git a/learning4e.py b/learning4e.py index 7dba31cfa..3cf41ad1e 100644 --- a/learning4e.py +++ b/learning4e.py @@ -568,7 +568,7 @@ def LogisticLinearLeaner(dataset, learning_rate=0.01, epochs=100): # pass over all examples for example in examples: x = [1] + example - y = Sigmoid().f(dot_product(w, x)) + y = Sigmoid().function(dot_product(w, x)) h.append(Sigmoid().derivative(y)) t = example[idx_t] err.append(t - y) @@ -580,7 +580,7 @@ def LogisticLinearLeaner(dataset, learning_rate=0.01, epochs=100): def predict(example): x = [1] + example - return Sigmoid().f(dot_product(w, x)) + return Sigmoid().function(dot_product(w, x)) return predict diff --git a/pytest.ini b/pytest.ini index 7d983c3fc..5b9f41dbc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,4 @@ [pytest] filterwarnings = - ignore::ResourceWarning + ignore::DeprecationWarning + ignore::RuntimeWarning diff --git a/tests/test_deep_learning4e.py b/tests/test_deep_learning4e.py index 305c2e65c..060e55788 100644 --- a/tests/test_deep_learning4e.py +++ b/tests/test_deep_learning4e.py @@ -11,8 +11,8 @@ def test_neural_net(): iris = DataSet(name='iris') classes = ['setosa', 'versicolor', 'virginica'] iris.classes_to_numbers(classes) - nnl_gd = NeuralNetLearner(iris, [4], learning_rate=0.15, epochs=100, optimizer=gradient_descent) - nnl_adam = NeuralNetLearner(iris, [4], learning_rate=0.001, epochs=200, optimizer=adam) + nnl_gd = NeuralNetLearner(iris, [4], l_rate=0.15, epochs=100, optimizer=stochastic_gradient_descent) + nnl_adam = NeuralNetLearner(iris, [4], l_rate=0.001, epochs=200, optimizer=adam) tests = [([5.0, 3.1, 0.9, 0.1], 0), ([5.1, 3.5, 1.0, 0.0], 0), ([4.9, 3.3, 1.1, 0.1], 0), @@ -32,8 +32,8 @@ def test_perceptron(): iris = DataSet(name='iris') classes = ['setosa', 'versicolor', 'virginica'] iris.classes_to_numbers(classes) - pl_gd = PerceptronLearner(iris, learning_rate=0.01, epochs=100, optimizer=gradient_descent) - pl_adam = PerceptronLearner(iris, learning_rate=0.01, epochs=100, optimizer=adam) + pl_gd = PerceptronLearner(iris, l_rate=0.01, epochs=100, optimizer=stochastic_gradient_descent) + pl_adam = PerceptronLearner(iris, l_rate=0.01, epochs=100, optimizer=adam) tests = [([5, 3, 1, 0.1], 0), ([5, 3.5, 1, 0], 0), ([6, 3, 4, 1.1], 1), diff --git a/utils4e.py b/utils4e.py index b0fbf8df8..777a88e4a 100644 --- a/utils4e.py +++ b/utils4e.py @@ -400,7 +400,7 @@ def gaussian_kernel_2D(size=3, sigma=0.5): class Activation: - def f(self, x): + def function(self, x): return NotImplementedError def derivative(self, x): @@ -414,7 +414,7 @@ def softmax1D(x): class Sigmoid(Activation): - def f(self, x): + def function(self, x): if x >= 100: return 1 if x <= -100: @@ -427,7 +427,7 @@ def derivative(self, value): class Relu(Activation): - def f(self, x): + def function(self, x): return max(0, x) def derivative(self, value): @@ -436,7 +436,7 @@ def derivative(self, value): class Elu(Activation): - def f(self, x, alpha=0.01): + def function(self, x, alpha=0.01): return x if x > 0 else alpha * (np.exp(x) - 1) def derivative(self, value, alpha=0.01): @@ -445,7 +445,7 @@ def derivative(self, value, alpha=0.01): class Tanh(Activation): - def f(self, x): + def function(self, x): return np.tanh(x) def derivative(self, value): @@ -454,7 +454,7 @@ def derivative(self, value): class LeakyRelu(Activation): - def f(self, x, alpha=0.01): + def function(self, x, alpha=0.01): return x if x > 0 else alpha * x def derivative(self, value, alpha=0.01):