Pytorch中accuracy和loss的计算知识点总结-创新互联
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例
def train(train_loader, model, criteon, optimizer, epoch):
train_loss = 0
train_acc = 0
num_correct= 0
for step, (x,y) in enumerate(train_loader):
# x: [b, 3, 224, 224], y: [b]
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += float(loss.item())
train_losses.append(train_loss)
pred = logits.argmax(dim=1)
num_correct += torch.eq(pred, y).sum().float().item()
logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
return num_correct/len(train_loader.dataset), train_loss/len(train_loader) 分享标题:Pytorch中accuracy和loss的计算知识点总结-创新互联
网站网址:http://www.jxjierui.cn/article/doecsd.html


咨询
建站咨询
