圓州率
🌐

Feature Image

分類與回歸樹 (Classification and Regression Tree, CART)

數學, 資料科學, 機器學習, 監督式學習
一種基決策樹的方法,用於預測分類變量,也可用於非線性預測。
   最後更新:

簡介

分類與回歸樹 (Regression and classification tree, CART) 是一種基於決策樹的方式,將資料根據特徵切分成數個區域後,並對各自區域做預測,不只能用於類別型預測,也能用在非線性的預測問題。

  1. 分類樹:用於預測類別型資料,如顏色、性別。
  2. 回歸樹:用於預測數值型資料,如身高、體重。

範例:給定 $X$ 預測 $Y$。

先將資料根據特徵切分成數個區塊。

然後對各自區塊做預測。

而本文僅針對 recursive binary tree 討論,其兩大特色為

  1. 對區域做二元切分切成兩份 (binary) 子區域
  2. 每份子區域等於簡化後的問題,可以遞迴 (recursive) 的再次做切分

回歸樹

單維度切分

考慮二元切分與常數預測,每次切分考慮切割點 $t_m$ (或區域 $R_m$),與對該切分的預測值 $c_m$。

$$ \hat{y} = f (x) = \sum_{m = 1}^{M} c_m I (x \in R_m). $$

以最小化平方誤差合 (sum of squared error, SSE) 為目標決定預測值 $c_m$

$$ \begin{align*} c_m & = \argmin_{c_m} \sum_{i = 1}^{n} (y_i - \hat{y}_i)^2 \\ & = \argmin_{c_m} \sum_{x_i \in R_m} (y_i - c_m)^2 \\ & = \text{mean} (y_i | x_i \in R_m). \end{align*} $$

下一步,根據切分點 $t$ 將資列切分成兩份

$$ \begin{align*} R_1 (t) = \{ X : X \leq t \} \quad \text{and} \quad R_2 (t) = \{ X : X > t \} \end{align*} $$

此時同樣以最小化 SSE 為目標決定 $t$

$$ \begin{align*} \min_{t} \left[ \min_{c_1} \sum_{x_i \in R_1 (t)} (y_i - c_1)^2 + \min_{c_2} \sum_{x_i \in R_2 (t)} (y_i - c_2)^2 \right] \end{align*} $$

我們不會考慮在 $t \in \mathbb R$,取而代之的是僅考慮有資料點取值得範圍。定義 $X_{(i)}$ 是 $X_1, X_2, \cdots, X_n$ 的排序統計量,我們從第二小的數值 $X_{(2)}$ 嘗試切分至第二大的數值 $X_{(n - 1)}$,並找到能最小化 SSE 的 $t$。

$$ \begin{align*} \min_{t \in X} \left[ \min_{c_1} \sum_{x_i \in R_1 (t)} (y_i - c_1)^2 + \min_{c_2} \sum_{x_i \in R_2 (t)} (y_i - c_2)^2 \right] \end{align*} $$

多維度切分

假設 $X \in \bb R^{n \times p}$,即每有 $n$ 筆資料,每筆資料有 $p$ 維度,現在考慮在第 $j$ 個維度將資料切分

$$ \begin{align*} R_1 (j, t) = \{ X : X_j \leq t \} \quad \text{and} \quad R_2 (j, t) = \{ X : X_j > t \}. \end{align*} $$

同樣是尋找能讓 SSE 最小化的 $j$ 與 $t$

$$ \begin{align*} \min_{j, t} \left[ \min_{c_1} \sum_{x_i \in R_1 (j, t)} (y_i - c_1)^2 + \min_{c_2} \sum_{x_i \in R_2 (j, t)} (y_i - c_2)^2 \right]. \end{align*} $$

讓 $j$ 從 $1$ 嘗試到 $p$,更精確地說

$$ \begin{align*} \min_{j \in \{1, 2, \cdots, p\}} \min_{t \in X_j} \left[ \min_{c_1} \sum_{x_i \in R_1 (j, t)} (y_i - c_1)^2 + \min_{c_2} \sum_{x_i \in R_2 (j, t)} (y_i - c_2)^2 \right]. \end{align*} $$

iris 作為示範 (只考慮數值部分)

Sepal.Length ($X_1$)Sepal.Width ($X_2$)Petal.Length ($X_3$)Petal.Width ($Y$)
5.13.51.40.2
4.93.01.40.2
4.73.21.30.2
4.63.11.50.2
5.03.61.40.2
5.43.91.70.4
4.63.41.40.3
5.03.41.50.2

iris_1.png

從 $j = 1$ 開始,即嘗試切分 $X_1$

iris_2.png

相似的,對 $X_2$ 與 $X_3$ 做切分

iris_3.png

iris_4.png

最終目標是找到 $j$ 與 $t$ 能最小化 SSE

$$ \begin{align*} \min_{j \in \{1, 2, 3\}} \min_{t \in X_j} \left[ \min_{c_1} \sum_{x_i \in R_1 (j, t)} (y_i - c_1)^2 + \min_{c_2} \sum_{x_i \in R_2 (j, t)} (y_i - c_2)^2 \right]. \end{align*} $$

