Title: | Visualise simple decision tree partitions |
---|---|
Description: | Simple functions for plotting 2D decision tree partition plots. |
Authors: | Grant McDermott [aut, cre] , Achim Zeileis [ctb] , Brian Heseung Kim [ctb] , Julia Silge [ctb] |
Maintainer: | Grant McDermott <[email protected]> |
License: | MIT + file LICENSE |
Version: | 0.0.1.9004 |
Built: | 2024-11-25 03:39:24 UTC |
Source: | https://github.com/grantmcdermott/parttree |
geom_parttree()
is a simple extension of
ggplot2::geom_rect()
that first calls
parttree()
to convert the inputted tree object into an
amenable data frame.
geom_parttree( mapping = NULL, data = NULL, stat = "identity", position = "identity", linejoin = "mitre", na.rm = FALSE, show.legend = NA, inherit.aes = TRUE, flipaxes = FALSE, ... )
geom_parttree( mapping = NULL, data = NULL, stat = "identity", position = "identity", linejoin = "mitre", na.rm = FALSE, show.legend = NA, inherit.aes = TRUE, flipaxes = FALSE, ... )
mapping |
Set of aesthetic mappings created by |
data |
An rpart::rpart.object or an object of compatible
type (e.g. a decision tree constructed via the |
stat |
The statistical transformation to use on the data for this
layer, either as a |
position |
Position adjustment, either as a string naming the adjustment
(e.g. |
linejoin |
Line join style (round, mitre, bevel). |
na.rm |
If |
show.legend |
logical. Should this layer be included in the legends?
|
inherit.aes |
If |
flipaxes |
Logical. By default, the "x" and "y" axes variables for
plotting are determined by the first split in the tree. This can cause
plot orientation mismatches depending on how users specify the other layers
of their plot. Setting to |
... |
Other arguments passed on to |
Because of the way that ggplot2
validates inputs and assembles
plot layers, note that the data input for geom_parttree()
(i.e. decision
tree object) must assigned in the layer itself; not in the initialising
ggplot2::ggplot()
call. See Examples.
geom_parttree()
aims to "work-out-of-the-box" with minimal input from
the user's side, apart from specifying the data object. This includes taking
care of the data transformation in a way that, generally, produces optimal
corner coordinates for each partition (i.e. xmin
, xmax
, ymin
, and
ymax
). However, it also understands the following aesthetics that users
may choose to specify manually:
fill
(particularly encouraged, since this will provide a visual
cue regarding the prediction in each partition region)
colour
alpha
linetype
size
parttree()
, ggplot2::geom_rect()
.
library(rpart) ### Simple decision tree (max of two predictor variables) iris_tree = rpart(Species ~ Petal.Length + Petal.Width, data=iris) ## Plot with original iris data only p = ggplot(data = iris, aes(x = Petal.Length, y = Petal.Width)) + geom_point(aes(col = Species)) ## Add tree partitions to the plot (borders only) p + geom_parttree(data = iris_tree) ## Better to use fill and highlight predictions p + geom_parttree(data = iris_tree, aes(fill = Species), alpha=0.1) ## To drop the black border lines (i.e. fill only) p + geom_parttree(data = iris_tree, aes(fill = Species), col = NA, alpha = 0.1) ### Example with plot orientation mismatch p2 = ggplot(iris, aes(x=Petal.Width, y=Petal.Length)) + geom_point(aes(col=Species)) ## Oops p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1) ## Fix with 'flipaxes = TRUE' p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1, flipaxes = TRUE) ### Various front-end frameworks are also supported, e.g.: library(parsnip) iris_tree_parsnip = decision_tree() %>% set_engine("rpart") %>% set_mode("classification") %>% fit(Species ~ Petal.Length + Petal.Width, data=iris) p + geom_parttree(data = iris_tree_parsnip, aes(fill=Species), alpha = 0.1) ### Trees with continuous independent variables are also supported. But you ### may need to adjust (or switch off) the fill legend to match the original ### data, e.g.: iris_tree_cont = rpart(Petal.Length ~ Sepal.Length + Petal.Width, data=iris) p3 = ggplot(data = iris, aes(x = Petal.Width, y = Sepal.Length)) + geom_parttree( data = iris_tree_cont, aes(fill = Petal.Length), alpha=0.5 ) + geom_point(aes(col = Petal.Length)) + theme_minimal() ## Legend scales don't quite match here: p3 ## Better to scale fill to the original data p3 + scale_fill_continuous(limits = range(iris$Petal.Length))
library(rpart) ### Simple decision tree (max of two predictor variables) iris_tree = rpart(Species ~ Petal.Length + Petal.Width, data=iris) ## Plot with original iris data only p = ggplot(data = iris, aes(x = Petal.Length, y = Petal.Width)) + geom_point(aes(col = Species)) ## Add tree partitions to the plot (borders only) p + geom_parttree(data = iris_tree) ## Better to use fill and highlight predictions p + geom_parttree(data = iris_tree, aes(fill = Species), alpha=0.1) ## To drop the black border lines (i.e. fill only) p + geom_parttree(data = iris_tree, aes(fill = Species), col = NA, alpha = 0.1) ### Example with plot orientation mismatch p2 = ggplot(iris, aes(x=Petal.Width, y=Petal.Length)) + geom_point(aes(col=Species)) ## Oops p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1) ## Fix with 'flipaxes = TRUE' p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1, flipaxes = TRUE) ### Various front-end frameworks are also supported, e.g.: library(parsnip) iris_tree_parsnip = decision_tree() %>% set_engine("rpart") %>% set_mode("classification") %>% fit(Species ~ Petal.Length + Petal.Width, data=iris) p + geom_parttree(data = iris_tree_parsnip, aes(fill=Species), alpha = 0.1) ### Trees with continuous independent variables are also supported. But you ### may need to adjust (or switch off) the fill legend to match the original ### data, e.g.: iris_tree_cont = rpart(Petal.Length ~ Sepal.Length + Petal.Width, data=iris) p3 = ggplot(data = iris, aes(x = Petal.Width, y = Sepal.Length)) + geom_parttree( data = iris_tree_cont, aes(fill = Petal.Length), alpha=0.5 ) + geom_point(aes(col = Petal.Length)) + theme_minimal() ## Legend scales don't quite match here: p3 ## Better to scale fill to the original data p3 + scale_fill_continuous(limits = range(iris$Petal.Length))
Extracts the terminal leaf nodes of a decision tree with one or two numeric predictor variables. These leaf nodes are then converted into a data frame, where each row represents a partition (or leaf or terminal node) that can easily be plotted in coordinate space.
parttree(tree, keep_as_dt = FALSE, flipaxes = FALSE)
parttree(tree, keep_as_dt = FALSE, flipaxes = FALSE)
tree |
A tree object. Supported classes include
rpart::rpart.object, or the compatible classes from
from the |
keep_as_dt |
Logical. The function relies on |
flipaxes |
Logical. The function will automatically set the y-axis
variable as the first split variable in the tree provided unless
the user specifies |
This function can be used with a regression or classification tree containing one or (at most) two numeric predictors.
A data frame comprising seven columns: the leaf node, its path, a set
of coordinates understandable to ggplot2
(i.e., xmin, xmax, ymin, ymax),
and a final column corresponding to the predicted value for that leaf.
geom_parttree()
, rpart::rpart()
, partykit::ctree()
.
## rpart trees library("rpart") rp = rpart(Species ~ Petal.Length + Petal.Width, data = iris) parttree(rp) ## conditional inference trees library("partykit") ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) parttree(ct) ## rpart via partykit rp2 = as.party(rp) parttree(rp2)
## rpart trees library("rpart") rp = rpart(Species ~ Petal.Length + Petal.Width, data = iris) parttree(rp) ## conditional inference trees library("partykit") ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) parttree(ct) ## rpart via partykit rp2 = as.party(rp) parttree(rp2)