Victor Farazdagi

Victor Farazdagi

Computer science and applied mathematics

07 Apr 2024

Typestate pattern in Rust

1 Introduction

Rust’s type system is quite sophisticated, allowing for plethora of useful idioms.

One of such idioms is using FSMs (Finite State Machines) to model objects’ behavior: some object starts in an initial state, then, by using transition function (which can take some extra arguments beside the input state), it can move to another state (where the destination state is not arbitrary, there are valid and invalid transitions), and so on.

The beauty of Rust is that it allows us to encode such FSMs directly into the type system, in so-called typestate pattern: a pattern where object’s run-time state is encoded in, and thus is enforced, — at compile time, — by the object’s type.

Basically, you make it impossible to have an object in an invalid state, or transition to an arbitrary state without first meeting required conditions. In addition, some other useful properties, like shared data between states or object’s capabilities in a given state, are also enforced by the type system.

This pattern is such a beautiful and simple idea that chances are you’ve already used it.

Any time you move an object into function, where it is consumed and another object is returned, you are using a variation of a typestate pattern: indeed, you cannot pass into that transition function an object of an invalid type, or return object of some type other than the one defined in function’s signature — so, transition invariants are enforced at compile time.

Builder pattern

To further investigate the idea, consider the builder pattern: you start with an object, configure it by calling various methods, and then, when you are done, you call `build` method to get the final object. This is also a variation of a `typestate` pattern: you cannot transform object into invalid one, and you cannot obtain a final object directly (no constructor) — only via transition methods of builder structure.

2 Illustrative example

Let’s consider a simple example:

  • We are writing a distributed database, which for our purposes is just a bunch of nodes redundantly storing some data and communicating with each other to service users’ requests.
  • Since we are writing robust and failure-tolerant system, we need to handle the situations when nodes are leaving and joining the cluster carefully: for instance, in order for a recently joined node to start catering for requests, it must first, say, sync data with its peers — so before node is marked as Running some preliminary work has to be performed.

Clearly, a node can be in one of the predefined states, nodes in different states need to perform different functionality, and thus possess different capabilities/methods, there are also some invariants when it comes to transitioning from one state to another. So, it seems that managing nodes’ transitions and states in our cluster using some implementation of FSM is a good idea.

Let’s assume that we have the following states for our nodes:

StateDescriptionNext state
NewNode has just been discovered and is joining the cluster.Syncing, Failed
SyncingNode is syncing data with its peers.Running, Failed
RunningNode is ready to serve requests.Leaving, Failed
LeavingNode is leaving the cluster.Removed, Failed
FailedNode is failed and needs to be removed.Removed
RemovedNode is removed from the cluster.-

2.1 Defining states

Note: all the code for this and the following sections is available on GitHub.

States should be defined as structs, where each state is a separate struct with its own methods and, when necessary, data. If something looks more like an enum, try slicing it into several struct states.

So, let’s define our states:

 1pub mod states {
 2    use super::NodeState;
 3
 4    pub struct New {}
 5    pub struct Syncing {}
 6    pub struct Running {}
 7    pub struct Leaving {}
 8    pub struct Failed {}
 9    pub struct Removed {}
10
11    impl NodeState for New {}
12    impl NodeState for Syncing {}
13    impl NodeState for Running {}
14    impl NodeState for Leaving {}
15    impl NodeState for Failed {}
16    impl NodeState for Removed {}
17}
18
19pub trait NodeState {}

Everything is pretty straightforward: we define a module states where we define our states as structs, with a NodeState marker trait (so, that a structure can be marked as a state).

Now, when defining a Node, we will pass a state as a type parameter:

1pub struct Node<S: NodeState> {
2    _marker: std::marker::PhantomData<S>,
3}

2.2 Enforcing state type

So, we have a Node structure, which takes a state as a type parameter. This means that end-users will not be able to create a node containing something other than a structure that implements the NodeState trait.

One additional requirement might be that end-users are not allowed to extend the list of available states. This is easily achievable by marking states with a sealed NodeState marker trait.

