diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index f97326a..0000000 --- a/test.ipynb +++ /dev/null @@ -1,165 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\art\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "from torch.utils.data import Dataset\n", - "from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer\n", - "from tqdm import tqdm as progress_bar\n", - "import torch\n", - "import matplotlib" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cuda\n" - ] - } - ], - "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", - "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", - "Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-large-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.12.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.12.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.12.crossattention.output.dense.bias', 'bert.encoder.layer.12.crossattention.output.dense.weight', 'bert.encoder.layer.12.crossattention.self.key.bias', 'bert.encoder.layer.12.crossattention.self.key.weight', 'bert.encoder.layer.12.crossattention.self.query.bias', 'bert.encoder.layer.12.crossattention.self.query.weight', 'bert.encoder.layer.12.crossattention.self.value.bias', 'bert.encoder.layer.12.crossattention.self.value.weight', 'bert.encoder.layer.13.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.13.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.13.crossattention.output.dense.bias', 'bert.encoder.layer.13.crossattention.output.dense.weight', 'bert.encoder.layer.13.crossattention.self.key.bias', 'bert.encoder.layer.13.crossattention.self.key.weight', 'bert.encoder.layer.13.crossattention.self.query.bias', 'bert.encoder.layer.13.crossattention.self.query.weight', 'bert.encoder.layer.13.crossattention.self.value.bias', 'bert.encoder.layer.13.crossattention.self.value.weight', 'bert.encoder.layer.14.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.14.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.14.crossattention.output.dense.bias', 'bert.encoder.layer.14.crossattention.output.dense.weight', 'bert.encoder.layer.14.crossattention.self.key.bias', 'bert.encoder.layer.14.crossattention.self.key.weight', 'bert.encoder.layer.14.crossattention.self.query.bias', 'bert.encoder.layer.14.crossattention.self.query.weight', 'bert.encoder.layer.14.crossattention.self.value.bias', 'bert.encoder.layer.14.crossattention.self.value.weight', 'bert.encoder.layer.15.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.15.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.15.crossattention.output.dense.bias', 'bert.encoder.layer.15.crossattention.output.dense.weight', 'bert.encoder.layer.15.crossattention.self.key.bias', 'bert.encoder.layer.15.crossattention.self.key.weight', 'bert.encoder.layer.15.crossattention.self.query.bias', 'bert.encoder.layer.15.crossattention.self.query.weight', 'bert.encoder.layer.15.crossattention.self.value.bias', 'bert.encoder.layer.15.crossattention.self.value.weight', 'bert.encoder.layer.16.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.16.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.16.crossattention.output.dense.bias', 'bert.encoder.layer.16.crossattention.output.dense.weight', 'bert.encoder.layer.16.crossattention.self.key.bias', 'bert.encoder.layer.16.crossattention.self.key.weight', 'bert.encoder.layer.16.crossattention.self.query.bias', 'bert.encoder.layer.16.crossattention.self.query.weight', 'bert.encoder.layer.16.crossattention.self.value.bias', 'bert.encoder.layer.16.crossattention.self.value.weight', 'bert.encoder.layer.17.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.17.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.17.crossattention.output.dense.bias', 'bert.encoder.layer.17.crossattention.output.dense.weight', 'bert.encoder.layer.17.crossattention.self.key.bias', 'bert.encoder.layer.17.crossattention.self.key.weight', 'bert.encoder.layer.17.crossattention.self.query.bias', 'bert.encoder.layer.17.crossattention.self.query.weight', 'bert.encoder.layer.17.crossattention.self.value.bias', 'bert.encoder.layer.17.crossattention.self.value.weight', 'bert.encoder.layer.18.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.18.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.18.crossattention.output.dense.bias', 'bert.encoder.layer.18.crossattention.output.dense.weight', 'bert.encoder.layer.18.crossattention.self.key.bias', 'bert.encoder.layer.18.crossattention.self.key.weight', 'bert.encoder.layer.18.crossattention.self.query.bias', 'bert.encoder.layer.18.crossattention.self.query.weight', 'bert.encoder.layer.18.crossattention.self.value.bias', 'bert.encoder.layer.18.crossattention.self.value.weight', 'bert.encoder.layer.19.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.19.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.19.crossattention.output.dense.bias', 'bert.encoder.layer.19.crossattention.output.dense.weight', 'bert.encoder.layer.19.crossattention.self.key.bias', 'bert.encoder.layer.19.crossattention.self.key.weight', 'bert.encoder.layer.19.crossattention.self.query.bias', 'bert.encoder.layer.19.crossattention.self.query.weight', 'bert.encoder.layer.19.crossattention.self.value.bias', 'bert.encoder.layer.19.crossattention.self.value.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.2.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.2.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.self.value.weight', 'bert.encoder.layer.20.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.20.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.20.crossattention.output.dense.bias', 'bert.encoder.layer.20.crossattention.output.dense.weight', 'bert.encoder.layer.20.crossattention.self.key.bias', 'bert.encoder.layer.20.crossattention.self.key.weight', 'bert.encoder.layer.20.crossattention.self.query.bias', 'bert.encoder.layer.20.crossattention.self.query.weight', 'bert.encoder.layer.20.crossattention.self.value.bias', 'bert.encoder.layer.20.crossattention.self.value.weight', 'bert.encoder.layer.21.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.21.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.21.crossattention.output.dense.bias', 'bert.encoder.layer.21.crossattention.output.dense.weight', 'bert.encoder.layer.21.crossattention.self.key.bias', 'bert.encoder.layer.21.crossattention.self.key.weight', 'bert.encoder.layer.21.crossattention.self.query.bias', 'bert.encoder.layer.21.crossattention.self.query.weight', 'bert.encoder.layer.21.crossattention.self.value.bias', 'bert.encoder.layer.21.crossattention.self.value.weight', 'bert.encoder.layer.22.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.22.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.22.crossattention.output.dense.bias', 'bert.encoder.layer.22.crossattention.output.dense.weight', 'bert.encoder.layer.22.crossattention.self.key.bias', 'bert.encoder.layer.22.crossattention.self.key.weight', 'bert.encoder.layer.22.crossattention.self.query.bias', 'bert.encoder.layer.22.crossattention.self.query.weight', 'bert.encoder.layer.22.crossattention.self.value.bias', 'bert.encoder.layer.22.crossattention.self.value.weight', 'bert.encoder.layer.23.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.23.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.23.crossattention.output.dense.bias', 'bert.encoder.layer.23.crossattention.output.dense.weight', 'bert.encoder.layer.23.crossattention.self.key.bias', 'bert.encoder.layer.23.crossattention.self.key.weight', 'bert.encoder.layer.23.crossattention.self.query.bias', 'bert.encoder.layer.23.crossattention.self.query.weight', 'bert.encoder.layer.23.crossattention.self.value.bias', 'bert.encoder.layer.23.crossattention.self.value.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.output.dense.bias', 'bert.encoder.layer.3.crossattention.output.dense.weight', 'bert.encoder.layer.3.crossattention.self.key.bias', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.self.value.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.4.crossattention.self.value.weight', 'bert.encoder.layer.5.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.5.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.key.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.5.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.value.weight', 'lm_head.bias', 'lm_head.decoder.bias']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - } - ], - "source": [ - "encoder = BertGenerationEncoder.from_pretrained(\"google-bert/bert-large-uncased\", bos_token_id=101, eos_token_id=102)\n", - "# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token\n", - "decoder = BertGenerationDecoder.from_pretrained(\"google-bert/bert-large-uncased\", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)\n", - "model = EncoderDecoderModel(encoder=encoder, decoder=decoder)\n", - "\n", - "# create tokenizer...\n", - "tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-large-uncased\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "class CodeDataset(Dataset):\n", - " def __init__(self):\n", - " with open(\"data/conala-train.json\") as f:\n", - " self.data = json.load(f)\n", - "\n", - " def __len__(self):\n", - " return len(self.data)\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.data[idx][\"rewritten_intent\"], self.data[idx][\"snippet\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)\n", - "dataloader = CodeDataset()\n", - "model = model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/2379 [00:00