torchinfoでHugging Faceのモデルのsummaryを出力しようとするとRuntimeError: Failed to run torchinfo.

kerasのmodel.summary()と似たようなことをpytorchのモデルで行ってくれるライブラリとして、torchinfoがある。

GitHub - TylerYep/torchinfo: View model summaries in PyTorch!
View model summaries in PyTorch! Contribute to TylerYep/torchinfo development by creating an account on GitHub.

llama2がどんな構造でパラメータ数なのか気になったのでtorchinfoを使ってみた所、表題のエラー。

elyza/ELYZA-japanese-Llama-2-7b · Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.
from transformers import AutoModelForCausalLM
from torchinfo import summary      


model_name = "elyza/ELYZA-japanese-Llama-2-7b"
model = AutoModelForCausalLM.from_pretrained(model_name)

batch_size = 1
max_len = 8
input_size = (batch_size, max_len)

summary(model, input_size=input_size)


RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

スタックトレースを見ろと書いてあるので、エラーログを遡ってみるとヒントを発見。

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

モデルにはlongかintを入力しなければならないのに、floatが入力されたのでエラーが出た様子。

summaryでdtypeを指定できないか、引数を見てみる。

def summary(
    model: nn.Module,
    input_size: INPUT_SIZE_TYPE | None = None,
    input_data: INPUT_DATA_TYPE | None = None,
    batch_dim: int | None = None,
    cache_forward_pass: bool | None = None,
    col_names: Iterable[str] | None = None,
    col_width: int = 25,
    depth: int = 3,
    device: torch.device | str | None = None,
    dtypes: list[torch.dtype] | None = None,
    mode: str | None = None,
    row_settings: Iterable[str] | None = None,
    verbose: int | None = None,
    **kwargs: Any,
) -> ModelStatistics:

    Args:

    (中略)

        dtypes (List[torch.dtype]):
                If you use input_size, torchinfo assumes your input uses FloatTensors.
                If your model use a different data type, specify that dtype.
                For multiple inputs, specify the size of both inputs, and
                also specify the types of each parameter here.
                Default: None

dtypesはデフォルトだとFloatTensorsになっている。

画像や時系列はfloatを入力することが多いからだろうか?

LLMのトークンはintなので、dtypesを明示的に指定して再実行すれば良さそう。

torchの整数型には、int8, int16, int32, int64がある。

Tensor Attributes — PyTorch 2.2 documentation

スタックトレースのログによるとlongかintにすればいいみたいなので、torch.intかtorch.longを指定すればいい。

今回はtorch.intで実行してみた。

summary(model, input_size=input_size, dtypes=[torch.int])
=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
LlamaForCausalLM                                        [1, 32, 8, 128]           --
├─LlamaModel: 1-1                                       [1, 32, 8, 128]           --
│    └─Embedding: 2-1                                   [1, 8, 4096]              131,072,000
│    └─ModuleList: 2-2                                  --                        --
│    │    └─LlamaDecoderLayer: 3-1                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-2                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-3                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-4                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-5                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-6                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-7                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-8                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-9                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-10                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-11                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-12                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-13                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-14                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-15                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-16                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-17                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-18                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-19                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-20                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-21                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-22                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-23                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-24                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-25                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-26                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-27                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-28                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-29                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-30                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-31                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-32                     [1, 8, 4096]              202,383,360
│    └─LlamaRMSNorm: 2-3                                [1, 8, 4096]              4,096
├─Linear: 1-2                                           [1, 8, 32000]             131,072,000
=========================================================================================================
Total params: 6,738,415,616
Trainable params: 6,738,415,616
Non-trainable params: 0
Total mult-adds (G): 6.74
=========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 106.38
Params size (MB): 26953.66
Estimated Total Size (MB): 27060.04
=========================================================================================================

無事にモデルの情報が出力された。

デコーダーが32層あり、num_hidden_layers=32と一致していた。

パラメータ数は67億で、確かに7Bだった。

コメント

タイトルとURLをコピーしました