Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
224 views
in Technique[技术] by (71.8m points)

python - How to use class weights in Keras for image segmentation

I am trying to segment medical images using a version of U-Net implemented with Keras. The inputs of my network are 3D images and the outputs are two one-hot-encoded 3D segmentation maps. I know that my dataset is very imbalanced (there is not so much to segment) and therefore I want to use class weights for my loss function (currently binary_crossentropy). With the class weights, I hope the model will give more attention to the small stuff it has to segment.

If you know the imbalance of your database, you can pass the parameter class_weight to model.fit(). Does this also work with my use case?

question from:https://stackoverflow.com/questions/65881582/how-to-use-class-weights-in-keras-for-image-segmentation

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

With the help of the above mentioned github issue I managed to solve the problem for my particular use case. I want to share the solution with you anyway. An extra hurdle was the fact I am using a custom generator for my data. A simplified version of this class is the following code:

import numpy as np
import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=2, dim=(144,144,144), n_classes=2):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_classes = n_classes
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        x, y = self.__data_generation(list_IDs_temp)

        return x, y

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, 1)
        # Initialization
        x = np.empty((self.batch_size, *self.dim, 1))
        y = np.empty((self.batch_size, *self.dim, 1))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Load dataset
            data = np.load('data/' + ID + '.npy')

            # Store x and y
            x[i,] = data[:, :, :, 0]  # Image
            y[i,] = data[:, :, :, 1]  # Mask

        # One-hot-encoding
        y = keras.utils.to_categorical(y, num_classes=self.n_classes)

        return x, y

Actually a few lines of code did the trick. With an extra input argument class_weights to my generator, a line to convert the class weights to sample weights for each individual batch in the __getitem__() method, and also a return of the sample weights in the same method, I solved the issue. The class weights are inputted as list with the following structure: class_weights = [weight_class_0, weight_class_1]. My basic generator class now looks like this (I have marked changes with a comment):

import numpy as np
import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, class_weights, batch_size=2, dim=(144,144,144), 
                 n_classes=2):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_classes = n_classes
        self.class_weights = class_weights  # CLASS WEIGHTS FIX
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        x, y = self.__data_generation(list_IDs_temp)

        # Compute sample weights CLASS WEIGHTS FIX
        sample_weights = np.take(np.array(self.class_weights), np.round(y[:, :, :, :, 1]).astype('int'))

        return x, y, sample weights  # CLASS WEIGHTS FIX

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, 1)
        # Initialization
        x = np.empty((self.batch_size, *self.dim, 1))
        y = np.empty((self.batch_size, *self.dim, 1))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Load dataset
            data = np.load('data/' + ID + '.npy')

            # Store x and y
            x[i,] = data[:, :, :, 0]  # Image
            y[i,] = data[:, :, :, 1]  # Mask

        # One-hot-encoding
        y = keras.utils.to_categorical(y, num_classes=self.n_classes)

        return x, y

It might seem a bit like a magic one-liner, but what sample_weights = np.take(np.array(self.class_weights), np.round(y[:, :, :, :, 1]).astype('int')) does is the following: It takes the y-values belonging to the not so common class, in my case the one to segment, and gives each pixel in this 3D image a sample weight. This sample weight is either the class weight for the common class or the uncommon class, depending on which class the pixel is belonging too.

The output of this generator class can be then used in the model.fit() method of the Keras model as long as sample_weight_mode="temporal" is passed to model.compile().


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...