LSTM not converging Announcing the arrival of Valued Associate #679: Cesar Manara Planned maintenance scheduled April 23, 2019 at 23:30 UTC (7:30pm US/Eastern) 2019 Moderator Election Q&A - Questionnaire 2019 Community Moderator Election ResultsNeural Network: how to interpret this loss graph?Understanding dimensions of Keras LSTM targetLSTM for time series - which window size to useModel Not Learning with Sparse Dataset (LSTM with Keras)LSTM Produces Random PredictionsRecurrent Neural Network (LSTM) not converging during optimizationKeras: extreme spike in loss during trainingHow to design a many-to-many LSTM?The memorisation capacity of an LSTM (real numbers)Remedies to CNN-LSTM overfitting on relatively small image dataset

Differences to CCompactSize and CVarInt

Is CEO the "profession" with the most psychopaths?

How to force a browser when connecting to a specific domain to be https only using only the client machine?

Where is the Next Backup Size entry on iOS 12?

Is there public access to the Meteor Crater in Arizona?

What is the origin of 落第?

Why not send Voyager 3 and 4 following up the paths taken by Voyager 1 and 2 to re-transmit signals of later as they fly away from Earth?

Are the endpoints of the domain of a function counted as critical points?

Does silver oxide react with hydrogen sulfide?

Is multiple magic items in one inherently imbalanced?

Why datecode is SO IMPORTANT to chip manufacturers?

Is it dangerous to install hacking tools on my private linux machine?

How much damage would a cupful of neutron star matter do to the Earth?

What is the difference between a "ranged attack" and a "ranged weapon attack"?

Why complex landing gears are used instead of simple,reliability and light weight muscle wire or shape memory alloys?

Should a wizard buy fine inks every time he want to copy spells into his spellbook?

My mentor says to set image to Fine instead of RAW — how is this different from JPG?

Putting class ranking in CV, but against dept guidelines

Is there hard evidence that the grant peer review system performs significantly better than random?

NERDTreeMenu Remapping

Test print coming out spongy

Central Vacuuming: Is it worth it, and how does it compare to normal vacuuming?

Is openssl rand command cryptographically secure?

Co-worker has annoying ringtone



LSTM not converging



Announcing the arrival of Valued Associate #679: Cesar Manara
Planned maintenance scheduled April 23, 2019 at 23:30 UTC (7:30pm US/Eastern)
2019 Moderator Election Q&A - Questionnaire
2019 Community Moderator Election ResultsNeural Network: how to interpret this loss graph?Understanding dimensions of Keras LSTM targetLSTM for time series - which window size to useModel Not Learning with Sparse Dataset (LSTM with Keras)LSTM Produces Random PredictionsRecurrent Neural Network (LSTM) not converging during optimizationKeras: extreme spike in loss during trainingHow to design a many-to-many LSTM?The memorisation capacity of an LSTM (real numbers)Remedies to CNN-LSTM overfitting on relatively small image dataset










2












$begingroup$


I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np

torch.manual_seed(1)
torch.cuda.set_device(0)

from fastai.learner import *

n_hidden = 64
n_classes = 2
bs = 1

class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)

def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)

def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))

model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()

if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)


Also:




data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')


and



data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')


I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.



And this is the output:



Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)


I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x):



tensor([[[-0.3617, -1.1924]],

[[-0.3046, -1.3373]],

[[-0.2696, -1.4424]],

[[-0.2477, -1.5169]],

[[-0.2345, -1.5654]],

[[-0.2262, -1.5971]],


And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?



enter image description here










share|improve this question











$endgroup$











  • $begingroup$
    Can you print out data_x.size() and data_y.size()?
    $endgroup$
    – Armen Aghajanyan
    Apr 4 at 22:07










  • $begingroup$
    @ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
    $endgroup$
    – Bill
    Apr 4 at 22:48










  • $begingroup$
    Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
    $endgroup$
    – Armen Aghajanyan
    Apr 5 at 23:04















2












$begingroup$


I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np

torch.manual_seed(1)
torch.cuda.set_device(0)

from fastai.learner import *

n_hidden = 64
n_classes = 2
bs = 1

class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)

def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)

def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))

model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()

if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)


Also:




data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')


and



data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')


I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.



And this is the output:



Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)


I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x):



tensor([[[-0.3617, -1.1924]],

[[-0.3046, -1.3373]],

[[-0.2696, -1.4424]],

[[-0.2477, -1.5169]],

[[-0.2345, -1.5654]],

[[-0.2262, -1.5971]],


And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?



enter image description here










share|improve this question











$endgroup$











  • $begingroup$
    Can you print out data_x.size() and data_y.size()?
    $endgroup$
    – Armen Aghajanyan
    Apr 4 at 22:07










  • $begingroup$
    @ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
    $endgroup$
    – Bill
    Apr 4 at 22:48










  • $begingroup$
    Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
    $endgroup$
    – Armen Aghajanyan
    Apr 5 at 23:04













2












2








2


0



$begingroup$


I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np

torch.manual_seed(1)
torch.cuda.set_device(0)

from fastai.learner import *

n_hidden = 64
n_classes = 2
bs = 1

class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)

def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)

def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))

model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()

if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)


Also:




data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')


and



data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')


I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.



And this is the output:



Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)


I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x):



tensor([[[-0.3617, -1.1924]],

[[-0.3046, -1.3373]],

[[-0.2696, -1.4424]],

[[-0.2477, -1.5169]],

[[-0.2345, -1.5654]],

[[-0.2262, -1.5971]],


And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?



enter image description here










share|improve this question











$endgroup$




I am sorry if this questions is basic but I am quite new to NN in general. I am trying to build an LSTM to predict certain properties of a light curve (the output is 0 or 1). I build it in pytorch. Here is my code:



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np

torch.manual_seed(1)
torch.cuda.set_device(0)

from fastai.learner import *

n_hidden = 64
n_classes = 2
bs = 1

class TESS_LSTM(nn.Module):
def __init__(self, nl):
super().__init__()
self.nl = nl
self.rnn = nn.LSTM(1, n_hidden, nl, dropout=0.01, bidirectional=True)
self.l_out = nn.Linear(n_hidden*2, n_classes)
self.init_hidden(bs)

def forward(self, input):
outp,h = self.rnn(input.view(len(input), bs, -1), self.h)
#self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp),dim=2)

def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl*2, bs, n_hidden)),
V(torch.zeros(self.nl*2, bs, n_hidden)))

model = TESS_LSTM(2).cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):
model.zero_grad()
tag_scores = model(data_x)
loss = loss_function(tag_scores.reshape(len(data_x),n_classes), data_y.reshape(len(data_y)))
loss.backward()
optimizer.step()

if epoch%10==0:
print("Loss at epoch %d = " %epoch, loss)


Also:




