I am trying to train a convolutional neural network model for 5 classes. These classes are very unbalanced, like this:
0: 25810
1: 2443
2: 5292
3: 873
4: 708
As you can see, class 0 has infinite data more than the rest. So to get a good model I need to balance them. I am currently using this code snippet which I then pass to the fit_generator:
class_weights = class_weight.compute_class_weight(
'balanced',
np.unique(train_generator.classes),
train_generator.classes)
The problem is that I don't notice any difference, and from what I read, this is what everyone uses. Is this a good way to balance classes or should I use another method?
The method on the class_weight
compute_class_weights
parameter has three options:balanced
where it determines the importance of each class according to the distribution of the data using the following formula.Where n_samples is the total number of data, n_classes is the number of classes in the dataset, and np.bincount(y) counts all occurrences of each class, i.e. the number of data per class, that you already have.
For those data that you show when using
balanced
internally, the following dictionary is created that reduces or increases the importance of certain classes, so that less represented classes are given greater importance and vice versa.dictionary
the second option to define these weights is to do it manually in a dictionary and this requires a broad understanding of the data and the problem you are working on. For example, if class 4 is of vital importance for the problem you are working on and not recognizing this class means a huge loss of money.None
or use nothing, in this case all classes will have equal importance.These are the ways that exist to determine the weights of the classes in the training, to notice the difference between using them or not, you must choose a point of view and change according to the dataset you have. For example, it is taking longer to train with or without weights, my model converges when I use weights balanced by formula or manually, as seen in the confusion matrices and the recall/precision metrics for the most difficult classes to classify. All this you must consider.
I hope it helps you.