一、前言
在深度学习中,数据集是非常重要的资源之一,但是我们有时需要从一个大的数据集中提取出特定的类别,这样可以让我们在模型训练、测试或者其他操作上更加方便。本文将介绍如何使用Python代码从COCO、VOC数据集中提取特定的类。
二、准备工作
在进行以下操作前,需要下载并解压相应的数据集,以COCO2017数据集为例,可以在官方网站(http://cocodataset.org/#download)下载并解压到合适的目录中。
安装必要的Python库:
# 安装必要的Python库
!pip install numpy matplotlib scipy scikit-image pillow
三、COCO数据集
- 导入必要的库
import json
import os
import shutil
from pycocotools.coco import COCO
- 要求输入的参数
# 要求输入的参数
dataDir = './coco2017'
dataType = 'train2017'
annFile = os.path.join(dataDir, 'annotations', 'instances_{}.json'.format(dataType))
catIds = [1, 2, 3, 4, 6, 7, 8, 10, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
catNms = ['person', 'bicycle', 'car', 'motorcycle', 'bus',
'train', 'truck', 'traffic light', 'stop sign', 'fire hydrant',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror',
'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush']
- 定义函数提取COCO数据集中特定的类
def get_coco_category(category_list):
"""
:param category_list: 需要提取的类别列表
:return: None
"""
# 初始化COCO API
coco = COCO(annFile)
# 获取包含需要提取类别的image ID
catIds = coco.getCatIds(catNms=category_list)
imgIds = coco.getImgIds(catIds=catIds)
for imgId in imgIds:
imgInfo = coco.loadImgs(imgId)[0]
imgName = imgInfo['file_name']
srcImgPath = os.path.join(dataDir, dataType, imgName)
dstImgPath = os.path.join(outputPath, imgName)
shutil.copy(srcImgPath, dstImgPath)
annIds = coco.getAnnIds(imgIds=imgId, catIds=catIds, iscrowd=None)
anns = coco.loadAnns(annIds)
dstAnnPath = os.path.join(annOutputPath, imgName.replace('.jpg', '.json'))
with open(dstAnnPath, 'w') as wf:
json.dump(anns, wf)
- 示例
# 需要提取的类别列表
categories = ['person', 'dog', 'cat']
# 输出路径
outputPath = './coco_person_cat_dog'
os.makedirs(outputPath, exist_ok=True)
# 注释输出路径
annOutputPath = os.path.join(outputPath, 'annotations')
os.makedirs(annOutputPath, exist_ok=True)
get_coco_category(categories)
五、VOC数据集
- 导入必要的库
import os
import shutil
from lxml import etree
from lxml.etree import Element, SubElement
- 要求输入的参数
# 要求输入的参数
dataDir = './VOCdevkit'
dataType = 'VOC2012'
categories = ['person', 'dog', 'cat']
- 定义相关函数
def get_voc_category(category_list):
"""
:param category_list: 需要提取的类别列表
:return: None
"""
# 原始标注文件夹路径
annDir = os.path.join(dataDir, dataType, 'Annotations')
imgDir = os.path.join(dataDir, dataType, 'JPEGImages')
# 新标注文件夹路径
newAnnDir = os.path.join(outputPath, 'Annotations')
newImgDir = os.path.join(outputPath, 'JPEGImages')
os.makedirs(newAnnDir, exist_ok=True)
os.makedirs(newImgDir, exist_ok=True)
for fileName in os.listdir(annDir):
annFilePath = os.path.join(annDir, fileName)
xmlTree = etree.parse(annFilePath)
xmlRoot = xmlTree.getroot()
objects = xmlRoot.findall('object')
findFlag = False
for object in objects:
name = object.find('name').text.lower().strip()
# 筛选需要的类别
if name in category_list:
findFlag = True
break
if findFlag:
imageName = fileName.replace('.xml', '.jpg')
shutil.copy(os.path.join(imgDir, imageName), os.path.join(newImgDir, imageName))
shutil.copy(annFilePath, os.path.join(newAnnDir, fileName))
def create_new_voc_annotations():
"""
:return: None
"""
# 原始标注文件夹路径
annDir = os.path.join(dataDir, dataType, 'Annotations')
imgDir = os.path.join(dataDir, dataType, 'JPEGImages')
# 新标注文件夹路径
newAnnDir = os.path.join(outputPath, 'Annotations')
os.makedirs(newAnnDir, exist_ok=True)
for fileName in os.listdir(annDir):
annFilePath = os.path.join(annDir, fileName)
xmlTree = etree.parse(annFilePath)
xmlRoot = xmlTree.getroot()
objects = xmlRoot.findall('object')
newObjects = []
for object in objects:
name = object.find('name').text.lower().strip()
# 筛选需要的类别
if name in categories:
bndBox = object.find('bndbox')
newObject = Element('object')
newName = SubElement(newObject, 'name')
newName.text = name
newBndBox = SubElement(newObject, 'bndbox')
newBndBoxXmin = SubElement(newBndBox, 'xmin')
newBndBoxYmin = SubElement(newBndBox, 'ymin')
newBndBoxXmax = SubElement(newBndBox, 'xmax')
newBndBoxYmax = SubElement(newBndBox, 'ymax')
newBndBoxXmin.text = bndBox.find('xmin').text
newBndBoxYmin.text = bndBox.find('ymin').text
newBndBoxXmax.text = bndBox.find('xmax').text
newBndBoxYmax.text = bndBox.find('ymax').text
newObjects.append(newObject)
if len(newObjects) > 0:
newXmlRoot = Element('annotation')
newObjects = [SubElement(newXmlRoot, obj) for obj in newObjects]
newTree = etree.ElementTree(newXmlRoot)
newTree.write(os.path.join(newAnnDir, fileName), pretty_print=True)
- 示例
# 输出路径
outputPath = './VOC_person_cat_dog'
os.makedirs(outputPath, exist_ok=True)
get_voc_category(categories)
create_new_voc_annotations()
以上是提取COCO、VOC数据集中特定的类的完整攻略,示例代码已经给出,可以根据实际需要进行修改。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python实现提取COCO,VOC数据集中特定的类 - Python技术站