--- title: "Multinomial Classification" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Multinomial Classification} %\VignetteEngine{knitr::knitr} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} use_saved_results <- TRUE knitr::opts_chunk$set( collapse = TRUE, comment = "#>", echo = TRUE, eval = !use_saved_results, message = FALSE, warning = FALSE ) if (use_saved_results) { results <- readRDS("vignette_mc.rds") pred <- results$pred } ``` ```{r, eval=TRUE} library(dplyr); library(tidyr); library(purrr) # Data wrangling library(ggplot2); library(stringr) # Plotting library(tidyfit) # Auto-ML modeling ``` Multinomial classification is possible in `tidyfit` using the methods powered by `glmnet`, `e1071` and `randomForest` (LASSO, Ridge, ElasticNet, AdaLASSO, SVM and Random Forest). Currently, none of the other methods support multinomial classification.^[Feature selection methods such as `relief` or `chisq` can be used with multinomial response variables. I may also add support for multinomial classification with `mboost` in future.] When the response variable contains more than 2 classes, `classify` automatically uses a multinomial response for the above-mentioned methods. Here's an example using the built-in `iris` dataset: ```{r, eval=TRUE} data("iris") # For reproducibility set.seed(42) ix_tst <- sample(1:nrow(iris), round(nrow(iris)*0.2)) data_trn <- iris[-ix_tst,] data_tst <- iris[ix_tst,] as_tibble(iris) ``` ## Penalized classification algorithms to predict `Species` The code chunk below fits the above mentioned algorithms on the training split, using a 10-fold cross validation to select optimal penalties. We then obtain out-of-sample predictions using `predict`. Unlike binomial classification, the `fit` and `pred` objects contain a `class` column with separate coefficients and predictions for each class. The predictions sum to one across classes: ```{r} fit <- data_trn %>% classify(Species ~ ., LASSO = m("lasso"), Ridge = m("ridge"), ElasticNet = m("enet"), AdaLASSO = m("adalasso"), SVM = m("svm"), `Random Forest` = m("rf"), `Least Squares` = m("ridge", lambda = 1e-5), .cv = "vfold_cv") pred <- fit %>% predict(data_tst) ``` Note that we can add unregularized least squares estimates by setting `lambda = 0` (or very close to zero). Next, we can use `yardstick` to calculate the log loss accuracy metric and compare the performance of the different models: ```{r, fig.width=7, fig.height=3, fig.align="center", eval=TRUE} metrics <- pred %>% group_by(model, class) %>% mutate(row_n = row_number()) %>% spread(class, prediction) %>% group_by(model) %>% yardstick::mn_log_loss(truth, setosa:virginica) metrics %>% mutate(model = str_wrap(model, 11)) %>% ggplot(aes(model, .estimate)) + geom_col(fill = "darkblue") + theme_bw() + theme(axis.title.x = element_blank()) ``` The least squares estimate performs poorest, while the random forest (nonlinear) and the support vector machine (SVM) achieve the best results. The SVM is estimated with a linear kernel by default (use `kernel = ` to use a different kernel).