Class imbalance and classification metrics with aircraft wildlife strikes

By Julia Silge in rstats tidymodels

June 21, 2021

This is the latest in my series of screencasts demonstrating how to use the tidymodels packages, from just starting out to tuning more complex models with many hyperparameters. I recently participated in SLICED, a competitive data science prediction challenge. I did not necessarily cover myself in glory but in today’s screencast, I walk through the data set on aircraft wildlife strikes we used and how different choices around handling class imbalance affect different classification metrics. ✈️


Here is the code I used in the video, for those who prefer reading instead of or in addition to video.

Explore data

Our modeling goal is to predict whether an aircraft strike with wildlife resulted in damage to the aircraft. There are two data sets provided, training (which has the label damaged) and testing (which does not).

library(tidyverse)

train_raw <- read_csv("train.csv", guess_max = 1e5) %>%
  mutate(damaged = case_when(
    damaged > 0 ~ "damage",
    TRUE ~ "no damage"
  ))
test_raw <- read_csv("test.csv", guess_max = 1e5)

There is lots available in the data!

skimr::skim(train_raw)
Name train_raw
Number of rows 21000
Number of columns 34
_______________________
Column type frequency:
character 20
numeric 14
________________________
Group variables None

Table 1: Data summary

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
operator_id 0 1.00 3 5 0 276 0
operator 0 1.00 3 33 0 275 0
aircraft 0 1.00 3 20 0 424 0
aircraft_type 4992 0.76 1 1 0 2 0
aircraft_make 5231 0.75 2 3 0 62 0
engine_model 6334 0.70 1 2 0 39 0
engine_type 5703 0.73 1 3 0 8 0
engine3_position 19671 0.06 1 11 0 4 0
airport_id 0 1.00 3 5 0 1039 0
airport 34 1.00 4 53 0 1038 0
state 2664 0.87 2 2 0 60 0
faa_region 2266 0.89 3 3 0 14 0
flight_phase 6728 0.68 4 12 0 12 0
visibility 7699 0.63 3 7 0 5 0
precipitation 10327 0.51 3 15 0 8 0
species_id 0 1.00 1 6 0 447 0
species_name 7 1.00 4 50 0 445 0
species_quantity 532 0.97 1 8 0 4 0
flight_impact 8944 0.57 4 21 0 6 0
damaged 0 1.00 6 9 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1.00 14980.94 8663.24 1 7458.75 14978.5 22472.25 30000 ▇▇▇▇▇
incident_year 0 1.00 2006.06 6.72 1990 2001.00 2007.0 2012.00 2015 ▂▃▅▆▇
incident_month 0 1.00 7.19 2.79 1 5.00 8.0 9.00 12 ▃▅▆▇▆
incident_day 0 1.00 15.63 8.82 1 8.00 15.0 23.00 31 ▇▇▇▇▆
aircraft_model 6259 0.70 24.65 21.70 0 10.00 22.0 37.00 98 ▇▆▂▁▁
aircraft_mass 5694 0.73 3.50 0.89 1 3.00 4.0 4.00 5 ▁▁▂▇▁
engine_make 6155 0.71 21.22 11.04 1 10.00 22.0 34.00 47 ▇▂▆▇▁
engines 5696 0.73 2.05 0.46 1 2.00 2.0 2.00 4 ▁▇▁▁▁
engine1_position 5838 0.72 2.99 2.09 1 1.00 1.0 5.00 7 ▇▁▂▅▁
engine2_position 6776 0.68 2.91 2.01 1 1.00 1.0 5.00 7 ▇▁▂▅▁
engine4_position 20650 0.02 2.02 1.43 1 1.00 1.0 4.00 5 ▇▁▁▃▁
height 8469 0.60 819.24 1772.53 0 0.00 50.0 800.00 24000 ▇▁▁▁▁
speed 12358 0.41 141.39 52.25 0 120.00 137.0 160.00 2500 ▇▁▁▁▁
distance 8913 0.58 0.66 3.33 0 0.00 0.0 0.00 100 ▇▁▁▁▁

