手撕 HashMap(空口无凭,实战为真)

如果你还不了解 HashMap,建议你先看我上一篇博客:
分享 HashMap 的精髓,它永远比你自己的写 map 效率高

这篇博客更多的阐述了如何去写,更多的是对代码的分析,需要有阅读代码的能力。
如果你不太擅长阅读代码,可以去看我的上一篇博客,里面更多的是对知识点语言描述,相对会更容易理解。

键值对的存储方式:Node

首先开始构建HashMap,并添加静态内部类Node
然后给Node创建属性 key, value, next
然后实现Node的基本方法
方法比较简单,只添加少量注释,粗略扫一遍即可

public class MyHashMap<K, V> extends AbstractMap<K, V> {
    static class Node<K, V> implements Map.Entry<K, V> {
    	final int hash; // 记录key的hashCode()值
        final K key;    // 键
        V value;        // 值
        Node<k, V> next;      // 指向下一个链表结点
        // 构造方法 给属性赋值
        Node(int hash, K key, V value) {
        	this.hash = hash;
            this.key = key;
            this.value = value;
        }
        // get 方法
        public final int getHash()       { return hash; }
        public final K getKey()          { return key; }
        public final V getValue()        { return value; }
        // set方法 设置新value 返回旧value
        public final V setValue(V newValue) {
            V oldValue = value;
            value = newValue;
            return oldValue;
        }
        public final String toString() { return key + "=" + value; }
        public final int hashCode() {
            return key.hashCode() ^ value.hashCode();
        }
        public final boolean equals(Object o) {
            if (o == this)
                return true;
            if (o instanceof Map.Entry) {
                Map.Entry<?,?> e = (Map.Entry<?,?>)o;
                if (e.hashCode() == this.hashCode())
                    return true;
            }
            return false;
        }
	}
}

属性:static和private成员变量

首先介绍 final static 变量

// 默认的数组大小 16
static final int DEFAULT_INITIAL_CAPACITY = 16;
// 数组最大的大小 2的30次方
static final int MAXIMUM_CAPACITY = 1 << 30;
// 默认的负载因子为 0.75
static final float DEFAULT_LOAD_FACTOR = 0.75f;

属于对象的成员变量

private Node<K, V>[] table; // 存放键值对的数组
private int size;           // 存放键值对的数目
private int threshold;      // 扩容阈值(表示size达到多少要扩容)
private int modCount;       // 相当于版本号
private final float loadFactor; // 负载因子(不可变)

构造方法(避免占用资源,延迟初始化)

  • 这里的代码与 jdk 源码大致相同,仅检查参数合理性,然后为属性赋值。
  • 我们发现,在给 threshold 赋值时,没有直接使用传递的 initialCapacity 参数,而是用了tableSizeFor() 方法计算出不小于它的最小的2的次方数作为它的值
  • 在 HashMap 刚刚创建时,内部的数组并没有初始化,这样子只有到真正用的时候才初始化,节省资源。
public MyHashMap(int initialCapacity, float loadFactor) {
    // 如果容量小于0,抛出异常
    if (initialCapacity < 0)
        throw new IllegalArgumentException("Illegal initial capacity: " +
                                           initialCapacity);
    // 容量大于最大值,则为最大值
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    // 负载因子不能小于0 否则抛出异常
    if (loadFactor <= 0)
        throw new IllegalArgumentException("Illegal load factor: " +
                                           loadFactor);
    // 赋值操作
    this.loadFactor = loadFactor;
    // tableSizeFor() 计算容量并赋值
    this.threshold = tableSizeFor(initialCapacity);
}

另两个构造方法十分简单,明白主构造函数即可

public MyHashMap(int initialCapacity) {
    this(initialCapacity, DEFAULT_LOAD_FACTOR);
}
public MyHashMap() {
	this.threshold = DEFAULT_INITIAL_CAPACITY;
    this.loadFactor = DEFAULT_LOAD_FACTOR;
}

tableSizeFor(),对数组长度的高效计算

  • 此处代码与 jdk 源码相同
  • 为了保证效率,此处计算全部采用位运算
  • a >>> b 表示a的二进制数中,所有位的数字,全部右移b位
  • a | b 是或运算,a与b表示的二进制数字中的所有相同的数字,如果都为0,则计算出的该位为0,否则该位为1。
  • 先跳过第一行,我们看后面几行,先进行右移运算,然后再或运算。
  • 我们发现,本来正数的最高位肯定不为0,所以在二进制中一定为1,第一次右移之后,原来的第一位就变成第二位,所以右移后的数字在第二位为1,,进行异或操作,第一位有1,计算出来的结果为1,第二位有1,计算出来也为1。所以第一次运算过后保证二进制数据的前两位一定为1。
  • 同理,第二次计算过后,前4位一定为1,然后8位,16位,32位,因为Integer一共只有32位,所以可以保证从第一位到最后一位都为1。
  • 然后这时只要加上1,就是2的倍数。
  • 现在很容易明白,前面减去1是为了防止这个数本身已经是2的倍数,通过计算后会变成它的两倍。
static final int tableSizeFor(int cap) {
    int n = cap - 1;
    n |= n >>> 1;
    n |= n >>> 2;
    n |= n >>> 4;
    n |= n >>> 8;
    n |= n >>> 16;
    return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1;
}

hash()方法,提高离散程度的高效算法

  • 这段代码仅仅是对 jdk 源码中的代码写得更清晰一点,效果与源码是相同的。
  • 此方法是对 key 的 hashCode() 值进行二次计算,将原值与它的无符号右移16位进行异或计算。
  • 因为我们平时使用到的 HashMap 用到的容量很小,但是 hashCode() 值很大,我们用哈希函数取模计算位置时,仅仅只使用到了最后的那几位。
  • 而这个方法对 hashCode() 值的前后16位进行异或操作,使得后16位的哈希值也同时包含了前16位的特性,这样的函数可以充分利用整个 hashCode() 值的更多位上的数值,拥有更好的离散性,因此可以更好避免冲突,从而使得 HashMap 拥有更好的性能。
static final int hash(Object key) {
    if(key == null)
        return 0;
    int h = key.hashCode();
    return h ^ (h >>> 16);
}

get()方法,三大主方法之一

get() 方法用于获取 key 对应的 value,要获取到 value 首先要获取到键值对结点对象。如果获取到,则返回 value,否则返回null

public V get(Object key) {
    Node<K, V> e = getNode(key);
    return e == null ? null : e.value;
}

所以主要的精力放在写 getNode() 方法上

  • 二次哈希
  • 计算出结点的位置
  • 然后遍历数组该位置的桶中所有结点
  • 如果存在结点,则返回,否则遍历结束后返回 null

计算位置是get,put,remove都需要用到的方法,所以把它单独写出,可以提高公用代码量。虽然只有一行,但是更便于表示它的步骤,代码更易阅读。

static final int getIndex(int hash, int n) {
    return hash & (n - 1);
}

getNode() 方法

private Node<K, V> getNode(Object key) {
    // hash()
    int h = hash(key);
    Node<K, V>[] tab = table;
    // 计算位置
    int n = tab.length;
    int i = getIndex(h, n);
    // 遍历该数组该位置桶中的所有结点
    Node<K, V> p = tab[i];
    if(key == null) {
        while(p != null) {
            if(key == p.getKey())
                return p; // 找到
            p = p.next;
        }
    }
    else {
        while(p != null) {
            if(key.equals(p.getKey()))
                return p; // 找到
            p = p.next;
        }
    }
    // 遍历所有没有找到
    return null;
}

resize()方法,学会put()方法的前提

  • 扩容方法,一般按照2倍的方式进行扩容。
  • 在数组还未初始化时,和存入键值对数目大于 threshold 时调用
  1. 计算新的数组容量和可容纳键值对数目 newCap newThr
  2. 根据新的容量创建新的数组
  3. 将旧数组中的内容全部添加入新数组

该方法比源码简略,且不涉及红黑树操作,更便于理解

