Skip to content

【第8章】encode_plus_untagged関数のtorch.Tensor変換について #46

@kawase621

Description

@kawase621

encode_plus_untagged関数で、最後にtorch.Tensorに変換する部分ですが、以下のように修正すべきではありませんか?

encoding = { k: torch.tensor([v]) for k, v in encoding.items() }

↓修正後

encoding = { k: torch.tensor(v) for k, v in encoding.items() }

torch.Tensorに変換する前にリスト化することで、encode_plus_untagged関数の返り値の要素であるinput_idsなどの値は「2次元Tensor」になっています。
その一方、encode_plus_tagged関数の返り値の要素であるinput_idsなどは値が「1次元Tensor」になっています。
性能評価もバッチ処理にて行うべくコードを書いている中で、encode_plus_untagged関数の返り値を使用してデータローダを作ろうとした際、input_idsなどの値が3次元のTensorになってしまい、BERTへ入力できないという問題が生じました。encode_plus_untagged関数のほうだけ、input_idsなどの値を2次元Tensorにしなくてはならない特別な理由がない限り、バッチ処理も想定して1次元Tensorにすべきと考えます。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions