|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | + |
| 4 | +''' |
| 5 | + Gini impurity is a measure of how often a randomly |
| 6 | + chosen element from the set would be incorrectly labeled |
| 7 | + if it was randomly labeled according to the distribution |
| 8 | + of labels in the subset. |
| 9 | +''' |
| 10 | + |
| 11 | +# Read example from csv |
| 12 | +#example = "ImpurityMeasures/example.csv" |
| 13 | +example = "ImpurityMeasures/example2.csv" |
| 14 | +data = pd.read_csv(example) |
| 15 | + |
| 16 | +# The fastest way to find occurences within an array |
| 17 | +# Check it out in this stackoverflow thread: |
| 18 | +# https://stackoverflow.com/questions/10741346/numpy-most-efficient-frequency-counts-for-unique-values-in-an-array |
| 19 | +def unique_count(A): |
| 20 | + unique, inverse = np.unique(A, return_inverse=True) |
| 21 | + count = np.zeros(len(unique), dtype=int) |
| 22 | + np.add.at(count, inverse, 1) |
| 23 | + return np.vstack((unique, count)).T |
| 24 | + |
| 25 | +# Get the gini impurity of one node |
| 26 | +def gini(N): |
| 27 | + uniques = unique_count(N) |
| 28 | + total = len(N) |
| 29 | + # Probability of every node leaf Pi |
| 30 | + Pi = np.zeros(uniques.shape[0]) |
| 31 | + for i in range(uniques.shape[0]): |
| 32 | + pi = uniques[i,1]/total |
| 33 | + Pi[i] = pi |
| 34 | + |
| 35 | + # Appliying the Gini formula: 1 - Sum[Pi(t)^2] |
| 36 | + gini = 1 - np.sum(Pi**2) |
| 37 | + |
| 38 | + print(uniques) |
| 39 | + print("Probabilities of Pi: ", Pi) |
| 40 | + print("Gini impurity: {}\n".format(gini)) |
| 41 | + return gini |
| 42 | + |
| 43 | +# Finding the purest node within data |
| 44 | +def findPurest(): |
| 45 | + purest = { |
| 46 | + "column": "", |
| 47 | + "gini": 1 |
| 48 | + } |
| 49 | + |
| 50 | + for node in data: |
| 51 | + print("------------ {} ------------".format(node)) |
| 52 | + gini_ = gini(data[node].values) |
| 53 | + |
| 54 | + if gini_ < purest['gini']: |
| 55 | + purest['column'] = node |
| 56 | + purest['gini'] = gini_ |
| 57 | + |
| 58 | + print("The purest node is: {} \nWith an gini index: {}" |
| 59 | + .format(purest['column'], purest['gini'])) |
| 60 | + |
| 61 | + return purest['gini'] |
| 62 | + |
| 63 | +# You can select the Father node |
| 64 | +#father = findPurest() # is the purest node |
| 65 | +father = "Clase" |
| 66 | +child = "Talla_Camisa" |
| 67 | + |
| 68 | +def giniWeighted(): |
| 69 | + # |
| 70 | + crosstab = pd.crosstab(data[child], data[father], margins=True, margins_name="Total") |
| 71 | + print(crosstab,'\n') |
| 72 | + index = crosstab.index |
| 73 | + crosstab = crosstab.values |
| 74 | + # Sum the ginis |
| 75 | + giniW = 0 |
| 76 | + |
| 77 | + for i in range(len(index)-1): |
| 78 | + print("--------------- {} ---------------".format(index[i])) |
| 79 | + pi = crosstab[i,:-1]/crosstab[i,-1] |
| 80 | + gini_ = 1 - np.sum(pi**2) |
| 81 | + |
| 82 | + # Calculate the weighted gini and sum |
| 83 | + giniW += (crosstab[i,-1]/crosstab[-1,-1])*gini_ |
| 84 | + |
| 85 | + print("Probabilities of Pi: {}\nGini impurity: {}\n" |
| 86 | + .format(pi, gini_)) |
| 87 | + |
| 88 | + print("The weighted gini is: ", giniW) |
| 89 | + return giniW |
| 90 | + |
| 91 | +giniW = giniWeighted() |
0 commit comments