import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
from sklearn.model_selection import GridSearchCV
from xgboost import XGBClassifier
import warnings
warnings.simplefilter(action= 'ignore', category=FutureWarning)

오늘은 부스팅이 해보고 싶어서 XGB Classifier를 이용하여 분류해보고자 한다.

하이퍼파라미터탐색은 GridSearch를 이용하겠다.

 

전처리된 데이터를 준비하고,

train = pd.read_csv('1차 최종 train.csv')
test = pd.read_csv('test.csv')

 

혹시 모르니 info()를 통해 타입을 살펴보자.

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8125 entries, 0 to 8124
Data columns (total 13 columns):
 #   Column        Non-Null Count  Dtype
---  ------        --------------  -----
 0   Unnamed: 0    8125 non-null   int64
 1   HomePlanet    8125 non-null   int64
 2   CryoSleep     8125 non-null   int64
 3   Cabin         8125 non-null   int64
 4   Destination   8125 non-null   int64
 5   Age           8125 non-null   int64
 6   VIP           8125 non-null   int64
 7   RoomService   8125 non-null   int64
 8   FoodCourt     8125 non-null   int64
 9   ShoppingMall  8125 non-null   int64
 10  Spa           8125 non-null   int64
 11  VRDeck        8125 non-null   int64
 12  Transported   8125 non-null   int64
dtypes: int64(13)

 

이제 input과 target 데이터로 나누자.

train_input = train.drop('Transported',axis=1)
train_target = train['Transported']
train_input.info()

 

[ 하이퍼 파라미터 탐색 / 튜닝]

params = {'learning_rate':(0.01,1),
          'gamma':(0,1),
          'max_depth':(3,10),
          'sub_sample':(0.5,1),
          'min_child_weight': (1, 5),
          'colsample_bytree':(0.1,1),
          'n_estimators' : (50,200)}

gs = GridSearchCV(XGBClassifier(random_state = 42), param_grid=params ,n_jobs=-1)

gs.fit(train_input, train_target)

[15:35:27] WARNING: C:/buildkite-agent/builds/buildkite-windows-cpu-autoscaling-group-i-0fc7796c793e6356f-1/xgboost/xgboost-ci-windows/src/learner.cc:767: 
Parameters: { "sub_sample" } are not used.

GridSearchCV(estimator=XGBClassifier(base_score=None, booster=None,
                                     callbacks=None, colsample_bylevel=None,
                                     colsample_bynode=None,
                                     colsample_bytree=None,
                                     early_stopping_rounds=None,
                                     enable_categorical=False, eval_metric=None,
                                     feature_types=None, gamma=None,
                                     gpu_id=None, grow_policy=None,
                                     importance_type=None,
                                     interaction_constraints=None,
                                     learning_rate=None, max_b...
                                     max_delta_step=None, max_depth=None,
                                     max_leaves=None, min_child_weight=None,
                                     missing=nan, monotone_constraints=None,
                                     n_estimators=100, n_jobs=None,
                                     num_parallel_tree=None, predictor=None,
                                     random_state=42, ...),
             n_jobs=-1,
             param_grid={'colsample_bytree': (0.1, 1), 'gamma': (0, 1),
                         'learning_rate': (0.01, 1), 'max_depth': (3, 10),
                         'min_child_weight': (1, 5), 'n_estimators': (50, 200),
                         'sub_sample': (0.5, 1)})
gs.best_params_

{'colsample_bytree': 1,
 'gamma': 1,
 'learning_rate': 0.01,
 'max_depth': 10,
 'min_child_weight': 5,
 'n_estimators': 50,
 'sub_sample': 0.5}

 

params = gs.best_params_
xgb = XGBClassifier(random_state=42, **params)
xgb.fit(train_input,train_target)
xgb.score(train_input,train_target)

[15:47:29] WARNING: C:/buildkite-agent/builds/buildkite-windows-cpu-autoscaling-group-i-0fc7796c793e6356f-1/xgboost/xgboost-ci-windows/src/learner.cc:767: 
Parameters: { "sub_sample" } are not used.

0.8016

 

그리드 서치를 통해 params의 범위를 지정 후 탐색 한 후 가장 높은 점수를 얻은 최적의

