library(tidyverse)
## ── Attaching packages ──────────────────
## ✔ ggplot2 3.2.1     ✔ purrr   0.3.2
## ✔ tibble  2.1.3     ✔ dplyr   0.8.3
## ✔ tidyr   1.0.0     ✔ stringr 1.4.0
## ✔ readr   1.3.1     ✔ forcats 0.4.0
## ── Conflicts ── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(knitr)
library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
library(rpart)
options(scipen = 4)

Prediction vs. Inference

Unless you’ve already taken a class on data mining or machine learning, a lot of the analytics tasks you’ve undertaken have probably taken the form of “inference” problems. Common inference problems include:

These are all what Mullainathan and Spiess\(\hat \beta\)” problems (“beta-hat problems”). Essentially, these are all problems that begin with you putting down a model

\[ y = X\beta + \epsilon, \] estimating \(\beta\), and making some conclusions about the world based on those estimates and corresponding statistical significance analyses.

Prediction problems are different. When we’re doing prediction, we aren’t interested in \(\beta\). Instead, we’re interested in being able to accurately predict \(y\) from information \(x\). These are what M&S call “\(\hat y\)” problems (“y-hat problems”). Prediction is a very useful paradigm to know about for a number of reasons. First, prediction problems are ubiquitous, even in policy settings. Second, prediction is much easier than inference. Whereas inference often relies on various assumptions holding, prediction is largely assumption-free.

These notes introduce you to prediction in its most common form—binary classification—and teach you the basics of training and evaluating different classifiers.

Classification

Many of the problems you’ll come across in the future will likely be classification problems (and if they’re not, there’s you can typically turn them into classification problems). These are problems where your outcome variable \(y\) is binary or categorical. E.g., \(y\) might be the indicator that an email is spam, a transation is fraud, or that a student graduates from college. There are certainly notable cases where a good prediction of a quantitative outcome is desired. E.g., stock prices, home values, crop yields. But “most” problems do wind up being ones of classification.

Let’s get started. We’ll use a marketing data set where we have observations on whether bank customers who were contacted by the bank’s sales team opened a term deposit (“subscribed”). Let’s start by loading the data.

marketing <- read_delim("http://www.andrew.cmu.edu/user/achoulde/94842/data/bank-full.csv", 
                        delim = ";")
## Parsed with column specification:
## cols(
##   age = col_double(),
##   job = col_character(),
##   marital = col_character(),
##   education = col_character(),
##   default = col_character(),
##   balance = col_double(),
##   housing = col_character(),
##   loan = col_character(),
##   contact = col_character(),
##   day = col_double(),
##   month = col_character(),
##   duration = col_double(),
##   campaign = col_double(),
##   pdays = col_double(),
##   previous = col_double(),
##   poutcome = col_character(),
##   y = col_character()
## )

What does the data contain?

str(marketing)
## Classes 'spec_tbl_df', 'tbl_df', 'tbl' and 'data.frame': 45211 obs. of  17 variables:
##  $ age      : num  58 44 33 47 33 35 28 42 58 43 ...
##  $ job      : chr  "management" "technician" "entrepreneur" "blue-collar" ...
##  $ marital  : chr  "married" "single" "married" "married" ...
##  $ education: chr  "tertiary" "secondary" "secondary" "unknown" ...
##  $ default  : chr  "no" "no" "no" "no" ...
##  $ balance  : num  2143 29 2 1506 1 ...
##  $ housing  : chr  "yes" "yes" "yes" "yes" ...
##  $ loan     : chr  "no" "no" "yes" "no" ...
##  $ contact  : chr  "unknown" "unknown" "unknown" "unknown" ...
##  $ day      : num  5 5 5 5 5 5 5 5 5 5 ...
##  $ month    : chr  "may" "may" "may" "may" ...
##  $ duration : num  261 151 76 92 198 139 217 380 50 55 ...
##  $ campaign : num  1 1 1 1 1 1 1 1 1 1 ...
##  $ pdays    : num  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 ...
##  $ previous : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ poutcome : chr  "unknown" "unknown" "unknown" "unknown" ...
##  $ y        : chr  "no" "no" "no" "no" ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   age = col_double(),
##   ..   job = col_character(),
##   ..   marital = col_character(),
##   ..   education = col_character(),
##   ..   default = col_character(),
##   ..   balance = col_double(),
##   ..   housing = col_character(),
##   ..   loan = col_character(),
##   ..   contact = col_character(),
##   ..   day = col_double(),
##   ..   month = col_character(),
##   ..   duration = col_double(),
##   ..   campaign = col_double(),
##   ..   pdays = col_double(),
##   ..   previous = col_double(),
##   ..   poutcome = col_character(),
##   ..   y = col_character()
##   .. )
marketing <- marketing %>%
  mutate(y = as.numeric(y == "yes"))

Our outcome variable here is y, whether or not a person opens an account. You’ll see above that we transformed the original yes/no y to an indicator that a person subscribes.

Classifiation as a \(\hat \beta\) problem: inference with logistic regression

You’re already familiar with linear regression, and you could certainly regress y on the other variables in the data using linear regiression to construct a model. That approach turns out to be ill-suited to the analysis of binary outcome data. Among other things, when \(y\) is either \(0\) or \(1\), it is odd to fit a line to predict this type of outcome. For one, the linear model generally won’t be constrained to predict values between 0 and 1; it can give negative values or values >1. Linear increases on the probability scale are also unlikely to be a good description of the association between a given input \(x\) and the outcome \(y\).

The standard approach to modeling binary outcome data in regression is to use logistic regression. This is an example of a generalized linear model (glm). GLM’s generalize what you know about linear regression to outcome variable types that aren’t well modeled by the “gaussian” linear model \(y = X\beta + \epsilon\). Obviously binary outcomes \(y\) aren’t generated by this sort of process.

The setup for logistic regression is essentially as follows. The observed outcome \(y\) for an observation with features \(x\) is thought to come from a Bernoulli\((p(x))\) random variable, where the success probability \(p(x)\) is parameterized by

\[ \log\left( \frac{p}{1 - p} \right) = \beta_0 + \beta_1 x_1 + \dots + \beta_p x_p \] Essentially this is saying that instead of modeling \(y\) (or, technically, \(E(y | x)\)) as a linear function of \(x\), we’ll model \(y\) as a Bernoulli realization where the success probility \(p\) is a function of \(x\). The “linear” part of the generalized linear model is what’s being illustrated in the above expression: A transformation of \(p\) is being modeled as a linear function of \(x\). You can compare this with the standard linear regression model, which says that \(y \sim N(\mu(x), \sigma^2)\), where the mean \(\mu = E(Y \mid x)\) is a linear function of \(x\):

\[ \mu = \beta_0 + \beta_1 x_1 + \dots + \beta_p x_p \]

So let’s fit a logistic regression model and look at what we find.

marketing.glm <- glm(y ~ ., data = marketing, family = binomial())
summary(marketing.glm)
## 
## Call:
## glm(formula = y ~ ., family = binomial(), data = marketing)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -5.7286  -0.3744  -0.2530  -0.1502   3.4288  
## 
## Coefficients:
##                        Estimate   Std. Error z value Pr(>|z|)    
## (Intercept)        -2.535637780  0.183703164 -13.803  < 2e-16 ***
## age                 0.000112719  0.002205165   0.051 0.959233    
## jobblue-collar     -0.309872593  0.072669201  -4.264 2.01e-05 ***
## jobentrepreneur    -0.357103762  0.125564459  -2.844 0.004455 ** 
## jobhousemaid       -0.504001652  0.136469021  -3.693 0.000221 ***
## jobmanagement      -0.165278440  0.073292526  -2.255 0.024130 *  
## jobretired          0.252362639  0.097217516   2.596 0.009436 ** 
## jobself-employed   -0.298336079  0.111996400  -2.664 0.007726 ** 
## jobservices        -0.223797106  0.084064904  -2.662 0.007763 ** 
## jobstudent          0.382135715  0.109029897   3.505 0.000457 ***
## jobtechnician      -0.176016548  0.068931178  -2.554 0.010664 *  
## jobunemployed      -0.176713126  0.111642461  -1.583 0.113456    
## jobunknown         -0.313264379  0.233463307  -1.342 0.179656    
## maritalmarried     -0.179453495  0.058910580  -3.046 0.002318 ** 
## maritalsingle       0.092497647  0.067260667   1.375 0.169066    
## educationsecondary  0.183528258  0.064792557   2.833 0.004618 ** 
## educationtertiary   0.378941502  0.075319068   5.031 4.88e-07 ***
## educationunknown    0.250478833  0.103896567   2.411 0.015915 *  
## defaultyes         -0.016681215  0.162837013  -0.102 0.918407    
## balance             0.000012835  0.000005148   2.493 0.012651 *  
## housingyes         -0.675384337  0.043869060 -15.395  < 2e-16 ***
## loanyes            -0.425371663  0.059989904  -7.091 1.33e-12 ***
## contacttelephone   -0.163374330  0.075185612  -2.173 0.029784 *  
## contactunknown     -1.623216856  0.073171806 -22.184  < 2e-16 ***
## day                 0.009968922  0.002496619   3.993 6.53e-05 ***
## monthaug           -0.693907553  0.078474461  -8.842  < 2e-16 ***
## monthdec            0.691124324  0.176682753   3.912 9.17e-05 ***
## monthfeb           -0.147320938  0.089413545  -1.648 0.099427 .  
## monthjan           -1.261718795  0.121702801 -10.367  < 2e-16 ***
## monthjul           -0.830795589  0.077404978 -10.733  < 2e-16 ***
## monthjun            0.453622601  0.093669266   4.843 1.28e-06 ***
## monthmar            1.589890543  0.119853742  13.265  < 2e-16 ***
## monthmay           -0.399111424  0.072285121  -5.521 3.36e-08 ***
## monthnov           -0.873398521  0.084409802 -10.347  < 2e-16 ***
## monthoct            0.881437433  0.108030525   8.159 3.37e-16 ***
## monthsep            0.874058052  0.119497320   7.314 2.58e-13 ***
## duration            0.004193695  0.000064532  64.986  < 2e-16 ***
## campaign           -0.090781782  0.010137033  -8.955  < 2e-16 ***
## pdays              -0.000102685  0.000306089  -0.335 0.737268    
## previous            0.010152353  0.006502908   1.561 0.118476    
## poutcomeother       0.203478400  0.089855382   2.265 0.023543 *  
## poutcomesuccess     2.291056017  0.082348964  27.821  < 2e-16 ***
## poutcomeunknown    -0.091793506  0.093474710  -0.982 0.326093    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 32631  on 45210  degrees of freedom
## Residual deviance: 21562  on 45168  degrees of freedom
## AIC: 21648
## 
## Number of Fisher Scoring iterations: 6

