diff --git a/machine_learning/decision_tree.py b/machine_learning/decision_tree.py index dfc2e1676..51f600cac 100644 --- a/machine_learning/decision_tree.py +++ b/machine_learning/decision_tree.py @@ -48,12 +48,15 @@ class Decision_Tree: return if y.ndim != 1: print("Error: Data set labels must be one dimensional") + return if len(X) < 2 * self.min_leaf_size: self.prediction = np.mean(y) + return if self.depth == 1: self.prediction = np.mean(y) + return best_split = 0 min_error = self.mean_squared_error(X,np.mean(y)) * 2