18  Multiclass Classification

18.1 Introduction

Thus far, we have only covered methods for binary classification - that is, for predicting between two categories.

What happens when our target variable of interest contains more than two categories? For example, instead of predicting whether or not someone has heart disease, perhaps we want to predict what type of disease they have, out of three options.

Read on to find out…

Note

If you do not have the sklearn library installed then you will need to run

pip install sklearn

in the Jupyter/Colab terminal to install. Remember: you only need to install once per machine (or Colab session).

import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression

18.2 Naturally Multiclass Models

Some model specifications lend themselves naturally to the multiclass setting.

Let’s take a quick look at how each of these predicts for three or more classes.

18.2.1 Multiclass KNN

Recall that in a binary setting, KNN considers the “votes” of the \(K\) most similar observations in the training set to classify a new observation.

In a multiclass setting, nothing changes! KNN still considers the “votes” of the closest observations; we simply now have votes for more than two options.

18.2.2 Multiclass Trees

Similarly, in a binary setting, Decision Trees assign new observations to the class that is most common in the node/leaf (or “bucket”) that they land in.

The same is true for the multiclass setting. However, it’s important to remember that the splits in tree itself were chosen automatically during the model fitting procedure to try to make the nodes have as much “purity” as possible - that is, to have mostly one class represented in each leaf. This means the fitted tree for a two-class prediction setting might look very different from the fitted tree for a three-class setting!

18.2.3 Multiclass LDA

In the binary setting, LDA relies on the assumption that the “scores” (linear combinations of predictors) for observations in the two classes were generated from two Normal distributions with different means. After using the training data to pick a score function and estimate means, we then assign new predictions to the class whose distribution would be most likely to output that data.

Instead of two Normal distributions, we can easily imagine three or more! We still use the observed data to pick a score function and then approximate the means and standard deviations of the Normal distributions, and we still assign new predictions to the “most likely” group.

Practice Activity

Open this Colab notebook. Fit a multiclass KNN, Decision Tree, and LDA for the heart disease data; this time predicting the type of chest pain (categories 0 - 3) that a patient experiences. For the decision tree, plot the fitted tree, and interpret the first couple splits.

18.3 Multiclass from Binary Classifiers

Some models simply cannot be easily “upgraded” to the multiclass setting. Of those we have studied, Logistic Regression and SVC/SVM fall into this category.

In Logistic Regression, we rely on the logistic function to transform our linear combination of predictors into a probability. We only have one “score” from the linear combination, and we can only turn it into one probability. Thus, it only make sense to fit this model to compare two classes; i.e., to predict the “probability of Class 1”.

In SVC, our goal is do find a separating line that maximizes the margin to the two classes. What do we do with three classes? Find three separating lines? But then which margins do we look at? And which classes do we measure the margins between? There is no way to define our “model preferences” to include “large margins” in this setting.

So, how do we proceed? There are two approaches to using binary classification models to answer multiclass prediction questions…

18.3.1 One vs. Rest (OvR)

The first approach is to try to target only one category at a time, and fit a model that can extract those observations from the rest of them. This is called “One vs Rest” or OvR modeling.

Practice Activity

Open this Colab notebook. Create a new column in the ha dataset called “cp_is_3”, which is equal to 1 if the cp variable is equal to 3 and 0 otherwise.

Then, fit a Logistic Regression to predict this new target, and report the F1 Score.

Repeat for the other three cp categories. Which category was the OvR approach best at distinguishing?

Check In

Your four OvR Logistic Regressions produced four probabilities for each observation: prob of cp_is_0, prob of cp_is_1, etc.

Is it guaranteed that these four probabilities add up to 1? Why or why not?

18.3.2 One vs. One (OvO)

The second approach is to try to fit a model that are able to separate every pair of categories. This is called “One vs One” or OvO modeling.

Practice Activity

Open this Colab notebook. Reduce your dataset to only the 0 and 1 types of chest pain.

Then, fit a Logistic Regression to predict between the two grousp, and report the ROC-AUC.

Repeat comparing category 0 to 2 and 3. Which pair was the OvO approach best at distinguishing?

Check In
  • Why do you think we reported ROC-AUC instead of F1 Score this time?

  • Your three OvO Logistic Regressions produced four probabilities for each observation: prob of 0 compared to 1, prob of 0 compared to 2, and prob of 0 compared to 3. Is it guaranteed that these four probabilities add up to 1? Why or why not?

  • If we had done all the OvO pairs, how many regressions would we have fit?

  • How would you use the results of all the OvO pairs to arrive at one final class prediction?

18.3.3 How to choose

In general, the OvO approach is better because:

  • It gives better predictions. Distinguishing between individual groups gives more information than lumping many (possibly dissimilar) groups into a “Rest” category.

  • It gives more interpretable information. We can discuss the coefficient estimates of the individual models to figure out what patterns exist between the categories.

However, the OvR might be preferred when:

  • You have many categories. Consider a problem with 10 classes to predict. In OvR, we then need to fit 10 models for each specification. In OvO, we need to fit 45 different models for each specification!

  • You are interested in what makes a single category stand out. For example, perhaps you are using these models to understand what features define different bacteria species. You are not trying to figure out how Bacteria A is different from Bacteria B or Bacteria C specifically; you are trying to figure out what makes Bacteria A unique among the rest.

  • You have “layers” of categories. For example, in the heart attack data, notice that Chest Pain category 0 was “asymptomatic”, aka, no pain. We might be most interested in learning what distiguishes no pain (0) from yes pain (“the rest”); but we still are secondarily interested in distinguishing the three pain types.

18.4 Metrics and Multiclass Estimators

Recall that in the binary setting, we have two metrics that do not change based which class is considered “Class 1” or the “Target Class”:

  • accuracy: How many predictions were correct

  • ROC-AUC: A measure of the trade-off for getting Class 1 wrong or Class 0 wrong as the decision boundary changes.

We also have many metrics that are asymmetrical, and are calculated differently for different target classes:

  • precision: How many of the predicted Target Class were truly from the Target Class?

  • recall: How many of the true Target Class observations were successfully identified as Target Class?

  • F1 Score: “Average” of precision and recall.

  • F2 Score: “Average” of precision and 2*recall.

Now that we are in the multiclass setting, we can think of precision, recall, and F1 Score as “OvR” metrics: They measure the model’s ability to successfully predict one category of interest out of the pack.

We can think of ROC-AUC as an “OvO” metric: It measures the model’s trade-off between success for two classes.

Only accuracy is truly a multiclass metric!

Check In

If you randomly guess categories in a two-class setting by flipping a coin, how often do you expect to be right?

If you randomly guess categories in a six-class setting by rolling a die, how often do you expect to be right?

What does this tell you about what you should consider a “good” accuracy for a model to achieve in multiclass settings?

18.4.1 Macro and micro

So, if we want to use a metric besides accuracy to measure our model’s success, what should we do? Three options:

  1. We look at the micro version of the metric: we choose one category that is most important to us to be the target category, and then we measure that. Realistically, we only really report micro metrics to summarize how well we can predict each individual category. We don’t use them to select between models - because if our definition of “best” model is just the one that pulls out the target category, why are we bothering with multiclass in the first place?

  2. We look at the macro version of the metric: the average of the micro versions across all the possible categories. This is the most common approach; you will often see classification models measured by f1_macro.

  3. We look at a weighted average of the micro metrics. This might be useful if there is one category that matters more, but we still care about all the categories. (Such as in the cp variable, where we care most about distinguishing 0 from the rest, but we still want to separate 1-3.)

18.5 Conclusion

There are many reasons why it’s important for a data scientist to understand the intuition and motivation behind the models they use, even if the computation and math are taken care of by the software.

Multiclass classification is a great example of this principle. What if we had just chucked some multiclass data into all our classification models: KNN, Trees, Logistic, LDA, QDA, SVC, and SVM. Some models would be fine, while others would be handling the multiclass problem in very different ways than they handle binary settings - and this could lead to bad model fits, or worse, incorrect interpretations of the results!