-
Notifications
You must be signed in to change notification settings - Fork 77
Open
Description
encode_plus_untagged関数で、最後にtorch.Tensorに変換する部分ですが、以下のように修正すべきではありませんか?
encoding = { k: torch.tensor([v]) for k, v in encoding.items() }
↓修正後
encoding = { k: torch.tensor(v) for k, v in encoding.items() }
torch.Tensorに変換する前にリスト化することで、encode_plus_untagged関数の返り値の要素であるinput_idsなどの値は「2次元Tensor」になっています。
その一方、encode_plus_tagged関数の返り値の要素であるinput_idsなどは値が「1次元Tensor」になっています。
性能評価もバッチ処理にて行うべくコードを書いている中で、encode_plus_untagged関数の返り値を使用してデータローダを作ろうとした際、input_idsなどの値が3次元のTensorになってしまい、BERTへ入力できないという問題が生じました。encode_plus_untagged関数のほうだけ、input_idsなどの値を2次元Tensorにしなくてはならない特別な理由がない限り、バッチ処理も想定して1次元Tensorにすべきと考えます。
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels