5 min read

Coding Warmup 4

Make a classification model and run evaluations.

Part A

We are going to use a toy dataset called bivariate. There is a training, testing, and validation dataset provided.

library(tidyverse)
## Warning: package 'tidyverse' was built under R version 4.1.3
## -- Attaching packages --------------------------------------- tidyverse 1.3.2 --
## v ggplot2 3.4.0      v purrr   0.3.5 
## v tibble  3.1.8      v dplyr   1.0.10
## v tidyr   1.2.1      v stringr 1.5.0 
## v readr   2.1.3      v forcats 0.5.2
## Warning: package 'ggplot2' was built under R version 4.1.3
## Warning: package 'tibble' was built under R version 4.1.3
## Warning: package 'tidyr' was built under R version 4.1.3
## Warning: package 'readr' was built under R version 4.1.3
## Warning: package 'purrr' was built under R version 4.1.3
## Warning: package 'dplyr' was built under R version 4.1.3
## Warning: package 'stringr' was built under R version 4.1.3
## Warning: package 'forcats' was built under R version 4.1.3
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.1.3
## -- Attaching packages -------------------------------------- tidymodels 1.0.0 --
## v broom        1.0.1     v rsample      1.1.1
## v dials        1.1.0     v tune         1.0.1
## v infer        1.0.4     v workflows    1.1.2
## v modeldata    1.0.1     v workflowsets 1.0.0
## v parsnip      1.0.3     v yardstick    1.1.0
## v recipes      1.0.3
## Warning: package 'broom' was built under R version 4.1.3
## Warning: package 'dials' was built under R version 4.1.3
## Warning: package 'scales' was built under R version 4.1.3
## Warning: package 'infer' was built under R version 4.1.3
## Warning: package 'modeldata' was built under R version 4.1.3
## Warning: package 'parsnip' was built under R version 4.1.3
## Warning: package 'recipes' was built under R version 4.1.3
## Warning: package 'rsample' was built under R version 4.1.3
## Warning: package 'tune' was built under R version 4.1.3
## Warning: package 'workflows' was built under R version 4.1.3
## Warning: package 'workflowsets' was built under R version 4.1.3
## Warning: package 'yardstick' was built under R version 4.1.3
## -- Conflicts ----------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter()   masks stats::filter()
## x recipes::fixed()  masks stringr::fixed()
## x dplyr::lag()      masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step()   masks stats::step()
## * Use suppressPackageStartupMessages() to eliminate package startup messages
theme_set(theme_bw())

data(bivariate)

ggplot(bivariate_train, aes(x=A, y=B, color=Class)) +
  geom_point(alpha=.3)

Use logistic_reg and glm to make a classification model of Class ~ A * B. Then use tidy and glance to see some summary information on our model. Anything stand out to you?

log_model <- logistic_reg() %>%
  set_engine('glm') %>%
  set_mode('classification') %>%
  fit(Class ~ A*B,
      data = bivariate_train)

log_model %>% tidy()
## # A tibble: 4 x 5
##   term          estimate  std.error statistic  p.value
##   <chr>            <dbl>      <dbl>     <dbl>    <dbl>
## 1 (Intercept)  0.115     0.404          0.284 7.76e- 1
## 2 A            0.00433   0.000434       9.97  2.01e-23
## 3 B           -0.0553    0.00633       -8.74  2.32e-18
## 4 A:B         -0.0000101 0.00000222    -4.56  5.04e- 6
log_model %>% broom::glance()
## # A tibble: 1 x 8
##   null.deviance df.null logLik   AIC   BIC deviance df.residual  nobs
##           <dbl>   <int>  <dbl> <dbl> <dbl>    <dbl>       <int> <int>
## 1         1329.    1008  -549. 1106. 1126.    1098.        1005  1009

Part B

Use augment to get predictions. Look at the predictions.

test_preds <- log_model %>% augment(bivariate_test)

test_preds 
## # A tibble: 710 x 6
##        A     B Class .pred_class .pred_One .pred_Two
##    <dbl> <dbl> <fct> <fct>           <dbl>     <dbl>
##  1  742.  68.8 One   One           0.730      0.270 
##  2  709.  50.4 Two   Two           0.491      0.509 
##  3 1006.  89.9 One   One           0.805      0.195 
##  4 1983. 112.  Two   Two           0.431      0.569 
##  5 1698.  81.0 Two   Two           0.169      0.831 
##  6  948.  98.9 One   One           0.900      0.0996
##  7  751.  54.8 One   One           0.521      0.479 
##  8 1254.  72.2 Two   Two           0.347      0.653 
##  9 4243. 136.  One   Two           0.00568    0.994 
## 10  713.  88.2 One   One           0.910      0.0898
## # ... with 700 more rows

Part C

Visually inspect the predictions using the code below

# log_model, your parnsip model
# bivariate_train / bivariate_val, data from bivariate

# to plot the countour we need to create a grid of points and get the model prediction at each point
x_grid <-
  expand.grid(A = seq(min(bivariate_train$A), max(bivariate_train$A), length.out = 100),
              B = seq(min(bivariate_train$B), max(bivariate_train$B), length.out = 100))
x_grid_preds <- log_model %>% augment(x_grid)

# plot predictions from grid as countour and validation data on plot
ggplot(x_grid_preds, aes(x = A, y = B)) + 
  geom_contour(aes(z = .pred_One), breaks = .5, col = "black") + 
  geom_point(data = bivariate_val, aes(col = Class), alpha = 0.3)

Part D

Evaluate your model using the following functions (which dataset(s) should you use to do this train, test, or validation). See if you can provide a basic interpretation of the measures.

  • roc_auc
  • accuracy
  • roc_curve and autoplot
  • f_meas
val_preds <- log_model %>% augment(bivariate_val)

val_preds %>% roc_auc(truth = Class,
                      estimate = .pred_One)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.790

val_preds %>% accuracy(truth = Class,
                      estimate = .pred_class)
## # A tibble: 1 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary          0.76


roc_curve(val_preds,
        truth = Class,
        estimate = .pred_One) %>%
  autoplot()


f_meas(val_preds,
        truth = Class,
        estimate = .pred_class) 
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 f_meas  binary         0.827

Part E

Recall Table 8.4 from the textbook. If necessary, class one can be positive and class two can be negative. Using the output from conf_mat, visually verify you know how to calculate the following:

  • True Positive Rate (TPR), Sensitivity, or Recall
  • True Negative Rate (TNR) or Specificity
  • False Positive Rate, Type I error
  • False Negative Rate (FNR), Type II error
  • Positive Predictive Value (PPV) or Precision
val_preds %>% conf_mat(truth = Class,
                      estimate = .pred_class) %>%
  autoplot("heatmap")