mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-25 17:49:10 +00:00
feasible training times by dropping to bert-base
This commit is contained in:
parent
6a4027afb7
commit
6cbffcdec1
5
test.py
5
test.py
@ -8,9 +8,9 @@ import matplotlib
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(device)
|
||||
|
||||
encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-large-uncased", bos_token_id=101, eos_token_id=102)
|
||||
encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102)
|
||||
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
|
||||
decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
|
||||
decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
|
||||
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
# create tokenizer...
|
||||
@ -43,7 +43,6 @@ for i in range(epochs):
|
||||
|
||||
for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)):
|
||||
|
||||
print(question)
|
||||
input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
|
||||
label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user