Daniel Ciocîrlan
6 min read •
Share on:
Much virtual ink has been spilled talking about Monads in Scala. Even here on Rock the JVM, we’ve talked about Monads in a variety of ways: as a practical necessity, or in their the most abstract possible representation, or as a component of Cats, a popular FP library for Scala that we teach here on the site.
In this article, we’ll approach the M word from yet another angle: the DRY principle and how the monad abstraction ends up simplifying our code, while making it more general, more expressive and more powerful at the same time.
The only requirements of this article are:
We’ll be writing Scala 3 for this article.
The Scala standard library is generally pretty consistent. Tools like map
and flatMap
are indispensable in data processing and various pieces of business logic. Those of you coming from the Essentials course should have these concepts fresh in your mind.
We can transform data structures easily with map
, flatMap
and for
comprehensions:
val aList = List(1,2,3,4)
val anIncrementedList = aList.map(_ + 1) // [2,3,4,5]
val aFlatMappedList = aList.flatMap(x => List(x, x + 1)) // [1,2,2,3,3,4,4,5]
The same concept applies to many useful data structures: Options, Try, IO all share the same properties:
Assume that you’re working on a small toolkit for your team. You want to be able to create a function that returns the combinations of strings and numbers out of many possible kinds of data structures:
With what we know so far, let’s sketch an implementation for lists:
def combineLists(str: List[String])(num: List[Int]): List[(String, Int)] = for {
s <- str
n <- num
} yield (s, n)
Given that Option and Try are different types, we have no choice but to duplicate this API for these types as well:
def combineOptions(str: Option[String])(num: Option[Int]): Option[(String, Int)] = for {
s <- str
n <- num
} yield (s, n)
def combineTry(str: Try[String])(num: Try[Int]): Try[(String, Int)] = for {
s <- str
n <- num
} yield (s, n)
The method signatures are different (and they have to be), yet they are strikingly similar, and their implementation is identical.
Can we do better?
Guided by the DRY principle, let’s abstract away this common structure.
What do we need? We need to be able to write a for-comprehension for any kind of data structure that has map
and flatMap
. We’re getting a little ahead of ourselves, so let’s try simply creating a type class that has these functionalities. We’ll provide a method for
map
flatMap
The type class is unsurprisingly called Monad
and it looks something like this:
trait Monad[M[_]] {
def pure[A](a: A): M[A]
def flatMap[A, B](ma: M[A])(f: A => M[B]): M[B]
// for free
def map[A, B](ma: M[A])(f: A => B): M[B] =
flatMap(ma)(a => pure(f(a)))
}
Notice our method signatures look a little different, because the methods do not belong to the magical data structures directly, but they belong to the type class. Otherwise, we have the same semantics as the map/flatMap combinations of standard data types.
Because we structured our code in this way, we get the map
method can be fully implemented in terms of pure + flatMap, so our Monad type is a natural and direct descendant of the Functor type class. But that’s a subject we approached in another article.
In any event, given the fact that we now have a type class, we need to follow the type class pattern and create some type class instances and create a single, unifying, general API that requires the presence of a type class instance for our particular type. The steps follow.
Let’s create, as an example, a type class instance for Monad for List:
given monadList: Monad[List] with {
override def pure[A](a: A) = List(a)
override def flatMap[A, B](ma: List[A])(f: A => List[B]) = ma.flatMap(f)
}
As an implementation, we rely on the standard Scala collection library.
Let’s also define our unifying API:
def combine[M[_]](str: M[String])(num: M[Int])(using monad: Monad[M]): M[(String, Int)] =
monad.flatMap(str)(s => monad.map(num)(n => (s, n)))
Notice that we’re not using for-comprehensions as we can’t yet (we depend on the monad instance), but the semantics stay the same. Given this new method signature and implementation, we can now safely combine lists:
println(combine(List("a", "b", "c"))(List(1,2,3))) // [(a,1), (a,2), (a,3), (b,1), (b,2), (b,3), (c,1), (c,2), (c,3)]
The magical benefit is that if we have a given Monad[Option]
in scope, we can safely call combine
on Options as well. Our user-facing API does not require a specific type now, but only the presence of a given Monad
of that type in scope. You can now combine Options, Try, IOs, States, or whatever other kinds of monads you want.
The Monad type class was driven, in this article, by the necessity to not duplicate any code.
That said, we can go further and make our instances of M[_]
(whatever M is) behave like our original duplicated API. In other words, we can grant instances of M[_]
the ability to do for comprehensions. How can we do that? Using extension methods.
For any M[_]
for which there is a given Monad[M]
in scope, we can enhance instances of M with the map and flatMap methods, so they behave exactly like our original lists:
extension [M[_], A](ma: M[A])(using monad: Monad[M]) {
def map[B](f: A => B): M[B] = monad.map(ma)(f)
def flatMap[B](f: A => M[B]): M[B] = monad.flatMap(ma)(f)
}
So with that in place, we can create another version of our general combine API to look like this:
def combine_v2[M[_] : Monad](str: M[String])(num: M[Int]): M[(String, Int)] = for {
s <- str
n <- num
} yield (s, n)
where the M[_] : Monad
means that we require the presence of a given Monad[M]
in scope. Because of that, we can now use map
and flatMap
on our instances of M[String]
and M[Int]
as well, thereby enabling for-comprehensions in the most general terms.
Testing the new API, we’ll see that all of them work in the same way:
combineLists(List("a", "b", "c"))(List(1,2,3,4)) == combine(List("a", "b", "c"))(List(1,2,3,4)) // true
combineLists(List("a", "b", "c"))(List(1,2,3,4)) == combine_v2(List("a", "b", "c"))(List(1,2,3,4)) // true
Although we explored many facets of Monads on the blog and in the Cats course, in this article we saw another necessity for Monads as a type class: the need to not repeat ourselves in API design and implementation. Through Monads, we can now create general ways of chaining computations, with significant implications in our code at all levels.
Share on: