GCC (Graph Contrastive Clustering) paper code reproduction


foreword

The future learning direction will mainly focus on deep clustering. In recent years, comparative learning algorithms are also very popular, so comparative learning will also be involved. The Graph Contrastive Clustering in this article comes from ICCV2021. It can be seen that it has the characteristics of both graph and comparative learning. However, the graph here only reflects the neighbor relationship between sample points rather than the graph neural network. In the code, it is expressed as using MemoryBank to maintain a KNN matrix, so readers can rest assured to eat without graph neural knowledge.

一、Graph Contrastive Clustering

insert image description here
The overall structure of the algorithm is actually very clear, that is, first use the backbone network CNN to extract features, then obtain the vector or representation (Representation) represented by each sample through MLP, and make the network converge by optimizing the RGC and AGC losses. Among them, Updated is reflected in the code to discover the nearest 5 neighbors of each sample, and then calculate RGC and AGC through these neighbors, which is the "Guided" guidance in the figure. The blue arrow indicates that the two parts of RGC and AGC are explained separately. The RGC is similar to the Instance Head in the article AAAI 2021 Contrasive Clustering, and the AGC is similar to the Cluster Head. The following figure shows the structure diagram of Contrastive Clustering. It can be seen that they are indeed very similar:
insert image description here

2. Code reproduction

1. Precautions

The author's experimental environment is the Ubuntu operating system, the hardware configuration is a GeForce RTX3090, and the computer memory is 64G.
The GCC project code can only run on the Linux operating system but not on the Windows operating system, because of the existence of the faiss library, which is an open source project of FaceBook, and currently only supports the Linux operating system. It provides GPU-accelerated matrix retrieval, which can be used It is used to quickly find K-nearest neighbors. Because the training involves MemoryBank, it needs to consume a lot of computer resources - CPU processing power, memory space, and this project does not support breakpoint recovery, because it maintains a graph structure, many of which exist in memory, once the training is restarted If the previous memory data cannot be restored, the breakpoint cannot be continued, and the training time is generally more than 24 hours, which is a headache during training.

Next, we started to reproduce the code, because the original project code would cause some errors when running directly, so we made the following corrections! ! !

2.utils/mypath.py

class MyPath(object):
    @staticmethod
    def db_root_dir(database=''):
        db_names = {
    
    'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200', 'tiny_imagenet'}
        assert(database in db_names)

        if database == 'cifar-10':
            return 'gruntdata/dataset'

        elif database == 'cifar-20':
            return 'gruntdata/dataset'

        elif database == 'stl-10':
            return 'gruntdata/dataset'

        elif database == 'tiny_imagenet':
            return 'gruntdata/dataset'

        elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']:
            return 'path/to/imagenet/'

        else:
            raise NotImplementedError

gruntdata/dataset is the directory for storing training datasets, but there is too much "/" in gruntdata in the original code, which makes it impossible to read the path normally, just remove it.

3.utils/collate.py

import torch
import numpy as np
import collections
# from torch._six import string_classes, int_classes
string_classes=str
int_classes=int

It may be due to the torch version that an error is reported on line 4. In this case, you only need to comment it out and add the last two lines of code. If there is no error in this part, ignore it.

4.data/datasets_imagenet_dogs.py

import torchvision
from PIL import Image
import numpy as np
from skimage import io
from torch.utils.data.dataset import Dataset


class ImageNetDogs(Dataset):
    base_folder = 'imagenet-dogs'
    class_names_file = 'class_names.txt'
    train_list = [
        ['ImageNetdogs.h5', '918c2871b30a85fa023e0c44e0bee87f'],
        ['ImageNetdogsAll.h5', '918c2871b30a85fa023e0c44e0bee87f'],
    ]

    splits = ('train', 'test', 'train+unlabeled')

    def __init__(self, split='train',
                 transform=None, target_transform=None, download=False):
        if split not in self.splits:
            raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
                split, ', '.join(self.splits),
            ))

        self.transform = transform
        self.target_transform = target_transform
        self.split = split  # train/test/unlabeled set
        
        self.data, self.targets = self.__loadfile()
        print("Dataset Loaded.")

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            dict: {'image': image, 'target': index of target class, 'meta': dict}
        """
        img, target = self.data[index], self.targets[index]
        img_size = (img.shape[0], img.shape[1])
        img = Image.fromarray(np.uint8(img)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        out = {
    
    'image': img, 'target': target, 'meta': {
    
    'im_size': img_size, 'index': index, 'class_name': 'unlabeled'}}

        return out

    def __len__(self):
        return len(self.data)

    def __loadfile(self):
        datas,labels = [],[]
        source_dataset = torchvision.datasets.ImageFolder(root='gruntdata/dataset/ImageNet-dogs/train')

        for line,target in zip(source_dataset.imgs,source_dataset.targets):
            try:
                img = io.imread(line[0])
            except:
                #print(line[0])
                continue
            else:
                datas.append(img)
                labels.append(target)

        return datas, labels

    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)

The author of the original text did not detail how to train the ImageNet-dog and ImageNet10 datasets, so here I rewritten the ImageNetDogs class. The prerequisite for correct operation is that the ImageNet-dog15 dataset has been placed in the gruntdata/dataset directory.

5.data/datasets_imagenet10.py

import torchvision
from PIL import Image
import os
import os.path
import numpy as np
from skimage import io, color

import torchvision.datasets as datasets

from torch.utils.data.dataset import Dataset

class ImageNet10(Dataset):
    """`ImageNet10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``stl10_binary`` exists.
        split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
            Accordingly dataset is selected.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
    base_folder = 'imagenet-10'
    class_names_file = 'class_names.txt'
    train_list = [
        ['ImageNet10_112.h5', '918c2871b30a85fa023e0c44e0bee87f'],
    ]

    splits = ('train', 'test')

    def __init__(self, split='train',
                 transform=None, target_transform=None, download=False):
        if split not in self.splits:
            raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
                split, ', '.join(self.splits),
            ))

        self.transform = transform
        self.target_transform = target_transform
        self.split = split  # train/test/unlabeled set

        self.data, self.targets = self.__loadfile()
        print("Dataset Loaded.")

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            dict: {'image': image, 'target': index of target class, 'meta': dict}
        """
        img, target = self.data[index], self.targets[index]
        img_size = (img.shape[0], img.shape[1])
        img = Image.fromarray(np.uint8(img)).convert('RGB')
        # class_name = self.classes[target]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        out = {
    
    'image': img, 'target': target, 'meta': {
    
    'im_size': img_size, 'index': index, 'class_name': 'unlabeled'}}

        return out

    def __len__(self):
        return len(self.data)

    def __loadfile(self):
        datas,labels = [],[]
        source_dataset = torchvision.datasets.ImageFolder(root='gruntdata/dataset/ImageNet-10/train/')

        for line,tar in zip(source_dataset.imgs,source_dataset.targets):
            try:
                img = io.imread(line[0])
                # img = color.gray2rgb(img)
            except:
                print(line[0])
                continue
            else:
                datas.append(img)
                labels.append(tar)

        return datas, labels

    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)

The code to read ImageNet10 is exactly the same as the code to read the ImageNet-dog dataset.


Summarize

Not surprisingly, the project code can be run smoothly. If you have any questions, please feel free to communicate in the comment area.

references

Github deep clustering related papers code finishing
Contrastive Clustering original paper
Graph Contrastive Clustering original paper

Guess you like

Origin blog.csdn.net/weixin_43594279/article/details/125071968