Breast cancer image segmentation using U-Net

Breast cancer image segmentation using U-Net

1. Data preparation

First, we need to prepare the data. In this example, our dataset contains two types of images: original images and corresponding mask images. The original image is a picture of breast cancer tissue, and the mask image is the segmentation label corresponding to the original image. We need to preprocess these images so that they can be fed into our neural network model.

1.1 Data preprocessing

We used theAlbumentations library for data enhancement to perform operations such as rotation, flipping, and brightness and contrast adjustment on images. At the same time, we divide the data set into training set, validation set and test set. Note that the Albumentations library needs to be installed through pip. The specific installation is:

pip install Albumentations

code show as below:

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader

train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Rotate(limit=30, p=0.3),
    A.RandomResizedCrop(height=256, width=256, scale=(0.8, 1.0), p=0.2),
    A.Normalize(),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(),
    ToTensorV2()
])

test_transform = A.Compose([
    A.Normalize(),
    ToTensorV2()
])

1.2 Custom data sets

We created a custom dataset class that is responsible for reading the image from disk, applying preprocessing operations and returning the image and corresponding mask.

code show as below:

from torchvision import transforms

class BreastCancerSegmentationDataset(Dataset):
    """
    乳腺癌分割数据集
    """
    def __init__(self, img_dir, mask_dir, transform=None, one_hot_encode=True, target_size=(256, 256)):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.one_hot_encode = one_hot_encode
        self.target_size = target_size
        self.img_filenames = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, index):
        # 根据文件存放方式设置os,便于建立原始图像与mask图像的联系
        img_name = self.img_filenames[index]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name[:-4] + '_mask'+img_name[-4:])
        # Skip .ipynb_checkpoints files
        if img_path.endswith(".ipynb_checkpoints") or mask_path.endswith(".ipynb_checkpoints"):
            return self.__getitem__((index + 1) % len(self))
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # 适应度处理,检查是否将图像添加入os
        if image is None:
            raise FileNotFoundError(f"Image not found at {
      
      img_path}")
        if mask is None:
            raise FileNotFoundError(f"Mask not found at {
      
      mask_path}")

        # 将 image 和 mask 的图像进行强制转化,转化成同样大小
        image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)

        # 将mask进行独热处理
        if self.one_hot_encode:
            mask = one_hot_encode(mask, num_classes=3)

        # 定义transform构架,为数据增强做准备
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # 将mask的属性值转化为float类型
        mask = np.asarray(mask)   # 转换为NumPy数组
        mask = mask.astype(np.float32)   # 变换dtype
        mask = torch.from_numpy(mask) # 转换为Tensor
        
        return image, mask

2. Build U-Net model

We choose to use the pretrained FCN-ResNet50 as our base model. We created a class called UNetTrainer to handle the training, evaluation and testing process.

code show as below:

class UNetTrainer:
    """实际上我们使用的是一个全卷积网络(FCN)的ResNet50实现,而不是U-Net"""
    def __init__(self, num_classes=3, lr=1e-4):
        # 我们使用的fcn_resnet50是U-Net模型的变体,其中编码器部分初始化了ResNet50的权重。这可以加速模型训练和提高最终性能。但解码器部分仍需要我们从零训练
        self.model = fcn_resnet50(pretrained=False, num_classes=num_classes)
        # 损失函数使用交叉熵损失函数
        self.criterion = nn.CrossEntropyLoss()
        # 使用adam优化器
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        # 设置GPU为训练设备
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

3. Model training and evaluation

We use cross-entropy loss as the loss function and use the Adam optimizer for model training. During the training process, we train using the training set and evaluate the model using the validation set at the end of eachepoch. We use IoU(Intersection over Union) and F1 scores as performance metrics. Note that train, evaluate and test are all member functions of UNetTrainer.

code show as below:

def train(self, num_epochs, train_loader, val_loader):
    for epoch in range(num_epochs):
        print(f"Epoch {
      
      epoch + 1}/{
      
      num_epochs}")
        print("-" * 10)

        start_time = time.time()

        self.evaluate(epoch, train_loader, "train")
        self.evaluate(epoch, val_loader, "val")

        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Epoch time: {
      
      elapsed_time:.4f}s")

    print("Training complete")

def evaluate(self, epoch, dataloader, phase):
    if phase == "train":
        self.model.train()
    else:
        self.model.eval()

    running_loss = 0.0
    running_iou = 0.0
    running_f1_score = 0.0

    for images, masks in dataloader:
        images = images.to(self.device)
        masks = masks.to(self.device)

        masks = torch.mean(masks, dim=3, keepdim=False).long()

        self.optimizer.zero_grad()

        with torch.set_grad_enabled(phase == "train"):
            outputs = self.model(images)['out']
            preds = torch.argmax(outputs, dim=1)
            loss = self.criterion(outputs, masks)

            if phase == "train":
                loss.backward()
                self.optimizer.step()

        running_loss += loss.item() * images.size(0)
        running_iou += self.calculate_iou(preds, masks.data)
        running_f1_score += f1_score(masks.cpu().numpy().ravel(), preds.cpu().numpy().ravel(), average="macro")

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_iou = running_iou / len(dataloader)
    epoch_f1_score = running_f1_score / len(dataloader)

    print(f"{
      
      phase} Loss: {
      
      epoch_loss:.4f} IoU: {
      
      epoch_iou:.4f} F1: {
      
      epoch_f1_score:.4f}")

4. Model testing

After training is complete, we merge the training and validation sets and retrain the model on this combined dataset. We then evaluate the final model using the test set.

code show as below:

def test(self, train_loader, val_loader, test_loader):
    # Combine train and val datasets
    combined_dataset = torch.utils.data.ConcatDataset([train_loader.dataset, val_loader.dataset])
    combined_loader = torch.utils.data.DataLoader(combined_dataset, batch_size=train_loader.batch_size, shuffle=True)

    # Retrain the model on the combined dataset
    print("Retraining the model on the combined dataset")
    self.train(20, combined_loader, test_loader)

    # Evaluate the model on the test dataset
    print("Evaluating the model on the test dataset")
    self.evaluate(0, test_loader, "test")

5. Suggestions for improving model performance

In order to further improve model performance and training efficiency, I provide some suggestions:

  1. Early Stopping: During the training process, if the loss of the verification set does not improve significantly for a long time, the training can be stopped early. This prevents the model from overfitting and saves training time.
  2. Learning rate adjustment strategy: Using learning rate adjustment strategies, such as learning rate decay or periodic learning rate, can help the model converge faster and improve the final performance to a certain extent.
  3. Data enhancement strategy: Try to use more data enhancement methods, such as random cropping, elastic deformation, etc., to improve the generalization ability of the model.
  4. Regularization methods: Adding regularization terms (such as L1, L2 regularization or Dropout) can help prevent model overfitting and improve the generalization ability of the model.
  5. Model structure adjustment: Try to use other more advanced network structures, such as Attention U-Net or DeepLabv3, etc., to improve model performance.
  6. Cross-validation: Use k-fold cross-validation to evaluate model performance. This can provide a more stable performance evaluation metric and help avoid overfitting.
  7. Model integration: Combining the prediction results of multiple models can further improve model performance. This can be achieved through simple averaging, voting, or more complex ensemble strategies.

Applying these methods to the model training process can help us further improve the performance of breast cancer image segmentation tasks.

Guess you like

Origin blog.csdn.net/qq_62862258/article/details/130445590