cse151b-final-project/dqn_letter_gssr.ipynb

88 lines
3.7 KiB
Plaintext
Raw Normal View History

2024-03-19 18:52:10 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"def load_valid_words(file_path='wordle_words.txt'):\n",
" \"\"\"\n",
" Load valid five-letter words from a specified text file.\n",
"\n",
" Parameters:\n",
" - file_path (str): The path to the text file containing valid words.\n",
"\n",
" Returns:\n",
" - list[str]: A list of valid words loaded from the file.\n",
" \"\"\"\n",
" with open(file_path, 'r') as file:\n",
" valid_words = [line.strip() for line in file if len(line.strip()) == 5]\n",
" return valid_words"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[2], line 5\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstable_baselines3\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menv_checker\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m check_env\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mletter_guess\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m LetterGuessingEnv\n\u001b[1;32m----> 5\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mLetterGuessingEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalid_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mload_valid_words\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Make sure to load your valid words\u001b[39;00m\n\u001b[0;32m 6\u001b[0m check_env(env) \u001b[38;5;66;03m# Optional: Verify the environment is compatible with SB3\u001b[39;00m\n\u001b[0;32m 8\u001b[0m model \u001b[38;5;241m=\u001b[39m PPO(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMlpPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, env, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
"\u001b[1;31mTypeError\u001b[0m: LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'"
]
}
],
"source": [
"from stable_baselines3 import PPO # Or any other suitable RL algorithm\n",
"from stable_baselines3.common.env_checker import check_env\n",
"from letter_guess import LetterGuessingEnv\n",
"\n",
"env = LetterGuessingEnv(valid_words=load_valid_words()) # Make sure to load your valid words\n",
"check_env(env) # Optional: Verify the environment is compatible with SB3\n",
"\n",
"model = PPO(\"MlpPolicy\", env, verbose=1)\n",
"\n",
"# Train for a certain number of timesteps\n",
"model.learn(total_timesteps=100000)\n",
"\n",
"# Save the model\n",
"model.save(\"wordle_ppo_model\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}