基于yolo3的图像半自动打标

标签: python深度学习  深度学习  xml

前言

这篇文章以yolo3为基础,用yolo3测试没有打标的图片,将结果转换为xml文件,然后在通过人工复查,一半智能一半人工所以叫做半自动。具体实施:先用少量打标样本训练出一个模型,用该模型测试未打标样本,通过人工复查之后将样本合并,这样减少了打标过程。图片样本过多打标确实是很累啊,我当时是有5000张图片需要打标,打标一整天搞了2000张人都麻木了,加入半自动后只需要微调结果框就好,半天就搞完剩余的3000张。

1. 图像样本

首先在根目录下新建semi-auto文件夹,文件夹包括Annotations和JPEGImages,与voc格式一样,方便最终的合并。
在这里插入图片描述
JPEGImages文件夹下放需要打标的图片,命名我建议从命名为voc格式的,然后接着已经打过标签的图片名往下继续,也就是说已经打过标的图片如果有8张,那么没打过标的就以000009开始。我当时是以时间命名的,后来出现标签与图片不对应的问题,所有的图片都要重新复查,都是泪。
在这里插入图片描述

2. yolo_test.py

2.1 代码

用该代码替换原来的yolo_test.py或则命名为yolo_test1.py。

# -*- coding: utf-8 -*-
import colorsys
import os
from timeit import default_timer as timer
import time

import numpy as np
from keras import backend as K
from keras.models import load_model
from keras.layers import Input
from PIL import Image, ImageFont, ImageDraw

from yolo3.model import yolo_eval, yolo_body, tiny_yolo_body
from yolo3.utils import letterbox_image
from keras.utils import multi_gpu_model

path ='C:/Users/Administrator/Desktop/keras-yolo3-master/semi-auto/JPEGImages/' #待检测图片的位置

# 创建创建一个存储检测结果的dir
result_path = './result'
if not os.path.exists(result_path):
    os.makedirs(result_path)

# result如果之前存放的有文件,全部清除
for i in os.listdir(result_path):
    path_file = os.path.join(result_path,i)  
    if os.path.isfile(path_file):
        os.remove(path_file)

#创建一个记录检测结果的文件
txt_path =result_path + '/result.txt'
file = open(txt_path,'w')  

