AI开发平台MODELARTS-模型适配:获取模型shape
时间:2025-01-03 09:38:46
获取模型shape
由于在后续模型转换时需要知道待转换模型的shape信息,此处指导如何通过训练好的stable diffusion pytorch模型获取模型shape,主要有如下两种方式获取:
- 方式一:通过stable diffusion的pytorch模型获取模型shape。
- 方式二:通过查看ModelArts-Ascend代码仓库,根据每个模型的configs文件获取已知的shape大小。
下文主要介绍如何通过方式一获取模型shape。
- 在pipeline应用准备章节,已经下载到sd的pytorch模型(/home_host/work/runwayml/pytorch_models)。进入工作目录:
cd /home_host/work
- 新建Python脚本文件“parse_models_shape.py”用于获取shape。其中,model_path是指上面下载的pytorch_models的路径。
# parse_models_shape.py import torch import numpy as np from diffusers import StableDiffusionPipeline model_path = '/home_host/work/runwayml/pytorch_models' pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32) # TEXT ENCODER num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( "A sample prompt", padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) print("# TEXT ENCODER") print(f"input_ids: {np.array(text_input.input_ids.shape).tolist()}") # UNET unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size print("# UNET") print(f"sample: [{2}, {unet_in_channels} {unet_sample_size} {unet_sample_size}]") print(f"timestep: [{1}]") # 此处应该是1,否则和后续的推理脚本不一致。 print(f"encoder_hidden_states: [{2}, {num_tokens} {text_hidden_size}]") # VAE ENCODER vae_encoder = pipeline.vae vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size print("# VAE ENCODER") print(f"sample: [{1}, {vae_in_channels}, {vae_sample_size}, {vae_sample_size}]") # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels vae_out_channels = vae_decoder.config.out_channels print("# VAE DECODER") print(f"latent_sample: [{1}, {vae_latent_channels}, {unet_sample_size}, {unet_sample_size}]") # SAFETY CHECKER safety_checker = pipeline.safety_checker clip_num_channels = safety_checker.config.vision_config.num_channels clip_image_size = safety_checker.config.vision_config.image_size print("# SAFETY CHECKER") print(f"clip_input: [{1}, {clip_num_channels}, {clip_image_size}, {clip_image_size}]") print(f"images: [{1}, {vae_sample_size}, {vae_sample_size}, {vae_out_channels}]")
- 执行以下命令获取shape信息。
python parse_models_shape.py
可以看到获取的shape信息如下图所示。
图1 shape信息
support.huaweicloud.com/bestpractice-modelarts/modelarts_10_2004.html
看了此文的人还看了
CDN加速
GaussDB
文字转换成语音
免费的服务器
如何创建网站
域名网站购买
私有云桌面
云主机哪个好
域名怎么备案
手机云电脑
SSL证书申请
云点播服务器
免费OCR是什么
电脑云桌面
域名备案怎么弄
语音转文字
文字图片识别
云桌面是什么
网址安全检测
网站建设搭建
国外CDN加速
SSL免费证书申请
短信批量发送
图片OCR识别
云数据库MySQL
个人域名购买
录音转文字
扫描图片识别文字
OCR图片识别
行驶证识别
虚拟电话号码
电话呼叫中心软件
怎么制作一个网站
Email注册网站
华为VNC
图像文字识别
企业网站制作
个人网站搭建
华为云计算
免费租用云托管
云桌面云服务器
ocr文字识别免费版
HTTPS证书申请
图片文字识别转换
国外域名注册商
使用免费虚拟主机
云电脑主机多少钱
鲲鹏云手机
短信验证码平台
OCR图片文字识别
SSL证书是什么
申请企业邮箱步骤
免费的企业用邮箱
云免流搭建教程
域名价格
推荐文章