Preparing CIFAR Image Data for PyTorch — Visual Studio Magazine
The Data Science Lab
Preparing CIFAR image data for PyTorch
CIFAR-10 problems analyze raw 32 x 32 color images to predict which of 10 classes the image belongs to. Here, Dr. James McCaffrey of Microsoft Research explains how to get the raw source CIFAR-10 data, convert it from binary to text, and save it as a text file that can be used to train a PyTorch neural network classifier.
A common dataset for image classification experiments is CIFAR-10. The goal of a CIFAR-10 problem is to analyze a 32 x 32 raw color image and predict which of the 10 classes the image belongs to. The 10 classes are Airplane, Car, Bird, Cat, Deer, Dog, Frog, Horse, Boat, and Truck.
CIFAR-10 (Canadian Institute for Advanced Research, 10 class) data includes 50,000 images for training and 10,000 images for testing. This article explains how to get the source CIFAR-10 raw data, convert the binary data to text, and save the data as a text file that can be used to train a PyTorch neural network classifier.
The most popular neural network libraries, including PyTorch, scikit, and Keras, have some form of built-in CIFAR-10 dataset designed to work with the library. But there are two problems with using an integrated dataset. First, data access becomes a magic black box and important information is hidden. Second, the built-in datasets use the 50,000 training images and 10,000 test images and it is difficult to use them because they are so large.
A good way to see where this article is heading is to take a look at the screenshot of a Python language program in Figure 1. The program loads a batch of 10,000 training images into memory. The first 5,000 images are converted from binary to text and then saved as “cifar10_train_5000.txt”. The program ends by opening the saved text file and displaying the first image, which is a scary frog with red eyes.
This article assumes that you have intermediate or better knowledge of a C-family programming language, preferably Python, but does not assume that you know anything about the CIFAR-10 dataset. The full source code for the demo program is shown in this article, and the code is also available in the companion file download.
Obtaining Source Data Files
Source data for CIFAR-10 can be found at www.cs.toronto.edu/~kriz/cifar.html. There are three different versions of data: Python pickle format, Matlab data format, and raw binary format. The Python format is the easiest to use in my opinion. If you click on the “CIFAR-10 Python version” link, you will download a file named cifar-10-python.tar.gz (tape-archive, gnu-zip format) to your machine.
Unlike regular zip files, Windows cannot extract tar.gz files, so you must use an application. I recommend the free utility 7-Zip, available at www.7-zip.org/. After installing 7-Zip, you can open Windows File Explorer, then right-click on the cifar-10-python.tar.gz file and select the Extract Here option. This will result in a file named cifar-10-python.tar. If you right-click on this tar file and select the Extract Here option again, you will get an uncompressed root directory named cifar-10-batches-py.
The cifar-10-batches-py directory contains six binaries that have names without file extensions: data_batch_1, data_batch_2, data_batch_3, data_batch_4, data_batch_5, and test_batch. Each of these files contains 10,000 images in the Python “pickle” binary format.
Each image is 32 x 32 pixels. As the images are in color, there are three channels (red, green, blue). Each channel-pixel value is an integer between 0 and 255. Therefore, each image is represented by 32 * 32 * 3 = 3072 values between 0 and 255.
Convert images from binary to text
To convert CIFAR-10 images from pickle binary format to text, you need to write a short Python program. To see List 1.
Convert CIFAR-10 images from binary to text
# unpickle_cifar10.py import numpy as np import pickle import matplotlib.pyplot as plt # ----------------------------------------------------------- print("nBegin demo ") print("nLoading CIFAR-10 images into dict in memory ") file = ".cifar-10-batches-pydata_batch_1" # train # file = ".cifar-10-batches-pytest_batch" # test with open(file, 'rb') as fin: dict = pickle.load(fin, encoding='bytes') # keys: b'batch_label' b'labels' b'data' b'filenames' labels = dict[b'labels'] # 10,000 labels pixels = dict[b'data'] # 3,072 per image (1024 per channel) n_images = 5000 # train # n_images = 1000 # test print("nWriting " + str(n_images) + " images to text file ") fn = ".cifar10_train_5000.txt" # train # fn = ".cifar10_test_1000.txt" # test fout = open(fn, 'w', encoding='utf-8') for i in range (n_images): # n images for j in range(3072): # pixels val = pixels[i][j] fout.write(str(val) + ",") fout.write(str(labels[i]) + "n") fout.close() print("Done ") print("nDisplaying first image in saved file: ") data = np.loadtxt(fn, delimiter=",", usecols=range(0,3072), dtype=np.int64) # img = data.reshape(3,32,32).transpose([1, 2, 0]) pxls_R = data[0:1024].reshape(32,32) # not last val pxls_G = data[1024:2048].reshape(32,32) pxls_B = data[2048:3072].reshape(32,32) img = np.dstack((pxls_R, pxls_G, pxls_B)) # depth-stack plt.imshow(img) plt.show() print("nEnd demo ")
The program is named unpickle_cifar10.py. The program starts by importing three modules:
import pickle import numpy as np import matplotlib.pyplot as plt
All standard Python distributions, like Anaconda, contain these three modules by default. If your distribution does not have the pickle module, you can install it using the pip package manager.
The program specifies the binary file to use as the source, then loads 10,000 images into memory:
file = ".cifar-10-batches-pydata_batch_1" # train # file = ".cifar-10-batches-pytest_batch" # test with open(file, 'rb') as fin: dict = pickle.load(fin, encoding='bytes') # keys: b'batch_label' b'labels' b'data' b'filenames'
The pickle.load() function stores the specified 10,000 images and their associated ‘0’ to ‘9’ class labels in a Python dictionary object that has four keys: b’batch_label’, b’labels’, b’data ‘, b ‘filenames’. The important keys are b’data’ (the pixel values) and b’labels’ (the labels in integer form from 0 to 9). The ‘b’ means the strings are in byte format rather than Unicode character format. The ‘rb’ argument passed to the file open() function means “read binary”.
Next, the program extracts the pixel values and labels from the dictionary object and specifies how many images to save out of the 10,000:
labels = dict[b'labels'] # 10,000 labels pixels = dict[b'data'] # 3,072 per image (1024 per channel) n_images = 5000 # train # n_images = 1000 # test print("nWriting " + str(n_images) + " images to text file ")
The program uses only 5,000 of the 10,000 images just to show how to limit the number of images. In most scenarios, you will save all 10,000 images from each batch.
The program loops through each frame, writing 3072 pixel values and associated class labels to a specified text file:
fn = ".cifar10_train_5000.txt" # train # fn = ".cifar10_test_1000.txt" # test fout = open(fn, 'w', encoding='utf-8') for i in range (n_images): # n images for j in range(3072): # pixels val = pixels[i][j] fout.write(str(val) + ",") fout.write(str(labels[i]) + "n") fout.close()
Each image is stored on one line of the destination text file. The first 32 x 32 = 1024 values are the red components of the image. The next 1024 values are the green components, then the next 1024 values are the blue components. The last value of each line is the class label from ‘0’ to ‘9’. Each value is delimited by commas. You can use a different delimiter, such as a tab, if you wish.
If you want to save the 50,000 training images to a file, you can repeat the program code using the five binaries, then pass an ‘a+’ (“create to write or append”) argument to the file open function ().
Viewing a CIFAR-10 image
After converting CIFAR-10 images from binary to text, it’s a good idea to review the text data to make sure the conversion worked correctly. The demo program displays the first saved image using these instructions:
data = np.loadtxt(fn, delimiter=",", usecols=range(0,3072), dtype=np.int64) pxls_R = data[0:1024].reshape(32,32) # not last val pxls_G = data[1024:2048].reshape(32,32) pxls_B = data[2048:3072].reshape(32,32) img = np.dstack((pxls_R, pxls_G, pxls_B)) # depth-stack plt.imshow(img) plt.show()
The text file containing the images is loaded into memory using the np.loadtxt() function. You only need the first 3072 values in a line, but you can also get the class label to use as the image title. Each set of 1024 RGB channel values is reshaped into a 32 x 32 plane. The three planes are combined into a single image using the specially designed dstack() (“deep stack”) function.
The 3072 pixel values are now in geometry that the imshow() function can handle. The imshow() function automatically scales all values in a range (0.0, 1.0), so you don’t need to scale explicitly when loading data as type np.float32 (au instead of np.int64), then dividing all values by 255 .
Instead of working with RGB channels separately, you can work with all three at once:
img = data.reshape(3,32,32).transpose([1, 2, 0])
This saves three lines of code but is harder to understand in my opinion.
Several years ago, classification of the MNIST (Modified National Institute of Standards and Technology) image dataset was considered a difficult challenge. MNIST images are single color (i.e. grayscale) handwritten digits with only 28 x 28 pixels. Classification of MNIST data has become too easy and the more difficult CIFAR-10 data is often used instead. With current techniques, it is relatively easy to achieve an accuracy of around 90% on CIFAR-10, but it is quite difficult to achieve an accuracy greater than 90%.
A related dataset you might come across is CIFAR-100. The CIFAR-100 dataset contains 60,000 images with 100 classes (600 images of each class). The 100 classes are items like “apple” (0), “bicycle” (8), “turtle” (93), and “worm” (99).
Dr. James McCaffrey works for Microsoft Research in Redmond, Washington. He has worked on several Microsoft products including Azure and Bing. Jacques can be reached at [email protected].