LSTM not converging Announcing the arrival of Valued Associate #679: Cesar Manara Planned maintenance scheduled April 23, 2019 at 23:30 UTC (7:30pm US/Eastern) 2019 Moderator Election Q&A - Questionnaire 2019 Community Moderator Election ResultsNeural Network: how to interpret this loss graph?Understanding dimensions of Keras LSTM targetLSTM for time series - which window size to useModel Not Learning with Sparse Dataset (LSTM with Keras)LSTM Produces Random PredictionsRecurrent Neural Network (LSTM) not converging during optimizationKeras: extreme spike in loss during trainingHow to design a many-to-many LSTM?The memorisation capacity of an LSTM (real numbers)Remedies to CNN-LSTM overfitting on relatively small image dataset
Multi tool use
Differences to CCompactSize and CVarInt
Is CEO the "profession" with the most psychopaths?
How to force a browser when connecting to a specific domain to be https only using only the client machine?
Where is the Next Backup Size entry on iOS 12?
Is there public access to the Meteor Crater in Arizona?
What is the origin of 落第?
Why not send Voyager 3 and 4 following up the paths taken by Voyager 1 and 2 to re-transmit signals of later as they fly away from Earth?
Are the endpoints of the domain of a function counted as critical points?
Does silver oxide react with hydrogen sulfide?
Is multiple magic items in one inherently imbalanced?
Why datecode is SO IMPORTANT to chip manufacturers?
Is it dangerous to install hacking tools on my private linux machine?
How much damage would a cupful of neutron star matter do to the Earth?
What is the difference between a "ranged attack" and a "ranged weapon attack"?
Why complex landing gears are used instead of simple,reliability and light weight muscle wire or shape memory alloys?
Should a wizard buy fine inks every time he want to copy spells into his spellbook?
My mentor says to set image to Fine instead of RAW — how is this different from JPG?
Putting class ranking in CV, but against dept guidelines
Is there hard evidence that the grant peer review system performs significantly better than random?
NERDTreeMenu Remapping
Test print coming out spongy
Central Vacuuming: Is it worth it, and how does it compare to normal vacuuming?
Is openssl rand command cryptographically secure?
Co-worker has annoying ringtone
LSTM not converging
Announcing the arrival of Valued Associate #679: Cesar Manara
Planned maintenance scheduled April 23, 2019 at 23:30 UTC (7:30pm US/Eastern)
2019 Moderator Election Q&A - Questionnaire
2019 Community Moderator Election ResultsNeural Network: how to interpret this loss graph?Understanding dimensions of Keras LSTM targetLSTM for time series - which window size to useModel Not Learning with Sparse Dataset (LSTM with Keras)LSTM Produces Random PredictionsRecurrent Neural Network (LSTM) not converging during optimizationKeras: extreme spike in loss during trainingHow to design a many-to-many LSTM?The memorisation capacity of an LSTM (real numbers)Remedies to CNN-LSTM overfitting on relatively small image dataset
$begingroup$
I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
torch.manual_seed(1)
torch.cuda.set_device(0)
from fastai.learner import *
n_hidden = 64
n_classes = 2
bs = 1
class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)
def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)
def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))
model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()
if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)
Also:
data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')
and
data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')
I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.
And this is the output:
Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)
I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x)
:
tensor([[[-0.3617, -1.1924]],
[[-0.3046, -1.3373]],
[[-0.2696, -1.4424]],
[[-0.2477, -1.5169]],
[[-0.2345, -1.5654]],
[[-0.2262, -1.5971]],
And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?
lstm pytorch convergence
$endgroup$
add a comment |
$begingroup$
I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
torch.manual_seed(1)
torch.cuda.set_device(0)
from fastai.learner import *
n_hidden = 64
n_classes = 2
bs = 1
class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)
def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)
def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))
model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()
if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)
Also:
data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')
and
data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')
I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.
And this is the output:
Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)
I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x)
:
tensor([[[-0.3617, -1.1924]],
[[-0.3046, -1.3373]],
[[-0.2696, -1.4424]],
[[-0.2477, -1.5169]],
[[-0.2345, -1.5654]],
[[-0.2262, -1.5971]],
And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?
lstm pytorch convergence
$endgroup$
$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07
$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48
$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04
add a comment |
$begingroup$
I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
torch.manual_seed(1)
torch.cuda.set_device(0)
from fastai.learner import *
n_hidden = 64
n_classes = 2
bs = 1
class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)
def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)
def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))
model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()
if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)
Also:
data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')
and
data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')
I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.
And this is the output:
Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)
I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x)
:
tensor([[[-0.3617, -1.1924]],
[[-0.3046, -1.3373]],
[[-0.2696, -1.4424]],
[[-0.2477, -1.5169]],
[[-0.2345, -1.5654]],
[[-0.2262, -1.5971]],
And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?
lstm pytorch convergence
$endgroup$
I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
torch.manual_seed(1)
torch.cuda.set_device(0)
from fastai.learner import *
n_hidden = 64
n_classes = 2
bs = 1
class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)
def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)
def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))
model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()
if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)
Also:
data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')
and
data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')
I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.
And this is the output:
Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)
I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x)
:
tensor([[[-0.3617, -1.1924]],
[[-0.3046, -1.3373]],
[[-0.2696, -1.4424]],
[[-0.2477, -1.5169]],
[[-0.2345, -1.5654]],
[[-0.2262, -1.5971]],
And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?
lstm pytorch convergence
lstm pytorch convergence
edited Apr 4 at 12:48
Stephen Rauch♦
1,52551330
1,52551330
asked Apr 4 at 5:16
BillBill
111
111
$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07
$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48
$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04
add a comment |
$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07
$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48
$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04
$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07
$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07
$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48
$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48
$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04
$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04
add a comment |
0
active
oldest
votes
Your Answer
StackExchange.ready(function()
var channelOptions =
tags: "".split(" "),
id: "557"
;
initTagRenderer("".split(" "), "".split(" "), channelOptions);
StackExchange.using("externalEditor", function()
// Have to fire editor after snippets, if snippets enabled
if (StackExchange.settings.snippets.snippetsEnabled)
StackExchange.using("snippets", function()
createEditor();
);
else
createEditor();
);
function createEditor()
StackExchange.prepareEditor(
heartbeatType: 'answer',
autoActivateHeartbeat: false,
convertImagesToLinks: false,
noModals: true,
showLowRepImageUploadWarning: true,
reputationToPostImages: null,
bindNavPrevention: true,
postfix: "",
imageUploader:
brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
allowUrls: true
,
onDemand: true,
discardSelector: ".discard-answer"
,immediatelyShowMarkdownHelp:true
);
);
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function ()
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fdatascience.stackexchange.com%2fquestions%2f48569%2flstm-not-converging%23new-answer', 'question_page');
);
Post as a guest
Required, but never shown
0
active
oldest
votes
0
active
oldest
votes
active
oldest
votes
active
oldest
votes
Thanks for contributing an answer to Data Science Stack Exchange!
- Please be sure to answer the question. Provide details and share your research!
But avoid …
- Asking for help, clarification, or responding to other answers.
- Making statements based on opinion; back them up with references or personal experience.
Use MathJax to format equations. MathJax reference.
To learn more, see our tips on writing great answers.
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function ()
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fdatascience.stackexchange.com%2fquestions%2f48569%2flstm-not-converging%23new-answer', 'question_page');
);
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function ()
StackExchange.helpers.onClickDraftSave('#login-link');
);
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
ZakBgkx2ORu
$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07
$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48
$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04