ThreadLocal shared across thread pools

background

In actual development, we often use the thread pool to handle a large number of tasks, but the use of the thread pool will make the thread variables ThreadLocalinaccessible, which will be very uncomfortable.
For example, when we want to improve performance, we use the thread pool to call multiple services at the same time. It will be useful if you don't want to modify the original code and implement non-invasive features.

principle

We want to ThreadLocalshare thread variables across threads, which breaks the access restrictions provided by jdk.
ThreadLocalThe thread isolation is by maintaining a mapping table inside each thread ThreadLocalMap, and each acquisition is from the map of the current thread or the parent thread ( For InheritableThreadLocalthe value of ), the isolation of variable access between threads is realized.

// ThreadLocal 的部分源码
// 获取线程的ThreadLocalMap 
ThreadLocalMap getMap(Thread t) {
    
    
  return t.threadLocals;
}

// 先获取线程的ThreadLocalMap,再往对应的map中设置值
public void set(T value) {
    
    
	Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

To this end, we can maintain a static variable, record the table used by the current thread that needs to be shared across threads ThreadLocal, and then create a thread running context to copy the thread variable, and then replace it with the required thread variable before and after the thread is running, run Restore after you are done.

// 用该结构包围实际运行的方法
public void run() {
    
    
	Map<MyThreadLocal<Object>, Object> replace = null;
    try {
    
    
        replace = replace();
        // 设置上下文
        runnable.run();
    } catch (Exception e) {
    
    
        e.printStackTrace();
    } finally {
    
    
        // 还原上下文
        restore(replace);
    }
}

test code

import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * @author Lion Zhou
 * @date 2022/9/15
 */
public class Test {
    
    
    public static void test5() {
    
    
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        // 用一个空任务让线程池创建好线程
        executorService.submit(() -> {
    
    
        });

        // 使用我们定义好的线程变量
        MyThreadLocal<Integer> mtl1 = new MyThreadLocal<>();
        mtl1.set(333);

        executorService.submit(MyThreadLocalContext.go(() -> {
    
    
            System.out.print("1:");
            System.out.println(mtl1.get());
            // 修改线程变量,因为是副本,不影响其他线程中的值
            mtl1.set(111);
        }));

        executorService.submit(() -> {
    
    
            System.out.print("2:");
            // 正常使用为 null
            System.out.println(mtl1.get());
        });

        executorService.submit(MyThreadLocalContext.go(() -> {
    
    
            System.out.print("3:");
            // 还是 333
            System.out.println(mtl1.get());
        }));

        executorService.shutdown();
        System.out.println("end:" + mtl1.get());
    }

    public static void main(String[] args) {
    
    
        test5();
    }
}

source code

import java.util.WeakHashMap;

/**
 * @author Lion Zhou
 * @date 2022/9/15
 */
public class MyThreadLocal<T> extends InheritableThreadLocal<T> {
    
    

    // 维护每个线程所持有的 MyThreadLocal 为后续跨线程传递使用
    static InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = new InheritableThreadLocal<>();

    @Override
    public T get() {
    
    
        // 直接调用原本的 get 方法
        T t = super.get();
        if (null == t && null != holder.get()) {
    
    
            // 对应key的值已经不存在了,删除当前的持有数据
            holder.get().remove(this);
        }
        return t;
    }

    @Override
    public void set(T value) {
    
    
        super.set(value);
        if (holder.get() == null) {
    
    
            holder.set(new WeakHashMap<>(8));
        }
        holder.get().put((MyThreadLocal<Object>) this, null);
    }

    @Override
    public void remove() {
    
    
        super.remove();
        if (holder.get() != null) {
    
    
            holder.get().remove(this);
        }
    }
}
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.WeakHashMap;

/**
 * @author Lion Zhou
 * @date 2022/9/15
 */
public class MyThreadLocalContext {
    
    

    public static Runnable go(Runnable runnable) {
    
    
        InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = MyThreadLocal.holder;

        Map<MyThreadLocal<Object>, Object> map = Collections.emptyMap();
        if (null != holder.get()) {
    
    
            map = new WeakHashMap<>(holder.get().size());
//            System.out.println("start");
            for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.get().entrySet()) {
    
    
//                System.out.println(entry.getKey().get());
                map.put(entry.getKey(), entry.getKey().get());
            }
//            System.out.println("end");
        }
        return new Context(map, runnable);
    }

    public static class Context implements Runnable {
    
    
        Map<MyThreadLocal<Object>, Object> holder;
        Runnable runnable;

        public Context(Map<MyThreadLocal<Object>, Object> holder, Runnable runnable) {
    
    
            this.holder = holder;
            this.runnable = runnable;
        }

        public Map<MyThreadLocal<Object>, Object> replace() {
    
    
            // 保留原本的线程本地变量
            Map<MyThreadLocal<Object>, Object> replace = new WeakHashMap<>();

            // 将复制过来的值重新赋值给当前上下文环境
//            System.out.println("context start");
            // 上下文切换
            for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
    
    
//                System.out.println(String.format("old: %s, new: %s", Optional.ofNullable(entry.getKey().get()).orElse("null").toString(),
//                        entry.getValue()));

                // 保存 线程本地变量 的现场
                replace.put(entry.getKey(), entry.getKey().get());
                // 替换需要的上下文
                entry.getKey().set(entry.getValue());
            }
//            System.out.println("context end");
            return replace;
        }

        public void restore(Map<MyThreadLocal<Object>, Object> restore) {
    
    
            if (null == restore) {
    
    
                return;
            }
            for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
    
    
                // 原本的值
                Object old = restore.get(entry.getKey());
                if (null == old) {
    
    
                    // 原本就为null
                    entry.getKey().remove();
                } else {
    
    
                    entry.getKey().set(old);
                }
            }
        }

        @Override
        public void run() {
    
    
            Map<MyThreadLocal<Object>, Object> replace = null;
            try {
    
    
                replace = replace();
                // 设置上下文
                runnable.run();
            } catch (Exception e) {
    
    
                e.printStackTrace();
            } finally {
    
    
                // 还原上下文
                restore(replace);
            }
        }
    }

}

Summarize

The name of the source code is not good, forgive me.
The code is very simple, just for demonstration, there are still some problems, such as not using deepcopy when replacing the context, etc. The use
WeakHashMapis the basic problem, because the thread variable is cross-thread, not a unique value of the thread , so the life cycle of the original variable cannot be destroyed (resulting in memory leaks), so weak references should be used.

Relevant information

  • Ali's open source inter-thread context transfer solution supports programming and java agent

Guess you like

Origin blog.csdn.net/weixin_46080554/article/details/126872872