VOC格式xml数据集转YOLO格式方案

syan 发布于 2020-05-23 2373 次阅读


贴两条代码
为了做robomaster的雷达,需要跑一份目标检测神经网络,这里选择了最新最强最快的YOLO v4.
但是官方提供的参考数据集是VOC格式的,无法直接交予网络使用,需要人为地进行修改
获取类别/框坐标/框框长宽

学习了一下xml文件的处理
参考了CSDN上的文章[YoloV3] DJIXML数据集转Yolo脚本

紧急报错,只能给yolo v4投喂classes的序号,与自己生成的names文件顺序相对应,而不能直接使用类的名称

第一次用此处生成的代码跑网络时,300轮loss降到了0.001,你敢信?(因此还请自行修改

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''
Created on Sat Jan 11 16:31:30 2020
XmlToTxt, for DJI ROCO Dataset
@author: HNU robomaster 跃鹿
原文链接:https://blog.csdn.net/chenhanxuan1999/java/article/details/103970884
Fix by: HDU robomaster Syan
'''
import os
import sys
import xml.etree.ElementTree as ET

North = 'robomaster_North China Regional Competition'
Central = 'robomaster_Central China Regional Competition'
Final = 'robomaster_Final Tournament'

Competition_Path = Final

filedir = 'D:/darknet-master/data/obj/DJI ROCO/'+ Competition_Path +'/image_annotation'
outdir = 'D:/darknet-master/data/obj/DJI ROCO/'+ Competition_Path +'/processedTXT'

print(filedir)

def del_all_files(path):
    ls = os.listdir(path)
    for i in ls:
        c_path = os.path.join(path, i)
        if os.path.isdir(c_path):
            del_file(c_path)
        else:
            os.remove(c_path)

def mkdir(path):
    # 去除首位空格
    path = path.strip()
    # 去除尾部 \ 和 / 符号
    path = path.rstrip("\\")
    path = path.rstrip("/")
    # 判断路径是否存在
    is_exist = os.path.exists(path)
    # 判断结果
    if not is_exist:
        os.makedirs(path)
    else:
        # 如果目录存在则首先递归删除该目录下所有内容
        del_all_files(path)

def xml_to_txt(indir, outdir):
    parser = ET.XMLParser(encoding="utf-8")
    root = ET.parse(indir, parser=parser)
    root = root.getroot()

    f_w = open(outdir, 'w')
    size_info = root.find("size")
    width = float(size_info.find("width").text)
    height = float(size_info.find("height").text)

    for obj in root.findall("object"):
        #error occurs when using func .find() specifically for string "armor_color"
        armor_color = getattr(obj.find("armor_color"), 'text', None)
        armor_class = getattr(obj.find("armor_class"), 'text', None)
        name = obj.find("name").text
        xmin = float(obj.find("bndbox").find("xmin").text)
        ymin = float(obj.find("bndbox").find("ymin").text)
        xmax = float(obj.find("bndbox").find("xmax").text)
        ymax = float(obj.find("bndbox").find("ymax").text)
        x_center = (xmax + xmin) / 2
        y_center = (ymax + ymin) / 2
        x = x_center / width
        y = y_center / height
        w = (xmax - xmin) / width
        h = (ymax - ymin) / height
        flag = ''
        if name == 'armor':
                flag = name + '-' + armor_color + '-' + armor_class
        else:
            flag = name
        f_w.write("".join([str(flag), ' ', str(round(x, 6)), ' ', str(round(y, 6)), ' ', str(round(w, 6)), ' ', str(round(h, 6)), '\n']))
    f_w.close()

def main():
    global filedir
    global outdir

    file_list = os.listdir(filedir)
    mkdir(outdir)
    for file_name in file_list:
        file_prefix = file_name.rpartition(".")[0]
        new_name = "".join([file_prefix, '.txt'])
        xml_to_txt(os.path.join(filedir, file_name), os.path.join(outdir, new_name))

if __name__ == "__main__":
    main()

并自己写了一段用于读取图片列表的小代码,用于喂养yolo

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''
Created on Sat May 23 12:03:30 2020
Get img list, for DJI ROCO Dataset
@author: HDU robomaster Syan
'''
import os
import sys

North = 'robomaster_North China Regional Competition'
Central = 'robomaster_Central China Regional Competition'
Final = 'robomaster_Final Tournament'

Competition_Path = [North,Central,Final]
filedir = 'D:/darknet-master/data/obj/DJI ROCO/'

def main():
    global filedir

    f_w = open(filedir + 'train.txt', 'w', encoding='utf-8')

    for Path in Competition_Path:
        for file_name in os.listdir(filedir+Path+'/image'):
            #print(file_name)

            file_name.strip()
            if file_name.endswith('.jpg'):
                f_w.write(filedir[18:]+Path+'/image/'+file_name+'\n')

    f_w.close()

if __name__ == "__main__":
    main()