import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
18 Multiclass Classification
18.1 Introduction
Thus far, we have only covered methods for binary classification - that is, for predicting between two categories.
What happens when our target variable of interest contains more than two categories? For example, instead of predicting whether or not someone has heart disease, perhaps we want to predict what type of disease they have, out of three options.
Read on to find out…
If you do not have the sklearn
library installed then you will need to run
pip install sklearn
in the Jupyter/Colab terminal to install. Remember: you only need to install once per machine (or Colab session).
18.2 Naturally Multiclass Models
Some model specifications lend themselves naturally to the multiclass setting.
Let’s take a quick look at how each of these predicts for three or more classes.
18.2.1 Multiclass KNN
Recall that in a binary setting, KNN considers the “votes” of the \(K\) most similar observations in the training set to classify a new observation.
In a multiclass setting, nothing changes! KNN still considers the “votes” of the closest observations; we simply now have votes for more than two options.
18.2.2 Multiclass Trees
Similarly, in a binary setting, Decision Trees assign new observations to the class that is most common in the node/leaf (or “bucket”) that they land in.
The same is true for the multiclass setting. However, it’s important to remember that the splits in tree itself were chosen automatically during the model fitting procedure to try to make the nodes have as much “purity” as possible - that is, to have mostly one class represented in each leaf. This means the fitted tree for a two-class prediction setting might look very different from the fitted tree for a three-class setting!
18.2.3 Multiclass LDA
In the binary setting, LDA relies on the assumption that the “scores” (linear combinations of predictors) for observations in the two classes were generated from two Normal distributions with different means. After using the training data to pick a score function and estimate means, we then assign new predictions to the class whose distribution would be most likely to output that data.
Instead of two Normal distributions, we can easily imagine three or more! We still use the observed data to pick a score function and then approximate the means and standard deviations of the Normal distributions, and we still assign new predictions to the “most likely” group.
18.3 Multiclass from Binary Classifiers
Some models simply cannot be easily “upgraded” to the multiclass setting. Of those we have studied, Logistic Regression and SVC/SVM fall into this category.
In Logistic Regression, we rely on the logistic function to transform our linear combination of predictors into a probability. We only have one “score” from the linear combination, and we can only turn it into one probability. Thus, it only make sense to fit this model to compare two classes; i.e., to predict the “probability of Class 1”.
In SVC, our goal is do find a separating line that maximizes the margin to the two classes. What do we do with three classes? Find three separating lines? But then which margins do we look at? And which classes do we measure the margins between? There is no way to define our “model preferences” to include “large margins” in this setting.
So, how do we proceed? There are two approaches to using binary classification models to answer multiclass prediction questions…
18.3.1 One vs. Rest (OvR)
The first approach is to try to target only one category at a time, and fit a model that can extract those observations from the rest of them. This is called “One vs Rest” or OvR modeling.
18.3.2 One vs. One (OvO)
The second approach is to try to fit a model that are able to separate every pair of categories. This is called “One vs One” or OvO modeling.
18.3.3 How to choose
In general, the OvO approach is better because:
It gives better predictions. Distinguishing between individual groups gives more information than lumping many (possibly dissimilar) groups into a “Rest” category.
It gives more interpretable information. We can discuss the coefficient estimates of the individual models to figure out what patterns exist between the categories.
However, the OvR might be preferred when:
You have many categories. Consider a problem with 10 classes to predict. In OvR, we then need to fit 10 models for each specification. In OvO, we need to fit 45 different models for each specification!
You are interested in what makes a single category stand out. For example, perhaps you are using these models to understand what features define different bacteria species. You are not trying to figure out how Bacteria A is different from Bacteria B or Bacteria C specifically; you are trying to figure out what makes Bacteria A unique among the rest.
You have “layers” of categories. For example, in the heart attack data, notice that Chest Pain category 0 was “asymptomatic”, aka, no pain. We might be most interested in learning what distiguishes no pain (0) from yes pain (“the rest”); but we still are secondarily interested in distinguishing the three pain types.
18.4 Metrics and Multiclass Estimators
Recall that in the binary setting, we have two metrics that do not change based which class is considered “Class 1” or the “Target Class”:
accuracy: How many predictions were correct
ROC-AUC: A measure of the trade-off for getting Class 1 wrong or Class 0 wrong as the decision boundary changes.
We also have many metrics that are asymmetrical, and are calculated differently for different target classes:
precision: How many of the predicted Target Class were truly from the Target Class?
recall: How many of the true Target Class observations were successfully identified as Target Class?
F1 Score: “Average” of precision and recall.
F2 Score: “Average” of precision and 2*recall.
Now that we are in the multiclass setting, we can think of precision, recall, and F1 Score as “OvR” metrics: They measure the model’s ability to successfully predict one category of interest out of the pack.
We can think of ROC-AUC as an “OvO” metric: It measures the model’s trade-off between success for two classes.
Only accuracy is truly a multiclass metric!
18.4.1 Macro and micro
So, if we want to use a metric besides accuracy to measure our model’s success, what should we do? Three options:
We look at the micro version of the metric: we choose one category that is most important to us to be the target category, and then we measure that. Realistically, we only really report micro metrics to summarize how well we can predict each individual category. We don’t use them to select between models - because if our definition of “best” model is just the one that pulls out the target category, why are we bothering with multiclass in the first place?
We look at the macro version of the metric: the average of the micro versions across all the possible categories. This is the most common approach; you will often see classification models measured by
f1_macro
.We look at a weighted average of the micro metrics. This might be useful if there is one category that matters more, but we still care about all the categories. (Such as in the
cp
variable, where we care most about distinguishing0
from the rest, but we still want to separate1-3
.)
18.5 Conclusion
There are many reasons why it’s important for a data scientist to understand the intuition and motivation behind the models they use, even if the computation and math are taken care of by the software.
Multiclass classification is a great example of this principle. What if we had just chucked some multiclass data into all our classification models: KNN, Trees, Logistic, LDA, QDA, SVC, and SVM. Some models would be fine, while others would be handling the multiclass problem in very different ways than they handle binary settings - and this could lead to bad model fits, or worse, incorrect interpretations of the results!