20 Extension Case Study: Springs, Part 1

In this chapter we will create our very first extension: A version of geom_segment() that instead of drawing a straight line between two points will draw it as a stretched spring. By the end of it you will hopefully have gained a lot of insight into the process of designing a Stat extension along with its benefits and problems. We will revisit this extension Multiple times over the next chapters.

20.1 Thinking a bit before we code

Usually an extension starts with an idea about what we want to be able to draw. In this case this is given as the ability to draw a spring between two points, but even with such a clear objective there are plenty of unanswered questions about the interface, e.g. how should the interface look like. Many of answers to these questions will be informed by iterations on the implementation, but it is a good idea to think a bit about which capabilities should be present in the layer. How is the diameter of the spring set? Is it supposed to be an aesthetic that can be scaled? What about the spring tension? Is that a scaled aesthetic as well? Or are both simple values set for all springs in the layer?

Another thing we need to need to start to think about is how one even draws a spring? I guess there are many ways, but I would naively think that we trace a circle while also moving the “pen” in one direction:

circle <- tibble(
  x = sin(seq(0, 2*pi, length.out = 100)),
  y = cos(seq(0, 2*pi, length.out = 100)),
  index = 1:100,
  type = "circle"
)
spring <- circle
spring$x <- spring$x + seq(0, 1.5, length.out = 100)
spring$type <- "spring"
ggplot(rbind(circle, spring)) + 
  geom_path(aes(x = x, y = y, group = type, alpha = index), show.legend = FALSE) + 
  facet_wrap(~ type, scales = "free_x")

It is clear that simply continuing to trace the circle while moving along x will make the spring longer, and that the speed of the x-movement will control the tension of the spring. While I’m pretty sure this is not a physically correct parameterisation of a spring, it is good enough for our illustrative purpose.

20.2 Choosing an approach

When developing a new layer, one has a choice between developing a Stat or a Geom extension. The decision is not guided by whether you want to end out with a geom_spring() constructor as plenty of Stat extensions are used along with a dedicated geom_*() constructor. Instead you should consider what you are actually doing. Are we drawing using some pre-existing primitive but interpreting input data in a new way? Yes! We are drawing with a path but simply circling around instead of going straight. This is a clear case for a Stat which is great for transforming input data into something that can be drawn with a pre-existing geom. Further, Stats are much easier to extend than Geoms as they are simply data-transformation pipelines. Thus, in the following we will work towards creating a new Stat for drawing springs.

20.3 Building functionality

When developing a new Stat it often helps immensely to create the data transformation function separately. It allows you to only focus on the aspects of ggproto when you really need to, thus freeing some mental space for actual problem solving. Our data transformation should take a start and end point, a diameter, and a tension. We will define tension to mean “times of diameter moved per revolution minus one”, thus 0 will mean that it doesn’t move at all, and will be forbidden as it would not allow our spring to extend between two points. Further, we will add a parameter n giving the number of points to use per revolution, thus defining the visual fidelity of the spring.

create_spring <- function(x, y, xend, yend, diameter, tension, n) {
  if (tension <= 0) {
    rlang::abort("`tension` must be larger than 0")
  }
  # Calculate direct length of segment
  length <- sqrt((x - xend)^2 + (y - yend)^2)
  
  # Figure out how many revolutions and points we need
  n_revolutions <- length / (diameter * tension)
  n_points <- n * n_revolutions
  
  # Calculate sequence of radians and x and y offset
  radians <- seq(0, n_revolutions * 2 * pi, length.out = n_points)
  x <- seq(x, xend, length.out = n_points)
  y <- seq(y, yend, length.out = n_points)
  
  # Create the new data
  data.frame(
    x = cos(radians) * diameter/2 + x,
    y = sin(radians) * diameter/2 + y
  )
}

One nice thing about encapsulating the functionality in a separate function is that we can immediately test it out and convince our self that the logic works:

spring <- create_spring(
  x = 4, y = 2, xend = 10, yend = 6,
  diameter = 2, tension = 0.75, n = 50
)

ggplot(spring) + 
  geom_path(aes(x = x, y = y))

With the confidence we get from this, we are ready to encapsulate it all in a new Stat. We’ll create the Stat in one go, as this is how it is always defined, and then go over the single parts afterwards:

StatSpring <- ggproto("StatSpring", Stat, 
  setup_params = function(data, params) {
    if (is.null(params$diameter)) {
      params$diameter <- 1
    } else if (params$diameter == 0) {
      rlang::abort("Springs cannot be defined with a diameter of 0")
    }
    if (is.null(params$tension)) {
      params$tension <- 0.75
    } else if (params$tension <= 0) {
      rlang::abort("Springs must be defined with a tension greater than 0")
    }
    if (is.null(params$n)) {
      params$n <- 50
    } else if (params$n <= 0) {
      rlang::abort("Springs must be defined with `n` greater than 0")
    }
    params
  },
  setup_data = function(data, params) {
    if (anyDuplicated(data$group)) {
      data$group <- paste(data$group, seq_len(nrow(data)), sep = "-")
    }
    data
  },
  compute_panel = function(data, scales, diameter = 1, tension = 0.75, 
                           n = 50) {
    cols_to_keep <- setdiff(names(data), c("x", "y", "xend", "yend"))
    springs <- lapply(seq_len(nrow(data)), function(i) {
      spring_path <- create_spring(data$x[i], data$y[i], data$xend[i], 
                                   data$yend[i], diameter, tension, n)
      cbind(spring_path, unclass(data[i, cols_to_keep]))
    })
    do.call(rbind, springs)
  },
  required_aes = c("x", "y", "xend", "yend")
)

Let’s take the above code piece by piece: We first start with the class definition

StatSpring <- ggproto("StatSpring", Stat, 
  ...
}

This creates a new Stat subclass, named StatSpring. ggproto classes always uses CamelCase for naming, and the new class is always saved into a variable of the same name. If you are creating a special version of an existing Stat it is fully allowed to subclass that instead of the Stat base class, but this is more often found in Geom, Coord, and Facet subclasses.

Inside the class definition we provide a range of methods, simply by assigning functions to the argument names of interest. While it is allowed to dream up new methods, it is most common to simply put all functionality inside the definition of already existing methods. The Stat base class has a range of methods open for overwriting, and some you should steer clear off:

print(Stat)
#> <ggproto object: Class Stat, gg>
#>     aesthetics: function
#>     compute_group: function
#>     compute_layer: function
#>     compute_panel: function
#>     default_aes: uneval
#>     extra_params: na.rm
#>     finish_layer: function
#>     non_missing_aes: 
#>     parameters: function
#>     required_aes: 
#>     retransform: TRUE
#>     setup_data: function
#>     setup_params: function

The aesthetics and parameters methods are not meant for overwriting, but the other methods are all open during subclassing. As discussed in the previous chapters, the most important part is the three compute_* methods. One of these must always be defined (and usually the group or panel version). Two other very common methods to overwrite is the setup_* methods, which allows the class to do early checks and modifications of the parameters and data of the stat.

In our implementation we provide a setup_params() method with the purpose of making sure that all the required parameters are present and that they have supported values assigned. This is not strictly necessary as the default values in our compute_*() method will make sure that lack of defined parameters doesn’t cause problems. Still, it is nice to be explicit, especially with checking the values of the parameters as it is possible to create much more meaningful error messages at this point.

We also implement a setup_data() method. We do this to make sure that each row in the input data has a unique group aesthetic assigned. ggplot2 will always create a group aesthetic for us if none is present, but it will depend on the values and types of the other aesthetics. Since our stat will expand each row into multiple rows that should be drawn as separate paths, we need each row to have a unique group. Instead of simply assigning a new unique number to the group column, we paste it together with the old value. This is good practice as the group aesthetic is sometimes used to carry information.

For the actual computation we choose to define the compute_panel() method. This method will receive the data for each panel (plotting area) and do the transformation on that. Our stat handles each row individually, so if we instead implemented a compute_group() method our data would be split up into individual rows, a rather costly operation with very little benefit for us. As a general rule of thumb, if the stat operates on multiple rows, start by implementing a compute_group() method, and if the stat operates on single rows, implement a compute_panel() method.

Inside our compute_panel() method we do a bit more than simply call our create_spring() function.

StatSpring$compute_panel
#> <ggproto method>
#>   <Wrapper function>
#>     function (...) 
#> f(...)
#> 
#>   <Inner function (f)>
#>     function(data, scales, diameter = 1, tension = 0.75, 
#>                            n = 50) {
#>     cols_to_keep <- setdiff(names(data), c("x", "y", "xend", "yend"))
#>     springs <- lapply(seq_len(nrow(data)), function(i) {
#>       spring_path <- create_spring(data$x[i], data$y[i], data$xend[i], 
#>                                    data$yend[i], diameter, tension, n)
#>       cbind(spring_path, unclass(data[i, cols_to_keep]))
#>     })
#>     do.call(rbind, springs)
#>   }

We loop over each row of the data and create the points required to draw the spring. Then we combine our new data with all the non-position columns of the row. This is very important, since otherwise the aesthetic mappings to e.g. color and size would be lost. In the end we combine the individual springs into a single data frame that gets returned.

