jdk源码之ThreadLocal

ThreadLocal算是为多线程解决并发问题提供了一种新的思路,为了更好地使用它,读其优秀的实现。

其构造方法是空的,那么直接看其set()方法

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
先得到当前线程,再以当前线程为key,调用getMap得到当前线程对应的ThreadLocalMap,如果ThreadLocalMap不为空,那么set(),以当前ThreadLocal实例为key与value构成键值对。否则createMap()
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
从getMap可以注意到在Thread的类有一个ThreadLocalMap类型的threadLocals成员变量。一开始是自然为null。
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
可以看到createMap直接给当前线程的threadLocals成员变量赋值。
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

ThreadLocalMap是ThreadLocal的内部类,可以从构造看到,给ThreadLocalMap的table成员赋值为Entry数组,且长度默认为16,然后firstKey为传入的ThreadLocal实例,取其threadLocalHashCode跟16取余(通过与其长度-1与实现),得到对应key存储在Entry数组中的下标i,然后赋值,并设为size为1,然后将threshold设为Entry数组长度的2/3倍。

实际上ThreadLocalMap类似于一个hashMap。每个键值对是一Entry形式存储。Entry是ThreadLocalMap的内部类

        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

其元素的key是对于的ThreadLocal实例,且其实弱引用的形式。我们可以大致猜到为什么要用弱引用。

java引用分为强软弱虚,弱引用一旦可达性分析不到,那么在下次gc时,会被回收。这样我们就好猜到,这里使用弱引用的原因,为了避免把节点的生命周期跟线程绑定在一起,即只有线程销毁后,该节点才引用不可达。设为弱引用,那么只要某个ThreadLocal失去可达性,那么随着它的回收,某个对应的entry将会失效,这为ThreadLocalMap本身的垃圾清理提供了便利。

那么进入ThreadLocal的set()另一个分支,ThreadLocalMap存在那么调用ThreadLocalMap的set()函数,传入当前ThreadLocal实例,跟value。
        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
通过key的ThreadLocalHashCode跟len的取余,得到对应的下标i,如果下标i对应的位置上有Entry存在,那么判断其位置上原来的k是否与key相同,如果相同则直接更新value,返回。如果k为null,那么说明此entry无效,那么通过replaceStaleEntry()方法,把key和value放在这个位置,且尽可能清理无效的entry。我们先来看下nextIndex()方法
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }
可以看到其解决hash冲突采用的是线性探测法。我们再来注意下threadLocalHashCode成员
    private final int threadLocalHashCode = nextHashCode();
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

这儿有个魔数HASH_INCREMENT = 0x61c88647,关于这个魔数的选取网上说法:与斐波那契散列有关,0x61c88647对应的十进制为1640531527。斐波那契散列的乘数可以用(long) ((1L << 31) * (Math.sqrt(5) - 1))可以得到2654435769,如果把这个值给转为带符号的int,则会得到-1640531527。

(我倒是发现个有趣的巧合,0x61c88647如果单位是秒的话,那么化成年约为51年,如果按照第一台计算机诞生1946年算,这是数字恰好是ThreadLocal上写的年份,1997。程序员的调皮~)

反正不管怎么样,通过理论与实践,当我们用0x61c88647作为魔数累加为每个ThreadLocal分配各自的ID也就是threadLocalHashCode再与2的幂取模,得到的结果分布很均匀。

我们再来看下replaceStaleEntry()方法

        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
先从失效的entry向前找,一但遇上entry为null的则停下,找到最前一个无效的slot的下标为slotToExpunge。然后开始从传入的staleSlot之后开始向后遍历。如果找到key,则将其与之前无效的entry交换,如果向前找无效的slot的过程中,没找到无效的slot即满足slotToExpunge == staleSlot,那么把slotToExpunge赋值为当前下标i,如果找到了那么则不更改slotToExpunge,然后以slotToExpunge为参数调用expungeStaleEntry(),如果没找到key,当前位置为空索引并且向前找无效的slot的过程中,没找到无效的slot那么把slotToExpunge赋值为当前下标i,然后继续往后找,直到找到。
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

这个函数是核心清理函数,现将stateSlot位置的entry设为null,因此这个位置的entry已经无效,然后从stateSlot后位置开始向后遍历,如果是无效entry(即弱引用的ThreadLocal已经被回收)清理,即对应entry中的value置为null,将指向这个table[i]的entry置为null,如果不是,即k(ThreadLocal)不为null,那么将重新hash散列,如果重新散列的位置不为当前位置,那么将当前位置设为null,且从重新散列的位置开始找找到对应的table[i]为null的slot,插入。返回的i是stateSlot后的第一个空的slot的下标。

然后调用了cleanSomeSlots(),传入expungeStaleEntry的返回值跟len

        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                //当前i对应的entry不可能为无效,要么指向的ThreadLocal没被回收,或者entry本身为空,
                //所以从下一个位置开始判断
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }
这个算是启发式的清理,正常情况下,判断logn次没有无效的entry则函数结束,如果有一次发现了无效的entry,则将n设为长度len,并且调用expungeStaleEntry()进行连续段的清理。

回到ThreadLocalMap的set方法中,直到找到空的slot,然后构造新的Entry键值对,插入,然后size自增,然后如果没有清理出新的位置,并且size大小大于负载容限threshold,于是调用rehash(),内部先清理,再判断是否需要扩容。看下rehash()函数。

        private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }
先调用expungeStaleEntries()进行全量清理
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }
可以看到,遍历所有位置,清理失效的。在全量清理后,size可能会明显变小,于是降低了此时的扩容门槛,从原来的2/3threshold变成了1/2。2/3threshold-2/3threshold/4=1/2threshold。
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }

这个流程很清晰,扩容为原来的两倍。

set方法到这里就结束了了我们来简单看下get()方法。

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
先从thread上读看看有没threadLocalMap,如果没有的话则调用setInitialValue()方法新建个map放进去,如果有则根据当前ThreadLocal查找有效Entry,再取出值返回。先看下getEntry函数,查找有效entry
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }
可以看到如果当前散列位置的entry不为null且有效则返回,否则调用getEntryAfterMiss(key, i, e),继续查找
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }
如果为null则返回null,如果不为null则说明当前无效,继续往后线性探测,如果碰到有效的则返回,碰到无效的,则调用expungeStaleEntry函数,清理连续段内的无效entry。所以,最后返回要么找到有效的,要么没找到返回null。
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

在get中thread内的map为null时调用setInitialValue()初始化。

get()流程到这介绍完了。

我们再来看下remove方法

     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }
     private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

找到相应位置,清理弱引用。

看到Thread内的成员threadLocals时还注意到inheritableThreadLocals成员

    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
在Thread的init()函数中注意到
        if (parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
可以看到inheritableThreadLocals是接收来自父类的inheritableThreadLocals的。那么说明inheritableThreadLocals可以实现父子线程之间的数据共享。虽然ThreadLocal本身是线程隔离的。
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }
    private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }
还是比较简单的,做的事情就是以父线程的 inheritableThreadLocals为数据源,过滤出有效的entry,初始化到自己的 inheritableThreadLocals中。其中childValue必须被重写。
    T childValue(T parentValue) {
        throw new UnsupportedOperationException();
    }
算是在这里让子线程看到,那么仅仅在子线程创建的时候会去拷一份父线程的且调用childValue函数,也就是说如果父线程是在子线程创建后再set某个InheritableThreadLocal对象的值,对子线程是不可见的。







猜你喜欢

转载自blog.csdn.net/panxj856856/article/details/80884635