Using pytorch, a user implemented linear regression and received the following error
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
x_data=Variable(torch.Tensor([[10.0],[9.0],[3.0],[2.0]]))
y_data=Variable(torch.Tensor([[90.0],[80.0],[50.0],[30.0]]))
class LinearRegression(torch.nn.Module):
def __init__(self):
super(LinearRegression,self). __init__ ()
self.linear = torch.nn.Linear(1,1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearRegression()
In the above code model=LinearRegression() should be outside of the class. Below code can fix the problem
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
x_data=Variable(torch.Tensor([[10.0],[9.0],[3.0],[2.0]]))
y_data=Variable(torch.Tensor([[90.0],[80.0],[50.0],[30.0]]))
class LinearRegression(torch.nn.Module):
def __init__(self):
super(LinearRegression,self). __init__ ()
self.linear = torch.nn.Linear(1,1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearRegression()