Files
ManimML/examples/decision_tree/decision_tree_surface.py

91 lines
2.9 KiB
Python

import numpy as np
from collections import deque
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree as ctree
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
class AABB:
"""Axis-aligned bounding box"""
def __init__(self, n_features):
self.limits = np.array([[-np.inf, np.inf]] * n_features)
def split(self, f, v):
left = AABB(self.limits.shape[0])
right = AABB(self.limits.shape[0])
left.limits = self.limits.copy()
right.limits = self.limits.copy()
left.limits[f, 1] = v
right.limits[f, 0] = v
return left, right
def tree_bounds(tree, n_features=None):
"""Compute final decision rule for each node in tree"""
if n_features is None:
n_features = np.max(tree.feature) + 1
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
queue = deque([0])
while queue:
i = queue.pop()
l = tree.children_left[i]
r = tree.children_right[i]
if l != ctree.TREE_LEAF:
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
queue.extend([l, r])
return aabbs
def decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
"""Extract decision areas.
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
x: index of the feature that goes on the x axis
y: index of the feature that goes on the y axis
n_features: override autodetection of number of features
"""
tree = tree_classifier.tree_
aabbs = tree_bounds(tree, n_features)
maxrange = np.array(maxrange)
rectangles = []
for i in range(len(aabbs)):
if tree.children_left[i] != ctree.TREE_LEAF:
continue
l = aabbs[i].limits
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
# clip out of bounds indices
"""
if r[0] < maxrange[0]:
r[0] = maxrange[0]
if r[1] > maxrange[1]:
r[1] = maxrange[1]
if r[2] < maxrange[2]:
r[2] = maxrange[2]
if r[3] > maxrange[3]:
r[3] = maxrange[3]
print(r)
"""
rectangles.append(r)
rectangles = np.array(rectangles)
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
return rectangles
def plot_areas(rectangles):
for rect in rectangles:
color = ["b", "r"][int(rect[4])]
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
rp = Rectangle(
[rect[0], rect[2]],
rect[1] - rect[0],
rect[3] - rect[2],
color=color,
alpha=0.3,
)
plt.gca().add_artist(rp)