mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-06 16:18:17 +08:00
91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import numpy as np
|
|
from collections import deque
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn.tree import _tree as ctree
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Rectangle
|
|
|
|
|
|
class AABB:
|
|
"""Axis-aligned bounding box"""
|
|
|
|
def __init__(self, n_features):
|
|
self.limits = np.array([[-np.inf, np.inf]] * n_features)
|
|
|
|
def split(self, f, v):
|
|
left = AABB(self.limits.shape[0])
|
|
right = AABB(self.limits.shape[0])
|
|
left.limits = self.limits.copy()
|
|
right.limits = self.limits.copy()
|
|
left.limits[f, 1] = v
|
|
right.limits[f, 0] = v
|
|
|
|
return left, right
|
|
|
|
|
|
def tree_bounds(tree, n_features=None):
|
|
"""Compute final decision rule for each node in tree"""
|
|
if n_features is None:
|
|
n_features = np.max(tree.feature) + 1
|
|
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
|
|
queue = deque([0])
|
|
while queue:
|
|
i = queue.pop()
|
|
l = tree.children_left[i]
|
|
r = tree.children_right[i]
|
|
if l != ctree.TREE_LEAF:
|
|
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
|
|
queue.extend([l, r])
|
|
return aabbs
|
|
|
|
|
|
def decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
|
"""Extract decision areas.
|
|
|
|
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
|
|
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
|
|
x: index of the feature that goes on the x axis
|
|
y: index of the feature that goes on the y axis
|
|
n_features: override autodetection of number of features
|
|
"""
|
|
tree = tree_classifier.tree_
|
|
aabbs = tree_bounds(tree, n_features)
|
|
maxrange = np.array(maxrange)
|
|
rectangles = []
|
|
for i in range(len(aabbs)):
|
|
if tree.children_left[i] != ctree.TREE_LEAF:
|
|
continue
|
|
l = aabbs[i].limits
|
|
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
|
|
# clip out of bounds indices
|
|
"""
|
|
if r[0] < maxrange[0]:
|
|
r[0] = maxrange[0]
|
|
if r[1] > maxrange[1]:
|
|
r[1] = maxrange[1]
|
|
if r[2] < maxrange[2]:
|
|
r[2] = maxrange[2]
|
|
if r[3] > maxrange[3]:
|
|
r[3] = maxrange[3]
|
|
print(r)
|
|
"""
|
|
rectangles.append(r)
|
|
rectangles = np.array(rectangles)
|
|
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
|
|
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
|
return rectangles
|
|
|
|
|
|
def plot_areas(rectangles):
|
|
for rect in rectangles:
|
|
color = ["b", "r"][int(rect[4])]
|
|
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
|
|
rp = Rectangle(
|
|
[rect[0], rect[2]],
|
|
rect[1] - rect[0],
|
|
rect[3] - rect[2],
|
|
color=color,
|
|
alpha=0.3,
|
|
)
|
|
plt.gca().add_artist(rp)
|