diff --git a/.gitignore b/.gitignore index c2eab8f..8bb44b6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ **/data/* -/env \ No newline at end of file +/env +**/*.zip \ No newline at end of file diff --git a/dqn_wordle.ipynb b/dqn_wordle.ipynb new file mode 100644 index 0000000..a3b8f14 --- /dev/null +++ b/dqn_wordle.ipynb @@ -0,0 +1,114 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "import gym_wordle\n", + "from stable_baselines3 import DQN\n", + "import numpy as np\n", + "import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env = gym.make(\"Wordle-v0\")\n", + "\n", + "print(env)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "total_timesteps = 100000\n", + "model = DQN(\"MlpPolicy\", env, verbose=0)\n", + "model.learn(total_timesteps=total_timesteps, progress_bar=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test(model):\n", + "\n", + " end_rewards = []\n", + "\n", + " for i in range(1000):\n", + " \n", + " state = env.reset()\n", + "\n", + " done = False\n", + "\n", + " while not done:\n", + "\n", + " action, _states = model.predict(state, deterministic=True)\n", + "\n", + " state, reward, done, info = env.step(action)\n", + " \n", + " end_rewards.append(reward == 0)\n", + " \n", + " return np.sum(end_rewards) / len(end_rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.save(\"dqn_wordle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = DQN.load(\"dqn_wordle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(test(model))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index f97326a..0000000 --- a/test.ipynb +++ /dev/null @@ -1,165 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\art\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "from torch.utils.data import Dataset\n", - "from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer\n", - "from tqdm import tqdm as progress_bar\n", - "import torch\n", - "import matplotlib" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cuda\n" - ] - } - ], - "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", - "You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.\n", - "Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-large-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.12.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.12.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.12.crossattention.output.dense.bias', 'bert.encoder.layer.12.crossattention.output.dense.weight', 'bert.encoder.layer.12.crossattention.self.key.bias', 'bert.encoder.layer.12.crossattention.self.key.weight', 'bert.encoder.layer.12.crossattention.self.query.bias', 'bert.encoder.layer.12.crossattention.self.query.weight', 'bert.encoder.layer.12.crossattention.self.value.bias', 'bert.encoder.layer.12.crossattention.self.value.weight', 'bert.encoder.layer.13.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.13.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.13.crossattention.output.dense.bias', 'bert.encoder.layer.13.crossattention.output.dense.weight', 'bert.encoder.layer.13.crossattention.self.key.bias', 'bert.encoder.layer.13.crossattention.self.key.weight', 'bert.encoder.layer.13.crossattention.self.query.bias', 'bert.encoder.layer.13.crossattention.self.query.weight', 'bert.encoder.layer.13.crossattention.self.value.bias', 'bert.encoder.layer.13.crossattention.self.value.weight', 'bert.encoder.layer.14.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.14.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.14.crossattention.output.dense.bias', 'bert.encoder.layer.14.crossattention.output.dense.weight', 'bert.encoder.layer.14.crossattention.self.key.bias', 'bert.encoder.layer.14.crossattention.self.key.weight', 'bert.encoder.layer.14.crossattention.self.query.bias', 'bert.encoder.layer.14.crossattention.self.query.weight', 'bert.encoder.layer.14.crossattention.self.value.bias', 'bert.encoder.layer.14.crossattention.self.value.weight', 'bert.encoder.layer.15.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.15.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.15.crossattention.output.dense.bias', 'bert.encoder.layer.15.crossattention.output.dense.weight', 'bert.encoder.layer.15.crossattention.self.key.bias', 'bert.encoder.layer.15.crossattention.self.key.weight', 'bert.encoder.layer.15.crossattention.self.query.bias', 'bert.encoder.layer.15.crossattention.self.query.weight', 'bert.encoder.layer.15.crossattention.self.value.bias', 'bert.encoder.layer.15.crossattention.self.value.weight', 'bert.encoder.layer.16.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.16.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.16.crossattention.output.dense.bias', 'bert.encoder.layer.16.crossattention.output.dense.weight', 'bert.encoder.layer.16.crossattention.self.key.bias', 'bert.encoder.layer.16.crossattention.self.key.weight', 'bert.encoder.layer.16.crossattention.self.query.bias', 'bert.encoder.layer.16.crossattention.self.query.weight', 'bert.encoder.layer.16.crossattention.self.value.bias', 'bert.encoder.layer.16.crossattention.self.value.weight', 'bert.encoder.layer.17.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.17.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.17.crossattention.output.dense.bias', 'bert.encoder.layer.17.crossattention.output.dense.weight', 'bert.encoder.layer.17.crossattention.self.key.bias', 'bert.encoder.layer.17.crossattention.self.key.weight', 'bert.encoder.layer.17.crossattention.self.query.bias', 'bert.encoder.layer.17.crossattention.self.query.weight', 'bert.encoder.layer.17.crossattention.self.value.bias', 'bert.encoder.layer.17.crossattention.self.value.weight', 'bert.encoder.layer.18.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.18.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.18.crossattention.output.dense.bias', 'bert.encoder.layer.18.crossattention.output.dense.weight', 'bert.encoder.layer.18.crossattention.self.key.bias', 'bert.encoder.layer.18.crossattention.self.key.weight', 'bert.encoder.layer.18.crossattention.self.query.bias', 'bert.encoder.layer.18.crossattention.self.query.weight', 'bert.encoder.layer.18.crossattention.self.value.bias', 'bert.encoder.layer.18.crossattention.self.value.weight', 'bert.encoder.layer.19.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.19.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.19.crossattention.output.dense.bias', 'bert.encoder.layer.19.crossattention.output.dense.weight', 'bert.encoder.layer.19.crossattention.self.key.bias', 'bert.encoder.layer.19.crossattention.self.key.weight', 'bert.encoder.layer.19.crossattention.self.query.bias', 'bert.encoder.layer.19.crossattention.self.query.weight', 'bert.encoder.layer.19.crossattention.self.value.bias', 'bert.encoder.layer.19.crossattention.self.value.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.2.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.2.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.self.value.weight', 'bert.encoder.layer.20.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.20.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.20.crossattention.output.dense.bias', 'bert.encoder.layer.20.crossattention.output.dense.weight', 'bert.encoder.layer.20.crossattention.self.key.bias', 'bert.encoder.layer.20.crossattention.self.key.weight', 'bert.encoder.layer.20.crossattention.self.query.bias', 'bert.encoder.layer.20.crossattention.self.query.weight', 'bert.encoder.layer.20.crossattention.self.value.bias', 'bert.encoder.layer.20.crossattention.self.value.weight', 'bert.encoder.layer.21.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.21.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.21.crossattention.output.dense.bias', 'bert.encoder.layer.21.crossattention.output.dense.weight', 'bert.encoder.layer.21.crossattention.self.key.bias', 'bert.encoder.layer.21.crossattention.self.key.weight', 'bert.encoder.layer.21.crossattention.self.query.bias', 'bert.encoder.layer.21.crossattention.self.query.weight', 'bert.encoder.layer.21.crossattention.self.value.bias', 'bert.encoder.layer.21.crossattention.self.value.weight', 'bert.encoder.layer.22.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.22.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.22.crossattention.output.dense.bias', 'bert.encoder.layer.22.crossattention.output.dense.weight', 'bert.encoder.layer.22.crossattention.self.key.bias', 'bert.encoder.layer.22.crossattention.self.key.weight', 'bert.encoder.layer.22.crossattention.self.query.bias', 'bert.encoder.layer.22.crossattention.self.query.weight', 'bert.encoder.layer.22.crossattention.self.value.bias', 'bert.encoder.layer.22.crossattention.self.value.weight', 'bert.encoder.layer.23.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.23.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.23.crossattention.output.dense.bias', 'bert.encoder.layer.23.crossattention.output.dense.weight', 'bert.encoder.layer.23.crossattention.self.key.bias', 'bert.encoder.layer.23.crossattention.self.key.weight', 'bert.encoder.layer.23.crossattention.self.query.bias', 'bert.encoder.layer.23.crossattention.self.query.weight', 'bert.encoder.layer.23.crossattention.self.value.bias', 'bert.encoder.layer.23.crossattention.self.value.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.output.dense.bias', 'bert.encoder.layer.3.crossattention.output.dense.weight', 'bert.encoder.layer.3.crossattention.self.key.bias', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.self.value.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.4.crossattention.self.value.weight', 'bert.encoder.layer.5.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.5.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.key.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.5.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.value.weight', 'lm_head.bias', 'lm_head.decoder.bias']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - } - ], - "source": [ - "encoder = BertGenerationEncoder.from_pretrained(\"google-bert/bert-large-uncased\", bos_token_id=101, eos_token_id=102)\n", - "# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token\n", - "decoder = BertGenerationDecoder.from_pretrained(\"google-bert/bert-large-uncased\", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)\n", - "model = EncoderDecoderModel(encoder=encoder, decoder=decoder)\n", - "\n", - "# create tokenizer...\n", - "tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-large-uncased\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "class CodeDataset(Dataset):\n", - " def __init__(self):\n", - " with open(\"data/conala-train.json\") as f:\n", - " self.data = json.load(f)\n", - "\n", - " def __len__(self):\n", - " return len(self.data)\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.data[idx][\"rewritten_intent\"], self.data[idx][\"snippet\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)\n", - "dataloader = CodeDataset()\n", - "model = model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/2379 [00:00