class YOLO(object):
    _defaults = {
        "model_path": 'model_data/yolov3-8img.h5',
        "anchors_path": 'model_data/yolo_anchors.txt',
        "classes_path": 'model_data/coco_classes.txt',
        "score" : 0.3,
        "iou" : 0.45,
        "model_image_size" : (416, 416),
        "gpu_num" : 1,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults) # set up default values
        self.__dict__.update(kwargs) # and update with user overrides
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.sess = K.get_session()
        self.boxes, self.scores, self.classes = self.generate()

    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names

    def _get_anchors(self):
        anchors_path = os.path.expanduser(self.anchors_path)
        with open(anchors_path) as f:
            anchors = f.readline()
        anchors = [float(x) for x in anchors.split(',')]
        return np.array(anchors).reshape(-1, 2)

    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'

        # Load model, or construct model and load weights.
        num_anchors = len(self.anchors)
        num_classes = len(self.class_names)
        is_tiny_version = num_anchors==6 # default setting
        try:
            self.yolo_model = load_model(model_path, compile=False)
        except:
            self.yolo_model = tiny_yolo_body(Input(shape=(None,None,3)), num_anchors//2, num_classes) \
                if is_tiny_version else yolo_body(Input(shape=(None,None,3)), num_anchors//3, num_classes)
            self.yolo_model.load_weights(self.model_path) # make sure model, anchors and classes match
        else:
            assert self.yolo_model.layers[-1].output_shape[-1] == \
                num_anchors/len(self.yolo_model.output) * (num_classes + 5), \
                'Mismatch between model and given anchor and class sizes'

        print('{} model, anchors, and classes loaded.'.format(model_path))

        # Generate colors for drawing bounding boxes.
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
        np.random.seed(10101)  # Fixed seed for consistent colors across runs.
        np.random.shuffle(self.colors)  # Shuffle colors to decorrelate adjacent classes.
        np.random.seed(None)  # Reset seed to default.

        # Generate output tensor targets for filtered bounding boxes.
        self.input_image_shape = K.placeholder(shape=(2, ))
        if self.gpu_num>=2:
            self.yolo_model = multi_gpu_model(self.yolo_model, gpus=self.gpu_num)
        boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors,
                len(self.class_names), self.input_image_shape,
                score_threshold=self.score, iou_threshold=self.iou)
        return boxes, scores, classes
    
    
    def detect_image(self, image_path):
        start = timer() # 开始计时
        for filename in os.listdir(path):        
            image_path = path+'/'+filename
            portion = os.path.split(image_path)
            # file.write(portion[1]+' detect_result:\n')
            image = Image.open(image_path)
            image = image.convert('RGB')
        
            if self.model_image_size != (None, None):
                assert self.model_image_size[0]%32 == 0, 'Multiples of 32 required'
                assert self.model_image_size[1]%32 == 0, 'Multiples of 32 required'
                boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size)))
            else:
                new_image_size = (image.width - (image.width % 32),
                                  image.height - (image.height % 32))
                boxed_image = letterbox_image(image, new_image_size)
            image_data = np.array(boxed_image, dtype='float32')
    
            print(image_data.shape) #打印图片的尺寸
            image_data /= 255.
            image_data = np.expand_dims(image_data, 0)  # Add batch dimension.
    
            out_boxes, out_scores, out_classes = self.sess.run(
                [self.boxes, self.scores, self.classes],
                feed_dict={
                    self.yolo_model.input: image_data,
                    self.input_image_shape: [image.size[1], image.size[0]],
                    K.learning_phase(): 0
                })
    
            print('Found {} boxes for {}'.format(len(out_boxes), 'img')) # 提示用于找到几个bbox
    
            font = ImageFont.truetype(font='font/FiraMono-Medium.otf',
                        size=np.floor(2e-2 * image.size[1] + 0.2).astype('int32'))
            thickness = (image.size[0] + image.size[1]) // 500
    
            # 保存框检测出的框的个数
            # file.write('find  '+str(len(out_boxes))+' target(s) \n')
    
            for i, c in reversed(list(enumerate(out_classes))):
                predicted_class = self.class_names[c]
                box = out_boxes[i]
                score = out_scores[i]
    
                label = '{} {:.2f}'.format(predicted_class, score)
                draw = ImageDraw.Draw(image)
                label_size = draw.textsize(label, font)
    
                top, left, bottom, right = box
                top = max(0, np.floor(top + 0.5).astype('int32'))
                left = max(0, np.floor(left + 0.5).astype('int32'))
                bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32'))
                right = min(image.size[0], np.floor(right + 0.5).astype('int32'))
    
                # 写入检测位置            
                file.write(portion[1]+' '+predicted_class+' ' +str(right)+' ' + str(bottom)+' ' + str(left)+' '+ str(top)+' '+'\n')
                
                print(label, (left, top), (right, bottom))
    
                if top - label_size[1] >= 0:
                    text_origin = np.array([left, top - label_size[1]])
                else:
                    text_origin = np.array([left, top + 1])
    
                # My kingdom for a good redistributable image drawing library.
                for i in range(thickness):
                    draw.rectangle(
                        [left + i, top + i, right - i, bottom - i],
                        outline=self.colors[c])
                draw.rectangle(
                    [tuple(text_origin), tuple(text_origin + label_size)],
                    fill=self.colors[c])
                draw.text(text_origin, label, fill=(0, 0, 0), font=font)
                del draw

        end = timer()
        print('time consume:%.3f s '%(end - start))
        return image

    def close_session(self):
        self.sess.close()


# 图片检测

