ThreadLocal源码分析---相同线程数据共享,不同线程数据隔离

      本来是想说下java中的Thread的,毕竟之前就写了join和Thread的状态,顺便想把下面的也说了,不过碰到了一只拦路虎ThreadLocal,两者有许多关联,为了彻底搞清楚Thread,还是要先说一说ThreadLocal,所以又回到了这里。

      ThreadLocal 大家应该都有些熟悉,刚刚开始熟悉java的时候,大伙往数据库里面写数据,需要创建数据库连接,不知道大家当时有没有遇到过连接被占用的问题,\(^o^)/~,查资料之后才知道要使用ThreadLocal,当然,现在都是直接使用工具的,比如C3P0,druid等等。

      ThreadLocal平时使用好像并不多,如果不是面试需要,可能有许多小伙伴根本不想去看^_^,其实ThreadLocal用处很广泛的,封装的也比较经典。

      使用方面的话,我刚刚已经说了,连接数据库,还有,不知道大家用不用PageHelper,这个分页插件,它内部的机制也是使用ThreadLocal传递页码等信息,在进行SQL之前,将这些参数封装到SQL语句中,从而实现分页的功能。再比如,Spring框架,这个大家应该是比较熟悉的,我们使用的注入bean,这个Bean都的作用域默认的都是singleton,使用该属性定义Bean时,IOC容器仅创建一个Bean实例,IOC容器每次返回的是同一个Bean实例。那么大家想一想,在多线程情况下,一个Bean实例,真的可以么?不会有安全问题?

      大家如果去看源码的话,就会发现,其实这个也是使用的ThreadLocal,同一个线程数据共享,不同的线程数据隔离,保证了每个请求各自数据的独特性,安全性。

     ThreadLocal是如何实现这种机制的呢?我们来看下。

public class LocalTest {
    private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();

    public static void main(String[] args) throws Exception {
        Thread threadA = new Thread(new Runnable() {
            @Override
            public void run() {
                threadLocal.set("线程A");
                System.out.println("线程A中threadLocal:" + threadLocal);
                // to do sth
                try {
                    String s = threadLocal.get();
                    Thread.sleep(20);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }

        }, "线程A");
        /* 换一种写法,类似Runnable*/
        Thread threadB = new Thread(() -> {
            threadLocal.set("线程B");
            System.out.println("线程B中threadLocal:" + threadLocal);

            try {
                Thread.sleep(20);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            // to do sth
        }, "线程B");
        threadA.start();
        threadB.start();

       // threadLocal.set("666");
        System.out.println("线程main中threadLocal:" + threadLocal);
        String s = threadLocal.get();
        System.out.println("获取信息s:" + s);
    }
}

        这段代码很简单,就是创建了一个ThreadLocal,然后启动两个线程,分别在ThreadLocal中进行赋值,然后在main方法中去获取,很显然,这个时候是获取不到任何信息的。ThreadLocal中的值只能本线程获取。

        还有就是threadlocal只是一个key而已,它并没有什么所谓的map,也不能存数据,这个map只是使用了ThreadLocal中的一个弱引用的类ThreadLocalMap,至于这个map,也是存放在线程里面的,使用ThreadLocal去获取的时候,就是去获取的一个对象属性。线程中的threadlocals中有多个数据,那是说明有多个ThreadLocal往里面塞值。

      下面就是分析ThreadLocal源码,看看其中的机制了,不过在看源码之前,我还是想说下为什么要用ThreadLocal?最简单的,我直接传参过去不行么?这样也可以进行逻辑处理啊,还有我创建个静态CurrentHashMap不行么?通过Map传递参数。

       行,这样做当然没问题!ThreadLocal也只是起了一个参数传递的作用。

      但是:

      1》直接使用传参,如果是个别方法还好,如果有许多方法需要使用呢?不同Service之间调用呢?为了适应方法,必须都加上规定的入参,第一,这样的方法入参就会很臃肿,第二,就是各个业务之间耦合太深,不符合我们的设计规范,而且改动起来比较繁琐,我觉得这个应该就是使用ThreadLocal的主要原因。

      2》使用一个CurrentHashMap也是没有问题的,但是这样的话你还要专门去维护这个Map,而且这个Map并没有和线程绑定,如果其他线程修改了本线程的数据呢?这样数据安全的问题你是不是也要考虑?等等等等。

      总之,使用ThreadLocal的目的就是解耦,作用就是相同线程数据共享,不同线程数据隔离。

     之前就是一直纠结使用ThreadLocal干啥?使用传参还不是一样,总是不想去了解,知道了它的作用再去看源码,就感觉舒服多了,O(∩_∩)O哈哈~。

     好了,进入正题。

     

//threadLocal set方法 
public void set(T value) {
        // 获取当前线程
        Thread t = Thread.currentThread();
        // 获取当前线程的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

        这个代码很简单,大家都看的懂,我们一点一点来深入下。  获取当前线程不多说。

 ThreadLocalMap map = getMap(t);

        这段代码我还是想扯一下,当时第一眼看到这段代码,以为这也是一个Map的get方法,脑子里老是认为Thread有两个Map。其实我们进入这个方法就会发现,这个其实是获取当前线程中的一个属性threadLocals。

/*------------------threadLocal--------------------------------*/  
ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

/*--------------------thread------------------------------*/
   
 ThreadLocal.ThreadLocalMap threadLocals = null;
有图有真相~

        所以,这段代码是先获取线程中的一个ThreadLocalMap的一个属性,然后去获取里面的值, set()方法也是一样,只是获取ThreadLocalMap,相当于是一个类里面属性的get方法。map如果不为空,才会进行后续操作,否则就要初始化了。 

       我们继续看set方法,第一次set的时候ThreadLocalMap是空的,需要创建新的ThreadLocalMap。大家在调试的时候,建议新建一个线程,使用ThreadLocal,因为main方法在启动的时候会往ThreadLocalMap里面塞东西,我们再进行操作的时候,这个map不是空的。

//ThreadLocalMap 构造方法
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            // 容量初始化16
            table = new Entry[INITIAL_CAPACITY];
            //使用&,类似于求余数,不过速度快很多
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            // 新建一个entry对象
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            //设置threshold
            setThreshold(INITIAL_CAPACITY);
        }

//------------------------------------------------------------
 //负载因子是2/3,不是HashMap的0.75哦
   private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        这是一个有参的构造方法,大家都应该看的懂,就对参数的一些初始化,然后把key,value封装成一个entry,放到了table中,初始化的同时,也将数据保存了进去。

       //set 方法
        private void set(ThreadLocal<?> key, Object value) {
            
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            /* 遍历,线性探测法,就是hash定位的i是不为空,则进行循环 */
            for (Entry e = tab[i];e != null;e = tab[i = nextIndex(i, len)]) {
                // 获取当前位置的key
                ThreadLocal<?> k = e.get();
                // key 相同,value直接替换,然后返回
                if (k == key) {
                    e.value = value;
                    return;
                }
                // key 为空,e又不为空,那就是value不为空,key明显是被垃圾回收了,value设置为 
                //空,返回
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            // 为空直接设置就行
            tab[i] = new Entry(key, value);
            int sz = ++size;
            // rehash
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

          这个是ThreadLocal的set方法,key是弱引用,为什么要使用弱引用?因为如果这里使用普通的key-value形式来定义存储结构,实质上就会造成节点的生命周期与线程强绑定,只要线程没有销毁,那么节点在GC分析中一直处于可达状态,没办法被回收,而程序本身也无法判断是否可以清理节点。弱引用是Java中四档引用的第三档,比软引用更加弱一些,如果一个对象没有强引用链可达,那么一般活不过下一次GC。当某个ThreadLocal已经没有强引用可达,则随着它被垃圾回收,在ThreadLocalMap里对应的Entry的键值会失效,这为ThreadLocalMap本身的垃圾清理提供了便利。

         调用replaceStaleEntry(key, value, i)的最终效果就是把entry(key,value)放到i的位置,或者是新建一个entry放到i这里。下面看下代码:

 private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            //需要去清除的默认i位置
            int slotToExpunge = staleSlot;
			// 往前遍历,找到Slot(槽)为null的位置,然后退出,就是为了找到null之后第一个key=null的槽然后赋值给slotToExpunge(翻译:需要清除的槽)
			// 为什么要这样做呢?因为这个清除都是以槽为null作为标识的,这样应该是防止当前位置(往前)到null的地方没有清除掉
            for (int i = prevIndex(staleSlot, len);(e = tab[i]) != null;i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            //向后遍历,查找key的位置
            for (int i = nextIndex(staleSlot, len);(e = tab[i]) != null;i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
				// 找到了key的位置,因为ThreadLocalMap用的是线性探测寻址,位置不一定是求得的hash
                if (k == key) {
					//更新对应slot的value值,并与staleSlot进行互换位置,
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;
					// 开始清理,如果在整个扫描过程中(包括函数一开始的向前扫描与i之前的向后扫描),找到了之前的无效slot则以那个位置作为清理的起点,
                    //否则则以当前的i作为清理起点
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
					//清理方法,这个是调用expungeStaleEntry,清理槽(slot)
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // 如果当前的槽(slot)已经无效,并且向前扫描过程中没有无效槽(slot),则更新slotToExpunge为当前位置
				//  如果没有找到key,staleSlot这个位置的value会在下面置空,找到的话会和staleSlot互换位置(上面代码),然后这个staleSlot就到后面去了,(*^▽^*))
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // 呶,就在这吧staleSlot的value置空了
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // 没找到key,也是要做一次清除的哦
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

           不得不说,写源码的人就是牛皮(破音)!面面俱到,前后承接。threadlocal清除的思想就是以null为节点进行清除的(当然,也有全局清理的地方,下面会说到),所以replaceStaleEntry,是先去查询key之前的为null的槽(slot),然后才开始向后遍历,对比是否有slot(槽)的key与传来的key相等。之后就是对垃圾数据进行清理了。

        咱之前也是看了好久才看懂(╥╯^╰╥),下面接着看删除过期的key:

  // 删除过时entry
  private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // 删除就是从key为null的地方开始的,这个是在传参的地方控制的
			// staleSlot ->过时的槽
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // 刷新Map,直到槽(slot)为null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);(e = tab[i]) != null;i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
				// 判断key==null,就是已经回收的数据,数据肯定要置空的
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
					//这里是重新进行对entry进行位置处理,看看能不能放到求hash的地方,不然就线性探测,重新放置entry,保证map一直是符合线性探测放置entry
					//因为一些key删除掉了,整个map的排序可能会变得比较混乱,这里重新进行放置。
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;
						// 循环找到下一个空的槽(slot),位置赋值给h,然后e,也就是tab[i]赋值给tab[h]---->线性探测法
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

        我们也看到了删除的判断条件是(e = tab[i]) != null,也就是说是以null为节点的,不过却会多次调用。这个方法比较简单就是看一下key是否为空,为空的话就把数据清空,否则的话就使用线性探测法进行数据的重新排列,防止乱序什么的。

//清理一些槽
 private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
				//清除操作
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
					//循环调用
                    i = expungeStaleEntry(i);
                }
            } 
			// 这里应该是涉及到一个频率和分段的问题,初始化的话如果需要回收的少,那么四次也就结束,否则
			//  n = len; 会继续回收(哪位大神有了解可以还望指点下,感觉这样理解不全面)
			while ( (n >>>= 1) != 0); // n >>>= 1 无符号右移一位,直到为0
            return removed;
        }

  然后就是rehash()方法了,注意这个不是扩容哦,扩容之前还有操作的。

 private void rehash() {
            //再清理一次
            expungeStaleEntries();

            // 计算是否需要resize
            if (size >= threshold - threshold / 4)
                resize();
        }
//这里就是全局清理了,就是一个遍历
private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }
 //扩容
 private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
            //循环
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
            //初始化
            setThreshold(newLen);
            size = count;
            table = newTab;
        }

这个get方法贴一下,没啥好说的。

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
// --------------------------------------------------------------------
private T setInitialValue() {
        // 初始化value,就是设置成null
        T value = initialValue();
        //获取属性,设置map,上面已经说过了
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

然后我们来看下remove方法。

 //remove方法
 private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];e != null;e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

        在这里可以看到遍历并没有遍历全部,也是遍历到null的位置,因为ThreadLocalMap用的就是线性探测法,同时,在清除的时候也保持其独特的结构,所以最多也就会到下一个null!大佬写代码就是严谨!

        在最后再说明一个问题,如果我们这样创建threadlocal

private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();

那么,threadlocal在方法区一直都保持着引用,而框架中的线程都是使用线程池的,不是一定会直接销毁的,所以threadlocal不一定会被清除,所以一定要自己去remove,防止数据与期望不一致。至于ThreadLocal中key为null,那应该是在方法中创建的局部变量threadlocal,这样方法运行完,threadlocal就为空了。

好了,上面就是我对threadlocal 的理解了,如有漏洞,还请斧正哈~

No sacrifice,no victory~

猜你喜欢

转载自blog.csdn.net/zsah2011/article/details/109648019
今日推荐