To seal the NodeState trait, we can define it as:

 1pub trait NodeState: private::Sealed {}
 2
 3mod private {
 4    use super::states::*;
 5
 6    pub trait Sealed {}
 7
 8    impl Sealed for New {}
 9    impl Sealed for Syncing {}
10    impl Sealed for Running {}
11    impl Sealed for Leaving {}
12    impl Sealed for Failed {}
13    impl Sealed for Removed {}
14}

Since Sealed trait is defined in a private module, it is not accessible to end-users, and thus they will not be able to implement NoteState for an arbitrary struct.

Invariants:

  • All states must implement NodeState trait.
  • State list is closed to extension by end-user (optionally, when necessary).

2.3 Enforcing initial state

Note that due to Node having a private field _marker, it is impossible to create a Node without a constructor function, that’s the following will fail:

1    let node = Node::<states::New> {};
2    let node = Node::<states::Syncing> {};
3    let node = Node::<states::Running> {};
4    // ...

And that’s a good thing: the initial node state can be enforced.

Let’s create a node in a valid initial state:

1impl Node<states::New> {
2    pub fn new() -> Self {
3        Node {
4            _marker: std::marker::PhantomData,
5        }
6    }
7}

Since the type parameter is selected, new() method can only be called on Node<states::New>, and produce a node in only a valid state — New in this case.

Invariants:

  • By having a private field in Node (_marker at the moment, but later on we will replace it with state’s data), we enforce that a node can only be created via constructor function.
  • By defining a constructor function only for a specific state, we enforce that a node can be created in a valid state only.

2.4 State transitions

One of the most important aspects of the typestate pattern is its ability to enforce valid state transitions — a node should not be able to transition to some unexpected state. This is done by defining methods on source states, and those methods return a node in a valid destination state.

For example, if from the New state you can transition to Syncing, and from Running to Leaving, you should define separate transition methods on Node<New> and Node<Running>, so that both source and destination states are controlled:

 1type NodeResult<T> = Result<T, Box<dyn Error>>;
 2
 3impl Node<states::New> {
 4    /// Start the node.
 5    pub fn start(self) -> NodeResult<Node<states::Syncing>> {
 6        Ok(Node {
 7            _marker: std::marker::PhantomData,
 8        })
 9    }
10}
11
12impl Node<states::Running> {
13    /// Stop the node.
14    pub fn stop(self) -> NodeResult<Node<states::Leaving>> {
15        Ok(Node {
16            _marker: std::marker::PhantomData,
17        })
18    }
19}

2.5 Multiple destination states

For most situations, having a single destination state per input state and data is enough:

1impl Node<InputState> {
2    pub fn transition_fn1(self, input_data1: SomeType) -> Result<Node<DestState1>> {
3        // ...
4    }
5
6    pub fn transition_fn2(self) -> Result<Node<DestState2>> {
7        // ...
8    }
9}

Since Result is used, it is possible to return an error in case of a failed transition.

However, in a situation where a single transition function can return, — say, depending on the input arguments, — node of different destination types, you can use an Either sum type idiom (see either::Either if something a bit most sophisticated than a raw enum is required — that crate defines a lot of useful macros and methods) :

 1pub enum Either<L, R> {
 2    Left(L),
 3    Right(R),
 4}
 5
 6impl Node<InputState> {
 7    pub fn transition_fn(self, arg: bool) -> Either<Node<DestState1>, Node<DestState1>> {
 8        if arg {
 9            Either::Right(Node {
10                _marker: std::marker::PhantomData,
11            })
12        } else {
13            Either::Right(Node {
14                _marker: std::marker::PhantomData,
15            })
16        }
17    }
18}

2.6 Shared functionality

Now, it is almost trivial to expose some shared functionality for all states:

1impl<S: NodeState> Node<S> {
2    pub fn fail(self) -> Node<states::Failed> {
3        Node {
4            _marker: std::marker::PhantomData,
5        }
6    }
7}

So, we just define an implementation for a generic state S, and voila, all states can fail now.

What if we want to expose some functionality only for a specific subset of states (the fail() probably should not be defined on Failed state — as the node is already in the desired state)?

Easy:

 1/// Selects the states that are active.
 2pub trait NodeStateActive: NodeState {}
 3
 4impl<S: NodeStateActive> Node<S> {
 5    pub fn fail(self) -> Node<states::Failed> {
 6        Node {
 7            _marker: std::marker::PhantomData,
 8        }
 9    }
10}

All we need to do is to define a new trait NodeStateActive, which is implemented only for states for which we want to provide a shared functionality.

2.7 Storing state data

We are almost done, but there is one more thing: how to store data when transitioning from state to state?

Generally, the purpose of storing data is one of the following:

  • Data is shared between states. For instance, a node might have some configuration that is used in all states.
  • Data is specific to a state. For instance, a node might have some state-specific data that is only used in that state.

2.7.1 Shared data

In order to store shared data, we can define a NodeContext structure, and inject it into Node:

 1/// Shared context for all nodes.
 2#[derive(Debug, Default, Clone, PartialEq)]
 3pub struct NodeContext {
 4    pub name: String,
 5}
 6
 7impl NodeContext {
 8    pub fn new(name: &str) -> Self {
 9        NodeContext {
10            name: name.to_string(),
11        }
12    }
13}
14
15pub struct Node<S: NodeState> {
16    ctx: NodeContext, // <<<< line added
17    _marker: std::marker::PhantomData<S>,
18}

In transition functions, we need to pass the context along:

1impl Node<states::New> {
2    /// Start the node.
3    pub fn start(self) -> NodeResult<Node<states::Syncing>> {
4        Ok(Node {
5            ctx: self.ctx, // <<<< line added
6            _marker: std::marker::PhantomData,
7        })
8    }
9}

2.7.2 State-specific data

When we need to store state-specific data, we can define it in state’s structure:

 1pub mod states {
 2    use super::NodeState;
 3
 4    pub struct Failed {
 5        reason: String,
 6    }
 7
 8    impl Failed {
 9        pub fn new(reason: String) -> Self {
10            Self { reason }
11        }
12
13        pub fn reason(&self) -> &str {
14            &self.reason
15        }
16    }
17
18    impl NodeState for Failed {}
19}

Now, that we have a state-specific data within a state, we need to store the state object in the Node:

1pub struct Node<S: NodeState> {
2    ctx: NodeContext,
3    state: S, // <<<< no more need for phantom data
4}

To accommodate this change, we need to update the constructor and transition functions, for example for the Node<New>:

 1impl Node<states::New> {
 2    /// Create a new node in a valid initial state.
 3    pub fn new(ctx: NodeContext) -> Self {
 4        Node {
 5            ctx,
 6            state: states::New {},
 7        }
 8    }
 9
10    /// Start the node.
11    pub fn start(self) -> NodeResult<Node<states::Syncing>> {
12        Ok(Node {
13            ctx: self.ctx,
14            state: states::Syncing {},
15        })
16    }
17}

For our fail() method, the following update is necessary:

1/// Define shared methods on a subset of nodes (active nodes only).
2impl<S: NodeStateActive> Node<S> {
3    pub fn fail(self, reason: &str) -> Node<states::Failed> {
4        Node {
5            ctx: self.ctx,
6            state: states::Failed::new(reason.to_string()), // <<<< line added
7        }
8    }
9}

Then we can access state-specific data:

1impl Node<states::Failed> {
2    // ..skipped..
3
4    pub fn reason(&self) -> &str {
5        &self.state.reason()
6    }
7}

Please note that we have access to self.state.reason() method on the Failed state only (states::Failed::reason() is called)! And as with every other invariant in this article, this is checked at compile time.

3 Summary

This article introduces the typestate pattern, a powerful tool of enforcing state machines' invariants.

We have created a simple example, where:

  • Only state objects can be passed as state type parameter.
  • State list can be closed to extension by end-user.
  • Initial state is enforced.
  • Valid state transitions are enforced.
  • Shared functionality can be exposed for all states.
  • Functionality can be exposed for a subset of states.
  • Shared context data can be stored and accessed between state transitions.
  • State-specific data can be stored and accessed.

And the best part is that all these invariants are checked at compile time.