Implementing C++ Coroutines

Coroutines (alongside Concepts) are my favourite C++20 feature. They allow us to (re)write deeply nested asynchronous as (seemingly) synchronous code, removing much of the nested calls and old complexity on top of it.

Introduction

Imagine, a land far away where Coroutines haven't been invented (imagine 6 years ago).

You're working on an embedded system that makes long-lasting calls to the hardware. As a requirement, your software cannot wait for hardware operations to complete, as there are a lot of hardware boards to communicate with, and waiting on each one would take far too long.

For the sake of this article, imagine there are 10 hardware components our Software must detect, communicate with and overall manage.
Each of these boards is relatively slow and takes around ~1 second to complete a single non-intensive operation.

We don't want to synchronously block on every single call we make to any singular board as our software would run at a snail's pace. Imagine starting up Firefox took 10 seconds! The users would riot!

So, what are your options to prevent your entire application from blocking on long-lasting operations and making sure your users (and PO's) don't riot against you and your slow Software?

The most straightforward option is to use threads, however they come with significant drawbacks:

  • Starting a thread is relatively expensive: we need to allocate memory for a stack (which we might not be fully using), we need to make some syscalls to the kernel to setup our new thread, ...
  • Multithreaded programs come with the most nefarious race-conditions and potential deadlocks. Of course, you can work around these issues via clever use of locks, mutexes and best practices, but they remain a constant source of late-night debugging sessions
  • Your system doesn't have an infinite number of threads it can create. Although unlikely, if you have to create a lot of threads, you might start running out of memory to create those threads.*
* "Why would you ever create that many threads?" I hear you ask. Our example here only uses 10 boards, but that number could be many times larger. A thread (on most Linux architectures) uses 2MB of memory, by default, for it's stack. If you have 100MB of RAM you won't be able to create a lot of threads while also supporting the kernel itself and your software.

Alright, so you decide to not use threads because you don't want to be woken up at 3 AM to deal with a race condition in your software that's affecting your company's biggest customer.

What else can we try?

Promises (based on the Javascript ones) seem like a great solution:

  • You can use .then() calls to handle your asynchronous code and provide a callback to run when your asynchronous code completes
  • You can provide .catch() functions to handle errors
  • You can enqueue all of your work on an Eventloop or an I/O service (such as boost::asio::io_service (https://www.boost.org/doc/libs/1_65_1/doc/html/boost_asio/overview/core/basics.html)
  • You no longer need to spawn a thread for each blocking call you make to your hardware, as you can simply enqueue the work on your Eventloop or I/O service. (Or, at the very least, spawn a lot less threads).

This is already a step in the right direction, but some problems start rearing their head:

  • .then() is great, until you have a lot of nested .then() calls:
MySocketType socket{};
socket.connect("MyServer:1234").then(
    [&socket]() {
        socket.send("FirstPartOfData").then(
            []() {
                socket.send("SecondPartOfData").then(
                ); // and many more ...
            }
        ).catch(
            [](std::runtime_error const & ex) {
                std::cerr << "Error sending data to server: " << ex.what() << "\n";
            }
        );
    }
).catch(
    [](std::runtime_error const & ex) {
        std::cerr << "Creating connection failed: " << ex.what() << "\n";
    }
);

.then() calls can become really unreadable!

  • You have to remember to always .catch() errors on .then() calls or your errors will get swallowed in your Promise type, forever lost.

Now, both of these problems could be managed with good documentation, good code structure and wrapper objects and function calls, however, those classes will still have this problem, nor can your user code fully avoid nesting .then() calls if it needs to make a lot of sequential asynchronous calls.


Why use coroutines

Enter, coroutines!

Invented in ye ancient times, coroutines (specifically awaiting Promises) have been widely used in Javascript (since 2016), Python (since 2015) and C# (since 2015) to great success!

Now, we finally get to do the same in C++.
Why are coroutines so amazing? Well, if we were to take our earlier example and transform it into coroutines:

MySocketType socket{};

try {
  co_await socket.connect("MyServer:1234");
} catch (std::runtime_error const & ex) {
  std::cerr << "Creating connection failed: " << ex.what() << "\n";
}

try {
  co_await socket.send("FirstPartOfData);
  co_await socket.send("SecondPartOfData");
  // and more!
} catch (std::runtime_error const & ex) {
  std::cerr << "Error sending data to server: " << ex.what() << "\n";
}

Much more readable, no?

Instead of callback after callback, we get to write code that looks synchronous (which is a danger we'll discuss later) but is asynchronous without needing to register a callback every time an asynchronous operation completes.

This allows you to write far more maintainable code as you'll no longer have nested callbacks inside of more nested callbacks dealing with sequential asynchronous calls.

Now that I've finally explained why you should be using coroutines, let's move on to actually implementing them.


Implementing Coroutines

Let's start by making a Promise type. What's a Promise?

Note that this section talks a lot about Promises in 3 different contexts:
- Promise as the custom type we will be implementing.
- Promise is the general concept of a Promise (explained below).
- promise_type is what we provide to the compiler to define how our coroutines will work.

A Promise represents an the state of some asynchronous work (our coroutine), some kind of work we promise to complete to whoever gets our Promise object.
A Coroutine is inherently tied to a Promise. A Coroutine cannot exist without a Promise representing it.

Our Promise type will need a couple of basics:

  • A way to check if the associated asynchronous work has completed
  • A callback to execute once our work is done
  • A way to add a callback
  • A way to execute our callback

Our implementation might look like:

class Promise {
private:
  bool m_isReady;
  std::function<void()> m_callback;

public:
  Promise() = default;

  bool IsReady() const {
    return m_isReady;
  }

  void AddCallback(std::function<void()> cb) {
    m_callback = std::move(cb);
  }

  void Set() {
    // we can only execute a Promise once
    if (m_isReady) return;
    
    m_isReady = true;
    m_callback();
  }
};

This covers the absolute basics of our Promise type. We'll be expanding it once we start implementing coroutines.


To implement coroutines, we need to tell the compiler what the promise_type type is that we want to use for our coroutines.
We do this by implementing std::coroutine_traits for our Promise object. The compiler expects to find a promise_type type in the std::coroutine_traits struct, which it will use as the Promise for the coroutines.

This promise_type type we provide also tells the compiler how our coroutine will work, what underlying Promise type it actually uses.

A promise_type must implement the following functions:

  • get_return_object(), which must return your Promise (in our case: Promise)
  • initial_suspend(), which must return an awaitable (we'll cover this later)
  • final_suspend(), which must return an awaitable (we'll cover this later)
  • return_void() if your Promise is of type void, otherwise return_value()
  • unhandled_exception(), which handles any uncaught exceptions in your coroutine

Our implementation will look like:

template<typename ... Args>
struct std::coroutine_traits<Promise, Args...> {
  // This struct could also be a 
  // using promise_type = SOMETYPE;
  // if you prefer to define your promise_type outside this struct.
  struct promise_type {
    Promise promise; // The Promise object associated with our coroutine
    Promise get_return_object() {
      return promise;
    }

    *TO BE DISCUSSED* initial_suspend() noexcept;
    *TO BE DISCUSSED* final_suspend() noexcept;

    void return_void(); // IMPLEMENTATION TO BE DISCUSSED

    void unhandled_exception(); // IMPLEMENTATION TO BE DISCUSSED
  };
};

This struct now tells the compiler already it's most crucial information: What is the Promise's type we want to use for our coroutine? In our case: Promise.

Promise MyFirstCoroutine()
{
  std::cout << "Yippee!\n";
}

This coroutine does not do any asynchronous work yet, but we'll get to that real soon.

Awaiting Promises

Earlier, I showed that the amazing thing about coroutines is that you can co_await asynchronous work, allowing you to write your asynchronous statements sequentially without having to provide all kinds of callbacks to control your program's flow.

This is where Awaitables and Awaiters come into play.
When you type co_await MyAsyncFunction(); the compiler will, magically, do a number of things for you:

  1. It will convert everything to the right of co_await (the expression) into an Awaitable.
  2. It will then try to obtain an Awaiter (object) from the Awaitable by:
    1. Calling Awaitable.operator co_await() (if it exists)
    2. Searching for the non-member overload operator co_await(Awaitable &&)
    3. Or finally, use the Awaitable as an Awaiter (if possible)

Since we'll always be awaiting coroutines which return our Promise type, our Promise type will need to provide a way to get an Awaiter (since the Promise type is an Awaitable).

But, first things first, we need an actual Awaiter object to return from our Promise. An Awaiter must implement the following:

  • await_ready(), to check if our Awaitable (= our coroutine) is ready.
  • await_suspend(), to suspend our Awaitable (suspend our asynchronous work).
  • await_resume(), called after our Awaitable has been resumed after being suspended.

Our Awaiter implementation will be leveraging our Promise, as that's what our Awaitable is:

class Awaiter {
private:
  Promise & m_promise;

public:
  Awaiter(Promise & promise) : m_promise(promise) {}

  bool await_ready() {
    return m_promise.IsReady();
  }

  *TO BE DISCUSSED* await_suspend(std::coroutine_handle<> handle);

  void await_resume();
};

await_ready()

The simplest function here is await_ready(). It returns a bool with a very simple effect:

  • If we return true, then our coroutine is ready and does not need to be suspended. This means our coroutine won't actually be doing asynchronous work and need not be suspended, or the coroutine is already done.
  • If we return false, then our coroutine is not ready and we should suspend the coroutine by calling await_suspend()

await_ready() allows us to slightly optimize our coroutines by avoiding the cost of suspending one if we already know that our coroutine has finished.


await_suspend()

await_suspend() is the most complicated part of the Awaiter.
It is passed a std::coroutine_handle<> which represents this coroutine's handle (internal state).
await_suspend()'s return type largely determines how it works:

  • If await_suspend() returns void, then we suspend our coroutine (the current coroutine) and immediately yield back control to the caller/resumer (a different coroutine) of our coroutine.
  • If await_suspend() returns a bool, then:
    • if we return true, we suspend our coroutine (the current coroutine) and immediately yield back control to the caller/resumer (a different coroutine) of our coroutine
    • if we return false, we don't suspend our coroutine, but simply resume it
  • If await_suspend() returns a std::coroutine_handle, then we resume the returned handle while our coroutine remains suspended (assuming we didn't return our own handle).

The difficult question is: Which one should you pick? And as usual: it depends 😄

Usually, you'll want to use the void variant as you do want to suspend your coroutine while waiting on asynchronous work (await_ready() should be ensuring you're not suspending your coroutine once it's already finished waiting on its work), however the 3 variants all have their place and allow you to do very interesting things with your coroutines (such as collecting coroutine handles, forcefully not suspending coroutines if you know you're only doing synchronous work, ...).


To keep going with our implementation, I'll use the void variant so we showcase the "basic" case of always suspending your coroutine.

class Awaiter {
private:
  Promise & m_promise;

public:
  Awaiter(Promise & promise) : m_promise(promise) {}

  bool await_ready() {
    return m_promise.IsReady();
  }

  void await_suspend(std::coroutine_handle<> handle) {
    m_promise.AddCallback([handle]() {
      handle.resume();
    });
  }

  void await_resume();
};

A very simple implementation: Upon reaching a suspension point (co_await) we simply add a callback to our Promise object (that we get from the coroutine) which simply resumes the coroutine once the Promise finishes.

await_resume() won't really do anything in our example, but you can always customize it to your needs; it can contain debug logs or some memory cleanup, ...


operator co_await: Tying the Promise and Awaiter together.

Now that we have a simple Promise and Awaiter object, we can finally tie them together and allow our coroutines to become Awaitables and be awaited.

We do this by implementing operator co_await on our Promise:

class Promise {
public:
  // ... Previous code shown above ...

  Awaiter operator co_await() {
    return Awaiter{ *this };
  }
};

That's all! When co_await Promise; is called we return an Awaiter which will determine whether our current coroutine needs to be suspended or not.

If we tie it all together we have:

class Promise {
private:
  bool m_isReady;
  std::function<void()> m_callback;

public:
  Promise() = default;

  bool IsReady() const {
    return m_isReady;
  }

  void AddCallback(std::function<void()> cb) {
    m_callback = std::move(cb);
  }

  void Set() {
    // we can only execute a Promise once
    if (m_isReady) return;
    
    m_isReady = true;
    m_callback();
  }

  Awaiter operator co_await() {
    return Awaiter{ *this };
  }
};

class Awaiter {
private:
  Promise & m_promise;

public:
  Awaiter(Promise & promise) : m_promise(promise) {}

  bool await_ready() {
    return m_promise.IsReady();
  }

  void await_suspend(std::coroutine_handle<> handle) {
    m_promise.AddCallback([handle]() {
      handle.resume();
    });
  }

  void await_resume() {}
};

We're now very close to having a fully working coroutine implementation


initial_suspend() and final_suspend()

Our std::coroutine_traits implementation still missed a couple of implementations, for one, initial_suspend() and final_suspend().

These functions are here to answer a design question of your choosing: When I call a coroutine, does it execute immediately, or does it not do any work until I await it?

The former are called eager coroutines, while the latter are called lazy coroutines. Both can make a lot of sense, and both happen in the real world. In the end, you should decide on what makes the most sense for your project, and for your own (or your team's) understanding.

To implement initial_suspend() and final_suspend() you simply provide them with the return type of std::suspend_never if you want eager coroutines or std::suspend_always if you want lazy coroutines.

Depending on your choice, calling a coroutine will cause it execute immediately or only execute once awaited.


return_void() and return_value()

At the end of a coroutine, you either explicitly co_return;, co_return <value>; or do nothing (in which case the compiler will automatically add co_return; for you).
This acts just like a normal return statement in a non-coroutine function, however, for our co_return statement we need to tell the compiler what to do.

This is what we implement return_void() or return_value() for: These functions are called when the compiler hits a co_return statement.

Your promise_type can only implement one of these (which makes sense: your Promise can only support 1 type).

In our example, Promise didn't support any non-void types, so we would implement return_void() to do the following:

template<typename ... Args>
struct std::coroutine_traits<Promise, Args...> {
  struct promise_type {
    Promise promise; // The Promise object associated with our coroutine
    Promise get_return_object() {
      return promise;
    }

    std::suspend_never initial_suspend() noexcept { return {}; }
    std::suspend_never final_suspend() noexcept { return {}; }

    void return_void() {
      promise.Set();
    }

    void unhandled_exception(); // IMPLEMENTATION TO BE DISCUSSED
  };
};

Since our coroutine is ending in return_void() (or return_value()) we simply tell our Promise that it is now done, which in turn will trigger its callback (which is being awaited somewhere by an Awaiter).


unhandled_exception()

To finish up our implementation: The final function we must implement is: What happens if an exception is thrown in our coroutine and it's not caught?

You could do plenty of things (it depends), but the most obvious thing is to simply re-throw the error:

template<typename ... Args>
struct std::coroutine_traits<Promise, Args...> {
  struct promise_type {
    Promise promise; // The Promise object associated with our coroutine
    Promise get_return_object() {
      return promise;
    }

    std::suspend_never initial_suspend() noexcept { return {}; }
    std::suspend_never final_suspend() noexcept { return {}; }

    void return_void() {
      promise.Set();
    }

    void unhandled_exception() {
      std::rethrow_exception(std::current_exception());
    }
  };
};

Wherever your Promise is being awaited will then receive an exception that can be caught and handled.


Integrating your coroutines

Now we have a full coroutine implementation, but we don't yet have a way to actually do asynchronous work.

This involves posting your asynchronous work on either a thread(pool) or an eventloop of some sort, but implementing these would be an entire topic on it's own, so I'll simply suggest some libraries to use rather than rolling your own solution (which is a lot more complicated than implementing coroutines).


Maybe I'll do a write-up one day of making your own eventloop as this is a very interesting topic and will help you truly understand coroutines and how they work.

Thanks for reading, I hope this helps you in any future coroutine endeavours!

I'll be writing a blog-post about asynchronous memory safety, as this is a topic you're bound to run into, and smash your head against the wall.