{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\art\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from torch.utils.data import Dataset\n", "from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer\n", "from tqdm import tqdm as progress_bar\n", "import torch\n", "import matplotlib" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", "Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-large-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.12.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.12.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.12.crossattention.output.dense.bias', 'bert.encoder.layer.12.crossattention.output.dense.weight', 'bert.encoder.layer.12.crossattention.self.key.bias', 'bert.encoder.layer.12.crossattention.self.key.weight', 'bert.encoder.layer.12.crossattention.self.query.bias', 'bert.encoder.layer.12.crossattention.self.query.weight', 'bert.encoder.layer.12.crossattention.self.value.bias', 'bert.encoder.layer.12.crossattention.self.value.weight', 'bert.encoder.layer.13.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.13.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.13.crossattention.output.dense.bias', 'bert.encoder.layer.13.crossattention.output.dense.weight', 'bert.encoder.layer.13.crossattention.self.key.bias', 'bert.encoder.layer.13.crossattention.self.key.weight', 'bert.encoder.layer.13.crossattention.self.query.bias', 'bert.encoder.layer.13.crossattention.self.query.weight', 'bert.encoder.layer.13.crossattention.self.value.bias', 'bert.encoder.layer.13.crossattention.self.value.weight', 'bert.encoder.layer.14.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.14.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.14.crossattention.output.dense.bias', 'bert.encoder.layer.14.crossattention.output.dense.weight', 'bert.encoder.layer.14.crossattention.self.key.bias', 'bert.encoder.layer.14.crossattention.self.key.weight', 'bert.encoder.layer.14.crossattention.self.query.bias', 'bert.encoder.layer.14.crossattention.self.query.weight', 'bert.encoder.layer.14.crossattention.self.value.bias', 'bert.encoder.layer.14.crossattention.self.value.weight', 'bert.encoder.layer.15.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.15.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.15.crossattention.output.dense.bias', 'bert.encoder.layer.15.crossattention.output.dense.weight', 'bert.encoder.layer.15.crossattention.self.key.bias', 'bert.encoder.layer.15.crossattention.self.key.weight', 'bert.encoder.layer.15.crossattention.self.query.bias', 'bert.encoder.layer.15.crossattention.self.query.weight', 'bert.encoder.layer.15.crossattention.self.value.bias', 'bert.encoder.layer.15.crossattention.self.value.weight', 'bert.encoder.layer.16.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.16.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.16.crossattention.output.dense.bias', 'bert.encoder.layer.16.crossattention.output.dense.weight', 'bert.encoder.layer.16.crossattention.self.key.bias', 'bert.encoder.layer.16.crossattention.self.key.weight', 'bert.encoder.layer.16.crossattention.self.query.bias', 'bert.encoder.layer.16.crossattention.self.query.weight', 'bert.encoder.layer.16.crossattention.self.value.bias', 'bert.encoder.layer.16.crossattention.self.value.weight', 'bert.encoder.layer.17.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.17.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.17.crossattention.output.dense.bias', 'bert.encoder.layer.17.crossattention.output.dense.weight', 'bert.encoder.layer.17.crossattention.self.key.bias', 'bert.encoder.layer.17.crossattention.self.key.weight', 'bert.encoder.layer.17.crossattention.self.query.bias', 'bert.encoder.layer.17.crossattention.self.query.weight', 'bert.encoder.layer.17.crossattention.self.value.bias', 'bert.encoder.layer.17.crossattention.self.value.weight', 'bert.encoder.layer.18.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.18.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.18.crossattention.output.dense.bias', 'bert.encoder.layer.18.crossattention.output.dense.weight', 'bert.encoder.layer.18.crossattention.self.key.bias', 'bert.encoder.layer.18.crossattention.self.key.weight', 'bert.encoder.layer.18.crossattention.self.query.bias', 'bert.encoder.layer.18.crossattention.self.query.weight', 'bert.encoder.layer.18.crossattention.self.value.bias', 'bert.encoder.layer.18.crossattention.self.value.weight', 'bert.encoder.layer.19.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.19.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.19.crossattention.output.dense.bias', 'bert.encoder.layer.19.crossattention.output.dense.weight', 'bert.encoder.layer.19.crossattention.self.key.bias', 'bert.encoder.layer.19.crossattention.self.key.weight', 'bert.encoder.layer.19.crossattention.self.query.bias', 'bert.encoder.layer.19.crossattention.self.query.weight', 'bert.encoder.layer.19.crossattention.self.value.bias', 'bert.encoder.layer.19.crossattention.self.value.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.2.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.2.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.self.value.weight', 'bert.encoder.layer.20.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.20.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.20.crossattention.output.dense.bias', 'bert.encoder.layer.20.crossattention.output.dense.weight', 'bert.encoder.layer.20.crossattention.self.key.bias', 'bert.encoder.layer.20.crossattention.self.key.weight', 'bert.encoder.layer.20.crossattention.self.query.bias', 'bert.encoder.layer.20.crossattention.self.query.weight', 'bert.encoder.layer.20.crossattention.self.value.bias', 'bert.encoder.layer.20.crossattention.self.value.weight', 'bert.encoder.layer.21.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.21.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.21.crossattention.output.dense.bias', 'bert.encoder.layer.21.crossattention.output.dense.weight', 'bert.encoder.layer.21.crossattention.self.key.bias', 'bert.encoder.layer.21.crossattention.self.key.weight', 'bert.encoder.layer.21.crossattention.self.query.bias', 'bert.encoder.layer.21.crossattention.self.query.weight', 'bert.encoder.layer.21.crossattention.self.value.bias', 'bert.encoder.layer.21.crossattention.self.value.weight', 'bert.encoder.layer.22.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.22.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.22.crossattention.output.dense.bias', 'bert.encoder.layer.22.crossattention.output.dense.weight', 'bert.encoder.layer.22.crossattention.self.key.bias', 'bert.encoder.layer.22.crossattention.self.key.weight', 'bert.encoder.layer.22.crossattention.self.query.bias', 'bert.encoder.layer.22.crossattention.self.query.weight', 'bert.encoder.layer.22.crossattention.self.value.bias', 'bert.encoder.layer.22.crossattention.self.value.weight', 'bert.encoder.layer.23.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.23.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.23.crossattention.output.dense.bias', 'bert.encoder.layer.23.crossattention.output.dense.weight', 'bert.encoder.layer.23.crossattention.self.key.bias', 'bert.encoder.layer.23.crossattention.self.key.weight', 'bert.encoder.layer.23.crossattention.self.query.bias', 'bert.encoder.layer.23.crossattention.self.query.weight', 'bert.encoder.layer.23.crossattention.self.value.bias', 'bert.encoder.layer.23.crossattention.self.value.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.output.dense.bias', 'bert.encoder.layer.3.crossattention.output.dense.weight', 'bert.encoder.layer.3.crossattention.self.key.bias', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.self.value.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.4.crossattention.self.value.weight', 'bert.encoder.layer.5.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.5.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.key.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.5.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.value.weight', 'lm_head.bias', 'lm_head.decoder.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "encoder = BertGenerationEncoder.from_pretrained(\"google-bert/bert-large-uncased\", bos_token_id=101, eos_token_id=102)\n", "# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token\n", "decoder = BertGenerationDecoder.from_pretrained(\"google-bert/bert-large-uncased\", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)\n", "model = EncoderDecoderModel(encoder=encoder, decoder=decoder)\n", "\n", "# create tokenizer...\n", "tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-large-uncased\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "class CodeDataset(Dataset):\n", " def __init__(self):\n", " with open(\"data/conala-train.json\") as f:\n", " self.data = json.load(f)\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " return self.data[idx][\"rewritten_intent\"], self.data[idx][\"snippet\"]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)\n", "dataloader = CodeDataset()\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/2379 [00:00