前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >死磕 java线程系列之ForkJoinPool深入解析

死磕 java线程系列之ForkJoinPool深入解析

作者头像
彤哥
发布2019-11-14 20:15:58
6290
发布2019-11-14 20:15:58
举报
文章被收录于专栏:彤哥读源码彤哥读源码

简介

随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。

今天,我们就来看一道面试题:

如何充分利用多核CPU,计算很大数组中所有整数的和?

剖析

  • 单线程相加? 我们最容易想到就是单线程相加,一个for循环搞定。
  • 线程池相加? 如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。
  • 其它? Yes,就是我们今天的主角——ForkJoinPool,但是它要怎么实现呢?似乎没怎么用过哈^^

三种实现

OK,剖析完了,我们直接来看三种实现,不墨迹,直接上菜。

代码语言:javascript
复制
    /**

     * 计算1亿个整数的和

     */

    publicclassForkJoinPoolTest01{

    publicstaticvoid main(String[] args) throwsExecutionException, InterruptedException{

    // 构造数据

    int length = 100000000;

    long[] arr = newlong[length];

    for(int i = 0; i < length; i++) {

                arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);

    }

    // 单线程

            singleThreadSum(arr);

    // ThreadPoolExecutor线程池

            multiThreadSum(arr);

    // ForkJoinPool线程池

            forkJoinSum(arr);


    }


    privatestaticvoid singleThreadSum(long[] arr) {

    long start = System.currentTimeMillis();


    long sum = 0;

    for(int i = 0; i < arr.length; i++) {

    // 模拟耗时

                sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);

    }


    System.out.println("sum: "+ sum);

    System.out.println("single thread elapse: "+ (System.currentTimeMillis() - start));


    }


    privatestaticvoid multiThreadSum(long[] arr) throwsExecutionException, InterruptedException{

    long start = System.currentTimeMillis();


    int count = 8;

    ExecutorService threadPool = Executors.newFixedThreadPool(count);

    List<Future<Long>> list = newArrayList<>();

    for(int i = 0; i < count; i++) {

    int num = i;

    // 分段提交任务

    Future<Long> future = threadPool.submit(() -> {

    long sum = 0;

    for(int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {

    try{

    // 模拟耗时

                            sum += (arr[j]/3*3/3*3/3*3/3*3/3*3);

    } catch(Exception e) {

                            e.printStackTrace();

    }

    }

    return sum;

    });

                list.add(future);

    }


    // 每个段结果相加

    long sum = 0;

    for(Future<Long> future : list) {

                sum += future.get();

    }


    System.out.println("sum: "+ sum);

    System.out.println("multi thread elapse: "+ (System.currentTimeMillis() - start));

    }


    privatestaticvoid forkJoinSum(long[] arr) throwsExecutionException, InterruptedException{

    long start = System.currentTimeMillis();


    ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();

    // 提交任务

    ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(newSumTask(arr, 0, arr.length));

    // 获取结果

    Long sum = forkJoinTask.get();


            forkJoinPool.shutdown();


    System.out.println("sum: "+ sum);

    System.out.println("fork join elapse: "+ (System.currentTimeMillis() - start));

    }


    privatestaticclassSumTaskextendsRecursiveTask<Long> {

    privatelong[] arr;

    privateint from;

    privateint to;


    publicSumTask(long[] arr, int from, int to) {

    this.arr = arr;

    this.from = from;

    this.to = to;

    }


    @Override

    protectedLong compute() {

    // 小于1000的时候直接相加,可灵活调整

    if(to - from <= 1000) {

    long sum = 0;

    for(int i = from; i < to; i++) {

    // 模拟耗时

                        sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);

    }

    return sum;

    }


    // 分成两段任务

    int middle = (from + to) / 2;

    SumTask left = newSumTask(arr, from, middle);

    SumTask right = newSumTask(arr, middle, to);


    // 提交左边的任务

                left.fork();

    // 右边的任务直接利用当前线程计算,节约开销

    Long rightResult = right.compute();

    // 等待左边计算完毕

    Long leftResult = left.join();

    // 返回结果

    return leftResult + rightResult;

    }

    }

    }

彤哥偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。

所以,为了演示ForkJoinPool的牛逼之处,我把每个数都 /3*3/3*3/3*3/3*3/3*3了一顿操作,用来模拟计算耗时。

来看结果:

代码语言:javascript
复制
    sum: 107352457433800662

    single thread elapse: 789

    sum: 107352457433800662

    multi thread elapse: 228

    sum: 107352457433800662

    fork join elapse: 189

可以看到,ForkJoinPool相对普通线程池还是有很大提升的。

问题:普通线程池能否实现ForkJoinPool这种计算方式呢,即大任务拆中任务,中任务拆小任务,最后再汇总?

你可以试试看(-?_-?)

OK,下面我们正式进入ForkJoinPool的解析。

分治法

  • 基本思想 把一个规模大的问题划分为规模较小的子问题,然后分而治之,最后合并子问题的解得到原问题的解。
  • 步骤 (1)分割原问题: (2)求解子问题: (3)合并子问题的解为原问题的解。

在分治法中,子问题一般是相互独立的,因此,经常通过递归调用算法来求解子问题。

  • 典型应用场景 (1)二分搜索 (2)大整数乘法 (3)Strassen矩阵乘法 (4)棋盘覆盖 (5)归并排序 (6)快速排序 (7)线性时间选择 (8)汉诺塔

ForkJoinPool继承体系

ForkJoinPool是 java 7 中新增的线程池类,它的继承体系如下:

ForkJoinPool和ThreadPoolExecutor都是继承自AbstractExecutorService抽象类,所以它和ThreadPoolExecutor的使用几乎没有多少区别,除了任务变成了ForkJoinTask以外。

这里又运用到了一种很重要的设计原则——开闭原则——对修改关闭,对扩展开放。

可见整个线程池体系一开始的接口设计就很好,新增一个线程池类,不会对原有的代码造成干扰,还能利用原有的特性。

ForkJoinTask

两个主要方法

  • fork() fork()方法类似于线程的Thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。
  • join() join()方法类似于线程的Thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。

三个子类

  • RecursiveAction 无返回值任务。
  • RecursiveTask 有返回值任务。
  • CountedCompleter 无返回值任务,完成任务后可以触发回调。

ForkJoinPool内部原理

ForkJoinPool内部使用的是“工作窃取”算法实现的。

(1)每个工作线程都有自己的工作队列WorkQueue;

(2)这是一个双端队列,它是线程私有的;

(3)ForkJoinTask中fork的子任务,将放入运行该任务的工作线程的队头,工作线程将以LIFO的顺序来处理工作队列中的任务;

(4)为了最大化地利用CPU,空闲的线程将从其它线程的队列中“窃取”任务来执行;

(5)从工作队列的尾部窃取任务,以减少竞争;

(6)双端队列的操作:push()/pop()仅在其所有者工作线程中调用,poll()是由其它线程窃取任务时调用的;

(7)当只剩下最后一个任务时,还是会存在竞争,是通过CAS来实现的;

ForkJoinPool最佳实践

(1)最适合的是计算密集型任务;

(2)在需要阻塞工作线程时,可以使用ManagedBlocker;

(3)不应该在RecursiveTask的内部使用ForkJoinPool.invoke()/invokeAll();

总结

(1)ForkJoinPool特别适合于“分而治之”算法的实现;

(2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;

(3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;

(4)ForkjoinPool内部基于“工作窃取”算法实现;

(5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;

(6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;

(7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;

彩蛋

ManagedBlocker怎么使用?

答:ManagedBlocker相当于明确告诉ForkJoinPool框架要阻塞了,ForkJoinPool就会启另一个线程来运行任务,以最大化地利用CPU。

请看下面的例子,自己琢磨哈^^。

代码语言:javascript
复制
    /**

     * 斐波那契数列

     * 一个数是它前面两个数之和

     * 1,1,2,3,5,8,13,21

     */

    publicclassFibonacci{


    publicstaticvoid main(String[] args) {

    long time = System.currentTimeMillis();

    Fibonacci fib = newFibonacci();

    int result = fib.f(1_000).bitCount();

            time = System.currentTimeMillis() - time;

    System.out.println("result = "+ result);

    System.out.println("test1_000() time = "+ time);

    }


    publicBigInteger f(int n) {

    Map<Integer, BigInteger> cache = newConcurrentHashMap<>();

            cache.put(0, BigInteger.ZERO);

            cache.put(1, BigInteger.ONE);

    return f(n, cache);

    }


    privatefinalBigInteger RESERVED = BigInteger.valueOf(-1000);


    publicBigInteger f(int n, Map<Integer, BigInteger> cache) {

    BigInteger result = cache.putIfAbsent(n, RESERVED);

    if(result == null) {


    int half = (n + 1) / 2;


    RecursiveTask<BigInteger> f0_task = newRecursiveTask<BigInteger>() {

    @Override

    protectedBigInteger compute() {

    return f(half - 1, cache);

    }

    };

                f0_task.fork();


    BigInteger f1 = f(half, cache);

    BigInteger f0 = f0_task.join();


    long time = n > 10_000 ? System.currentTimeMillis() : 0;

    try{


    if(n % 2== 1) {

                        result = f0.multiply(f0).add(f1.multiply(f1));

    } else{

                        result = f0.shiftLeft(1).add(f1).multiply(f1);

    }

    synchronized(RESERVED) {

                        cache.put(n, result);

                        RESERVED.notifyAll();

    }

    } finally{

                    time = n > 10_000 ? System.currentTimeMillis() - time : 0;

    if(time > 50)

    System.out.printf("f(%d) took %d%n", n, time);

    }

    } elseif(result == RESERVED) {

    try{

    ReservedFibonacciBlocker blocker = newReservedFibonacciBlocker(n, cache);

    ForkJoinPool.managedBlock(blocker);

                    result = blocker.result;

    } catch(InterruptedException e) {

    thrownewCancellationException("interrupted");

    }


    }

    return result;

    // return f(n - 1).add(f(n - 2));

    }


    privateclassReservedFibonacciBlockerimplementsForkJoinPool.ManagedBlocker{

    privateBigInteger result;

    privatefinalint n;

    privatefinalMap<Integer, BigInteger> cache;


    publicReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) {

    this.n = n;

    this.cache = cache;

    }


    @Override

    publicboolean block() throwsInterruptedException{

    synchronized(RESERVED) {

    while(!isReleasable()) {

                        RESERVED.wait();

    }

    }

    returntrue;

    }


    @Override

    publicboolean isReleasable() {

    return(result = cache.get(n)) != RESERVED;

    }

    }

    }
本文参与?腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-11-09,如有侵权请联系?cloudcommunity@tencent.com 删除

本文分享自 彤哥读源码 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体分享计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 简介
  • 剖析
  • 三种实现
  • 分治法
  • ForkJoinPool继承体系
  • ForkJoinTask
    • 两个主要方法
      • 三个子类
      • ForkJoinPool内部原理
      • ForkJoinPool最佳实践
      • 总结
      • 彩蛋
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
      http://www.vxiaotou.com