Classification as a \(\hat y\) problem: Try everything and see what sticks

First step: Set aside a test set

# Initialized seed for random number generation
set.seed(12345)

# Upsample the data to artifically overcome sample imbalance
marketing.more.idx <- sample(which(marketing$y == 1), 15000, replace = TRUE)
marketing.upsample <- rbind(marketing,
                            marketing[marketing.more.idx, ])

# Trim job strings to 5 characters
# marketing.upsample <- transform(marketing.upsample, job = strtrim(job, 5))

# Randomly select 20% of the data to be held out for model validation
test.indexes <- sample(1:nrow(marketing.upsample), 
                       round(0.2 * nrow(marketing.upsample)))
train.indexes <- setdiff(1:nrow(marketing.upsample), test.indexes)

# Just pull the covariates available to marketers (cols 1:8) and the outcome (col 17)
marketing.train <- marketing.upsample[train.indexes, c(1:8, 17)]
marketing.test <- marketing.upsample[test.indexes, c(1:8, 17)]

When buliding models it is important that we hold out a subset of our data, typically called a “test set” or a “validation set”, or a “holdout set”. In this example we’re holding out a random 20% of our data. The purpose of this test set is to ensure that we get reasonable estimated of the prediction accuracy of our model even if we make mistakes during our “training” process that result in “overfitting”.

When you have a large number of covariates, it’s easy to overfit the data. When you overfit the training data, you get a model that describes the training data really well, but which doesn’t give good predictions on unseen data.

Overfitting

Overfitting

source: http://pingax.com/regularization-implementation-r/


Second step: “Train” models

library(glmnet)  # Regularized regression
## Loading required package: Matrix
## 
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## Loaded glmnet 3.0-1
library(ranger)  # random forests

We’ll start by fitting a logistic regression model to the training data.

marketing.glm <- glm(y ~ ., data = marketing.train, family = binomial())
pred.test.glm <- as.numeric(predict(marketing.glm, newdata = marketing.test, type = "response") > 0.5)

The code above fits a logistic regression model to the training data, and then gets predicted probabilities for the test data. The round operation here is equivalent to thresholding those probabilities at 0.5 to form predictions of whether the person is a high earner.

Evaluate predictive performance

# Confusion matrix for logistic regression
conf.glm <- table(marketing.test$y, pred.test.glm)
conf.glm
##    pred.test.glm
##        0    1
##   0 7302  724
##   1 3005 1011
# How accurate is our model?
sum(diag(conf.glm)) / sum(conf.glm)
## [1] 0.6903338