하이퍼 파라미터 조합을 사용해 XGB Classifier을 사용했다

예측 성공률은 80%가 나왔다.

처음 돌린것 치고 괜찮은것 같다.

 

XGB를 다시 사용하고,

 그리드서치를 베이지안 옵티마이저로 변경하여 하이퍼 파라미터만 바꿔서 다시 돌려보자.

from sklearn.model_selection import cross_val_score
# Define the hyperparameters and their ranges to be optimized
params = {'max_depth': (3, 10),
          'learning_rate': (0.01, 1),
          'n_estimators': (50, 200),
          'min_child_weight': (1, 5),
          'subsample': (0.5, 1),
          'gamma': (0, 1),
          'colsample_bytree': (0.1, 1)}

# Define the objective function to be optimized by Bayesian Optimization
def xgb_cv(max_depth, learning_rate, n_estimators, min_child_weight, subsample, gamma, colsample_bytree):
    xgb.set_params(max_depth=int(max_depth),
                   learning_rate=learning_rate,
                   n_estimators=int(n_estimators),
                   min_child_weight=int(min_child_weight),
                   subsample=subsample,
                   gamma=gamma,
                   colsample_bytree=colsample_bytree)
    return cross_val_score(xgb, train_input, train_target, cv=5, scoring='accuracy').mean()

# Define the Bayesian Optimization object
xgbBO = BayesianOptimization(xgb_cv, params)

# Optimize the hyperparameters using Bayesian Optimization
xgbBO.maximize(n_iter=50, init_points=5)

# Print the best hyperparameters and their corresponding score
print(xgbBO.max)

베이지안 옵티마이저 ( 하이퍼 파라미터 탐색 (서치)) 는 그리드와 다르게

함수를 따로 설정해줘야 하는것 같다. def xgb_cv :

 

이퍼 파라미터를 입력으로 가져가서 XGB 분류기의 교차 검증된 정확도 점수를 반환하는

"xgb_cv"라는 목적 함수를 정의한다.

목적 함수는 sikit-learn의 "cross_val_score" 함수를 사용하여 5배 교차 검증을 수행한다.

그런 다음 "xgbB"라는 베이지안 최적화 개체를 만든다."  목적 함수와 하이퍼 파라미터를 인수로 전달한다.

우리는 베이지안 최적화를 사용하여 하이퍼 파라미터를 최적화하기 위해

베이지안 최적화 객체의 "최대화" 방법을 부른다.

마지막으로 베이지안 최적화 개체의 "max" 속성을 사용하여 최상의 하이퍼 파라미터와 해당 점수를 print한다.

 

 탐색 과정 보기