if __name__ == '__main__':

    t1 = time.time()
    yolo = YOLO()           
    image_path ='C:/Users/Administrator/Desktop/keras-yolo3-master/semi-auto/JPEGImages/'  #待检测图片的位置
    r_image = yolo.detect_image(image_path)
    file.close() 
    yolo.close_session()

需要修改的是17行和206行的路径,改为待检测图片的位置

2.2 运行过程和结果

程序运行过程:
在这里插入图片描述

运行后结果为result文件夹下的result.txt文件

000009.jpg dog 468 470 34 42 
000009.jpg cattle 469 499 41 55 
000010.jpg cattle 448 469 0 58 
000011.jpg cattle 445 456 22 109 

3. txt2xml.py将result.txt转换为xml文件

3.1 代码

新建在根目录下就好

import copy
from lxml.etree import Element, SubElement, tostring, ElementTree
import cv2
 
# 修改为你自己的路径
template_file = 'C:/Users/Administrator/Desktop/keras-yolo3-master/VOCdevkit/VOC2007/Annotations/000001.xml'    #已打标过的xml作为模板
target_dir = 'C:/Users/Administrator/Desktop/keras-yolo3-master/semi-auto/Annotations/'    #xml保存到哪去
image_dir = 'C:/Users/Administrator/Desktop/keras-yolo3-master/semi-auto/JPEGImages/'  # 图片文件夹
train_file = 'C:/Users/Administrator/Desktop/keras-yolo3-master/result/result.txt'  # 存储了图片信息的txt文件
 
with open(train_file) as f:
    trainfiles = f.readlines()  # 标注数据 格式(123.jpg pig x_min y_min x_max y_max)
 
file_names = []
tree = ElementTree()
 
for line in trainfiles:
    trainFile = line.split()
    file_name = trainFile[0]
    print(file_name)
 
    # 如果没有重复,则顺利进行。这给的数据集一张图片的多个框没有写在一起。
    if file_name not in file_names:
        file_names.append(file_name)
        lable = trainFile[1]
        
        #因为要使用labelimg来编辑xml,所以将trainFile[2] float类型转成整型。再将整型转成str类型存到xml文件里面。
        xmin = trainFile[2]
        ymin = trainFile[3]
        xmax = trainFile[4]
        ymax = trainFile[5]
 
        tree.parse(template_file)
        root = tree.getroot()
        root.find('filename').text = file_name
 
        # size
        sz = root.find('size')
        im = cv2.imread(image_dir + file_name)#读取图片信息
 
        sz.find('height').text = str(1080)
        sz.find('width').text = str(1920)
        sz.find('depth').text = str(3)
 
        # object 因为我的数据集都只有一个框
        obj = root.find('object')
 
        obj.find('name').text = lable
        bb = obj.find('bndbox')
        bb.find('xmin').text = xmin
        bb.find('ymin').text = ymin
        bb.find('xmax').text = xmax
        bb.find('ymax').text = ymax
        # 如果重复,则需要添加object框
    else:
        lable = trainFile[1]
        
        xmin = trainFile[2]
        ymin = trainFile[3]
        xmax = trainFile[4]
        ymax = trainFile[5]
 
        xml_file = file_name.replace('jpg', 'xml')
        tree.parse(target_dir + xml_file)#如果已经重复
        root = tree.getroot()
 
        obj_ori = root.find('object')
 
        obj = copy.deepcopy(obj_ori)  # 注意这里深拷贝
 
        obj.find('name').text = lable
        bb = obj.find('bndbox')
        bb.find('xmin').text = xmin
        bb.find('ymin').text = ymin
        bb.find('xmax').text = xmax
        bb.find('ymax').text = ymax
        root.append(obj)
 
    xml_file = file_name.replace('jpg', 'xml')
    tree.write(target_dir + xml_file, encoding='utf-8')

需要修改路径

3.2 结果

结果得到3个xml文件:
在这里插入图片描述

4. 用labelimg人工复查

在这里插入图片描述
在这里插入图片描述

总结

