< Back to Blog

Efficient broadcast joins in Spark, using Bloom filters

22 November 2018 Mohamed Abdelbary

This article assumes working knowledge of Apache Spark and Python

Joining two RDDs is a common operation when working with Spark. In a lot of cases, a join is used as a form of filtering, for example, you want to perform an operation on a subset of the records in the RDD, represented by entities in another RDD. While you can use an inner join to achieve that effect, sometimes you want to avoid the shuffle that the join operation introduces, especially if the RDD you want to use for filtering is significantly smaller than the main RDD on which you will perform your further computation.

The next logical thing is to do a broadcast join using a set constructed by collecting the smaller RDD you wish to filter by. However, this means collecting the whole RDD in driver memory, and even if it is relatively small (100s of thousands to say 1M records), that can still lead to some undesirable memory pressure.

Bloom filters can provide a neat solution to this problem. Bloom filters are efficient probabilistic data structures constructed of a set of values to be used for membership tests. It can tell you if an arbitrary element being tested might be in the set, or definitely not in the set, that is false positives are allowed, but false negatives are not. The data structure is a bit array, onto which elements are mapped using a hash function. The mapping basically sets some bits to 1 leaving the rest as 0's. The size of the bit array is determined by how much false positives you are willing to tolerate, so most implementations accept an FPR param in the constructor of the data structure (typical value is 1%). The efficiency of the data structure stems from two reasons

  • Elements are reduced to a compact bit representation, removing all the overhead introduced by the original data structure, which in case of a dynamic language like Python for instance (where a single set or list can have mixed types), can be quite significant.
  • Tolerating a certain FPR means that you need fewer bits to encode the elements you want to test for.

False positives occur when multiple elements hash to exactly the same set bits (collision), thus the Bloom filter will return a positive membership test (element is in the set), when it's actually a different element. As you reduce the accepted FPR, the Bloom filter will need a bigger bit array, progressively reducing the space advantage you get.

But so what? We still haven't solved the original problem. Even if the Bloom filter ultimately is more compact than a full set of the original elements, we still need to collect the original RDD entirely on the driver before we construct the Bloom filter, correct?

Well not really. Bloom filters can be used to build other Bloom filters, so you can build the Bloom filter progressively on Spark executors before merging into the final Bloom filter on the driver using a Spark action such as reduce. The Python library pybloom provides a neat interface to union filters using a union method as if you were combining regular Python sets. So you can write a simple method like

def _merge_bloom_filters(filter_a, filter_b):
    """
    Takes in two pybloom based bloom filters
    and returns the result of their union.
    Meaning that any element giving a positive
    answer in either, will give a positive
    answer to the union.
    """

    return filter_a.union(filter_b)

to combine two filters. Using that you can progressively use Spark to build "intermediate" Bloom filters on partitions of your data using a mapPartitions operation, and then only merge the lightweight Bloom filters on the driver. You can have a simple method that constructs a Bloom filter off a partition of data as follows

def _construct_bloom_filter(records, bloom_capacity, bloom_fpr):
    """
    Constructs a bloom filter that includes as members all
    passed records. The records are assumed to be hashable.
    """

    bloom_filter = BloomFilter(bloom_capacity, bloom_fpr)
    for record in records:
        bloom_filter.add(record)
    yield bloom_filter

If your input data has too many partitions, you can reduce that before constructing the intermediate filters using a coalesce operation, which won't introduce a shuffle.

And finally, using those two pieces, you can construct the final Bloom filter by doing a union of all filters in a reduce operation. The operation is associative and commutative so the order of performing the merges of the intermediate filters doesn't matter. Putting it all together you can build the final filter using the below snippet

from pybloom import BloomFilter


def bloom_filter_from_rdd(rdd, bloom_fpr=0.01, bloom_capacity=None, num_partitions=30):
    """
    Reads in an rdd, returns
    a bloom filter including the rdd
    records as members.
    
    If bloom capacity is not specified, it will count
    number of elements in the rdd as the bloom capacity.
    You would typically set it to a number < rdd.count() and
    informed by your choice of FPR
    
    The false positive rate(fpr) for that capacity is
    passed as a function parameter. Bear in mind that you might
    not be able to construct a filter with the provided capacity
    if the FPR is too low and you have a large number of elements,
    in which case pybloom will automatically expand the filter to a much
    larger capacity.
    
    num_partitions specifies number of partitions to coalesce to
    for the intermediate partition bloom filters. If partitions in
    rdd < num_partitions, it won't do anything and will construct intermediate
    bloom filters off existing partitions. The partition bloom filters will
    then get merged into the final bloom filter using the reduce call.
    """

    if not bloom_capacity:
        bloom_capacity = rdd.count()

    if bloom_capacity == 0:
        return set()

    # construct bloom filters
    bloom_filters_rdd = rdd \
        .coalesce(num_partitions) \
        .mapPartitions(lambda records: _construct_bloom_filter(records, bloom_capacity, bloom_fpr))

    # merge partition bloom filters
    # into the single final bloom filter
    return bloom_filters_rdd \
        .reduce(_merge_bloom_filters)

Which utilises the two methods we explained before. You can then use the final Bloom filter to filter other RDDs using a simple broadcast filter (pybloom filters are serialisable, hence can be broadcast to Spark executors).

And there you go, you have a Bloom filter constructed from your RDD in an efficient and non memory intensive way, that you can use for filtering or broadcast joins, provided you can tolerate a certain false positive rate.

Want to see your career bloom? Join us, we're hiring now.