Multi-armed bandits¶
River has a bandit module. It contains several multi-armed bandit policies, bandit environments, and utilities to benchmark policies on bandit problems.
Bandit environments in River implement the Gym interface. You can thus load them with gym.make. Note that Gym is intended for reinforcement learning algorithms, while bandit policies are the simplest form of reinforcement learning. Bandit policies learn by receiving a reward after each step, while reinforcement learning algorithms have to learn from feedback that may arrive at the end of a (long) sequence of steps.
import gymnasium as gym
for k in gym.envs.registry:
if k.startswith('river_bandits'):
print(k)
River's bandit module offers the bandit.evaluate function to benchmark several policies on a given environment. It takes as input a list of bandit policies, a bandit environment (the problem to solve), and a reward object.
import gymnasium as gym
from river import bandit
import pandas as pd
from tqdm import tqdm
from river import stats
policies=[
bandit.EpsilonGreedy(epsilon=0.1),
bandit.EpsilonGreedy(epsilon=0.01),
bandit.EpsilonGreedy(epsilon=0),
]
env = gym.make(
'river_bandits/KArmedTestbed-v0',
max_episode_steps=1000
)
trace = bandit.evaluate(
policies=policies,
env=env,
reward_stat=stats.Mean(),
n_episodes=(n_episodes := 2000),
)
The bandit.evaluate function returns a generator containing the results at each step of the benchmark. This can be wrapped with a pandas.DataFrame to gather all the results.
trace_df = pd.DataFrame(tqdm(
trace, position=0, total=(
n_episodes *
len(policies) *
env._max_episode_steps
)
))
trace_df.sample(5, random_state=42)
0%| | 0/6000000 [00:00<?, ?it/s]
0%| | 16183/6000000 [00:00<00:36, 161818.19it/s]
1%| | 33105/6000000 [00:00<00:35, 166164.00it/s]
1%| | 50352/6000000 [00:00<00:35, 169038.44it/s]
1%| | 67358/6000000 [00:00<00:35, 169436.36it/s]
1%|▏ | 84358/6000000 [00:00<00:34, 169638.20it/s]
2%|▏ | 101322/6000000 [00:00<00:35, 167936.85it/s]
2%|▏ | 118119/6000000 [00:00<00:35, 164568.51it/s]
2%|▏ | 135008/6000000 [00:00<00:35, 165917.62it/s]
3%|▎ | 151612/6000000 [00:00<00:42, 136508.89it/s]
3%|▎ | 168627/6000000 [00:01<00:40, 145461.43it/s]
3%|▎ | 185742/6000000 [00:01<00:38, 152535.15it/s]
3%|▎ | 202565/6000000 [00:01<00:36, 156966.49it/s]
4%|▎ | 219449/6000000 [00:01<00:36, 160377.32it/s]
4%|▍ | 236542/6000000 [00:01<00:35, 163448.50it/s]
4%|▍ | 253317/6000000 [00:01<00:34, 164709.54it/s]
5%|▍ | 270136/6000000 [00:01<00:34, 165735.39it/s]
5%|▍ | 287328/6000000 [00:01<00:34, 167568.81it/s]
5%|▌ | 304189/6000000 [00:01<00:33, 167875.71it/s]
5%|▌ | 321134/6000000 [00:01<00:33, 168343.82it/s]
6%|▌ | 338282/6000000 [00:02<00:33, 169278.19it/s]
6%|▌ | 355258/6000000 [00:02<00:33, 169420.36it/s]
6%|▌ | 372454/6000000 [00:02<00:33, 170178.38it/s]
6%|▋ | 389607/6000000 [00:02<00:32, 170582.67it/s]
7%|▋ | 406676/6000000 [00:02<00:32, 170612.57it/s]
7%|▋ | 423745/6000000 [00:02<00:32, 170337.29it/s]
7%|▋ | 441001/6000000 [00:02<00:32, 170861.75it/s]
8%|▊ | 458414/6000000 [00:02<00:32, 171837.38it/s]
8%|▊ | 475601/6000000 [00:02<00:32, 171787.07it/s]
8%|▊ | 492782/6000000 [00:02<00:32, 171348.74it/s]
9%|▊ | 510001/6000000 [00:03<00:32, 171100.30it/s]
9%|▉ | 527113/6000000 [00:03<00:32, 170869.56it/s]
9%|▉ | 544201/6000000 [00:03<00:31, 170858.49it/s]
9%|▉ | 561290/6000000 [00:03<00:31, 170866.74it/s]
10%|▉ | 578637/6000000 [00:03<00:31, 171643.51it/s]
10%|▉ | 595826/6000000 [00:03<00:31, 171714.71it/s]
10%|█ | 612998/6000000 [00:03<00:31, 171657.71it/s]
11%|█ | 630266/6000000 [00:03<00:31, 171960.49it/s]
11%|█ | 647550/6000000 [00:03<00:31, 172223.48it/s]
11%|█ | 664773/6000000 [00:03<00:31, 171816.81it/s]
11%|█▏ | 681955/6000000 [00:04<00:31, 171207.22it/s]
12%|█▏ | 699077/6000000 [00:04<00:31, 170514.59it/s]
12%|█▏ | 716130/6000000 [00:04<00:31, 170314.75it/s]
12%|█▏ | 733162/6000000 [00:04<00:32, 163227.79it/s]
13%|█▎ | 750143/6000000 [00:04<00:31, 165136.88it/s]
13%|█▎ | 767407/6000000 [00:04<00:31, 167334.55it/s]
13%|█▎ | 784379/6000000 [00:04<00:31, 168035.59it/s]
13%|█▎ | 801336/6000000 [00:04<00:30, 168488.40it/s]
14%|█▎ | 818726/6000000 [00:04<00:30, 170095.66it/s]
14%|█▍ | 835752/6000000 [00:05<00:30, 169969.33it/s]
14%|█▍ | 852843/6000000 [00:05<00:30, 170249.04it/s]
15%|█▍ | 870019/6000000 [00:05<00:30, 170697.84it/s]
15%|█▍ | 887432/6000000 [00:05<00:29, 171721.58it/s]
15%|█▌ | 904609/6000000 [00:05<00:29, 171248.10it/s]
15%|█▌ | 921776/6000000 [00:05<00:29, 171373.22it/s]
16%|█▌ | 939001/6000000 [00:05<00:29, 171241.35it/s]
16%|█▌ | 956419/6000000 [00:05<00:29, 172118.00it/s]
16%|█▌ | 973633/6000000 [00:05<00:29, 171326.58it/s]
17%|█▋ | 990768/6000000 [00:05<00:29, 170523.75it/s]
17%|█▋ | 1008001/6000000 [00:06<00:29, 170611.74it/s]
17%|█▋ | 1025179/6000000 [00:06<00:29, 170956.70it/s]
17%|█▋ | 1042276/6000000 [00:06<00:29, 170896.72it/s]
18%|█▊ | 1059367/6000000 [00:06<00:28, 170704.60it/s]
18%|█▊ | 1076672/6000000 [00:06<00:28, 171403.92it/s]
18%|█▊ | 1093813/6000000 [00:06<00:28, 171289.54it/s]
19%|█▊ | 1110943/6000000 [00:06<00:28, 171105.86it/s]
19%|█▉ | 1128063/6000000 [00:06<00:28, 171131.03it/s]
19%|█▉ | 1145720/6000000 [00:06<00:28, 172758.50it/s]
19%|█▉ | 1162997/6000000 [00:06<00:28, 172138.45it/s]
20%|█▉ | 1180212/6000000 [00:07<00:28, 171562.95it/s]
20%|█▉ | 1197488/6000000 [00:07<00:27, 171918.48it/s]
20%|██ | 1214681/6000000 [00:07<00:27, 171618.67it/s]
21%|██ | 1231844/6000000 [00:07<00:27, 171375.50it/s]
21%|██ | 1248982/6000000 [00:07<00:27, 171101.52it/s]
21%|██ | 1266140/6000000 [00:07<00:27, 171243.01it/s]
21%|██▏ | 1283564/6000000 [00:07<00:27, 172137.82it/s]
22%|██▏ | 1300779/6000000 [00:07<00:27, 171430.49it/s]
22%|██▏ | 1317923/6000000 [00:07<00:27, 171254.18it/s]
22%|██▏ | 1335049/6000000 [00:07<00:27, 170856.71it/s]
23%|██▎ | 1352329/6000000 [00:08<00:27, 171436.46it/s]
23%|██▎ | 1369563/6000000 [00:08<00:26, 171705.80it/s]
23%|██▎ | 1386735/6000000 [00:08<00:26, 170962.90it/s]
23%|██▎ | 1403893/6000000 [00:08<00:26, 171144.57it/s]
24%|██▎ | 1421009/6000000 [00:08<00:26, 170875.11it/s]
24%|██▍ | 1438098/6000000 [00:08<00:26, 170444.98it/s]
24%|██▍ | 1455143/6000000 [00:08<00:26, 170257.82it/s]
25%|██▍ | 1472479/6000000 [00:08<00:26, 171182.68it/s]
25%|██▍ | 1489598/6000000 [00:08<00:26, 170791.00it/s]
25%|██▌ | 1506678/6000000 [00:08<00:26, 169255.95it/s]
25%|██▌ | 1524001/6000000 [00:09<00:26, 170148.30it/s]
26%|██▌ | 1541163/6000000 [00:09<00:26, 170583.98it/s]
26%|██▌ | 1558224/6000000 [00:09<00:26, 170094.33it/s]
26%|██▋ | 1575235/6000000 [00:09<00:26, 169534.51it/s]
27%|██▋ | 1592269/6000000 [00:09<00:25, 169772.76it/s]
27%|██▋ | 1609264/6000000 [00:09<00:25, 169823.81it/s]
27%|██▋ | 1626373/6000000 [00:09<00:25, 170201.57it/s]
27%|██▋ | 1643555/6000000 [00:09<00:25, 170684.14it/s]
28%|██▊ | 1660624/6000000 [00:09<00:25, 170235.63it/s]
28%|██▊ | 1677688/6000000 [00:09<00:25, 170355.91it/s]
28%|██▊ | 1694863/6000000 [00:10<00:25, 170771.86it/s]
29%|██▊ | 1712063/6000000 [00:10<00:25, 171136.89it/s]
29%|██▉ | 1729178/6000000 [00:10<00:25, 170762.74it/s]
29%|██▉ | 1746255/6000000 [00:10<00:24, 170256.84it/s]
29%|██▉ | 1763560/6000000 [00:10<00:24, 171089.45it/s]
30%|██▉ | 1780670/6000000 [00:10<00:24, 170545.27it/s]
30%|██▉ | 1797726/6000000 [00:10<00:24, 170224.86it/s]
30%|███ | 1815001/6000000 [00:10<00:24, 170618.60it/s]
31%|███ | 1832206/6000000 [00:10<00:24, 171041.80it/s]
31%|███ | 1849311/6000000 [00:10<00:24, 170656.73it/s]
31%|███ | 1866378/6000000 [00:11<00:24, 170343.09it/s]
31%|███▏ | 1883644/6000000 [00:11<00:24, 171033.77it/s]
32%|███▏ | 1900748/6000000 [00:11<00:23, 170938.75it/s]
32%|███▏ | 1917843/6000000 [00:11<00:23, 170511.44it/s]
32%|███▏ | 1934922/6000000 [00:11<00:23, 170594.04it/s]
33%|███▎ | 1951982/6000000 [00:11<00:23, 170169.79it/s]
33%|███▎ | 1969000/6000000 [00:11<00:23, 169466.83it/s]
33%|███▎ | 1986002/6000000 [00:11<00:23, 169590.81it/s]
33%|███▎ | 2003052/6000000 [00:11<00:23, 169861.34it/s]
34%|███▎ | 2020039/6000000 [00:11<00:23, 169276.79it/s]
34%|███▍ | 2037001/6000000 [00:12<00:23, 169093.48it/s]
34%|███▍ | 2054161/6000000 [00:12<00:23, 169841.10it/s]
35%|███▍ | 2071488/6000000 [00:12<00:22, 170865.07it/s]
35%|███▍ | 2088576/6000000 [00:12<00:22, 170698.28it/s]
35%|███▌ | 2105647/6000000 [00:12<00:22, 170464.60it/s]
35%|███▌ | 2122694/6000000 [00:12<00:22, 170352.62it/s]
36%|███▌ | 2139730/6000000 [00:12<00:22, 169914.81it/s]
36%|███▌ | 2156722/6000000 [00:12<00:22, 167470.62it/s]
36%|███▌ | 2173477/6000000 [00:12<00:23, 162733.56it/s]
37%|███▋ | 2190368/6000000 [00:12<00:23, 164532.53it/s]
37%|███▋ | 2207395/6000000 [00:13<00:22, 166216.93it/s]
37%|███▋ | 2224374/6000000 [00:13<00:22, 167271.46it/s]
37%|███▋ | 2241356/6000000 [00:13<00:22, 168026.12it/s]
38%|███▊ | 2258455/6000000 [00:13<00:22, 168906.41it/s]
38%|███▊ | 2275462/6000000 [00:13<00:22, 169252.23it/s]
38%|███▊ | 2292467/6000000 [00:13<00:21, 169487.51it/s]
38%|███▊ | 2309800/6000000 [00:13<00:21, 170636.62it/s]
39%|███▉ | 2326868/6000000 [00:13<00:21, 168387.61it/s]
39%|███▉ | 2343824/6000000 [00:13<00:21, 168732.81it/s]
39%|███▉ | 2361103/6000000 [00:13<00:21, 169938.84it/s]
40%|███▉ | 2378461/6000000 [00:14<00:21, 171024.71it/s]
40%|███▉ | 2395610/6000000 [00:14<00:21, 171162.13it/s]
40%|████ | 2412747/6000000 [00:14<00:20, 171222.18it/s]
40%|████ | 2429896/6000000 [00:14<00:20, 171301.08it/s]
41%|████ | 2447028/6000000 [00:14<00:20, 170838.54it/s]
41%|████ | 2464277/6000000 [00:14<00:20, 171331.78it/s]
41%|████▏ | 2481451/6000000 [00:14<00:20, 171451.49it/s]
42%|████▏ | 2498836/6000000 [00:14<00:20, 172167.63it/s]
42%|████▏ | 2516054/6000000 [00:14<00:20, 171610.62it/s]
42%|████▏ | 2533216/6000000 [00:14<00:20, 171344.66it/s]
43%|████▎ | 2550352/6000000 [00:15<00:20, 170946.72it/s]
43%|████▎ | 2567448/6000000 [00:15<00:20, 170936.36it/s]
43%|████▎ | 2584542/6000000 [00:15<00:19, 170884.36it/s]
43%|████▎ | 2601631/6000000 [00:15<00:19, 170362.17it/s]
44%|████▎ | 2618668/6000000 [00:15<00:20, 166755.25it/s]
44%|████▍ | 2635360/6000000 [00:15<00:20, 161666.38it/s]
44%|████▍ | 2652361/6000000 [00:15<00:20, 164081.84it/s]
44%|████▍ | 2669701/6000000 [00:15<00:19, 166809.25it/s]
45%|████▍ | 2686829/6000000 [00:15<00:19, 168126.54it/s]
45%|████▌ | 2703949/6000000 [00:15<00:19, 169034.90it/s]
45%|████▌ | 2720871/6000000 [00:16<00:19, 167612.78it/s]
46%|████▌ | 2738019/6000000 [00:16<00:19, 168757.18it/s]
46%|████▌ | 2755023/6000000 [00:16<00:19, 169137.51it/s]
46%|████▌ | 2772020/6000000 [00:16<00:19, 169383.18it/s]
46%|████▋ | 2789494/6000000 [00:16<00:18, 170982.73it/s]
47%|████▋ | 2806659/6000000 [00:16<00:18, 171181.28it/s]
47%|████▋ | 2823962/6000000 [00:16<00:18, 171731.86it/s]
47%|████▋ | 2841138/6000000 [00:16<00:18, 171215.74it/s]
48%|████▊ | 2858491/6000000 [00:16<00:18, 171905.55it/s]
48%|████▊ | 2875686/6000000 [00:16<00:18, 171915.83it/s]
48%|████▊ | 2892879/6000000 [00:17<00:18, 171728.44it/s]
49%|████▊ | 2910101/6000000 [00:17<00:17, 171873.60it/s]
49%|████▉ | 2927456/6000000 [00:17<00:17, 172374.92it/s]
49%|████▉ | 2944694/6000000 [00:17<00:17, 171329.22it/s]
49%|████▉ | 2962048/6000000 [00:17<00:17, 171988.39it/s]
50%|████▉ | 2979249/6000000 [00:17<00:17, 171452.29it/s]
50%|████▉ | 2996607/6000000 [00:17<00:17, 172086.84it/s]
50%|█████ | 3013817/6000000 [00:17<00:17, 171690.32it/s]
51%|█████ | 3031018/6000000 [00:17<00:17, 171783.45it/s]
51%|█████ | 3048198/6000000 [00:17<00:17, 171396.13it/s]
51%|█████ | 3065395/6000000 [00:18<00:17, 171566.64it/s]
51%|█████▏ | 3082553/6000000 [00:18<00:17, 171420.98it/s]
52%|█████▏ | 3099696/6000000 [00:18<00:16, 171143.21it/s]
52%|█████▏ | 3116939/6000000 [00:18<00:16, 171526.75it/s]
52%|█████▏ | 3134092/6000000 [00:18<00:16, 171061.91it/s]
53%|█████▎ | 3151199/6000000 [00:18<00:16, 170793.58it/s]
53%|█████▎ | 3168405/6000000 [00:18<00:16, 171171.41it/s]
53%|█████▎ | 3185688/6000000 [00:18<00:16, 171665.77it/s]
53%|█████▎ | 3202855/6000000 [00:18<00:16, 168477.98it/s]
54%|█████▎ | 3219926/6000000 [00:18<00:16, 169135.36it/s]
54%|█████▍ | 3237001/6000000 [00:19<00:16, 169515.91it/s]
54%|█████▍ | 3254387/6000000 [00:19<00:16, 170807.06it/s]
55%|█████▍ | 3271474/6000000 [00:19<00:16, 169345.77it/s]
55%|█████▍ | 3288416/6000000 [00:19<00:16, 169340.40it/s]
55%|█████▌ | 3305950/6000000 [00:19<00:15, 171125.03it/s]
55%|█████▌ | 3323150/6000000 [00:19<00:15, 171383.73it/s]
56%|█████▌ | 3340292/6000000 [00:19<00:15, 171358.95it/s]
56%|█████▌ | 3357530/6000000 [00:19<00:15, 171661.84it/s]
56%|█████▌ | 3374815/6000000 [00:19<00:15, 172014.38it/s]
57%|█████▋ | 3392102/6000000 [00:20<00:15, 172268.33it/s]
57%|█████▋ | 3409330/6000000 [00:20<00:15, 172072.67it/s]
57%|█████▋ | 3426538/6000000 [00:20<00:15, 171370.64it/s]
57%|█████▋ | 3443677/6000000 [00:20<00:14, 171132.74it/s]
58%|█████▊ | 3460791/6000000 [00:20<00:14, 170608.90it/s]
58%|█████▊ | 3477853/6000000 [00:20<00:14, 170089.20it/s]
58%|█████▊ | 3495001/6000000 [00:20<00:14, 170130.97it/s]
59%|█████▊ | 3512217/6000000 [00:20<00:14, 170733.06it/s]
59%|█████▉ | 3529385/6000000 [00:20<00:14, 171014.90it/s]
59%|█████▉ | 3546488/6000000 [00:20<00:14, 170736.90it/s]
59%|█████▉ | 3563624/6000000 [00:21<00:14, 170920.39it/s]
60%|█████▉ | 3580717/6000000 [00:21<00:14, 170123.02it/s]
60%|█████▉ | 3597731/6000000 [00:21<00:14, 169919.97it/s]
60%|██████ | 3615058/6000000 [00:21<00:13, 170919.96it/s]
61%|██████ | 3632151/6000000 [00:21<00:13, 170890.10it/s]
61%|██████ | 3649311/6000000 [00:21<00:13, 171100.58it/s]
61%|██████ | 3666422/6000000 [00:21<00:13, 170465.79it/s]
61%|██████▏ | 3683902/6000000 [00:21<00:13, 171758.60it/s]
62%|██████▏ | 3701079/6000000 [00:21<00:13, 171499.02it/s]
62%|██████▏ | 3718230/6000000 [00:21<00:13, 171425.08it/s]
62%|██████▏ | 3735374/6000000 [00:22<00:13, 170059.90it/s]
63%|██████▎ | 3752485/6000000 [00:22<00:13, 170370.52it/s]
63%|██████▎ | 3769525/6000000 [00:22<00:13, 169581.81it/s]
63%|██████▎ | 3786486/6000000 [00:22<00:13, 168581.93it/s]
63%|██████▎ | 3803699/6000000 [00:22<00:12, 169634.70it/s]
64%|██████▎ | 3820665/6000000 [00:22<00:12, 169551.43it/s]
64%|██████▍ | 3837622/6000000 [00:22<00:12, 169398.69it/s]
64%|██████▍ | 3854770/6000000 [00:22<00:12, 170017.60it/s]
65%|██████▍ | 3872004/6000000 [00:22<00:12, 170711.55it/s]
65%|██████▍ | 3889077/6000000 [00:22<00:12, 170294.94it/s]
65%|██████▌ | 3906108/6000000 [00:23<00:12, 169916.28it/s]
65%|██████▌ | 3923234/6000000 [00:23<00:12, 170315.62it/s]
66%|██████▌ | 3940267/6000000 [00:23<00:12, 169141.13it/s]
66%|██████▌ | 3957184/6000000 [00:23<00:12, 168161.27it/s]
66%|██████▌ | 3974060/6000000 [00:23<00:12, 168337.57it/s]
67%|██████▋ | 3990896/6000000 [00:23<00:11, 167473.48it/s]
67%|██████▋ | 4007968/6000000 [00:23<00:11, 168437.71it/s]
67%|██████▋ | 4024837/6000000 [00:23<00:11, 168509.90it/s]
67%|██████▋ | 4041691/6000000 [00:23<00:11, 168515.10it/s]
68%|██████▊ | 4058866/6000000 [00:23<00:11, 169480.27it/s]
68%|██████▊ | 4075875/6000000 [00:24<00:11, 169661.31it/s]
68%|██████▊ | 4092842/6000000 [00:24<00:11, 169499.91it/s]
69%|██████▊ | 4110001/6000000 [00:24<00:11, 169659.04it/s]
69%|██████▉ | 4127230/6000000 [00:24<00:10, 170442.43it/s]
69%|██████▉ | 4144275/6000000 [00:24<00:10, 169737.15it/s]
69%|██████▉ | 4161250/6000000 [00:24<00:10, 169427.96it/s]
70%|██████▉ | 4178530/6000000 [00:24<00:10, 170431.90it/s]
70%|██████▉ | 4195575/6000000 [00:24<00:10, 169611.46it/s]
70%|███████ | 4212538/6000000 [00:24<00:10, 169260.20it/s]
70%|███████ | 4229465/6000000 [00:24<00:10, 169196.49it/s]
71%|███████ | 4246386/6000000 [00:25<00:10, 168617.88it/s]
71%|███████ | 4263342/6000000 [00:25<00:10, 168896.44it/s]
71%|███████▏ | 4280260/6000000 [00:25<00:10, 168978.35it/s]
72%|███████▏ | 4297159/6000000 [00:25<00:10, 168157.73it/s]
72%|███████▏ | 4314002/6000000 [00:25<00:10, 168230.54it/s]
72%|███████▏ | 4331169/6000000 [00:25<00:09, 169256.55it/s]
72%|███████▏ | 4348096/6000000 [00:25<00:09, 168670.89it/s]
73%|███████▎ | 4365001/6000000 [00:25<00:09, 168577.92it/s]
73%|███████▎ | 4382327/6000000 [00:25<00:09, 169975.62it/s]
73%|███████▎ | 4399375/6000000 [00:25<00:09, 170115.65it/s]
74%|███████▎ | 4416411/6000000 [00:26<00:09, 170186.56it/s]
74%|███████▍ | 4433560/6000000 [00:26<00:09, 170575.28it/s]
74%|███████▍ | 4450618/6000000 [00:26<00:09, 169946.11it/s]
74%|███████▍ | 4467614/6000000 [00:26<00:09, 169490.30it/s]
75%|███████▍ | 4484659/6000000 [00:26<00:08, 169774.76it/s]
75%|███████▌ | 4501750/6000000 [00:26<00:08, 170110.13it/s]
75%|███████▌ | 4518762/6000000 [00:26<00:08, 169530.37it/s]
76%|███████▌ | 4535784/6000000 [00:26<00:08, 169733.13it/s]
76%|███████▌ | 4552758/6000000 [00:26<00:08, 169682.58it/s]
76%|███████▌ | 4569727/6000000 [00:26<00:08, 169063.77it/s]
76%|███████▋ | 4586907/6000000 [00:27<00:08, 169878.10it/s]
77%|███████▋ | 4603896/6000000 [00:27<00:08, 169559.54it/s]
77%|███████▋ | 4620853/6000000 [00:27<00:08, 169295.64it/s]
77%|███████▋ | 4638001/6000000 [00:27<00:08, 169381.29it/s]
78%|███████▊ | 4654940/6000000 [00:27<00:07, 169291.87it/s]
78%|███████▊ | 4671888/6000000 [00:27<00:07, 169346.46it/s]
78%|███████▊ | 4688958/6000000 [00:27<00:07, 169750.42it/s]
78%|███████▊ | 4705934/6000000 [00:27<00:07, 168989.62it/s]
79%|███████▊ | 4722834/6000000 [00:27<00:07, 168863.63it/s]
79%|███████▉ | 4739753/6000000 [00:27<00:07, 168958.94it/s]
79%|███████▉ | 4756657/6000000 [00:28<00:07, 168980.23it/s]
80%|███████▉ | 4773627/6000000 [00:28<00:07, 169193.77it/s]
80%|███████▉ | 4790815/6000000 [00:28<00:07, 169997.32it/s]
80%|████████ | 4807832/6000000 [00:28<00:07, 170046.94it/s]
80%|████████ | 4824837/6000000 [00:28<00:06, 169911.76it/s]
81%|████████ | 4841921/6000000 [00:28<00:06, 170188.05it/s]
81%|████████ | 4858940/6000000 [00:28<00:06, 169483.78it/s]
81%|████████▏ | 4875890/6000000 [00:28<00:06, 169117.99it/s]
82%|████████▏ | 4892901/6000000 [00:28<00:06, 169411.09it/s]
82%|████████▏ | 4909843/6000000 [00:28<00:06, 169082.99it/s]
82%|████████▏ | 4926981/6000000 [00:29<00:06, 169768.42it/s]
82%|████████▏ | 4944184/6000000 [00:29<00:06, 170441.48it/s]
83%|████████▎ | 4961402/6000000 [00:29<00:06, 170959.91it/s]
83%|████████▎ | 4978798/6000000 [00:29<00:05, 171857.35it/s]
83%|████████▎ | 4996092/6000000 [00:29<00:05, 172180.33it/s]
84%|████████▎ | 5013311/6000000 [00:29<00:05, 171644.02it/s]
84%|████████▍ | 5030559/6000000 [00:29<00:05, 171892.98it/s]
84%|████████▍ | 5047749/6000000 [00:29<00:05, 171512.96it/s]
84%|████████▍ | 5064901/6000000 [00:29<00:05, 171097.65it/s]
85%|████████▍ | 5082012/6000000 [00:29<00:05, 170514.73it/s]
85%|████████▍ | 5099065/6000000 [00:30<00:05, 170464.50it/s]
85%|████████▌ | 5116112/6000000 [00:30<00:05, 170132.19it/s]
86%|████████▌ | 5133148/6000000 [00:30<00:05, 170198.00it/s]
86%|████████▌ | 5150444/6000000 [00:30<00:04, 171023.36it/s]
86%|████████▌ | 5167547/6000000 [00:30<00:04, 170498.75it/s]
86%|████████▋ | 5184598/6000000 [00:30<00:04, 169584.05it/s]
87%|████████▋ | 5201917/6000000 [00:30<00:04, 170658.04it/s]
87%|████████▋ | 5219021/6000000 [00:30<00:04, 170768.16it/s]
87%|████████▋ | 5236099/6000000 [00:30<00:04, 170283.52it/s]
88%|████████▊ | 5253129/6000000 [00:30<00:04, 169881.92it/s]
88%|████████▊ | 5270501/6000000 [00:31<00:04, 171024.46it/s]
88%|████████▊ | 5287605/6000000 [00:31<00:04, 170220.67it/s]
88%|████████▊ | 5304629/6000000 [00:31<00:04, 169424.75it/s]
89%|████████▊ | 5321574/6000000 [00:31<00:04, 169429.07it/s]
89%|████████▉ | 5338518/6000000 [00:31<00:03, 169341.28it/s]
89%|████████▉ | 5355453/6000000 [00:31<00:03, 168688.82it/s]
90%|████████▉ | 5372366/6000000 [00:31<00:03, 168817.62it/s]
90%|████████▉ | 5389249/6000000 [00:31<00:03, 168482.24it/s]
90%|█████████ | 5406098/6000000 [00:31<00:03, 167536.32it/s]
90%|█████████ | 5423055/6000000 [00:31<00:03, 168141.01it/s]
91%|█████████ | 5439871/6000000 [00:32<00:03, 167578.20it/s]
91%|█████████ | 5456782/6000000 [00:32<00:03, 168031.81it/s]
91%|█████████ | 5473587/6000000 [00:32<00:03, 167705.17it/s]
92%|█████████▏| 5490359/6000000 [00:32<00:03, 167472.51it/s]
92%|█████████▏| 5507422/6000000 [00:32<00:02, 168414.40it/s]
92%|█████████▏| 5524265/6000000 [00:32<00:02, 167835.63it/s]
92%|█████████▏| 5541050/6000000 [00:32<00:02, 167259.07it/s]
93%|█████████▎| 5557997/6000000 [00:32<00:02, 167916.73it/s]
93%|█████████▎| 5574790/6000000 [00:32<00:02, 167888.92it/s]
93%|█████████▎| 5592001/6000000 [00:32<00:02, 168748.76it/s]
93%|█████████▎| 5609312/6000000 [00:33<00:02, 170049.62it/s]
94%|█████████▍| 5626318/6000000 [00:33<00:02, 169967.29it/s]
94%|█████████▍| 5643457/6000000 [00:33<00:02, 170391.65it/s]
94%|█████████▍| 5660500/6000000 [00:33<00:01, 170401.24it/s]
95%|█████████▍| 5677541/6000000 [00:33<00:01, 169630.05it/s]
95%|█████████▍| 5694604/6000000 [00:33<00:01, 169924.55it/s]
95%|█████████▌| 5712001/6000000 [00:33<00:01, 170577.63it/s]
95%|█████████▌| 5729059/6000000 [00:33<00:01, 170557.73it/s]
96%|█████████▌| 5746115/6000000 [00:33<00:01, 170108.92it/s]
96%|█████████▌| 5763127/6000000 [00:33<00:01, 169892.12it/s]
96%|█████████▋| 5780398/6000000 [00:34<00:01, 170731.86it/s]
97%|█████████▋| 5797508/6000000 [00:34<00:01, 170839.14it/s]
97%|█████████▋| 5814593/6000000 [00:34<00:01, 169889.09it/s]
97%|█████████▋| 5831584/6000000 [00:34<00:00, 169350.59it/s]
97%|█████████▋| 5848552/6000000 [00:34<00:00, 169447.08it/s]
98%|█████████▊| 5865498/6000000 [00:34<00:00, 168606.57it/s]
98%|█████████▊| 5882756/6000000 [00:34<00:00, 169787.66it/s]
98%|█████████▊| 5899805/6000000 [00:34<00:00, 169996.13it/s]
99%|█████████▊| 5916923/6000000 [00:34<00:00, 170349.13it/s]
99%|█████████▉| 5934001/6000000 [00:34<00:00, 170163.98it/s]
99%|█████████▉| 5951165/6000000 [00:35<00:00, 170602.42it/s]
99%|█████████▉| 5968226/6000000 [00:35<00:00, 170251.09it/s]
100%|█████████▉| 5985252/6000000 [00:35<00:00, 169655.33it/s]
100%|██████████| 6000000/6000000 [00:35<00:00, 169658.61it/s]
| episode | step | policy_idx | arm | reward | reward_stat | |
|---|---|---|---|---|---|---|
| 1324896 | 441 | 632 | 0 | 8 | 3.377683 | 1.507971 |
| 3566176 | 1188 | 725 | 1 | 6 | 0.360026 | 1.360048 |
| 1109043 | 369 | 681 | 0 | 1 | 0.774966 | 1.071004 |
| 4286042 | 1428 | 680 | 2 | 0 | 1.215224 | 1.738224 |
| 5395174 | 1798 | 391 | 1 | 0 | 1.958754 | 0.912307 |
It is then straightforward to plot the average reward each policy obtains at each step, by averaging over episodes.
policy_names = {
0: 'ε = 0.1',
1: 'ε = 0.01',
2: 'ε = 0 (greedy)'
}
(
trace_df
.assign(policy=trace_df.policy_idx.map(policy_names))
.groupby(['step', 'policy'])
['reward'].mean()
.unstack()
.plot()
)
<Axes: xlabel='step'>

