摘要

主要是记录如何使用mmseg这个分割库,使用自己的数据集进行训练,使用自己的config参数以及最后的测试以及推理。

在这里主要记录的是使用Unet网络的应用,数据采用胃部肠化数据,共有两个标签,一个是正常组织的标签(蓝色),一个是肠化组织标签(红色)。

我们的任务就是训练一个网络来识别正常腺体和异常腺体的区域。

mmseg介绍

就是一个分割工具箱没什么好介绍的,具体内容看链接:

内容具体介绍

  1. 环境配置
  2. 数据集制作
  3. 自定义数据集
  4. 数据集的读取
  5. 模型选取以及参数调整
  6. 训练测试
    主要参考来源:mmsegmentation教程1:自定义数据集、config文件修改、训练教程_AESA相控阵的博客-CSDN博客

环境配置

跟着官方的来就行开始:安装和运行 MMSeg — MMSegmentation 1.2.1 文档

可能会遇到问题,遇到问题装一个这个

pip install mmcv-full

数据集制作

首先自己的数据要按照如下的格式摆放:

├── data
│   ├── my_dataset
│   │   ├── img_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{img_suffix}
│   │   │   │   ├── yyy{img_suffix}
│   │   │   │   ├── zzz{img_suffix}
│   │   │   ├── val
│   │   ├── ann_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{seg_map_suffix}
│   │   │   │   ├── yyy{seg_map_suffix}
│   │   │   │   ├── zzz{seg_map_suffix}
│   │   │   ├── val

几个主要的要点:

  1. 这个data文件夹最好放在mmseg的子路径下,加快读取速度
  2. img_dir放置所有图片,ann_dir放置所有mask(标注文件),提前分好测试集和训练集(分训练集和测试集代码如下),注意文件名要对应(img中在测试集的xxx文件,它的标注文件一定也在测试集中)
    import os
    import random
    import shutil
    random.seed(42)
    
    input_img_dir = ''
    input_label_dir = ''
    output_img_dir = ''
    output_label_dir = ''
    
    for split in ['train', 'val']:
        os.makedirs(os.path.join(output_img_dir, split), exist_ok=True)
        os.makedirs(os.path.join(output_label_dir, split), exist_ok=True)
    
    img_files = os.listdir(input_img_dir)
    label_files = os.listdir(input_label_dir)
    
    train_ratio = 0.7
    val_ratio = 0.3 
    
    random.shuffle(img_files)
    
    total_samples = len(img_files)
    num_train = int(total_samples * train_ratio)
    train_img_files = img_files[:num_train]
    val_img_files = img_files[num_train:]
    for src_dir, split in [(input_img_dir, 'train'), (input_img_dir, 'val')]:
        for img_file in locals()[f'{split}_img_files']:
            label_file = img_file  # 图像和标签文件名一一对应
            src_img_path = os.path.join(input_img_dir, img_file)
            src_label_path = os.path.join(input_label_dir, label_file)
            dst_img_path = os.path.join(output_img_dir, split, img_file)
            dst_label_path = os.path.join(output_label_dir, split, label_file)
            shutil.move(src_img_path, dst_img_path)
            shutil.move(src_label_path, dst_label_path)

自定义数据集

由于我们的任务比较特殊,常规的分割方法所采用的数据集,因此需要自己重新定义。

mmseg/datasets/目录下找到stare.py文件
image.png

同路径下复制一份,重命名为my_dataset.py记住这个名字后续还要用。

下面需要修改三个地方:
image.png

你需要修改标红的三个地方:
STAREDataset:数据集的名称,自定义,自己要记得
CLASSES:表示你数据集的背景+类别
PALETTE:表示你数据集各类别的像素值
img_suffix:原图图像后缀
seg_map_suffix:mask图像后缀
reduce_zero_label:这个后面说

# Copyright (c) OpenMMLab. All rights reserved.

import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class Mydataset(BaseSegDataset):
    """STARE dataset.
    In segmentation map annotation for STARE, 0 stands for background, which is
    included in 2 categories. ``reduce_zero_label`` is fixed to False. The
    ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
    '.ah.png'.
    """
    METAINFO = dict(
        classes=('blue', 'red'),
        palette=[[120, 120, 120], [6, 230, 230]])
    def __init__(self,

                 img_suffix='.png',

                 seg_map_suffix='.png',

                 reduce_zero_label=True,

                 **kwargs) -> None:

        super().__init__(

            img_suffix=img_suffix,

            seg_map_suffix=seg_map_suffix,

            reduce_zero_label=reduce_zero_label,

            **kwargs)

        assert fileio.exists(

            self.data_prefix['img_path'], backend_args=self.backend_args)

最后修改后的样子

然后修改这个__init__文件
image.png

前面添加一个from .my_dataset import Mydataset
后面添加一个, ‘Mydataset’
image.png
image.png
就是你刚刚命名的文件名和dataset的命名

注意:palette是调色板,对应与后续测试的时候每类别的输出颜色,另外,mask也有要求,由于现在的标签是蓝色和红色,即[255, 0, 0]和[0, 0, 255]。但是后续的处理需要变为0:背景,1:标签1,2:标签2。因此标签需要进一步处理。

代码如下:

import os

from tqdm import tqdm

import numpy as np

from PIL import Image

input_folder = ''

output_folder = ''

for filename in tqdm(os.listdir(input_folder)):

    if filename.endswith(".png"):  

        img_path = os.path.join(input_folder, filename)

        img = Image.open(img_path)

        img_array = np.array(img)

        img_array[np.where((img_array == [0, 0, 255]).all(axis=2))] = [1, 1, 1]

        img_array[np.where((img_array == [255, 0, 0]).all(axis=2))] = [2, 2, 2]

        mask = img_array[:, :, 0]

        new_img = Image.fromarray(mask)

        new_img.save(os.path.join(output_folder, filename))

数据集的读取

在configs/_ base_ /datasets/目录下找到stare.py文件,复制一份重命名为my_dataset.py。进行修改
image.png

修改后的代码以及所需要修改的部分都已经添加注释了

# dataset settings
dataset_type = 'MyDataset' # 改成自己的数据集类名
data_root = '/data_sda/data/gastric_2label' # 数据集存储路径
img_scale = (256, 256) # 图像尺寸
crop_size = (256, 256) # 裁剪尺寸
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', scale=img_scale, keep_ratio=True), #这个改一下
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=img_scale, keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]
train_dataloader = dict(
    batch_size=12, # 这个可以自己测试
    num_workers=1, # 
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type='RepeatDataset',
        times=40000,
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            data_prefix=dict(
                img_path='img_split/train', # 改路径
                seg_map_path='label_split/train'), # 改路径
            pipeline=train_pipeline)))
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_split/val', # 改路径
            seg_map_path='label_split/val'), # 改路径
        pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice', 'mIoU', 'mFscore']) # 改一下评价指标
test_evaluator = val_evaluator

模型选取以及参数调整

这里简单写一个用Unet的(其他的其实类似)