Machine Learning/Advanced (hands on machine learning)

16. 결정 트리 - 결정 트리 학습과 시각화, 예측

jwjwvison 2021. 5. 21. 22:10


SVM처럼 결정 트리는 분류와 회귀 작업 그리고 다중출력 작업도 가능한 다재다능한 머신러닝 알고리즘이다. 또한 매우 복잡한 데이터셋도 학습할 수 있는 강력한 알고리즘이다.
결정 트리는 최근에 자주 사용되는 가장 강력한 머신러닝 알고리즘 중 하나인 랜덤 포레스트의 기본 구성 요소이기도 하다.

결정 트리를 이해하기 위해 모델 하나를 만들어서 어떻게 예측을 하는지 살펴보자. 다음은 붓꽃 데이터셋에 DecisionClassifier를 훈련시키는 코드이다.

from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier iris=load_iris() X=iris.data[:,2:] #꽃잎의 길이와 너비 y=iris.target tree_clf=DecisionTreeClassifier(max_depth=2) tree_clf.fit(X,y)
from sklearn.tree import plot_tree plot_tree(tree_clf,filled=True,class_names=True)


트리가 어떻게 예측을 만들어내는지 살펴보자. 새로 발견한 붓꽃의 품종을 분류하려 한다고 가정하면 루트노드에서 시작한다. 이 노드는 꽃잎의 길이가 2.45cm 보다 짧은지 검사한다. 만약 그렇다면 루트 노드에서 왼쪽의 자식 노드로 이동한다. 이 경우는 이 노드가 리프 노드이므로 추가적인 검사를 하지 않는다. 그냥 노드에 있는 예측 클래스를 보고 결정 트리가 새로 발견한 꽃의 품종을 Iris-Setosa(class=setosa)라고 예측한다.


노드의 sample 속성은 얼마나 많은 훈련 샘플이 적용되었는지 헤아린 것이다. 노드의 value 속성은 노드에서 각 클래스에 얼마나 많은 훈련 샘플이 있는지 알려준다. 마지막으로 노드의 gini 속성은 불순도(impurity)를 측정한다. 한 노드의 모든 샘플이 같은 클래스에 속해 있다면 이 노드를 순수(gini=0)하다고 한다.
식 6-1은 훈련 알고리즘이 i번째 노드의 gini 점수 Gi를 계산하는 방법을 보여준다. 깊이 2의 왼쪽 노드의 gini 점수는 1-(0/54)^2 - (49/54)^2 - (5/54)^2 =0.168이다.