mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-25 09:39:10 +00:00
new reward scheme
This commit is contained in:
parent
bbe9a1891c
commit
e799c14ece
795
dqn_wordle.ipynb
795
dqn_wordle.ipynb
@ -35,240 +35,495 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7c52630b65904d5e8e200be505d2121a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Output()"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using cpu device\n",
|
||||
"Wrapping the env in a DummyVecEnv.\n",
|
||||
"---------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 5.59 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 544 |\n",
|
||||
"| iterations | 1 |\n",
|
||||
"| time_elapsed | 3 |\n",
|
||||
"| total_timesteps | 2048 |\n",
|
||||
"---------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.77 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 245 |\n",
|
||||
"| iterations | 2 |\n",
|
||||
"| time_elapsed | 16 |\n",
|
||||
"| total_timesteps | 4096 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.021515464 |\n",
|
||||
"| clip_fraction | 0.335 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.00118 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 89.5 |\n",
|
||||
"| n_updates | 10 |\n",
|
||||
"| policy_gradient_loss | -0.0854 |\n",
|
||||
"| value_loss | 262 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.31 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 211 |\n",
|
||||
"| iterations | 3 |\n",
|
||||
"| time_elapsed | 29 |\n",
|
||||
"| total_timesteps | 6144 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02457875 |\n",
|
||||
"| clip_fraction | 0.465 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.161 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 118 |\n",
|
||||
"| n_updates | 20 |\n",
|
||||
"| policy_gradient_loss | -0.0987 |\n",
|
||||
"| value_loss | 217 |\n",
|
||||
"----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5.96 |\n",
|
||||
"| ep_rew_mean | 5.79 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 196 |\n",
|
||||
"| iterations | 4 |\n",
|
||||
"| time_elapsed | 41 |\n",
|
||||
"| total_timesteps | 8192 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02515613 |\n",
|
||||
"| clip_fraction | 0.447 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.151 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 138 |\n",
|
||||
"| n_updates | 30 |\n",
|
||||
"| policy_gradient_loss | -0.103 |\n",
|
||||
"| value_loss | 242 |\n",
|
||||
"----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 4.9 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 177 |\n",
|
||||
"| iterations | 5 |\n",
|
||||
"| time_elapsed | 57 |\n",
|
||||
"| total_timesteps | 10240 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.026685718 |\n",
|
||||
"| clip_fraction | 0.444 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.176 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 96.7 |\n",
|
||||
"| n_updates | 40 |\n",
|
||||
"| policy_gradient_loss | -0.111 |\n",
|
||||
"| value_loss | 211 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.19 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 164 |\n",
|
||||
"| iterations | 6 |\n",
|
||||
"| time_elapsed | 74 |\n",
|
||||
"| total_timesteps | 12288 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02762504 |\n",
|
||||
"| clip_fraction | 0.463 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.186 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 103 |\n",
|
||||
"| n_updates | 50 |\n",
|
||||
"| policy_gradient_loss | -0.115 |\n",
|
||||
"| value_loss | 200 |\n",
|
||||
"----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 5.5 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 155 |\n",
|
||||
"| iterations | 7 |\n",
|
||||
"| time_elapsed | 92 |\n",
|
||||
"| total_timesteps | 14336 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02694263 |\n",
|
||||
"| clip_fraction | 0.458 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.15 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 84.1 |\n",
|
||||
"| n_updates | 60 |\n",
|
||||
"| policy_gradient_loss | -0.116 |\n",
|
||||
"| value_loss | 225 |\n",
|
||||
"----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 7.27 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 154 |\n",
|
||||
"| iterations | 8 |\n",
|
||||
"| time_elapsed | 106 |\n",
|
||||
"| total_timesteps | 16384 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.024316464 |\n",
|
||||
"| clip_fraction | 0.412 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.173 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 126 |\n",
|
||||
"| n_updates | 70 |\n",
|
||||
"| policy_gradient_loss | -0.112 |\n",
|
||||
"| value_loss | 227 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 7.8 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 151 |\n",
|
||||
"| iterations | 9 |\n",
|
||||
"| time_elapsed | 121 |\n",
|
||||
"| total_timesteps | 18432 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.022988513 |\n",
|
||||
"| clip_fraction | 0.391 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.206 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 139 |\n",
|
||||
"| n_updates | 80 |\n",
|
||||
"| policy_gradient_loss | -0.111 |\n",
|
||||
"| value_loss | 228 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 6.14 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 153 |\n",
|
||||
"| iterations | 10 |\n",
|
||||
"| time_elapsed | 133 |\n",
|
||||
"| total_timesteps | 20480 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.022813996 |\n",
|
||||
"| clip_fraction | 0.372 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.199 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 117 |\n",
|
||||
"| n_updates | 90 |\n",
|
||||
"| policy_gradient_loss | -0.108 |\n",
|
||||
"| value_loss | 212 |\n",
|
||||
"-----------------------------------------\n"
|
||||
"Using cuda device\n",
|
||||
"Wrapping the env with a `Monitor` wrapper\n",
|
||||
"Wrapping the env in a DummyVecEnv.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -175 |\n",
|
||||
"| exploration_rate | 0.525 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 10000 |\n",
|
||||
"| fps | 4606 |\n",
|
||||
"| time_elapsed | 10 |\n",
|
||||
"| total_timesteps | 49989 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -208 |\n",
|
||||
"| exploration_rate | 0.0502 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 20000 |\n",
|
||||
"| fps | 1118 |\n",
|
||||
"| time_elapsed | 89 |\n",
|
||||
"| total_timesteps | 99980 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 24.6 |\n",
|
||||
"| n_updates | 12494 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -230 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 30000 |\n",
|
||||
"| fps | 856 |\n",
|
||||
"| time_elapsed | 175 |\n",
|
||||
"| total_timesteps | 149974 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 18.7 |\n",
|
||||
"| n_updates | 24993 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -242 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 40000 |\n",
|
||||
"| fps | 766 |\n",
|
||||
"| time_elapsed | 260 |\n",
|
||||
"| total_timesteps | 199967 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 24 |\n",
|
||||
"| n_updates | 37491 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -186 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 50000 |\n",
|
||||
"| fps | 722 |\n",
|
||||
"| time_elapsed | 346 |\n",
|
||||
"| total_timesteps | 249962 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 21.5 |\n",
|
||||
"| n_updates | 49990 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -183 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 60000 |\n",
|
||||
"| fps | 694 |\n",
|
||||
"| time_elapsed | 431 |\n",
|
||||
"| total_timesteps | 299957 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 17.6 |\n",
|
||||
"| n_updates | 62489 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -181 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 70000 |\n",
|
||||
"| fps | 675 |\n",
|
||||
"| time_elapsed | 517 |\n",
|
||||
"| total_timesteps | 349953 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 26.8 |\n",
|
||||
"| n_updates | 74988 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -196 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 80000 |\n",
|
||||
"| fps | 663 |\n",
|
||||
"| time_elapsed | 603 |\n",
|
||||
"| total_timesteps | 399936 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 22.5 |\n",
|
||||
"| n_updates | 87483 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -174 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 90000 |\n",
|
||||
"| fps | 653 |\n",
|
||||
"| time_elapsed | 688 |\n",
|
||||
"| total_timesteps | 449928 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 21.1 |\n",
|
||||
"| n_updates | 99981 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -155 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 100000 |\n",
|
||||
"| fps | 645 |\n",
|
||||
"| time_elapsed | 774 |\n",
|
||||
"| total_timesteps | 499920 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 22.8 |\n",
|
||||
"| n_updates | 112479 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -153 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 110000 |\n",
|
||||
"| fps | 638 |\n",
|
||||
"| time_elapsed | 860 |\n",
|
||||
"| total_timesteps | 549916 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 16 |\n",
|
||||
"| n_updates | 124978 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -164 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 120000 |\n",
|
||||
"| fps | 633 |\n",
|
||||
"| time_elapsed | 947 |\n",
|
||||
"| total_timesteps | 599915 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 17.8 |\n",
|
||||
"| n_updates | 137478 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -145 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 130000 |\n",
|
||||
"| fps | 628 |\n",
|
||||
"| time_elapsed | 1033 |\n",
|
||||
"| total_timesteps | 649910 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 17.8 |\n",
|
||||
"| n_updates | 149977 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -154 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 140000 |\n",
|
||||
"| fps | 624 |\n",
|
||||
"| time_elapsed | 1120 |\n",
|
||||
"| total_timesteps | 699902 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 20.9 |\n",
|
||||
"| n_updates | 162475 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -192 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 150000 |\n",
|
||||
"| fps | 621 |\n",
|
||||
"| time_elapsed | 1206 |\n",
|
||||
"| total_timesteps | 749884 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 18.3 |\n",
|
||||
"| n_updates | 174970 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -170 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 160000 |\n",
|
||||
"| fps | 618 |\n",
|
||||
"| time_elapsed | 1293 |\n",
|
||||
"| total_timesteps | 799869 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 17.7 |\n",
|
||||
"| n_updates | 187467 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -233 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 170000 |\n",
|
||||
"| fps | 615 |\n",
|
||||
"| time_elapsed | 1380 |\n",
|
||||
"| total_timesteps | 849855 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 21.6 |\n",
|
||||
"| n_updates | 199963 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -146 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 180000 |\n",
|
||||
"| fps | 613 |\n",
|
||||
"| time_elapsed | 1466 |\n",
|
||||
"| total_timesteps | 899847 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 19.4 |\n",
|
||||
"| n_updates | 212461 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -142 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 190000 |\n",
|
||||
"| fps | 611 |\n",
|
||||
"| time_elapsed | 1553 |\n",
|
||||
"| total_timesteps | 949846 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 22.9 |\n",
|
||||
"| n_updates | 224961 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -171 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 200000 |\n",
|
||||
"| fps | 609 |\n",
|
||||
"| time_elapsed | 1640 |\n",
|
||||
"| total_timesteps | 999839 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 20.3 |\n",
|
||||
"| n_updates | 237459 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
||||
],
|
||||
"text/plain": []
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<stable_baselines3.ppo.ppo.PPO at 0x2200a962b50>"
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"total_timesteps = 20_000\n",
|
||||
"model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
||||
"model.learn(total_timesteps=total_timesteps)"
|
||||
"total_timesteps = 1_000_000\n",
|
||||
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
||||
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"dqn_wordle\")"
|
||||
"model.save(\"dqn_new_rewards\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -280,88 +535,28 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"Exception: code() argument 13 must be str, not int\n",
|
||||
" warnings.warn(\n",
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"Exception: code() argument 13 must be str, not int\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = PPO.load(\"dqn_wordle\")"
|
||||
"# model = DQN.load(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [18 8 5 13 5 2 3 2 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]]\n",
|
||||
"[[16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [18 8 5 13 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]]\n",
|
||||
"[[16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [18 8 5 13 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]]\n",
|
||||
"[[16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [18 8 5 13 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]]\n",
|
||||
"[[16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [18 1 11 5 4 2 1 3 2 3]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]]\n",
|
||||
"[[16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [18 1 11 5 4 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]]\n",
|
||||
"[[16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [18 1 11 5 4 3 3 3 2 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]]\n",
|
||||
"[[16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [18 8 5 13 5 1 3 3 2 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]]\n",
|
||||
"[[16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [18 8 5 13 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]]\n",
|
||||
"[[16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [18 8 5 13 5 3 3 1 1 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]]\n",
|
||||
"0\n"
|
||||
]
|
||||
}
|
||||
@ -369,7 +564,7 @@
|
||||
"source": [
|
||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||
"\n",
|
||||
"for i in range(10):\n",
|
||||
"for i in range(1000):\n",
|
||||
" \n",
|
||||
" state, info = env.reset()\n",
|
||||
"\n",
|
||||
@ -386,16 +581,62 @@
|
||||
" if info[\"correct\"]:\n",
|
||||
" wins += 1\n",
|
||||
"\n",
|
||||
"print(wins)\n"
|
||||
"print(wins)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n",
|
||||
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n",
|
||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
||||
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n",
|
||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n",
|
||||
" -130)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"state, reward"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print()"
|
||||
"blah = (14, 1, 9, 22, 5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"blah in info['guesses']"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -6,6 +6,7 @@ from sty import fg, bg, ef, rs
|
||||
from collections import Counter
|
||||
from gym_wordle.utils import to_english, to_array, get_words
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class WordList(gym.spaces.Discrete):
|
||||
@ -160,7 +161,14 @@ class WordleEnv(gym.Env):
|
||||
|
||||
self.n_rounds = 6
|
||||
self.n_letters = 5
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
self.info = {
|
||||
'correct': False,
|
||||
'guesses': set(),
|
||||
'known_positions': np.full(5, -1), # -1 for unknown, else letter index
|
||||
'known_letters': set(), # Letters known to be in the word
|
||||
'not_in_word': set(), # Letters known not to be in the word
|
||||
'tried_positions': defaultdict(set) # Positions tried for each letter
|
||||
}
|
||||
|
||||
def _highlighter(self, char: str, flag: int) -> str:
|
||||
"""Terminal renderer functionality. Properly highlights a character
|
||||
@ -192,13 +200,57 @@ class WordleEnv(gym.Env):
|
||||
"""
|
||||
self.round = 0
|
||||
self.solution = self.solution_space.sample()
|
||||
self.soln_hash = set(self.solution_space[self.solution])
|
||||
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
self.info = {
|
||||
'correct': False,
|
||||
'guesses': set(),
|
||||
'known_positions': np.full(5, -1),
|
||||
'known_letters': set(),
|
||||
'not_in_word': set(),
|
||||
'tried_positions': defaultdict(set)
|
||||
}
|
||||
|
||||
self.simulate_first_guess()
|
||||
|
||||
return self.state, self.info
|
||||
|
||||
def simulate_first_guess(self):
|
||||
fixed_first_guess = "rates"
|
||||
fixed_first_guess_array = to_array(fixed_first_guess)
|
||||
|
||||
# Simulate the feedback for each letter in the fixed first guess
|
||||
feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array
|
||||
for i, letter in enumerate(fixed_first_guess_array):
|
||||
if letter in self.solution_space[self.solution]:
|
||||
if letter == self.solution_space[self.solution][i]:
|
||||
feedback[i] = 1 # Correct position
|
||||
else:
|
||||
feedback[i] = 2 # Correct letter, wrong position
|
||||
else:
|
||||
feedback[i] = 3 # Letter not in word
|
||||
|
||||
# Update the state to reflect the fixed first guess and its feedback
|
||||
self.state[0, :self.n_letters] = fixed_first_guess_array
|
||||
self.state[0, self.n_letters:] = feedback
|
||||
|
||||
# Update self.info based on the feedback
|
||||
for i, flag in enumerate(feedback):
|
||||
if flag == self.right_pos:
|
||||
# Mark letter as correctly placed
|
||||
self.info['known_positions'][i] = fixed_first_guess_array[i]
|
||||
elif flag == self.wrong_pos:
|
||||
# Note the letter is in the word but in a different position
|
||||
self.info['known_letters'].add(fixed_first_guess_array[i])
|
||||
elif flag == self.wrong_char:
|
||||
# Note the letter is not in the word
|
||||
self.info['not_in_word'].add(fixed_first_guess_array[i])
|
||||
|
||||
# Since we're simulating the first guess, increment the round counter
|
||||
self.round = 1
|
||||
|
||||
def render(self, mode: str = 'human'):
|
||||
"""Renders the Wordle environment.
|
||||
|
||||
@ -220,67 +272,69 @@ class WordleEnv(gym.Env):
|
||||
super().render(mode=mode)
|
||||
|
||||
def step(self, action):
|
||||
"""Run one step of the Wordle game. Every game must be previously
|
||||
initialized by a call to the `reset` method.
|
||||
|
||||
Args:
|
||||
action: Word guessed by the agent.
|
||||
|
||||
Returns:
|
||||
state (object): Wordle game state after the guess.
|
||||
reward (float): Reward associated with the guess.
|
||||
done (bool): Whether the game has ended.
|
||||
info (dict): Auxiliary diagnostic information.
|
||||
"""
|
||||
assert self.action_space.contains(action), 'Invalid word!'
|
||||
|
||||
action = self.action_space[action]
|
||||
solution = self.solution_space[self.solution]
|
||||
guessed_word = self.action_space[action]
|
||||
solution_word = self.solution_space[self.solution]
|
||||
|
||||
self.state[self.round][:self.n_letters] = action
|
||||
reward = 0
|
||||
correct_guess = np.array_equal(guessed_word, solution_word)
|
||||
|
||||
counter = Counter()
|
||||
for i, char in enumerate(action):
|
||||
flag_i = i + self.n_letters
|
||||
counter[char] += 1
|
||||
# Initialize flags for current guess
|
||||
current_flags = np.full(self.n_letters, self.wrong_char)
|
||||
|
||||
if char == solution[i]:
|
||||
self.state[self.round, flag_i] = self.right_pos
|
||||
elif counter[char] <= (char == solution).sum():
|
||||
self.state[self.round, flag_i] = self.wrong_pos
|
||||
# Track newly discovered information
|
||||
new_info = False
|
||||
|
||||
for i in range(self.n_letters):
|
||||
guessed_letter = guessed_word[i]
|
||||
if guessed_letter in solution_word:
|
||||
# Penalize for reusing a letter found to not be in the word
|
||||
if guessed_letter in self.info['not_in_word']:
|
||||
reward -= 2
|
||||
|
||||
# Handle correct letter in the correct position
|
||||
if guessed_letter == solution_word[i]:
|
||||
current_flags[i] = self.right_pos
|
||||
if self.info['known_positions'][i] != guessed_letter:
|
||||
reward += 10 # Large reward for new correct placement
|
||||
new_info = True
|
||||
self.info['known_positions'][i] = guessed_letter
|
||||
else:
|
||||
reward += 20 # Large reward for repeating correct placement
|
||||
else:
|
||||
current_flags[i] = self.wrong_pos
|
||||
if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]:
|
||||
reward += 10 # Reward for guessing a letter in a new position
|
||||
new_info = True
|
||||
else:
|
||||
reward -= 20 # Penalize for not leveraging known information
|
||||
self.info['known_letters'].add(guessed_letter)
|
||||
self.info['tried_positions'][guessed_letter].add(i)
|
||||
else:
|
||||
self.state[self.round, flag_i] = self.wrong_char
|
||||
# New incorrect letter
|
||||
if guessed_letter not in self.info['not_in_word']:
|
||||
reward -= 2 # Penalize for guessing a letter not in the word
|
||||
self.info['not_in_word'].add(guessed_letter)
|
||||
new_info = True
|
||||
else:
|
||||
reward -= 15 # Larger penalty for repeating an incorrect letter
|
||||
|
||||
# Update observation state with the current guess and flags
|
||||
self.state[self.round, :self.n_letters] = guessed_word
|
||||
self.state[self.round, self.n_letters:] = current_flags
|
||||
|
||||
# Check if the game is over
|
||||
done = self.round == self.n_rounds - 1 or correct_guess
|
||||
self.info['correct'] = correct_guess
|
||||
|
||||
if correct_guess:
|
||||
reward += 100 # Major reward for winning
|
||||
elif done:
|
||||
reward -= 50 # Penalty for losing without using new information effectively
|
||||
elif not new_info:
|
||||
reward -= 10 # Penalty if no new information was used in this guess
|
||||
|
||||
self.round += 1
|
||||
|
||||
correct = (action == solution).all()
|
||||
game_over = (self.round == self.n_rounds)
|
||||
|
||||
done = correct or game_over
|
||||
|
||||
reward = 0
|
||||
# correct spot
|
||||
reward += np.sum(self.state[:, 5:] == 1) * 2
|
||||
|
||||
# correct letter not correct spot
|
||||
reward += np.sum(self.state[:, 5:] == 2) * 1
|
||||
|
||||
# incorrect letter
|
||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
||||
|
||||
# guess same word as before
|
||||
hashable_action = tuple(action)
|
||||
if hashable_action in self.info['guesses']:
|
||||
reward += -10
|
||||
else: # guess different word
|
||||
reward += 10
|
||||
|
||||
self.info['guesses'].add(hashable_action)
|
||||
|
||||
# for game ending in win or loss
|
||||
reward += 10 if correct else -10 if done else 0
|
||||
|
||||
self.info['correct'] = correct
|
||||
|
||||
# observation, reward, terminated, truncated, info
|
||||
return self.state, reward, done, False, self.info
|
||||
|
189
test.ipynb
Normal file
189
test.ipynb
Normal file
@ -0,0 +1,189 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import defaultdict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def my_func()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"t = defaultdict(lambda: [0, 1, 2, 3, 4])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"defaultdict(<function __main__.<lambda>()>, {})"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"t"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[0, 1, 2, 3, 4]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"t['t']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"defaultdict(<function __main__.<lambda>()>, {'t': [0, 1, 2, 3, 4]})"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"t"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"False"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"'x' in t"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = np.array([1, 1, 1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x[:] = 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([0, 0, 0])"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"'abcde'aaa\n",
|
||||
" 33221\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue
Block a user