Controlling the evaluation loop¶
The bandit.evaluate function is useful for benchmarking. But in practice, you'll want to have control over your bandit policy. Indeed you'll want the freedom to pull arms (with the pull method) and update the policy (with the update method) at your discretion.
As an example, the following is a possible reimplementation of the bandit.evaluate function. Here we'll be measuring the rate at which each policy selects the optimal arm.
Note how the pull and update methods are used.
import copy
policies=[
bandit.EpsilonGreedy(epsilon=0.1),
bandit.EpsilonGreedy(epsilon=0.01),
bandit.EpsilonGreedy(epsilon=0),
]
env = gym.make(
'river_bandits/KArmedTestbed-v0',
max_episode_steps=1000
)
n_episodes = 2000
trace = []
with tqdm(total=len(policies) * n_episodes * env._max_episode_steps, position=0) as progress:
for policy in policies:
for episode in range(n_episodes):
episode_policy = policy.clone()
episode_env = copy.deepcopy(env)
episode_env.reset()
step = 0
while True:
action = episode_policy.pull(range(episode_env.action_space.n))
observation, reward, terminated, truncated, info = episode_env.step(action)
best_action = observation
episode_policy.update(action, reward)
trace.append({
"episode": episode,
"step": step,
"policy": f"ε = {policy.epsilon}",
"is_action_optimal": action == best_action
})
step += 1
progress.update()
if terminated or truncated:
break
trace_df = pd.DataFrame(trace)
0%| | 0/6000000 [00:00<?, ?it/s]
0%| | 17267/6000000 [00:00<00:34, 172669.34it/s]
1%| | 35009/6000000 [00:00<00:33, 175460.99it/s]
1%| | 52825/6000000 [00:00<00:33, 176691.36it/s]
1%| | 70502/6000000 [00:00<00:33, 176722.31it/s]
1%|▏ | 88175/6000000 [00:00<00:33, 176396.81it/s]
2%|▏ | 106003/6000000 [00:00<00:33, 177033.49it/s]
2%|▏ | 123772/6000000 [00:00<00:33, 177245.37it/s]
2%|▏ | 141522/6000000 [00:00<00:33, 177324.02it/s]
3%|▎ | 159255/6000000 [00:00<00:33, 172702.51it/s]
3%|▎ | 176968/6000000 [00:01<00:33, 174041.18it/s]
3%|▎ | 194665/6000000 [00:01<00:33, 174925.80it/s]
4%|▎ | 212429/6000000 [00:01<00:32, 175742.33it/s]
4%|▍ | 230237/6000000 [00:01<00:32, 176445.07it/s]
4%|▍ | 247998/6000000 [00:01<00:32, 176793.03it/s]
4%|▍ | 265808/6000000 [00:01<00:32, 177185.06it/s]
5%|▍ | 283531/6000000 [00:01<00:32, 177055.59it/s]
5%|▌ | 301248/6000000 [00:01<00:32, 177088.96it/s]
5%|▌ | 318959/6000000 [00:01<00:32, 177054.54it/s]
6%|▌ | 336699/6000000 [00:01<00:31, 177158.10it/s]
6%|▌ | 354416/6000000 [00:02<00:31, 177152.47it/s]
6%|▌ | 372132/6000000 [00:02<00:31, 177072.59it/s]
6%|▋ | 389987/6000000 [00:02<00:31, 177515.57it/s]
7%|▋ | 407739/6000000 [00:02<00:31, 177215.99it/s]
7%|▋ | 425461/6000000 [00:02<00:31, 177143.47it/s]
7%|▋ | 443186/6000000 [00:02<00:31, 177175.06it/s]
8%|▊ | 460938/6000000 [00:02<00:31, 177277.62it/s]
8%|▊ | 478666/6000000 [00:02<00:31, 177016.58it/s]
8%|▊ | 496379/6000000 [00:02<00:31, 177048.46it/s]
9%|▊ | 514097/6000000 [00:02<00:30, 177085.91it/s]
9%|▉ | 531853/6000000 [00:03<00:30, 177226.81it/s]
9%|▉ | 549576/6000000 [00:03<00:30, 177089.19it/s]
9%|▉ | 567415/6000000 [00:03<00:30, 177478.12it/s]
10%|▉ | 585185/6000000 [00:03<00:30, 177543.17it/s]
10%|█ | 602998/6000000 [00:03<00:30, 177718.93it/s]
10%|█ | 620775/6000000 [00:03<00:30, 177732.14it/s]
11%|█ | 638549/6000000 [00:03<00:30, 177714.98it/s]
11%|█ | 656321/6000000 [00:03<00:30, 177663.68it/s]
11%|█ | 674163/6000000 [00:03<00:29, 177889.40it/s]
12%|█▏ | 691952/6000000 [00:03<00:29, 177888.49it/s]
12%|█▏ | 709758/6000000 [00:04<00:29, 177937.95it/s]
12%|█▏ | 727552/6000000 [00:04<00:29, 177710.74it/s]
12%|█▏ | 745324/6000000 [00:04<00:29, 177653.40it/s]
13%|█▎ | 763090/6000000 [00:04<00:29, 177543.02it/s]
13%|█▎ | 780845/6000000 [00:04<00:29, 176345.57it/s]
13%|█▎ | 798482/6000000 [00:04<00:29, 174917.82it/s]
14%|█▎ | 815977/6000000 [00:04<00:29, 174004.28it/s]
14%|█▍ | 833380/6000000 [00:04<00:29, 173208.20it/s]
14%|█▍ | 850703/6000000 [00:04<00:29, 173013.17it/s]
14%|█▍ | 868006/6000000 [00:04<00:29, 172495.48it/s]
15%|█▍ | 885258/6000000 [00:05<00:29, 172501.36it/s]
15%|█▌ | 902509/6000000 [00:05<00:29, 172338.98it/s]
15%|█▌ | 919777/6000000 [00:05<00:29, 172438.91it/s]
16%|█▌ | 937022/6000000 [00:05<00:29, 172352.94it/s]
16%|█▌ | 954258/6000000 [00:05<00:29, 169395.22it/s]
16%|█▌ | 971542/6000000 [00:05<00:29, 170412.18it/s]
16%|█▋ | 988837/6000000 [00:05<00:29, 171164.44it/s]
17%|█▋ | 1005978/6000000 [00:05<00:29, 171236.80it/s]
17%|█▋ | 1023215/6000000 [00:05<00:29, 171574.92it/s]
17%|█▋ | 1040446/6000000 [00:05<00:28, 171794.26it/s]
18%|█▊ | 1057692/6000000 [00:06<00:28, 171991.10it/s]
18%|█▊ | 1074970/6000000 [00:06<00:28, 172226.69it/s]
18%|█▊ | 1092195/6000000 [00:06<00:28, 172211.95it/s]
18%|█▊ | 1109418/6000000 [00:06<00:28, 172172.74it/s]
19%|█▉ | 1126636/6000000 [00:06<00:28, 172131.91it/s]
19%|█▉ | 1143857/6000000 [00:06<00:28, 172153.40it/s]
19%|█▉ | 1161073/6000000 [00:06<00:28, 171868.42it/s]
20%|█▉ | 1178291/6000000 [00:06<00:28, 171959.10it/s]
20%|█▉ | 1195488/6000000 [00:06<00:27, 171872.52it/s]
20%|██ | 1212676/6000000 [00:06<00:27, 171619.78it/s]
20%|██ | 1229844/6000000 [00:07<00:27, 171635.29it/s]
21%|██ | 1247008/6000000 [00:07<00:27, 171429.79it/s]
21%|██ | 1264303/6000000 [00:07<00:27, 171882.53it/s]
21%|██▏ | 1281539/6000000 [00:07<00:27, 172023.31it/s]
22%|██▏ | 1298742/6000000 [00:07<00:27, 170461.11it/s]
22%|██▏ | 1315971/6000000 [00:07<00:27, 171005.02it/s]
22%|██▏ | 1333159/6000000 [00:07<00:27, 171265.79it/s]
23%|██▎ | 1350472/6000000 [00:07<00:27, 171822.68it/s]
23%|██▎ | 1367721/6000000 [00:07<00:26, 172019.99it/s]
23%|██▎ | 1384953/6000000 [00:07<00:26, 172108.91it/s]
23%|██▎ | 1402165/6000000 [00:08<00:26, 172013.72it/s]
24%|██▎ | 1419440/6000000 [00:08<00:26, 172233.03it/s]
24%|██▍ | 1436783/6000000 [00:08<00:26, 172589.88it/s]
24%|██▍ | 1454043/6000000 [00:08<00:26, 172496.91it/s]
25%|██▍ | 1471299/6000000 [00:08<00:26, 172514.52it/s]
25%|██▍ | 1488771/6000000 [00:08<00:26, 173173.35it/s]
25%|██▌ | 1506089/6000000 [00:08<00:26, 172720.46it/s]
25%|██▌ | 1523362/6000000 [00:08<00:26, 171939.20it/s]
26%|██▌ | 1540557/6000000 [00:08<00:25, 171634.56it/s]
26%|██▌ | 1557855/6000000 [00:08<00:25, 172034.57it/s]
26%|██▋ | 1575060/6000000 [00:09<00:25, 171332.56it/s]
27%|██▋ | 1592195/6000000 [00:09<00:25, 171320.38it/s]
27%|██▋ | 1609328/6000000 [00:09<00:25, 170975.07it/s]
27%|██▋ | 1626571/6000000 [00:09<00:25, 171408.03it/s]
27%|██▋ | 1643886/6000000 [00:09<00:25, 171927.49it/s]
28%|██▊ | 1661086/6000000 [00:09<00:25, 171949.13it/s]
28%|██▊ | 1678383/6000000 [00:09<00:25, 172253.54it/s]
28%|██▊ | 1695609/6000000 [00:09<00:25, 172072.94it/s]
29%|██▊ | 1712982/6000000 [00:09<00:24, 172567.80it/s]
29%|██▉ | 1730263/6000000 [00:09<00:24, 172637.84it/s]
29%|██▉ | 1747527/6000000 [00:10<00:24, 172613.62it/s]
29%|██▉ | 1764789/6000000 [00:10<00:24, 172429.83it/s]
30%|██▉ | 1782033/6000000 [00:10<00:24, 172047.72it/s]
30%|██▉ | 1799272/6000000 [00:10<00:24, 172148.61it/s]
30%|███ | 1816560/6000000 [00:10<00:24, 172365.49it/s]
31%|███ | 1833883/6000000 [00:10<00:24, 172622.12it/s]
31%|███ | 1851146/6000000 [00:10<00:24, 172497.19it/s]
31%|███ | 1868553/6000000 [00:10<00:23, 172966.46it/s]
31%|███▏ | 1885986/6000000 [00:10<00:23, 173374.76it/s]
32%|███▏ | 1903324/6000000 [00:10<00:23, 172999.31it/s]
32%|███▏ | 1920625/6000000 [00:11<00:23, 172995.39it/s]
32%|███▏ | 1937925/6000000 [00:11<00:23, 172977.05it/s]
33%|███▎ | 1955223/6000000 [00:11<00:23, 172590.27it/s]
33%|███▎ | 1972534/6000000 [00:11<00:23, 172744.53it/s]
33%|███▎ | 1989817/6000000 [00:11<00:23, 172769.81it/s]
33%|███▎ | 2007095/6000000 [00:11<00:23, 171835.58it/s]
34%|███▎ | 2024280/6000000 [00:11<00:23, 169712.12it/s]
34%|███▍ | 2041258/6000000 [00:11<00:23, 167972.95it/s]
34%|███▍ | 2058062/6000000 [00:11<00:23, 167033.48it/s]
35%|███▍ | 2074770/6000000 [00:11<00:23, 166517.68it/s]
35%|███▍ | 2091425/6000000 [00:12<00:23, 165572.88it/s]
35%|███▌ | 2107985/6000000 [00:12<00:23, 165213.07it/s]
35%|███▌ | 2124508/6000000 [00:12<00:23, 165204.70it/s]
36%|███▌ | 2141030/6000000 [00:12<00:23, 164975.43it/s]
36%|███▌ | 2157528/6000000 [00:12<00:23, 164900.89it/s]
36%|███▌ | 2174120/6000000 [00:12<00:23, 165202.88it/s]
37%|███▋ | 2190641/6000000 [00:12<00:23, 164946.95it/s]
37%|███▋ | 2207136/6000000 [00:12<00:23, 164475.08it/s]
37%|███▋ | 2223584/6000000 [00:12<00:22, 164328.80it/s]
37%|███▋ | 2240187/6000000 [00:12<00:22, 164835.43it/s]
38%|███▊ | 2256888/6000000 [00:13<00:22, 165483.20it/s]
38%|███▊ | 2273437/6000000 [00:13<00:22, 165403.80it/s]
38%|███▊ | 2290033/6000000 [00:13<00:22, 165568.07it/s]
38%|███▊ | 2306646/6000000 [00:13<00:22, 165735.67it/s]
39%|███▊ | 2323220/6000000 [00:13<00:22, 165736.07it/s]
39%|███▉ | 2339794/6000000 [00:13<00:22, 165684.59it/s]
39%|███▉ | 2356363/6000000 [00:13<00:22, 165011.69it/s]
40%|███▉ | 2372929/6000000 [00:13<00:21, 165204.65it/s]
40%|███▉ | 2389474/6000000 [00:13<00:21, 165276.48it/s]
40%|████ | 2406074/6000000 [00:13<00:21, 165492.05it/s]
40%|████ | 2422660/6000000 [00:14<00:21, 165600.70it/s]
41%|████ | 2439221/6000000 [00:14<00:21, 165288.25it/s]
41%|████ | 2455874/6000000 [00:14<00:21, 165659.88it/s]
41%|████ | 2472505/6000000 [00:14<00:21, 165853.05it/s]
41%|████▏ | 2489091/6000000 [00:14<00:21, 165436.68it/s]
42%|████▏ | 2505715/6000000 [00:14<00:21, 165676.91it/s]
42%|████▏ | 2522283/6000000 [00:14<00:21, 164948.23it/s]
42%|████▏ | 2538859/6000000 [00:14<00:20, 165189.12it/s]
43%|████▎ | 2555379/6000000 [00:14<00:20, 165176.29it/s]
43%|████▎ | 2571904/6000000 [00:14<00:20, 165196.98it/s]
43%|████▎ | 2588439/6000000 [00:15<00:20, 165242.16it/s]
43%|████▎ | 2604964/6000000 [00:15<00:20, 164676.85it/s]
44%|████▎ | 2621469/6000000 [00:15<00:20, 164788.37it/s]
44%|████▍ | 2638001/6000000 [00:15<00:20, 164833.80it/s]
44%|████▍ | 2654485/6000000 [00:15<00:20, 164721.30it/s]
45%|████▍ | 2670958/6000000 [00:15<00:20, 164478.54it/s]
45%|████▍ | 2687407/6000000 [00:15<00:20, 163991.95it/s]
45%|████▌ | 2703925/6000000 [00:15<00:20, 164345.20it/s]
45%|████▌ | 2720557/6000000 [00:15<00:19, 164933.42it/s]
46%|████▌ | 2737051/6000000 [00:15<00:19, 164812.95it/s]
46%|████▌ | 2753887/6000000 [00:16<00:19, 165872.57it/s]
46%|████▌ | 2770532/6000000 [00:16<00:19, 166043.61it/s]
46%|████▋ | 2787193/6000000 [00:16<00:19, 166212.35it/s]
47%|████▋ | 2803851/6000000 [00:16<00:19, 166319.89it/s]
47%|████▋ | 2820507/6000000 [00:16<00:19, 166390.63it/s]
47%|████▋ | 2837147/6000000 [00:16<00:19, 166325.22it/s]
48%|████▊ | 2853780/6000000 [00:16<00:19, 165460.26it/s]
48%|████▊ | 2870328/6000000 [00:16<00:18, 165298.76it/s]
48%|████▊ | 2886859/6000000 [00:16<00:18, 165256.23it/s]
48%|████▊ | 2903543/6000000 [00:16<00:18, 165728.02it/s]
49%|████▊ | 2920117/6000000 [00:17<00:18, 165385.47it/s]
49%|████▉ | 2936656/6000000 [00:17<00:18, 165235.54it/s]
49%|████▉ | 2953197/6000000 [00:17<00:18, 165285.20it/s]
49%|████▉ | 2969726/6000000 [00:17<00:18, 164864.00it/s]
50%|████▉ | 2986213/6000000 [00:17<00:18, 164848.92it/s]
50%|█████ | 3002747/6000000 [00:17<00:18, 164994.67it/s]
50%|█████ | 3019247/6000000 [00:17<00:18, 163836.93it/s]
51%|█████ | 3035858/6000000 [00:17<00:18, 164512.11it/s]
51%|█████ | 3052311/6000000 [00:17<00:18, 163371.03it/s]
51%|█████ | 3068870/6000000 [00:17<00:17, 164028.98it/s]
51%|█████▏ | 3085276/6000000 [00:18<00:17, 163945.81it/s]
52%|█████▏ | 3101894/6000000 [00:18<00:17, 164611.85it/s]
52%|█████▏ | 3118531/6000000 [00:18<00:17, 165136.82it/s]
52%|█████▏ | 3135065/6000000 [00:18<00:17, 165197.58it/s]
53%|█████▎ | 3151637/6000000 [00:18<00:17, 165353.41it/s]
53%|█████▎ | 3168173/6000000 [00:18<00:17, 165329.68it/s]
53%|█████▎ | 3184707/6000000 [00:18<00:17, 164815.26it/s]
53%|█████▎ | 3201293/6000000 [00:18<00:16, 165125.15it/s]
54%|█████▎ | 3217979/6000000 [00:18<00:16, 165643.45it/s]
54%|█████▍ | 3234544/6000000 [00:18<00:16, 165252.17it/s]
54%|█████▍ | 3251070/6000000 [00:19<00:16, 164946.15it/s]
54%|█████▍ | 3267566/6000000 [00:19<00:16, 164893.55it/s]
55%|█████▍ | 3284270/6000000 [00:19<00:16, 165535.61it/s]
55%|█████▌ | 3300824/6000000 [00:19<00:16, 165257.78it/s]
55%|█████▌ | 3317351/6000000 [00:19<00:16, 164849.43it/s]
56%|█████▌ | 3333894/6000000 [00:19<00:16, 165022.23it/s]
56%|█████▌ | 3350397/6000000 [00:19<00:16, 164821.38it/s]
56%|█████▌ | 3366996/6000000 [00:19<00:15, 165170.89it/s]
56%|█████▋ | 3383514/6000000 [00:19<00:15, 165150.55it/s]
57%|█████▋ | 3400030/6000000 [00:19<00:15, 164570.14it/s]
57%|█████▋ | 3416596/6000000 [00:20<00:15, 164895.09it/s]
57%|█████▋ | 3433086/6000000 [00:20<00:15, 164854.61it/s]
57%|█████▋ | 3449662/6000000 [00:20<00:15, 165124.36it/s]
58%|█████▊ | 3466175/6000000 [00:20<00:15, 165002.69it/s]
58%|█████▊ | 3482676/6000000 [00:20<00:15, 164561.59it/s]
58%|█████▊ | 3499247/6000000 [00:20<00:15, 164903.99it/s]
59%|█████▊ | 3515777/6000000 [00:20<00:15, 165020.12it/s]
59%|█████▉ | 3532280/6000000 [00:20<00:14, 164518.33it/s]
59%|█████▉ | 3548760/6000000 [00:20<00:14, 164600.38it/s]
59%|█████▉ | 3565221/6000000 [00:20<00:14, 164601.43it/s]
60%|█████▉ | 3581690/6000000 [00:21<00:14, 164625.89it/s]
60%|█████▉ | 3598153/6000000 [00:21<00:14, 164344.31it/s]
60%|██████ | 3614588/6000000 [00:21<00:14, 163937.05it/s]
61%|██████ | 3630983/6000000 [00:21<00:14, 163682.81it/s]
61%|██████ | 3647509/6000000 [00:21<00:14, 164154.11it/s]
61%|██████ | 3664091/6000000 [00:21<00:14, 164652.18it/s]
61%|██████▏ | 3680691/6000000 [00:21<00:14, 165054.62it/s]
62%|██████▏ | 3697197/6000000 [00:21<00:13, 164810.78it/s]
62%|██████▏ | 3713679/6000000 [00:21<00:13, 164700.23it/s]
62%|██████▏ | 3730150/6000000 [00:21<00:13, 163863.75it/s]
62%|██████▏ | 3746538/6000000 [00:22<00:14, 158095.60it/s]
63%|██████▎ | 3762390/6000000 [00:22<00:14, 153017.33it/s]
63%|██████▎ | 3778901/6000000 [00:22<00:14, 156487.57it/s]
63%|██████▎ | 3795164/6000000 [00:22<00:13, 158272.39it/s]
64%|██████▎ | 3811368/6000000 [00:22<00:13, 159377.51it/s]
64%|██████▍ | 3827722/6000000 [00:22<00:13, 160605.00it/s]
64%|██████▍ | 3844098/6000000 [00:22<00:13, 161538.93it/s]
64%|██████▍ | 3860662/6000000 [00:22<00:13, 162759.49it/s]
65%|██████▍ | 3877001/6000000 [00:22<00:13, 162936.21it/s]
65%|██████▍ | 3893451/6000000 [00:23<00:12, 163403.41it/s]
65%|██████▌ | 3909916/6000000 [00:23<00:12, 163776.33it/s]
65%|██████▌ | 3926363/6000000 [00:23<00:12, 163983.16it/s]
66%|██████▌ | 3942875/6000000 [00:23<00:12, 164323.47it/s]
66%|██████▌ | 3959310/6000000 [00:23<00:12, 163368.33it/s]
66%|██████▋ | 3975650/6000000 [00:23<00:12, 162954.42it/s]
67%|██████▋ | 3991969/6000000 [00:23<00:12, 163022.34it/s]
67%|██████▋ | 4008273/6000000 [00:23<00:12, 162485.15it/s]
67%|██████▋ | 4024743/6000000 [00:23<00:12, 163144.57it/s]
67%|██████▋ | 4041222/6000000 [00:23<00:11, 163635.97it/s]
68%|██████▊ | 4057908/6000000 [00:24<00:11, 164599.62it/s]
68%|██████▊ | 4074689/6000000 [00:24<00:11, 165560.33it/s]
68%|██████▊ | 4091326/6000000 [00:24<00:11, 165800.86it/s]
68%|██████▊ | 4108271/6000000 [00:24<00:11, 166892.71it/s]
69%|██████▉ | 4125008/6000000 [00:24<00:11, 167035.74it/s]
69%|██████▉ | 4141722/6000000 [00:24<00:11, 167065.50it/s]
69%|██████▉ | 4158480/6000000 [00:24<00:11, 167218.89it/s]
70%|██████▉ | 4175274/6000000 [00:24<00:10, 167434.27it/s]
70%|██████▉ | 4192018/6000000 [00:24<00:10, 167179.94it/s]
70%|███████ | 4208812/6000000 [00:24<00:10, 167405.59it/s]
70%|███████ | 4225553/6000000 [00:25<00:10, 166528.30it/s]
71%|███████ | 4242310/6000000 [00:25<00:10, 166838.10it/s]
71%|███████ | 4259001/6000000 [00:25<00:10, 166719.72it/s]
71%|███████▏ | 4275795/6000000 [00:25<00:10, 167082.60it/s]
72%|███████▏ | 4292504/6000000 [00:25<00:10, 166855.92it/s]
72%|███████▏ | 4309479/6000000 [00:25<00:10, 167720.30it/s]
72%|███████▏ | 4326252/6000000 [00:25<00:10, 164821.24it/s]
72%|███████▏ | 4343025/6000000 [00:25<00:10, 165680.44it/s]
73%|███████▎ | 4360001/6000000 [00:25<00:09, 166761.52it/s]
73%|███████▎ | 4376872/6000000 [00:25<00:09, 167340.05it/s]
73%|███████▎ | 4393612/6000000 [00:26<00:09, 167038.92it/s]
74%|███████▎ | 4410606/6000000 [00:26<00:09, 167894.93it/s]
74%|███████▍ | 4427435/6000000 [00:26<00:09, 168011.53it/s]
74%|███████▍ | 4444239/6000000 [00:26<00:09, 167432.08it/s]
74%|███████▍ | 4460985/6000000 [00:26<00:09, 167328.79it/s]
75%|███████▍ | 4477720/6000000 [00:26<00:09, 167168.70it/s]
75%|███████▍ | 4494438/6000000 [00:26<00:09, 166783.23it/s]
75%|███████▌ | 4511118/6000000 [00:26<00:08, 166112.66it/s]
75%|███████▌ | 4527843/6000000 [00:26<00:08, 166449.79it/s]
76%|███████▌ | 4544489/6000000 [00:26<00:08, 165864.48it/s]
76%|███████▌ | 4561077/6000000 [00:27<00:08, 165363.55it/s]
76%|███████▋ | 4577852/6000000 [00:27<00:08, 166073.05it/s]
77%|███████▋ | 4594461/6000000 [00:27<00:08, 165931.00it/s]
77%|███████▋ | 4611173/6000000 [00:27<00:08, 166284.62it/s]
77%|███████▋ | 4628005/6000000 [00:27<00:08, 166893.08it/s]
77%|███████▋ | 4644695/6000000 [00:27<00:08, 166267.17it/s]
78%|███████▊ | 4661323/6000000 [00:27<00:08, 166027.59it/s]
78%|███████▊ | 4677945/6000000 [00:27<00:07, 166083.85it/s]
78%|███████▊ | 4694558/6000000 [00:27<00:07, 166094.39it/s]
79%|███████▊ | 4711168/6000000 [00:27<00:07, 165503.90it/s]
79%|███████▉ | 4727719/6000000 [00:28<00:07, 165441.76it/s]
79%|███████▉ | 4744273/6000000 [00:28<00:07, 165468.71it/s]
79%|███████▉ | 4760821/6000000 [00:28<00:07, 165270.77it/s]
80%|███████▉ | 4777354/6000000 [00:28<00:07, 165288.31it/s]
80%|███████▉ | 4794001/6000000 [00:28<00:07, 165473.20it/s]
80%|████████ | 4810703/6000000 [00:28<00:07, 165935.96it/s]
80%|████████ | 4827297/6000000 [00:28<00:07, 165621.33it/s]
81%|████████ | 4843968/6000000 [00:28<00:06, 165946.30it/s]
81%|████████ | 4860874/6000000 [00:28<00:06, 166876.61it/s]
81%|████████▏ | 4877562/6000000 [00:28<00:06, 166741.27it/s]
82%|████████▏ | 4894237/6000000 [00:29<00:06, 165898.79it/s]
82%|████████▏ | 4910886/6000000 [00:29<00:06, 166074.22it/s]
82%|████████▏ | 4927521/6000000 [00:29<00:06, 166156.50it/s]
82%|████████▏ | 4944138/6000000 [00:29<00:06, 166152.65it/s]
83%|████████▎ | 4960754/6000000 [00:29<00:06, 165776.42it/s]
83%|████████▎ | 4977333/6000000 [00:29<00:06, 165742.70it/s]
83%|████████▎ | 4993908/6000000 [00:29<00:06, 165621.16it/s]
84%|████████▎ | 5010471/6000000 [00:29<00:05, 164997.71it/s]
84%|████████▍ | 5027046/6000000 [00:29<00:05, 165221.83it/s]
84%|████████▍ | 5043894/6000000 [00:29<00:05, 166194.79it/s]
84%|████████▍ | 5060515/6000000 [00:30<00:05, 165906.40it/s]
85%|████████▍ | 5077107/6000000 [00:30<00:05, 164662.55it/s]
85%|████████▍ | 5093852/6000000 [00:30<00:05, 165490.77it/s]
85%|████████▌ | 5110543/6000000 [00:30<00:05, 165913.04it/s]
85%|████████▌ | 5127136/6000000 [00:30<00:05, 164960.63it/s]
86%|████████▌ | 5143710/6000000 [00:30<00:05, 165192.03it/s]
86%|████████▌ | 5160231/6000000 [00:30<00:05, 164898.29it/s]
86%|████████▋ | 5177011/6000000 [00:30<00:04, 165762.96it/s]
87%|████████▋ | 5193865/6000000 [00:30<00:04, 166592.44it/s]
87%|████████▋ | 5210526/6000000 [00:30<00:04, 165265.06it/s]
87%|████████▋ | 5227168/6000000 [00:31<00:04, 165608.06it/s]
87%|████████▋ | 5243981/6000000 [00:31<00:04, 166358.05it/s]
88%|████████▊ | 5260638/6000000 [00:31<00:04, 166420.39it/s]
88%|████████▊ | 5277282/6000000 [00:31<00:04, 166392.83it/s]
88%|████████▊ | 5294001/6000000 [00:31<00:04, 166599.82it/s]
89%|████████▊ | 5310662/6000000 [00:31<00:04, 166319.84it/s]
89%|████████▉ | 5327295/6000000 [00:31<00:04, 165931.90it/s]
89%|████████▉ | 5344036/6000000 [00:31<00:03, 166371.47it/s]
89%|████████▉ | 5360674/6000000 [00:31<00:03, 165960.13it/s]
90%|████████▉ | 5377353/6000000 [00:31<00:03, 166205.00it/s]
90%|████████▉ | 5394061/6000000 [00:32<00:03, 166465.03it/s]
90%|█████████ | 5410708/6000000 [00:32<00:03, 166161.04it/s]
90%|█████████ | 5427325/6000000 [00:32<00:03, 166117.90it/s]
91%|█████████ | 5444154/6000000 [00:32<00:03, 166765.84it/s]
91%|█████████ | 5461052/6000000 [00:32<00:03, 167427.81it/s]
91%|█████████▏| 5477883/6000000 [00:32<00:03, 167691.19it/s]
92%|█████████▏| 5494665/6000000 [00:32<00:03, 167727.10it/s]
92%|█████████▏| 5511438/6000000 [00:32<00:02, 167594.54it/s]
92%|█████████▏| 5528198/6000000 [00:32<00:02, 167453.34it/s]
92%|█████████▏| 5544944/6000000 [00:32<00:02, 167314.82it/s]
93%|█████████▎| 5561900/6000000 [00:33<00:02, 167986.85it/s]
93%|█████████▎| 5578718/6000000 [00:33<00:02, 168043.47it/s]
93%|█████████▎| 5595523/6000000 [00:33<00:02, 167707.80it/s]
94%|█████████▎| 5612294/6000000 [00:33<00:02, 167409.13it/s]
94%|█████████▍| 5629036/6000000 [00:33<00:02, 167403.12it/s]
94%|█████████▍| 5645792/6000000 [00:33<00:02, 167448.83it/s]
94%|█████████▍| 5662537/6000000 [00:33<00:02, 167365.06it/s]
95%|█████████▍| 5679274/6000000 [00:33<00:01, 166693.68it/s]
95%|█████████▍| 5695953/6000000 [00:33<00:01, 166721.45it/s]
95%|█████████▌| 5712644/6000000 [00:33<00:01, 166775.18it/s]
95%|█████████▌| 5729414/6000000 [00:34<00:01, 167049.90it/s]
96%|█████████▌| 5746160/6000000 [00:34<00:01, 167171.62it/s]
96%|█████████▌| 5762922/6000000 [00:34<00:01, 167305.94it/s]
96%|█████████▋| 5779763/6000000 [00:34<00:01, 167636.16it/s]
97%|█████████▋| 5796527/6000000 [00:34<00:01, 167479.19it/s]
97%|█████████▋| 5813276/6000000 [00:34<00:01, 166916.17it/s]
97%|█████████▋| 5830001/6000000 [00:34<00:01, 166963.08it/s]
97%|█████████▋| 5846711/6000000 [00:34<00:00, 166996.96it/s]
98%|█████████▊| 5863411/6000000 [00:34<00:00, 166630.62it/s]
98%|█████████▊| 5880199/6000000 [00:34<00:00, 167004.22it/s]
98%|█████████▊| 5896924/6000000 [00:35<00:00, 167077.74it/s]
99%|█████████▊| 5913633/6000000 [00:35<00:00, 166965.78it/s]
99%|█████████▉| 5930330/6000000 [00:35<00:00, 166963.87it/s]
99%|█████████▉| 5947027/6000000 [00:35<00:00, 166887.29it/s]
99%|█████████▉| 5963716/6000000 [00:35<00:00, 166777.33it/s]
100%|█████████▉| 5980405/6000000 [00:35<00:00, 166809.59it/s]
100%|█████████▉| 5997087/6000000 [00:35<00:00, 166629.25it/s]
100%|██████████| 6000000/6000000 [00:35<00:00, 168164.58it/s]
colors = {
'ε = 0.1': 'tab:blue',
'ε = 0.01': 'tab:red',
'ε = 0': 'tab:green'
}
(
trace_df
.groupby(['step', 'policy'])
['is_action_optimal'].mean()
.unstack()
.plot()
)
<Axes: xlabel='step'>

