ForkJoinPool 探索
介绍
“分而治之“是理清思路和解决问题的一个重要的方法。大到系统架构对功能模块的拆分,小到归并排序的实现,无一不在散发着分而治之的思想。在实现分而治之的算法的时候,我们通常使用递归的方法。递归相当于把大的任务拆成多个小的任务,然后大任务等待多个小的子任务执行完成后,合并子任务的结果。一般来说,父任务依赖与子任务的执行结果,子任务与子任务之间没有依赖关系。因此子任务之间可以并发执行来提升性能。于是ForkJoinPool
提供了一个并发处理“分而治之”的框架,让我们能以类似于递归的编程方式获得并发执行的能力。
使用
分而治之代码典型的形式如下:
Result solve(Problem problem) { if (problem is small) { directly solve problem } else { split problem into independent parts fork new subtasks to solve each part join all subtasks compose result from subresults } }
计算斐波那契数:
Class Fibonacci extends RecursiveTask<Integer> { final int n; Fibonacci(int n) { this.n = n; } Integer compute() { if (n <= 1) return n; Fibonacci f1 = new Fibonacci(n - 1); f1.fork(); Fibonacci f2 = new Fibonacci(n - 2); return f2.compute() + f1.join(); } }
原理
ForkJoinPool
的核心在于其轻量级的调度机制,采用了Cilk的work-stealing的基本调度策略:
- 每个工作线程维持一个任务队列
- 任务队列以双端队列的形式维护,不仅支持先进后出的
push
和pop
操作,还支持先进先出的take操作 - 由父任务
fork
出来的子任务被push
到运行该父任务的工作线程对应的任务队列中 - 工作线程以先进后出的方式处理
pop
自己任务队列中的任务(优先处理最年轻的任务) - 当任务队列中没有任务时,工作线程尝试随机从其他任务队列中窃取任务
- 当工作线程没有任务可以执行,且窃取不到任务时,它会“退出”(yiled、sleep、优先级调整),经过一段时间后再次尝试。除非其他所有的线程也都没有任务可以执行,这种情况下它们会一直阻塞直到有新的任务从上层添加进来
一个简单的实现:
public class NaiveForkJoinPool { private final TaskQueue[] submissionQueues; private final TaskQueue[] workerQueues; private final WorkerThread[] workers; private final AtomicInteger aliveCount; private final ReentrantLock lock = new ReentrantLock(); private final Condition taskEmpty = lock.newCondition(); private final int parallelism; public NaiveForkJoinPool(int parallelism) { this.parallelism = parallelism; submissionQueues = new TaskQueue[parallelism]; workerQueues = new TaskQueue[parallelism]; workers = new WorkerThread[parallelism]; aliveCount = new AtomicInteger(parallelism); for (int i = 0; i < parallelism; i++) { submissionQueues[i] = new TaskQueue(); workerQueues[i] = new TaskQueue(); workers[i] = new WorkerThread(this, workerQueues[i]); } for (int i = 0; i < parallelism; i++) { workers[i].start(); } } public <T> T invoke(Task<T> task) { TaskQueue sd = submissionQueues[(submissionQueues.length - 1) & ThreadLocalRandom.current().nextInt()]; sd.push(task); tryCompensate(); return task.join(); } public <T> List<T> invokeAll(Task<T>... tasks) { List<T> res = new LinkedList<>(); for (Task<T> task : tasks) { TaskQueue sd = submissionQueues[(submissionQueues.length - 1) & ThreadLocalRandom.current().nextInt()]; sd.push(task); tryCompensate(); res.add(task.join()); } return res; } void tryCompensate() { if (aliveCount.get() < parallelism) { lock.lock(); if (aliveCount.get() < parallelism) { taskEmpty.signal(); } lock.unlock(); } } void runWorker() { int len = submissionQueues.length; int startIndex = (ThreadLocalRandom.current().nextInt()) & (len - 1); for (Task task = null; ; ) { if (task != null || (task = scan(startIndex)) != null) { task.runTask(); task = null; } else { task = awaitForWork(startIndex); } } } Task scan(int startIndex) { Task task; if ((task = scan(startIndex, submissionQueues)) != null) { return task; } if ((task = scan(startIndex, workerQueues)) != null) { return task; } return null; } Task scan(int startIndex, TaskQueue[] queues) { for (int i = startIndex, len = queues.length; i < startIndex + len; i++) { TaskQueue td = queues[i & (len - 1)]; Task task = td.take(); if (task != null) { return task; } } return null; } Task awaitForWork(int startIndex) { lock.lock(); try { Task task = scan(startIndex); if (task != null) { return task; } aliveCount.decrementAndGet(); try { taskEmpty.await(); } catch (InterruptedException e) { e.printStackTrace(); } aliveCount.incrementAndGet(); return null; } finally { lock.unlock(); } } class WorkerThread extends Thread { NaiveForkJoinPool pool; TaskQueue workQueue; public WorkerThread(NaiveForkJoinPool pool, TaskQueue workQueue) { this.pool = pool; this.workQueue = workQueue; } @Override public void run() { runWorker(); } } static abstract class Task<T> { static final int NORMAL = 1; final AtomicInteger status = new AtomicInteger(); final CountDownLatch isDone = new CountDownLatch(1); private T result; public abstract T compute(); public void runTask() { result = compute(); status.set(NORMAL); isDone.countDown(); } public Task<T> fork() { WorkerThread t = (WorkerThread) Thread.currentThread(); t.workQueue.push(this); t.pool.tryCompensate(); return this; } public T join() { Thread currentThread = Thread.currentThread(); if (currentThread instanceof WorkerThread) { WorkerThread t = (WorkerThread) Thread.currentThread(); TaskQueue wk = t.workQueue; for (Task task = wk.pop(); task != null; task = wk.pop()) { task.runTask(); if (task == this) { return result; } } waitForComplete(); } else { waitForComplete(); } return result; } void waitForComplete() { try { isDone.await(); } catch (InterruptedException e) { } } } static class TaskQueue { private final Deque<Task> deque = new ConcurrentLinkedDeque<>(); public void push(Task task) { deque.push(task); } public Task pop() { return deque.pollFirst(); } public Task take() { return deque.pollLast(); } } }
参考资料:
原创文章,转载请注明: 转载自并发编程网 – ifeve.com本文链接地址: ForkJoinPool 探索
暂无评论