That’s way better than 50%! But… is 50% accuracy really the baseline we want? You often hear that something is “better than a coin flip” or “no better than a coin flip”. Is a fair coin flip really the right baseline? Generally, no. Let’s look at what fraction of our test data are actually high earners

mean(marketing.train$y)
## [1] 0.3378314

Hmm… So if we guessed that no one subscribes, our accuracy would already be 0.6621686. That makes our accuracy of 0.6903338 a lot less impressive by comparison.

Now let’s fit some other models

Regularized logistic regression (Lasso)

###  Regularized logistic regression, with parameters tuned through cross-validation
# Extract y column
y.marketing <- marketing.train$y
# Get a numeric design matrix x
x.marketing <- model.matrix(~ . - y - 1, data = marketing.train)
x.marketing.test <- model.matrix(~ . - y - 1, data = marketing.test)
# Run cross-validated regularized regression
marketing.cv.glmnet <- cv.glmnet(x.marketing, y.marketing, family = "binomial")
# Have a look at the cv error plot
plot(marketing.cv.glmnet)

Let’s get our predictions for the test data

# Extract predictions from model selected by the 1se rule (simplest model within 1 standard error from the minimum)
pred.test.glmnet <- predict(marketing.cv.glmnet, x.marketing.test, s = "lambda.1se", type = "class")

# Confusion matrix for regularized logistic regression
conf.glmnet <- table(marketing.test$y, pred.test.glmnet)
conf.glmnet
##    pred.test.glmnet
##        0    1
##   0 7570  456
##   1 3267  749

How did we do?

sum(diag(conf.glmnet)) / sum(conf.glmnet)
## [1] 0.6908321

Well… that wasn’t any better…

Tree model (with rpart)

library(partykit)
## Loading required package: grid
## Loading required package: libcoin
## Loading required package: mvtnorm
marketing.tree <- rpart(as.factor(y) ~ ., data = marketing.train, 
                        control = rpart.control(minsplit=50, cp=0.002))
marketing.party <- as.party(marketing.tree)
plot(marketing.party, gp = gpar(fontsize = 10))

pred.test.tree <- as.numeric(predict(marketing.tree, newdata = marketing.test)[,"1"] > 0.5)
# Confusion matrix for tree model
conf.tree <- table(marketing.test$y, pred.test.tree)
conf.tree
##    pred.test.tree
##        0    1
##   0 7400  626
##   1 2883 1133

How did we do?

sum(diag(conf.tree)) / sum(conf.tree)
## [1] 0.7086032

That’s a little better.

Random forest (with ranger)

marketing.rf <- ranger(y ~ ., data = marketing.train, importance = 'impurity')
pred.test.rf <- as.numeric(predict(marketing.rf, data = marketing.test)$predictions > 0.5)
# Confusion matrix for random forest model
conf.rf <- table(marketing.test$y, pred.test.rf)
conf.rf
##    pred.test.rf
##        0    1
##   0 7450  576
##   1 2275 1741

How did we do?

sum(diag(conf.rf)) / sum(conf.rf)
## [1] 0.7632453

Way better!

But is overall accuracy really what we care about? How will we use this model in the future? Presumably we’ll be using the model to help guide a new marketing campaign. In that case our task will be to select a subset of new customers who we should contact, instead of contacting everyone. How do we think about our model’s performance in that type of setting?

Here’s a function that calculates a bunch of classification metrics based on a model’s confusion table. We’ll assess it on all of our models.

