스파르타 코딩클럽/[강의] 머신러닝
[머신러닝 심화] 데이터 분석 프로세스 - 데이터 분리
sance
2024. 1. 31. 22:19
과적합(Overfitting)
- 데이터를 너무 과도하게 학습한 나머지 해당 문제만 잘 맞추고 새로운 데이터를 제대로 예측 혹은 분류하지 못하는 현상
- 예시
- 예측 혹은 분류를 하기 위해서 모형 복잡도를 설정
- 모형이 지나치게 복잡할 때 : 과대 적합이 될 수 있음
- 모형이 지나치게 단순할 때 : 과소 적합이 될 수 있음
- 과적합의 원인
- 모델의 복잡도
- 데이터 양이 충분하지 않음
- 학습 반복이 많음(딥러닝의 경우)
- 데이터 불균형
과적합 해결 - 테스트 데이터 분리
- Train Data : 모델을 학습(fit)하기 위한 데이터
- Test Data : 모델을 평가하기 위한 데이터
- 함수 : sklearn.model_selection.train_test_split
- 파라미터
▪️ test_size : 테스트 데이터 세트 크기
▪️ train_size : 학습 데이터 세트 크기
▪️ shuffle : 데이터 분리 시 섞기
▪️ random_state : 호출할 때마다 동일한 학습/테스트 데이터를 생성하기 위한 난수 값. 수행할 때 마다 동일한 데이터 세트로 분리하기 위해 숫자를 고정시켜야 함 - 반환 값 (순서 중요)
▪️ X_train, X_test, y_train, y_test
- 파라미터
데이터 분리 실습
- X 변수 : Fare, Sex
- Y 변수 : Survived
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(titanic_df[['Fare', 'Sex']], titanic_df[['Survived']], test_size=0.3, shuffle=True, random_state=42, stratify=titanic_df[['Survived']])
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
## (623, 2) (268, 2) (623, 1) (268, 1)
# 원 자료 891개 y값의 분포
sns.countplot(titanic_df, x='Survived')
|
sns.countplot(y_train, x='Survived')
|
sns.countplot(y_test, x='Survived')
|
![]() |
![]() |
![]() |