스파르타 코딩클럽/[강의] 머신러닝

[머신러닝 심화] 데이터 분석 프로세스 - 데이터 분리

sance 2024. 1. 31. 22:19

과적합(Overfitting)

  • 데이터를 너무 과도하게 학습한 나머지 해당 문제만 잘 맞추고 새로운 데이터를 제대로 예측 혹은 분류하지 못하는 현상
  • 예시
  • 예측 혹은 분류를 하기 위해서 모형 복잡도를 설정
    • 모형이 지나치게 복잡할 때 : 과대 적합이 될 수 있음
    • 모형이 지나치게 단순할 때 : 과소 적합이 될 수 있음
  • 과적합의 원인
    • 모델의 복잡도
    • 데이터 양이 충분하지 않음
    • 학습 반복이 많음(딥러닝의 경우)
    • 데이터 불균형

 

과적합 해결 - 테스트 데이터 분리

  • Train Data : 모델을 학습(fit)하기 위한 데이터
  • Test Data : 모델을 평가하기 위한 데이터
  • 함수 : sklearn.model_selection.train_test_split
    • 파라미터
      ▪️ test_size : 테스트 데이터 세트 크기
      ▪️ train_size : 학습 데이터 세트 크기
      ▪️ shuffle : 데이터 분리 시 섞기
      ▪️ random_state : 호출할 때마다 동일한 학습/테스트 데이터를 생성하기 위한 난수 값. 수행할 때 마다 동일한 데이터 세트로 분리하기 위해 숫자를 고정시켜야 함
    • 반환 값 (순서 중요)
      ▪️ X_train, X_test, y_train, y_test

 

데이터 분리 실습
  • X 변수 : Fare, Sex
  • Y 변수 : Survived
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(titanic_df[['Fare', 'Sex']], titanic_df[['Survived']], test_size=0.3, shuffle=True, random_state=42, stratify=titanic_df[['Survived']])
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

## (623, 2) (268, 2) (623, 1) (268, 1)
# 원 자료 891개 y값의 분포
sns.countplot(titanic_df, x='Survived')
sns.countplot(y_train, x='Survived')
sns.countplot(y_test, x='Survived')