The data is imbalanced, with not many incidents resulting in damage.

train_raw %>%
  count(damaged)
## # A tibble: 2 x 2
##   damaged       n
##   <chr>     <int>
## 1 damage     1799
## 2 no damage 19201

For numeric predictors, I often like to make a pairs plot for EDA.

library(GGally)

train_raw %>%
  select(damaged, incident_year, height, speed, distance) %>%
  ggpairs(columns = 2:5, aes(color = damaged, alpha = 0.5))

For categorical predictors, plots like these can be useful. Notice especially that NA values look like they may be informative so we likely don’t want to throw them out.

train_raw %>%
  select(
    damaged, precipitation, visibility, engine_type,
    flight_impact, flight_phase, species_quantity
  ) %>%
  pivot_longer(precipitation:species_quantity) %>%
  ggplot(aes(y = value, fill = damaged)) +
  geom_bar(position = "fill") +
  facet_wrap(vars(name), scales = "free", ncol = 2) +
  labs(x = NULL, y = NULL, fill = NULL)

Let’s use the following variables for this post.

bird_df <- train_raw %>%
  select(
    damaged, flight_impact, precipitation,
    visibility, flight_phase, engines, incident_year,
    incident_month, species_id, engine_type,
    aircraft_model, species_quantity, height, speed
  )

Build a model

If I had enough time to try many models, I would split the provided training data via initial_split(), but I learned that two hours isn’t really enough time for me to try that many models. Let’s just create resampling folds from the provided training data.

library(tidymodels)

set.seed(123)
bird_folds <- vfold_cv(train_raw, v = 5, strata = damaged)
bird_folds
## #  5-fold cross-validation using stratification 
## # A tibble: 5 x 2
##   splits               id   
##   <list>               <chr>
## 1 <split [16800/4200]> Fold1
## 2 <split [16800/4200]> Fold2
## 3 <split [16800/4200]> Fold3
## 4 <split [16800/4200]> Fold4
## 5 <split [16800/4200]> Fold5

The SLICED prediction problem was evaluate on a single metric, log loss, so let’s create a metric set for that metric plus a few others for demonstration purposes.

bird_metrics <- metric_set(mn_log_loss, accuracy, sensitivity, specificity)

This data requires lots of preprocessing, such as handling new levels in the test set, pooling infrequent factor levels, and imputing or replacing the NA values.

bird_rec <- recipe(damaged ~ ., data = bird_df) %>%
  step_novel(all_nominal_predictors()) %>%
  step_other(all_nominal_predictors(), threshold = 0.01) %>%
  step_unknown(all_nominal_predictors()) %>%
  step_impute_median(all_numeric_predictors()) %>%
  step_zv(all_predictors())

bird_rec
## Data Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor         13
## 
## Operations:
## 
## Novel factor level assignment for all_nominal_predictors()
## Collapsing factor levels for all_nominal_predictors()
## Unknown factor level assignment for all_nominal_predictors()
## Median Imputation for all_numeric_predictors()
## Zero variance filter on all_predictors()

For this post, let’s use a model I didn’t try out during the stream, a bagged tree model. It’s similar to the kinds of models that perform well in SLICED-like situations but it is easy to set up and very fast to fit.

library(baguette)

bag_spec <-
  bag_tree(min_n = 10) %>%
  set_engine("rpart", times = 25) %>%
  set_mode("classification")

imb_wf <-
  workflow() %>%
  add_recipe(bird_rec) %>%
  add_model(bag_spec)

imb_fit <- fit(imb_wf, data = bird_df)
imb_fit
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: bag_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
## 
## • step_novel()
## • step_other()
## • step_unknown()
## • step_impute_median()
## • step_zv()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Bagged CART (classification with 25 members)
## 
## Variable importance scores include:
## 
## # A tibble: 13 x 4
##    term             value std.error  used
##    <chr>            <dbl>     <dbl> <int>
##  1 flight_impact    480.       6.81    25
##  2 aircraft_model   363.       4.97    25
##  3 incident_year    354.       5.51    25
##  4 species_id       337.       4.62    25
##  5 height           332.       5.45    25
##  6 speed            297.       4.82    25
##  7 incident_month   285.       6.18    25
##  8 flight_phase     246.       4.41    25
##  9 engine_type      213.       3.31    25
## 10 visibility       196.       3.82    25
## 11 precipitation    136.       3.23    25
## 12 engines          117.       2.67    25
## 13 species_quantity  83.7      3.12    25

