Fork/Join框架

Fork/Join框架是Java7提供的用户并行任务执行的框架。原理是将一个大任务分割成多个小任务,最终再将小任务的结果汇聚,从而得到大任务的结果。

工作窃取算法

工作窃取算法(work-stealing)是指某个线程从其他任务队列窃取任务来执行。如下图所示:

当我们处理一个大任务时,可以把它拆分成多个小任务,为了减少线程间的竞争,将其放到不同的队列中去。然后为每个队列创建一个单独的线程来处理任务,但是有的线程会先处理完其内部的任务。这个时候就会帮助其他线程处理任务,从其他线程队列的尾部窃取一个任务处理。为了减少线程间的竞争,通常会使用双端队列,本线程会从队列的头部取任务,窃取线程会从队列的尾部取任务。

窃取算法的优点是充分利用的线程的并行计算,减少了线程间的冲突产生。其缺点是并不能完全的避免线程竞争,比如双端队列中只有一个任务时。同时双端队列和多线程也消耗了更多的线程资源。

参考:http://ifeve.com/talk-concurrency-forkjoin/

Fork/Join详解

Fork/Join框架的步骤主要分为下面两步:

(1)分隔任务:首先需要一个Fork类用来将大任务分割成子任务,直到分割出的子任务足够小。(代码结构可以类比递归)

(2)执行任务并合并结果:分隔的子任务存放在双端队列中,然后几个启动线程分别从双端队列获取任务并执行。子任务执行完的结果都放在一个队列里,启动一个线程从队列里拿数据,然后合并这些数据。

JDK中提供的Fork/Join框架的类图关系如下:

从上面类图可以看出,Fork/Join框架主要使用ForkJoinTask和ForkJoinPool实现:

ForkJoinTask:这是一个Fork/Join的执行任务,它提供了在任务中执行fork()和join()的工作机制。通常情况下,我们也不会直接继承ForkJoinTask抽象类,而是继承其子类:RecursiveTask(有返回结果)和RecursiveAction(无返回结果)。

ForkJoinPool:ForkJoinTask需要提交到ForkJoinPool中执行,任务分隔出来的子任务会添加到当前的双端队列中,并且进入队列的头部。当工作线程的队列中无任务时,会从其他队列中随机窃取一个任务执行。

其原理图可参见如下:

简单使用

下面给出一个数字求和的例子:计算1+2+...+200000的结果。如下所示:

/**
 * 1到200000求和任务分解
 * 
 * @author xuefeihu
 *
 */
public class CountTask extends RecursiveTask<Long> {

	private static final long serialVersionUID = 1L;
	private static final int THRESHOLD = 10000;
	private long start;
	private long end;

	public CountTask(long start, long end) {
		this.start = start;
		this.end = end;
	}

	@Override
	protected Long compute() {
		long sum = 0;
		boolean canCompute = (end - start) < THRESHOLD;
		if (canCompute) {
			for (long i = start; i <= end; i++) {
				sum += i;
			}
		} else {
			// 分成100个小任务
			long step = (start + end) / 100;
			ArrayList<CountTask> subTasks = new ArrayList<>();
			long pos = start;
			for (int i = 0; i < 100; i++) {
				long lastOne = pos + step;
				if (lastOne > end) {
					lastOne = end;
				}
				CountTask subTask = new CountTask(pos, lastOne);
				pos += step + 1;
				subTasks.add(subTask);
				subTask.fork();
			}
			for (CountTask task : subTasks) {
				sum += task.join();
			}
		}

		return sum;
	}

	public static void main(String[] args) {
		ForkJoinPool forkJoinPool = new ForkJoinPool();
		CountTask task = new CountTask(0, 200000L);
		ForkJoinTask<Long> result = forkJoinPool.submit(task);
		try {
			long res = result.get();
			System.out.println("sum=" + res);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

}

从上面代码可以看出,对于CountTask继承了有返回值的RecursiveTask。里面实现了compute()方法,在方法里面给出了小任务的执行逻辑和大任务的分解逻辑。在大任务的分解逻辑中,存储了小任务的数组subTasks,用以join()返回结果使用。

在main函数里面,首先创建了ForkJoinPool线程池。然后将CountTask(大任务/对应一个工作线程和一个双端队列)提交到线程池中。最后调用get()方法进入同步等待结果输出。

异常处理

ForkJoinTask在执行的过程中有可能抛出异常,但是我们在主线程中无法获取子线程的异常信息。因此ForkJoinTask提供了isCompletedAbnormally(),用于判断Task是否执行异常或者被取消等。获取异常信息可使用如下代码:

if(task.isCompletedAbnormally()) {
    // 如有异常则返回CancellationException,没有则返回null
    System.out.println(task.getException());
}

代码实现

ForkJoinPool由ForkJoinTask数组和ForkJoinWorkerThread数组组成,ForkJoinTask数组负责存放程序提交给ForkJoinPool的任务,而ForkJoinWorkerThread数组负责执行这些任务。

ForkJoinTask的fork方法实现,当我们调用ForkJoinTask的fork方法时,程序会调用ForkJoinWorkerThread的push方法异步的执行这个任务,然后立即返回结果;对于非ForkJoinWorkerThread的线程会提交到ForkJoinPool.common中。代码如下:

    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }

push方法把当前任务存放在ForkJoinTask数组queue里。然后再调用ForkJoinPool的signalWork()方法唤醒或创建一个工作线程来执行任务。代码如下:

        final void push(ForkJoinTask<?> task) {
            ForkJoinTask<?>[] a; ForkJoinPool p;
            int b = base, s = top, n;
            if ((a = array) != null) {    // ignore if queue removed
                int m = a.length - 1;     // fenced write for task visibility
                U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
                U.putOrderedInt(this, QTOP, s + 1);
                if ((n = s - b) <= 1) {
                    if ((p = pool) != null)
                        p.signalWork(p.workQueues, this);
                }
                else if (n >= m)
                    growArray();
            }
        }

ForkJoinTask的join方法实现。join()方法可以阻塞当前线程并等待结果。下面看一下join()方法的实现,代码如下:

    public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }

首先调用了doJoin()方法,通过doJoin()方法得到当前任务的状态来判断返回什么结果,任务状态有四种:已完成(NORMAL),被取消(CANCELLED),信号(SIGNAL)和出现异常(EXCEPTIONAL)。

如果任务状态是已完成,则直接返回任务结果。

如果任务状态是被取消,则直接抛出CancellationException。

如果任务状态是抛出异常,则直接抛出对应的异常。

下面是doJoin()方法的实现代码:

    private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }

在doJoin()方法里,首先通过查看任务的状态,看任务是否已经执行完了,如果执行完了,则直接返回任务状态,如果没有执行完,则从任务数组里取出任务并执行。如果任务顺利执行完成了,则设置任务状态为NORMAL,如果出现异常,则纪录异常,并将任务状态设置为EXCEPTIONAL。