데이터 셋

CIFAR Dataset

cifar dataset

dataset/cifar.py
  1import tarfile
  2import pickle
  3import requests
  4import os.path as osp
  5import numpy as np
  6from string import Template
  7
  8cifar_base_url = "https://www.cs.toronto.edu/~kriz/"
  9dataset_dir = osp.dirname(osp.abspath(__file__))
 10save_file_template = Template(osp.join(dataset_dir, "cifar-$nr_classes.pkl"))
 11data_file_template = Template('cifar-$nr_classes-python.tar.gz')
 12
 13
 14def _download(file_name):
 15    file_path = dataset_dir + "/" + file_name
 16
 17    if osp.exists(file_path):
 18        return
 19
 20    print("Downloading " + file_name + " ... ")
 21    url = cifar_base_url + file_name
 22    r = requests.get(url=url, stream=True)
 23    with open(file_path, 'wb') as fd:
 24        for chunk in r.iter_content(chunk_size=8192):
 25            fd.write(chunk)
 26    print("Done")
 27
 28
 29def _read_dataset(nr_classes):
 30    file_name = data_file_template.substitute(nr_classes=nr_classes)
 31    file_path = osp.join(dataset_dir, file_name)
 32    print("Extracting " + file_name + " ...")
 33    with tarfile.open(file_path, 'r:gz') as tar:
 34        train_data = []
 35        train_label = []
 36        test_data = None
 37        test_label = None
 38
 39        for tar_info in tar:
 40            if tar_info.isfile():
 41                tar_info_name = tar_info.name.split('/')[-1]
 42                if "data" in tar_info_name or "train" in tar_info_name:
 43                    extracted = tar.extractfile(tar_info)
 44                    raw_data, raw_label = _read_cfar_data(
 45                        extracted, nr_classes)
 46                    train_data.append(raw_data)
 47                    train_label.extend(raw_label)
 48                elif "meta" in tar_info_name:
 49                    extracted = tar.extractfile(tar_info)
 50                    meta = _read_cfar_meta(extracted, nr_classes)
 51                elif "test" in tar_info_name:
 52                    extracted = tar.extractfile(tar_info)
 53                    test_data, test_label = _read_cfar_data(
 54                        extracted, nr_classes)
 55    print("Done!")
 56    train_img = np.vstack(train_data) if nr_classes == 10 else train_data[0]
 57    train_label = np.array(train_label)
 58    test_img = test_data
 59    test_label = np.array(test_label)
 60    dataset = {'train_img': train_img, 'train_label': train_label,
 61               'test_img': test_img, 'test_label': test_label, 'classes': meta}
 62    return dataset
 63
 64
 65def _read_cfar_data(tar_extracted, nr_classes):
 66    raw_dict = pickle.load(tar_extracted, encoding='latin1')
 67    raw_data = raw_dict['data']
 68    # raw_data = raw_data.reshape(raw_data.shape[0], 3, 32, 32)
 69    label = raw_dict['labels' if nr_classes == 10 else 'fine_labels']
 70    return raw_data, label
 71
 72
 73def _read_cfar_meta(tar_extracted, nr_classes):
 74    meta = pickle.load(tar_extracted, encoding='latin1')
 75    return meta['label_names'] if nr_classes == 10 else meta['fine_label_names']
 76
 77
 78def download_cifar(nr_classes=10):
 79    assert nr_classes in (10, 100)
 80
 81    data_file = data_file_template.substitute(nr_classes=nr_classes)
 82    _download(data_file)
 83
 84
 85def init_cifar(nr_classes=10):
 86    """다운로드하고 학습데이터와 시험데이터로 분리해 피클로 저장"""
 87
 88    download_cifar(nr_classes)
 89    dataset = _read_dataset(nr_classes)
 90    print("Creating pickle file ...")
 91    with open(save_file_template.substitute(nr_classes=nr_classes), 'wb') as f:
 92        pickle.dump(dataset, f, -1)
 93    print("Done!")
 94
 95
 96def _change_one_hot_label(X):
 97    T = np.zeros((X.size, 10))
 98    for idx, row in enumerate(T):
 99        row[X[idx]] = 1
100
101    return T
102
103
104def load_cifar(normalize=True, flatten=True, one_hot_label=False, nr_classes=10):
105    """CFAR 데이터셋 읽기
106
107    Parameters
108    ----------
109    normalize : 이미지의 픽셀 값을 0.0~1.0 사이의 값으로 정규화할지 정한다.
110    one_hot_label : 
111        one_hot_label이 True면、레이블을 원-핫(one-hot) 배열로 돌려준다.
112        one-hot 배열은 예를 들어 [0,0,1,0,0,0,0,0,0,0]처럼 한 원소만 1인 배열이다.
113    flatten : 입력 이미지를 1차원 배열로 만들지를 정한다. 
114
115    Returns
116    -------
117    (훈련 이미지, 훈련 레이블), (시험 이미지, 시험 레이블), 레이블 이름
118    """
119
120    save_file = save_file_template.substitute(nr_classes=nr_classes)
121    if not osp.exists(save_file):
122        init_cifar()
123
124    with open(save_file, 'rb') as f:
125        dataset = pickle.load(f)
126
127    if normalize:
128        for key in ('train_img', 'test_img'):
129            dataset[key] = dataset[key].astype(np.float32)
130            dataset[key] /= 255.0
131
132    if one_hot_label:
133        dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
134        dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
135
136    if not flatten:
137        for key in ('train_img', 'test_img'):
138            dataset[key] = dataset[key].reshape(-1, 3, 32, 32)
139
140    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']), dataset['classes']
141
142
143if __name__ == '__main__':
144    init_cifar()

연습문제

  1. CIFAR-10 데이터 중에서 무작위로 100개를 선택하여 10x10 테이블로 그림을 그려보세요. 그리고 각 이미지의 레이블을 이미지 아래에 표시해 보세요.

  2. CIFAR-10 데이터를 SimpleConvNet에 적용하여 결과를 토론해 보세요.