iris_5.png

而最好的切點落在 $(j, t) = (3, 1.9)$,即

$$ \begin{align*} R_1 = \{ X : X_3 \leq 1.9 \} \quad \text{and} \quad R_2 = \{ X : X_3 > 1.9 \}. \end{align*} $$

best_cut.png

現在將 $R_1$ 的點以紅色標記。

best_cut_1.png

下一步將 $R_1$ 的點移除,剩下的將是 $R_2$ 的點。

best_cut_2.png

我們能對剩下的資料再次進行切分

best_cut_3.png

並不斷重複這個過程,直到切分 SSE 沒有明顯的最小值,或資料已經切分至太細。

複雜度標準 (cost complexity criterion)

僅追求最小化 SSE 會導致切分區域過多,即"樹非常大",會導致 over fitting;相對的,過小的樹可能會沒辦法抓到資料的特徵。因此定義 複雜度標準 (cost complexity criterion),概念是同時平衡"錯誤量"與"樹的複雜度"。

令 $|T|$ 是 $T$ 這棵樹的節點數量,並令

$$ \begin{align*} N_m & = | R_m |, \\ c_m & = \frac{1}{N_m} \sum_{x_i \in R_m} y_i, \\ Q_m (T) & = \frac{1}{N_m} \sum_{x_i \in R_m} (y_i - c_m)^2. \end{align*} $$

給定一數 $\alpha \geq 0$,定義

$$ \begin{align*} C_\alpha (T) = \sum_{m = 1}^{|T|} N_m Q_m (T) + \alpha |T|. \end{align*} $$

我們的目標是最小化這個複雜度標準,其中的 $\alpha$ 是個調整參數,他會懲罰那些過大過複雜的樹,即"$\alpha$ 越大,樹越小"。對 $\alpha$ 的選擇,可考慮最小化 5-cross validation sum of squares。

分類樹

對類別型變數 (例如:顏色、性別) SSE 不再可用,若假設 $y$ 有 $K$ 種可能的分類,標作 $\\{1, 2, \cdots, K\\}$,我們會用純度來描述 $Q_m (T)$,並定義 $p_{mk}$ 是第 $m$ 個區域中第 $k$ 種類所佔的比例

$$ \begin{align*} \hat p_{mk} = \frac{1}{N_m} \sum_{x_i \in R_m} I (y_i = k) \end{align*} $$

最終根據節點 $m$ 中所佔比例最高的種類當成最終預測

$$ \begin{align*} k (m) = \argmax_{k \in \{1, 2, \cdots, K\}} \hat p_{mk}. \end{align*} $$

下列提供數種不同純度標準 $Q_m (T)$

  1. 錯分率 (misclassification error):

    $$ \begin{align*} \frac{1}{N_m} \sum_{x_i \in R_m} I (y_i \ne k(m)) = 1 - \hat p_{m k (m)}. \end{align*} $$
  2. Gini index:

    $$ \begin{align*} \sum_{k \ne l} \hat p_{m k} \hat p_{m l} = \sum_{k = 1}^{K} \hat p_{m k} (1 - \hat p_{m k}). \end{align*} $$
  3. Cross-entropy or deviance:

    $$ \begin{align*} - \sum_{k = 1}^{K} \hat p_{m k} \ln (\hat p_{m k}). \end{align*} $$

其他建樹方案

要建立決策樹,我們有兩步需要考慮

  1. 如何將資料切分
  2. 如何對切分的區域做預測

線性組合切法

在一般的切分中僅考慮 $X_j < t$,但可以考慮資料的線性組合切割 $\sum_{j = 1}^{p} a_j X_j < t$,而其中的係數 $a_j$ 與切分點 $t$ 是以最小化 Gini index 為目標所決定的。

例如:iris的前兩個維度能看到完美切分點。

linear_combination_splits.png

回歸預測

不僅可以做常數的預測,也能對切分出來的區域單獨做回歸,而最終目標仍是最小化 cost complexity criterion (或 SSE,當 $\alpha = 0$ 的時候)。

regression_fit.gif

R 語言範例

載入必要 packages

library(rpart)      # CART
library(rpart.plot) # CART plot
library(caret)      # model training
library(dplyr)      # %>%

採樣

# sampling
set.seed(0)
trainIndex = createDataPartition(iris$Species, p = 0.7, list = FALSE)
trainData = iris[trainIndex, ]
testData = iris[-trainIndex, ]

用 caret 的 train 訓練 CART 模型

# cart 
cart = train(Species ~ ., 
             data = trainData, 
             method = "rpart", 
             trControl = trainControl(method = "cv", number = 10), # k-Fold cross validation
             tuneLength = 10) # Try possible number of parameter

模型修剪,平衡錯誤率與樹大小

# complexity parameter prune
pruned_cart = prune(cart$finalModel, cp = 0.01)

混淆矩陣表現

# confusion matrix
predictions = predict(pruned_cart, testData, type = "class")
confusionMatrix(predictions, testData$Species)

繪製模型結果

rpart.png

參考資料

  1. The Element of Statistical Learning (Trevor Hastie)
  2. Kno’s SML 2024