Tuning parameters in a machine learning model play a critical role. Here, we are showing a grid search example on how to tune a random forest model:
# Random Forest Classifier - Grid Search >>> from sklearn.pipeline import Pipeline >>> from sklearn.model_selection import train_test_split,GridSearchCV >>> pipeline = Pipeline([ ('clf',RandomForestClassifier(criterion='gini',class_weight = {0:0.3,1:0.7}))])
Tuning parameters are similar to random forest parameters apart from verifying all the combinations using the pipeline function. The number of combinations to be evaluated will be (3 x 3 x 2 x 2) *5 =36*5 = 180 combinations. Here 5 is used in the end, due to the cross-validation of five-fold:
>>> parameters = { ... 'clf__n_estimators':(2000,3000,5000), ... 'clf__max_depth':(5,15,30), ... 'clf__min_samples_split':(2,3), ... 'clf__min_samples_leaf':(1,2) }
>>> grid_search = GridSearchCV(pipeline,parameters,n_jobs=-1,cv=5,verbose=1, scoring='accuracy') >>> grid_search.fit(x_train,y_train)
>>> print ('Best Training score: %0.3f' % grid_search.best_score_) >>> print ('Best parameters set:') >>> best_parameters = grid_search.best_estimator_.get_params() >>> for param_name in sorted(parameters.keys()): ... print (' %s: %r' % (param_name, best_parameters[param_name])) >>> predictions = grid_search.predict(x_test) >>> print ("Testing accuracy:",round(accuracy_score(y_test, predictions),4)) >>> print (" Complete report of Testing data ",classification_report(y_test, predictions))
>>> print (" Random Forest Grid Search- Test Confusion Matrix ",pd.crosstab( y_test, predictions,rownames = ["Actuall"],colnames = ["Predicted"]))
In the preceding results, grid search seems to not provide many advantages compared with the already explored random forest result. But, practically, most of the times, it will provide better and more robust results compared with a simple exploration of models. However, by carefully evaluating many different combinations, it will eventually discover the best parameters combination:
R Code for random forest classifier with grid search applied on HR attrition data:
# Grid Search - Random Forest library(e1071) library(randomForest) rf_grid = tune(randomForest,Attrition_ind~.,data = train_data,classwt = c(0.3,0.7),ranges = list( mtry = c(5,6), maxnodes = c(32,64), ntree = c(3000,5000), nodesize = c(1,2) ), tunecontrol = tune.control(cross = 5) ) print(paste("Best parameter from Grid Search")) print(summary(rf_grid)) best_model = rf_grid$best.model tr_y_pred=predict(best_model,data = train_data,type ="response") ts_y_pred=predict(best_model,newdata = test_data,type= "response") tr_y_act = train_data$Attrition_ind; ts_y_act= test_data$Attrition_ind tr_tble = table(tr_y_act,tr_y_pred) print(paste("Random Forest Grid search Train Confusion Matrix")) print(tr_tble) tr_acc = accrcy(tr_y_act,tr_y_pred) trprec_zero = prec_zero(tr_y_act,tr_y_pred); trrecl_zero = recl_zero(tr_y_act,tr_y_pred) trprec_one = prec_one(tr_y_act,tr_y_pred); trrecl_one = recl_one(tr_y_act,tr_y_pred) trprec_ovll = trprec_zero *frac_trzero + trprec_one*frac_trone trrecl_ovll = trrecl_zero *frac_trzero + trrecl_one*frac_trone print(paste("Random Forest Grid Search Train accuracy:",tr_acc)) print(paste("Random Forest Grid Search - Train Classification Report")) print(paste("Zero_Precision",trprec_zero,"Zero_Recall",trrecl_zero)) print(paste("One_Precision",trprec_one,"One_Recall",trrecl_one)) print(paste("Overall_Precision",round(trprec_ovll,4),"Overall_Recall",round(trrecl_ovll,4))) ts_tble = table(ts_y_act,ts_y_pred) print(paste("Random Forest Grid search Test Confusion Matrix")) print(ts_tble) ts_acc = accrcy(ts_y_act,ts_y_pred) tsprec_zero = prec_zero(ts_y_act,ts_y_pred); tsrecl_zero = recl_zero(ts_y_act,ts_y_pred) tsprec_one = prec_one(ts_y_act,ts_y_pred); tsrecl_one = recl_one(ts_y_act,ts_y_pred) tsprec_ovll = tsprec_zero *frac_tszero + tsprec_one*frac_tsone tsrecl_ovll = tsrecl_zero *frac_tszero + tsrecl_one*frac_tsone print(paste("Random Forest Grid Search Test accuracy:",ts_acc)) print(paste("Random Forest Grid Search - Test Classification Report")) print(paste("Zero_Precision",tsprec_zero,"Zero_Recall",tsrecl_zero)) print(paste("One_Precision",tsprec_one,"One_Recall",tsrecl_one)) print(paste("Overall_Precision",round(tsprec_ovll,4),"Overall_Recall",round(tsrecl_ovll,4)))