private final Node<K, V>[] resize() {
    Node<K,V>[] oldTab = table;
    int oldCap = (oldTab == null) ? 0 : oldTab.length;
    int oldThr = threshold;
    int newCap, newThr;
    // 计算newCap,newThr
    if (oldCap > 0) { // 按照两倍进行扩容
        newCap = oldCap << 1;
        newThr = oldThr << 1;
        // 如果容量较小,则threshold不精确,重新计算
        if(newCap <= 64) {
            newThr = (int) (newCap * loadFactor);
        }
        // 如果到最大容量,则把扩容阈值设为最大
        if(newCap == MAXIMUM_CAPACITY)
            newThr = Integer.MAX_VALUE;
    }
    else { // oldCap为0,说明数组还未被初始化
        newCap = oldThr;
        int ft = (int) (newCap * loadFactor);
        newThr = ft < MAXIMUM_CAPACITY ? ft : Integer.MAX_VALUE;
    }
    // 创建新数组
    @SuppressWarnings("unchecked")
    Node<K,V>[] newTab = (Node<K, V>[]) new Node<?,?>[newCap];
    // 将旧数组结点添加入新数组
    if(oldCap != 0) {
        for(int t = 0; t < oldCap; t++) {
            Node<K, V> p = oldTab[t];
            while(p != null) {
                int h = p.getHash();
                int i = getIndex(h, newCap);
                Node<K, V> next = p.next;
                p.next = newTab[i];
                newTab[i] = p;
                p = next;
            }
        }
    }
    // 更新成员变量
    threshold = newThr;
    table = newTab;
    return newTab;
}

put()方法,三大主方法之一

  1. 如果存在 key,则更新它对应的 value
  2. 如果不存在,则插入新的键值对
  3. 如果数组未初始化,先初始化
  4. 二次哈希值
  5. 计算数组中桶的位置
  6. 遍历桶中结点,如果存在,则更新;否则在尾部插入新节点。
  7. 更新 modCount
  8. 更新 size
  9. 如果 size 大于 threshold,则需要扩容
public V put(K key, V value) {
    Node<K, V>[] tab = table;
    V oldV = null;
    // 1、如果未初始化则初始化
    if(tab == null || tab.length == 0)
        tab = resize();
    // 2、二次hash
    int n = tab.length;
    int h = hash(key);
    // 3、计算位置
    int i = getIndex(h, n);
    // 4、遍历结点,更新或插入
    boolean updated = false; // 表示是否更新过结点
    Node<K, V> p = tab[i];   // 首节点
    if(p == null) {
        tab[i] = new Node<>(h, key, value);
    }
    else {
        if(key == null) {
            if(key == p.getKey()) { // 首节点
                oldV = p.setValue(value);
                updated = true;
            }
            if(! updated) { // 首节点未更新,遍历之后结点
                while(p.next != null) {
                    if(key == p.next.getKey()) {
                        oldV = p.next.setValue(value);
                        updated = true;
                        break;
                    }
                    p = p.next;
                }
            }
            if(! updated) { // 没有更新过,插入新节点
                Node<K, V> node = new Node<>(h, key, value);
                p.next = node;
            }
        }
        else { // key不是null的情况
            if(key.equals(p.getKey())) { // 首节点
                oldV = p.setValue(value);
                updated = true;
            }
            if(! updated) { // 首节点未更新,遍历之后结点
                while(p.next != null) {
                    if(key.equals(p.next.getKey())) {
                        oldV = p.next.setValue(value);
                        updated = true;
                        break;
                    }
                    p = p.next;
                }
            }
            if(! updated) { // 没有更新过,插入新节点
                Node<K, V> node = new Node<>(h, key, value);
                p.next = node;
            }
        }
    }
    // 5、更新modCount
    ++modCount;
    // 6、更新size
    if(! updated) // 没有更新,说明新增结点
        ++size;
    // 7、检查扩容
    if(size > threshold)
        resize();
    return oldV;
}

remove()方法,三大主方法之一

移除 HashMap 中的 key 对应的结点,返回它的 value 值,否则返回 null。

  1. 二次哈希
  2. 计算出在数组中对应的桶的位置
  3. 遍历桶中结点
public V remove(Object key) {
    Node<K, V>[] tab = table;
    // 1、二次哈希
    int h = hash(key);
    // 2、计算位置
    int n = tab.length;
    int i = getIndex(h, n);
    // 3、遍历桶中结点
    Node<K, V> p = tab[i];
    if(p == null) // 桶为空
        return null;
    if(key == null) {
        // 首节点
        if(key == p.getKey()) {
            V value = p.getValue();
            tab[i] = p.next;
            return value;
        }
        // 之后的结点
        while(p.next != null) {
            if(key == p.next.getKey()) {
                V value = p.next.getValue();
                p.next = p.next.next;
                return value;
            }
        }
    }
    else {
        // 首节点
        if(key.equals(p.getKey())) {
            V value = p.getValue();
            tab[i] = p.next;
            return value;
        }
        // 之后的结点
        while(p.next != null) {
            if(key.equals(p.next.getKey())) {
                V value = p.next.getValue();
                p.next = p.next.next;
                return value;
            }
        }
    }
    return null;
}

entrySet():迭代HashMap的方法

  • 首先我们要知道,Map 接口是不具有 Iterator() 方法的,因此自身并不具备迭代遍历的能力。
  • 因此,为了遍历 Map 中的元素,则必须对 Map 中的元素做一个 Collection 集合进行迭代方可。
  • Map 共有三个返回集合的方法,一个是 keySet() 方法,用于返回所有 key 组成的集合,还有一个是 values() 方法,用于返回所有的 value 集合,并且可以重复,最后是 entrySet() 方法,返回所有键值对的集合。
  • 而 entrySet() 是前两个方法的基础,因为只需要迭代所有的键值对,就可以迭代所有的 key 和 value。
  • 对于该 Map 中返回的集合,只是对于 Map 起一个迭代功能的作用,里面的对象都是原有的同一个对象,对集合对象的增删改查也会直接影响到原有的 Map。
  1. 首先需要有一个 EntrySet 类
  2. 既然是作为迭代,那必然先有 EntryIterator 类
  3. 主要完成 EntryIterator 中的关键方法
  4. EntrySet 中的其他方法大家按需写即可
  5. 调用方法时,如果 EntrySet 已经有了对象,则直接返回,如果没有,则新建一个对象返回
// 创建一个变量保存entrySet
private Set<Map.Entry<K,V>> entrySet;
// entrySet()方法
public Set<Entry<K, V>> entrySet() {
    // 没有则创建一个
    if(entrySet == null) {
        entrySet = new EntrySet();
    }
    return entrySet; // 返回
}

EntryIterator 迭代器类

class EntryIterator implements Iterator<Entry<K, V>> {
    Node<K, V> next;        // 下一个结点
    Node<K, V> current;     // 当前结点
    // 快速失败机制,如果和modCount不一样,则抛异常
    int expectedModCount;
    int index;              // 迭代到哪一个桶
    EntryIterator() { // 构造方法
        expectedModCount = modCount;
        current = next = null;
        // 找出第一个结点 
        index = 0;
        while((next = table[index]) == null) {
            index++;
        }
    }
    public Entry<K, V> next() {
        // 如果modCount与原先不同,说明已被修改,抛出异常
        if (modCount != expectedModCount)
            throw new ConcurrentModificationException();
        // 如果next为null,说明已经迭代到末尾,抛出异常
        Node<K, V> e = next;
        if (e == null)
            throw new NoSuchElementException();
        // 找出下一个结点
        if ((next = (current = e).next) == null)
            while(++index < table.length && (next = table[index]) == null) {}
        return e;
    }
    public final void remove() {
        Node<K,V> p = current;
        // 当前为null,说明已经remove过,无法再remove
        if (p == null)
            throw new IllegalStateException();
        // 如果modCount与原先不同,说明已被修改,抛出异常
        if (modCount != expectedModCount)
            throw new ConcurrentModificationException();
        current = null;
        K key = p.key;
        // 调用MyHashMap的remove方法来移除
        MyHashMap.this.remove(key);
        expectedModCount = modCount;
    }
    public final boolean hasNext() {
        return next != null;
    }
}

