Tuning Spark Back Pressure by Simulation

Spark back pressure, which can be enabled by setting spark.streaming.backpressure.enabled=true, will dynamically resize batches so as to avoid queue build up. It is implemented using a Proportional Integral Derivative (PID) algorithm. This algorithm has some interesting properties, the most interesting of which is, in contrast with TCP-style probing algorithms, the lack of guarantee of a stable fixed point. This can manifest itself not just in transient overshoot, but in a batch size oscillating around a (potentially optimal) constant throughput. The overshoot incurs latency; the undershoot costs throughput. Catastrophic overshoot leading to OOM is possible in degenerate circumstances (you need to choose the parameters quite deviously to cause this to happen). Having witnessed undershoot and slow recovery in production streaming jobs, I decided to investigate further by testing the algorithm with a simulator. This is very simple to do in the context of JUnit by creating an instance of a PIDRateEstimator and calling its methods within a simulation loop.

PID Controllers

The PID controller is a closed feedback loop on a single process variable, which it aims to stabilise (but may not be able to) by minimising three error metrics – present error (w.r.t. the last measurement), accumulated or integral error, and the rate at which the error is changing, or derivative error. The total signal is a weighted sum of these factors.

u(t) = K_{p}e(t) + K_{i}\int_{0}^{t} e(\tau) d\tau + K_{d}\frac{d}{dt}e(t)

The K_{p} term aims to react to any immediate error, the K_{i} term dampens change in the process variable, and the K_{d} amplifies any trend to react to underlying changes quickly. Spark allows these parameters to be tuned by setting the following environment variables

  • K_{p} spark.streaming.backpressure.pid.proportional (defaults to 1, non-negative) – weight for the contribution of the difference between the current rate with the rate of the last batch to the overall control signal.
  • K_{i} spark.streaming.backpressure.pid.integral (defaults to 0.2, non-negative) – weight for the contribution of the accumulation of the proportional error to the overall control signal.
  • K_{d} spark.streaming.backpressure.pid.derived (defaults to zero, non-negative) – weight for the contribution of the change of the proportional error to the overall control signal.

The default values are typically quite good. You can set an additional parameter not present in classical PID controllers: spark.streaming.backpressure.pid.minRate – the default value is 100 (must be positive). This is definitely a variable to watch out for if you are using back-pressure and have up front knowledge that you are expected to process messages at a rate much higher than 100; you can improve stability by minimising undershoot.

Constant Throughput Simulator

Consider the case where all else is held equal and a job will execute at a constant rate, with batches scheduled at a constant frequency. The size of the batch is allowed to vary according to back pressure, but the rate is constant. These assumptions negate the proportional amortization of fixed overhead in larger batches (economy of scale), and any extrinsic fluctuation in rate. The goal of the algorithm is to find, in a reasonable number of iterations, a stable fixed point for the batch size, close to the constant rate implied by the frequency. Depending on your batch interval, this may be a long time in real terms. The test is useful in two ways

  1. Evaluating the entire parameter space of the algorithm: finding toxic combinations of parameters and initial throughputs which lead to divergence, oscillation or sluggish convergence.
  2. Optimising parameters for a use case: for a given expected throughput rate, rapidly simulate the job under backpressure to optimise the choice of parameters.

Running a simulator is faster than running Spark jobs with production data repeatedly so can be useful for tuning the parameters. The suggested settings need to be validated afterwards by testing on an actual cluster.

  1. Fix a batch interval and a target throughput
  2. Choose a range possible values for K_{p}K_{i}K_{d} between zero and two
  3. Choose a range of initial batch sizes (as in spark.streaming.receiver.maxRate), above and below the size implied by the target throughput and frequency
  4. Choose some minimum batch sizes to investigate the effect on damping.
  5. Run a simulation for each element of the cartesian product of these parameters.
  6. Verify the simulation by running a Spark streaming job using the recommended parameters.

