--- title: "Introduction to parttree" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Introduction to parttree} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", warning = FALSE ) ``` ## Basic use Let's start by loading the **parttree** package alongside **rpart**, which comes bundled with the base R installation and is what we'll use for fitting our decision trees (at least, to start with). For the basic examples that follow, I'll use the well-known [Palmer Penguins](https://allisonhorst.github.io/palmerpenguins/) dataset to demonstrate functionality. You can load this dataset via the parent package (as I have here), or import it directly as a CSV [here](https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv). ```{r setup} library(rpart) # For fitting decisions trees library(parttree) # This package (will automatically load ggplot2 too) theme_set(theme_linedraw()) # install.packages("palmerpenguins") data("penguins", package = "palmerpenguins") head(penguins) ``` ### Categorical predictions Say we are interested in predicting the penguins _species_ as a function of 1) flipper length and 2) bill length. We can visualize these relationships as a simple scatter plot prior to doing any formal modeling. ```{r penguin_plot_cat1} p = ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_point(aes(col = species)) p ``` Recasting in terms of a decision tree is easily done (e.g., with `rpart`). However, visualizing the resulting tree predictions against the raw data is hard to do out of the box and this where **parttree** enters the fray. The main function that users will interact with is `geom_parttree()`, which provides a new geom layer for **ggplot2** objects. ```{r penguin_plot_cat2} ## Fit a decision tree using the same variables as the above plot tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins) ## Visualize the tree partitions by adding it to our plot with geom_parttree() p + geom_parttree(data = tree, aes(fill=species), alpha = 0.1) + labs(caption = "Note: Points denote observations. Shaded regions denote model predictions.") ``` #### Continuous predictions Trees with continuous independent variables are also supported. However, I recommend adjusting the plot fill aesthetic since your model will likely partition the data into intervals that don't match up exactly with the raw data. The easiest way to do this is by setting your colour and fill aesthetic together as part of the same `scale_colour_*` call. ```{r penguin_plot_con} tree2 = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data=penguins) ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_parttree(data = tree2, aes(fill=body_mass_g), alpha = 0.3) + geom_point(aes(col = body_mass_g)) + scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together ``` ## Supported model classes Currently, the package works with decision trees created by the [**rpart**](https://CRAN.R-project.org/web/package=rpart) and [**partykit**](https://CRAN.R-project.org/web/package=partykit) packages. Moreover, it supports other front-end modes that call `rpart::rpart()` as the underlying engine; in particular the [**tidymodels**](https://www.tidymodels.org/) ([parsnip](https://parsnip.tidymodels.org/) or [workflows](https://workflows.tidymodels.org/)) and [**mlr3**](https://mlr3.mlr-org.com/) packages. Here's a quick example with **parsnip**. ```{r titanic_plot} set.seed(123) ## For consistent jitter library(parsnip) library(titanic) ## Just for a different data set titanic_train$Survived = as.factor(titanic_train$Survived) ## Build our tree using parsnip (but with rpart as the model engine) ti_tree = decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Survived ~ Pclass + Age, data = titanic_train) ## Plot the data and model partitions titanic_train |> ggplot(aes(x=Pclass, y=Age)) + geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) + geom_jitter(aes(col=Survived), alpha=0.7) ``` ## Plot orientation Underneath the hood, `geom_parttree()` is calling the companion `parttree()` function, which coerces the **rpart** tree object into a data frame that is easily understood by **ggplot2**. For example, consider again our first "tree" model from earlier. Here's the print output of the raw model. ```{r tree} tree ``` And here's what we get after we feed it to `parttree()`. ```{r tree_parted} parttree(tree) ``` Again, the resulting data frame is designed to be amenable to a **ggplot2** geom layer, with columns like `xmin`, `xmax`, etc. specifying aesthetics that **ggplot2** recognises. (Fun fact: `geom_parttree()` is really just a thin wrapper around `geom_rect()`.) The goal of the package is to abstract away these kinds of details from the user, so we can just specify `geom_parttree()` — with a valid tree object as the data input — and be done with it. However, while this generally works well, it can sometimes lead to unexpected behaviour in terms of plot orientation. That's because it's hard to guess ahead of time what the user will specify as the x and y variables (i.e. axes) in their other plot layers. To see what I mean, let's redo our penguin plot from earlier, but this time switch the axes in the main `ggplot()` call. ```{r tree_plot_mismatch} ## First, redo our first plot but this time switch the x and y variables p3 = ggplot( data = penguins, aes(x = bill_length_mm, y = flipper_length_mm) ## Switched! ) + geom_point(aes(col = species)) ## Add on our tree (and some preemptive titling..) p3 + geom_parttree(data = tree, aes(fill = species), alpha = 0.1) + labs( title = "Oops!", subtitle = "Looks like a mismatch between our x and y axes..." ) ``` As was the case here, this kind of orientation mismatch is normally (hopefully) pretty easy to recognize. To fix, we can use the `flipaxes = TRUE` argument to flip the orientation of the `geom_parttree` layer. ```{r tree_plot_flip} p3 + geom_parttree( data = tree, aes(fill = species), alpha = 0.1, flipaxes = TRUE ## Flip the orientation ) + labs(title = "That's better") ``` ## Base graphics While the package has been primarily designed to work with **ggplot2**, the `parttree()` infrastructure can also be used to generate plots with base graphics. Here, the `ctree()` function from **partykit** is used for fitting the tree. ```{r ctree_base_graphics} library(partykit) ## CTree and corresponding partition ct = ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins) pt = parttree(ct) ## Color palette pal = palette.colors(4, "R4")[-1] ## Maximum/minimum for plotting range as rect() does not handle Inf well m = 1000 ## scatter plot() with added rect() plot( bill_length_mm ~ flipper_length_mm, data = penguins, col = pal[species], pch = 19 ) rect( pmax(-m, pt$xmin), pmax(-m, pt$ymin), pmin(m, pt$xmax), pmin(m, pt$ymax), col = adjustcolor(pal, alpha.f = 0.1)[pt$species] ) ```