classSummary <- function(tbl) {
  n <- sum(tbl)
  prev <- sum(tbl[2,]) / sum(tbl)
  acc <- sum(diag(tbl)) / n
  prop.pos <- sum(tbl[,2]) / n
  ppv <- tbl[2,2] / sum(tbl[,2])
  fpr <- tbl[1,2] / sum(tbl[1,])
  fnr <- tbl[2,1] / sum(tbl[2,])
  spec <- 1 - fpr
  sens <- 1 - fnr
  lr.pos <- sens / fpr
  lr.neg <- fnr / spec
  out <- data.frame(value = round(c(n, prev, acc, 
                                    prop.pos,
                                    ppv, fpr, fnr, spec, sens,
                                    lr.pos, lr.neg), 3))
  rownames(out) <- c("count",
                     "prevalence",
                     "accuracy",
                     "prop.positive",
                     "PPV",
                     "FPR",
                     "FNR",
                     "Specificity (TNR)",
                     "Sensitivity (TPR)",
                     "LR+",
                     "LR-")
  out
}
classSummary(conf.glm)
##                       value
## count             12042.000
## prevalence            0.333
## accuracy              0.690
## prop.positive         0.144
## PPV                   0.583
## FPR                   0.090
## FNR                   0.748
## Specificity (TNR)     0.910
## Sensitivity (TPR)     0.252
## LR+                   2.791
## LR-                   0.822
classSummary(conf.glmnet)
##                       value
## count             12042.000
## prevalence            0.333
## accuracy              0.691
## prop.positive         0.100
## PPV                   0.622
## FPR                   0.057
## FNR                   0.813
## Specificity (TNR)     0.943
## Sensitivity (TPR)     0.187
## LR+                   3.283
## LR-                   0.862
classSummary(conf.tree)
##                       value
## count             12042.000
## prevalence            0.333
## accuracy              0.709
## prop.positive         0.146
## PPV                   0.644
## FPR                   0.078
## FNR                   0.718
## Specificity (TNR)     0.922
## Sensitivity (TPR)     0.282
## LR+                   3.617
## LR-                   0.779
classSummary(conf.rf)
##                       value
## count             12042.000
## prevalence            0.333
## accuracy              0.763
## prop.positive         0.192
## PPV                   0.751
## FPR                   0.072
## FNR                   0.566
## Specificity (TNR)     0.928
## Sensitivity (TPR)     0.434
## LR+                   6.041
## LR-                   0.610

Let’s bind those together to make them easier to compare

tibble(metric = rownames(classSummary(conf.glm)),
       logistic = classSummary(conf.glm)$value,
       lasso = classSummary(conf.glmnet)$value,
       tree = classSummary(conf.tree)$value,
       rf = classSummary(conf.rf)$value)
## # A tibble: 11 x 5
##    metric             logistic     lasso      tree        rf
##    <chr>                 <dbl>     <dbl>     <dbl>     <dbl>
##  1 count             12042     12042     12042     12042    
##  2 prevalence            0.333     0.333     0.333     0.333
##  3 accuracy              0.69      0.691     0.709     0.763
##  4 prop.positive         0.144     0.1       0.146     0.192
##  5 PPV                   0.583     0.622     0.644     0.751
##  6 FPR                   0.09      0.057     0.078     0.072
##  7 FNR                   0.748     0.813     0.718     0.566
##  8 Specificity (TNR)     0.91      0.943     0.922     0.928
##  9 Sensitivity (TPR)     0.252     0.187     0.282     0.434
## 10 LR+                   2.79      3.28      3.62      6.04 
## 11 LR-                   0.822     0.862     0.779     0.61
marketing.preds <- tibble(glm = predict(marketing.glm, newdata = marketing.test, type = "response"),
                       lasso = predict(marketing.cv.glmnet, x.marketing.test, s = "lambda.min", type = "response")[,1],
                       tree = predict(marketing.tree, newdata = marketing.test)[,"1"],
                       rf = predict(marketing.rf, data = marketing.test)$predictions,
                       y = marketing.test$y)

Let’s look at ROC curves and the AUC. ROC curves trace out the TPR on the y axis and the FPR on the x axis as we vary the threshold used for classification.

roc.list <- with(marketing.preds, roc(y ~ glm + lasso + tree + rf))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
plot(roc.list[[1]])
plot(roc.list[[2]], col = "red", add = TRUE)
plot(roc.list[[3]], col = "purple", add = TRUE)
plot(roc.list[[4]], col = "steelblue", add = TRUE)

Let’s calculate the AUCs for these (the areas under the curve). The AUC

with(marketing.preds, auc(y ~ glm))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Area under the curve: 0.665
with(marketing.preds, auc(y ~ lasso))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Area under the curve: 0.665
with(marketing.preds, auc(y ~ rf))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Area under the curve: 0.847

OK… but what variables are important? We don’t have p-values or coefficient estimates, but we do have “importance” measures that tell us how important variables are for predictions.

sort(marketing.rf$variable.importance)
##   default      loan   marital education       job   housing       age 
##  21.16536 119.14167 152.43088 175.14662 305.50842 357.57898 791.81364 
##   balance 
## 876.18534
library(edarf)
pd <- partial_dependence(marketing.rf, 
                        data = marketing.test,
                        vars = c("balance"))
plot_pd(pd)

One of the reasons that the logistic models might not be performing well is that variable like balance appear to have non-linear relationships with the outcome. There looks to be a sharp discontinuity in the relationship between outcome and balance, as modeled by the random forest.