데이터 셋
CIFAR Dataset
1import tarfile
2import pickle
3import requests
4import os.path as osp
5import numpy as np
6from string import Template
7
8cifar_base_url = "https://www.cs.toronto.edu/~kriz/"
9dataset_dir = osp.dirname(osp.abspath(__file__))
10save_file_template = Template(osp.join(dataset_dir, "cifar-$nr_classes.pkl"))
11data_file_template = Template('cifar-$nr_classes-python.tar.gz')
12
13
14def _download(file_name):
15 file_path = dataset_dir + "/" + file_name
16
17 if osp.exists(file_path):
18 return
19
20 print("Downloading " + file_name + " ... ")
21 url = cifar_base_url + file_name
22 r = requests.get(url=url, stream=True)
23 with open(file_path, 'wb') as fd:
24 for chunk in r.iter_content(chunk_size=8192):
25 fd.write(chunk)
26 print("Done")
27
28
29def _read_dataset(nr_classes):
30 file_name = data_file_template.substitute(nr_classes=nr_classes)
31 file_path = osp.join(dataset_dir, file_name)
32 print("Extracting " + file_name + " ...")
33 with tarfile.open(file_path, 'r:gz') as tar:
34 train_data = []
35 train_label = []
36 test_data = None
37 test_label = None
38
39 for tar_info in tar:
40 if tar_info.isfile():
41 tar_info_name = tar_info.name.split('/')[-1]
42 if "data" in tar_info_name or "train" in tar_info_name:
43 extracted = tar.extractfile(tar_info)
44 raw_data, raw_label = _read_cfar_data(
45 extracted, nr_classes)
46 train_data.append(raw_data)
47 train_label.extend(raw_label)
48 elif "meta" in tar_info_name:
49 extracted = tar.extractfile(tar_info)
50 meta = _read_cfar_meta(extracted, nr_classes)
51 elif "test" in tar_info_name:
52 extracted = tar.extractfile(tar_info)
53 test_data, test_label = _read_cfar_data(
54 extracted, nr_classes)
55 print("Done!")
56 train_img = np.vstack(train_data) if nr_classes == 10 else train_data[0]
57 train_label = np.array(train_label)
58 test_img = test_data
59 test_label = np.array(test_label)
60 dataset = {'train_img': train_img, 'train_label': train_label,
61 'test_img': test_img, 'test_label': test_label, 'classes': meta}
62 return dataset
63
64
65def _read_cfar_data(tar_extracted, nr_classes):
66 raw_dict = pickle.load(tar_extracted, encoding='latin1')
67 raw_data = raw_dict['data']
68 # raw_data = raw_data.reshape(raw_data.shape[0], 3, 32, 32)
69 label = raw_dict['labels' if nr_classes == 10 else 'fine_labels']
70 return raw_data, label
71
72
73def _read_cfar_meta(tar_extracted, nr_classes):
74 meta = pickle.load(tar_extracted, encoding='latin1')
75 return meta['label_names'] if nr_classes == 10 else meta['fine_label_names']
76
77
78def download_cifar(nr_classes=10):
79 assert nr_classes in (10, 100)
80
81 data_file = data_file_template.substitute(nr_classes=nr_classes)
82 _download(data_file)
83
84
85def init_cifar(nr_classes=10):
86 """다운로드하고 학습데이터와 시험데이터로 분리해 피클로 저장"""
87
88 download_cifar(nr_classes)
89 dataset = _read_dataset(nr_classes)
90 print("Creating pickle file ...")
91 with open(save_file_template.substitute(nr_classes=nr_classes), 'wb') as f:
92 pickle.dump(dataset, f, -1)
93 print("Done!")
94
95
96def _change_one_hot_label(X):
97 T = np.zeros((X.size, 10))
98 for idx, row in enumerate(T):
99 row[X[idx]] = 1
100
101 return T
102
103
104def load_cifar(normalize=True, flatten=True, one_hot_label=False, nr_classes=10):
105 """CFAR 데이터셋 읽기
106
107 Parameters
108 ----------
109 normalize : 이미지의 픽셀 값을 0.0~1.0 사이의 값으로 정규화할지 정한다.
110 one_hot_label :
111 one_hot_label이 True면、레이블을 원-핫(one-hot) 배열로 돌려준다.
112 one-hot 배열은 예를 들어 [0,0,1,0,0,0,0,0,0,0]처럼 한 원소만 1인 배열이다.
113 flatten : 입력 이미지를 1차원 배열로 만들지를 정한다.
114
115 Returns
116 -------
117 (훈련 이미지, 훈련 레이블), (시험 이미지, 시험 레이블), 레이블 이름
118 """
119
120 save_file = save_file_template.substitute(nr_classes=nr_classes)
121 if not osp.exists(save_file):
122 init_cifar()
123
124 with open(save_file, 'rb') as f:
125 dataset = pickle.load(f)
126
127 if normalize:
128 for key in ('train_img', 'test_img'):
129 dataset[key] = dataset[key].astype(np.float32)
130 dataset[key] /= 255.0
131
132 if one_hot_label:
133 dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
134 dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
135
136 if not flatten:
137 for key in ('train_img', 'test_img'):
138 dataset[key] = dataset[key].reshape(-1, 3, 32, 32)
139
140 return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']), dataset['classes']
141
142
143if __name__ == '__main__':
144 init_cifar()
연습문제
CIFAR-10 데이터 중에서 무작위로 100개를 선택하여 10x10 테이블로 그림을 그려보세요. 그리고 각 이미지의 레이블을 이미지 아래에 표시해 보세요.
CIFAR-10 데이터를 SimpleConvNet에 적용하여 결과를 토론해 보세요.