|   iter    |  target   | colsam... |   gamma   | learni... | max_depth | min_ch... | n_esti... | subsample |
-------------------------------------------------------------------------------------------------------------
| 1         | 0.6103    | 0.6987    | 0.1225    | 0.1515    | 8.224     | 2.78      | 88.54     | 0.8106    |
| 2         | 0.5686    | 0.6152    | 0.4665    | 0.6057    | 4.602     | 4.858     | 83.59     | 0.7877    |
| 3         | 0.5669    | 0.6712    | 0.5453    | 0.9401    | 7.787     | 1.631     | 84.9      | 0.7387    |
| 4         | 0.5202    | 0.2059    | 0.7044    | 0.0881    | 7.868     | 2.342     | 148.1     | 0.916     |
| 5         | 0.5641    | 0.518     | 0.08341   | 0.9347    | 6.403     | 2.885     | 142.2     | 0.9287    |
| 6         | 0.7503    | 0.7642    | 0.2208    | 0.01476   | 7.224     | 2.787     | 89.72     | 0.7512    |
| 7         | 0.6741    | 0.6956    | 0.2691    | 0.04304   | 7.473     | 2.803     | 90.27     | 0.7354    |
| 8         | 0.6111    | 0.3772    | 0.1156    | 0.1722    | 6.862     | 2.909     | 89.46     | 0.5736    |
| 9         | 0.5589    | 0.996     | 0.3121    | 0.7178    | 6.272     | 1.441     | 197.7     | 0.9347    |
| 10        | 0.7438    | 1.0       | 0.2819    | 0.01      | 7.435     | 2.692     | 89.73     | 0.888     |
| 11        | 0.6028    | 0.935     | 0.7813    | 0.3266    | 7.275     | 2.008     | 89.35     | 0.9696    |
| 12        | 0.5586    | 0.8707    | 0.4239    | 0.9497    | 7.755     | 2.939     | 90.22     | 0.8634    |
| 13        | 0.7199    | 0.9006    | 0.3698    | 0.03016   | 7.438     | 2.105     | 90.24     | 0.5563    |
| 14        | 0.5978    | 0.6174    | 0.7405    | 0.2071    | 7.302     | 2.448     | 89.83     | 0.9597    |
| 15        | 0.6238    | 0.5315    | 0.3929    | 0.05831   | 5.405     | 1.407     | 139.3     | 0.9426    |
| 16        | 0.5623    | 0.9584    | 0.5933    | 0.7928    | 5.375     | 2.849     | 173.5     | 0.8334    |
| 17        | 0.5706    | 0.5806    | 0.2061    | 0.8379    | 8.009     | 3.695     | 96.87     | 0.5719    |
| 18        | 0.6091    | 0.969     | 0.6108    | 0.5672    | 8.644     | 4.891     | 123.8     | 0.7119    |
| 19        | 0.5499    | 0.4206    | 0.9338    | 0.9778    | 4.152     | 2.256     | 142.6     | 0.8824    |
| 20        | 0.6901    | 0.685     | 0.01052   | 0.03924   | 7.289     | 2.735     | 90.7      | 0.5096    |
| 21        | 0.6047    | 0.6388    | 0.07058   | 0.1224    | 6.596     | 2.135     | 90.52     | 0.6906    |
| 22        | 0.7596    | 1.0       | 0.05041   | 0.01      | 7.305     | 2.524     | 89.87     | 0.5584    |
| 23        | 0.5822    | 0.8653    | 0.4866    | 0.3237    | 7.976     | 1.363     | 90.84     | 0.9786    |
| 24        | 0.5978    | 0.9655    | 0.1968    | 0.2951    | 7.567     | 2.518     | 89.95     | 0.505     |
| 25        | 0.6135    | 0.7961    | 0.8492    | 0.1947    | 7.032     | 1.907     | 90.26     | 0.5117    |
| 26        | 0.5785    | 0.2114    | 0.9544    | 0.9256    | 6.561     | 4.046     | 93.02     | 0.9875    |
| 27        | 0.5724    | 0.3869    | 0.213     | 0.126     | 8.753     | 1.661     | 195.6     | 0.9378    |
| 28        | 0.6228    | 0.7242    | 0.4029    | 0.1495    | 6.745     | 3.011     | 90.32     | 0.6581    |
| 29        | 0.6012    | 0.6911    | 0.2479    | 0.5442    | 7.269     | 2.962     | 89.39     | 0.6084    |
| 30        | 0.5546    | 0.2324    | 0.2779    | 0.3053    | 7.547     | 2.212     | 90.49     | 0.5984    |
| 31        | 0.5579    | 0.8693    | 0.7293    | 0.5319    | 7.034     | 2.044     | 90.46     | 0.6849    |
| 32        | 0.7332    | 0.7667    | 0.3579    | 0.02248   | 7.386     | 1.447     | 89.73     | 0.8661    |
| 33        | 0.5915    | 0.9329    | 0.4795    | 0.7687    | 7.461     | 1.033     | 89.7      | 0.9438    |
| 34        | 0.5662    | 0.2785    | 0.8665    | 0.6453    | 8.654     | 2.789     | 160.3     | 0.7509    |
| 35        | 0.6073    | 0.7983    | 0.1568    | 0.4487    | 6.283     | 3.771     | 160.9     | 0.886     |
| 36        | 0.521     | 0.1814    | 0.7958    | 0.09026   | 9.277     | 4.778     | 134.8     | 0.8938    |
| 37        | 0.6261    | 0.4929    | 0.6858    | 0.05002   | 7.517     | 1.226     | 89.75     | 0.9839    |
| 38        | 0.5979    | 0.6657    | 0.7642    | 0.2697    | 6.881     | 1.312     | 89.71     | 0.6174    |
| 39        | 0.7503    | 0.9457    | 0.07201   | 0.01      | 7.16      | 2.515     | 89.88     | 0.8681    |
| 40        | 0.7575    | 1.0       | 0.0       | 0.01      | 7.228     | 2.885     | 89.85     | 0.7391    |
| 41        | 0.6132    | 0.7363    | 0.1436    | 0.3085    | 7.023     | 2.021     | 89.61     | 0.5638    |
| 42        | 0.624     | 0.891     | 0.06212   | 0.1219    | 7.22      | 3.854     | 90.19     | 0.883     |
| 43        | 0.5767    | 0.7601    | 0.25      | 0.2714    | 7.582     | 1.709     | 89.91     | 0.9229    |
| 44        | 0.614     | 0.3562    | 0.5391    | 0.214     | 7.684     | 3.257     | 91.44     | 0.5431    |
| 45        | 0.6106    | 0.6325    | 0.3444    | 0.1795    | 7.056     | 2.479     | 91.33     | 0.6597    |
| 46        | 0.7557    | 0.8779    | 0.2409    | 0.0188    | 5.539     | 4.755     | 93.72     | 0.755     |
| 47        | 0.6015    | 0.5372    | 0.0837    | 0.1466    | 7.103     | 2.961     | 89.7      | 0.9337    |
| 48        | 0.581     | 0.9351    | 0.1166    | 0.4818    | 9.327     | 1.28      | 78.47     | 0.5416    |
| 49        | 0.5762    | 0.599     | 0.5343    | 0.6791    | 4.135     | 4.219     | 50.72     | 0.7866    |
| 50        | 0.609     | 0.3763    | 0.9039    | 0.2149    | 5.726     | 4.672     | 79.75     | 0.7045    |
| 51        | 0.5929    | 0.37      | 0.6496    | 0.7564    | 4.512     | 3.833     | 190.0     | 0.6005    |
| 52        | 0.6017    | 0.9499    | 0.5243    | 0.5982    | 5.635     | 4.218     | 94.03     | 0.5944    |
| 53        | 0.6108    | 0.8254    | 0.2186    | 0.1845    | 6.988     | 2.436     | 90.58     | 0.6205    |
| 54        | 0.5652    | 0.5436    | 0.9611    | 0.901     | 8.289     | 3.423     | 79.72     | 0.751     |
| 55        | 0.5876    | 0.7708    | 0.5102    | 0.7385    | 5.594     | 3.223     | 153.1     | 0.5488    |
=============================================================================================================
{'target': 0.7596307692307692, 'params': {'colsample_bytree': 1.0, 'gamma': 0.050406580268173413, 'learning_rate': 0.01, 'max_depth': 7.305157214715147, 'min_child_weight': 2.5238667272883433, 'n_estimators': 89.87424334757735, 'subsample': 0.5583965007292178}}
params = xgbBO.max['params']
params['max_depth'] = int(round(params['max_depth']))
params['n_estimators'] = int(round(params['n_estimators']))
params['gamma'] = int(round(params['gamma']))
params['min_child_weight'] = int(round(params['min_child_weight']))
xgb = XGBClassifier(random_state = 42, **params)
xgb.fit(train_input, train_target)



XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=1.0, early_stopping_rounds=None,
              enable_categorical=False, eval_metric=None, feature_types=None,
              gamma=0, gpu_id=None, grow_policy=None, importance_type=None,
              interaction_constraints=None, learning_rate=0.01, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=7, max_leaves=None,
              min_child_weight=3, missing=nan, monotone_constraints=None,
              n_estimators=90, n_jobs=None, num_parallel_tree=None,
              predictor=None, random_state=42, ...)

score를 출력해보자.

xgb.score(train_input, train_target)	

0.8045538461538462

 

'부록' 카테고리의 다른 글

Resnet 논문 리뷰  (0) 2024.04.22
Style GAN & Style GAN2  (1) 2024.04.07
LSTM  (0) 2023.08.02
[딥러닝] 옵티마이저 [퍼온 글]  (0) 2023.03.06
Bayesian Optimization  (2) 2023.02.21

+ Recent posts