From 6cbffcdec13ed8016563b0b6cd54ef9f87981fac Mon Sep 17 00:00:00 2001 From: ltcptgeneral Date: Wed, 6 Mar 2024 21:48:09 -0800 Subject: [PATCH] feasible training times by dropping to bert-base --- test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test.py b/test.py index 4bc953b..f4bde2e 100644 --- a/test.py +++ b/test.py @@ -8,9 +8,9 @@ import matplotlib device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) -encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-large-uncased", bos_token_id=101, eos_token_id=102) +encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102) # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token -decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) +decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) model = EncoderDecoderModel(encoder=encoder, decoder=decoder) # create tokenizer... @@ -43,7 +43,6 @@ for i in range(epochs): for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)): - print(question) input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device) label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)