In this chapter, we’ll look at some of the functional programming features of Scala, specifically the ubiquitous map
and flatMap
functions. We’re interested in these because they’re closely related to the idea of monads, a key feature of functional programming.
You’ll see the map
function on countless classes in Scala. It’s often described in the context of collections. Classes like List
, Set
, and Map
all have it. For these, it applies a given function to each element in the collection, giving back a new collection based on the result of that function. You “map” some function over each element of your collection.
For example, you could create a function that works out how old a person is given the year of their birth.
import
java.util.Calendar
def
age
(
birthYear
:
Int
)
=
{
val
currentYear
=
Calendar
.
getInstance
.
get
(
Calendar
.
YEAR
)
currentYear
-
birthYear
}
We could call the map
function on a list of birth years, passing in the function to create a new list of ages.
val
birthdays
=
List
(
1990
,
1977
,
1984
,
1961
,
1973
)
birthdays
.
map
(
age
)
The result would be a list of ages. We’ve transformed the year 1990 into an age of 25, for example.
res0
:
List
[
Int
]
=
List
(
25
,
38
,
31
,
54
,
42
)
Being a higher-order function, you could have written the function inline as a lambda like this:
birthdays
.
map
(
year
=>
Calendar
.
getInstance
.
get
(
Calendar
.
YEAR
)
-
year
)
Using the underscore as a shorthand for the lambda’s parameter, it would look like this:
birthdays
.
map
(
Calendar
.
getInstance
.
get
(
Calendar
.
YEAR
)
-
_
)
foreach
So map
is a transforming function. For collections, it iterates over the collection applying some function, just like foreach
does. The difference is that unlike foreach
, map
will collect the return values from the function into a new collection and then return that collection.
It’s trivial to implement a mapping function by hand. For example, we could create a class Mappable
that takes a number of elements of type A
and creates a map
function.
class
Mappable
[
A
](
val
elements
:
List
[
A
])
{
def
map
[
B
](
f
:
Function1
[
A
, B
])
:
List
[
B
]
=
{
???
}
}
The parameter to map
is a function that transforms from type A
to type B
; it takes an A
and returns a B
. I’ve written it longhand as a type of Function1
which is equivalent to Java 8’s java.util.function.Function
class. We can also write it using Scala’s shorthand syntax and the compiler will do the conversion for us.
def
map
[
B
](
f
:
A
=>
B
)
:
List
[
B
]
=
...
Then it’s just a question of creating a new collection, calling the function (using apply
) with each element as the argument. We’d store the result to the new collection and finally return it.
class
Mappable
[
A
](
val
elements
:
List
[
A
])
{
def
map
[
B
](
f
:
A
=>
B
)
:
List
[
B
]
=
{
val
result
=
collection
.
mutable
.
MutableList
[
B
]()
elements
.
foreach
{
result
+=
f
.
apply
(
_
)
}
result
.
toList
}
}
We can test it by creating a list of numbers, making them “mappable” by creating a new instance of Mappable
and calling map with an anonymous function that simply doubles the input.
object
Example
extends
App
{
val
numbers
:
List
[
Int
]
=
List
(
1
,
2
,
54
,
4
,
12
,
43
,
54
,
23
,
34
)
val
mappable
:
Mappable
[
Int
]
=
new
Mappable
(
numbers
)
val
result
=
mappable
.
map
(
_
*
2
)
println
(
result
)
}
The output would look like this:
List
(
2
,
4
,
108
,
8
,
24
,
86
,
108
,
46
,
68
)
You’ll often see the flatMap
function where you see the map
function. For collections, it’s very similar in that it maps a function over the collection, storing the result in a new collection, but with a couple of differences:
flatMap
also flattens the result to give a single collection.So,
A
, the map
function applies a function to each element transforming an A
to B
. The result is a collection of B
(i.e. List[B]
).A
, the flatMap
function applies a function to each element transforming an A
to a collection of B
. This results in a collection of collection of B
(i.e. List[List[B]]
) which is the flattened to a single collection of B
(i.e. List[B]
). Let’s say we want a mapping function to return a person’s age plus or minus a year. So if we think a person is 38, we’d return a list of 37, 38, 39.
import
java.util.Calendar
def
ages
(
birthYear
:
Int
)
:
List
[
Int
]
=
{
val
today
=
Calendar
.
getInstance
.
get
(
Calendar
.
YEAR
)
List
(
today
-
1
-
birthYear
,
today
-
birthYear
,
today
+
1
-
birthYear
)
}
The signature has changed from the previous example to return a List[Int]
rather than just an Int
. If we pass the list of birthday years into the map
function, we get a list of lists back (res0
below).
val
birthdays
=
List
(
1990
,
1977
,
1984
)
val
ages
=
birthdays
.
map
(
ages
)
println
(
ages
)
scala
>
birthdays
.
map
(
age
)
res0
:
List
[
List
[
Int
]]
=
List
(
List
(
24
,
25
,
26
),
List
(
37
,
38
,
39
),
List
(
30
,
31
,
32
))
If, however, we pass it into the flatMap
function, we get a flattened list back. It maps, then flattens.
scala
>
birthdays
.
flatMap
(
age
)
res1
:
List
[
Int
]
=
List
(
24
,
25
,
26
,
37
,
38
,
39
,
30
,
31
,
32
)
If you wanted to write your own version of flatMap
, it might look something like this (notice the return type of the function).
class
FlatMappable
[
A
](
elements
:
A*
)
{
def
flatMap
[
B
](
f
:
A
=>
List
[
B
])
:
List
[
B
]
=
{
val
result
=
collection
.
mutable
.
MutableList
[
B
]()
elements
.
foreach
{
f
.
apply
(
_
).
foreach
{
result
+=
_
}
}
result
.
toList
}
}
The first loop will enumerate the elements of the collection and apply the function to each. Because this function itself returns a list, another loop is needed to enumerate each of these, adding them into the result collection. This is the bit that flattens the function’s result.
To test it, let’s start by creating a function that goes from an Int
to a collection of Int
. It gives back all the odd numbers between zero and the argument.
def
oddNumbersTo
(
end
:
Int
)
:
List
[
Int
]
=
{
val
odds
=
collection
.
mutable
.
MutableList
[
Int
]()
for
(
i
<-
0
to
end
)
{
if
(
i
%
2
!=
0
)
odds
+=
i
}
odds
.
toList
}
We then just create an instance of our class with a few numbers in. Call flatMap
and you’ll see that all odd numbers from 0 to 1, 0 to 2, and 0 to 10 are collected into a list.
object
Example
{
def
main
(
args
:
Array
[
String
])
{
val
mappable
=
new
FlatMappable
(
1
,
2
,
10
)
val
result
=
mappable
.
flatMap
(
oddNumbersTo
)
println
(
result
)
}
}
The output would be the following:
List
(
1
,
1
,
1
,
3
,
5
,
7
,
9
)
We’ve seen how map
and flatMap
work for collections, but they also exist on many other classes. More generally, map
and flatMap
operate on what’s called monads. In fact, having map
and flatMap
behaviour is one of the defining features of monads.
So just what are monads? We’ll look at that next.