AI开发平台MODELARTS-开发用于预置框架训练的代码:训练代码完整示例

时间:2024-12-10 11:36:22

训练代码完整示例

训练代码示例中涉及的代码与您使用的AI引擎密切相关,以下案例以Tensorflow框架为例。案例中使用到的“mnist.npz”文件需要提前下载并上传至OBS桶中,训练输入为“mnist.npz”所在OBS路径。

以下训练代码样例中包含了保存模型代码。

import os
import argparse
import tensorflow as tf

parser = argparse.ArgumentParser(description='train mnist')
parser.add_argument('--data_url', type=str, default="./Data/mnist.npz", help='path where the dataset is saved')
parser.add_argument('--train_url', type=str, default="./Model", help='path where the model is saved')
args = parser.parse_args()

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data(args.data_url)
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)

model.save(os.path.join(args.train_url, 'model'))
support.huaweicloud.com/usermanual-standard-modelarts/develop-modelarts-0008.html