Skip to content

知识点卡片:决策树

基本信息

属性内容
知识点决策树 (Decision Tree)
掌握程度★★★★☆
学习优先级P1
预估时间6小时
面试频率★★★★☆

核心原理

决策树通过递归分裂特征空间来构建分类/回归模型:

分裂准则:
- 分类:Gini不纯度 / 信息增益(熵)
- 回归:MSE减少量

停止条件:
- max_depth, min_samples_split, min_samples_leaf

信息增益

Gini: G = 1 - Σ p_k²
Entropy: H = -Σ p_k log p_k
信息增益: IG = H(parent) - Σ (N_child/N) * H(child)

剪枝策略

python
# 预剪枝(Pre-pruning):训练时提前停止
# - max_depth: 最大深度
# - min_samples_split: 最小分裂样本数
# - min_samples_leaf: 叶节点最小样本数
# - min_impurity_decrease: 最小不纯度减少

# 后剪枝(Post-pruning):训练后再修剪
# - ccp_alpha: 成本复杂度剪枝参数

sklearn实现

python
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.tree import plot_tree

# 分类
clf = DecisionTreeClassifier(
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=5,
    criterion='gini',  # 'gini' 或 'entropy'
    random_state=42
)
clf.fit(X_train, y_train)

# 特征重要性
for feature, importance in zip(feature_names, clf.feature_importances_):
    print(f"{feature}: {importance:.4f}")

# 回归
reg = DecisionTreeRegressor(max_depth=5, min_samples_leaf=10)
reg.fit(X_train, y_train)

从零实现CART

python
import numpy as np

class DecisionTree:
    def __init__(self, max_depth=5, min_samples_split=10):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.tree = None

    def _gini(self, y):
        """计算Gini不纯度"""
        _, counts = np.unique(y, return_counts=True)
        probs = counts / len(y)
        return 1 - np.sum(probs ** 2)

    def _best_split(self, X, y):
        """找到最佳分裂点"""
        best_gain = 0
        best_feature, best_threshold = None, None

        for feature in range(X.shape[1]):
            thresholds = np.unique(X[:, feature])
            for threshold in thresholds:
                left_mask = X[:, feature] <= threshold
                if left_mask.sum() == 0 or left_mask.sum() == len(y):
                    continue

                parent_gini = self._gini(y)
                left_gini = self._gini(y[left_mask])
                right_gini = self._gini(y[~left_mask])

                n_left = left_mask.sum()
                n_total = len(y)
                weighted_gini = (n_left/n_total)*left_gini + (1-n_left/n_total)*right_gini
                gain = parent_gini - weighted_gini

                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature
                    best_threshold = threshold

        return best_feature, best_threshold, best_gain

面试高频问题

Q1: Gini不纯度 vs 信息增益?

  • Gini计算更快(无对数运算)
  • 实际效果差异通常不大
  • sklearn默认用Gini,CART算法也是Gini
  • 信息增益对多值特征有偏向(会倾向选特征值多的特征)

Q2: 决策树的过拟合与解决方案?

: 决策树容易过拟合(可以完美拟合训练集):

  1. 限制深度(max_depth)
  2. 限制叶节点样本数(min_samples_leaf)
  3. 剪枝(ccp_alpha)
  4. 集成(随机森林/XGBoost)——最常见方案

相关知识点