data_x = tensor([
[0.9995450377],
[0.9991207719],
[0.9986526966],
[1.0017241240],
[1.0016067028],
[1.0000480413],
[1.0016841888],
[1.0010652542],
[0.9991232157],
[1.0004128218],
[0.9986800551],
[1.0011130571],
[1.0001415014],
[1.0004080534],
[1.0016922951],
[1.0008358955],
[1.0001622438],
[1.0004277229],
[1.0011759996],
[1.0013391972],
[0.9995799065],
[1.0019282103],
[1.0006642342],
[1.0006272793],
[1.0011570454],
[1.0015332699],
[1.0011225939],
[1.0003337860],
[1.0014277697],
[1.0003565550],
[0.9989787340],
[1.0006136894],
[1.0003052950],
[1.0001049042],
[1.0020918846],
[0.9999115467],
[1.0006635189],
[1.0007561445],
[1.0016170740],
[1.0008252859],
[0.9997656345],
[1.0001330376],
[1.0017272234],
[1.0004107952],
[1.0012439489],
[0.9994274378],
[1.0014992952],
[1.0015807152],
[1.0004781485],
[1.0010997057],
[1.0011326075],
[1.0005493164],
[1.0014353991],
[0.9990324974],
[1.0012129545],
[0.9990709424],
[1.0006347895],
[1.0000327826],
[1.0005196333],
[1.0012207031],
[1.0003460646],
[1.0004434586],
[1.0003618002],
[1.0005420446],
[1.0005528927],
[1.0006977320],
[1.0005317926],
[1.0000808239],
[1.0005664825],
[0.9994245768],
[0.9999254942],
[1.0011985302],
[1.0009841919],
[0.9999029040],
[1.0014100075],
[1.0014085770],
[1.0005567074],
[1.0016088486],
[0.9997186661],
[0.9998687506],
[0.9988344908],
[0.9999858141],
[1.0004914999],
[1.0003308058],
[1.0001890659],
[1.0002681017],
[1.0029908419],
[1.0005286932],
[1.0004363060],
[0.9994311333],
[1.0011523962],
[1.0008679628],
[1.0014137030],
[0.9994244576],
[1.0003470182],
[1.0001592636],
[1.0002418756],
[0.9992931485],
[1.0016175508],
[1.0000959635],
[1.0005099773],
[1.0008889437],
[0.9998087287],
[0.9995828867],
[0.9997566342],
[1.0002474785],
[1.0010808706],
[1.0002821684],
[1.0013456345],
[1.0013040304],
[1.0010949373],
[1.0002720356],
[0.9996811152],
[1.0006061792],
[1.0012511015],
[0.9999302626],
[0.9985374212],
[1.0002642870],
[0.9996038675],
[1.0007606745],
[0.9992995858],
[1.0000385046],
[0.9997834563],
[1.0005996227],
[1.0006167889],
[1.0015753508],
[1.0010306835],
[0.9997833371],
[1.0010590553],
[1.0008200407],
[1.0008001328],
[1.0014072657],
[0.9994395375],
[0.9991182089],
[1.0011717081],
[1.0007920265],
[1.0011025667],
[1.0004047155],
[1.0017303228],
[1.0014981031],
[0.9995774031],
[0.9999650121],
[0.9992966652],
[1.0013586283],
[1.0003392696],
[1.0005040169],
[1.0008341074],
[1.0014744997],
[0.9996585250],
[1.0019916296],
[1.0007069111],
[1.0004591942],
[1.0004271269],
[0.9991059303],
[1.0003436804],
[0.9990482330],
[0.9980322123],
[0.9980198145],
[0.9966595173],
[0.9969686270],
[0.9977232814],
[0.9969192147],
[0.9962794185],
[0.9947851300],
[0.9946336746],
[0.9943053722],
[0.9946651459],
[0.9930071235],
[0.9940539598],
[0.9950682521],
[0.9947031140],
[0.9950703979],
[0.9945428371],
[0.9945927858],
[0.9937841296],
[0.9944553375],
[0.9929991364],
[0.9940859079],
[0.9930059314],
[0.9942978621],
[0.9950152636],
[0.9943225384],
[0.9934711456],
[0.9929080606],
[0.9934846163],
[0.9954113960],
[0.9925802350],
[0.9929560423],
[0.9933584929],
[0.9929228425],
[0.9930893779],
[0.9936142564],
[0.9943635464],
[0.9933300614],
[0.9925817847],
[0.9927681088],
[0.9930697680],
[0.9937900901],
[0.9919354320],
[0.9937084913],
[0.9951301217],
[0.9926426411],
[0.9933566451],
[0.9937180877],
[0.9922621250],
[0.9933888316],
[0.9936477542],
[0.9916112423],
[0.9943441153],
[0.9934164286],
[0.9949553013],
[0.9941871166],
[0.9933763146],
[0.9959306121],
[0.9930690527],
[0.9928541183],
[0.9936354756],
[0.9931223392],
[0.9936516881],
[0.9935654402],
[0.9932218790],
[0.9943401814],
[0.9931038022],
[0.9926875830],
[0.9928631186],
[0.9936705232],
[0.9939361215],
[0.9942125678],
[0.9939611554],
[0.9936586618],
[0.9933990240],
[0.9948219061],
[0.9940339923],
[0.9950091243],
[0.9952197671],
[0.9947227240],
[0.9935435653],
[0.9956403971],
[0.9943848252],
[0.9942221045],
[0.9960014224],
[0.9931004643],
[0.9960579872],
[0.9951166511],
[0.9964768291],
[0.9968702793],
[0.9967978597],
[0.9971982837],
[0.9977793097],
[0.9982623458],
[0.9988413453],
[1.0008778572],
[1.0013417006],
[1.0000336170],
[0.9979853630],
[0.9988892674],
[0.9994396567],
[1.0002176762],
[1.0017417669],
[1.0013097525],
[1.0011264086],
[1.0004124641],
[1.0003939867],
[0.9996479750],
[0.9995540380],
[1.0003930330],
[1.0016323328],
[1.0004589558],
[0.9996963739],
[0.9989817142],
[0.9998068213],
[1.0011200905],
[1.0006275177],
[1.0000452995],
[1.0012514591],
[1.0002357960],
[0.9993159175],
[1.0002738237],
[0.9994575381],
[0.9986617565],
[0.9982920289],
[0.9998571873],
[0.9996472597],
[1.0012613535],
[1.0015693903],
[0.9999635220],
[1.0006184578],
[1.0010757446],
[0.9988756776],
[1.0004955530],
[1.0011548996],
[1.0007628202],
[1.0006260872],
[0.9989725947],
[1.0013129711],
[0.9994829297],
[0.9998571873],
[0.9994959831],
[1.0007432699],
[0.9995724559],
[0.9999076724],
[0.9992097020],
[1.0011855364],
[0.9987785220],
[1.0010210276],
[0.9998293519],
[0.9996315837],
[0.9999501705],
[1.0001417398],
[1.0005141497],
[0.9993781447],
[1.0003532171],
[0.9999422431],
[1.0014258623],
[1.0012118816],
[0.9994109273],
[1.0019438267],
[1.0012354851],
[1.0009905100],
[1.0001032352],
[0.9999653101],
[0.9991906881],
[1.0004152060],
[0.9998226762],
[0.9999175668],
[0.9994540215],
[1.0000722408],
[1.0019129515],
[0.9997307658],
[0.9996227026],
[1.0011816025],
[0.9993667006],
[1.0010036230],
[0.9993645549],
[1.0004647970],
[0.9995272160],
[0.9989504814],
[0.9981039166],
[1.0006005764],
[0.9998896718],
[1.0004893541],
[0.9991874099],
[1.0005015135],
[0.9995905161],
[0.9990965128],
[1.0012912750],
[1.0004948378],
[1.0002779961],
[0.9988743067],
[1.0019037724],
[1.0006437302],
[0.9999380112],
[1.0001602173],
[0.9997741580],
[0.9988395572],
[0.9999371171],
[0.9989091754],
[0.9987531900],
[1.0003957748],
[0.9997722507],
[0.9988819361],
[0.9998422265],
[0.9986129999],
[0.9989410639],
[1.0016149282],
[0.9997441173],
[1.0002747774],
[0.9990793467],
[1.0006495714],
[1.0004252195],
[0.9997921586],
[0.9987344146],
[0.9998763800],
[0.9988097548],
[1.0007627010],
[1.0004670620],
[1.0007309914],
[0.9987894297],
[1.0000542402],
[1.0004990101],
[0.9999514818],
[0.9998412132],
[1.0000183582],
[1.0003197193],
[0.9991712570],
[0.9992188215],
[0.9986482859],
[1.0010583401],
[1.0011837482],
[0.9993829727],
[0.9995718002],
[0.9997168183],
[1.0017461777],
[0.9998381138],
[0.9990652204],
[1.0001449585],
[0.9998424053],
[1.0011798143],
[1.0013160706],
[0.9995942712],
[1.0001651049],
[1.0001466274],
[0.9982855320],
[0.9992064238],
[1.0009102821],
[0.9982813597],
[1.0000503063],
[0.9982630014],
[1.0017516613],
[0.9995808005],
[0.9989835620],
[1.0003046989],
[1.0019340515],
[0.9996930957],
[1.0000711679],
[1.0011881590],
[1.0009138584],
[1.0013902187],
[0.9994105101],
[0.9986224174],
[0.9995336533],
[1.0006912947],
[0.9995169044],
[0.9998968840],
[0.9989182949],
[0.9999300838],
[0.9991120696],
[0.9996063709],
[1.0008803606],
[1.0019868612],
[1.0004760027],
[0.9996407032],
[1.0011100769],
[1.0026890039],
[0.9996611476],
[0.9991108775],
[0.9982090592],
[1.0000833273],
[1.0015701056],
[0.9994426966],
[0.9999341369],
[1.0002813339],
[0.9998958707],
[1.0011670589],
[1.0009137392],
[0.9994600415],
[1.0010378361],
[1.0008393526],
[1.0013997555],
[0.9994245768],
[0.9995403886],
[0.9997746348],
[0.9997846484],
[1.0012620687],
[1.0009645224],
[0.9995513558],
[1.0008162260],
[1.0008013248],
[0.9990139604],
[1.0004394054],
[0.9991726875],
[1.0009342432],
[1.0008635521],
[1.0007735491],
[1.0013785362],
[0.9997245073],
[0.9989474416],
[0.9996470809],
[1.0008428097],
[1.0017400980],
[0.9994468689],
[0.9999369979],
[1.0007227659],
[1.0012919903],
[0.9981160164],
[0.9999316335],
[0.9997596741],
[1.0008264780],
[0.9994930029],
[1.0001339912],
[0.9998437166],
[0.9999112487],
[1.0001872778],
[1.0006663799],
[1.0007426739],
[1.0016776323],
[0.9996471405],
[0.9981047511],
[1.0007015467],
[1.0006203651],
[0.9987628460],
[0.9981441498],
[0.9981172085],
[0.9999507666],
[1.0002735853],
[1.0006685257],
[1.0001268387],
[1.0000184774],
[0.9998023510],
[1.0006322861]], device='cuda:0')