From 4000 test cases, I found over 500 conditions where the algorithm oscillated violently out of control, despite maintaining constant throughput, typically when the integral and derivative components are large.

In a real Spark Streaming job this would result in an OOM. When I get time I will try to find a mathematical justification for this to verify this isn’t a bug in the simulator.

There are also interesting cases where the algorithm is stable and convergent. The first is catching up from a conservative estimate of throughput – which seems across a range of parameters not to overshoot.

When the algorithm needs to slow down it will always undershoot before converging. This cost will be amortized as will the latency incurred. However, if throughput is preferred it is better to overestimate the initial batch; if latency is preferred it is better to underestimate.

Simulator Code

public class TestPIDController {

  @Parameterized.Parameters(name = "Aim for {5}/s starting from initial batch={6} with interval {0}s, p={1}, i={2}, d={3}")
  public static Object[][] generatedParams() {

    double requiredThroughput = 5000;
    long maxlatency = 1L;

    double[] proportionals = new double[10];
    double[] integrals = new double[10];
    double[] derivatives = new double[10];
    double v = 0.0;
    for(int j = 0; j < proportionals.length; ++j) {
      proportionals[j] = v;
      integrals[j] = v;
      derivatives[j] = v;
      v += 2D/proportionals.length;

    double[] initialBatchSizes = new double[] { 2500, 4500, 5500, 7500 };
    double[] minBatchSizes = new double[] { 100, 500, 1000, 2500, 4500};
    int numTestCases = proportionals.length * integrals.length * derivatives.length * initialBatchSizes.length;
    Object[][] cases = new Object[numTestCases][];
    for(int caseNum = 0; caseNum < numTestCases; ++caseNum) {
      cases[caseNum] = new Object[7];
      cases[caseNum][0] = maxlatency;
      cases[caseNum][5] = requiredThroughput;
    for(int caseNum = 0; caseNum < numTestCases; ++caseNum) {       cases[caseNum][1] = proportionals[caseNum % proportionals.length];     }     Arrays.sort(cases, (a, b) -> (int)(20 *(double)a[1] - 20 * (double)b[1]));
    for(int caseNum = 0; caseNum < numTestCases; ++caseNum) {       cases[caseNum][2] = integrals[caseNum % integrals.length];     }     Arrays.sort(cases, (a, b) -> (int)(20 * (double)a[2] - 20 * (double)b[2]));
    for(int caseNum = 0; caseNum < numTestCases; ++caseNum) {       cases[caseNum][3] = derivatives[caseNum % derivatives.length];     }     Arrays.sort(cases, (a, b) -> (int)(20 * (double)a[3] - 20 * (double)b[3]));
    for(int caseNum = 0; caseNum < numTestCases; ++caseNum) {       cases[caseNum][4] = minBatchSizes[caseNum % minBatchSizes.length];     }     Arrays.sort(cases, (a, b) -> (int)((double)a[4] - (double)b[4]));
    for(int caseNum = 0; caseNum < numTestCases; ++caseNum) {       cases[caseNum][6] = initialBatchSizes[caseNum % initialBatchSizes.length];     }     Arrays.sort(cases, (a, b) -> (int)((double)a[6] - (double)b[6]));
    return cases;

  public TestPIDController(long batchSizeSeconds,
                           double proportional,
                           double summation,
                           double derivative,
                           double minRate,
                           double constantProcessingRate,
                           double initialBatchSize) {
    this.expectedBatchDurationSeconds = batchSizeSeconds;
    this.proportional = proportional;
    this.summation = summation;
    this.derivative = derivative;
    this.minRate = minRate;
    this.constantProcessingRatePerSecond = constantProcessingRate;
    this.initialBatchSize = initialBatchSize;

  private final long expectedBatchDurationSeconds;
  private final double proportional;
  private final double summation;
  private final double derivative;
  private final double minRate;
  private final double constantProcessingRatePerSecond;
  private final double initialBatchSize;

  public void ensureRapidConvergence() {
    System.out.println("Time,Scheduling Delay,Processing Delay,Throughput,Batch Size");
    long schedulingDelayMillis = 0;
    double batchSize = initialBatchSize;
    long expectedBatchDurationMillis = 1000 * expectedBatchDurationSeconds;
    double batchTimeSeconds;
    long batchTimeMillis;
    long timeMillis = 0;
    PIDRateEstimator estimator = new PIDRateEstimator(expectedBatchDurationMillis, proportional, summation, derivative, minRate);
    Option<Object> newSize;
    double numProcessed = 0;
    double throughput = Double.NaN;

    for(int i = 0; i < 100; ++i) {       // sanity check
      if(timeMillis > 200 * expectedBatchDurationMillis)
      numProcessed += batchSize;
      batchTimeSeconds = getTimeToCompleteSeconds(batchSize);
      batchTimeMillis = (long)Math.ceil(batchTimeSeconds * 1000);
      long pauseTimeMillis = schedulingDelayMillis == 0 && batchTimeSeconds <= expectedBatchDurationSeconds ? expectedBatchDurationMillis - batchTimeMillis : 0;
      timeMillis += batchTimeMillis + pauseTimeMillis;
      newSize = estimator.compute(timeMillis, (long)batchSize, batchTimeMillis, schedulingDelayMillis);
      if(newSize.isDefined()) {
        batchSize = (double)newSize.get();
      long processingDelay = batchTimeMillis - expectedBatchDurationMillis;
      schedulingDelayMillis += processingDelay;
      if(schedulingDelayMillis < 0) schedulingDelayMillis = 0;
      throughput = numProcessed/timeMillis*1000;
      System.out.println(String.format("%d,%d,%d,%f,%d", timeMillis,schedulingDelayMillis, processingDelay, throughput,(long)batchSize));

    double percentageError = 100 * Math.abs((constantProcessingRatePerSecond - throughput) /
    Assert.assertTrue(String.format("Effective rate %f more than %f away from target throughput %f",
                                    throughput, percentageError, constantProcessingRatePerSecond), percentageError < 10);

  private double getTimeToCompleteSeconds(double batchSize) {
    return batchSize / (constantProcessingRatePerSecond);

Response to Instantaneous and Sustained Shocks

There is an important feature of a control algorithm I haven’t simulated yet – how does the algorithm respond to random extrinsic shocks, both instantaneous and sustained. An instantaneous shock should not move the process variable away from its fixed point for very long. Under a sustained shock, the algorithm should move the process variable to a new stable fixed point.

Co-locating Spark Partitions with HBase Regions

HBase scans can be accelerated if they start and stop on a single region server. IO costs can be reduced further if the scan is executed on the same machine as the region server. This article is about extending the Spark RDD abstraction to load an RDD from an HBase table so each partition is co-located with a region server. This pattern could be adopted to read data into Spark from other sharded data stores, whenever there is a metadata protocol available to dictate partitioning.

The strategy involves creating a custom implementation of the Spark class RDD, which understands how to create partitions from metadata about HBase regions. To read data from HBase, we want to execute a scan on a single region server, and we want to execute on the same machine as the region to minimise IO. Therefore we need the start key, stop key, and hostname for each region associated with each Spark partition.

public class HBasePartition implements Partition {

  private final String regionHostname;
  private final int partitionIndex;
  private final byte[] start;
  private final byte[] stop;

  public HBasePartition(String regionHostname, int partitionIndex, byte[] start, byte[] stop) {
    this.regionHostname = regionHostname;
    this.partitionIndex = partitionIndex;
    this.start = start;
    this.stop = stop;

  public String getRegionHostname() {
    return regionHostname;

  public byte[] getStart() {
    return start;

  public byte[] getStop() {
    return stop;

  public int index() {
    return partitionIndex;

The HBase interface RegionLocator, which can be obtained from a Connection instance, can be used to build an array of HBasePartitions. It aids efficiency to check if it is possible to skip each region entirely, if the supplied start and stop keys do not overlap with its extent.

public class HBasePartitioner implements Serializable {

  public Partition[] getPartitions(byte[] table, byte[] start, byte[] stop) {
    try(RegionLocator regionLocator = ConnectionFactory.createConnection().getRegionLocator(TableName.valueOf(table))) {
      List<HRegionLocation> regionLocations = regionLocator.getAllRegionLocations();
      int regionCount = regionLocations.size();
      List<Partition> partitions = Lists.newArrayListWithExpectedSize(regionCount);
      int partition = 0;
      for(HRegionLocation regionLocation : regionLocations) {
        HRegionInfo regionInfo = regionLocation.getRegionInfo();
        byte[] regionStart = regionInfo.getStartKey();
        byte[] regionStop = regionInfo.getEndKey();
        if(!skipRegion(start, stop, regionStart, regionStop)) {
          partitions.add(new HBasePartition(regionLocation.getHostname(),
                                            max(start, regionStart),
                                            min(stop, regionStop)));
      return partitions.toArray(new Partition[partition]);
    catch (IOException e) {
      throw new RuntimeException("Could not create HBase region partitions", e);

  private static boolean skipRegion(byte[] scanStart, byte[] scanStop, byte[] regionStart, byte[] regionStop) {
    // check scan starts before region stops, and that the scan stops before the region starts
    return min(scanStart, regionStop) == regionStop || max(scanStop, regionStart) == regionStart;

  private static byte[] min(byte[] left, byte[] right) {
    if(left.length == 0) {
      return left;
    if(right.length == 0) {
      return right;
    return Bytes.compareTo(left, right) < 0 ? left : right;   }   private static byte[] max(byte[] left, byte[] right) {     if(left.length == 0) {       return right;     }     if(right.length == 0) {       return left;     }     return Bytes.compareTo(left, right) >= 0 ? left : right;

Finally, we can implement an RDD specialised for executing HBasePartitions. We want to exploit the ability to choose or influence where the partition is executed, so need access to a Scala RDD method getPreferredLocations. This method is not available on JavaRDD, so we are forced to do some Scala conversions. The Scala/Java conversion work is quite tedious but necessary when accessing low level features on a Java-only project.

public class HBaseRDD<T> extends RDD<T> {

  private static <T> ClassTag<T> createClassTag(Class<T> klass) {
    return scala.reflect.ClassTag$.MODULE$.apply(klass);

  private final HBasePartitioner partitioner;
  private final String tableName;
  private final byte[] startKey;
  private final byte[] stopKey;
  private final Function<Result, T> mapper;

  public HBaseRDD(SparkContext sparkContext,
                  Class<T> klass,
                  HBasePartitioner partitioner,
                  String tableName,
                  byte[] startKey,
                  byte[] stopKey,
                  Function<Result, T> mapper) {
    super(new EmptyRDD<>(sparkContext, createClassTag(klass)), createClassTag(klass));
    this.partitioner = partitioner;
    this.tableName = tableName;
    this.startKey = startKey;
    this.stopKey = stopKey;
    this.mapper = mapper;

  public Iterator<T> compute(Partition split, TaskContext context) {
    HBasePartition partition = (HBasePartition)split;
    try(Connection connection = ConnectionFactory.createConnection()) {
      Scan scan = new Scan()
      Table table = connection.getTable(TableName.valueOf(tableName));
      ResultScanner scanner = table.getScanner(scan);
      return JavaConversions.asScalaIterator(
              StreamSupport.stream(scanner.spliterator(), false).map(mapper).iterator()
    catch (IOException e) {
      throw new RuntimeException("Region scan failed", e);

  public Seq<String> getPreferredLocations(Partition split) {
    Set<String> locations = ImmutableSet.of(((HBasePartition)split).getRegionHostname());
    return JavaConversions.asScalaSet(locations).toSeq();

  public Partition[] getPartitions() {
    return partitioner.getPartitions(Bytes.toBytes(tableName), startKey, stopKey);

As far as the interface of this class is concerned, it’s just normal Java, so it can be used from a more Java-centric Spark project, despite using some Scala APIs under the hood. We could achieve similar results with mapPartitions, but would have less control over partitioning and co-location.

HBase Connection Management

I have built several web applications recently using Apache HBase as a backend data store. This article addresses some of the design concerns and approaches made in efficiently managing HBase connections.

One of the first things I noticed about the HBase client API was how long it takes to create the connection. HBase connection creation is effectively Zookeeper based service discovery, the end result being a client which knows where all the region servers are, and which region server is serving which key space. This operation is expensive and needs to be minimised.

At first I only created the connection once, when I started the web application. This is very simple and is fine for most use cases.

public static void main(String[] args) throws Exception {
        Configuration configuration = HBaseConfiguration.create();
        Connection connection = ConnectionFactory.createConnection(configuration);

This approach is great unless there is the requirement to proxy your end user when querying HBase. If Apache Ranger is enabled on your HBase cluster, proxying your users allows it to apply user specific authorisation to the query, rather than to your web application service user. This poses a few constraints: the most relevant being that you need to create a connection per user so you can’t just connect when you start your application any more.

Proxy Users

I needed to proxy users and minimise connection creation, so I built a connection pool class which, given a user principal, creates a connection as the user. I used Guava’s loading cache to handle cache eviction and concurrency. Guava’s cache also has a very useful eviction listener, which allows the connection to be closed when evicted from the cache.

In order to get the user proxying working, the UserGroupInformation for the web application service principal itself is required (see here), and you need to have successfully authenticated your user (I used SPNego to do this). The Hadoop class UserProvider is then used to create a proxy user. Your web application service principal also needs to be configured as a proxying user in core-site.xml, which you can manage via tools like Ambari.

public class ConnectionPool implements Closeable {

  private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionPool.class);
  private final Configuration configuration;
  private final LoadingCache<String, Connection> cache;
  private final ExecutorService threadPool;
  private final UserProvider userProvider;
  private volatile boolean closed = false;
  private final UserGroupInformation loginUser;

  public ConnectionPool(Configuration configuration, UserGroupInformation loginUser) {
    this.loginUser = loginUser;
    this.configuration = configuration;
    this.userProvider = UserProvider.instantiate(configuration);
    this.threadPool = Executors.newFixedThreadPool(50, new ThreadFactoryBuilder().setNameFormat("hbase-client-connection-pool").build());
    this.cache = createCache();

  public Connection getConnection(Principal principal) throws IOException {
    return cache.getUnchecked(principal.getName());

  public void close() throws IOException {
    if(!closed) {
      closed = true;

  private Connection createConnection(String userName) throws IOException {
      UserGroupInformation proxyUserGroupInformation = UserGroupInformation.createProxyUser(userName, loginUser);
      return ConnectionFactory.createConnection(configuration, threadPool, userProvider.create(proxyUserGroupInformation));

  private LoadingCache<String, Connection> createCache() {
    return CacheBuilder.newBuilder()
                       .expireAfterAccess(10, TimeUnit.MINUTES)
            .<String, Connection>removalListener(eviction -> {
              Connection connection = eviction.getValue();
              if(null != connection) {
                try {
                } catch (IOException e) {
                  LOGGER.error("Connection could not be closed for user=" + eviction.getKey(), e);
            .build(new CacheLoader<String, Connection>() {
              public Connection load(String userName) throws Exception {
                LOGGER.info("Create connection for user={}", userName);
                return createConnection(userName);

One drawback of this approach is that the user experiences a slow connection the first time they query the server or any time after their connection has been evicted from the cache. They will also observe a lag if you are sharding your application behind a load balancer without sticky sessions. If you use a round robin strategy connection creation costs will be incurred whenever there is a new instance/user combination route.