Menu
×
   ❮     
HTML CSS JAVASCRIPT SQL PYTHON JAVA PHP HOW TO W3.CSS C C++ C# BOOTSTRAP REACT MYSQL JQUERY EXCEL XML DJANGO NUMPY PANDAS NODEJS R TYPESCRIPT ANGULAR GIT POSTGRESQL MONGODB ASP AI GO KOTLIN SASS VUE DSA GEN AI SCIPY AWS CYBERSECURITY DATA SCIENCE
     ❯   

Python Tutorial

Python HOME Python Intro Python Get Started Python Syntax Python Comments Python Variables Python Data Types Python Numbers Python Casting Python Strings Python Booleans Python Operators Python Lists Python Tuples Python Sets Python Dictionaries Python If...Else Python While Loops Python For Loops Python Functions Python Lambda Python Arrays Python Classes/Objects Python Inheritance Python Iterators Python Polymorphism Python Scope Python Modules Python Dates Python Math Python JSON Python RegEx Python PIP Python Try...Except Python User Input Python String Formatting

File Handling

Python File Handling Python Read Files Python Write/Create Files Python Delete Files

Python Modules

NumPy Tutorial Pandas Tutorial SciPy Tutorial Django Tutorial

Python Matplotlib

Matplotlib Intro Matplotlib Get Started Matplotlib Pyplot Matplotlib Plotting Matplotlib Markers Matplotlib Line Matplotlib Labels Matplotlib Grid Matplotlib Subplot Matplotlib Scatter Matplotlib Bars Matplotlib Histograms Matplotlib Pie Charts

Machine Learning

Getting Started Mean Median Mode Standard Deviation Percentile Data Distribution Normal Data Distribution Scatter Plot Linear Regression Polynomial Regression Multiple Regression Scale Train/Test Decision Tree Confusion Matrix Hierarchical Clustering Logistic Regression Grid Search Categorical Data K-means Bootstrap Aggregation Cross Validation AUC - ROC Curve K-nearest neighbors

Python MySQL

MySQL Get Started MySQL Create Database MySQL Create Table MySQL Insert MySQL Select MySQL Where MySQL Order By MySQL Delete MySQL Drop Table MySQL Update MySQL Limit MySQL Join

Python MongoDB

MongoDB Get Started MongoDB Create DB MongoDB Collection MongoDB Insert MongoDB Find MongoDB Query MongoDB Sort MongoDB Delete MongoDB Drop Collection MongoDB Update MongoDB Limit

Python Reference

Python Overview Python Built-in Functions Python String Methods Python List Methods Python Dictionary Methods Python Tuple Methods Python Set Methods Python File Methods Python Keywords Python Exceptions Python Glossary

Module Reference

Random Module Requests Module Statistics Module Math Module cMath Module

Python How To

Remove List Duplicates Reverse a String Add Two Numbers

Python Examples

Python Examples Python Compiler Python Exercises Python Quiz Python Server Python Syllabus Python Interview Q&A Python Bootcamp Python Certificate

Machine Learning - Cross Validation


On this page, W3schools.com collaborates with NYC Data Science Academy, to deliver digital training content to our students.


Cross Validation

When adjusting models we are aiming to increase overall model performance on unseen data. Hyperparameter tuning can lead to much better performance on test sets. However, optimizing parameters to the test set can lead information leakage causing the model to preform worse on unseen data. To correct for this we can perform cross validation.

To better understand CV, we will be performing different methods on the iris dataset. Let us first load in and separate the data.

from sklearn import datasets

X, y = datasets.load_iris(return_X_y=True)

There are many methods to cross validation, we will start by looking at k-fold cross validation.


K-Fold

The training data used in the model is split, into k number of smaller sets, to be used to validate the model. The model is then trained on k-1 folds of training set. The remaining fold is then used as a validation set to evaluate the model.

As we will be trying to classify different species of iris flowers we will need to import a classifier model, for this exercise we will be using a DecisionTreeClassifier. We will also need to import CV modules from sklearn.

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import KFold, cross_val_score

