Introduction to parttree

Basic use

Let’s start by loading the parttree package alongside rpart, which comes bundled with the base R installation and is what we’ll use for fitting our decision trees (at least, to start with). For the basic examples that follow, I’ll use the well-known Palmer Penguins dataset to demonstrate functionality. You can load this dataset via the parent package (as I have here), or import it directly as a CSV here.

library(rpart)     # For fitting decisions trees
library(parttree)  # This package (will automatically load ggplot2 too)

theme_set(theme_linedraw())

# install.packages("palmerpenguins")
data("penguins", package = "palmerpenguins")
head(penguins)
#> # A tibble: 6 × 8
#>   species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>   <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
#> 1 Adelie  Torgersen           39.1          18.7               181        3750
#> 2 Adelie  Torgersen           39.5          17.4               186        3800
#> 3 Adelie  Torgersen           40.3          18                 195        3250
#> 4 Adelie  Torgersen           NA            NA                  NA          NA
#> 5 Adelie  Torgersen           36.7          19.3               193        3450
#> 6 Adelie  Torgersen           39.3          20.6               190        3650
#> # ℹ 2 more variables: sex <fct>, year <int>

Categorical predictions

Say we are interested in predicting the penguins species as a function of 1) flipper length and 2) bill length. We can visualize these relationships as a simple scatter plot prior to doing any formal modeling.

p = 
  ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
  geom_point(aes(col = species))
p

Recasting in terms of a decision tree is easily done (e.g., with rpart). However, visualizing the resulting tree predictions against the raw data is hard to do out of the box and this where parttree enters the fray. The main function that users will interact with is geom_parttree(), which provides a new geom layer for ggplot2 objects.

## Fit a decision tree using the same variables as the above plot
tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins)

## Visualize the tree partitions by adding it to our plot with geom_parttree()
p +  
  geom_parttree(data = tree, aes(fill=species), alpha = 0.1) +
  labs(caption = "Note: Points denote observations. Shaded regions denote model predictions.")

Continuous predictions

Trees with continuous independent variables are also supported. However, I recommend adjusting the plot fill aesthetic since your model will likely partition the data into intervals that don’t match up exactly with the raw data. The easiest way to do this is by setting your colour and fill aesthetic together as part of the same scale_colour_* call.

tree2 = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data=penguins)

ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
  geom_parttree(data = tree2, aes(fill=body_mass_g), alpha = 0.3) +
  geom_point(aes(col = body_mass_g)) + 
  scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together

Supported model classes

Currently, the package works with decision trees created by the rpart and partykit packages. Moreover, it supports other front-end modes that call rpart::rpart() as the underlying engine; in particular the tidymodels (parsnip or workflows) and mlr3 packages. Here’s a quick example with parsnip.

set.seed(123) ## For consistent jitter

library(parsnip)
library(titanic) ## Just for a different data set

titanic_train$Survived = as.factor(titanic_train$Survived)

## Build our tree using parsnip (but with rpart as the model engine)
ti_tree =
  decision_tree() |>
  set_engine("rpart") |>
  set_mode("classification") |>
  fit(Survived ~ Pclass + Age, data = titanic_train)

## Plot the data and model partitions
titanic_train |>
  ggplot(aes(x=Pclass, y=Age)) +
  geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) +
  geom_jitter(aes(col=Survived), alpha=0.7)

Plot orientation

Underneath the hood, geom_parttree() is calling the companion parttree() function, which coerces the rpart tree object into a data frame that is easily understood by ggplot2. For example, consider again our first “tree” model from earlier. Here’s the print output of the raw model.

tree
#> n=342 (2 observations deleted due to missingness)
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 342 191 Adelie (0.441520468 0.198830409 0.359649123)  
#>   2) flipper_length_mm< 206.5 213  64 Adelie (0.699530516 0.295774648 0.004694836)  
#>     4) bill_length_mm< 43.35 150   5 Adelie (0.966666667 0.033333333 0.000000000) *
#>     5) bill_length_mm>=43.35 63   5 Chinstrap (0.063492063 0.920634921 0.015873016) *
#>   3) flipper_length_mm>=206.5 129   7 Gentoo (0.015503876 0.038759690 0.945736434) *

And here’s what we get after we feed it to parttree().

parttree(tree)
#>   node   species                                                  path  xmin
#> 1    3    Gentoo                            flipper_length_mm >= 206.5 206.5
#> 2    4    Adelie  flipper_length_mm < 206.5 --> bill_length_mm < 43.35  -Inf
#> 3    5 Chinstrap flipper_length_mm < 206.5 --> bill_length_mm >= 43.35  -Inf
#>    xmax  ymin  ymax
#> 1   Inf  -Inf   Inf
#> 2 206.5  -Inf 43.35
#> 3 206.5 43.35   Inf

Again, the resulting data frame is designed to be amenable to a ggplot2 geom layer, with columns like xmin, xmax, etc. specifying aesthetics that ggplot2 recognises. (Fun fact: geom_parttree() is really just a thin wrapper around geom_rect().) The goal of the package is to abstract away these kinds of details from the user, so we can just specify geom_parttree() — with a valid tree object as the data input — and be done with it. However, while this generally works well, it can sometimes lead to unexpected behaviour in terms of plot orientation. That’s because it’s hard to guess ahead of time what the user will specify as the x and y variables (i.e. axes) in their other plot layers. To see what I mean, let’s redo our penguin plot from earlier, but this time switch the axes in the main ggplot() call.

## First, redo our first plot but this time switch the x and y variables
p3 = 
  ggplot(
    data = penguins, 
    aes(x = bill_length_mm, y = flipper_length_mm) ## Switched!
    ) +
  geom_point(aes(col = species))

## Add on our tree (and some preemptive titling..)
p3 +
  geom_parttree(data = tree, aes(fill = species), alpha = 0.1) +
  labs(
    title = "Oops!",
    subtitle = "Looks like a mismatch between our x and y axes..."
    )

As was the case here, this kind of orientation mismatch is normally (hopefully) pretty easy to recognize. To fix, we can use the flipaxes = TRUE argument to flip the orientation of the geom_parttree layer.

p3 +
  geom_parttree(
    data = tree, aes(fill = species), alpha = 0.1,
    flipaxes = TRUE  ## Flip the orientation
    ) +
  labs(title = "That's better")

Base graphics

While the package has been primarily designed to work with ggplot2, the parttree() infrastructure can also be used to generate plots with base graphics. Here, the ctree() function from partykit is used for fitting the tree.

library(partykit)
#> Loading required package: grid
#> 
#> Attaching package: 'grid'
#> The following object is masked from 'package:imager':
#> 
#>     depth
#> Loading required package: libcoin
#> Loading required package: mvtnorm
#> 
#> Attaching package: 'partykit'
#> The following object is masked from 'package:imager':
#> 
#>     width

## CTree and corresponding partition
ct = ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins)
pt = parttree(ct)

## Color palette
pal = palette.colors(4, "R4")[-1]

## Maximum/minimum for plotting range as rect() does not handle Inf well
m = 1000

## scatter plot() with added rect()
plot(
  bill_length_mm ~ flipper_length_mm, 
  data = penguins, col = pal[species], pch = 19
  )
rect(
  pmax(-m, pt$xmin), pmax(-m, pt$ymin), pmin(m, pt$xmax), pmin(m, pt$ymax),
  col = adjustcolor(pal, alpha.f = 0.1)[pt$species]
  )