With the help of the above mentioned github issue I managed to solve the problem for my particular use case. I want to share the solution with you anyway. An extra hurdle was the fact I am using a custom generator for my data. A simplified version of this class is the following code:
import numpy as np
import keras
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, list_IDs, batch_size=2, dim=(144,144,144), n_classes=2):
'Initialization'
self.dim = dim
self.batch_size = batch_size
self.list_IDs = list_IDs
self.n_classes = n_classes
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_IDs) / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes]
# Generate data
x, y = self.__data_generation(list_IDs_temp)
return x, y
def __data_generation(self, list_IDs_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, 1)
# Initialization
x = np.empty((self.batch_size, *self.dim, 1))
y = np.empty((self.batch_size, *self.dim, 1))
# Generate data
for i, ID in enumerate(list_IDs_temp):
# Load dataset
data = np.load('data/' + ID + '.npy')
# Store x and y
x[i,] = data[:, :, :, 0] # Image
y[i,] = data[:, :, :, 1] # Mask
# One-hot-encoding
y = keras.utils.to_categorical(y, num_classes=self.n_classes)
return x, y
Actually a few lines of code did the trick. With an extra input argument class_weights
to my generator, a line to convert the class weights to sample weights for each individual batch in the __getitem__()
method, and also a return of the sample weights in the same method, I solved the issue. The class weights are inputted as list with the following structure: class_weights = [weight_class_0, weight_class_1]
. My basic generator class now looks like this (I have marked changes with a comment):
import numpy as np
import keras
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, list_IDs, class_weights, batch_size=2, dim=(144,144,144),
n_classes=2):
'Initialization'
self.dim = dim
self.batch_size = batch_size
self.list_IDs = list_IDs
self.n_classes = n_classes
self.class_weights = class_weights # CLASS WEIGHTS FIX
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_IDs) / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes]
# Generate data
x, y = self.__data_generation(list_IDs_temp)
# Compute sample weights CLASS WEIGHTS FIX
sample_weights = np.take(np.array(self.class_weights), np.round(y[:, :, :, :, 1]).astype('int'))
return x, y, sample weights # CLASS WEIGHTS FIX
def __data_generation(self, list_IDs_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, 1)
# Initialization
x = np.empty((self.batch_size, *self.dim, 1))
y = np.empty((self.batch_size, *self.dim, 1))
# Generate data
for i, ID in enumerate(list_IDs_temp):
# Load dataset
data = np.load('data/' + ID + '.npy')
# Store x and y
x[i,] = data[:, :, :, 0] # Image
y[i,] = data[:, :, :, 1] # Mask
# One-hot-encoding
y = keras.utils.to_categorical(y, num_classes=self.n_classes)
return x, y
It might seem a bit like a magic one-liner, but what sample_weights = np.take(np.array(self.class_weights), np.round(y[:, :, :, :, 1]).astype('int'))
does is the following: It takes the y-values belonging to the not so common class, in my case the one to segment, and gives each pixel in this 3D image a sample weight. This sample weight is either the class weight for the common class or the uncommon class, depending on which class the pixel is belonging too.
The output of this generator class can be then used in the model.fit()
method of the Keras model as long as sample_weight_mode="temporal"
is passed to model.compile()
.