PyTorch线性回归和逻辑回归实战示例-创新互联
线性回归实战

使用PyTorch定义线性回归模型一般分以下几步:
1.设计网络架构
2.构建损失函数(loss)和优化器(optimizer)
3.训练(包括前馈(forward)、反向传播(backward)、更新模型参数(update))
#author:yuquanle
#data:2018.2.5
#Study of LinearRegression use PyTorch
import torch
from torch.autograd import Variable
# train data
x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0]]))
y_data = Variable(torch.Tensor([[2.0], [4.0], [6.0]]))
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(1, 1) # One in and one out
def forward(self, x):
y_pred = self.linear(x)
return y_pred
# our model
model = Model()
criterion = torch.nn.MSELoss(size_average=False) # Defined loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Defined optimizer
# Training: forward, loss, backward, step
# Training loop
for epoch in range(50):
# Forward pass
y_pred = model(x_data)
# Compute loss
loss = criterion(y_pred, y_data)
print(epoch, loss.data[0])
# Zero gradients
optimizer.zero_grad()
# perform backward pass
loss.backward()
# update weights
optimizer.step()
# After training
hour_var = Variable(torch.Tensor([[4.0]]))
print("predict (after training)", 4, model.forward(hour_var).data[0][0])
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。
当前标题:PyTorch线性回归和逻辑回归实战示例-创新互联
文章路径:http://www.jxjierui.cn/article/dsdspe.html


咨询
建站咨询