We automatically get out some variable importance too, which is nice! We see that flight_impact and aircraft_model are very important for this model.

Resample and compare models

Now let’s evaluate how this model performs using resampling.

doParallel::registerDoParallel()
set.seed(123)
imb_rs <-
  fit_resamples(
    imb_wf,
    resamples = bird_folds,
    metrics = bird_metrics
  )

collect_metrics(imb_rs)
## # A tibble: 4 x 6
##   .metric     .estimator  mean     n  std_err .config             
##   <chr>       <chr>      <dbl> <int>    <dbl> <chr>               
## 1 accuracy    binary     0.925     5 0.00221  Preprocessor1_Model1
## 2 mn_log_loss binary     0.212     5 0.00511  Preprocessor1_Model1
## 3 sens        binary     0.278     5 0.00941  Preprocessor1_Model1
## 4 spec        binary     0.986     5 0.000843 Preprocessor1_Model1

This is quite good compared to how other folks did with this data, especially for such a simple model. We could take this as a starting point and move to a similar but better performing model like xgboost.

What happens, though, if we change the preprocessing recipe to account for the class imbalance?

library(themis)

bal_rec <- bird_rec %>%
  step_dummy(all_nominal_predictors()) %>%
  step_smote(damaged)

bal_wf <-
  workflow() %>%
  add_recipe(bal_rec) %>%
  add_model(bag_spec)

set.seed(234)
bal_rs <-
  fit_resamples(
    bal_wf,
    resamples = bird_folds,
    metrics = bird_metrics
  )

collect_metrics(bal_rs)
## # A tibble: 4 x 6
##   .metric     .estimator  mean     n std_err .config             
##   <chr>       <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy    binary     0.919     5 0.00215 Preprocessor1_Model1
## 2 mn_log_loss binary     0.224     5 0.00559 Preprocessor1_Model1
## 3 sens        binary     0.322     5 0.00967 Preprocessor1_Model1
## 4 spec        binary     0.975     5 0.00103 Preprocessor1_Model1

Notice that the log loss and accuracy got worse, while the sensitivity got better. This is very common and expected, and frankly I wish I hadn’t been so laser focused on needing to get subsampling to work during the SLICED stream! In most real-world situations, a single metric is not adequate to measure how useful a model will be practically, and also unfortunately we often are most interested in detecting the minority class. This means that learning how to account for class imbalance is important in many real modeling scenarios. However, if you are ever in a situation where you are being evaluated on a single metric like log loss, you may want to stick with an imbalanced fit.

test_df <- test_raw %>%
  select(
    id, flight_impact, precipitation,
    visibility, flight_phase, engines, incident_year,
    incident_month, species_id, engine_type,
    aircraft_model, species_quantity, height, speed
  )

augment(imb_fit, test_df) %>%
  select(id, .pred_damage)
## # A tibble: 9,000 x 2
##       id .pred_damage
##    <dbl>        <dbl>
##  1 11254     0.346   
##  2 27716     0.00606 
##  3 29066     0.000544
##  4  3373     0.0406  
##  5  1996     0.153   
##  6 18061     0.000654
##  7 22237     0.00489 
##  8 25346     0.274   
##  9 21554     0.348   
## 10  4273     0.00390 
## # … with 8,990 more rows
Posted on:
June 21, 2021
Length:
8 minute read, 1598 words
Categories:
rstats tidymodels
Tags:
rstats tidymodels
See Also:
Topic modeling for #TidyTuesday Spice Girls lyrics
Predicting viewership for #TidyTuesday Doctor Who episodes
Spatial resampling for #TidyTuesday and the #30DayMapChallenge