ガラクタ置き場

COCO データセットから適当にファイルを取得する

ガラクタ

SSDなどの動作確認に使えるかなぁ~?と思ってCOCO データセットから適当にいくつかの画像ファイルを取得するスクリプトを作ってみた。
結局使わなかったけど、せっかくなのでガラクタ置き場に置いておく。(^^ゞ

今回はリポジトリ作るほどではないので、ここにソース貼っとく。
解説するほどでもないので、使い方はソース見てチョ!← テヌキ…

ソース

'''
COCOデータセットから適当にファイルを取得する
'''

import sys
import os
import glob
import shutil
from argparse import ArgumentParser, SUPPRESS, RawTextHelpFormatter

import random
import urllib.request
import zipfile
import json

COCO_2014 = False       # COCO 2014のminivalデータセット を使用する場合はTrueにする

# パラメータ
if COCO_2014 :
    DOWNLOAD_URL = 'https://dl.dropboxusercontent.com/s/o43o90bna78omob/instances_minival2014.json.zip?dl=0'
    JSON_FILE    = 'instances_minival2014.json'
else :
    DOWNLOAD_URL = 'http://images.cocodataset.org/annotations/image_info_test2017.zip'
    JSON_FILE    = 'annotations/image_info_test2017.json'

# zipファイル展開のための一時ファイル
TEMP_ZIP     = 'temp.zip'

def build_argparser():
    parser = ArgumentParser(add_help=False, formatter_class=RawTextHelpFormatter)
    parser.add_argument('-h', '--help', action='help', default=SUPPRESS, 
                        help='Show this help message and exit.')
    parser.add_argument('--num', type=int, default=100,
                        help="ダウンロードするファイル数")
    parser.add_argument('--margin', type=int, default=30,
                        help="ダウンロードエラーのためのマージン数")
    parser.add_argument('--clean', action='store_true', 
                        help="ダウンロード済みのファイルを削除して終了します\n(ダウンロードは行いません)")
    return parser


# コマンドラインパーサの生成&解析
args = build_argparser().parse_args()

if args.clean :
    print("==== CLEAN!! ====")
    if os.path.isfile(TEMP_ZIP):
        os.remove(TEMP_ZIP)
    
    if os.path.isfile(JSON_FILE):
        if COCO_2014 :
            os.remove(JSON_FILE)
        else :
            shutil.rmtree(os.path.dirname(JSON_FILE), ignore_errors=True)
    
    for file in glob.glob('*.jpg'):
        if os.path.isfile(file):
            os.remove(file)
    
    sys.exit(0)

# ファイル数設定
num_files = args.num
num_margin = args.margin

# JSONファイルがなければダウンロードする
if not os.path.isfile(JSON_FILE) :
    # ダウンロードを実行
    print(f'{DOWNLOAD_URL} をダウンロード中...')
    urllib.request.urlretrieve(DOWNLOAD_URL, TEMP_ZIP)
    
    # 展開
    print(f'{TEMP_ZIP} を展開中...')
    with zipfile.ZipFile(TEMP_ZIP) as zf:
        zf.extractall()

    # テンポラリファイルの削除
    os.remove(TEMP_ZIP)

# JSONファイルの読み込み
print(f'{JSON_FILE}の読み込み中...')
with open(JSON_FILE, 'r') as f_json :
    data = json.load(f_json)

# データ個数(5000のハズ)
data_len = len(data["images"])
print(f'DATA_LENGTH = {data_len}')

# 内容をダンプ
# print(json.dumps(data["images"], indent=2))

# ダウンロード数のチェック
if num_files + num_margin > data_len :
    print('ERROR : ファイル数とマージンの合計がデータ個数を超えています')
    sys.exit(1)

# ダウンロードするインデックスを乱数で決定(エラー発生時のためにマージンを積んどく)
download_indexs = random.sample(range(data_len), num_files + num_margin) 

# ダウンロード完了個数
num_downloaded = 0
num_error = 0

# ダウンロードループ
for id in download_indexs :
    # ファイル名
    filename = data["images"][id]["file_name"]
    # URL
    if COCO_2014 :
        url      = data["images"][id]["url"]
    else :
        url      = data["images"][id]["coco_url"]
    # ダウンロード
    print(f'{filename}をダウンロード中...')
    print(f'ID : {id:3}\tNAME : {filename}\tURL : {url}')
    try : 
        urllib.request.urlretrieve(url, filename)
    except urllib.error.HTTPError as e :
        # エラー発生
        num_error += 1
        print(f'**** skip!! ****  {e}')
    else :
        # ダウンロード完了
        print('---- DONE!! ----')
        num_downloaded += 1
        # 所望の数に達したら終了
        if num_downloaded >= num_files :
            break

print(f'ダウンロード数 : {num_downloaded}')
print(f'エラー数       : {num_error}')
if num_error > num_margin :
    print(f'warning : エラー数がマージンを越えました')

今後、なんかに使えると良いなぁ~(笑)