KFold
대회 준비를 하다가 baseline code에 아래와 같이 scikit learn의 KFold가 있어 따로 정리한다.
뭐 데이터를 당연이 K-fold cross validation으로 진행한다고 얘기하겠지만 구체적인
동작방식이 궁금했다.
sklearn.model_selection.KFold
아래 링크 안의 내용을 직역하면
K-Folds cross-validator
Train/Test sets의 데이터를 split하기 위한 학습/테스트 Index를 제공.
데이터 세트를 k consecutive folds로 분할(default : without shuffle).
그런 다음 각 폴드는 validation로 한 번 사용되는
반면 k - 1개의 나머지 폴드는 training set를 형성.
즉, 위 KFold를 사용하면 알아서 겹치지 않게 fold를 나눠주고 그 형태를
index로 제공한다는 말 같다.
예시 코드를 보자 (원본 코드에서 조금 수정하였다)
import numpy as np
from sklearn.model_selection import KFold
X = np.array([5, 6, 7, 8, 9, 10])
y = np.array([1, 2, 3, 4, 5, 6])
kf = KFold(n_splits=2)
kf.get_n_splits(X)
print(kf)
for train_index, test_index in kf.split(X):
print("TRAIN:", train_index, "TEST:", test_index)
X_train, y_train = X[train_index], y[train_index]
X_test, y_test = X[test_index], y[test_index]
print("X_train : ", X_train, "y_train : ", y_train)
print("X_test : ", X_test, "y_test : ", y_test)
print("==============================================")
TRAIN: [3 4 5] TEST: [0 1 2]
X_train : [ 8 9 10] y_train : [4 5 6]
X_test : [5 6 7] y_test : [1 2 3]
==============================================
TRAIN: [0 1 2] TEST: [3 4 5]
X_train : [5 6 7] y_train : [1 2 3]
X_test : [ 8 9 10] y_test : [4 5 6]
==============================================
결과를 보면 train_index과 test_index를 반환하고, 반환한 index를 기반으로
train data와 train label, test data와 test label을 가져오는 것을 알 수 있다.
또한 fold를 2 fold로 나눴으니 반복은 두 번 된다.
위 KFold는 단점이 있으니 그건 class imbalance를 고려하지 않았다는 점이다.
이를 해결하기 위해 stratifiedKFold 라는 것이 있다고 한다.
stratifiedKFold
위 함수는 각 클래스에 대한 비율을 고려해서 fold를 나눈다고 한다.
예시 코드
import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 1, 1])
skf = StratifiedKFold(n_splits=2)
skf.get_n_splits(X, y)
print(skf)
for train_index, test_index in skf.split(X, y):
print("TRAIN:", train_index, "TEST:", test_index)
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
클래스 분포를 고려하기 위해 위 split 시 y의 label 정보를 주어지면 이를 고려하여 사용할 수 있다고 한다.
참고
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html
'Language&Framework&Etc > Python' 카테고리의 다른 글
from win32com.shell import shellcon, shell ImportError: DLL load failed while importing shell: 지정된 프로시저를 찾을 수 없습니다. (0) | 2022.07.19 |
---|---|
@staticmethod 와 @classmethod (0) | 2022.03.07 |
파이썬 정규식 연습장 (0) | 2021.12.01 |
Visual Studio Code에서 내부 라이브러리 디버깅 하는 방법 (0) | 2021.10.25 |
Multiprocessing 파이썬 (0) | 2021.10.15 |