Handling drift¶
The environment used above is a toy situation used for introducing bandits. It is stationary, meaning that the expected reward of each arm does not change over time.
In practice, arms are dynamic, and their performance can vary over time. A simple example of this is the Candy Cane Contest that was hosted on Kaggle in 2020. The expected reward of each arm diminishes each time it is pulled.
The way bandit policies in River deal with drift depends on the method. For the bandit.EpsilonGreedy policy, it makes sense to use a rolling average as the reward object. What this means is that the empirical reward the policy calculates for each arm is a rolling average, rather than a global one.
from river import proba, utils
policies=[
bandit.EpsilonGreedy(
epsilon=0.1,
seed=42
),
bandit.EpsilonGreedy(
epsilon=0.3,
reward_obj=utils.Rolling(stats.Mean(), window_size=50),
seed=42
),
bandit.ThompsonSampling(
reward_obj=proba.Beta(),
seed=42
)
]
env = gym.make('river_bandits/CandyCaneContest-v0')
trace = bandit.evaluate(
policies=policies,
env=env,
n_episodes=(n_episodes := 30),
seed=42
)
trace_df = pd.DataFrame(tqdm(
trace, position=0, total=(
n_episodes *
len(policies) *
env._max_episode_steps
)
))
0%| | 0/180000 [00:00<?, ?it/s]
1%| | 1308/180000 [00:00<00:13, 13076.77it/s]
1%|▏ | 2642/180000 [00:00<00:13, 13229.64it/s]
2%|▏ | 3972/180000 [00:00<00:13, 13252.71it/s]
3%|▎ | 5298/180000 [00:00<00:13, 13248.01it/s]
4%|▎ | 6623/180000 [00:00<00:13, 13156.42it/s]
4%|▍ | 7965/180000 [00:00<00:12, 13239.89it/s]
5%|▌ | 9304/180000 [00:00<00:12, 13288.03it/s]
6%|▌ | 10635/180000 [00:00<00:12, 13288.91it/s]
7%|▋ | 11964/180000 [00:00<00:12, 13206.33it/s]
7%|▋ | 13285/180000 [00:01<00:12, 13195.33it/s]
8%|▊ | 14619/180000 [00:01<00:12, 13237.33it/s]
9%|▉ | 15951/180000 [00:01<00:12, 13258.93it/s]
10%|▉ | 17289/180000 [00:01<00:12, 13290.28it/s]
10%|█ | 18619/180000 [00:01<00:12, 13228.85it/s]
11%|█ | 19975/180000 [00:01<00:12, 13327.93it/s]
12%|█▏ | 21322/180000 [00:01<00:11, 13369.64it/s]
13%|█▎ | 22674/180000 [00:01<00:11, 13411.75it/s]
13%|█▎ | 24016/180000 [00:01<00:11, 13099.24it/s]
14%|█▍ | 25416/180000 [00:01<00:11, 13361.46it/s]
15%|█▍ | 26754/180000 [00:02<00:11, 13308.50it/s]
16%|█▌ | 28086/180000 [00:02<00:11, 13279.63it/s]
16%|█▋ | 29415/180000 [00:02<00:11, 13235.23it/s]
17%|█▋ | 30740/180000 [00:02<00:11, 13172.33it/s]
18%|█▊ | 32100/180000 [00:02<00:11, 13294.50it/s]
19%|█▊ | 33430/180000 [00:02<00:11, 13228.23it/s]
19%|█▉ | 34755/180000 [00:02<00:10, 13232.87it/s]
20%|██ | 36079/180000 [00:02<00:11, 12975.11it/s]
21%|██ | 37491/180000 [00:02<00:10, 13311.32it/s]
22%|██▏ | 38844/180000 [00:02<00:10, 13372.76it/s]
22%|██▏ | 40191/180000 [00:03<00:10, 13396.36it/s]
23%|██▎ | 41544/180000 [00:03<00:10, 13431.94it/s]
24%|██▍ | 42888/180000 [00:03<00:10, 13333.39it/s]
25%|██▍ | 44229/180000 [00:03<00:10, 13350.84it/s]
25%|██▌ | 45565/180000 [00:03<00:10, 13348.67it/s]
26%|██▌ | 46901/180000 [00:03<00:09, 13342.05it/s]
27%|██▋ | 48236/180000 [00:03<00:10, 13105.82it/s]
28%|██▊ | 49603/180000 [00:03<00:09, 13270.56it/s]
28%|██▊ | 50931/180000 [00:03<00:09, 13260.49it/s]
29%|██▉ | 52260/180000 [00:03<00:09, 13265.21it/s]
30%|██▉ | 53587/180000 [00:04<00:09, 13200.97it/s]
31%|███ | 54908/180000 [00:04<00:09, 13122.97it/s]
31%|███ | 56231/180000 [00:04<00:09, 13153.78it/s]
32%|███▏ | 57562/180000 [00:04<00:09, 13200.25it/s]
33%|███▎ | 58886/180000 [00:04<00:09, 13211.22it/s]
33%|███▎ | 60208/180000 [00:04<00:09, 13002.90it/s]
34%|███▍ | 61587/180000 [00:04<00:08, 13232.69it/s]
35%|███▍ | 62928/180000 [00:04<00:08, 13282.88it/s]
36%|███▌ | 64275/180000 [00:04<00:08, 13334.26it/s]
36%|███▋ | 65612/180000 [00:04<00:08, 13344.52it/s]
37%|███▋ | 66947/180000 [00:05<00:08, 13233.06it/s]
38%|███▊ | 68286/180000 [00:05<00:08, 13273.76it/s]
39%|███▊ | 69626/180000 [00:05<00:08, 13310.96it/s]
39%|███▉ | 70962/180000 [00:05<00:08, 13325.19it/s]
40%|████ | 72295/180000 [00:05<00:08, 13161.29it/s]
41%|████ | 73654/180000 [00:05<00:08, 13287.65it/s]
42%|████▏ | 74994/180000 [00:05<00:07, 13315.40it/s]
42%|████▏ | 76328/180000 [00:05<00:07, 13321.75it/s]
43%|████▎ | 77661/180000 [00:05<00:07, 13321.56it/s]
44%|████▍ | 78994/180000 [00:05<00:07, 13233.64it/s]
45%|████▍ | 80319/180000 [00:06<00:07, 13233.38it/s]
45%|████▌ | 81645/180000 [00:06<00:07, 13235.49it/s]
46%|████▌ | 82969/180000 [00:06<00:07, 13232.45it/s]
47%|████▋ | 84293/180000 [00:06<00:07, 13077.51it/s]
48%|████▊ | 85665/180000 [00:06<00:07, 13262.91it/s]
48%|████▊ | 86997/180000 [00:06<00:07, 13274.57it/s]
49%|████▉ | 88329/180000 [00:06<00:06, 13282.35it/s]
50%|████▉ | 89658/180000 [00:06<00:06, 13269.67it/s]
51%|█████ | 90986/180000 [00:06<00:06, 13193.21it/s]
51%|█████▏ | 92327/180000 [00:06<00:06, 13257.57it/s]
52%|█████▏ | 93660/180000 [00:07<00:06, 13275.09it/s]
53%|█████▎ | 94999/180000 [00:07<00:06, 13308.21it/s]
54%|█████▎ | 96330/180000 [00:07<00:06, 13153.73it/s]
54%|█████▍ | 97710/180000 [00:07<00:06, 13340.93it/s]
55%|█████▌ | 99055/180000 [00:07<00:06, 13372.35it/s]
56%|█████▌ | 100393/180000 [00:07<00:05, 13350.99it/s]
57%|█████▋ | 101729/180000 [00:07<00:05, 13329.74it/s]
57%|█████▋ | 103063/180000 [00:07<00:05, 13245.47it/s]
58%|█████▊ | 104394/180000 [00:07<00:05, 13260.08it/s]
59%|█████▊ | 105721/180000 [00:07<00:05, 13244.42it/s]
59%|█████▉ | 107046/180000 [00:08<00:05, 13227.42it/s]
60%|██████ | 108369/180000 [00:08<00:05, 13062.78it/s]
61%|██████ | 109734/180000 [00:08<00:05, 13235.32it/s]
62%|██████▏ | 111074/180000 [00:08<00:05, 13283.56it/s]
62%|██████▏ | 112403/180000 [00:08<00:05, 13230.54it/s]
63%|██████▎ | 113733/180000 [00:08<00:05, 13250.47it/s]
64%|██████▍ | 115059/180000 [00:08<00:04, 13181.49it/s]
65%|██████▍ | 116384/180000 [00:08<00:04, 13201.50it/s]
65%|██████▌ | 117714/180000 [00:08<00:04, 13226.77it/s]
66%|██████▌ | 119049/180000 [00:08<00:04, 13260.94it/s]
67%|██████▋ | 120376/180000 [00:09<00:04, 13104.58it/s]
68%|██████▊ | 121737/180000 [00:09<00:04, 13252.22it/s]
68%|██████▊ | 123073/180000 [00:09<00:04, 13283.28it/s]
69%|██████▉ | 124422/180000 [00:09<00:04, 13343.14it/s]
70%|██████▉ | 125769/180000 [00:09<00:04, 13373.90it/s]
71%|███████ | 127107/180000 [00:09<00:03, 13304.58it/s]
71%|███████▏ | 128439/180000 [00:09<00:03, 13307.94it/s]
72%|███████▏ | 129770/180000 [00:09<00:03, 13297.07it/s]
73%|███████▎ | 131100/180000 [00:09<00:03, 13290.64it/s]
74%|███████▎ | 132430/180000 [00:09<00:03, 13157.54it/s]
74%|███████▍ | 133773/180000 [00:10<00:03, 13238.33it/s]
75%|███████▌ | 135098/180000 [00:10<00:03, 13231.94it/s]
76%|███████▌ | 136425/180000 [00:10<00:03, 13237.15it/s]
77%|███████▋ | 137752/180000 [00:10<00:03, 13246.28it/s]
77%|███████▋ | 139077/180000 [00:10<00:03, 13159.76it/s]
78%|███████▊ | 140409/180000 [00:10<00:02, 13206.92it/s]
79%|███████▊ | 141747/180000 [00:10<00:02, 13257.49it/s]
79%|███████▉ | 143077/180000 [00:10<00:02, 13268.96it/s]
80%|████████ | 144404/180000 [00:10<00:02, 13123.07it/s]
81%|████████ | 145760/180000 [00:10<00:02, 13252.38it/s]
82%|████████▏ | 147108/180000 [00:11<00:02, 13317.02it/s]
82%|████████▏ | 148461/180000 [00:11<00:02, 13379.89it/s]
83%|████████▎ | 149816/180000 [00:11<00:02, 13430.16it/s]
84%|████████▍ | 151160/180000 [00:11<00:02, 13368.30it/s]
85%|████████▍ | 152509/180000 [00:11<00:02, 13404.40it/s]
85%|████████▌ | 153851/180000 [00:11<00:01, 13408.72it/s]
86%|████████▌ | 155193/180000 [00:11<00:01, 13411.79it/s]
87%|████████▋ | 156535/180000 [00:11<00:01, 13280.36it/s]
88%|████████▊ | 157890/180000 [00:11<00:01, 13358.11it/s]
88%|████████▊ | 159227/180000 [00:12<00:01, 13284.30it/s]
89%|████████▉ | 160556/180000 [00:12<00:01, 13259.32it/s]
90%|████████▉ | 161883/180000 [00:12<00:01, 13228.16it/s]
91%|█████████ | 163206/180000 [00:12<00:01, 13183.86it/s]
91%|█████████▏| 164550/180000 [00:12<00:01, 13254.63it/s]
92%|█████████▏| 165891/180000 [00:12<00:01, 13297.04it/s]
93%|█████████▎| 167221/180000 [00:12<00:00, 13282.81it/s]
94%|█████████▎| 168550/180000 [00:12<00:00, 13206.20it/s]
94%|█████████▍| 169902/180000 [00:12<00:00, 13297.96it/s]
95%|█████████▌| 171249/180000 [00:12<00:00, 13347.18it/s]
96%|█████████▌| 172584/180000 [00:13<00:00, 13347.37it/s]
97%|█████████▋| 173919/180000 [00:13<00:00, 13334.11it/s]
97%|█████████▋| 175253/180000 [00:13<00:00, 13215.41it/s]
98%|█████████▊| 176607/180000 [00:13<00:00, 13306.12it/s]
99%|█████████▉| 177966/180000 [00:13<00:00, 13389.65it/s]
100%|█████████▉| 179319/180000 [00:13<00:00, 13430.73it/s]
100%|██████████| 180000/180000 [00:13<00:00, 13268.41it/s]
We can compare the performance of each policy by checking the average reward at the end of each episode.
(
trace_df
.groupby(['policy_idx', 'episode'])
.last()
.groupby('policy_idx')
.reward_stat.mean()
)
policy_idx
0 736.1
1 817.0
2 854.0
Name: reward_stat, dtype: float64
We see that using a rolling average gives a boost to the epsilon greedy strategy. However, we see that the bandit.ThompsonSampling policy performs even better, even though no particular care was given to drift. A natural next step would thus be to see how it could be improved to handle drift. For instance, its dist parameter could be wrapped with a utils.Rolling:
policy = bandit.ThompsonSampling(
reward_obj=utils.Rolling(proba.Beta(), window_size=50),
seed=42
)
Bandits can be used for several tasks. They can be used for content personalization, as well as online model selection (see model_selection.BanditRegressor). The policies in River are therefore designed to be flexible, so that they can be used in conjunction with other River modules. For instance, the reward_obj in bandit.EpsilonGreedy can be a metric, a probability distribution, or a statistic. This works because objects in River adher to a coherent get/update interface.