@@ -35,7 +35,7 @@ def forward(self, batch, key=None):
3535
3636class TransformerEmbedder (AbstractEncoder ):
3737 """Some transformer encoder layers"""
38- def __init__ (self , n_embed , n_layer , vocab_size , max_seq_len = 77 , device = "cuda" ):
38+ def __init__ (self , n_embed , n_layer , vocab_size , max_seq_len = 77 , device = "cuda" if torch . cuda . is_available () else "cpu" ):
3939 super ().__init__ ()
4040 self .device = device
4141 self .transformer = TransformerWrapper (num_tokens = vocab_size , max_seq_len = max_seq_len ,
@@ -52,7 +52,7 @@ def encode(self, x):
5252
5353class BERTTokenizer (AbstractEncoder ):
5454 """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55- def __init__ (self , device = "cuda" , vq_interface = True , max_length = 77 ):
55+ def __init__ (self , device = "cuda" if torch . cuda . is_available () else "cpu" , vq_interface = True , max_length = 77 ):
5656 super ().__init__ ()
5757 from transformers import BertTokenizerFast # TODO: add to reuquirements
5858 self .tokenizer = BertTokenizerFast .from_pretrained ("bert-base-uncased" )
@@ -80,7 +80,7 @@ def decode(self, text):
8080class BERTEmbedder (AbstractEncoder ):
8181 """Uses the BERT tokenizr model and add some transformer encoder layers"""
8282 def __init__ (self , n_embed , n_layer , vocab_size = 30522 , max_seq_len = 77 ,
83- device = "cuda" , use_tokenizer = True , embedding_dropout = 0.0 ):
83+ device = "cuda" if torch . cuda . is_available () else "cpu" , use_tokenizer = True , embedding_dropout = 0.0 ):
8484 super ().__init__ ()
8585 self .use_tknz_fn = use_tokenizer
8686 if self .use_tknz_fn :
@@ -136,7 +136,7 @@ def encode(self, x):
136136
137137class FrozenCLIPEmbedder (AbstractEncoder ):
138138 """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139- def __init__ (self , version = "openai/clip-vit-large-patch14" , device = "cuda" , max_length = 77 ):
139+ def __init__ (self , version = "openai/clip-vit-large-patch14" , device = "cuda" if torch . cuda . is_available () else "cpu" , max_length = 77 ):
140140 super ().__init__ ()
141141 self .tokenizer = CLIPTokenizer .from_pretrained (version )
142142 self .transformer = CLIPTextModel .from_pretrained (version )
@@ -231,4 +231,4 @@ def forward(self, x):
231231if __name__ == "__main__" :
232232 from ldm .util import count_params
233233 model = FrozenCLIPEmbedder ()
234- count_params (model , verbose = True )
234+ count_params (model , verbose = True )
0 commit comments