AI开发平台MODELARTS-背景信息:核心基础类介绍

时间:2024-04-11 20:22:54

核心基础类介绍

使用AI Gallery SDK构建自定义模型,需要了解2个核心基础类“PretrainedModel”“PretrainedConfig”之间的交互。

  • “PretrainedConfig”:预训练模型的配置基类

    提供模型配置的通用属性和两个主要方法,用于序列化和反序列化配置文件。

    PretrainedConfig.from_pretrained(dir) # 从目录中加载序列化对象(本地或者是url),配置文件为dir/config.json
    PretrainedConfig.save_pretrained(dir) # 将配置实例序列化到dir/config.json
  • “PretrainedModel”:预训练模型的基类

    包含一个配置实例“config”,提供两个主要方法,用来加载和保存预训练模型。

    # 1. 调用 init_weights() 来初始化所有模型权重
    # 2. 从目录中(本地或者是url)中导入序列化的模型
    # 3. 使用导入的模型权重覆盖所有初始化的权重
    # 4. 调用 PretrainedConfig.from_pretrained(dir)来将配置设置到self.config中
    PretrainedModel.from_pretrained(dir)
    
    # 将模型实例序列化到 dir/pytorch_model.bin 中
    PretrainedModel.save_pretrained(dir)
    
    # 给定input_ids,生成 output_ids,在循环中调用 PretrainedModel.forward() 来做前向推理
    PretrainedModel.generate()
support.huaweicloud.com/aimarket-modelarts/ma_gallery_0047.html