博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
深度学习目标检测工具箱mmdetection,训练自己的数据
阅读量:2814 次
发布时间:2019-05-13

本文共 3663 字,大约阅读时间需要 12 分钟。

文章目录

一、简介

商汤科技(2018 COCO 目标检测挑战赛冠军)和香港中文大学最近开源了一个基于Pytorch实现的深度学习目标检测工具箱mmdetection,支持Faster-RCNN,Mask-RCNN,Fast-RCNN等主流的目标检测框架,后续会加入Cascade-RCNN以及其他一系列目标检测框架。

相比于Facebook开源的Detectron框架,作者声称mmdetection有三点优势:performance稍高、训练速度稍快、所需显存稍小。

我很早就听说了这个工具箱,但是一直没有开源。现在总算是开源了,于是就迫不及待地测试一下这个框架的效果。下面将记录一下我的测试过程,已经期间所遇到的一些坑。

2019.05.26更新了网盘里的demo.py文件,现在的demo应该可以匹配官方的最新代码了。

2019.02.22上传多个mmdetection模型,检测器包括Faster-RCNN、Mask-RCNN、RetinaNet,backbone包括Resnet-50、Resnet-101、ResNext-101,网盘链接:(密码:dpyl)

二、安装教程

本人的系统环境:

  • Ubuntu 16.04
  • Cuda 9.0 + Cudnn 7.0.5
  • Python 3.6 (mmdetection要求Python版本需要3.4+)
  • Anaconda 3 (可选)

这里推荐大家使用Anaconda,可以比较方便的创建Python虚拟环境,避免不同的Python库之间产生冲突。在安装mmdetection之前,需要安装以下几个依赖库:

  • PyTorch 0.4.1 和 torchvision
  • PyTorch 1.0 PyTorch 1.1 (Pytorch 0.4.1的版本需要切换branch,在clone了mmdetection的git之后需要git checkout pytorch-0.4.1)
  • Cython
  • mmcv

下面是我的安装和测试步骤,以Anaconda 3为例。

1. 使用conda创建Python虚拟环境(可选)

conda create -n mmdetection python=3.6source activate mmdetection

这样就创建了名为mmdetection的Python3.6环境,并且在terminal中激活。如果不需要虚拟环境,则将下文的conda install改为pip install

2. 安装PyTorch 1.1

conda install pytorch=1.1 -c pytorch

安装好以后,进入Python环境,输入以下代码测试是否安装成功,不报错则说明安装成功

import torch

3. 安装Cython

conda install cython

4. 安装mmcv

官方代码已更新,直接运行下一步就可以自动安装所有依赖库了

5. 安装mmdetection

git clone https://github.com/open-mmlab/mmdetection.gitcd mmdetectionpython setup.py develop

到此,我们就完成了mmdetection及其依赖库的安装

6. 测试Demo

将下方的代码写入py文件,并存放到mmdetection文件夹目录下,然后运行。该代码的功能是检测图片中的目标,测试模型是官方给出的Faster-RCNN-fpn-resnet50的模型,运行代码会自动下载模型。由于模型是存储在亚马逊云服务器上,速度可能会稍慢,如果下载失败可以通过我的网盘链接(密码:dpyl)进行下载,存放到mmdetection文件夹目录下,然后修改下方代码的相关部分

from mmdet.apis import init_detector, inference_detector, show_result # 首先下载模型文件https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pthconfig_file = 'configs/faster_rcnn_r50_fpn_1x.py'checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth' # 初始化模型model = init_detector(config_file, checkpoint_file) # 测试一张图片img = 'test.jpg'result = inference_detector(model, img)show_result(img, result, model.CLASSES) # 测试一系列图片imgs = ['test1.jpg', 'test2.jpg']for i, result in enumerate(inference_detector(model, imgs, device='cuda:0')):    show_result(imgs[i], result, model.CLASSES, out_file='result_{}.jpg'.format(i))

7. 准备自己的数据

mmdetection支持coco格式和voc格式的数据集,下面将分别介绍这两种数据集的使用方式

  • coco数据集
    官方推荐coco数据集按照以下的目录形式存储,以coco2017数据集为例
mmdetection├── mmdet├── tools├── configs├── data│   ├── coco│   │   ├── annotations│   │   ├── train2017│   │   ├── val2017│   │   ├── test2017

推荐以软连接的方式创建data文件夹,下面是创建软连接的步骤

cd mmdetectionmkdir dataln -s $COCO_ROOT data

其中,$COCO_ROOT需改为你的coco数据集根目录

  • voc数据集
    与coco数据集类似,将voc数据集按照以下的目录形式存储,以VOC2007为例
mmdetection├── mmdet├── tools├── configs├── data│   ├── VOCdevkit│   │   ├── VOC2007│   │   │   ├── Annotations│   │   │   ├── JPEGImages│   │   │   ├── ImageSets│   │   │   │   ├── Main│   │   │   │   │   ├── test.txt│   │   │   │   │   ├── trainval.txt

同样推荐以软连接的方式创建

cd mmdetectionmkdir dataln -s $VOC2007_ROOT data/VOCdevkit

其中,$VOC2007_ROOT需改为你的VOC2007数据集根目录

然后,下载 pascal_voc_mod.py 和 voc_classes.txt (上方的模型下载地址中有)存放到mmdetection根目录下,运行以下代码
mmdetection官方代码已更新,不再需要自己生成

如果需要标注自己的数据,推荐使用工具标注

然后在运行 pascal_voc_mod.py 之前,修改 voc_classes.txt 里的类别名为你自己设定的类别名,再运行py文件
然后需要修改mmdet/datasets/voc.py文件中的CLASSES为你自己的类别

8. 训练

官方推荐使用分布式的训练方式,这样速度更快,如果是coco训练集,修改CONFIG_FILE中的pretrained参数,改为你的模型路径,然后运行下方代码

./tools/dist_train.sh 
[optional arguments]

如果是voc训练集,还需要修改config文件中的相关参数,可以参考 faster_rcnn_r50_mod.py (上方网盘地址中有),然后再运行上面的代码

mmdetection官方代码已更新,目前已支持voc格式的数据集,不再需要自己修改

如果不想采用分布式的训练方式,或者你只有一块显卡,则运行下方的代码

python tools/train.py 
--gpus
--work_dir

至此,如果一切顺利的话,你的模型应该就开始训练了

转载地址:http://ojqqd.baihongyu.com/

你可能感兴趣的文章
原厂内核移植流程
查看>>
内核与文件系统
查看>>
手动构建rootfs及文件功能分析
查看>>
利用nfs调试rootfs
查看>>
c++的泛型编程与模板
查看>>
Linux环境变量详解
查看>>
shell中变量的使用
查看>>
shell中的数值操作
查看>>
shell中的函数
查看>>
内核的进程管理与调度
查看>>
Linux环境下修改或指定python的默认版本
查看>>
linux 安装Anaconda/Miniconda以后无法识别conda命令
查看>>
linux
查看>>
caffe/ windows 10 /Can't parse message of type "caffe.NetParameter" because it is missing required
查看>>
matplotlib/plt 函数savefig保存的图像有空白
查看>>
Machine learning Year阅读笔记
查看>>
windows10 + caffe 配置CUDA10.0
查看>>
python/matplotlib绘制统计直方图
查看>>
linxu/torch/pytorch会自动安装 cuda
查看>>
详解微信小程序胶囊按钮返回|首页自定义导航栏功能
查看>>