mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-09 22:54:45 +00:00
166 lines
19 KiB
Plaintext
166 lines
19 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\art\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from torch.utils.data import Dataset\n",
|
||
|
"from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer\n",
|
||
|
"from tqdm import tqdm as progress_bar\n",
|
||
|
"import torch\n",
|
||
|
"import matplotlib"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cuda\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n",
|
||
|
"You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n",
|
||
|
"Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-large-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.12.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.12.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.12.crossattention.output.dense.bias', 'bert.encoder.layer.12.crossattention.output.dense.weight', 'bert.encoder.layer.12.crossattention.self.key.bias', 'bert.encoder.layer.12.crossattention.self.key.weight', 'bert.encoder.layer.12.crossattention.self.query.bias', 'bert.encoder.layer.12.crossattention.self.query.weight', 'bert.encoder.layer.12.crossattention.self.value.bias', 'bert.encoder.layer.12.crossattention.self.value.weight', 'bert.encoder.layer.13.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.13.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.13.crossattention.output.dense.bias', 'bert.encoder.layer.13.crossattention.output.dense.weight', 'bert.encoder.layer.13.crossattention.self.key.bias', 'bert.encoder.layer.13.crossattention.self.key.weight', 'bert.encoder.layer.13.crossattention.self.query.bias', 'bert.encoder.layer.13.crossattention.self.query.weight', 'bert.encoder.layer.13.crossattention.self.value.bias', 'bert.encoder.layer.13.crossattention.self.value.weight', 'bert.encoder.layer.14.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.14.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.14.crossattention.output.dense.bias', 'bert.encoder.layer.14.crossattention.output.dense.weight', 'bert.encoder.layer.14.crossattention.self.key.bias', 'bert.encoder.layer.14.crossattention.self.key.weight', 'bert.encoder.layer.14.crossattention.self.query.bias', 'bert.encoder.layer.14.crossattention.self.query.weight', '
|
||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"encoder = BertGenerationEncoder.from_pretrained(\"google-bert/bert-large-uncased\", bos_token_id=101, eos_token_id=102)\n",
|
||
|
"# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token\n",
|
||
|
"decoder = BertGenerationDecoder.from_pretrained(\"google-bert/bert-large-uncased\", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)\n",
|
||
|
"model = EncoderDecoderModel(encoder=encoder, decoder=decoder)\n",
|
||
|
"\n",
|
||
|
"# create tokenizer...\n",
|
||
|
"tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-large-uncased\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import json\n",
|
||
|
"\n",
|
||
|
"class CodeDataset(Dataset):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" with open(\"data/conala-train.json\") as f:\n",
|
||
|
" self.data = json.load(f)\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.data)\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, idx):\n",
|
||
|
" return self.data[idx][\"rewritten_intent\"], self.data[idx][\"snippet\"]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)\n",
|
||
|
"dataloader = CodeDataset()\n",
|
||
|
"model = model.to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 0%| | 0/2379 [00:00<?, ?it/s]C:\\Users\\art\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\transformers\\models\\encoder_decoder\\modeling_encoder_decoder.py:636: FutureWarning: Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no need to pass them yourself anymore.\n",
|
||
|
" warnings.warn(DEPRECATION_WARNING, FutureWarning)\n",
|
||
|
" 0%| | 3/2379 [01:18<16:16:50, 24.67s/it]"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"losses = []\n",
|
||
|
"epochs = 10\n",
|
||
|
"for i in range(epochs):\n",
|
||
|
"\n",
|
||
|
" epoch_loss = 0\n",
|
||
|
"\n",
|
||
|
" for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)):\n",
|
||
|
"\n",
|
||
|
" print(questions)\n",
|
||
|
" input_ids = tokenizer(question, add_special_tokens=False, return_tensors=\"pt\").input_ids.to(device)\n",
|
||
|
" label_ids = tokenizer(answer, return_tensors=\"pt\").input_ids.to(device)\n",
|
||
|
"\n",
|
||
|
" loss = model(input_ids=input_ids, decoder_input_ids=label_ids, labels=label_ids).loss\n",
|
||
|
"\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
"\n",
|
||
|
" epoch_loss += loss.item()\n",
|
||
|
"\n",
|
||
|
" losses.append(epoch_loss)\n",
|
||
|
"\n",
|
||
|
"plt.plot(losses, color=\"green\", label=\"Training Loss\")\n",
|
||
|
"plt.legend(loc = 'upper left')\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.11.8"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|