Bug fixes and linting for the activation functions addition.

This commit is contained in:
Alec Helbling
2023-01-25 08:40:32 -05:00
parent ce184af78e
commit f56620f047
42 changed files with 1275 additions and 387 deletions

View File

@ -3,6 +3,7 @@ import numpy as np
from collections import deque
from sklearn.tree import _tree as ctree
class AABB:
"""Axis-aligned bounding box"""
@ -19,6 +20,7 @@ class AABB:
return left, right
def tree_bounds(tree, n_features=None):
"""Compute final decision rule for each node in tree"""
if n_features is None:
@ -34,6 +36,7 @@ def tree_bounds(tree, n_features=None):
queue.extend([l, r])
return aabbs
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
"""Extract decision areas.
@ -70,6 +73,7 @@ def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None)
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
return rectangles
def plot_areas(rectangles):
for rect in rectangles:
color = ["b", "r"][int(rect[4])]
@ -83,6 +87,7 @@ def plot_areas(rectangles):
)
plt.gca().add_artist(rp)
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
# get all polygons of each color
polygon_dict = {
@ -156,6 +161,7 @@ def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
return_polygons.append(polygon)
return return_polygons
class IrisDatasetPlot(VGroup):
def __init__(self, iris):
points = iris.data[:, 0:2]
@ -269,7 +275,6 @@ class IrisDatasetPlot(VGroup):
class DecisionTreeSurface(VGroup):
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
# take the tree and construct the surface from it
self.tree_clf = tree_clf
@ -346,8 +351,8 @@ class DecisionTreeSurface(VGroup):
def make_split_to_animation_map(self):
"""
Returns a dictionary mapping a given split
node to an animation to be played
Returns a dictionary mapping a given split
node to an animation to be played
"""
# Create an initial decision tree surface
# Go through each split node