With the data loaded we can now create and fit a model for evaluation.

clf = DecisionTreeClassifier(random_state=42)

Now let's evaluate our model and see how it performs on each k-fold.

k_folds = KFold(n_splits = 5)

scores = cross_val_score(clf, X, y, cv = k_folds)

It is also good pratice to see how CV performed overall by averaging the scores for all folds.

Example

Run k-fold CV:

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import KFold, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

k_folds = KFold(n_splits = 5)

scores = cross_val_score(clf, X, y, cv = k_folds)

print("Cross Validation Scores: ", scores)
print("Average CV Score: ", scores.mean())
print("Number of CV Scores used in Average: ", len(scores))
Run example »

ADVERTISEMENT


Stratified K-Fold

In cases where classes are imbalanced we need a way to account for the imbalance in both the train and validation sets. To do so we can stratify the target classes, meaning that both sets will have an equal proportion of all classes.

Example

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

sk_folds = StratifiedKFold(n_splits = 5)

scores = cross_val_score(clf, X, y, cv = sk_folds)

print("Cross Validation Scores: ", scores)
print("Average CV Score: ", scores.mean())
print("Number of CV Scores used in Average: ", len(scores))
Run example »

While the number of folds is the same, the average CV increases from the basic k-fold when making sure there is stratified classes.


Leave-One-Out (LOO)

Instead of selecting the number of splits in the training data set like k-fold LeaveOneOut, utilize 1 observation to validate and n-1 observations to train. This method is an exaustive technique.

Example

Run LOO CV:

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import LeaveOneOut, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

loo = LeaveOneOut()

scores = cross_val_score(clf, X, y, cv = loo)

print("Cross Validation Scores: ", scores)
print("Average CV Score: ", scores.mean())
print("Number of CV Scores used in Average: ", len(scores))
Run example »

We can observe that the number of cross validation scores performed is equal to the number of observations in the dataset. In this case there are 150 observations in the iris dataset.

The average CV score is 94%.


Leave-P-Out (LPO)

Leave-P-Out is simply a nuanced diffence to the Leave-One-Out idea, in that we can select the number of p to use in our validation set.

Example

Run LPO CV:

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import LeavePOut, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

lpo = LeavePOut(p=2)

scores = cross_val_score(clf, X, y, cv = lpo)

print("Cross Validation Scores: ", scores)
print("Average CV Score: ", scores.mean())
print("Number of CV Scores used in Average: ", len(scores))
Run example »

As we can see this is an exhaustive method we many more scores being calculated than Leave-One-Out, even with a p = 2, yet it achieves roughly the same average CV score.


Shuffle Split

Unlike KFold, ShuffleSplit leaves out a percentage of the data, not to be used in the train or validation sets. To do so we must decide what the train and test sizes are, as well as the number of splits.

Example

Run Shuffle Split CV:

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import ShuffleSplit, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

ss = ShuffleSplit(train_size=0.6, test_size=0.3, n_splits = 5)

scores = cross_val_score(clf, X, y, cv = ss)

print("Cross Validation Scores: ", scores)
print("Average CV Score: ", scores.mean())
print("Number of CV Scores used in Average: ", len(scores))
Run example »

Ending Notes

These are just a few of the CV methods that can be applied to models. There are many more cross validation classes, with most models having their own class. Check out sklearns cross validation for more CV options.


×

Contact Sales

If you want to use W3Schools services as an educational institution, team or enterprise, send us an e-mail:
sales@w3schools.com

Report Error

If you want to report an error, or if you want to make a suggestion, send us an e-mail:
help@w3schools.com

W3Schools is optimized for learning and training. Examples might be simplified to improve reading and learning. Tutorials, references, and examples are constantly reviewed to avoid errors, but we cannot warrant full correctness of all content. While using W3Schools, you agree to have read and accepted our terms of use, cookie and privacy policy.

Copyright 1999-2024 by Refsnes Data. All Rights Reserved. W3Schools is Powered by W3.CSS.