What loss function to use for imbalanced classes (using PyTorch)? Announcing the arrival of Valued Associate #679: Cesar Manara Planned maintenance scheduled April 17/18, 2019 at 00:00UTC (8:00pm US/Eastern) 2019 Moderator Election Q&A - Questionnaire 2019 Community Moderator Election ResultsCNN - imbalanced classes, class weights vs data augmentationCensored output data, which activation function for the output layer and which loss function to use?What is the use of torch.no_grad in pytorch?Precision recall loss functionLoss function for Hierarchical Multi-label classificationHow to use Cross Entropy loss in pytorch for binary prediction?Loss function when the output is a single probabilityWhat loss function avoids overconfidence?Loss Function for Probability RegressionLoading own train data and labels in dataloader using pytorch?
How to motivate offshore teams and trust them to deliver?
What is a Meta algorithm?
List *all* the tuples!
Gastric acid as a weapon
Disable hyphenation for an entire paragraph
How to find all the available tools in macOS terminal?
Can a non-EU citizen traveling with me come with me through the EU passport line?
Is a manifold-with-boundary with given interior and non-empty boundary essentially unique?
Did Xerox really develop the first LAN?
What LEGO pieces have "real-world" functionality?
Is it ethical to give a final exam after the professor has quit before teaching the remaining chapters of the course?
How can players work together to take actions that are otherwise impossible?
Is 1 ppb equal to 1 μg/kg?
Why don't the Weasley twins use magic outside of school if the Trace can only find the location of spells cast?
What does '1 unit of lemon juice' mean in a grandma's drink recipe?
Should I call the interviewer directly, if HR aren't responding?
Do I really need recursive chmod to restrict access to a folder?
What's the difference between `auto x = vector<int>()` and `vector<int> x`?
Diagram with tikz
What is the musical term for a note that continously plays through a melody?
do i need a schengen visa for a direct flight to amsterdam?
Stars Make Stars
Antler Helmet: Can it work?
How to recreate this effect in Photoshop?
What loss function to use for imbalanced classes (using PyTorch)?
Announcing the arrival of Valued Associate #679: Cesar Manara
Planned maintenance scheduled April 17/18, 2019 at 00:00UTC (8:00pm US/Eastern)
2019 Moderator Election Q&A - Questionnaire
2019 Community Moderator Election ResultsCNN - imbalanced classes, class weights vs data augmentationCensored output data, which activation function for the output layer and which loss function to use?What is the use of torch.no_grad in pytorch?Precision recall loss functionLoss function for Hierarchical Multi-label classificationHow to use Cross Entropy loss in pytorch for binary prediction?Loss function when the output is a single probabilityWhat loss function avoids overconfidence?Loss Function for Probability RegressionLoading own train data and labels in dataloader using pytorch?
$begingroup$
I have a dataset with 3 classes with the following items:
- Class 1: 900 elements
- Class 2: 15000 elements
- Class 3: 800 elements
I need to predict class 1 and class 3, which signal important deviations from the norm. Class 2 is the default “normal” case which I don’t care about.
What kind of loss function would I use here? I was thinking of using CrossEntropyLoss, but since there is a class imbalance, this would need to be weighted I suppose? How does that work in practice? Like this (using PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
Or should the weight be inverted? i.e. 1 / weight?
Is this the right approach to begin with or are there other / better methods I could use?
Thanks
neural-network pytorch
$endgroup$
add a comment |
$begingroup$
I have a dataset with 3 classes with the following items:
- Class 1: 900 elements
- Class 2: 15000 elements
- Class 3: 800 elements
I need to predict class 1 and class 3, which signal important deviations from the norm. Class 2 is the default “normal” case which I don’t care about.
What kind of loss function would I use here? I was thinking of using CrossEntropyLoss, but since there is a class imbalance, this would need to be weighted I suppose? How does that work in practice? Like this (using PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
Or should the weight be inverted? i.e. 1 / weight?
Is this the right approach to begin with or are there other / better methods I could use?
Thanks
neural-network pytorch
$endgroup$
add a comment |
$begingroup$
I have a dataset with 3 classes with the following items:
- Class 1: 900 elements
- Class 2: 15000 elements
- Class 3: 800 elements
I need to predict class 1 and class 3, which signal important deviations from the norm. Class 2 is the default “normal” case which I don’t care about.
What kind of loss function would I use here? I was thinking of using CrossEntropyLoss, but since there is a class imbalance, this would need to be weighted I suppose? How does that work in practice? Like this (using PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
Or should the weight be inverted? i.e. 1 / weight?
Is this the right approach to begin with or are there other / better methods I could use?
Thanks
neural-network pytorch
$endgroup$
I have a dataset with 3 classes with the following items:
- Class 1: 900 elements
- Class 2: 15000 elements
- Class 3: 800 elements
I need to predict class 1 and class 3, which signal important deviations from the norm. Class 2 is the default “normal” case which I don’t care about.
What kind of loss function would I use here? I was thinking of using CrossEntropyLoss, but since there is a class imbalance, this would need to be weighted I suppose? How does that work in practice? Like this (using PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
Or should the weight be inverted? i.e. 1 / weight?
Is this the right approach to begin with or are there other / better methods I could use?
Thanks
neural-network pytorch
neural-network pytorch
edited Apr 1 at 22:37
Muppet
asked Apr 1 at 19:00
MuppetMuppet
1485
1485
add a comment |
add a comment |
1 Answer
1
active
oldest
votes
$begingroup$
What kind of loss function would I use here?
Cross-entropy is the go-to loss function for classification tasks, either balanced or imbalanced. It is the first choice when no preference is built from domain knowledge yet.
This would need to be weighted I suppose? How does that work in practice?
Yes. Weight of class $c$ is the size of largest class divided by the size of class $c$.
For example, If class 1 has 900, class 2 has 15000, and class 3 has 800 samples, then their weights would be 16.67, 1.0, and 18.75 respectively.
You can also use the smallest class as nominator, which gives 0.889, 0.053, and 1.0 respectively. This is only a re-scaling, the relative weights are the same.
Is this the right approach to begin with or are there other / better
methods I could use?
Yes, this is the right approach.
EDIT:
Thanks to @Muppet, we can also use class over-sampling, which is equivalent to using class weights. This is accomplished by WeightedRandomSampler
in PyTorch, using the same aforementioned weights.
$endgroup$
$begingroup$
I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this.
$endgroup$
– Muppet
Apr 2 at 17:40
add a comment |
Your Answer
StackExchange.ready(function()
var channelOptions =
tags: "".split(" "),
id: "557"
;
initTagRenderer("".split(" "), "".split(" "), channelOptions);
StackExchange.using("externalEditor", function()
// Have to fire editor after snippets, if snippets enabled
if (StackExchange.settings.snippets.snippetsEnabled)
StackExchange.using("snippets", function()
createEditor();
);
else
createEditor();
);
function createEditor()
StackExchange.prepareEditor(
heartbeatType: 'answer',
autoActivateHeartbeat: false,
convertImagesToLinks: false,
noModals: true,
showLowRepImageUploadWarning: true,
reputationToPostImages: null,
bindNavPrevention: true,
postfix: "",
imageUploader:
brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
allowUrls: true
,
onDemand: true,
discardSelector: ".discard-answer"
,immediatelyShowMarkdownHelp:true
);
);
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function ()
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fdatascience.stackexchange.com%2fquestions%2f48369%2fwhat-loss-function-to-use-for-imbalanced-classes-using-pytorch%23new-answer', 'question_page');
);
Post as a guest
Required, but never shown
1 Answer
1
active
oldest
votes
1 Answer
1
active
oldest
votes
active
oldest
votes
active
oldest
votes
$begingroup$
What kind of loss function would I use here?
Cross-entropy is the go-to loss function for classification tasks, either balanced or imbalanced. It is the first choice when no preference is built from domain knowledge yet.
This would need to be weighted I suppose? How does that work in practice?
Yes. Weight of class $c$ is the size of largest class divided by the size of class $c$.
For example, If class 1 has 900, class 2 has 15000, and class 3 has 800 samples, then their weights would be 16.67, 1.0, and 18.75 respectively.
You can also use the smallest class as nominator, which gives 0.889, 0.053, and 1.0 respectively. This is only a re-scaling, the relative weights are the same.
Is this the right approach to begin with or are there other / better
methods I could use?
Yes, this is the right approach.
EDIT:
Thanks to @Muppet, we can also use class over-sampling, which is equivalent to using class weights. This is accomplished by WeightedRandomSampler
in PyTorch, using the same aforementioned weights.
$endgroup$
$begingroup$
I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this.
$endgroup$
– Muppet
Apr 2 at 17:40
add a comment |
$begingroup$
What kind of loss function would I use here?
Cross-entropy is the go-to loss function for classification tasks, either balanced or imbalanced. It is the first choice when no preference is built from domain knowledge yet.
This would need to be weighted I suppose? How does that work in practice?
Yes. Weight of class $c$ is the size of largest class divided by the size of class $c$.
For example, If class 1 has 900, class 2 has 15000, and class 3 has 800 samples, then their weights would be 16.67, 1.0, and 18.75 respectively.
You can also use the smallest class as nominator, which gives 0.889, 0.053, and 1.0 respectively. This is only a re-scaling, the relative weights are the same.
Is this the right approach to begin with or are there other / better
methods I could use?
Yes, this is the right approach.
EDIT:
Thanks to @Muppet, we can also use class over-sampling, which is equivalent to using class weights. This is accomplished by WeightedRandomSampler
in PyTorch, using the same aforementioned weights.
$endgroup$
$begingroup$
I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this.
$endgroup$
– Muppet
Apr 2 at 17:40
add a comment |
$begingroup$
What kind of loss function would I use here?
Cross-entropy is the go-to loss function for classification tasks, either balanced or imbalanced. It is the first choice when no preference is built from domain knowledge yet.
This would need to be weighted I suppose? How does that work in practice?
Yes. Weight of class $c$ is the size of largest class divided by the size of class $c$.
For example, If class 1 has 900, class 2 has 15000, and class 3 has 800 samples, then their weights would be 16.67, 1.0, and 18.75 respectively.
You can also use the smallest class as nominator, which gives 0.889, 0.053, and 1.0 respectively. This is only a re-scaling, the relative weights are the same.
Is this the right approach to begin with or are there other / better
methods I could use?
Yes, this is the right approach.
EDIT:
Thanks to @Muppet, we can also use class over-sampling, which is equivalent to using class weights. This is accomplished by WeightedRandomSampler
in PyTorch, using the same aforementioned weights.
$endgroup$
What kind of loss function would I use here?
Cross-entropy is the go-to loss function for classification tasks, either balanced or imbalanced. It is the first choice when no preference is built from domain knowledge yet.
This would need to be weighted I suppose? How does that work in practice?
Yes. Weight of class $c$ is the size of largest class divided by the size of class $c$.
For example, If class 1 has 900, class 2 has 15000, and class 3 has 800 samples, then their weights would be 16.67, 1.0, and 18.75 respectively.
You can also use the smallest class as nominator, which gives 0.889, 0.053, and 1.0 respectively. This is only a re-scaling, the relative weights are the same.
Is this the right approach to begin with or are there other / better
methods I could use?
Yes, this is the right approach.
EDIT:
Thanks to @Muppet, we can also use class over-sampling, which is equivalent to using class weights. This is accomplished by WeightedRandomSampler
in PyTorch, using the same aforementioned weights.
edited Apr 3 at 8:18
answered Apr 1 at 20:29
EsmailianEsmailian
3,311420
3,311420
$begingroup$
I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this.
$endgroup$
– Muppet
Apr 2 at 17:40
add a comment |
$begingroup$
I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this.
$endgroup$
– Muppet
Apr 2 at 17:40
$begingroup$
I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this.
$endgroup$
– Muppet
Apr 2 at 17:40
$begingroup$
I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this.
$endgroup$
– Muppet
Apr 2 at 17:40
add a comment |
Thanks for contributing an answer to Data Science Stack Exchange!
- Please be sure to answer the question. Provide details and share your research!
But avoid …
- Asking for help, clarification, or responding to other answers.
- Making statements based on opinion; back them up with references or personal experience.
Use MathJax to format equations. MathJax reference.
To learn more, see our tips on writing great answers.
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function ()
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fdatascience.stackexchange.com%2fquestions%2f48369%2fwhat-loss-function-to-use-for-imbalanced-classes-using-pytorch%23new-answer', 'question_page');
);
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown