In a classification task, sometimes a situation where some class is not equally distributed. What do you do in this case? How to deal with class imbalance? There are various techniques that you can use to overcome class imbalances. One of them is set to class weight. In this tutorial, we will discuss how to set class weight for an individual class. It gives weight to the minority class proportional to its underrepresentation.

DataSet

Let’s first create the problem dataset, for now, only try to identify one image from CIFAR10 for example, the dog. This “dog-detector” will be an example of a binary classifier, capable of distinguishing between just two classes, dog and not-dog. Let’s create the target vectors for this classification task:

(x_train,y_train),(x_test,y_test)=tf.keras.datasets.cifar10.load_data()

y_train_dog = [0 if y==5 else 1 for y in y_train]
y_test_dog = [0 if y==5 else 1 for y in y_test]

unique, counts = np.unique(y_train_dog, return_counts=True)
dict(zip(unique, counts))

Create a Model

Here, we create a simple model for binary classification in TensorFlow Keras.

model=tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32,(3,3),padding='same',input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Activation('relu'))

model.add(tf.keras.layers.Conv2D(32,(3,3)))
model.add(tf.keras.layers.Activation('relu'))

model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))


model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512))
model.add(tf.keras.layers.Activation('relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(2, activation='softmax')) 

model.compile(loss='sparse_categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

Train and Evaluate model

Evaluating a classifier is significantly tricky when the classes are imbalanced. A simple way to evaluate a model is to use model accuracy.

model.fit(x_train, y_train_dog,
              batch_size=BATCH_SIZE,
              epochs=3,
              validation_data=(x_test, y_test_dog),
              shuffle=True)

model.evaluate(x_test, y_test_dog, verbose=1)

It has over 90% accuracy! This is simply because only about 10% of the images are dogs, so if you always guess that an image is not a dog, you will be right about 90% of the time.

This demonstrates why accuracy is generally not the preferred performance measure for classifiers, especially when some classes are much more frequent than others.

Set Class Weight

You can set the class weight for every class when the dataset is unbalanced. Let’s say you have 5000 samples of class dog and 45000 samples of class not-dog than you feed in class_weight = {0: 5, 1: 0.5}. That gives class “dog” 10 times the weight of class “not-dog” which means that in your loss function, you assign a higher value to these instances.

The loss becomes a weighted average when the weight of each sample is specified by class_weight and its corresponding class.

Calculate Class Weight

You can calculate class weight programmatically using scikit-learn´s sklearn.utils.compute_class_weight().

from sklearn.utils import class_weight


class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(y_train_dog),
                                                 y_train_dog)

It looks distribution of labels and produces weights to equally penalize under or over-represented classes in the training set.

class_weight.compute_class_weight produces an array, we need to change it to a dict in order to work with Keras.

class_weights = dict(enumerate(class_weights))

Train Model with Class Weight

The class_weight parameter of the fit() function is a dictionary mapping class to a weight value. Feed this dictionary as a parameter of model fit.

model.fit(x_train, y_train_dog,
              batch_size=BATCH_SIZE,
              epochs=3,
              class_weight=class_weights,
              validation_data=(x_test, y_test_dog),
              shuffle=True)

Related Post

Run this code in Google colab