AI开发平台MODELARTS-模型适配:获取模型shape

时间:2025-01-03 09:38:46

获取模型shape

由于在后续模型转换时需要知道待转换模型的shape信息,此处指导如何通过训练好的stable diffusion pytorch模型获取模型shape,主要有如下两种方式获取:

  • 方式一:通过stable diffusion的pytorch模型获取模型shape。
  • 方式二:通过查看ModelArts-Ascend代码仓库,根据每个模型的configs文件获取已知的shape大小。

下文主要介绍如何通过方式一获取模型shape。

  1. pipeline应用准备章节,已经下载到sd的pytorch模型(/home_host/work/runwayml/pytorch_models)。进入工作目录:

    cd /home_host/work

  2. 新建Python脚本文件“parse_models_shape.py”用于获取shape。其中,model_path是指上面下载的pytorch_models的路径。

    # parse_models_shape.py
    import torch
    import numpy as np
    from diffusers import StableDiffusionPipeline
    
    model_path = '/home_host/work/runwayml/pytorch_models'
    pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32)
    
    # TEXT ENCODER
    num_tokens = pipeline.text_encoder.config.max_position_embeddings
    text_hidden_size = pipeline.text_encoder.config.hidden_size
    text_input = pipeline.tokenizer(
        "A sample prompt",
        padding="max_length",
        max_length=pipeline.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    print("# TEXT ENCODER")
    print(f"input_ids: {np.array(text_input.input_ids.shape).tolist()}")
    
    # UNET
    unet_in_channels = pipeline.unet.config.in_channels
    unet_sample_size = pipeline.unet.config.sample_size
    print("# UNET")
    print(f"sample: [{2}, {unet_in_channels} {unet_sample_size} {unet_sample_size}]")
    print(f"timestep: [{1}]")  # 此处应该是1,否则和后续的推理脚本不一致。
    print(f"encoder_hidden_states: [{2}, {num_tokens} {text_hidden_size}]")
    
    # VAE ENCODER
    vae_encoder = pipeline.vae
    vae_in_channels = vae_encoder.config.in_channels
    vae_sample_size = vae_encoder.config.sample_size
    print("# VAE ENCODER")
    print(f"sample: [{1}, {vae_in_channels}, {vae_sample_size}, {vae_sample_size}]")
    
    # VAE DECODER
    vae_decoder = pipeline.vae
    vae_latent_channels = vae_decoder.config.latent_channels
    vae_out_channels = vae_decoder.config.out_channels
    print("# VAE DECODER")
    print(f"latent_sample: [{1}, {vae_latent_channels}, {unet_sample_size}, {unet_sample_size}]")
    
    # SAFETY CHECKER
    safety_checker = pipeline.safety_checker
    clip_num_channels = safety_checker.config.vision_config.num_channels
    clip_image_size = safety_checker.config.vision_config.image_size
    print("# SAFETY CHECKER")
    print(f"clip_input: [{1}, {clip_num_channels}, {clip_image_size}, {clip_image_size}]")
    print(f"images: [{1}, {vae_sample_size}, {vae_sample_size}, {vae_out_channels}]")

  3. 执行以下命令获取shape信息。

    python parse_models_shape.py

    可以看到获取的shape信息如下图所示。

    图1 shape信息

support.huaweicloud.com/bestpractice-modelarts/modelarts_10_2004.html