mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-09-20 00:47:55 +00:00
88 lines
3.7 KiB
Plaintext
88 lines
3.7 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def load_valid_words(file_path='wordle_words.txt'):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Load valid five-letter words from a specified text file.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - file_path (str): The path to the text file containing valid words.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - list[str]: A list of valid words loaded from the file.\n",
|
||
|
" \"\"\"\n",
|
||
|
" with open(file_path, 'r') as file:\n",
|
||
|
" valid_words = [line.strip() for line in file if len(line.strip()) == 5]\n",
|
||
|
" return valid_words"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"ename": "TypeError",
|
||
|
"evalue": "LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'",
|
||
|
"output_type": "error",
|
||
|
"traceback": [
|
||
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
|
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||
|
"Cell \u001b[1;32mIn[2], line 5\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstable_baselines3\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menv_checker\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m check_env\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mletter_guess\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m LetterGuessingEnv\n\u001b[1;32m----> 5\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mLetterGuessingEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalid_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mload_valid_words\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Make sure to load your valid words\u001b[39;00m\n\u001b[0;32m 6\u001b[0m check_env(env) \u001b[38;5;66;03m# Optional: Verify the environment is compatible with SB3\u001b[39;00m\n\u001b[0;32m 8\u001b[0m model \u001b[38;5;241m=\u001b[39m PPO(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMlpPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, env, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
|
||
|
"\u001b[1;31mTypeError\u001b[0m: LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from stable_baselines3 import PPO # Or any other suitable RL algorithm\n",
|
||
|
"from stable_baselines3.common.env_checker import check_env\n",
|
||
|
"from letter_guess import LetterGuessingEnv\n",
|
||
|
"\n",
|
||
|
"env = LetterGuessingEnv(valid_words=load_valid_words()) # Make sure to load your valid words\n",
|
||
|
"check_env(env) # Optional: Verify the environment is compatible with SB3\n",
|
||
|
"\n",
|
||
|
"model = PPO(\"MlpPolicy\", env, verbose=1)\n",
|
||
|
"\n",
|
||
|
"# Train for a certain number of timesteps\n",
|
||
|
"model.learn(total_timesteps=100000)\n",
|
||
|
"\n",
|
||
|
"# Save the model\n",
|
||
|
"model.save(\"wordle_ppo_model\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "env",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.11.5"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|