云服务器内容精选

  • 在代码中导入torch并使用 # -*- coding:utf-8 -*-import json# 导入torch依赖import torch as timport numpy as npdef handler (event, context): print("start training!") train() print("finished!") return { "statusCode": 200, "isBase64Encoded": False, "body": json.dumps(event), "headers": { "Content-Type": "application/json" } } def get_fake_data(batch_size=8): x = t.rand(batch_size, 1) * 20; y = x * 2 + (1 + t.randn(batch_size, 1)) * 3 return x, y def train(): t.manual_seed(1000) x, y = get_fake_data() w = t.rand(1, 1) b = t.zeros(1, 1) lr = 0.001 for ii in range(2000): x, y = get_fake_data() y_pred = x.mm(w) + b.expand_as(y) loss = 0.5 * (y_pred - y) ** 2 loss = loss.sum() dloss = 1 dy_pred = dloss * (y_pred - y) dw = x.t().mm(dy_pred) db = dy_pred.sum() w.sub_(lr * dw) b.sub_(lr * db) if ii % 10 == 0: x = t.arange(0, 20).view(-1, 1) y = x.float().mm(w)+ b.expand_as(x) x2, y2 = get_fake_data(batch_size=20) print("w=",w.item(), "b=",b.item()) 父主题: 使用pytorch进行线性回归