【Java源码解析】ThreadLocal

简介

线程本地变量,用于同一线程之间的传递。每一个线程对象都保存在两个ThreadLocalMap,threadLocals和inheritableThreadLocals,后者会继承父线程的本地变量,以ThreadLocal对象为key,取得map里的值。

源码

属性和构造方法

 1     // 哈希值
 2     private final int threadLocalHashCode = nextHashCode();
 3 
 4     private static AtomicInteger nextHashCode =
 5         new AtomicInteger();
 6 
 7     private static final int HASH_INCREMENT = 0x61c88647;
 8 
 9     private static int nextHashCode() {
10         return nextHashCode.getAndAdd(HASH_INCREMENT);
11     }
12 
13     protected T initialValue() { // 初始值为空
14         return null;
15     }
16 
17     public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) { // 函数式编程,返回一个ThreadLocal对象
18         return new SuppliedThreadLocal<>(supplier);
19     }
20 
21     public ThreadLocal() { // 构造方法
22     }

SuppliedThreadLocal

 1     static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
 2 
 3         private final Supplier<? extends T> supplier;
 4 
 5         SuppliedThreadLocal(Supplier<? extends T> supplier) {
 6             this.supplier = Objects.requireNonNull(supplier);
 7         }
 8 
 9         @Override
10         protected T initialValue() {
11             return supplier.get();
12         }
13     }

基本方法

get()

 1     public T get() {
 2         Thread t = Thread.currentThread(); // 获取当前线程
 3         ThreadLocalMap map = getMap(t); // 根据当前线程获取ThreadLocalMap, Thread#threadLocals,子类可重写getMap()方法,比如InheritableThreadLocal, 返回的就是Thread#inheritableThreadLocals
 4         if (map != null) {
 5             ThreadLocalMap.Entry e = map.getEntry(this); // 以当前ThreadLocal对象为key, 取得value(ThreadLocalMap.Entry)
 6             if (e != null) {
 7                 @SuppressWarnings("unchecked")
 8                 T result = (T)e.value; // 返回结果
 9                 return result;
10             }
11         }
12         return setInitialValue(); // 如果map为空,执行初始化
13     }

getMap()

1     ThreadLocalMap getMap(Thread t) { // 获得map
2         return t.threadLocals;
3     }

setInitialValue()

 1     private T setInitialValue() { // 设置初始值
 2         T value = initialValue(); // 取得初始值
 3         Thread t = Thread.currentThread(); // 当前线程
 4         ThreadLocalMap map = getMap(t);  // 获取map
 5         if (map != null) // 不为空,设置value, key为当前对象
 6             map.set(this, value);
 7         else
 8             createMap(t, value); // 否则,创建map
 9         return value;
10     }

set()

1     public void set(T value) { // 设置值,逻辑同setInitialValue()
2         Thread t = Thread.currentThread();
3         ThreadLocalMap map = getMap(t);
4         if (map != null)
5             map.set(this, value);
6         else
7             createMap(t, value);
8     }

remove()

1      public void remove() { // 移除
2          ThreadLocalMap m = getMap(Thread.currentThread());
3          if (m != null)
4              m.remove(this);
5      }

创建map(ThreadLocalMap)

1     void createMap(Thread t, T firstValue) { // 创建map(ThreadLocalMap)
2         t.threadLocals = new ThreadLocalMap(this, firstValue);
3     }
4 
5     static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) { // 继承map值
6         return new ThreadLocalMap(parentMap);
7     }

ThreadLocalMap

属性

 1         static class Entry extends WeakReference<ThreadLocal<?>> { // 弱引用
 2             Object value;
 3 
 4             Entry(ThreadLocal<?> k, Object v) {
 5                 super(k);
 6                 value = v;
 7             }
 8         }
 9 
10         private static final int INITIAL_CAPACITY = 16; // 初始容量
11 
12         private Entry[] table; // 数组
13 
14         private int size = 0; // 大小
15 
16         private int threshold; // 阈值,长度的2/3
17 
18         private void setThreshold(int len) { // 设置阈值
19             threshold = len * 2 / 3;
20         }
21 
22         private static int nextIndex(int i, int len) { // 下一个索引
23             return ((i + 1 < len) ? i + 1 : 0); // 长度范围内,加1; 超过范围,0
24         }
25 
26         private static int prevIndex(int i, int len) { // 前一个索引
27             return ((i - 1 >= 0) ? i - 1 : len - 1); // 长度范围内,减1; 超过范围,len - 1
28         }

构造方法

 1         ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) { // 构造方法
 2             table = new Entry[INITIAL_CAPACITY]; // 初始化数组,初始容量为16
 3             int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); // firstKey的索引
 4             table[i] = new Entry(firstKey, firstValue); // 构造Entry(firstKey->firstValue)对象,并置于i索引处
 5             size = 1; // 当前大小为1 
 6             setThreshold(INITIAL_CAPACITY); // 设置阈值
 7         }
 8 
 9         private ThreadLocalMap(ThreadLocalMap parentMap) { // 以父ThreadLocalMap构造ThreadLocalMap
10             Entry[] parentTable = parentMap.table; // 取得parentMap的table
11             int len = parentTable.length; // 取得table长度
12             setThreshold(len); // 设置阈值
13             table = new Entry[len]; // 创建数组table
14 
15             for (int j = 0; j < len; j++) {
16                 Entry e = parentTable[j];
17                 if (e != null) {
18                     @SuppressWarnings("unchecked")
19                     ThreadLocal<Object> key = (ThreadLocal<Object>) e.get(); // 取得key
20                     if (key != null) {
21                         Object value = key.childValue(e.value); // 计算子线程的value
22                         Entry c = new Entry(key, value); // 构建Entry对象
23                         int h = key.threadLocalHashCode & (len - 1); // 计算索引
24                         while (table[h] != null) // 如果索引不为空,则计算下一个索引,直到找到空位
25                             h = nextIndex(h, len); // 寻找下一个索引,(hash碰撞时,没有使用链表,而是寻找下一个索引)
26                         table[h] = c;
27                         size++; // 长度加1
28                     }
29                 }
30             }
31         }

基本方法

getEntry()

1         private Entry getEntry(ThreadLocal<?> key) { // 获取entry
2             int i = key.threadLocalHashCode & (table.length - 1); // 计算索引
3             Entry e = table[i]; // 取得entry
4             if (e != null && e.get() == key) // 找到返回
5                 return e;
6             else
7                 return getEntryAfterMiss(key, i, e); // 否则调用getEntryAfterMiss方法
8         }

getEntryAfterMiss()

 1         private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { // 处理碰撞的情况
 2             Entry[] tab = table;
 3             int len = tab.length;
 4 
 5             while (e != null) { // 若e为空,直接返回null, 否则遍历table
 6                 ThreadLocal<?> k = e.get(); // 获得key
 7                 if (k == key) // 若相等,则找到返回
 8                     return e;
 9                 if (k == null)
10                     expungeStaleEntry(i); // 删除过期的Entry对象
11                 else
12                     i = nextIndex(i, len); // 计算下一个索引,继续寻找
13                 e = tab[i];
14             }
15             return null;
16         }

expungeStaleEntry()

 1         private int expungeStaleEntry(int staleSlot) {
 2             Entry[] tab = table;
 3             int len = tab.length;
 4 
 5             tab[staleSlot].value = null; // value置为空
 6             tab[staleSlot] = null; // 槽位置为空
 7             size--; // size减1
 8 
 9             Entry e;
10             int i;
11             for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { // 以staleSlot起始,索引与之碰撞的所有槽位,尝试清除无效的元素
12                 ThreadLocal<?> k = e.get(); // key
13                 if (k == null) { // 过期,清理
14                     e.value = null;
15                     tab[i] = null;
16                     size--;
17                 } else {
18                     int h = k.threadLocalHashCode & (len - 1); // 重新归置元素
19                     if (h != i) {
20                         tab[i] = null; // 原来的槽位清空
21                         while (tab[h] != null) // 以h为始,找空位
22                             h = nextIndex(h, len);
23                         tab[h] = e; // 设置元素
24                     }
25                 }
26             }
27             return i;
28         }

set()

 1         private void set(ThreadLocal<?> key, Object value) {
 2             Entry[] tab = table;
 3             int len = tab.length;
 4             int i = key.threadLocalHashCode & (len - 1); // 计算索引
 5             for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
 6                 ThreadLocal<?> k = e.get();
 7                 if (k == key) { // 命中,更新value, 返回
 8                     e.value = value;
 9                     return;
10                 }
11                 if (k == null) { // 替换过期的槽位
12                     replaceStaleEntry(key, value, i);
13                     return;
14                 }
15             }
16             tab[i] = new Entry(key, value); // 找到空位,新建Entry对象
17             int sz = ++size; // size加1
18             if (!cleanSomeSlots(i, sz) && sz >= threshold) // 清理槽位失败,并且当前size大于阈值,调用rehash方法
19                 rehash();
20         }