The last remaining part of our new class is the required_aes field. This is simply a character vector giving the names of aesthetics that the user must provide for the stat. The required_aes, along with default_aes and non_missing_aes also defines the aesthetics that this stat understands. Along with the counterpart in a geom, these aesthetics are the ones that a layer can work with and any mapping to aesthetics not mentioned will cause a warning and the mapping will be removed.

20.4 Constructors

Users never really see the ggproto objects (unless they go looking for them), since they are abstracted away into the well-known constructor functions that make up the ggplot2 API. Having created our stat, we should probably also create a constructor.

It should be noted that a constructor is not strictly needed. We currently have everything needed to call geom_path(stat = "spring"), but if we don’t provide a constructor there will be no place to document our new functionality.

While this is a stat, we will first and foremost create a geom_*() constructor. This is because most people are accustomed to adding geoms as opposed to stats when building up their plot. The constructor is mainly boilerplate code. Care should be taken to match the argument order and naming used in the geom constructors provided by ggplot2 itself since users have an internalized expectation about how they work.

geom_spring <- function(mapping = NULL, data = NULL, stat = "spring", 
                        position = "identity", ..., diameter = 1, tension = 0.75, 
                        n = 50, arrow = NULL, lineend = "butt", linejoin = "round", 
                        na.rm = FALSE, show.legend = NA, inherit.aes = TRUE) {
  layer(
    data = data, 
    mapping = mapping, 
    stat = stat, 
    geom = GeomPath, 
    position = position, 
    show.legend = show.legend, 
    inherit.aes = inherit.aes, 
    params = list(
      diameter = diameter, 
      tension = tension, 
      n = n, 
      arrow = arrow, 
      lineend = lineend, 
      linejoin = linejoin, 
      na.rm = na.rm, 
      ...
    )
  )
}

As can be seen, the constructor is a simple wrapper around layer() setting sensible defaults and funneling layer parameters into the correct location.

Having everything in place, we can now test out our new layer for the first time.

some_data <- tibble(
  x = runif(5, max = 10),
  y = runif(5, max = 10),
  xend = runif(5, max = 10),
  yend = runif(5, max = 10),
  class = sample(letters[1:2], 5, replace = TRUE)
)

ggplot(some_data) + 
  geom_spring(aes(x = x, y = y, xend = xend, yend = yend))

Having implemented our layer, we can now benefit from all the automatic features normally associated with ggplot2, such as scaling and faceting:

ggplot(some_data) + 
  geom_spring(aes(x = x, y = y, xend = xend, yend = yend, colour = class)) + 
  facet_wrap(~ class)

For completion we will also create a stat constructor. The stat_spring() function looks almost identical except that it does not force a geom but a stat instead. Further we will usually not provide default argument for the default geom:

stat_spring <- function(mapping = NULL, data = NULL, geom = "path", 
                        position = "identity", ..., diameter = 1, tension = 0.75, 
                        n = 50, na.rm = FALSE, show.legend = NA, 
                        inherit.aes = TRUE) {
  layer(
    data = data, 
    mapping = mapping, 
    stat = StatSpring, 
    geom = geom, 
    position = position, 
    show.legend = show.legend, 
    inherit.aes = inherit.aes, 
    params = list(
      diameter = diameter, 
      tension = tension, 
      n = n, 
      na.rm = na.rm, 
      ...
    )
  )
}

We will test this by drawing our springs with dots instead:

ggplot(some_data) + 
  stat_spring(aes(x = x, y = y, xend = xend, yend = yend, colour = class),
              geom = 'point', n = 15) + 
  facet_wrap(~ class)

20.5 Post-Mortem

We have now successfully created our first extension. The first implementation is almost never the end of it, though, as new insight appears during development and make you rethink your approaches. One shortcoming of our implementation is that diameter and tension can only be set for the full layer. These settings feel more like aesthetics and it would be nice if their values could be mapped to a variable in the data. We will revisit this point in the next chapter, were we rewrite our Stat to support this.

Another, potentially bigger, issue is that the spring path is relative to the coordinate system of the plot. This means that strong deviations from an aspect ratio of 1 will visibly distort the spring, as can be seen in the example below:

ggplot() + 
  geom_spring(aes(x = 0, y = 0, xend = 3, yend = 20))

The same underlying problem means that the diameter is expressed in coordinate space, meaning that it is difficult to define a meaningful default:

ggplot() + 
  geom_spring(aes(x = 0, y = 0, xend = 100, yend = 80))

This is one of the biggest downsides (or features, depending on your need) of shapes and line geoms defined as stats expansions. We will look into fixing this in a later chapter, by creating a new geom from the ground up.