← blog

A scalable groupByKey and secondary sort for (Java) Spark

Spark is, in our opinion, the new reference Big Data processing framework. Its flexible API allows for unified batch and stream processing, and can be easily extended for many purposes. It also incorporates a new API called Data Frames which makes it terribly easy to analyze data while providing a solid columnar storage by default (Parquet). Spark is changing fast, currently going through a big refactor (Project Tungsten) which will bring it to the next level.

We see currently Spark as the default tool for building Big Data projects. Those who, like us, have been working with Hadoop’s MapReduce for a long time, encounter many subtle (and sometimes not so subtle) differences when approaching Big Data problems with both frameworks. This post covers some of them.

Today we will talk about one particular problem: How to group our data in Spark, given that some groups might contain a tremendous amount of data?

First, of all, for the cases where the function we want to apply to our groups is associative, what we need to do is use Spark’s reduceByKey (see this link). This mechanism basically encapsulates Hadoop’s Combiner-Reducer, so Spark itself is in charge of combining the associative function as many times as he likes. But how about the cases where we don’t have an associative function to apply? In such cases we would use Spark’s groupByKey(). But this method is assuming that the data for a key will fit in memory (see this issue).

In Big Data it is common to have outliers (entities inside a dataset that have enormous amounts of data associated to it). Think about the list of followers for Lady Gaga. Think about robots that generate tremendous amounts of page views. Think about click-stream data associated with very popular websites.

  • What if we need to analyze their data in various ways, sorted by time?
  • Or simply choosing a representative sample of their data, for example using Reservoir Sampling?
  • Or calculate an approximate distribution, for example using Ted Dunning’s T-Digest?
  • Or train a machine learning model for each such group?

There are many cases where the function we want to apply to our groups is not associative, and where the data for our groups might not always fit in memory. In some cases we might want to perform multiple associative functions at once over the same group, and not have to launch a parallel job for each of them.

The idea here is that it is possible to partition an RDD arbitrarily, and iterate over each partition’s data, sorted in a particular way. We then just need to “detect” when a group starts or ends. In order to make this easier, we created the following abstraction (full code in Github):

   * The API version with secondary sort.
   * @param rdd
   *          a pair RDD that can be grouped by its key
   * @param secondarySort
   *          a comparator to be applied over the values of each group
   * @param handler
   *          the business logic to apply over each group
   * @return an rdd of the results of applying the groupByKey
  public static <G extends Comparable<G>, T, K> JavaRDD<K> groupByKeySecondarySort(JavaPairRDD<G, T> rdd,
      SerializableComparator<T> secondarySort, GroupByKeyHandler<G, T, K> handler) {

    JavaPairRDD<Tuple2<G, T>, Void> allInKey = rdd
        .mapToPair(tuple -> new Tuple2<Tuple2<G, T>, Void>(new Tuple2<>(tuple._1(), tuple._2()), Void.instance()));

    final int partitions = allInKey.partitions().size();
    return allInKey.repartitionAndSortWithinPartitions(new Partitioner() {
      public int getPartition(Object obj) {
        Tuple2<G, T> group = (Tuple2<G, T>) obj;
        return Math.abs(group._1().hashCode()) % partitions;

      public int numPartitions() {
        return partitions;
    }, new GroupAndValueTupleComparator<>(secondarySort))
        .mapPartitions(i -> new IteratorIterable<>(new GroupIterator<>(i, handler)));

The code compiles for Java 8, but the same idea could be applied to Scala. What this code does, essentially:

  • Accept a PairRDD with a key and a value.
  • Create a new PairRDD where we put both the data and the key in a Tuple2 key.
  • Repartition this RDD by the key element (group._1().hashCode()).
  • Sort each partition by the key + a specific Comparator for the values (to achieve secondary sort).
  • Return an Iterator to Spark that will hold the results of applying each group.

There are some tricks to it:

  • We can’t return an Iterator to Spark, we are forced to return an Iterable due to this issue. We took the “IterableIterator” class from there.
  • The Comparator we accept from the end user should be of a special kind we called “SerializableComparator”, so that we can safely ship it to Spark:
  public static interface SerializableComparator<G> extends Comparator<G>, Serializable {

We can also provide a simpler version that doesn’t care about secondary sort – it will just call the previous method with a comparator that always returns 0:

  public static <G extends Comparable<G>, T, K> JavaRDD<K> groupByKey(JavaPairRDD<G, T> rdd,
      GroupByKeyHandler<G, T, K> handler) {
    return groupByKeySecondarySort(rdd, (t1, t2) -> 0, handler);

Finally, here’s an example on how this could be used. This code takes a RDD with a String group and an Integer value and returns a RDD with the String and its top N values:

    List<Tuple2<String, Integer>> elements = Lists.newArrayList();
    elements.add(new Tuple2<>("A", 5));
    elements.add(new Tuple2<>("A", 6));
    elements.add(new Tuple2<>("A", 9));

    elements.add(new Tuple2<>("B", 10));
    elements.add(new Tuple2<>("B", 6));
    elements.add(new Tuple2<>("B", 20));

    elements.add(new Tuple2<>("C", 10));

    // ctx is a JavaSparkContext
    JavaRDD<Tuple2<String, Integer>> tuples = ctx.parallelize(elements, 1);
    JavaPairRDD<String, Integer> pairRdd = tuples.mapToPair(tuple -> tuple);
    final int topN = 2;
    // note how we sort by the value DESC via lambda (elem1, elem2) -> elem2.compareTo(elem1)
    JavaRDD<Tuple2<String, List<Integer>>> topElementsPerLetter = ScalableGroupByKey.groupByKeySecondarySort(pairRdd,
        (elem1, elem2) -> elem2.compareTo(elem1), (group, iterator) -> {          
          List<Integer> topElements = new ArrayList<>(topN);
          while (iterator.hasNext()) {
            Integer el = iterator.next();
            if(topElements.size() < topN) {
          return new Tuple2<>(group, topElements);

This API has one important tradeoff: it still needs to keep in memory the result of the user code for every group, just because the result of processing a group has to be returned from the handler. For those cases where the output will be at least as big as the input, this abstraction makes not so much sense (it would be better to just use repartitionAndSortWithinPartitions followed by mapPartitions, manually). It could be possible to create another, more complicated abstraction that provides to the user both an Iterator over the group and an Emitter interface. But then the user code and the Iterator returned to Spark would need to run on different Threads, and communicate using a BlockingQueue. Obviously a much more complicated topic, out of the scope of this blog post.

This “hack” on top of current Spark’s API is meant to make things easier when iterating over big groups, but the proper solution will be provided eventually by the Spark API – tracked in this issue.

Leave a Comment