You may want Spark to write to some data store (after all, there are some things that other technologies do better). Given an RDD distributed over many partitions on many nodes, how do you write the data efficiently?
You may choose to do something like this (where I've used println statements rather than explicit calls to the technologies of your choice):
rdd foreachPartition { iterator =>
println("create connection")
iterator foreach { element =>
println("insert")
}
println("close connection")
}
Naturally, you need to get a connection inside the function (as the function will be serialized and run on other partitions that might be on other machines).
Note that if the connection is referenced with a lazy val then each function will have its own connection (even on the same machine) and it will only be instantiated when it's run on its partition.
So, how do we know when to close the connection? A common answer is to just use connection pools.
Also note that the return type of foreachPartition is Unit so it's not too surprising that this is executed immediately since Unit hints at side effects. A quick look at the code of RDD confirms this.
Great. But what if you want to read data from some store and enhance the RDD? Now we're using mapPartitions that may very well be lazy. So, with similar code, the function might look like this:
val mapped = rdd mapPartitions { iterator =>
lazy val connection = getConnection()
val mapped = iterator map { element =>
if (connection.isOpen) select(element)
}
connection.close()
mapped
}
where I'm pretending to open and close some sort of pseudo-connection thus:
val nSelects = new AtomicInteger(0)
val nConnections = new AtomicInteger(0)
class Connection {
@volatile private var open = true
def close(): Unit = { open = false }
def isOpen(): Boolean = open
}
def getConnection(): Connection = {
nConnections.incrementAndGet()
new Connection
}
def select(element: Int): Int = nSelects.incrementAndGet()
Now, let's run RDD.count() so it should force even the lazy map to do something:
println("Finished. Count = " + mapped.count())
println("nSelects = " + nSelects.get() + ", nConnections = " + nConnections.get())
giving:
Finished. Count = 10000
nSelects = 0, nConnections = 0
What's this? We get the correct number of elements but no selects? What gives?
The issue is that the iterator given to the function passed to mapPartitions is lazy. This is nothing to do with Spark but is due to Scala's lazy collections (this iterator is actually a class of scala.collection.Iterator$$anon$11). If we ran this code with a real connection, we'd see that the it had closed by the time inner function wanted to do something with it. Odd, we might say since the call to close it comes later.
We could just force the mapping to run by calling size() on the resulting collection (which may or may not be efficient) but one solution suggested to me was:
val batchSize = 100
rdd mapPartitions { iterator =>
val batched = iterator.grouped(batchSize)
batched.flatMap { grouped =>
val connection = getConnection()
val seqMapped = grouped.map(insert(_))
connection.close()
seqMapped
}
}
This creates (and closes) a connection for batchSize each of elements but this could be more efficient than the usual contention in a connection pool. Remember, small contention can make a big difference when dealing with hundreds of millions of elements.
No comments:
Post a Comment