ThreadLocal以及FastThreadlocal源码解读

ThreadLocal在线程之间隔离变量十分方便,但是JDK自带的这个东东,也并不是毫无缺点的,这里先不讲他的缺点,因为大多数情况下,JDK的ThreadLocal还是够用的。
JDK对ThreadLocal的介绍
该类提供了线程局部 (thread-local) 变量。这些变量不同于它们的普通对应物,因为访问某个变量(通过其 get 或 set 方法)的每个线程都有自己的局部变量,它独立于变量的初始化副本。ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联。 每个线程都保持对其线程局部变量副本的隐式引用,只要线程是活动的并且 ThreadLocal 实例是可访问的;在线程消失之后,其线程局部实例的所有副本都会被垃圾回收(除非存在对这些副本的其他引用)

使用ThreadLocal

ThreadLocal其中供我们使用的API重要的有

	public T get();

    public void set(T value);

    public void remove();

    protected T initialValue();
  1. get() : 返回此线程局部变量的当前线程副本中的值。如果变量没有用于当前线程的值,则先将其初始化为调用 initialValue() 方法返回的值(默认返回的是null)
  2. set():将此线程局部变量的当前线程副本中的值设置为指定值。大部分子类不需要重写此方法,它们只依靠 initialValue() 方法来设置线程局部变量的值。
  3. remove() :移除此线程局部变量当前线程的值。如果此线程局部变量随后被当前线程读取, 且这期间当前线程没有设置其值,则将调用其 initialValue() 方法重新初始化其值。这将导致在当前线程多次调用 initialValue 方法。则不会对该线程再调用 initialValue 方法。通常,此方法对每个线程最多调用一次,但如果在调用 get() 后又调用了 remove() ,则可能再次调用此方法。
  4. initialValue():返回此线程局部变量的当前线程的“初始值”。线程第一次使用 get() 方法变量时将调用此方法,但如果线程之前调用了 set(T) 方法,

get()方法的源码如下:

  /**
     * Returns the value in the current thread's copy of this
     * thread-local variable.  If the variable has no value for the
     * current thread, it is first initialized to the value returned
     * by an invocation of the {@link #initialValue} method.
     *
     * @return the current thread's value of this thread-local
     */
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

代码通过getMap(t) 获取了一个ThreadLocalMap如果这个map存在的话,就从这个map中以ThreadLocal 为key迁移获取对应的值,否则返回setInitialValue();
getMap()这个方法就有意思了,

ThreadLocalMap getMap(Thread t) {
     return t.threadLocals;
 }

他直接返回了Thread中的threadLocals变量,哈哈,看来还是对Thread 这个Java并发基础类的源代码不熟悉啊。?
此 threadLocals实则是一个ThreadLocal.ThreadLocalMap ThreadLocal类中定义的静态内部类。

static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

        private final Supplier<? extends T> supplier;

        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }

        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }

    /**
     * ThreadLocalMap is a customized hash map suitable only for
     * maintaining thread local values. No operations are exported
     * outside of the ThreadLocal class. The class is package private to
     * allow declaration of fields in class Thread.  To help deal with
     * very large and long-lived usages, the hash table entries use
     * WeakReferences for keys. However, since reference queues are not
     * used, stale entries are guaranteed to be removed only when
     * the table starts running out of space.
     */
    static class ThreadLocalMap {

        /**
         * The entries in this hash map extend WeakReference, using
         * its main ref field as the key (which is always a
         * ThreadLocal object).  Note that null keys (i.e. entry.get()
         * == null) mean that the key is no longer referenced, so the
         * entry can be expunged from table.  Such entries are referred to
         * as "stale entries" in the code that follows.
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

这个map 持有一个 Entry 数组,Entry 继承了 WeakReference ,也就是弱引用,如果一个对象具有弱引用,在GC线程扫描内存区域的过程中,不管当前内存空间足够与否,都会回收内存。这个JVM特性不熟悉需要去看一看JVM 和JMM 相关的书

总的来说,每个线程对象中都有一个 ThreadLocalMap 属性,该属性存储 ThreadLocal 为 key ,值则是我们调用 ThreadLocal 的 set 方法设置的,也就是说,一个ThreakLocal 对象对应一个 value。
那么我们获取到了Map之后,调用getEntry()方法又做了什么呢?

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

如果hash获取正常直接返回,否则调用

      private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

去循环找到key为k的元素,如果在循环的过程中,遇到了key为null的元素,则调用expungeStaleEntry(i); 清楚key为null的元素。

		private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

清楚key为null的代码,其实就是数组的循环移位,效率不是很高。

Set()方法


	public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

先得到当前线程,然后根据当前线程得到线程的 ThreadLocalMap 属性,如果 Map 为null, 则创建一个Map ,并将值放置到Map中,否则,直接将值放置到Map中。

  void createMap(Thread t, T firstValue) {
      t.threadLocals = new ThreadLocalMap(this, firstValue);
  }

map.set(ThreadLocal<?> key, Object value) 方法如何实现的:


        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

分析一下set(0方法,首先拿到key对应的hash索引。
首先通过下标找对应的entry对象,如果没有,则创建一个新的 entry对象
如果找到了,但key冲突了或者key是null,则将下标加一(加一后如果小于数组长度则使用该值,否则使用0),
再次尝试获取对应的 entry,如果不为null,则在循环中继续判断key 是否重复或者k是否是null
新建Entry对象之后,size++是理所当然的,后面如果清理key为null的元素失败&&size大于阈值,则执行扩容。
看到那层for循环,大家也都想起了数据结构课程学习到的线性探测法。不错,ThreadLocal正是使用了现行探测法来进行扩容的,并没有使用HashMap的拉链法。

回到Set方法,如果没有Hash冲突,那么会执行一遍


       private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

遍历所有的entry,并判断他们的key,如果key是null,则调用 expungeStaleEntry 方法,也就是清除 entry。最后返回 true。

如果返回了 false ,说明没有清除,并且 size 还 大于等于 10 ,就需要 rahash,该方法如下:


        private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }

又看到了那个expungeStaleEntries(),ThreadLocal的作者无时无刻不在想着 清理key为null的Entry
remove()方法


   private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

遍历清楚key为指定k的Entry,同时又执行expungeStaleEntry()。
ThreadLocal就介绍到这里吧,整体来说,代码不是太难,不过有的细节地方看了几次都不明白,最后看了大佬写的博客才明白含义。所以ThreadLocal的最佳实践,就是业务线程处理完之后,一定要记得调用remove(),.
楼主之所以,想分析一下ThreadLocal的源代码,是因为看到了Netty一直在吹捧他的FastThreadLocal是多么多么块,不过Netty利用对象池化,和内存填充等技术,在高并发场景下的确比JDK原生的ThreadLocal更胜一筹,这个留作下次分析。

原创文章 132 获赞 23 访问量 3万+

猜你喜欢

转载自blog.csdn.net/qq_33797928/article/details/94591744