One of the methods to prevent overfitting is to have more data. By this, our model will be exposed to more aspects of data and thus will generalize better. To get more data, either you manually collect data or generate data from the existing data by applying some transformations. The latter method is known as Data Augmentation.
In this blog, we will learn how we can perform data augmentation using Keras ImageDataGenerator class. First, we will discuss keras image augmentation API and then we will learn how to use this.
Keras API
1 |
ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, zca_epsilon=1e-06, rotation_range=0, width_shift_range=0.0, height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0, channel_shift_range=0.0, fill_mode='nearest', cval=0.0, horizontal_flip=False, vertical_flip=False, rescale=None, preprocessing_function=None, data_format=None, validation_split=0.0, dtype=None) |
Let’s understand each of its arguments in detail using the following image
featurewise_center: Feature-wise means of the entire dataset. So, in this, we first calculate the mean over the entire dataset and then subtract this mean from each image. So, this results in shifting the mean of the distribution close to zero. To calculate the mean, you need to fit the data generator to the training data as
1 2 |
datagen = ImageDataGenerator(featurewise_center=True) datagen.fit(x_train) |
For this, you have to load the entire training dataset which may significantly kill your memory if the dataset is large. To prevent this, one can calculate the mean from a smaller sample.
featurewise_std_normalization: In this, we divide each image by the standard deviation of the entire dataset. Thus, featurewise center and std_normalization together known as standardization tends to make the mean of the data to be 0 and std. deviation of 1 or in short Gaussian Distribution.
samplewise_center: Sample-wise means of a single image. So, in this, we set the mean pixel value of each image to be zero. Since the image mean is a local statistic that can be calculated from the image itself, there is no need for calling the fit method.
samplewise_std_normalization: In this, we divide each input image by its standard deviation.
zca_whitening: This is a preprocessing method which tries to remove the redundancy from the data while keeping its structure intact, unlike PCA. In short, this strengthens the high-frequency components in the image. For maths behind this, refer to this StackOverflow question. You need to fit the training data to calculate the principal components. This should be used with featurewise_center=True, otherwise, this will give you a warning and automatically set featurewise_center=True.
Note: For featurewise_center, featurewise_std_normalization, zca_whitening, one must fit the data to calculate the mean, standard deviation, and principal components.
rotation_range: This rotates each image up to the angle specified. Below figure shows the rotations by 45 degrees
width_shift_range: This results in shifting the image in the horizontal direction.
- If it is a float less than 1, then this shifts the image by that fraction of width. For instance, 0.2 means shift horizontally by 20% of the image width.
- If it is integer >=1, then this shifts the image horizontally by pixels in the range [-num, num]. For instance, 3 means shift horizontally by the pixels selected from the range [-2,-1,0,1,2]. So, the image may be shifted by 2 or 1 or 0 pixels.
- Similarly for a 1D array.
height_shift_range: Similar to width_shift_range but in the vertical direction.
brightness_range: This produces images similar to as taken with different lighting conditions. In this, you pass the min and the max range based on which the image will be darkened or brightened. Values <1 darkens the image, >1 brightens the image and =1 means no change. For example, below line darkens the image as shown
1 |
datagen = ImageDataGenerator(brightness_range=[0.2,0.8]) |
rescale: This is to normalize the pixel values to a specific range. For 8-bit image, we generally rescale by 1/255 so as to have pixel values in the range 0 and 1.
shear_range: This is the shear angle in the counter-clockwise direction in degrees.
zoom_range: This zooms the image. If passed as float then [lower, upper] = [1-zoom_range, 1+zoom_range]. For instance, 0.2 means zoom in the range [0.8, 1.2]. Can also be passed a list directly.
channel_shift_range: This randomly shifts the values of the channels by the values specified. The below code sums up what this actually does.
1 |
[np.clip(x_channel + np.random.uniform(-value, value), min_img, max_img) for x_channel in img] |
Add random values to channel and then clipping depending on the max and min of the image.
horizontal_flip and vertical flip: Randomly flips the input image in the horizontal and vertical directions respectively.
data_format: Either channels_first or channels_last (default).
preprocessing_function: This function is applied to each input after the augmentation step. Below is an example of one such function where images are blurred
1 2 3 4 |
def blur(img): return (cv2.blur(img,(5,5))) datagen = ImageDataGenerator(preprocessing_function= blur) |
How to use this?
Below is the code using which I have generated the above images
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import numpy as np import matplotlib.pyplot as plt import keras from keras.preprocessing.image import load_img, ImageDataGenerator, img_to_array # Load the image and change it into an array and expand the dimensions img = load_img('D:/downloads/opencv_logo.PNG') img = img_to_array(img) img1 = np.expand_dims(img, axis=0) # create an instance of the class with the desired operation datagen = ImageDataGenerator(horizontal_flip=True) # Depending on the augmentation method you may need to call # fit method to calculate the global statistics data_generator = datagen.flow(img1,batch_size=1) # Display some augmented samples plt.figure(figsize=(10,5)) for i in range(6): plt.subplot(2,3,i+1) for x in data_generator: plt.imshow(x[0]/255.) plt.xticks([]) plt.yticks([]) break plt.tight_layout() plt.show() |
This way you can create augmented examples. In the next blog, we will discuss how to generate batches of augmented data using the flow method.
Hope you enjoy reading.
If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.