Title: | Visualize Simple 2-D Decision Tree Partitions |
---|---|
Description: | Visualize the partitions of simple decision trees, involving one or two predictors, on the scale of the original data. Provides an intuitive alternative to traditional tree diagrams, by visualizing how a decision tree divides the predictor space in a simple 2D plot alongside the original data. The 'parttree' package supports both classification and regression trees from 'rpart' and 'partykit', as well as trees produced by popular frontend systems like 'tidymodels' and 'mlr3'. Visualization methods are provided for both base R graphics and 'ggplot2'. |
Authors: | Grant McDermott [aut, cre] , Achim Zeileis [ctb] , Brian Heseung Kim [ctb] , Julia Silge [ctb] |
Maintainer: | Grant McDermott <[email protected]> |
License: | MIT + file LICENSE |
Version: | 0.1.0 |
Built: | 2025-01-16 16:32:32 UTC |
Source: | https://github.com/grantmcdermott/parttree |
geom_parttree()
is a simple wrapper around parttree()
that
takes a tree model object and then converts into an amenable data frame
that ggplot2
knows how to plot. Please note that ggplot2
is not a hard
dependency of parttree
and must thus be installed separately on the
user's system before calling geom_parttree
.
geom_parttree( mapping = NULL, data = NULL, stat = "identity", position = "identity", linejoin = "mitre", na.rm = FALSE, show.legend = NA, inherit.aes = TRUE, flip = FALSE, ... )
geom_parttree( mapping = NULL, data = NULL, stat = "identity", position = "identity", linejoin = "mitre", na.rm = FALSE, show.legend = NA, inherit.aes = TRUE, flip = FALSE, ... )
mapping |
Set of aesthetic mappings created by |
data |
An rpart::rpart.object or an object of compatible
type (e.g. a decision tree constructed via the |
stat |
The statistical transformation to use on the data for this layer.
When using a
|
position |
A position adjustment to use on the data for this layer. This
can be used in various ways, including to prevent overplotting and
improving the display. The
|
linejoin |
Line join style (round, mitre, bevel). |
na.rm |
If |
show.legend |
logical. Should this layer be included in the legends?
|
inherit.aes |
If |
flip |
Logical. By default, the "x" and "y" axes variables for
plotting are determined by the first split in the tree. This can cause
plot orientation mismatches depending on how users specify the other layers
of their plot. Setting to |
... |
Other arguments passed on to
|
Because of the way that ggplot2
validates inputs and assembles
plot layers, note that the data input for geom_parttree()
(i.e. decision
tree object) must assigned in the layer itself; not in the initialising
ggplot2::ggplot()
call. See Examples.
A ggplot
layer.
geom_parttree()
aims to "work-out-of-the-box" with minimal input from
the user's side, apart from specifying the data object. This includes taking
care of the data transformation in a way that, generally, produces optimal
corner coordinates for each partition (i.e. xmin
, xmax
, ymin
, and
ymax
). However, it also understands the following aesthetics that users
may choose to specify manually:
fill
(particularly encouraged, since this will provide a visual
cue regarding the prediction in each partition region)
colour
alpha
linetype
size
plot.parttree()
, which provides an alternative plotting method using base R graphics.
# install.packages("ggplot2") library(ggplot2) # ggplot2 must be installed/loaded separately library(parttree) # this package library(rpart) # decision trees # ## Simple decision tree (max of two predictor variables) iris_tree = rpart(Species ~ Petal.Length + Petal.Width, data=iris) # Plot with original iris data only p = ggplot(data = iris, aes(x = Petal.Length, y = Petal.Width)) + geom_point(aes(col = Species)) # Add tree partitions to the plot (borders only) p + geom_parttree(data = iris_tree) # Better to use fill and highlight predictions p + geom_parttree(data = iris_tree, aes(fill = Species), alpha=0.1) # To drop the black border lines (i.e. fill only) p + geom_parttree(data = iris_tree, aes(fill = Species), col = NA, alpha = 0.1) # ## Example with plot orientation mismatch p2 = ggplot(iris, aes(x=Petal.Width, y=Petal.Length)) + geom_point(aes(col=Species)) # Oops p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1) # Fix with 'flip = TRUE' p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1, flip = TRUE) # ## Various front-end frameworks are also supported, e.g.: # install.packages("parsnip") library(parsnip) iris_tree_parsnip = decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Species ~ Petal.Length + Petal.Width, data=iris) p + geom_parttree(data = iris_tree_parsnip, aes(fill=Species), alpha = 0.1) # ## Trees with continuous independent variables are also supported. # Note: you may need to adjust (or switch off) the fill legend to match the # original data, e.g.: iris_tree_cont = rpart(Petal.Length ~ Sepal.Length + Petal.Width, data=iris) p3 = ggplot(data = iris, aes(x = Petal.Width, y = Sepal.Length)) + geom_parttree( data = iris_tree_cont, aes(fill = Petal.Length), alpha=0.5 ) + geom_point(aes(col = Petal.Length)) + theme_minimal() # Legend scales don't quite match here: p3 # Better to scale fill to the original data p3 + scale_fill_continuous(limits = range(iris$Petal.Length))
# install.packages("ggplot2") library(ggplot2) # ggplot2 must be installed/loaded separately library(parttree) # this package library(rpart) # decision trees # ## Simple decision tree (max of two predictor variables) iris_tree = rpart(Species ~ Petal.Length + Petal.Width, data=iris) # Plot with original iris data only p = ggplot(data = iris, aes(x = Petal.Length, y = Petal.Width)) + geom_point(aes(col = Species)) # Add tree partitions to the plot (borders only) p + geom_parttree(data = iris_tree) # Better to use fill and highlight predictions p + geom_parttree(data = iris_tree, aes(fill = Species), alpha=0.1) # To drop the black border lines (i.e. fill only) p + geom_parttree(data = iris_tree, aes(fill = Species), col = NA, alpha = 0.1) # ## Example with plot orientation mismatch p2 = ggplot(iris, aes(x=Petal.Width, y=Petal.Length)) + geom_point(aes(col=Species)) # Oops p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1) # Fix with 'flip = TRUE' p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1, flip = TRUE) # ## Various front-end frameworks are also supported, e.g.: # install.packages("parsnip") library(parsnip) iris_tree_parsnip = decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Species ~ Petal.Length + Petal.Width, data=iris) p + geom_parttree(data = iris_tree_parsnip, aes(fill=Species), alpha = 0.1) # ## Trees with continuous independent variables are also supported. # Note: you may need to adjust (or switch off) the fill legend to match the # original data, e.g.: iris_tree_cont = rpart(Petal.Length ~ Sepal.Length + Petal.Width, data=iris) p3 = ggplot(data = iris, aes(x = Petal.Width, y = Sepal.Length)) + geom_parttree( data = iris_tree_cont, aes(fill = Petal.Length), alpha=0.5 ) + geom_point(aes(col = Petal.Length)) + theme_minimal() # Legend scales don't quite match here: p3 # Better to scale fill to the original data p3 + scale_fill_continuous(limits = range(iris$Petal.Length))
Extracts the terminal leaf nodes of a decision tree that contains no more that two numeric predictor variables. These leaf nodes are then converted into a data frame, where each row represents a partition (or leaf or terminal node) that can easily be plotted in 2-D coordinate space.
parttree(tree, keep_as_dt = FALSE, flip = FALSE)
parttree(tree, keep_as_dt = FALSE, flip = FALSE)
tree |
An |
keep_as_dt |
Logical. The function relies on |
flip |
Logical. Should we flip the "x" and "y" variables in the return
data frame? The default behaviour is for the first split variable in the
tree to take the "y" slot, and any second split variable to take the "x"
slot. Setting to Note: This argument is primarily useful when it passed via
geom_parttree to ensure correct axes orientation as part of a |
A data frame comprising seven columns: the leaf node, its path, a set of rectangle limits (i.e., xmin, xmax, ymin, ymax), and a final column corresponding to the predicted value for that leaf.
plot.parttree, geom_parttree, rpart
,
ctree
partykit::ctree.
library("parttree") # ## rpart trees library("rpart") rp = rpart(Kyphosis ~ Start + Age, data = kyphosis) # A parttree object is just a data frame with additional attributes (rp_pt = parttree(rp)) attr(rp_pt, "parttree") # simple plot plot(rp_pt) # removing the (recursive) partition borders helps to emphasise overall fit plot(rp_pt, border = NA) # customize further by passing extra options to (tiny)plot plot( rp_pt, border = NA, # no partition borders pch = 16, # filled points alpha = 0.6, # point transparency grid = TRUE, # background grid palette = "classic", # new colour palette xlab = "Topmost vertebra operated on", # custom x title ylab = "Patient age (months)", # custom y title main = "Tree predictions: Kyphosis recurrence" # custom title ) # ## conditional inference trees from partyit library("partykit") ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) ct_pt = parttree(ct) plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species") ## rpart via partykit rp2 = as.party(rp) parttree(rp2) # ## various front-end frameworks are also supported, e.g. # tidymodels # install.packages("parsnip") library(parsnip) decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Species ~ Petal.Length + Petal.Width, data=iris) |> parttree() |> plot(main = "This time brought to you via parsnip...") # mlr3 (NB: use `keep_model = TRUE` for mlr3 learners) # install.packages("mlr3") library(mlr3) task_iris = TaskClassif$new("iris", iris, target = "Species") task_iris$formula(rhs = "Petal.Length + Petal.Width") fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB! fit_iris$train(task_iris) plot(parttree(fit_iris), main = "... and now mlr3")
library("parttree") # ## rpart trees library("rpart") rp = rpart(Kyphosis ~ Start + Age, data = kyphosis) # A parttree object is just a data frame with additional attributes (rp_pt = parttree(rp)) attr(rp_pt, "parttree") # simple plot plot(rp_pt) # removing the (recursive) partition borders helps to emphasise overall fit plot(rp_pt, border = NA) # customize further by passing extra options to (tiny)plot plot( rp_pt, border = NA, # no partition borders pch = 16, # filled points alpha = 0.6, # point transparency grid = TRUE, # background grid palette = "classic", # new colour palette xlab = "Topmost vertebra operated on", # custom x title ylab = "Patient age (months)", # custom y title main = "Tree predictions: Kyphosis recurrence" # custom title ) # ## conditional inference trees from partyit library("partykit") ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) ct_pt = parttree(ct) plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species") ## rpart via partykit rp2 = as.party(rp) parttree(rp2) # ## various front-end frameworks are also supported, e.g. # tidymodels # install.packages("parsnip") library(parsnip) decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Species ~ Petal.Length + Petal.Width, data=iris) |> parttree() |> plot(main = "This time brought to you via parsnip...") # mlr3 (NB: use `keep_model = TRUE` for mlr3 learners) # install.packages("mlr3") library(mlr3) task_iris = TaskClassif$new("iris", iris, target = "Species") task_iris$formula(rhs = "Petal.Length + Petal.Width") fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB! fit_iris$train(task_iris) plot(parttree(fit_iris), main = "... and now mlr3")
Provides a plot method for parttree objects.
## S3 method for class 'parttree' plot( x, raw = TRUE, border = "black", fill_alpha = 0.3, expand = TRUE, jitter = FALSE, add = FALSE, ... )
## S3 method for class 'parttree' plot( x, raw = TRUE, border = "black", fill_alpha = 0.3, expand = TRUE, jitter = FALSE, add = FALSE, ... )
x |
A parttree data frame. |
raw |
Logical. Should the raw (original) data points be plotted too? Default is TRUE. |
border |
Colour of the partition borders (edges). Default is "black". To
remove the borders altogether, specify as |
fill_alpha |
Numeric in the range |
expand |
Logical. Should the partition limits be expanded to to meet the
edge of the plot axes? Default is |
jitter |
Logical. Should the raw points be jittered? Default is |
add |
Logical. Add to an existing plot? Default is |
... |
Additional arguments passed down to
|
No return value, called for side effect of producing a plot.
No return value; called for its side effect of producing a plot.
library("parttree") # ## rpart trees library("rpart") rp = rpart(Kyphosis ~ Start + Age, data = kyphosis) # A parttree object is just a data frame with additional attributes (rp_pt = parttree(rp)) attr(rp_pt, "parttree") # simple plot plot(rp_pt) # removing the (recursive) partition borders helps to emphasise overall fit plot(rp_pt, border = NA) # customize further by passing extra options to (tiny)plot plot( rp_pt, border = NA, # no partition borders pch = 16, # filled points alpha = 0.6, # point transparency grid = TRUE, # background grid palette = "classic", # new colour palette xlab = "Topmost vertebra operated on", # custom x title ylab = "Patient age (months)", # custom y title main = "Tree predictions: Kyphosis recurrence" # custom title ) # ## conditional inference trees from partyit library("partykit") ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) ct_pt = parttree(ct) plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species") ## rpart via partykit rp2 = as.party(rp) parttree(rp2) # ## various front-end frameworks are also supported, e.g. # tidymodels # install.packages("parsnip") library(parsnip) decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Species ~ Petal.Length + Petal.Width, data=iris) |> parttree() |> plot(main = "This time brought to you via parsnip...") # mlr3 (NB: use `keep_model = TRUE` for mlr3 learners) # install.packages("mlr3") library(mlr3) task_iris = TaskClassif$new("iris", iris, target = "Species") task_iris$formula(rhs = "Petal.Length + Petal.Width") fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB! fit_iris$train(task_iris) plot(parttree(fit_iris), main = "... and now mlr3")
library("parttree") # ## rpart trees library("rpart") rp = rpart(Kyphosis ~ Start + Age, data = kyphosis) # A parttree object is just a data frame with additional attributes (rp_pt = parttree(rp)) attr(rp_pt, "parttree") # simple plot plot(rp_pt) # removing the (recursive) partition borders helps to emphasise overall fit plot(rp_pt, border = NA) # customize further by passing extra options to (tiny)plot plot( rp_pt, border = NA, # no partition borders pch = 16, # filled points alpha = 0.6, # point transparency grid = TRUE, # background grid palette = "classic", # new colour palette xlab = "Topmost vertebra operated on", # custom x title ylab = "Patient age (months)", # custom y title main = "Tree predictions: Kyphosis recurrence" # custom title ) # ## conditional inference trees from partyit library("partykit") ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) ct_pt = parttree(ct) plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species") ## rpart via partykit rp2 = as.party(rp) parttree(rp2) # ## various front-end frameworks are also supported, e.g. # tidymodels # install.packages("parsnip") library(parsnip) decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Species ~ Petal.Length + Petal.Width, data=iris) |> parttree() |> plot(main = "This time brought to you via parsnip...") # mlr3 (NB: use `keep_model = TRUE` for mlr3 learners) # install.packages("mlr3") library(mlr3) task_iris = TaskClassif$new("iris", iris, target = "Species") task_iris$formula(rhs = "Petal.Length + Petal.Width") fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB! fit_iris$train(task_iris) plot(parttree(fit_iris), main = "... and now mlr3")