replaceStaleEntry()

 1         private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
 2             Entry[] tab = table;
 3             int len = tab.length;
 4             Entry e;
 5 
 6             int slotToExpunge = staleSlot;
 7             for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len)) // 从当前staleSlot往前找
 8                 if (e.get() == null)
 9                     slotToExpunge = i; // 过期槽位起始处,接下来从slotToExpunge清理过期槽位
10 
11             for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { // 从当前staleSlot往后找
12                 ThreadLocal<?> k = e.get();
13                 if (k == key) { // 命中,替换
14                     e.value = value; // 替换value
15                     tab[i] = tab[staleSlot]; // 交换i和staleSlot的元素,i槽位处等待被清理
16                     tab[staleSlot] = e;
17                     if (slotToExpunge == staleSlot) // 在staleSlot槽位之前没有过期的槽位,将slotToExpunge设置为i(staleSlot之后的槽位,因为staleSlot已经设置了有效的元素)
18                         slotToExpunge = i;
19                     cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); // 清理工作
20                     return;
21                 }
22                 if (k == null && slotToExpunge == staleSlot)
23                     slotToExpunge = i; // 在staleSlot槽位之前没有过期的槽位,将slotToExpunge设置为i(staleSlot之后的槽位,因为staleSlot后面会设置有效的元素)
24             }
25             tab[staleSlot].value = null; // 置空
26             tab[staleSlot] = new Entry(key, value); // 设置新的值
27             if (slotToExpunge != staleSlot) // 如果有过期元素,做清理工作
28                 cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
29         }

cleanSomeSlots()

 1         private boolean cleanSomeSlots(int i, int n) {
 2             boolean removed = false;
 3             Entry[] tab = table;
 4             int len = tab.length;
 5             do {
 6                 i = nextIndex(i, len);
 7                 Entry e = tab[i];
 8                 if (e != null && e.get() == null) { // 找到过期元素,执行清理操作
 9                     n = len;
10                     removed = true;
11                     i = expungeStaleEntry(i); // 具体操作还是由expungeStaleEntry完成
12                 }
13             } while ((n >>>= 1) != 0);
14             return removed;
15         }

rehash()

1         private void rehash() {
2             expungeStaleEntries(); // 清理过期的元素
3             if (size >= threshold - threshold / 4)
4                 resize(); // 扩容
5         }

expungeStaleEntries()

1         private void expungeStaleEntries() {
2             Entry[] tab = table;
3             int len = tab.length;
4             for (int j = 0; j < len; j++) { // 从槽位0处,尝试清理过期的条目
5                 Entry e = tab[j];
6                 if (e != null && e.get() == null)
7                     expungeStaleEntry(j); // 调用expungeStaleEntry方法
8             }
9         }

resize()

 1         private void resize() { // 扩容
 2             Entry[] oldTab = table;
 3             int oldLen = oldTab.length;
 4             int newLen = oldLen * 2; // 2倍
 5             Entry[] newTab = new Entry[newLen]; // 新数组
 6             int count = 0;
 7             for (int j = 0; j < oldLen; ++j) {
 8                 Entry e = oldTab[j];
 9                 if (e != null) {
10                     ThreadLocal<?> k = e.get();
11                     if (k == null) {
12                         e.value = null; // 帮助GC
13                     } else {
14                         int h = k.threadLocalHashCode & (newLen - 1); // rehash
15                         while (newTab[h] != null) // 碰撞,计算下一个索引(槽位)
16                             h = nextIndex(h, newLen);
17                         newTab[h] = e;
18                         count++;
19                     }
20                 }
21             }
22 
23             setThreshold(newLen); // 设置新的阈值
24             size = count;
25             table = newTab;
26         }

remove()

 1         private void remove(ThreadLocal<?> key) {
 2             Entry[] tab = table;
 3             int len = tab.length;
 4             int i = key.threadLocalHashCode & (len - 1);
 5             for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
 6                 if (e.get() == key) {
 7                     e.clear(); // 移除
 8                     expungeStaleEntry(i); // 并清理过期槽位
 9                     return;
10                 }
11             }
12         }

行文至此结束。

尊重他人的劳动,转载请注明出处:http://www.cnblogs.com/aniao/p/aniao_threadlocal.html

猜你喜欢

转载自www.cnblogs.com/aniao/p/aniao_threadlocal.html