Let’s start by loading the parttree package alongside rpart, which comes bundled with the base R installation and is what we’ll use for fitting our decision trees (at least, to start with). For the basic examples that follow, I’ll use the well-known Palmer Penguins dataset to demonstrate functionality. You can load this dataset via the parent package (as I have here), or import it directly as a CSV here.
library(rpart) # For fitting decisions trees
library(parttree) # This package (will automatically load ggplot2 too)
theme_set(theme_linedraw())
# install.packages("palmerpenguins")
data("penguins", package = "palmerpenguins")
head(penguins)
#> # A tibble: 6 × 8
#> species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#> <fct> <fct> <dbl> <dbl> <int> <int>
#> 1 Adelie Torgersen 39.1 18.7 181 3750
#> 2 Adelie Torgersen 39.5 17.4 186 3800
#> 3 Adelie Torgersen 40.3 18 195 3250
#> 4 Adelie Torgersen NA NA NA NA
#> 5 Adelie Torgersen 36.7 19.3 193 3450
#> 6 Adelie Torgersen 39.3 20.6 190 3650
#> # ℹ 2 more variables: sex <fct>, year <int>
Say we are interested in predicting the penguins species as a function of 1) flipper length and 2) bill length. We can visualize these relationships as a simple scatter plot prior to doing any formal modeling.
p =
ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
geom_point(aes(col = species))
p
Recasting in terms of a decision tree is easily done (e.g., with
rpart
). However, visualizing the resulting tree predictions
against the raw data is hard to do out of the box and this where
parttree enters the fray. The main function that users
will interact with is geom_parttree()
, which provides a new
geom layer for ggplot2 objects.
## Fit a decision tree using the same variables as the above plot
tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins)
## Visualize the tree partitions by adding it to our plot with geom_parttree()
p +
geom_parttree(data = tree, aes(fill=species), alpha = 0.1) +
labs(caption = "Note: Points denote observations. Shaded regions denote model predictions.")
Trees with continuous independent variables are also supported.
However, I recommend adjusting the plot fill aesthetic since your model
will likely partition the data into intervals that don’t match up
exactly with the raw data. The easiest way to do this is by setting your
colour and fill aesthetic together as part of the same
scale_colour_*
call.
tree2 = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data=penguins)
ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
geom_parttree(data = tree2, aes(fill=body_mass_g), alpha = 0.3) +
geom_point(aes(col = body_mass_g)) +
scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together
Currently, the package works with decision trees created by the rpart
and partykit
packages. Moreover, it supports other front-end modes that call
rpart::rpart()
as the underlying engine; in particular the
tidymodels
(parsnip or workflows) and mlr3 packages.
Here’s a quick example with parsnip.
set.seed(123) ## For consistent jitter
library(parsnip)
library(titanic) ## Just for a different data set
titanic_train$Survived = as.factor(titanic_train$Survived)
## Build our tree using parsnip (but with rpart as the model engine)
ti_tree =
decision_tree() |>
set_engine("rpart") |>
set_mode("classification") |>
fit(Survived ~ Pclass + Age, data = titanic_train)
## Plot the data and model partitions
titanic_train |>
ggplot(aes(x=Pclass, y=Age)) +
geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) +
geom_jitter(aes(col=Survived), alpha=0.7)
Underneath the hood, geom_parttree()
is calling the
companion parttree()
function, which coerces the
rpart tree object into a data frame that is easily
understood by ggplot2. For example, consider again our
first “tree” model from earlier. Here’s the print output of the raw
model.
tree
#> n=342 (2 observations deleted due to missingness)
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 342 191 Adelie (0.441520468 0.198830409 0.359649123)
#> 2) flipper_length_mm< 206.5 213 64 Adelie (0.699530516 0.295774648 0.004694836)
#> 4) bill_length_mm< 43.35 150 5 Adelie (0.966666667 0.033333333 0.000000000) *
#> 5) bill_length_mm>=43.35 63 5 Chinstrap (0.063492063 0.920634921 0.015873016) *
#> 3) flipper_length_mm>=206.5 129 7 Gentoo (0.015503876 0.038759690 0.945736434) *
And here’s what we get after we feed it to
parttree()
.
parttree(tree)
#> node species path xmin
#> 1 3 Gentoo flipper_length_mm >= 206.5 206.5
#> 2 4 Adelie flipper_length_mm < 206.5 --> bill_length_mm < 43.35 -Inf
#> 3 5 Chinstrap flipper_length_mm < 206.5 --> bill_length_mm >= 43.35 -Inf
#> xmax ymin ymax
#> 1 Inf -Inf Inf
#> 2 206.5 -Inf 43.35
#> 3 206.5 43.35 Inf
Again, the resulting data frame is designed to be amenable to a
ggplot2 geom layer, with columns like
xmin
, xmax
, etc. specifying aesthetics that
ggplot2 recognises. (Fun fact:
geom_parttree()
is really just a thin wrapper around
geom_rect()
.) The goal of the package is to abstract away
these kinds of details from the user, so we can just specify
geom_parttree()
— with a valid tree object as the data
input — and be done with it. However, while this generally works well,
it can sometimes lead to unexpected behaviour in terms of plot
orientation. That’s because it’s hard to guess ahead of time what the
user will specify as the x and y variables (i.e. axes) in their other
plot layers. To see what I mean, let’s redo our penguin plot from
earlier, but this time switch the axes in the main ggplot()
call.
## First, redo our first plot but this time switch the x and y variables
p3 =
ggplot(
data = penguins,
aes(x = bill_length_mm, y = flipper_length_mm) ## Switched!
) +
geom_point(aes(col = species))
## Add on our tree (and some preemptive titling..)
p3 +
geom_parttree(data = tree, aes(fill = species), alpha = 0.1) +
labs(
title = "Oops!",
subtitle = "Looks like a mismatch between our x and y axes..."
)
As was the case here, this kind of orientation mismatch is normally
(hopefully) pretty easy to recognize. To fix, we can use the
flipaxes = TRUE
argument to flip the orientation of the
geom_parttree
layer.
While the package has been primarily designed to work with
ggplot2, the parttree()
infrastructure can
also be used to generate plots with base graphics. Here, the
ctree()
function from partykit is used for
fitting the tree.
library(partykit)
#> Loading required package: grid
#>
#> Attaching package: 'grid'
#> The following object is masked from 'package:imager':
#>
#> depth
#> Loading required package: libcoin
#> Loading required package: mvtnorm
#>
#> Attaching package: 'partykit'
#> The following object is masked from 'package:imager':
#>
#> width
## CTree and corresponding partition
ct = ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins)
pt = parttree(ct)
## Color palette
pal = palette.colors(4, "R4")[-1]
## Maximum/minimum for plotting range as rect() does not handle Inf well
m = 1000
## scatter plot() with added rect()
plot(
bill_length_mm ~ flipper_length_mm,
data = penguins, col = pal[species], pch = 19
)
rect(
pmax(-m, pt$xmin), pmax(-m, pt$ymin), pmin(m, pt$xmax), pmin(m, pt$ymax),
col = adjustcolor(pal, alpha.f = 0.1)[pt$species]
)