基于mmdetection训练Swin Transformer Object Detection

mmdetection官方文档
环境搭建
docker
找了一个torch版本为1.5.1+cu101的docker环境,然后安装mmdetection环境

pip install mmcv-full
git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection-master
pip install -r requirements/build.txt
pip install -v -e .

安装apex

git clone https://github.com/NVIDIA/apex
cd apex
pip install -r requirements.txt
python setup.py install --cpp_ext

安装成功

Processing dependencies for apex==0.1
Finished processing dependencies for apex==0.1

backbone:mmdet/models/backbones
neck:mmdet/models/necks
head:mmdet/models/roi_heads
BBox Assigner:mmdet/core/bbox/assigners
BBox Sampler:mmdet/core/bbox/samplers
BBox Encoder:mmdet/core/bbox/coder
BBox Decoder:mmdet/core/bbox/coder
Loss:mmdet/models/losses
BBox PostProcess:mmdet/core/post_processing
在"Swin-Transformer-Object-Detection-master/configs/swin/"目录下,可以看到模型文件,选择对应的修改
以"cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py"为例:

# head为例
roi_head=dict(
        bbox_head=[
            dict(
                type='ConvFCBBoxHead',
                num_shared_convs=4,
                num_shared_fcs=1,
                in_channels=256,
                conv_out_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                num_classes=15,  # 修改类别数量
# 根据gpu的数量,使用合适的BN
# norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_cfg=dict(type='BN', requires_grad=True),

# 调整学习率等相关参数,lr = 0.00125*batch_size
optimizer = dict(_delete_=True, type='AdamW', lr=0.00125, betas=(0.9, 0.999), weight_decay=0.05,
                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))
# 修改epoch
runner = dict(type='EpochBasedRunner', max_epochs=20)                                                  
# 不适用fp16,将use_fp16改为False
optimizer_config = dict(
    type="DistOptimizerHook",
    update_interval=1,
    grad_clip=None,
    coalesce=True,
    bucket_size_mb=-1,
    use_fp16=False,
)

在"configs/base/datasets/coco_instance.py"中根据需要修改

# 修改数据集的类型,路径
dataset_type = 'CocoDataset'
data_root = '/home/coco/'

# 修改img_size等参数,CUDA out of memory时可以修改
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    # 原本为1333*800
    #dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='Resize', img_scale=(416, 416), keep_ratio=True),

# 修改batch_size
data = dict(
    samples_per_gpu=1, # 每块GPU上的sample个数,batch_size = gpu数目*该参数
    workers_per_gpu=1, # 每块GPU上的workers的个数
    # 以train为例
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json', # 标注路径
        img_prefix=data_root + 'train2017/', # 训练图片路径
        pipeline=train_pipeline),

修改类别:mmdet/datasets/coco.py和 mmdet/core/evaluation/class_names.py文件

class CocoDataset(CustomDataset):

    #CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    #           'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    #           'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    #           'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    #           'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    #           'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    #           'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    #           'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    #           'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    #           'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    #           'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
    #           'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    #           'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
    #           'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
    CLASSES = ('person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
         'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car')
def coco_classes():
    return ['person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
         'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car']

修改"./tools/train.py"文件

# 选取其中一种版本,单机版本 MMDataParallel、分布式(单机多卡或多机多卡)版本 MMDistributedDataParallel
parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

模型预训练,权重加载、保存参数,config/base/default_runtime.py文件

checkpoint_config = dict(interval=1) # 每训练一个epoch,保存一次权重
load_from = None # 加载backbone权重
resume_from = None # 继续训练

训练模型
使用编号为3的单个gpu训练

python ./tools/train.py configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py --gpu-ids 3

使用多gpu训练

tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 4

训练Log及权重
保存在"Swin-Transformer-Object-Detection-master/work_dirs/"中

coco测试

python tools/test.py configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py cascade_mask_rcnn_swin_small_patch4_window7.pth --eval segm

踩过的坑及解决方案:
error with env var RANK
参考:
轻松掌握 MMDetection 整体构建流程(一)

未经允许不得转载!基于mmdetection训练Swin Transformer Object Detection

如遇到无法显示的问题,请先尝试刷新页面

客服联系邮箱:ai52learn@foxmail.com

本文地址:https://ai.52learn.online/27704