半自动打标还是挺省事的,啥时候能搞个全自动就好了,太懒!!!

版权声明:本文为YL950305原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/YL950305/article/details/109573587

智能推荐

【机器学习基础】线性回归

                                                        &nbs...

08-Vue实现书籍购物车案例

书籍购物车案例 index.html main.js style.css 1.内容讲解 写一个table和thead,tbody中每一个tr都用来遍历data变量中的books列表。 结果如下: 在thead中加上购买数量和操作,并在对应的tbody中加入对应的按钮。结果如下: 为每个+和-按钮添加事件,将index作为参数传入,并判断当数量为1时,按钮-不可点击。 结果如下: 为每个移除按钮添加...

堆排序

堆排序就是利用堆进行排序的方法,基本思想是,将代排序列构造成一个大根堆,此时整个序列的最大值就是堆顶的根节点。将它与堆数组的末尾元素交换,此时末尾元素就是最大值,移除末尾元素,然后将剩余n-1个元素重新构造成一个大根堆,堆顶元素为次大元素,再次与末尾元素交换,再移除,如此反复进行,便得到一个有序序列。 (大根堆为每一个父节点都大于两个子节点的堆) 上面思想的实现还要解决两个问题: 1.如何由一个无...

基础知识(变量类型和计算)

一、值类型 常见的有:number、string、Boolean、undefined、Symbol 二、引用类型 常用的有:object、Array、null(指针指向为空)、function 两者的区别: 值类型暂用空间小,所以存放在栈中,赋值时互不干扰,所以b还是100 引用类型暂用空间大,所以存放在堆中,赋值的时候b是引用了和a一样的内存地址,所以a改变了b也跟着改变,b和a相等 如图: 值...

猜你喜欢

Codeforces 1342 C. Yet Another Counting Problem(找规律)

题意: [l,r][l,r][l,r] 范围内多少个数满足 (x%b)%a!=(x%a)%b(x \% b) \% a != (x \% a) \% b(x%b)%a!=(x%a)%b。 一般这种题没什么思路就打表找一下规律。 7 8 9 10 11 12 13 14 15 16 17 18 19 20 28 29 30 31 32 33 34 35 36 37 38 39 40 41 49 50...

[笔记]飞浆PaddlePaddle-百度架构师手把手带你零基础实践深度学习-21日学习打卡(Day 3)

[笔记]飞浆PaddlePaddle-百度架构师手把手带你零基础实践深度学习-21日学习打卡(Day 3) (Credit: https://gitee.com/paddlepaddle/Paddle/raw/develop/doc/imgs/logo.png) MNIST数据集 MNIST数据集可以认为是学习机器学习的“hello world”。最早出现在1998年LeC...

哈希数据结构和代码实现

主要结构体: 实现插入、删除、查找、扩容、冲突解决等接口,用于理解哈希这种数据结构 完整代码参见github: https://github.com/jinxiang1224/cpp/tree/master/DataStruct_Algorithm/hash...

解决Ubuntu中解压zip文件(提取到此处)中文乱码问题

在Ubuntu系统下,解压zip文件时,使用右键--提取到此处,得到的文件内部文件名中文出现乱码。 导致此问题出现的原因一般为未下载相应的字体。 解决方案: 在终端中使用unar命令。 需要注意的是系统需要包含unar命令,如果没有,采用如下的方式解决: 实例效果展示: 直接提取到此处: 使用 unar filename.zip得到的文件...

centos7安装mysql8.0.20单机版详细教程

mysql8.0之后与5.7存在着很大的差异,这些差异不仅仅表现在功能和性能上,还表现在基础操作和设置上。这给一些熟悉mysql5.7的小伙伴带来了很多困扰,下面我们就来详细介绍下8.0的安装和配置过程。 mysql在linux上的多种安装方式: 1.yum安装 由于centos默认的yum源中没有mysql,所以我们要使用yum安装mysql就必须自己指定mysql的yum源。在官网下载mysq...