python实现提取COCO,VOC数据集中特定的类

一、前言

在深度学习中,数据集是非常重要的资源之一,但是我们有时需要从一个大的数据集中提取出特定的类别,这样可以让我们在模型训练、测试或者其他操作上更加方便。本文将介绍如何使用Python代码从COCO、VOC数据集中提取特定的类。

二、准备工作

在进行以下操作前,需要下载并解压相应的数据集,以COCO2017数据集为例,可以在官方网站(http://cocodataset.org/#download)下载并解压到合适的目录中。

安装必要的Python库:

# 安装必要的Python库
!pip install numpy matplotlib scipy scikit-image pillow

三、COCO数据集

  1. 导入必要的库
import json
import os
import shutil
from pycocotools.coco import COCO
  1. 要求输入的参数
# 要求输入的参数
dataDir = './coco2017'
dataType = 'train2017'
annFile = os.path.join(dataDir, 'annotations', 'instances_{}.json'.format(dataType))
catIds = [1, 2, 3, 4, 6, 7, 8, 10, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
catNms = ['person', 'bicycle', 'car', 'motorcycle', 'bus',
          'train', 'truck', 'traffic light', 'stop sign', 'fire hydrant',
          'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
          'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
          'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
          'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass',
          'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
          'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
          'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror',
          'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop',
          'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
          'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase',
          'scissors', 'teddy bear', 'hair drier', 'toothbrush']
  1. 定义函数提取COCO数据集中特定的类
def get_coco_category(category_list):
    """
    :param category_list: 需要提取的类别列表
    :return: None
    """
    # 初始化COCO API
    coco = COCO(annFile)

    # 获取包含需要提取类别的image ID
    catIds = coco.getCatIds(catNms=category_list)
    imgIds = coco.getImgIds(catIds=catIds)

    for imgId in imgIds:
        imgInfo = coco.loadImgs(imgId)[0]
        imgName = imgInfo['file_name']
        srcImgPath = os.path.join(dataDir, dataType, imgName)
        dstImgPath = os.path.join(outputPath, imgName)
        shutil.copy(srcImgPath, dstImgPath)

        annIds = coco.getAnnIds(imgIds=imgId, catIds=catIds, iscrowd=None)
        anns = coco.loadAnns(annIds)
        dstAnnPath = os.path.join(annOutputPath, imgName.replace('.jpg', '.json'))
        with open(dstAnnPath, 'w') as wf:
            json.dump(anns, wf)
  1. 示例
# 需要提取的类别列表
categories = ['person', 'dog', 'cat']

# 输出路径
outputPath = './coco_person_cat_dog'
os.makedirs(outputPath, exist_ok=True)

# 注释输出路径
annOutputPath = os.path.join(outputPath, 'annotations')
os.makedirs(annOutputPath, exist_ok=True)

get_coco_category(categories)

五、VOC数据集

  1. 导入必要的库
import os
import shutil
from lxml import etree
from lxml.etree import Element, SubElement
  1. 要求输入的参数
# 要求输入的参数
dataDir = './VOCdevkit'
dataType = 'VOC2012'
categories = ['person', 'dog', 'cat']
  1. 定义相关函数
def get_voc_category(category_list):
    """
    :param category_list: 需要提取的类别列表
    :return: None
    """
    # 原始标注文件夹路径
    annDir = os.path.join(dataDir, dataType, 'Annotations')
    imgDir = os.path.join(dataDir, dataType, 'JPEGImages')

    # 新标注文件夹路径
    newAnnDir = os.path.join(outputPath, 'Annotations')
    newImgDir = os.path.join(outputPath, 'JPEGImages')

    os.makedirs(newAnnDir, exist_ok=True)
    os.makedirs(newImgDir, exist_ok=True)

    for fileName in os.listdir(annDir):
        annFilePath = os.path.join(annDir, fileName)
        xmlTree = etree.parse(annFilePath)
        xmlRoot = xmlTree.getroot()
        objects = xmlRoot.findall('object')
        findFlag = False
        for object in objects:
            name = object.find('name').text.lower().strip()
            # 筛选需要的类别
            if name in category_list:
                findFlag = True
                break
        if findFlag:
            imageName = fileName.replace('.xml', '.jpg')
            shutil.copy(os.path.join(imgDir, imageName), os.path.join(newImgDir, imageName))
            shutil.copy(annFilePath, os.path.join(newAnnDir, fileName))


def create_new_voc_annotations():
    """
    :return: None
    """
    # 原始标注文件夹路径
    annDir = os.path.join(dataDir, dataType, 'Annotations')
    imgDir = os.path.join(dataDir, dataType, 'JPEGImages')

    # 新标注文件夹路径
    newAnnDir = os.path.join(outputPath, 'Annotations')
    os.makedirs(newAnnDir, exist_ok=True)

    for fileName in os.listdir(annDir):
        annFilePath = os.path.join(annDir, fileName)
        xmlTree = etree.parse(annFilePath)
        xmlRoot = xmlTree.getroot()
        objects = xmlRoot.findall('object')
        newObjects = []
        for object in objects:
            name = object.find('name').text.lower().strip()
            # 筛选需要的类别
            if name in categories:
                bndBox = object.find('bndbox')
                newObject = Element('object')
                newName = SubElement(newObject, 'name')
                newName.text = name
                newBndBox = SubElement(newObject, 'bndbox')

                newBndBoxXmin = SubElement(newBndBox, 'xmin')
                newBndBoxYmin = SubElement(newBndBox, 'ymin')
                newBndBoxXmax = SubElement(newBndBox, 'xmax')
                newBndBoxYmax = SubElement(newBndBox, 'ymax')

                newBndBoxXmin.text = bndBox.find('xmin').text
                newBndBoxYmin.text = bndBox.find('ymin').text
                newBndBoxXmax.text = bndBox.find('xmax').text
                newBndBoxYmax.text = bndBox.find('ymax').text

                newObjects.append(newObject)

        if len(newObjects) > 0:
            newXmlRoot = Element('annotation')
            newObjects = [SubElement(newXmlRoot, obj) for obj in newObjects]
            newTree = etree.ElementTree(newXmlRoot)
            newTree.write(os.path.join(newAnnDir, fileName), pretty_print=True)
  1. 示例
# 输出路径
outputPath = './VOC_person_cat_dog'
os.makedirs(outputPath, exist_ok=True)

get_voc_category(categories)
create_new_voc_annotations()

以上是提取COCO、VOC数据集中特定的类的完整攻略,示例代码已经给出,可以根据实际需要进行修改。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python实现提取COCO,VOC数据集中特定的类 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • python微信跳一跳系列之棋子定位颜色识别

    下面是“Python微信跳一跳系列之棋子定位颜色识别”的完整攻略。 前言 本攻略是关于使用Python实现微信跳一跳自动玩游戏的系列文章之一,主要介绍棋子定位和颜色识别的方法,用于辅助自动玩游戏。 棋子定位 在跳一跳游戏中,我们利用手机截图并导入电脑后,需要先找到当前界面中棋子所在的位置,从而计算出距离和方向。因此,在Python中需要实现棋子的定位操作。 …

    python 2023年6月6日
    00
  • python使用socket高效传输视频数据帧(连续发送图片)

    下面我将为您详细讲解“python使用socket高效传输视频数据帧(连续发送图片)”的完整实例教程,包括示例说明: 1. 简介 在本教程中,我们将使用Python中的socket库实现高效的视频数据帧传输,特别是连续发送图片。实现这种数据流的目标是传输即时视频,并尽可能地减小延迟。 2. 实现 2.1 导入库 我们首先要导入需要的Python库: impo…

    python 2023年5月13日
    00
  • python多线程同步售票系统

    Python多线程同步售票系统 简介 在本系统中,我们将使用Python的多线程和线程同步技术,编写一个简单的售票系统。该系统包括两个主要模块:票务管理模块和售票模块。 票务管理模块 票务管理模块需要维护车票的总数(假设为100张)和已售出的票数。票务管理员可以通过该模块完成以下操作: 查询当前余票数量 查询已售票数量 增加车票数量 我们可以通过使用Pyth…

    python 2023年5月18日
    00
  • 无法在 Python 2.7 中为 ldap 设置 TIMEOUT

    【问题标题】:Unable to set TIMEOUT for ldap in Python 2.7无法在 Python 2.7 中为 ldap 设置 TIMEOUT 【发布时间】:2023-04-04 10:56:01 【问题描述】: 我想为 ldap 库 (python-ldap-2.4.15-2.el7.x86_64) 和 python 2.7 设置…

    Python开发 2023年4月6日
    00
  • python中各种路径设置的方法详解

    当我们在使用Python开发时,常常需要处理文件或者目录的路径,正确地设置和使用路径是保证程序正常运行的重要基础。本篇攻略将介绍Python中各种路径设置的方法,包括绝对路径、相对路径、os模块、os.path模块及Pathlib库。 绝对路径与相对路径 路径分为绝对路径和相对路径。绝对路径是从根目录开始的完整路径,比如在Windows操作系统中,绝对路径通…

    python 2023年6月2日
    00
  • python连接打印机实现打印文档、图片、pdf文件等功能

    下面我将为您讲解如何使用 Python 连接打印机,实现打印文档、图片、pdf 文件等功能的完整攻略。整个过程包含以下几个步骤: 确定打印机类型 安装打印机驱动程序 安装 Python 插件 编写 Python 程序 执行 Python 程序 下面我将一步一步为您详细讲解如何实现每一步。 1. 确定打印机类型 首先需要确定使用的打印机类型。对于本地打印机,可…

    python 2023年5月23日
    00
  • PyQt5编程扩展之资源文件的使用教程

    我来为您详细讲解“PyQt5编程扩展之资源文件的使用教程”吧。 什么是资源文件 在PyQt5中,资源文件是一种用于存储应用程序中的图像、音频文件和其它资源的文件。资源文件通常以.qrc为扩展名,其中.qrc是XML格式的文件。它允许我们把应用程序中的资源打包成一个文件,这样就可以方便地管理和访问它们。 资源文件的使用 1. 使用工具生成.qrc文件 我们可以…

    python 2023年6月5日
    00
  • python自动化操作之动态验证码、滑动验证码的降噪和识别

    Python自动化操作之动态验证码、滑动验证码的降噪和识别 什么是动态验证码和滑动验证码? 动态验证码和滑动验证码是常见的防止自动化操作的方式。动态验证码是指,验证码在输入之前会动态地改变,比如验证码的旋转角度、字体颜色等。滑动验证码是指,用户需要将图片中的某一个小块通过拖动的方式移动到正确的位置才能够通过验证。 如何降噪和识别动态验证码和滑动验证码? 1.…

    python 2023年6月6日
    00
合作推广
合作推广
分享本页
返回顶部