고차 다항 회귀를 적용하면 보통의 선형 회귀에서보다 훨씬 더 훈련 데이터에 잘 맞추려 할 것이다. 다음 그림은 300차 다항 회귀 모델을 이전의 훈련 데이터에 적용한 것이다. 단순한 선형 모델이나 2차 모델(2차 다항 회귀 모델)과 결과를 비교해보자. 300차 다항 회귀 모델은 훈련 샘플에 가능한 한 가까이 가려고 구불구불하게 나타난다.
이 고차 다항 회귀 모델은 심각하게 훈련 데이터에 과대적합되었다. 반면에 선형 모델은 과소적합이다. 이 경우 가장 일반화가 잘된 모델은 2차 다항 회귀이다.
일반적으로는 어떤 함수로 데이터가 생성됐는지 알 수 없다. 그러면 얼마나 복잡한 모델을 사용할지 어떻게 결정할 수 있을까? 어떻게 모델이 데이터에 과대적합 또는 과소적합되었는지 알 수 있을까?
훈련 데이터에서는 성능이 좋지만 교차 검증 점수가 나쁘다면 모델이 과대적합된 것이다. 만약 양쪽에서 좋지 않으면 과소적합이다. 이때 모델이 너무 단순하거나 너무 복잡하다고 말한다.
또 다른 방법은 학습 곡선을 살펴보는 것이다. 이 그래프는 훈련 세트와 검증 세트의 모델 성능을 훈련 세트 크기(또는 훈련 반복)의 함수로 나타낸다. 다음 코드는 주어진 훈련 데이터에서 모델의 학습 곡선을 그리는 함수를 정의한다.
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
def plot_learning_curves(model, X, y):
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=10)
train_errors, val_errors = [], []
for m in range(1, len(X_train)):
model.fit(X_train[:m], y_train[:m])
y_train_predict = model.predict(X_train[:m])
y_val_predict = model.predict(X_val)
train_errors.append(mean_squared_error(y_train[:m], y_train_predict))
val_errors.append(mean_squared_error(y_val, y_val_predict))
plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")
plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")
plt.legend(loc="upper right", fontsize=14)
plt.xlabel("Training set size", fontsize=14)
plt.ylabel("RMSE", fontsize=14)
단순 선형 회귀 모델(직선)의 학습 곡선을 살펴보자.
lin_reg = LinearRegression()
plot_learning_curves(lin_reg, X, y)
plt.axis([0, 80, 0, 3])
plt.show()
훈련 데이터의 성능을 먼저 보면, 그래프가 0에서 시작하므로 훈련 세트에 하나 혹은 두 개의 샘플이 있을 땐 모델이 완벽하게 작동하낟. 하지만 훈련 세트에 샘플이 추가됨에 따라 잡음도 있고 비선형이기 때문에 모델이 훈련 데이터를 완벽히 학습하는 것이 불가능해진다. 그래서 곡선이 어느 정도 평편해질 때까지 오차가 계속 상승한다. 이 위치에서는 훈련 세트에 샘플이 추가되어도 평균 오차가 크게 나아지거나 나빠지지 않는다.
검증 데이터는 모델이 적은 수의 훈련 샘플로 훈련될 때는 제대로 일반화될 수 없어서 검증 오차가 초기에 매우 크다. 모데렝 훈련 샘플이 추가됨에 따라 학습이 되고 검증 오차가 천천히 감소한다. 하지만 선형 회귀의 직선은 데이터를 잘 모델링할 수 없으므로 오차의 감소가 완만해져서 훈련 세트의 그래프와 가까워진다.
이제 같은 데이터에서 10차 다항 회귀 모델의 학습 곡선을 그려보자.
from sklearn.pipeline import Pipeline
polynomial_regression=Pipeline([
('poly_features',PolynomialFeatures(degree=3,include_bias=False)),
('lin_reg',LinearRegression()),
])
plot_learning_curves(polynomial_regression,X,y)
'Machine Learning > Advanced (hands on machine learning)' 카테고리의 다른 글
11. 모델 훈련 - 로지스틱 회귀 (0) | 2021.05.19 |
---|---|
10. 모델훈련 - 규제가 있는 선형 모델 (0) | 2021.05.19 |
8. 모델훈련 - 다항 회귀 (0) | 2021.05.18 |
7. 모델훈련 - 경사하강법 (0) | 2021.05.18 |
6. 모델훈련 - 선형 회귀 (0) | 2021.05.18 |