EntrySet 类。如果上面的代码都已经写过了,那这里的代码就已经很简单了,注释相应减少,不作重点。

class EntrySet extends AbstractSet<Map.Entry<K,V>> {
    public final int size()                 { return size; }
    public final void clear()               { MyHashMap.this.clear(); }
    public final Iterator<Map.Entry<K,V>> iterator() {
        return new EntryIterator();
    }
    public final boolean contains(Object o) {
        if (!(o instanceof Map.Entry))
            return false;
        Map.Entry<?,?> e = (Map.Entry<?,?>) o;
        Object key = e.getKey();
        Node<K,V> candidate = getNode(key);
        return candidate != null && candidate.equals(e);
    }
    public final boolean remove(Object o) {
        if (o instanceof Map.Entry) {
            Map.Entry<?,?> e = (Map.Entry<?,?>) o;
            Object key = e.getKey();
            Node<K, V> candidate = getNode(key);
            if(candidate != null && candidate.equals(e)) {
                MyHashMap.this.remove(key);
                return true;
            }
            return false;
        }
        return false;
    }
    public final Spliterator<Map.Entry<K,V>> spliterator() {
        // 此方法不做讨论,省略
        return null;
    }
    public final void forEach(Consumer<? super Map.Entry<K,V>> action) {
        if (action == null)
            throw new NullPointerException();
        Iterator<Entry<K, V>> it = iterator();
        while(it.hasNext())
            action.accept(it.next());
    }
}

Map接口

要写HashMap,先从它的接口Map开始写起(仅仅一个接口,只包含了需要实现的方法,不涉及知识点,大家看注释即可理解)。
(当然,为了方便,我们可以直接 import JDK 自带的 Map 接口)

public interface Map<K, V> { //Map接口,指定了需要写哪些方法
    int size();    // 返回存储的键值对个数
    boolean isEmpty(); // 判空
    boolean containsKey(Object key); // 是否包含key
    boolean containsValue(Object value);// 是否包含value
    V get(K key);          // 获取key对应的value
    V put(K key, V value); // 存储键值对
    V remove(Object key);  // 移除指定key的键值对
    // 将Map中的所有键值对存储到该Map中
    void putAll(Map<? extends K, ? extends V> m);
    void clear();    // 清空
    Set<K> keySet(); // 返回key的集合
    Collection<V> values(); // 返回value集合(可重复)
    Set<Map.Entry<K, V>> entrySet(); // 返回键值对集合
    // 内部接口 定义了键值对的所拥有的方法
    interface Entry<K, V> {
        K getKey();
        V getValue();
        // 设置新value 返回旧value
        V setValue(V value);
    }
}

AbstractMap抽象类(isEmpty,toString,equals,hashCode)

看过 java.util 包里内容的应该知道,jdk 源码对于每一个集合接口,都写了一个抽象类,实现了一些基本的方法,这样在写具体的类的时候就不用写这些基本代码。

  • toString()
  • equals(Object obj)
  • hashCode()

笔者并未全部列出,此处不作重点,大家可以看 jdk 的详细源码

public abstract class AbstractMap<K,V> implements Map<K,V> {
    // 用于表示
	public String toString() {
		StringBuilder str = new StringBuilder();
		Iterator<Entry<K,V>> it = entrySet().iterator();
		while(it.hasNext())
			str.append(it.next().toString()).append(" ");
		return str.toString();
	}
	public boolean equals(Object o) {
		if(o == this)
			return true;
		if(o instanceof Map) {
			Map<?,?> map = (Map<?,?>) o;
			Iterator<?> itm = map.entrySet().iterator();
			Iterator<Entry<K, V>> itt = entrySet().iterator();
			while(itm.hasNext() && itt.hasNext()) {
				Object om = itm.next();
				Entry<K, V> e = itt.next();
				if(om.equals(e) == false)
					return false;
			}
			if(itm.hasNext() || itt.hasNext())
				return false;
		}
		return false;
	}
	public int hashCode() {
		int h = 0;
		Iterator<Entry<K, V>> it = entrySet().iterator();
		if(it.hasNext())
			h = it.next().hashCode();
		while(it.hasNext())
			h = h ^ it.next().hashCode();
		return h;
    }
}

