mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-10 07:04:45 +00:00
feasible training times by dropping to bert-base
This commit is contained in:
parent
6a4027afb7
commit
6cbffcdec1
5
test.py
5
test.py
@ -8,9 +8,9 @@ import matplotlib
|
|||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print(device)
|
print(device)
|
||||||
|
|
||||||
encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-large-uncased", bos_token_id=101, eos_token_id=102)
|
encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102)
|
||||||
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
|
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
|
||||||
decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
|
decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
|
||||||
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
|
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
# create tokenizer...
|
# create tokenizer...
|
||||||
@ -43,7 +43,6 @@ for i in range(epochs):
|
|||||||
|
|
||||||
for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)):
|
for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)):
|
||||||
|
|
||||||
print(question)
|
|
||||||
input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
|
input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
|
||||||
label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)
|
label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user