AI开发平台MODELARTS-构建学习器:基于learner.predict进行模型推理

时间:2024-05-28 16:22:36

基于learner.predict进行模型推理

learner.predict(
	img_path='your_local_path_of_image',
    checkpoint='your_local_path_of_pretrained_model',
    gpu_ids=None,
    save_dir='your_local_path_for_saving_output'
)
表4 learner.predict参数

参数名称

可选/必选

参数类型

参数描述

img_path

必选

string

图片路径,当前predict仅支持推理图片。

checkpoint

可选

string

预训练模型路径,默认为None。当基于learner.fit完成训练且该参数为None,则基于训练后的模型参数进行推理。如果指定checkpoint路径,则加载对应路径的模型参数进行推理。

gpu_ids

可选

int/list

推理时使用的GPU,默认为None(使用cpu进行推理)。

save_dir

可选

string

默认为初始化Learner时指定的work_dir,可指定其他本地路径。

model

可选

Model object

自定义Model对象,仅用于open-mmlab系列模型,默认为None。默认值时使用基于learner.fit训练好的模型进行推理。

score_thr

可选

float

推理时结果置信度阈值,默认为0.3,仅用于open-mmlab系列模型。

ret_vis

可选

boolean

是否可视化推理结果,默认为False,仅用于open-mmlab系列模型。

support.huaweicloud.com/devtool-modelarts/devtool-modelarts_0221.html