函数工作流 FunctionGraph-在代码中导入torch并使用

时间:2023-11-01 16:16:37

在代码中导入torch并使用

# -*- coding:utf-8 -*-import json# 导入torch依赖import torch as timport numpy as npdef handler (event, context):    print("start training!")    train()    print("finished!")    return {        "statusCode": 200,        "isBase64Encoded": False,        "body": json.dumps(event),        "headers": {            "Content-Type": "application/json"        }    }  def get_fake_data(batch_size=8):    x = t.rand(batch_size, 1) * 20;    y = x * 2 + (1 + t.randn(batch_size, 1)) * 3      return x, y def train():    t.manual_seed(1000)       x, y = get_fake_data()      w = t.rand(1, 1)     b = t.zeros(1, 1)    lr = 0.001        for ii in range(2000):        x, y = get_fake_data()         y_pred = x.mm(w) + b.expand_as(y)        loss = 0.5 * (y_pred - y) ** 2          loss = loss.sum()         dloss = 1         dy_pred = dloss * (y_pred - y)         dw = x.t().mm(dy_pred)        db = dy_pred.sum()        w.sub_(lr * dw)        b.sub_(lr * db)             if ii % 10 == 0:            x = t.arange(0, 20).view(-1, 1)             y = x.float().mm(w)+ b.expand_as(x)                        x2, y2 = get_fake_data(batch_size=20)              print("w=",w.item(), "b=",b.item())
support.huaweicloud.com/usermanual-functiongraph/functiongraph_01_2112.html