题外话

这篇文章我并没有直接拿 HashMap 的源代码来和大家分析。不过事实上笔者的 HashMap 在一般的效用上来说与 JDK 的源码并没有太大的差距。
我们要去学习 HashMap 源码的话,很好的方式就是去亲手实现一个。可能一开始写得并不好,不过随着你不断去研究源码中的这些优点,你可以试着不断将这些优点囊括到你自己的代码中来,这样就会取得长足的进步。
此外,若是你能想到更好的源码中没有用到的优秀算法,那你的能力想必也是非常之高了。

MyHashMap 代码(需要直接复制)

如果有错误,也请在评论区指正。

import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.ConcurrentModificationException;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.Spliterator;
import java.util.function.Consumer;

public class MyHashMap<K, V> extends AbstractMap<K, V> {
    static class Node<K, V> implements Map.Entry<K, V> {
        final int hash;
        final K key;
        V value;
        Node<K, V> next;
        Node(int hash, K key, V value) {
            this.hash = hash;
            this.key = key;
            this.value = value;
        }
        public final int getHash()       { return hash; }
        public final K getKey()          { return key; }
        public final V getValue()        { return value; }
        public final V setValue(V newValue) {
            V oldValue = value;
            value = newValue;
            return oldValue;
        }
        public final String toString() { return key + "=" + value; }
        public final int hashCode() {
            return hash(key) ^ hash(value);
        }
        public final boolean equals(Object o) {
            if (o == this)
                return true;
            if (o instanceof Map.Entry) {
                Map.Entry<?,?> e = (Map.Entry<?,?>)o;
                if (e.hashCode() == this.hashCode())
                    return true;
            }
            return false;
        }
    }
    static final int DEFAULT_INITIAL_CAPACITY = 16;
    static final int MAXIMUM_CAPACITY = 1 << 30;
    static final float DEFAULT_LOAD_FACTOR = 0.75f;
    static final int hash(Object key) {
        if(key == null)
            return 0;
        int h = key.hashCode();
        return h ^ (h >>> 16);
    }
    static final int tableSizeFor(int cap) {
        int n = cap - 1;
        n |= n >>> 1;
        n |= n >>> 2;
        n |= n >>> 4;
        n |= n >>> 8;
        n |= n >>> 16;
        return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1;
    }
    static final int getIndex(int hash, int n) {
        return hash & (n - 1);
    }
    private Node<K, V>[] table;
    private int size;
    private int threshold;
    private int modCount;
    private final float loadFactor;
    public MyHashMap(int initialCapacity, float loadFactor) {
        if (initialCapacity < 0)
            throw new IllegalArgumentException("Illegal initial capacity: " +
                                               initialCapacity);
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        if (loadFactor <= 0 || Float.isNaN(loadFactor))
            throw new IllegalArgumentException("Illegal load factor: " +
                                               loadFactor);
        this.loadFactor = loadFactor;
        this.threshold = tableSizeFor(initialCapacity);
    }
    public MyHashMap(int initialCapacity) {
        this(initialCapacity, DEFAULT_LOAD_FACTOR);
    }
    public MyHashMap() {
        this.loadFactor = DEFAULT_LOAD_FACTOR;
    }
    public V get(Object key) {
        Node<K, V> e = getNode(key);
        return e == null ? null : e.value;
    }
    private Node<K, V> getNode(Object key) {
        // hash()
        int h = hash(key);
        Node<K, V>[] tab = table;
        // 计算位置
        int n = tab.length;
        int i = getIndex(h, n);
        // 遍历该数组该位置桶中的所有结点
        Node<K, V> p = tab[i];
        if(key == null) {
            while(p != null) {
                if(key == p.getKey())
                    return p; // 找到
                p = p.next;
            }
        }
        else {
            while(p != null) {
                if(key.equals(p.getKey()))
                    return p; // 找到
                p = p.next;
            }
        }
        // 遍历所有没有找到
        return null;
    }
    private final Node<K, V>[] resize() {
        Node<K,V>[] oldTab = table;
        int oldCap = (oldTab == null) ? 0 : oldTab.length;
        int oldThr = threshold;
        int newCap, newThr;
        // 计算newCap,newThr
        if (oldCap > 0) { // 按照两倍进行扩容
            newCap = oldCap << 1;
            newThr = oldThr << 1;
            // 如果容量较小,则threshold不精确,重新计算
            if(newCap <= 64) {
                newThr = (int) (newCap * loadFactor);
            }
            // 如果到最大容量,则把扩容阈值设为最大
            if(newCap == MAXIMUM_CAPACITY)
                newThr = Integer.MAX_VALUE;
        }
        else { // oldCap为0,说明数组还未被初始化
            newCap = oldThr;
            int ft = (int) (newCap * loadFactor);
            newThr = ft < MAXIMUM_CAPACITY ? ft : Integer.MAX_VALUE;
        }
        // 创建新数组
        @SuppressWarnings("unchecked")
        Node<K,V>[] newTab = (Node<K, V>[]) new Node<?,?>[newCap];
        // 将旧数组结点添加入新数组
        if(oldCap != 0) {
            for(int t = 0; t < oldCap; t++) {
                Node<K, V> p = oldTab[t];
                while(p != null) {
                    int h = p.getHash();
                    int i = getIndex(h, newCap);
                    Node<K, V> next = p.next;
                    p.next = newTab[i];
                    newTab[i] = p;
                    p = next;
                }
            }
        }
        // 更新成员变量
        threshold = newThr;
        table = newTab;
        return newTab;
    }
    public V put(K key, V value) {
        Node<K, V>[] tab = table;
        V oldV = null;
        // 1、如果未初始化则初始化
        if(tab == null || tab.length == 0)
            tab = resize();
        // 2、二次hash
        int n = tab.length;
        int h = hash(key);
        // 3、计算位置
        int i = getIndex(h, n);
        // 4、遍历结点,更新或插入
        boolean updated = false; // 表示是否更新过结点
        Node<K, V> p = tab[i];   // 首节点
        if(p == null) {
            tab[i] = new Node<>(h, key, value);
        }
        else {
            if(key == null) {
                if(key == p.getKey()) { // 首节点
                    oldV = p.setValue(value);
                    updated = true;
                }
                if(! updated) { // 首节点未更新,遍历之后结点
                    while(p.next != null) {
                        if(key == p.next.getKey()) {
                            oldV = p.next.setValue(value);
                            updated = true;
                            break;
                        }
                        p = p.next;
                    }
                }
                if(! updated) { // 没有更新过,插入新节点
                    Node<K, V> node = new Node<>(h, key, value);
                    p.next = node;
                }
            }
            else { // key不是null的情况
                if(key.equals(p.getKey())) { // 首节点
                    oldV = p.setValue(value);
                    updated = true;
                }
                if(! updated) { // 首节点未更新,遍历之后结点
                    while(p.next != null) {
                        if(key.equals(p.next.getKey())) {
                            oldV = p.next.setValue(value);
                            updated = true;
                            break;
                        }
                        p = p.next;
                    }
                }
                if(! updated) { // 没有更新过,插入新节点
                    Node<K, V> node = new Node<>(h, key, value);
                    p.next = node;
                }
            }
        }
        // 5、更新modCount
        ++modCount;
        // 6、更新size
        if(! updated) // 没有更新,说明新增结点
            ++size;
        // 7、检查扩容
        if(size > threshold)
            resize();
        return oldV;
    }
    public V remove(Object key) {
        Node<K, V>[] tab = table;
        // 1、二次哈希
        int h = hash(key);
        // 2、计算位置
        int n = tab.length;
        int i = getIndex(h, n);
        // 3、遍历桶中结点
        Node<K, V> p = tab[i];
        if(p == null) // 桶为空
            return null;
        if(key == null) {
            // 首节点
            if(key == p.getKey()) {
                V value = p.getValue();
                tab[i] = p.next;
                return value;
            }
            // 之后的结点
            while(p.next != null) {
                if(key == p.next.getKey()) {
                    V value = p.next.getValue();
                    p.next = p.next.next;
                    return value;
                }
            }
        }
        else {
            // 首节点
            if(key.equals(p.getKey())) {
                V value = p.getValue();
                tab[i] = p.next;
                return value;
            }
            // 之后的结点
            while(p.next != null) {
                if(key.equals(p.next.getKey())) {
                    V value = p.next.getValue();
                    p.next = p.next.next;
                    return value;
                }
            }
        }
        return null;
    }
    public Set<Entry<K, V>> entrySet() {
        if(entrySet == null) {
            entrySet = new EntrySet();
        }
        return entrySet;
    }
    private Set<Map.Entry<K,V>> entrySet;
    class EntryIterator implements Iterator<Entry<K, V>> {
        Node<K, V> next;        // 下一个结点
        Node<K, V> current;     // 当前结点
        // 快速失败机制,如果和modCount不一样,则抛异常
        int expectedModCount;
        int index;              // 迭代到哪一个桶
        EntryIterator() { // 构造方法
            expectedModCount = modCount;
            current = next = null;
            // 找出第一个结点 
            index = 0;
            while((next = table[index]) == null) {
                index++;
            }
        }
        public Entry<K, V> next() {
            // 如果modCount与原先不同,说明已被修改,抛出异常
            if (modCount != expectedModCount)
                throw new ConcurrentModificationException();
            // 如果next为null,说明已经迭代到末尾,抛出异常
            Node<K, V> e = next;
            if (e == null)
                throw new NoSuchElementException();
            // 找出下一个结点
            if ((next = (current = e).next) == null)
                while(++index < table.length && (next = table[index]) == null) {}
            return e;
        }
        public final void remove() {
            Node<K,V> p = current;
            // 当前为null,说明已经remove过,无法再remove
            if (p == null)
                throw new IllegalStateException();
            // 如果modCount与原先不同,说明已被修改,抛出异常
            if (modCount != expectedModCount)
                throw new ConcurrentModificationException();
            current = null;
            K key = p.key;
            // 调用MyHashMap的remove方法来移除
            MyHashMap.this.remove(key);
            expectedModCount = modCount;
        }
        public final boolean hasNext() {
            return next != null;
        }
    }
    class EntrySet extends AbstractSet<Map.Entry<K,V>> {
        public final int size()                 { return size; }
        public final void clear()               { MyHashMap.this.clear(); }
        public final Iterator<Map.Entry<K,V>> iterator() {
            return new EntryIterator();
        }
        public final boolean contains(Object o) {
            if (!(o instanceof Map.Entry))
                return false;
            Map.Entry<?,?> e = (Map.Entry<?,?>) o;
            Object key = e.getKey();
            Node<K,V> candidate = getNode(key);
            return candidate != null && candidate.equals(e);
        }
        public final boolean remove(Object o) {
            if (o instanceof Map.Entry) {
                Map.Entry<?,?> e = (Map.Entry<?,?>) o;
                Object key = e.getKey();
                Node<K, V> candidate = getNode(key);
                if(candidate != null && candidate.equals(e)) {
                    MyHashMap.this.remove(key);
                    return true;
                }
                return false;
            }
            return false;
        }
        public final Spliterator<Map.Entry<K,V>> spliterator() {
            // 此方法不做讨论,省略
            return null;
        }
        public final void forEach(Consumer<? super Map.Entry<K,V>> action) {
            if (action == null)
                throw new NullPointerException();
            Iterator<Entry<K, V>> it = iterator();
            while(it.hasNext())
                action.accept(it.next());
        }
    }
}
发布了9 篇原创文章 · 获赞 123 · 访问量 4546

猜你喜欢

转载自blog.csdn.net/weixin_44051223/article/details/104100171
今日推荐