重学Java——ThreadLocal源码解读

ThreadLocal

今天端午,就看一下轻松点的东西吧,上次说消息机制,说到Looper时,就是把Looper存储在ThreadLocal中,然后在对应的线程获取到对象,今天就来看下ThreadLocal的源码解读吧。

ThreadLocal的简单使用

还是上次讲的那个例子

    final ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
    threadLocal.set(1);
    new Thread(new Runnable() {
        @Override
        public void run() {
            Log.d(TAG, "run: "+threadLocal.get());     
        }
    }).start();
复制代码

得到的结果是***MainActivity: run: null,可见在哪个线程放在数据,只有在对应的那个线程取出。

ThreadLocal源码分析

每一个线程Thread的源码内部有属性

/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
复制代码

再看ThreadLocal的源码,主要关心的就是存储问题,也就是set和get方法,先来看下set

    public void set(T value) {
        //获取当前线程
        Thread t = Thread.currentThread();
        //得到ThreadLocalMap,这是个专门用于存储线程的ThreadLocal的数据
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
复制代码

看第二行ThreadLocalMap map = getMap(t);,这是个专门用于存储线程的ThreadLocal的数据,set的步骤是:

  1. 获取当前线程的成员变量map
  2. map非空,则重新将ThreadLocal和新的value副本放入到map中。
  3. map空,则对线程的成员变量ThreadLocalMap进行初始化创建,并将ThreadLocal和value副本放入map中。
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    void createMap(Thread t, T firstValue) {
        //每一个线程中都一个对应的threadLocal,然后又通过ThreadLocal负责来维护对应的ThreadLocalMap
        //通过ThreadLocal来获取来设置线程的变量值
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
复制代码

先暂时不去看ThreadLocalMap的源码,只要知道它是用于存储就行,我们先看下get方法

    public T get() {
        Thread t = Thread.currentThread();
        //还是能过当前线程get到这个threadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                //从map里取到值就直接返回
                return result;
            }
        }
        //没有取到值就返回默认初始值
        return setInitialValue();
    }

    ThreadLocalMap getMap(Thread t) {
        //当前线程的ThreadLocal
        return t.threadLocals;
    }

	
    private T setInitialValue() {
        //这个等下看,看字面意思就是初始化value
        T value = initialValue();
        Thread t = Thread.currentThread();
        //下面的就是和set方法就是一样的了
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
复制代码

再看一下initialValue方法

    /**
     * Returns the current thread's "initial value" for this
     * thread-local variable.  This method will be invoked the first
     * time a thread accesses the variable with the {@link #get}
     * method, unless the thread previously invoked the {@link #set}
     * method, in which case the {@code initialValue} method will not
     * be invoked for the thread.  Normally, this method is invoked at
     * most once per thread, but it may be invoked again in case of
     * subsequent invocations of {@link #remove} followed by {@link #get}.
     *
     * <p>This implementation simply returns {@code null}; if the
     * programmer desires thread-local variables to have an initial
     * value other than {@code null}, {@code ThreadLocal} must be
     * subclassed, and this method overridden.  Typically, an
     * anonymous inner class will be used.
     *
     * @return the initial value for this thread-local
     */
    protected T initialValue() {
        return null;
    }
复制代码

就返回了一个null,为什么不直接用null呢,这也是复制这一大段注释的原因,此实现只返回{@code null};如果程序员希望线程局部变量的初始值不是{@code null},则必须对{@code ThreadLocal}进行子类化,并且此方法将被重写。通常,将使用匿名内部类。

再回到get方法,可以得出get的步骤为:

  1. 获取当前线程的ThreadLocalMap对象threadLocal
  2. 从map中获取线程存储的K-V Entry节点。
  3. 从Entry节点获取存储的Value副本值返回。
  4. map为空的话返回初始值null,即线程变量副本为null,在使用时需要注意判断NullPointerException。

ThreadLocalMap

现在我们集中精力来看ThreadLocalMap的源码

static class ThreadLocalMap {
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
    
        private static final int INITIAL_CAPACITY = 16;
    
        private Entry[] table;

        private int size = 0;

        private int threshold;
}
复制代码

上面是ThreadLocalMap的一些属性,结构看起来和HashMap结构差不多,可以看到ThreadLocalMap的Entry继承自WeakReference,并使用ThreadLocal为键值。

这里为什么不使用普通的key-value形式来定义存储结构,实质上就会造成节点的生命周期与线程绑定,只要线程没有销毁,那么节点在GC是一直是处于可达状态,是没办法回收的,而程序本身并没有方法判断是否可以清理节点。弱引用的性质就是GC到达时,那么这个对象就会被回收。当某个ThreadLocal已经没有强引用可达,则随着它被GC回收,在ThreadLocalMap里对应的Entry就会失效,这也为Map本身垃圾清理提供了便利。

        /**
         * Set the resize threshold to maintain at worst a 2/3 load factor.
         * 设置resize阈值以维持最坏2/3的负载因子
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * Increment i modulo len.
         * 下一个索引
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * Decrement i modulo len.
         * 上一个索引
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }
复制代码

熟悉HashMap的话,其实对负载因子应该很熟悉,ThreadLocal有两个方法用于得到上/下一个索引,用于解决Hash冲突的方式就是简单的步长加1或减1,寻找下一个相邻的位置。

所以很明显,ThreadLocalMap这种线性探测方式来解决Hash冲突效率很低,建议:每个线程只存一个变量,这样的话所有的线程存放到map中的key都是相同的ThreadLocal,如果一个线程要保存多个变量,就需要创建多个ThreadLocal,多个ThreadLocal放入Map中时会极大的增加Hash冲突的可能。

再来看它的set方法

        private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
			
            //线性探测
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
		
                //找到对应的entry
                if (k == key) {
                    e.value = value;
                    return;
                }

                //替换失效的entry
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            //如果没有找到对应的key,就在末尾放上new Entry
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                //再次hash
                rehash();
        }


        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            //向前探测
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 //固定步长
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
            //向后遍历
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                //找到key,更新为新的value
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    //如果在整个扫描过程中,找到了之前的无效值,那么以它为清理起点,否则以当前的i为清理起点
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            //如果key在table中不存在,在原地放一个new entry
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            //在探测过程中发现无效的位置,则做一次清理(连续段清理+启发式清理)
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

       private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    //清理一个连续段
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

        private void rehash() {
            //清理陈旧数据
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            // 清理完后,如果>=3/4阈值,就进行扩容
            if (size >= threshold - threshold / 4)
                resize();
        }
复制代码
  • 在set方法中,先循环查找,如果key的值找到了,直接替换覆盖就可
  • 如果k失效,那么直接调用replaceStaleEntry,效果就是把这个key和value都放在这个位置,同时病理历史key=null的陈旧数据
  • 如果没有找到key,那么在末尾后的一个空位置放上entry,放完后做一次清理,如果清理出去key,并且当前是table大小已经超过阈值,则做一次rehash。
    • 清理一遍陈旧数据
    • >= 3/4阀值,就执行扩容,把table扩容2倍==》注意这里3/4阀值就执行扩容,避免迟滞
    • 把老数据重新哈希散列进新table

可以看到,和HashMap最大的不同在于,ThreadLocalMap的结构非常简单,没有next引用,就是说ThreadLocalMap解决Hash冲突并不是链表的方式,而是线性探测——当key的hashcode值在table数组的位置,如果发现这个位置上已经有其他的key值元素占用了,那么利用固定的算法寻找下一定步长的下一个位置,依次判断,直到找到能够存放的位置

ThreadLocal的问题

对于ThreadLocal的内存泄漏,由于ThreadLocalMap的key是弱引用,而value是强引用,这就导致当ThreadLocal在没有外部对象强引用时,GC会回收key,但value不会回收,如果ThreadLocal的线程一直运行着,那么这个Entry对象的value就可能一直不能回收,引发内存泄漏。

ThreadLocal的设计已经为我们考虑到了这个问题,提供remove方法,将Entry节点和Map的引用关系移除,这样Entry在GC分析时就变成了不可达,下次GC就能回收。

看一下remove的源码

     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }
复制代码

可以看到就是调用了ThreadLocalMap的remove方法

		/**
		* 从map中删除threadLocal
		*/
		private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    //调用entry的clear方法
                    e.clear();
                    //进行清理
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
复制代码

所以我们在使用后,可以显示的调用remove方法,来避免内存泄漏,是一个很好的编程习惯。

参考

ThreadLocal终极源码剖析-一篇足矣!

ThreadLocal-面试必问深度解析

Java并发编程:深入剖析ThreadLocal


我的CSDN

下面是我的公众号,欢迎大家关注我

猜你喜欢

转载自juejin.im/post/5cfa7f156fb9a07ed440f2be