commit 6a4027afb74af461106996ab9c2df24ff46b2874 Author: ltcptgeneral Date: Wed Mar 6 21:00:38 2024 -0800 minimal example of bert idea diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8f14cb9 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +**/data/* \ No newline at end of file diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..f97326a --- /dev/null +++ b/test.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\art\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from torch.utils.data import Dataset\n", + "from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer\n", + "from tqdm import tqdm as progress_bar\n", + "import torch\n", + "import matplotlib" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuda\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", + "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", + "Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-large-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.12.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.12.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.12.crossattention.output.dense.bias', 'bert.encoder.layer.12.crossattention.output.dense.weight', 'bert.encoder.layer.12.crossattention.self.key.bias', 'bert.encoder.layer.12.crossattention.self.key.weight', 'bert.encoder.layer.12.crossattention.self.query.bias', 'bert.encoder.layer.12.crossattention.self.query.weight', 'bert.encoder.layer.12.crossattention.self.value.bias', 'bert.encoder.layer.12.crossattention.self.value.weight', 'bert.encoder.layer.13.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.13.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.13.crossattention.output.dense.bias', 'bert.encoder.layer.13.crossattention.output.dense.weight', 'bert.encoder.layer.13.crossattention.self.key.bias', 'bert.encoder.layer.13.crossattention.self.key.weight', 'bert.encoder.layer.13.crossattention.self.query.bias', 'bert.encoder.layer.13.crossattention.self.query.weight', 'bert.encoder.layer.13.crossattention.self.value.bias', 'bert.encoder.layer.13.crossattention.self.value.weight', 'bert.encoder.layer.14.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.14.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.14.crossattention.output.dense.bias', 'bert.encoder.layer.14.crossattention.output.dense.weight', 'bert.encoder.layer.14.crossattention.self.key.bias', 'bert.encoder.layer.14.crossattention.self.key.weight', 'bert.encoder.layer.14.crossattention.self.query.bias', 'bert.encoder.layer.14.crossattention.self.query.weight', 'bert.encoder.layer.14.crossattention.self.value.bias', 'bert.encoder.layer.14.crossattention.self.value.weight', 'bert.encoder.layer.15.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.15.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.15.crossattention.output.dense.bias', 'bert.encoder.layer.15.crossattention.output.dense.weight', 'bert.encoder.layer.15.crossattention.self.key.bias', 'bert.encoder.layer.15.crossattention.self.key.weight', 'bert.encoder.layer.15.crossattention.self.query.bias', 'bert.encoder.layer.15.crossattention.self.query.weight', 'bert.encoder.layer.15.crossattention.self.value.bias', 'bert.encoder.layer.15.crossattention.self.value.weight', 'bert.encoder.layer.16.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.16.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.16.crossattention.output.dense.bias', 'bert.encoder.layer.16.crossattention.output.dense.weight', 'bert.encoder.layer.16.crossattention.self.key.bias', 'bert.encoder.layer.16.crossattention.self.key.weight', 'bert.encoder.layer.16.crossattention.self.query.bias', 'bert.encoder.layer.16.crossattention.self.query.weight', 'bert.encoder.layer.16.crossattention.self.value.bias', 'bert.encoder.layer.16.crossattention.self.value.weight', 'bert.encoder.layer.17.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.17.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.17.crossattention.output.dense.bias', 'bert.encoder.layer.17.crossattention.output.dense.weight', 'bert.encoder.layer.17.crossattention.self.key.bias', 'bert.encoder.layer.17.crossattention.self.key.weight', 'bert.encoder.layer.17.crossattention.self.query.bias', 'bert.encoder.layer.17.crossattention.self.query.weight', 'bert.encoder.layer.17.crossattention.self.value.bias', 'bert.encoder.layer.17.crossattention.self.value.weight', 'bert.encoder.layer.18.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.18.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.18.crossattention.output.dense.bias', 'bert.encoder.layer.18.crossattention.output.dense.weight', 'bert.encoder.layer.18.crossattention.self.key.bias', 'bert.encoder.layer.18.crossattention.self.key.weight', 'bert.encoder.layer.18.crossattention.self.query.bias', 'bert.encoder.layer.18.crossattention.self.query.weight', 'bert.encoder.layer.18.crossattention.self.value.bias', 'bert.encoder.layer.18.crossattention.self.value.weight', 'bert.encoder.layer.19.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.19.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.19.crossattention.output.dense.bias', 'bert.encoder.layer.19.crossattention.output.dense.weight', 'bert.encoder.layer.19.crossattention.self.key.bias', 'bert.encoder.layer.19.crossattention.self.key.weight', 'bert.encoder.layer.19.crossattention.self.query.bias', 'bert.encoder.layer.19.crossattention.self.query.weight', 'bert.encoder.layer.19.crossattention.self.value.bias', 'bert.encoder.layer.19.crossattention.self.value.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.2.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.2.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.self.value.weight', 'bert.encoder.layer.20.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.20.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.20.crossattention.output.dense.bias', 'bert.encoder.layer.20.crossattention.output.dense.weight', 'bert.encoder.layer.20.crossattention.self.key.bias', 'bert.encoder.layer.20.crossattention.self.key.weight', 'bert.encoder.layer.20.crossattention.self.query.bias', 'bert.encoder.layer.20.crossattention.self.query.weight', 'bert.encoder.layer.20.crossattention.self.value.bias', 'bert.encoder.layer.20.crossattention.self.value.weight', 'bert.encoder.layer.21.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.21.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.21.crossattention.output.dense.bias', 'bert.encoder.layer.21.crossattention.output.dense.weight', 'bert.encoder.layer.21.crossattention.self.key.bias', 'bert.encoder.layer.21.crossattention.self.key.weight', 'bert.encoder.layer.21.crossattention.self.query.bias', 'bert.encoder.layer.21.crossattention.self.query.weight', 'bert.encoder.layer.21.crossattention.self.value.bias', 'bert.encoder.layer.21.crossattention.self.value.weight', 'bert.encoder.layer.22.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.22.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.22.crossattention.output.dense.bias', 'bert.encoder.layer.22.crossattention.output.dense.weight', 'bert.encoder.layer.22.crossattention.self.key.bias', 'bert.encoder.layer.22.crossattention.self.key.weight', 'bert.encoder.layer.22.crossattention.self.query.bias', 'bert.encoder.layer.22.crossattention.self.query.weight', 'bert.encoder.layer.22.crossattention.self.value.bias', 'bert.encoder.layer.22.crossattention.self.value.weight', 'bert.encoder.layer.23.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.23.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.23.crossattention.output.dense.bias', 'bert.encoder.layer.23.crossattention.output.dense.weight', 'bert.encoder.layer.23.crossattention.self.key.bias', 'bert.encoder.layer.23.crossattention.self.key.weight', 'bert.encoder.layer.23.crossattention.self.query.bias', 'bert.encoder.layer.23.crossattention.self.query.weight', 'bert.encoder.layer.23.crossattention.self.value.bias', 'bert.encoder.layer.23.crossattention.self.value.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.output.dense.bias', 'bert.encoder.layer.3.crossattention.output.dense.weight', 'bert.encoder.layer.3.crossattention.self.key.bias', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.self.value.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.4.crossattention.self.value.weight', 'bert.encoder.layer.5.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.5.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.key.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.5.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.value.weight', 'lm_head.bias', 'lm_head.decoder.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "encoder = BertGenerationEncoder.from_pretrained(\"google-bert/bert-large-uncased\", bos_token_id=101, eos_token_id=102)\n", + "# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token\n", + "decoder = BertGenerationDecoder.from_pretrained(\"google-bert/bert-large-uncased\", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)\n", + "model = EncoderDecoderModel(encoder=encoder, decoder=decoder)\n", + "\n", + "# create tokenizer...\n", + "tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-large-uncased\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "class CodeDataset(Dataset):\n", + " def __init__(self):\n", + " with open(\"data/conala-train.json\") as f:\n", + " self.data = json.load(f)\n", + "\n", + " def __len__(self):\n", + " return len(self.data)\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.data[idx][\"rewritten_intent\"], self.data[idx][\"snippet\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)\n", + "dataloader = CodeDataset()\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/2379 [00:00