mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-18 03:05:23 +08:00
61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
from manim import *
|
|
from manim_ml.decision_tree.decision_tree import (
|
|
DecisionTreeDiagram,
|
|
DecisionTreeSurface,
|
|
IrisDatasetPlot,
|
|
)
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn import datasets
|
|
import sklearn
|
|
import matplotlib.pyplot as plt
|
|
|
|
def learn_iris_decision_tree(iris):
|
|
decision_tree = DecisionTreeClassifier(
|
|
random_state=1, max_depth=3, max_leaf_nodes=6
|
|
)
|
|
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
|
|
# output the decisioin tree in some format
|
|
return decision_tree
|
|
|
|
def make_sklearn_tree(dataset, max_tree_depth=3):
|
|
tree = learn_iris_decision_tree(dataset)
|
|
feature_names = dataset.feature_names[0:2]
|
|
return tree, tree.tree_
|
|
|
|
class DecisionTreeScene(Scene):
|
|
def construct(self):
|
|
"""Makes a decision tree object"""
|
|
iris_dataset = datasets.load_iris()
|
|
clf, sklearn_tree = make_sklearn_tree(iris_dataset)
|
|
# sklearn.tree.plot_tree(clf, node_ids=True)
|
|
# plt.show()
|
|
decision_tree = DecisionTreeDiagram(
|
|
sklearn_tree,
|
|
class_images_paths=[
|
|
"images/iris_dataset/SetosaFlower.jpeg",
|
|
"images/iris_dataset/VeriscolorFlower.jpeg",
|
|
"images/iris_dataset/VirginicaFlower.jpeg",
|
|
],
|
|
class_names=["Setosa", "Veriscolor", "Virginica"],
|
|
feature_names=["Sepal Length", "Sepal Width"],
|
|
)
|
|
decision_tree.move_to(ORIGIN)
|
|
create_decision_tree = Create(decision_tree, traversal_order="bfs")
|
|
self.play(create_decision_tree)
|
|
# self.play(create_decision_tree)
|
|
|
|
class SurfacePlot(Scene):
|
|
def construct(self):
|
|
iris_dataset = datasets.load_iris()
|
|
iris_dataset_plot = IrisDatasetPlot(iris_dataset)
|
|
iris_dataset_plot.all_group.scale(1.0)
|
|
iris_dataset_plot.all_group.shift([-3, 0.2, 0])
|
|
self.play(Create(iris_dataset_plot))
|
|
# make the decision tree classifier
|
|
decision_tree_classifier, sklearn_tree = make_sklearn_tree(iris_dataset)
|
|
decision_tree_surface = DecisionTreeSurface(
|
|
decision_tree_classifier, iris_dataset.data, iris_dataset_plot.axes_group[0]
|
|
)
|
|
self.play(Create(decision_tree_surface))
|
|
self.wait(1)
|