Java集合的交集、并集、差集

集合运算

retainAll

最近写代码时,遇到对集合进行操作:交集、并集、差集。
对于并集,最开始写法如下:

    @Test
	public void should_get_union_with_removeAll_and_addAll() {
        List<Long> result = Lists.newArrayList();
        List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
        List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);
        result.addAll(s1);
        result.removeAll(s2);
        result.addAll(s2);
        assertThat(result).isSubsetOf(1L, 2L, 3L, 7L);
    }

这里利用List 的removeAll()和 addAll()方法,先从第一个集合中去掉两个集合的共同元素,再加上第二个集合。点开removeAll()函数源码如下:

    public boolean removeAll(Collection<?> c) {
        Objects.requireNonNull(c);
        return batchRemove(c, false);
    }

    // ...
    
    public boolean retainAll(Collection<?> c) {
        Objects.requireNonNull(c);
        return batchRemove(c, true);
    }

    private boolean batchRemove(Collection<?> c, boolean complement) {
        final Object[] elementData = this.elementData;
        int r = 0, w = 0;
        boolean modified = false;
        try {
            for (; r < size; r++)
                if (c.contains(elementData[r]) == complement)
                    elementData[w++] = elementData[r];
        } finally {
            // Preserve behavioral compatibility with AbstractCollection,
            // even if c.contains() throws.
            if (r != size) {
                System.arraycopy(elementData, r,
                                 elementData, w,
                                 size - r);
                w += size - r;
            }
            if (w != size) {
                // clear to let GC do its work
                for (int i = w; i < size; i++)
                    elementData[i] = null;
                modCount += size - w;
                size = w;
                modified = true;
            }
        }
        return modified;
    }

也就是说 removeAll() 和 retainAll() 实现类似,都是线性复杂度,不过containes()函数也用到了一次循环,因此复杂度都是n^2。

同样的点开HashSet的源码(在AbstractCollection.java中):

    // AbstractCollection 中
    public boolean retainAll(Collection<?> c) {
        Objects.requireNonNull(c);
        boolean modified = false;
        Iterator<E> it = iterator();
        while (it.hasNext()) {
            if (!c.contains(it.next())) {
                it.remove();
                modified = true;
            }
        }
        return modified;
    }

因此,很容易想到用这几个方法获取集合的交集和差集。代码如下:

    // 并集
    public Set<Integer> getUnion(Set<Integer> set1, Set<Integer> set2) {
        Set<Integer> result = new HashSet<>();
        result.addAll(set1);
        result.removeAll(set2);
        result.addAll(set2);
        return result;
    }

	// 交集
    public Set<Integer> getIntersection(Set<Integer> set1, Set<Integer> set2) {
        Set<Integer> result = new HashSet<>();
        result.addAll(set1);
        // 保留所有set2
        result.retainAll(set2);
        return result;
    }

    // 差集
    public Set<Integer> getSubtraction(Set<Integer> set1, Set<Integer> set2) {
        Set<Integer> result = new HashSet<>();
        result.addAll(set1);
        result.removeAll(set2);
        return result;
    }

这里解释下:
s1.retainAll(s2):对集合s1,保留出现在s2中的元素。
s1.removeAll(s2):对集合s1,删除出现在s2中的元素。

这俩方法在集合类中都存在。不过使用时还是有点问题。

    @Test
	public void should_get_union_with_removeAll_and_addAll() {
        List<Long> result = Lists.newArrayList();
        List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
        List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);
        result.addAll(s1);
        result.removeAll(s2);
        result.addAll(s2);
        assertThat(result).hasSize(4);// 断言失败,集合中出现两个 3
    }

比如这个例子中,原始集合中出现重复的元素会被保留。

Apache

因此开始找工具类,Apache的common-collection包提供了集合操作方法。添加依赖如下:

        <!-- https://mvnrepository.com/artifact/org.apache.commons/commons-collections4 -->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-collections4</artifactId>
            <version>4.2</version>
        </dependency>

代码如下:

    @Test
    public void should_get_union_with_apache_union() {
        List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
        List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);

        List<Long> result = (List<Long>) CollectionUtils.union(s1, s2);
        assertThat(result).hasSize(4);// 断言失败,集合中出现两个 3
    }

测试依旧失败,点开源码:


    public static <O> Collection<O> union(final Iterable<? extends O> a, final Iterable<? extends O> b) {
        final SetOperationCardinalityHelper<O> helper = new SetOperationCardinalityHelper<>(a, b);
        for (final O obj : helper) {
            helper.setCardinality(obj, helper.max(obj));
        }
        return helper.list();
    }


    private static class SetOperationCardinalityHelper<O> extends CardinalityHelper<O> implements Iterable<O> {

        // ...

        /**
         * Add the object {@code count} times to the result collection.
         * @param obj  the object to add
         * @param count  the count
         */
        public void setCardinality(final O obj, final int count) {
            for (int i = 0; i < count; i++) {
                newList.add(obj);
            }
        }
        // ...
    }

复杂度也是n^2,依旧没有解决去重。

不过这不是问题,我们知道List和Set是可以相互转化的,利用Set集合元素的唯一性就可以解决:

    @Test
    public void should_get_union_with_apache_union() {
        List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
        List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);

        Set set1 = new HashSet(s1);
        Set set2 = new HashSet(s2);
        List<Long> result = (List<Long>) CollectionUtils.union(set1, set2);
        assertThat(result).hasSize(4);
    }

此时,相当于使用了之前两倍的对象(两个输入set,一个输出set)。

了解后,自己实现一个:

    @Test
    public void should_get_union_with_self_union() {
        List<Long> s1 = Lists.newArrayList(1L, 2L, 3L);
        List<Long> s2 = Lists.newArrayList(2L, 3L, 3L, 7L);

        List<Long> result = getSelfUnion(s1, s2);
        assertThat(result).hasSize(4);
    }

    public List<Long> getSelfUnion(List<Long> s1, List<Long> s2) {
        Set<Long> result = new HashSet(s1.size() + s2.size());
        for (Long s : s1) {
            result.add(s);
        }
        for (Long m : s2) {
            result.add(m);
        }
        return (List<Long>) Lists.newArrayList(result);
    }

此时能满足之前的需要,复杂度是线性。

性能测试

上边一共三种方式实现集合取并集。利用下边代码测试。:

    @Test
    public void testTime() {
        long cost = 0;
        for (int i = 0; i < 20; i++) {
            long start = System.currentTimeMillis();
            // 方式一
            List<Long> result = getSelfUnion(longs1, longs2);
            // 方式二
//            List<Long> result2 = (List<Long>) CollectionUtils.union(longs1, longs2);
            // 方式三
//            List<Long> result3 = getUnion(longs1, longs2);// removeAll and addAll
            long end = System.currentTimeMillis();
            cost = cost + (end - start);
        }
        System.out.println("longs1: " + longs1.size() + ", longs2: " + longs2.size() + ", average cost: " + cost / 10 + "ms");
    }

测试结果如下:
方式一:longs1: 36900, longs2: 16035, average cost: 7ms
方式二:longs1: 36900, longs2: 16035, average cost: 17ms
方式三:longs1: 36900, longs2: 16035, average cost: 3779ms

因此,当数据量比较大时,还是不慎用 removeAllretainAll 的方式进行集合运算。

猜你喜欢

转载自blog.csdn.net/hustzw07/article/details/82053517