and



data_y = tensor([
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
], device='cuda:0')


I read data_x and data_y from a file, so that's why I just pasted the values here. See the image below: 0 corresponds to blue and 1 to red.



And this is the output:



Loss at epoch 0 = tensor(0.6795, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 10 = tensor(0.4872, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 20 = tensor(0.4818, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 30 = tensor(0.4834, device='cuda:0', grad_fn=<NllLossBackward>)
Loss at epoch 40 = tensor(0.4828, device='cuda:0', grad_fn=<NllLossBackward>)


I tried reducing and increasing the learning rate, trying SGD and RMSprop increasing the number of epochs, but the loss always stops at 0.48. This is part of the output of model(data_x):



tensor([[[-0.3617, -1.1924]],

[[-0.3046, -1.3373]],

[[-0.2696, -1.4424]],

[[-0.2477, -1.5169]],

[[-0.2345, -1.5654]],

[[-0.2262, -1.5971]],


And all the other values are similar to this. I expected at least that the LSTM will overfit my model, or at least predict 0 for everything (given that I have just few ones, the loss would still be pretty small). But instead it just predicts these numbers and I am not sure why it stops there. I tried any debugging method I know (which are not very many given my AI experience). How can I fix this?



enter image description here







lstm pytorch convergence






share|improve this question















share|improve this question













share|improve this question




share|improve this question








edited Apr 4 at 12:48









Stephen Rauch

1,52551330




1,52551330










asked Apr 4 at 5:16









BillBill

111




111











  • $begingroup$
    Can you print out data_x.size() and data_y.size()?
    $endgroup$
    – Armen Aghajanyan
    Apr 4 at 22:07










  • $begingroup$
    @ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
    $endgroup$
    – Bill
    Apr 4 at 22:48










  • $begingroup$
    Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
    $endgroup$
    – Armen Aghajanyan
    Apr 5 at 23:04
















  • $begingroup$
    Can you print out data_x.size() and data_y.size()?
    $endgroup$
    – Armen Aghajanyan
    Apr 4 at 22:07










  • $begingroup$
    @ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
    $endgroup$
    – Bill
    Apr 4 at 22:48










  • $begingroup$
    Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
    $endgroup$
    – Armen Aghajanyan
    Apr 5 at 23:04















$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07




$begingroup$
Can you print out data_x.size() and data_y.size()?
$endgroup$
– Armen Aghajanyan
Apr 4 at 22:07












$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48




$begingroup$
@ArmenAghajanyan this is the output for both: torch.Size([500, 1]) The size of the vectors is the right one needed by the PyTorch LSTM. I actually tried replacing all the ones in the output with zeros (so all the outputs are zeros), and in that case the loss goes down to 10^-5, so the LSTM seems to be able to learn in general, it just has a problem in this case (actually even if I have only one "1" and the rest zeros, it also stops learning).
$endgroup$
– Bill
Apr 4 at 22:48












$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04




$begingroup$
Instead of using 2 outputs followed by log_softmax trained with NLL, use 1 output followed sigmoid and trained with binary_cross_entropy
$endgroup$
– Armen Aghajanyan
Apr 5 at 23:04










0






active

oldest

votes












Your Answer








StackExchange.ready(function()
var channelOptions =
tags: "".split(" "),
id: "557"
;
initTagRenderer("".split(" "), "".split(" "), channelOptions);

StackExchange.using("externalEditor", function()
// Have to fire editor after snippets, if snippets enabled
if (StackExchange.settings.snippets.snippetsEnabled)
StackExchange.using("snippets", function()
createEditor();
);

else
createEditor();

);

function createEditor()
StackExchange.prepareEditor(
heartbeatType: 'answer',
autoActivateHeartbeat: false,
convertImagesToLinks: false,
noModals: true,
showLowRepImageUploadWarning: true,
reputationToPostImages: null,
bindNavPrevention: true,
postfix: "",
imageUploader:
brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
allowUrls: true
,
onDemand: true,
discardSelector: ".discard-answer"
,immediatelyShowMarkdownHelp:true
);



);













draft saved

draft discarded


















StackExchange.ready(
function ()
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fdatascience.stackexchange.com%2fquestions%2f48569%2flstm-not-converging%23new-answer', 'question_page');

);

Post as a guest















Required, but never shown

























0






active

oldest

votes








0






active

oldest

votes









active

oldest

votes






active

oldest

votes















draft saved

draft discarded
















































Thanks for contributing an answer to Data Science Stack Exchange!


  • Please be sure to answer the question. Provide details and share your research!

But avoid


  • Asking for help, clarification, or responding to other answers.

  • Making statements based on opinion; back them up with references or personal experience.

Use MathJax to format equations. MathJax reference.


To learn more, see our tips on writing great answers.




draft saved


draft discarded














StackExchange.ready(
function ()
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fdatascience.stackexchange.com%2fquestions%2f48569%2flstm-not-converging%23new-answer', 'question_page');

);

Post as a guest















Required, but never shown





















































Required, but never shown














Required, but never shown












Required, but never shown







Required, but never shown

































Required, but never shown














Required, but never shown












Required, but never shown







Required, but never shown







Popular posts from this blog

Adding axes to figuresAdding axes labels to LaTeX figuresLaTeX equivalent of ConTeXt buffersRotate a node but not its content: the case of the ellipse decorationHow to define the default vertical distance between nodes?TikZ scaling graphic and adjust node position and keep font sizeNumerical conditional within tikz keys?adding axes to shapesAlign axes across subfiguresAdding figures with a certain orderLine up nested tikz enviroments or how to get rid of themAdding axes labels to LaTeX figures

Luettelo Yhdysvaltain laivaston lentotukialuksista Lähteet | Navigointivalikko

Gary (muusikko) Sisällysluettelo Historia | Rockin' High | Lähteet | Aiheesta muualla | NavigointivalikkoInfobox OKTuomas "Gary" Keskinen Ancaran kitaristiksiProjekti Rockin' High