Skip to content

Multi-armed bandits

River has a bandit module. It contains several multi-armed bandit policies, bandit environments, and utilities to benchmark policies on bandit problems.

Bandit environments in River implement the Gym interface. You can thus load them with gym.make. Note that Gym is intended for reinforcement learning algorithms, while bandit policies are the simplest form of reinforcement learning. Bandit policies learn by receiving a reward after each step, while reinforcement learning algorithms have to learn from feedback that may arrive at the end of a (long) sequence of steps.

import gymnasium as gym

for k in gym.envs.registry:
    if k.startswith('river_bandits'):
        print(k)

River's bandit module offers the bandit.evaluate function to benchmark several policies on a given environment. It takes as input a list of bandit policies, a bandit environment (the problem to solve), and a reward object.

import gymnasium as gym
from river import bandit
import pandas as pd
from tqdm import tqdm
from river import stats

policies=[
    bandit.EpsilonGreedy(epsilon=0.1),
    bandit.EpsilonGreedy(epsilon=0.01),
    bandit.EpsilonGreedy(epsilon=0),
]

env = gym.make(
    'river_bandits/KArmedTestbed-v0',
    max_episode_steps=1000
)

trace = bandit.evaluate(
    policies=policies,
    env=env,
    reward_stat=stats.Mean(),
    n_episodes=(n_episodes := 2000),
)

The bandit.evaluate function returns a generator containing the results at each step of the benchmark. This can be wrapped with a pandas.DataFrame to gather all the results.

trace_df = pd.DataFrame(tqdm(
    trace, position=0, total=(
        n_episodes *
        len(policies) *
        env._max_episode_steps
    )
))
trace_df.sample(5, random_state=42)

0%| | 0/6000000 [00:00<?, ?it/s]

0%| | 16326/6000000 [00:00<00:36, 163247.31it/s]

1%| | 33004/6000000 [00:00<00:36, 165313.33it/s]

1%| | 50258/6000000 [00:00<00:35, 168610.13it/s]

1%| | 67428/6000000 [00:00<00:34, 169826.95it/s]

1%|▏ | 84411/6000000 [00:00<00:35, 168198.42it/s]

2%|▏ | 101423/6000000 [00:00<00:34, 168843.98it/s]

2%|▏ | 118310/6000000 [00:00<00:34, 168683.62it/s]

2%|▏ | 135180/6000000 [00:00<00:34, 168673.18it/s]

3%|▎ | 152049/6000000 [00:00<00:43, 133706.60it/s]

3%|▎ | 168773/6000000 [00:01<00:40, 142455.26it/s]

3%|▎ | 186001/6000000 [00:01<00:38, 150382.10it/s]

3%|▎ | 202848/6000000 [00:01<00:37, 155432.08it/s]

4%|▎ | 219277/6000000 [00:01<00:36, 157953.96it/s]

4%|▍ | 236056/6000000 [00:01<00:35, 160799.88it/s]

4%|▍ | 252867/6000000 [00:01<00:35, 162935.29it/s]

4%|▍ | 269783/6000000 [00:01<00:34, 164767.46it/s]

5%|▍ | 286412/6000000 [00:01<00:34, 164497.23it/s]

5%|▌ | 303151/6000000 [00:01<00:34, 165353.88it/s]

5%|▌ | 319869/6000000 [00:01<00:34, 165896.19it/s]

6%|▌ | 336829/6000000 [00:02<00:33, 166998.46it/s]

6%|▌ | 353913/6000000 [00:02<00:33, 168144.49it/s]

6%|▌ | 370850/6000000 [00:02<00:33, 168508.61it/s]

6%|▋ | 387720/6000000 [00:02<00:33, 167692.14it/s]

7%|▋ | 404504/6000000 [00:02<00:33, 166668.01it/s]

7%|▋ | 421182/6000000 [00:02<00:33, 166509.99it/s]

7%|▋ | 438001/6000000 [00:02<00:33, 166416.75it/s]

8%|▊ | 455081/6000000 [00:02<00:33, 167720.33it/s]

8%|▊ | 471986/6000000 [00:02<00:32, 168114.21it/s]

8%|▊ | 488862/6000000 [00:02<00:32, 168305.05it/s]

8%|▊ | 505695/6000000 [00:03<00:32, 167868.91it/s]

9%|▊ | 522484/6000000 [00:03<00:32, 167475.27it/s]

9%|▉ | 539713/6000000 [00:03<00:32, 168910.38it/s]

9%|▉ | 556606/6000000 [00:03<00:32, 167333.06it/s]

10%|▉ | 573344/6000000 [00:03<00:32, 164913.69it/s]

10%|▉ | 589926/6000000 [00:03<00:32, 165179.47it/s]

10%|█ | 606590/6000000 [00:03<00:32, 165610.68it/s]

10%|█ | 623654/6000000 [00:03<00:32, 167106.21it/s]

11%|█ | 640370/6000000 [00:03<00:32, 166931.75it/s]

11%|█ | 657067/6000000 [00:04<00:32, 166466.49it/s]

11%|█ | 673901/6000000 [00:04<00:31, 167023.82it/s]

12%|█▏ | 690606/6000000 [00:04<00:31, 166256.42it/s]

12%|█▏ | 707363/6000000 [00:04<00:31, 166644.91it/s]

12%|█▏ | 724083/6000000 [00:04<00:31, 166807.23it/s]

12%|█▏ | 741001/6000000 [00:04<00:31, 167052.88it/s]

13%|█▎ | 757928/6000000 [00:04<00:31, 167713.74it/s]

13%|█▎ | 774701/6000000 [00:04<00:31, 167055.45it/s]

13%|█▎ | 791435/6000000 [00:04<00:31, 167135.82it/s]

13%|█▎ | 808150/6000000 [00:04<00:31, 166315.86it/s]

14%|█▍ | 825001/6000000 [00:05<00:31, 166623.05it/s]

14%|█▍ | 841971/6000000 [00:05<00:30, 167539.55it/s]

14%|█▍ | 858727/6000000 [00:05<00:30, 167231.75it/s]

15%|█▍ | 875804/6000000 [00:05<00:30, 168285.64it/s]

15%|█▍ | 892634/6000000 [00:05<00:30, 167839.29it/s]

15%|█▌ | 909443/6000000 [00:05<00:30, 167911.75it/s]

15%|█▌ | 926235/6000000 [00:05<00:30, 167858.11it/s]

16%|█▌ | 943022/6000000 [00:05<00:30, 167160.19it/s]

16%|█▌ | 959970/6000000 [00:05<00:30, 167851.67it/s]

16%|█▋ | 976757/6000000 [00:05<00:30, 167081.16it/s]

17%|█▋ | 993467/6000000 [00:06<00:29, 167075.06it/s]

17%|█▋ | 1010472/6000000 [00:06<00:29, 167959.61it/s]

17%|█▋ | 1027269/6000000 [00:06<00:29, 167650.05it/s]

17%|█▋ | 1044035/6000000 [00:06<00:29, 167101.94it/s]

18%|█▊ | 1060906/6000000 [00:06<00:29, 167578.63it/s]

18%|█▊ | 1078015/6000000 [00:06<00:29, 168625.71it/s]

18%|█▊ | 1095147/6000000 [00:06<00:28, 169430.78it/s]

19%|█▊ | 1112128/6000000 [00:06<00:28, 169543.32it/s]

19%|█▉ | 1129083/6000000 [00:06<00:28, 168708.01it/s]

19%|█▉ | 1145956/6000000 [00:06<00:28, 168515.88it/s]

19%|█▉ | 1162809/6000000 [00:07<00:28, 168167.74it/s]

20%|█▉ | 1179627/6000000 [00:07<00:28, 167908.56it/s]

20%|█▉ | 1196562/6000000 [00:07<00:28, 168335.24it/s]

20%|██ | 1213396/6000000 [00:07<00:28, 167608.08it/s]

21%|██ | 1230158/6000000 [00:07<00:28, 165232.89it/s]

21%|██ | 1247099/6000000 [00:07<00:28, 166465.56it/s]

21%|██ | 1263779/6000000 [00:07<00:28, 166563.81it/s]

21%|██▏ | 1280922/6000000 [00:07<00:28, 168010.30it/s]

22%|██▏ | 1297728/6000000 [00:07<00:28, 167155.78it/s]

22%|██▏ | 1314448/6000000 [00:07<00:28, 167091.14it/s]

22%|██▏ | 1331356/6000000 [00:08<00:27, 167683.09it/s]

22%|██▏ | 1348127/6000000 [00:08<00:27, 167689.63it/s]

23%|██▎ | 1364948/6000000 [00:08<00:27, 167842.28it/s]

23%|██▎ | 1381734/6000000 [00:08<00:27, 167810.25it/s]

23%|██▎ | 1398516/6000000 [00:08<00:27, 167326.81it/s]

24%|██▎ | 1415322/6000000 [00:08<00:27, 167544.59it/s]

24%|██▍ | 1432078/6000000 [00:08<00:27, 166630.04it/s]

24%|██▍ | 1448973/6000000 [00:08<00:27, 167320.34it/s]

24%|██▍ | 1465707/6000000 [00:08<00:27, 166761.94it/s]

25%|██▍ | 1482385/6000000 [00:08<00:27, 166764.05it/s]

25%|██▍ | 1499311/6000000 [00:09<00:26, 167508.35it/s]

25%|██▌ | 1516063/6000000 [00:09<00:26, 166476.95it/s]

26%|██▌ | 1533001/6000000 [00:09<00:26, 166817.15it/s]

26%|██▌ | 1549819/6000000 [00:09<00:26, 167218.31it/s]

26%|██▌ | 1566542/6000000 [00:09<00:26, 167172.81it/s]

26%|██▋ | 1583488/6000000 [00:09<00:26, 167853.22it/s]

27%|██▋ | 1600275/6000000 [00:09<00:26, 166720.41it/s]

27%|██▋ | 1617037/6000000 [00:09<00:26, 166983.97it/s]

27%|██▋ | 1633800/6000000 [00:09<00:26, 167175.40it/s]

28%|██▊ | 1650786/6000000 [00:09<00:25, 167967.61it/s]

28%|██▊ | 1667649/6000000 [00:10<00:25, 168163.89it/s]

28%|██▊ | 1684557/6000000 [00:10<00:25, 168436.89it/s]

28%|██▊ | 1701402/6000000 [00:10<00:25, 167942.68it/s]

29%|██▊ | 1718589/6000000 [00:10<00:25, 169107.42it/s]

29%|██▉ | 1735501/6000000 [00:10<00:25, 168023.67it/s]

29%|██▉ | 1752306/6000000 [00:10<00:25, 167312.69it/s]

29%|██▉ | 1769255/6000000 [00:10<00:25, 167957.68it/s]

30%|██▉ | 1786053/6000000 [00:10<00:25, 167435.22it/s]

30%|███ | 1803216/6000000 [00:10<00:24, 168683.13it/s]

30%|███ | 1820384/6000000 [00:10<00:24, 169576.80it/s]

31%|███ | 1837344/6000000 [00:11<00:24, 168965.21it/s]

31%|███ | 1854242/6000000 [00:11<00:24, 167624.71it/s]

31%|███ | 1871234/6000000 [00:11<00:24, 168304.81it/s]

31%|███▏ | 1888068/6000000 [00:11<00:24, 168069.48it/s]

32%|███▏ | 1904910/6000000 [00:11<00:24, 168172.09it/s]

32%|███▏ | 1921729/6000000 [00:11<00:24, 167833.77it/s]

32%|███▏ | 1938636/6000000 [00:11<00:24, 168201.27it/s]

33%|███▎ | 1955664/6000000 [00:11<00:23, 168821.11it/s]

33%|███▎ | 1972547/6000000 [00:11<00:23, 168070.75it/s]

33%|███▎ | 1989356/6000000 [00:11<00:23, 167769.07it/s]

33%|███▎ | 2006674/6000000 [00:12<00:23, 169381.79it/s]

34%|███▎ | 2023614/6000000 [00:12<00:23, 167737.65it/s]

34%|███▍ | 2040392/6000000 [00:12<00:23, 167070.09it/s]

34%|███▍ | 2057337/6000000 [00:12<00:23, 167775.78it/s]

35%|███▍ | 2074118/6000000 [00:12<00:23, 167561.21it/s]

35%|███▍ | 2091003/6000000 [00:12<00:23, 167898.73it/s]

35%|███▌ | 2108066/6000000 [00:12<00:23, 168713.30it/s]

35%|███▌ | 2124939/6000000 [00:12<00:23, 167540.83it/s]

36%|███▌ | 2141960/6000000 [00:12<00:22, 168333.86it/s]

36%|███▌ | 2158796/6000000 [00:12<00:22, 167799.45it/s]

36%|███▋ | 2175578/6000000 [00:13<00:22, 167431.71it/s]

37%|███▋ | 2192508/6000000 [00:13<00:22, 167987.13it/s]

37%|███▋ | 2209308/6000000 [00:13<00:22, 167682.19it/s]

37%|███▋ | 2226078/6000000 [00:13<00:22, 167257.30it/s]

37%|███▋ | 2243175/6000000 [00:13<00:22, 168363.04it/s]

38%|███▊ | 2260013/6000000 [00:13<00:22, 168073.38it/s]

38%|███▊ | 2277001/6000000 [00:13<00:22, 168281.28it/s]

38%|███▊ | 2294204/6000000 [00:13<00:21, 169398.02it/s]

39%|███▊ | 2311145/6000000 [00:13<00:21, 168804.05it/s]

39%|███▉ | 2328027/6000000 [00:13<00:21, 167586.17it/s]

39%|███▉ | 2345177/6000000 [00:14<00:21, 168749.15it/s]

39%|███▉ | 2362055/6000000 [00:14<00:21, 168543.73it/s]

40%|███▉ | 2379001/6000000 [00:14<00:21, 168500.69it/s]

40%|███▉ | 2395853/6000000 [00:14<00:21, 168449.91it/s]

40%|████ | 2412699/6000000 [00:14<00:21, 167585.60it/s]

40%|████ | 2429978/6000000 [00:14<00:21, 169135.40it/s]

41%|████ | 2446894/6000000 [00:14<00:21, 167514.91it/s]

41%|████ | 2463650/6000000 [00:14<00:21, 167518.03it/s]

41%|████▏ | 2480604/6000000 [00:14<00:20, 168119.06it/s]

42%|████▏ | 2497419/6000000 [00:14<00:20, 167360.57it/s]

42%|████▏ | 2514199/6000000 [00:15<00:20, 167488.61it/s]

42%|████▏ | 2531024/6000000 [00:15<00:20, 167714.70it/s]

42%|████▏ | 2548085/6000000 [00:15<00:20, 168578.47it/s]

43%|████▎ | 2564981/6000000 [00:15<00:20, 168690.43it/s]

43%|████▎ | 2581851/6000000 [00:15<00:20, 167112.50it/s]

43%|████▎ | 2598566/6000000 [00:15<00:20, 166629.84it/s]

44%|████▎ | 2615232/6000000 [00:15<00:20, 166363.92it/s]

44%|████▍ | 2631901/6000000 [00:15<00:20, 166458.68it/s]

44%|████▍ | 2648665/6000000 [00:15<00:20, 166807.49it/s]

44%|████▍ | 2665347/6000000 [00:15<00:20, 166496.34it/s]

45%|████▍ | 2682096/6000000 [00:16<00:19, 166790.94it/s]

45%|████▍ | 2699092/6000000 [00:16<00:19, 167736.46it/s]

45%|████▌ | 2715912/6000000 [00:16<00:19, 167871.97it/s]

46%|████▌ | 2732831/6000000 [00:16<00:19, 168265.61it/s]

46%|████▌ | 2749658/6000000 [00:16<00:19, 167411.93it/s]

46%|████▌ | 2766401/6000000 [00:16<00:19, 166674.92it/s]

46%|████▋ | 2783616/6000000 [00:16<00:19, 168305.07it/s]

47%|████▋ | 2800449/6000000 [00:16<00:19, 167113.65it/s]

47%|████▋ | 2817164/6000000 [00:16<00:19, 166013.92it/s]

47%|████▋ | 2833833/6000000 [00:16<00:19, 166213.54it/s]

48%|████▊ | 2850457/6000000 [00:17<00:19, 165638.65it/s]

48%|████▊ | 2867208/6000000 [00:17<00:18, 166192.83it/s]

48%|████▊ | 2883977/6000000 [00:17<00:18, 166636.46it/s]

48%|████▊ | 2900642/6000000 [00:17<00:18, 166058.35it/s]

49%|████▊ | 2917249/6000000 [00:17<00:18, 165801.82it/s]

49%|████▉ | 2934001/6000000 [00:17<00:18, 165986.53it/s]

49%|████▉ | 2950825/6000000 [00:17<00:18, 166657.59it/s]

49%|████▉ | 2967498/6000000 [00:17<00:18, 166677.27it/s]

50%|████▉ | 2984311/6000000 [00:17<00:18, 167109.04it/s]

50%|█████ | 3001023/6000000 [00:17<00:17, 166928.68it/s]

50%|█████ | 3018001/6000000 [00:18<00:17, 167051.69it/s]

51%|█████ | 3035009/6000000 [00:18<00:17, 167953.26it/s]

51%|█████ | 3051805/6000000 [00:18<00:17, 167686.89it/s]

51%|█████ | 3068988/6000000 [00:18<00:17, 168922.55it/s]

51%|█████▏ | 3085881/6000000 [00:18<00:17, 167421.50it/s]

52%|█████▏ | 3102627/6000000 [00:18<00:17, 166881.86it/s]

52%|█████▏ | 3119711/6000000 [00:18<00:17, 168057.09it/s]

52%|█████▏ | 3136520/6000000 [00:18<00:17, 167831.06it/s]

53%|█████▎ | 3153305/6000000 [00:18<00:17, 167335.07it/s]

53%|█████▎ | 3170265/6000000 [00:19<00:16, 168008.84it/s]

53%|█████▎ | 3187068/6000000 [00:19<00:16, 167554.52it/s]

53%|█████▎ | 3204001/6000000 [00:19<00:16, 167263.78it/s]

54%|█████▎ | 3220863/6000000 [00:19<00:16, 167665.37it/s]

54%|█████▍ | 3237684/6000000 [00:19<00:16, 167824.46it/s]

54%|█████▍ | 3255001/6000000 [00:19<00:16, 168737.94it/s]

55%|█████▍ | 3272029/6000000 [00:19<00:16, 169196.66it/s]

55%|█████▍ | 3288949/6000000 [00:19<00:16, 168824.24it/s]

55%|█████▌ | 3305984/6000000 [00:19<00:15, 169278.83it/s]

55%|█████▌ | 3322913/6000000 [00:19<00:15, 168844.37it/s]

56%|█████▌ | 3339798/6000000 [00:20<00:15, 166960.99it/s]

56%|█████▌ | 3357001/6000000 [00:20<00:15, 167627.12it/s]

56%|█████▌ | 3374236/6000000 [00:20<00:15, 169027.47it/s]

57%|█████▋ | 3391143/6000000 [00:20<00:15, 168062.82it/s]

57%|█████▋ | 3408001/6000000 [00:20<00:15, 168037.85it/s]

57%|█████▋ | 3424816/6000000 [00:20<00:15, 168068.08it/s]

57%|█████▋ | 3441866/6000000 [00:20<00:15, 168791.66it/s]

58%|█████▊ | 3458926/6000000 [00:20<00:15, 169329.83it/s]

58%|█████▊ | 3475861/6000000 [00:20<00:15, 168206.91it/s]

58%|█████▊ | 3492684/6000000 [00:20<00:14, 167377.90it/s]

58%|█████▊ | 3509630/6000000 [00:21<00:14, 167995.31it/s]

59%|█████▉ | 3526432/6000000 [00:21<00:14, 167692.03it/s]

59%|█████▉ | 3543203/6000000 [00:21<00:14, 167339.06it/s]

59%|█████▉ | 3560387/6000000 [00:21<00:14, 168679.78it/s]

60%|█████▉ | 3577430/6000000 [00:21<00:14, 169198.88it/s]

60%|█████▉ | 3594352/6000000 [00:21<00:14, 168820.80it/s]

60%|██████ | 3611316/6000000 [00:21<00:14, 169064.21it/s]

60%|██████ | 3628224/6000000 [00:21<00:14, 167683.42it/s]

61%|██████ | 3645001/6000000 [00:21<00:14, 167359.81it/s]

61%|██████ | 3662016/6000000 [00:21<00:13, 168189.07it/s]

61%|██████▏ | 3678837/6000000 [00:22<00:13, 167838.35it/s]

62%|██████▏ | 3696001/6000000 [00:22<00:13, 168364.14it/s]

62%|██████▏ | 3712897/6000000 [00:22<00:13, 168537.59it/s]

62%|██████▏ | 3729752/6000000 [00:22<00:13, 168111.21it/s]

62%|██████▏ | 3746801/6000000 [00:22<00:13, 168817.83it/s]

63%|██████▎ | 3763684/6000000 [00:22<00:13, 168185.57it/s]

63%|██████▎ | 3780606/6000000 [00:22<00:13, 168491.54it/s]

63%|██████▎ | 3797651/6000000 [00:22<00:13, 169075.63it/s]

64%|██████▎ | 3814560/6000000 [00:22<00:13, 168059.41it/s]

64%|██████▍ | 3831468/6000000 [00:22<00:12, 168362.26it/s]

64%|██████▍ | 3848306/6000000 [00:23<00:12, 168289.69it/s]

64%|██████▍ | 3865136/6000000 [00:23<00:12, 167924.21it/s]

65%|██████▍ | 3881930/6000000 [00:23<00:12, 167101.00it/s]

65%|██████▍ | 3898642/6000000 [00:23<00:12, 166681.57it/s]

65%|██████▌ | 3915311/6000000 [00:23<00:12, 166168.13it/s]

66%|██████▌ | 3931929/6000000 [00:23<00:12, 160223.08it/s]

66%|██████▌ | 3948509/6000000 [00:23<00:12, 161843.03it/s]

66%|██████▌ | 3964729/6000000 [00:23<00:13, 152853.12it/s]

66%|██████▋ | 3981390/6000000 [00:23<00:12, 156752.95it/s]

67%|██████▋ | 3998212/6000000 [00:23<00:12, 160058.45it/s]

67%|██████▋ | 4015059/6000000 [00:24<00:12, 162510.64it/s]

67%|██████▋ | 4031886/6000000 [00:24<00:11, 164202.11it/s]

67%|██████▋ | 4048484/6000000 [00:24<00:11, 164725.68it/s]

68%|██████▊ | 4065104/6000000 [00:24<00:11, 165160.73it/s]

68%|██████▊ | 4081899/6000000 [00:24<00:11, 165989.19it/s]

68%|██████▊ | 4098517/6000000 [00:24<00:11, 165750.99it/s]

69%|██████▊ | 4115576/6000000 [00:24<00:11, 167193.01it/s]

69%|██████▉ | 4132305/6000000 [00:24<00:11, 167218.24it/s]

69%|██████▉ | 4149034/6000000 [00:24<00:11, 166567.10it/s]

69%|██████▉ | 4165967/6000000 [00:24<00:10, 167390.69it/s]

70%|██████▉ | 4182711/6000000 [00:25<00:10, 166983.87it/s]

70%|██████▉ | 4199589/6000000 [00:25<00:10, 167517.18it/s]

70%|███████ | 4216343/6000000 [00:25<00:10, 166827.23it/s]

71%|███████ | 4233028/6000000 [00:25<00:10, 166103.18it/s]

71%|███████ | 4249641/6000000 [00:25<00:10, 165857.93it/s]

71%|███████ | 4266228/6000000 [00:25<00:10, 165626.92it/s]

71%|███████▏ | 4283108/6000000 [00:25<00:10, 166572.31it/s]

72%|███████▏ | 4299767/6000000 [00:25<00:10, 166411.59it/s]

72%|███████▏ | 4316610/6000000 [00:25<00:10, 167012.53it/s]

72%|███████▏ | 4333312/6000000 [00:25<00:10, 166392.91it/s]

73%|███████▎ | 4350061/6000000 [00:26<00:09, 166719.06it/s]

73%|███████▎ | 4366928/6000000 [00:26<00:09, 167300.66it/s]

73%|███████▎ | 4383659/6000000 [00:26<00:09, 166571.21it/s]

73%|███████▎ | 4400496/6000000 [00:26<00:09, 167105.61it/s]

74%|███████▎ | 4417208/6000000 [00:26<00:09, 166207.90it/s]

74%|███████▍ | 4433977/6000000 [00:26<00:09, 166641.86it/s]

74%|███████▍ | 4450643/6000000 [00:26<00:09, 166299.22it/s]

74%|███████▍ | 4467392/6000000 [00:26<00:09, 166650.41it/s]

75%|███████▍ | 4484058/6000000 [00:26<00:09, 166475.28it/s]

75%|███████▌ | 4500707/6000000 [00:26<00:09, 166378.52it/s]

75%|███████▌ | 4517346/6000000 [00:27<00:08, 166244.17it/s]

76%|███████▌ | 4533998/6000000 [00:27<00:08, 166324.64it/s]

76%|███████▌ | 4551001/6000000 [00:27<00:08, 167279.31it/s]

76%|███████▌ | 4567759/6000000 [00:27<00:08, 167365.58it/s]

76%|███████▋ | 4584496/6000000 [00:27<00:08, 164849.37it/s]

77%|███████▋ | 4601517/6000000 [00:27<00:08, 166437.41it/s]

77%|███████▋ | 4618169/6000000 [00:27<00:08, 165706.31it/s]

77%|███████▋ | 4634917/6000000 [00:27<00:08, 166231.29it/s]

78%|███████▊ | 4651545/6000000 [00:27<00:08, 166030.24it/s]

78%|███████▊ | 4668288/6000000 [00:27<00:08, 166444.81it/s]

78%|███████▊ | 4685133/6000000 [00:28<00:07, 167042.17it/s]

78%|███████▊ | 4702013/6000000 [00:28<00:07, 167565.44it/s]

79%|███████▊ | 4718875/6000000 [00:28<00:07, 167877.56it/s]

79%|███████▉ | 4735851/6000000 [00:28<00:07, 168440.17it/s]

79%|███████▉ | 4752696/6000000 [00:28<00:07, 167325.53it/s]

79%|███████▉ | 4769604/6000000 [00:28<00:07, 167845.79it/s]

80%|███████▉ | 4786391/6000000 [00:28<00:07, 167476.43it/s]

80%|████████ | 4803140/6000000 [00:28<00:07, 167425.08it/s]

80%|████████ | 4820340/6000000 [00:28<00:06, 168789.90it/s]

81%|████████ | 4837221/6000000 [00:28<00:06, 168150.35it/s]

81%|████████ | 4854038/6000000 [00:29<00:06, 167288.64it/s]

81%|████████ | 4871226/6000000 [00:29<00:06, 168655.83it/s]

81%|████████▏ | 4888094/6000000 [00:29<00:06, 167786.53it/s]

82%|████████▏ | 4905001/6000000 [00:29<00:06, 167730.04it/s]

82%|████████▏ | 4921858/6000000 [00:29<00:06, 167977.42it/s]

82%|████████▏ | 4938701/6000000 [00:29<00:06, 168111.53it/s]

83%|████████▎ | 4955514/6000000 [00:29<00:06, 166623.70it/s]

83%|████████▎ | 4972180/6000000 [00:29<00:06, 166541.24it/s]

83%|████████▎ | 4989003/6000000 [00:29<00:06, 167025.46it/s]

83%|████████▎ | 5005901/6000000 [00:29<00:05, 167607.53it/s]

84%|████████▎ | 5022664/6000000 [00:30<00:05, 166465.07it/s]

84%|████████▍ | 5039695/6000000 [00:30<00:05, 167608.45it/s]

84%|████████▍ | 5056463/6000000 [00:30<00:05, 167626.52it/s]

85%|████████▍ | 5073329/6000000 [00:30<00:05, 167932.66it/s]

85%|████████▍ | 5090186/6000000 [00:30<00:05, 168120.32it/s]

85%|████████▌ | 5107000/6000000 [00:30<00:05, 167620.12it/s]

85%|████████▌ | 5124001/6000000 [00:30<00:05, 167909.10it/s]

86%|████████▌ | 5140842/6000000 [00:30<00:05, 168056.17it/s]

86%|████████▌ | 5157795/6000000 [00:30<00:04, 168494.57it/s]

86%|████████▌ | 5174808/6000000 [00:31<00:04, 168980.75it/s]

87%|████████▋ | 5191707/6000000 [00:31<00:04, 168113.77it/s]

87%|████████▋ | 5208535/6000000 [00:31<00:04, 168160.58it/s]

87%|████████▋ | 5225352/6000000 [00:31<00:04, 168110.95it/s]

87%|████████▋ | 5242164/6000000 [00:31<00:04, 168068.38it/s]

88%|████████▊ | 5258972/6000000 [00:31<00:04, 167624.37it/s]

88%|████████▊ | 5275735/6000000 [00:31<00:04, 166927.50it/s]

88%|████████▊ | 5292429/6000000 [00:31<00:04, 166543.53it/s]

88%|████████▊ | 5309322/6000000 [00:31<00:04, 167252.07it/s]

89%|████████▉ | 5326048/6000000 [00:31<00:04, 167227.91it/s]

89%|████████▉ | 5343001/6000000 [00:32<00:03, 167359.45it/s]

89%|████████▉ | 5360199/6000000 [00:32<00:03, 168737.18it/s]

90%|████████▉ | 5377192/6000000 [00:32<00:03, 169091.96it/s]

90%|████████▉ | 5394128/6000000 [00:32<00:03, 169171.17it/s]

90%|█████████ | 5411106/6000000 [00:32<00:03, 169350.94it/s]

90%|█████████ | 5428042/6000000 [00:32<00:03, 168374.65it/s]

91%|█████████ | 5444881/6000000 [00:32<00:03, 168181.39it/s]

91%|█████████ | 5461701/6000000 [00:32<00:03, 167751.00it/s]

91%|█████████▏| 5478477/6000000 [00:32<00:03, 167401.47it/s]

92%|█████████▏| 5495620/6000000 [00:32<00:02, 168602.25it/s]

92%|█████████▏| 5512514/6000000 [00:33<00:02, 168699.80it/s]

92%|█████████▏| 5529385/6000000 [00:33<00:02, 167705.59it/s]

92%|█████████▏| 5546510/6000000 [00:33<00:02, 168759.55it/s]

93%|█████████▎| 5563388/6000000 [00:33<00:02, 167739.91it/s]

93%|█████████▎| 5580165/6000000 [00:33<00:02, 167083.14it/s]

93%|█████████▎| 5597006/6000000 [00:33<00:02, 167474.47it/s]

94%|█████████▎| 5613873/6000000 [00:33<00:02, 167828.59it/s]

94%|█████████▍| 5630658/6000000 [00:33<00:02, 167732.14it/s]

94%|█████████▍| 5647433/6000000 [00:33<00:02, 167376.75it/s]

94%|█████████▍| 5664172/6000000 [00:33<00:02, 167222.03it/s]

95%|█████████▍| 5681173/6000000 [00:34<00:01, 168052.76it/s]

95%|█████████▍| 5697979/6000000 [00:34<00:01, 167906.75it/s]

95%|█████████▌| 5714771/6000000 [00:34<00:01, 167852.55it/s]

96%|█████████▌| 5731564/6000000 [00:34<00:01, 167872.91it/s]

96%|█████████▌| 5748490/6000000 [00:34<00:01, 168287.12it/s]

96%|█████████▌| 5765319/6000000 [00:34<00:01, 168169.17it/s]

96%|█████████▋| 5782137/6000000 [00:34<00:01, 167433.19it/s]

97%|█████████▋| 5799002/6000000 [00:34<00:01, 167794.96it/s]

97%|█████████▋| 5816126/6000000 [00:34<00:01, 168823.91it/s]

97%|█████████▋| 5833010/6000000 [00:34<00:00, 167528.52it/s]

98%|█████████▊| 5850001/6000000 [00:35<00:00, 168111.96it/s]

98%|█████████▊| 5866860/6000000 [00:35<00:00, 168252.15it/s]

98%|█████████▊| 5883687/6000000 [00:35<00:00, 167501.36it/s]

98%|█████████▊| 5900853/6000000 [00:35<00:00, 168740.17it/s]

99%|█████████▊| 5917729/6000000 [00:35<00:00, 167837.10it/s]

99%|█████████▉| 5934515/6000000 [00:35<00:00, 167744.09it/s]

99%|█████████▉| 5951309/6000000 [00:35<00:00, 167800.83it/s]

99%|█████████▉| 5968091/6000000 [00:35<00:00, 167294.26it/s]

100%|█████████▉| 5984822/6000000 [00:35<00:00, 167270.54it/s]

100%|██████████| 6000000/6000000 [00:35<00:00, 167035.97it/s]



episode step policy_idx arm reward reward_stat
1324896 441 632 0 5 0.883382 1.063547
3566176 1188 725 1 4 3.437877 0.828136
1109043 369 681 0 0 3.719861 1.653319
4286042 1428 680 2 0 1.043589 0.988311
5395174 1798 391 1 8 0.965791 1.490329

It is then straightforward to plot the average reward each policy obtains at each step, by averaging over episodes.

policy_names = {
    0: 'ε = 0.1',
    1: 'ε = 0.01',
    2: 'ε = 0 (greedy)'
}

(
    trace_df
    .assign(policy=trace_df.policy_idx.map(policy_names))
    .groupby(['step', 'policy'])
    ['reward'].mean()
    .unstack()
    .plot()
)


<Axes: xlabel='step'>


png

Controlling the evaluation loop

The bandit.evaluate function is useful for benchmarking. But in practice, you'll want to have control over your bandit policy. Indeed you'll want the freedom to pull arms (with the pull method) and update the policy (with the update method) at your discretion.

As an example, the following is a possible reimplementation of the bandit.evaluate function. Here we'll be measuring the rate at which each policy selects the optimal arm.

Note how the pull and update methods are used.

import copy

policies=[
    bandit.EpsilonGreedy(epsilon=0.1),
    bandit.EpsilonGreedy(epsilon=0.01),
    bandit.EpsilonGreedy(epsilon=0),
]

env = gym.make(
    'river_bandits/KArmedTestbed-v0',
    max_episode_steps=1000
)
n_episodes = 2000

trace = []

with tqdm(total=len(policies) * n_episodes * env._max_episode_steps, position=0) as progress:
    for policy in policies:
        for episode in range(n_episodes):
            episode_policy = policy.clone()
            episode_env = copy.deepcopy(env)
            episode_env.reset()
            step = 0
            while True:
                action = episode_policy.pull(range(episode_env.action_space.n))
                observation, reward, terminated, truncated, info = episode_env.step(action)
                best_action = observation
                episode_policy.update(action, reward)

                trace.append({
                    "episode": episode,
                    "step": step,
                    "policy": f"ε = {policy.epsilon}",
                    "is_action_optimal": action == best_action
                })
                step += 1
                progress.update()

                if terminated or truncated:
                    break

trace_df = pd.DataFrame(trace)

0%| | 0/6000000 [00:00<?, ?it/s]

0%| | 17069/6000000 [00:00<00:35, 170687.72it/s]

1%| | 34861/6000000 [00:00<00:34, 174940.85it/s]

1%| | 52571/6000000 [00:00<00:33, 175925.11it/s]

1%| | 70252/6000000 [00:00<00:33, 176273.02it/s]

1%|▏ | 87905/6000000 [00:00<00:33, 176365.45it/s]

2%|▏ | 105542/6000000 [00:00<00:33, 176121.21it/s]

2%|▏ | 123160/6000000 [00:00<00:33, 176138.13it/s]

2%|▏ | 141001/6000000 [00:00<00:33, 176655.42it/s]

3%|▎ | 159001/6000000 [00:00<00:32, 177651.86it/s]

3%|▎ | 176767/6000000 [00:01<00:32, 177169.43it/s]

3%|▎ | 194713/6000000 [00:01<00:32, 177867.13it/s]

4%|▎ | 212501/6000000 [00:01<00:32, 177131.16it/s]

4%|▍ | 230237/6000000 [00:01<00:32, 177198.72it/s]

4%|▍ | 248228/6000000 [00:01<00:32, 178013.78it/s]

4%|▍ | 266031/6000000 [00:01<00:32, 177009.20it/s]

5%|▍ | 283819/6000000 [00:01<00:32, 177268.18it/s]

5%|▌ | 301548/6000000 [00:01<00:32, 176771.15it/s]

5%|▌ | 319393/6000000 [00:01<00:32, 177272.83it/s]

6%|▌ | 337145/6000000 [00:01<00:31, 177345.61it/s]

6%|▌ | 354881/6000000 [00:02<00:31, 177242.63it/s]

6%|▌ | 372685/6000000 [00:02<00:31, 177478.88it/s]

7%|▋ | 390434/6000000 [00:02<00:31, 176855.11it/s]

7%|▋ | 408255/6000000 [00:02<00:31, 177258.94it/s]

7%|▋ | 426155/6000000 [00:02<00:31, 177778.36it/s]

7%|▋ | 443934/6000000 [00:02<00:31, 177745.46it/s]

8%|▊ | 461709/6000000 [00:02<00:31, 177180.85it/s]

8%|▊ | 479428/6000000 [00:02<00:31, 176356.21it/s]

8%|▊ | 497065/6000000 [00:02<00:31, 175537.73it/s]

9%|▊ | 514843/6000000 [00:02<00:31, 176204.32it/s]

9%|▉ | 532465/6000000 [00:03<00:31, 176177.58it/s]

9%|▉ | 550276/6000000 [00:03<00:30, 176752.74it/s]

9%|▉ | 567987/6000000 [00:03<00:30, 176858.31it/s]

10%|▉ | 585998/6000000 [00:03<00:30, 177830.48it/s]

10%|█ | 603799/6000000 [00:03<00:30, 177883.81it/s]

10%|█ | 621588/6000000 [00:03<00:30, 177207.57it/s]

11%|█ | 639331/6000000 [00:03<00:30, 177272.49it/s]

11%|█ | 657059/6000000 [00:03<00:30, 177159.37it/s]

11%|█ | 674952/6000000 [00:03<00:29, 177686.72it/s]

12%|█▏ | 692722/6000000 [00:03<00:29, 177440.07it/s]

12%|█▏ | 710467/6000000 [00:04<00:29, 176615.82it/s]

12%|█▏ | 728130/6000000 [00:04<00:29, 176373.50it/s]

12%|█▏ | 745961/6000000 [00:04<00:29, 176949.76it/s]

13%|█▎ | 763657/6000000 [00:04<00:29, 175796.02it/s]

13%|█▎ | 781239/6000000 [00:04<00:29, 174352.15it/s]

13%|█▎ | 798678/6000000 [00:04<00:29, 173859.78it/s]

14%|█▎ | 816066/6000000 [00:04<00:29, 173347.87it/s]

14%|█▍ | 833403/6000000 [00:04<00:29, 173082.66it/s]

14%|█▍ | 850713/6000000 [00:04<00:29, 172899.41it/s]

14%|█▍ | 868004/6000000 [00:04<00:29, 172705.38it/s]

15%|█▍ | 885319/6000000 [00:05<00:29, 172837.55it/s]

15%|█▌ | 902603/6000000 [00:05<00:29, 172758.78it/s]

15%|█▌ | 919880/6000000 [00:05<00:29, 171936.05it/s]

16%|█▌ | 937075/6000000 [00:05<00:29, 171392.11it/s]

16%|█▌ | 954215/6000000 [00:05<00:29, 171301.21it/s]

16%|█▌ | 971390/6000000 [00:05<00:29, 171434.92it/s]

16%|█▋ | 988605/6000000 [00:05<00:29, 171648.20it/s]

17%|█▋ | 1005988/6000000 [00:05<00:28, 172300.23it/s]

17%|█▋ | 1023219/6000000 [00:05<00:28, 171786.85it/s]

17%|█▋ | 1040569/6000000 [00:05<00:28, 172297.73it/s]

18%|█▊ | 1057800/6000000 [00:06<00:28, 172265.31it/s]

18%|█▊ | 1075027/6000000 [00:06<00:28, 172004.92it/s]

18%|█▊ | 1092422/6000000 [00:06<00:28, 172584.71it/s]

18%|█▊ | 1109688/6000000 [00:06<00:28, 172607.18it/s]

19%|█▉ | 1127011/6000000 [00:06<00:28, 172793.00it/s]

19%|█▉ | 1144291/6000000 [00:06<00:28, 172688.52it/s]

19%|█▉ | 1161560/6000000 [00:06<00:28, 172127.42it/s]

20%|█▉ | 1178943/6000000 [00:06<00:27, 172636.11it/s]

20%|█▉ | 1196226/6000000 [00:06<00:27, 172693.59it/s]

20%|██ | 1213728/6000000 [00:06<00:27, 173390.30it/s]

21%|██ | 1231068/6000000 [00:07<00:27, 172822.26it/s]

21%|██ | 1248397/6000000 [00:07<00:27, 172960.85it/s]

21%|██ | 1265731/6000000 [00:07<00:27, 173071.81it/s]

21%|██▏ | 1283039/6000000 [00:07<00:27, 172670.69it/s]

22%|██▏ | 1300509/6000000 [00:07<00:27, 173276.91it/s]

22%|██▏ | 1317838/6000000 [00:07<00:27, 172990.29it/s]

22%|██▏ | 1335138/6000000 [00:07<00:27, 172616.26it/s]

23%|██▎ | 1352524/6000000 [00:07<00:26, 172987.22it/s]

23%|██▎ | 1369860/6000000 [00:07<00:26, 173095.64it/s]

23%|██▎ | 1387170/6000000 [00:07<00:26, 172832.75it/s]

23%|██▎ | 1404512/6000000 [00:08<00:26, 173008.46it/s]

24%|██▎ | 1421837/6000000 [00:08<00:26, 173078.51it/s]

24%|██▍ | 1439146/6000000 [00:08<00:26, 173016.74it/s]

24%|██▍ | 1456585/6000000 [00:08<00:26, 173426.98it/s]

25%|██▍ | 1473936/6000000 [00:08<00:26, 173451.54it/s]

25%|██▍ | 1491282/6000000 [00:08<00:26, 172334.61it/s]

25%|██▌ | 1508518/6000000 [00:08<00:26, 172317.83it/s]

25%|██▌ | 1525751/6000000 [00:08<00:26, 170751.69it/s]

26%|██▌ | 1543001/6000000 [00:08<00:26, 171191.25it/s]

26%|██▌ | 1560123/6000000 [00:08<00:25, 171136.39it/s]

26%|██▋ | 1577239/6000000 [00:09<00:25, 171060.81it/s]

27%|██▋ | 1594347/6000000 [00:09<00:25, 170947.34it/s]

27%|██▋ | 1611443/6000000 [00:09<00:25, 170787.75it/s]

27%|██▋ | 1628691/6000000 [00:09<00:25, 171293.17it/s]

27%|██▋ | 1645917/6000000 [00:09<00:25, 171581.96it/s]

28%|██▊ | 1663160/6000000 [00:09<00:25, 171834.10it/s]

28%|██▊ | 1680477/6000000 [00:09<00:25, 172232.80it/s]

28%|██▊ | 1697791/6000000 [00:09<00:24, 172504.45it/s]

29%|██▊ | 1715042/6000000 [00:09<00:24, 172153.43it/s]

29%|██▉ | 1732551/6000000 [00:09<00:24, 173030.91it/s]

29%|██▉ | 1749855/6000000 [00:10<00:24, 172965.60it/s]

29%|██▉ | 1767158/6000000 [00:10<00:24, 172982.85it/s]

30%|██▉ | 1784613/6000000 [00:10<00:24, 173450.68it/s]

30%|███ | 1801959/6000000 [00:10<00:24, 172997.57it/s]

30%|███ | 1819260/6000000 [00:10<00:24, 172754.64it/s]

31%|███ | 1836536/6000000 [00:10<00:24, 172623.04it/s]

31%|███ | 1853799/6000000 [00:10<00:24, 172520.90it/s]

31%|███ | 1871183/6000000 [00:10<00:23, 172915.12it/s]

31%|███▏ | 1888482/6000000 [00:10<00:23, 172936.86it/s]

32%|███▏ | 1905853/6000000 [00:10<00:23, 173167.01it/s]

32%|███▏ | 1923170/6000000 [00:11<00:23, 172782.16it/s]

32%|███▏ | 1940648/6000000 [00:11<00:23, 173380.08it/s]

33%|███▎ | 1958097/6000000 [00:11<00:23, 173710.76it/s]

33%|███▎ | 1975563/6000000 [00:11<00:23, 173994.55it/s]

33%|███▎ | 1993100/6000000 [00:11<00:22, 174404.32it/s]

34%|███▎ | 2010541/6000000 [00:11<00:23, 172334.95it/s]

34%|███▍ | 2027781/6000000 [00:11<00:23, 170117.46it/s]

34%|███▍ | 2044802/6000000 [00:11<00:23, 168950.85it/s]

34%|███▍ | 2061704/6000000 [00:11<00:23, 167412.27it/s]

35%|███▍ | 2078451/6000000 [00:11<00:23, 166645.68it/s]

35%|███▍ | 2095119/6000000 [00:12<00:23, 166018.61it/s]

35%|███▌ | 2111839/6000000 [00:12<00:23, 166367.33it/s]

35%|███▌ | 2128478/6000000 [00:12<00:23, 166009.18it/s]

36%|███▌ | 2145080/6000000 [00:12<00:23, 165359.89it/s]

36%|███▌ | 2161617/6000000 [00:12<00:23, 165166.17it/s]

36%|███▋ | 2178135/6000000 [00:12<00:23, 164602.28it/s]

37%|███▋ | 2194789/6000000 [00:12<00:23, 165178.53it/s]

37%|███▋ | 2211308/6000000 [00:12<00:22, 165111.39it/s]

37%|███▋ | 2227820/6000000 [00:12<00:22, 165101.77it/s]

37%|███▋ | 2244331/6000000 [00:12<00:22, 165039.42it/s]

38%|███▊ | 2260836/6000000 [00:13<00:22, 164976.70it/s]

38%|███▊ | 2277334/6000000 [00:13<00:22, 164281.03it/s]

38%|███▊ | 2293794/6000000 [00:13<00:22, 164375.90it/s]

39%|███▊ | 2310233/6000000 [00:13<00:22, 164142.07it/s]

39%|███▉ | 2326813/6000000 [00:13<00:22, 164637.49it/s]

39%|███▉ | 2343278/6000000 [00:13<00:22, 163866.24it/s]

39%|███▉ | 2359871/6000000 [00:13<00:22, 164481.61it/s]

40%|███▉ | 2376343/6000000 [00:13<00:22, 164551.76it/s]

40%|███▉ | 2392799/6000000 [00:13<00:21, 164512.84it/s]

40%|████ | 2409251/6000000 [00:13<00:21, 164438.01it/s]

40%|████ | 2425696/6000000 [00:14<00:21, 164351.40it/s]

41%|████ | 2442132/6000000 [00:14<00:21, 163537.63it/s]

41%|████ | 2458671/6000000 [00:14<00:21, 164088.41it/s]

41%|████▏ | 2475081/6000000 [00:14<00:21, 164079.64it/s]

42%|████▏ | 2491490/6000000 [00:14<00:21, 163943.10it/s]

42%|████▏ | 2507941/6000000 [00:14<00:21, 164111.44it/s]

42%|████▏ | 2524353/6000000 [00:14<00:21, 163637.95it/s]

42%|████▏ | 2540718/6000000 [00:14<00:21, 162028.90it/s]

43%|████▎ | 2557149/6000000 [00:14<00:21, 162706.36it/s]

43%|████▎ | 2573530/6000000 [00:14<00:21, 163033.07it/s]

43%|████▎ | 2590036/6000000 [00:15<00:20, 163637.30it/s]

43%|████▎ | 2606667/6000000 [00:15<00:20, 164434.29it/s]

44%|████▎ | 2623113/6000000 [00:15<00:20, 164396.51it/s]

44%|████▍ | 2639633/6000000 [00:15<00:20, 164635.65it/s]

44%|████▍ | 2656098/6000000 [00:15<00:20, 164224.81it/s]

45%|████▍ | 2672579/6000000 [00:15<00:20, 164397.14it/s]

45%|████▍ | 2689037/6000000 [00:15<00:20, 164449.83it/s]

45%|████▌ | 2705483/6000000 [00:15<00:22, 146891.75it/s]

45%|████▌ | 2721968/6000000 [00:15<00:21, 151857.83it/s]

46%|████▌ | 2738440/6000000 [00:16<00:20, 155498.41it/s]

46%|████▌ | 2754815/6000000 [00:16<00:20, 157873.19it/s]

46%|████▌ | 2771361/6000000 [00:16<00:20, 160083.19it/s]

46%|████▋ | 2787984/6000000 [00:16<00:19, 161889.24it/s]

47%|████▋ | 2804529/6000000 [00:16<00:19, 162939.97it/s]

47%|████▋ | 2820887/6000000 [00:16<00:19, 162577.82it/s]

47%|████▋ | 2837496/6000000 [00:16<00:19, 163620.09it/s]

48%|████▊ | 2854001/6000000 [00:16<00:19, 164021.25it/s]

48%|████▊ | 2870426/6000000 [00:16<00:19, 163994.33it/s]

48%|████▊ | 2887148/6000000 [00:16<00:18, 164958.39it/s]

48%|████▊ | 2903807/6000000 [00:17<00:18, 165444.88it/s]

49%|████▊ | 2920588/6000000 [00:17<00:18, 166151.27it/s]

49%|████▉ | 2937209/6000000 [00:17<00:18, 165573.75it/s]

49%|████▉ | 2954001/6000000 [00:17<00:18, 166060.79it/s]

50%|████▉ | 2970611/6000000 [00:17<00:18, 165998.12it/s]

50%|████▉ | 2987256/6000000 [00:17<00:18, 166130.77it/s]

50%|█████ | 3004001/6000000 [00:17<00:18, 166349.02it/s]

50%|█████ | 3020705/6000000 [00:17<00:17, 166553.80it/s]

51%|█████ | 3037362/6000000 [00:17<00:17, 165449.80it/s]

51%|█████ | 3053910/6000000 [00:17<00:17, 163976.85it/s]

51%|█████ | 3070312/6000000 [00:18<00:17, 163961.04it/s]

51%|█████▏ | 3086711/6000000 [00:18<00:17, 163801.88it/s]

52%|█████▏ | 3103390/6000000 [00:18<00:17, 164691.60it/s]

52%|█████▏ | 3119950/6000000 [00:18<00:17, 164961.40it/s]

52%|█████▏ | 3136448/6000000 [00:18<00:17, 164695.56it/s]

53%|█████▎ | 3152919/6000000 [00:18<00:17, 164191.98it/s]

53%|█████▎ | 3169422/6000000 [00:18<00:17, 164439.99it/s]

53%|█████▎ | 3185970/6000000 [00:18<00:17, 164748.96it/s]

53%|█████▎ | 3202446/6000000 [00:18<00:16, 164612.04it/s]

54%|█████▎ | 3218908/6000000 [00:18<00:16, 164041.58it/s]

54%|█████▍ | 3235313/6000000 [00:19<00:16, 163768.66it/s]

54%|█████▍ | 3251797/6000000 [00:19<00:16, 164087.33it/s]

54%|█████▍ | 3268207/6000000 [00:19<00:16, 164076.19it/s]

55%|█████▍ | 3284615/6000000 [00:19<00:16, 163844.51it/s]

55%|█████▌ | 3301033/6000000 [00:19<00:16, 163940.99it/s]

55%|█████▌ | 3317428/6000000 [00:19<00:16, 163934.14it/s]

56%|█████▌ | 3333985/6000000 [00:19<00:16, 164422.34it/s]

56%|█████▌ | 3350446/6000000 [00:19<00:16, 164475.94it/s]

56%|█████▌ | 3367001/6000000 [00:19<00:15, 164702.41it/s]

56%|█████▋ | 3383472/6000000 [00:19<00:15, 164586.05it/s]

57%|█████▋ | 3400001/6000000 [00:20<00:15, 164705.97it/s]

57%|█████▋ | 3416472/6000000 [00:20<00:15, 164275.09it/s]

57%|█████▋ | 3432900/6000000 [00:20<00:15, 163683.67it/s]

57%|█████▋ | 3449269/6000000 [00:20<00:15, 163530.49it/s]

58%|█████▊ | 3465694/6000000 [00:20<00:15, 163744.53it/s]

58%|█████▊ | 3482069/6000000 [00:20<00:15, 163636.25it/s]

58%|█████▊ | 3498553/6000000 [00:20<00:15, 163993.78it/s]

59%|█████▊ | 3515001/6000000 [00:20<00:15, 163939.04it/s]

59%|█████▉ | 3531396/6000000 [00:20<00:15, 162740.83it/s]

59%|█████▉ | 3548001/6000000 [00:20<00:14, 163559.20it/s]

59%|█████▉ | 3564361/6000000 [00:21<00:14, 163568.86it/s]

60%|█████▉ | 3580808/6000000 [00:21<00:14, 163835.09it/s]

60%|█████▉ | 3597211/6000000 [00:21<00:14, 163891.70it/s]

60%|██████ | 3613755/6000000 [00:21<00:14, 164353.27it/s]

61%|██████ | 3630191/6000000 [00:21<00:14, 163875.85it/s]

61%|██████ | 3646674/6000000 [00:21<00:14, 164160.24it/s]

61%|██████ | 3663091/6000000 [00:21<00:14, 163691.30it/s]

61%|██████▏ | 3679582/6000000 [00:21<00:14, 164055.22it/s]

62%|██████▏ | 3696246/6000000 [00:21<00:13, 164828.70it/s]

62%|██████▏ | 3712730/6000000 [00:21<00:13, 164741.79it/s]

62%|██████▏ | 3729205/6000000 [00:22<00:13, 164105.48it/s]

62%|██████▏ | 3745617/6000000 [00:22<00:13, 163670.33it/s]

63%|██████▎ | 3761985/6000000 [00:22<00:13, 163396.49it/s]

63%|██████▎ | 3778326/6000000 [00:22<00:13, 163097.32it/s]

63%|██████▎ | 3794637/6000000 [00:22<00:13, 163049.83it/s]

64%|██████▎ | 3810943/6000000 [00:22<00:13, 162553.79it/s]

64%|██████▍ | 3827265/6000000 [00:22<00:13, 162750.73it/s]

64%|██████▍ | 3843552/6000000 [00:22<00:13, 162785.40it/s]

64%|██████▍ | 3859898/6000000 [00:22<00:13, 162986.22it/s]

65%|██████▍ | 3876549/6000000 [00:22<00:12, 164039.20it/s]

65%|██████▍ | 3893001/6000000 [00:23<00:12, 164081.69it/s]

65%|██████▌ | 3909617/6000000 [00:23<00:12, 164702.95it/s]

65%|██████▌ | 3926088/6000000 [00:23<00:12, 164466.04it/s]

66%|██████▌ | 3942789/6000000 [00:23<00:12, 165226.16it/s]

66%|██████▌ | 3959312/6000000 [00:23<00:12, 165035.05it/s]

66%|██████▋ | 3975816/6000000 [00:23<00:12, 164940.59it/s]

67%|██████▋ | 3992476/6000000 [00:23<00:12, 165436.75it/s]

67%|██████▋ | 4009238/6000000 [00:23<00:11, 166088.57it/s]

67%|██████▋ | 4026048/6000000 [00:23<00:11, 166689.21it/s]

67%|██████▋ | 4042840/6000000 [00:23<00:11, 167056.02it/s]

68%|██████▊ | 4059570/6000000 [00:24<00:11, 167126.72it/s]

68%|██████▊ | 4076501/6000000 [00:24<00:11, 167780.85it/s]

68%|██████▊ | 4093539/6000000 [00:24<00:11, 168558.56it/s]

69%|██████▊ | 4110588/6000000 [00:24<00:11, 169136.14it/s]

69%|██████▉ | 4127502/6000000 [00:24<00:11, 168774.27it/s]

69%|██████▉ | 4144380/6000000 [00:24<00:11, 168196.97it/s]

69%|██████▉ | 4161277/6000000 [00:24<00:10, 168426.19it/s]

70%|██████▉ | 4178294/6000000 [00:24<00:10, 168946.76it/s]

70%|██████▉ | 4195190/6000000 [00:24<00:10, 168526.97it/s]

70%|███████ | 4212243/6000000 [00:24<00:10, 169124.21it/s]

70%|███████ | 4229182/6000000 [00:25<00:10, 169202.26it/s]

71%|███████ | 4246103/6000000 [00:25<00:10, 169119.35it/s]

71%|███████ | 4263016/6000000 [00:25<00:10, 168711.07it/s]

71%|███████▏ | 4280026/6000000 [00:25<00:10, 169125.60it/s]

72%|███████▏ | 4296939/6000000 [00:25<00:10, 169088.94it/s]

72%|███████▏ | 4313849/6000000 [00:25<00:10, 167669.90it/s]

72%|███████▏ | 4330906/6000000 [00:25<00:09, 168531.44it/s]

72%|███████▏ | 4347924/6000000 [00:25<00:09, 169023.16it/s]

73%|███████▎ | 4364922/6000000 [00:25<00:09, 169307.83it/s]

73%|███████▎ | 4381855/6000000 [00:25<00:09, 168867.73it/s]

73%|███████▎ | 4398904/6000000 [00:26<00:09, 169352.26it/s]

74%|███████▎ | 4415841/6000000 [00:26<00:09, 169284.90it/s]

74%|███████▍ | 4432771/6000000 [00:26<00:09, 168679.87it/s]

74%|███████▍ | 4449782/6000000 [00:26<00:09, 169106.66it/s]

74%|███████▍ | 4466730/6000000 [00:26<00:09, 169218.04it/s]

75%|███████▍ | 4483653/6000000 [00:26<00:08, 168901.42it/s]

75%|███████▌ | 4500622/6000000 [00:26<00:08, 169135.72it/s]

75%|███████▌ | 4517538/6000000 [00:26<00:08, 169141.59it/s]

76%|███████▌ | 4534595/6000000 [00:26<00:08, 169569.47it/s]

76%|███████▌ | 4551553/6000000 [00:26<00:08, 169400.66it/s]

76%|███████▌ | 4568494/6000000 [00:27<00:08, 169073.69it/s]

76%|███████▋ | 4585402/6000000 [00:27<00:08, 168299.43it/s]

77%|███████▋ | 4602233/6000000 [00:27<00:08, 168159.60it/s]

77%|███████▋ | 4619235/6000000 [00:27<00:08, 168714.46it/s]

77%|███████▋ | 4636232/6000000 [00:27<00:08, 169089.27it/s]

78%|███████▊ | 4653142/6000000 [00:27<00:07, 168546.53it/s]

78%|███████▊ | 4670001/6000000 [00:27<00:07, 168549.60it/s]

78%|███████▊ | 4686857/6000000 [00:27<00:07, 168413.27it/s]

78%|███████▊ | 4703707/6000000 [00:27<00:07, 168437.71it/s]

79%|███████▊ | 4720722/6000000 [00:27<00:07, 168948.53it/s]

79%|███████▉ | 4737618/6000000 [00:28<00:07, 167199.82it/s]

79%|███████▉ | 4754590/6000000 [00:28<00:07, 167948.36it/s]

80%|███████▉ | 4771415/6000000 [00:28<00:07, 168037.55it/s]

80%|███████▉ | 4788348/6000000 [00:28<00:07, 168421.02it/s]

80%|████████ | 4805365/6000000 [00:28<00:07, 168942.00it/s]

80%|████████ | 4822390/6000000 [00:28<00:06, 169333.55it/s]

81%|████████ | 4839600/6000000 [00:28<00:06, 170162.02it/s]

81%|████████ | 4856618/6000000 [00:28<00:06, 169559.11it/s]

81%|████████ | 4873576/6000000 [00:28<00:06, 169564.80it/s]

82%|████████▏ | 4890583/6000000 [00:28<00:06, 169714.98it/s]

82%|████████▏ | 4907556/6000000 [00:29<00:06, 169243.33it/s]

82%|████████▏ | 4924481/6000000 [00:29<00:06, 169184.45it/s]

82%|████████▏ | 4941400/6000000 [00:29<00:06, 168929.89it/s]

83%|████████▎ | 4958386/6000000 [00:29<00:06, 169207.73it/s]

83%|████████▎ | 4975324/6000000 [00:29<00:06, 169259.00it/s]

83%|████████▎ | 4992251/6000000 [00:29<00:05, 169012.68it/s]

83%|████████▎ | 5009172/6000000 [00:29<00:05, 169070.96it/s]

84%|████████▍ | 5026314/6000000 [00:29<00:05, 169773.05it/s]

84%|████████▍ | 5043292/6000000 [00:29<00:05, 169130.55it/s]

84%|████████▍ | 5060206/6000000 [00:29<00:05, 168998.25it/s]

85%|████████▍ | 5077107/6000000 [00:30<00:05, 168427.29it/s]

85%|████████▍ | 5094080/6000000 [00:30<00:05, 168814.26it/s]

85%|████████▌ | 5111358/6000000 [00:30<00:05, 169998.40it/s]

85%|████████▌ | 5128359/6000000 [00:30<00:05, 169349.62it/s]

86%|████████▌ | 5145519/6000000 [00:30<00:05, 170019.39it/s]

86%|████████▌ | 5162522/6000000 [00:30<00:04, 169429.84it/s]

86%|████████▋ | 5179466/6000000 [00:30<00:04, 169327.91it/s]

87%|████████▋ | 5196400/6000000 [00:30<00:04, 169326.99it/s]

87%|████████▋ | 5213334/6000000 [00:30<00:04, 168790.12it/s]

87%|████████▋ | 5230293/6000000 [00:30<00:04, 169028.44it/s]

87%|████████▋ | 5247303/6000000 [00:31<00:04, 169346.40it/s]

88%|████████▊ | 5264408/6000000 [00:31<00:04, 169855.30it/s]

88%|████████▊ | 5281411/6000000 [00:31<00:04, 169907.44it/s]

88%|████████▊ | 5298403/6000000 [00:31<00:04, 169404.66it/s]

89%|████████▊ | 5315400/6000000 [00:31<00:04, 169571.79it/s]

89%|████████▉ | 5332358/6000000 [00:31<00:03, 169425.29it/s]

89%|████████▉ | 5349301/6000000 [00:31<00:03, 169231.17it/s]

89%|████████▉ | 5366246/6000000 [00:31<00:03, 169295.63it/s]

90%|████████▉ | 5383260/6000000 [00:31<00:03, 169548.36it/s]

90%|█████████ | 5400215/6000000 [00:31<00:03, 169421.07it/s]

90%|█████████ | 5417183/6000000 [00:32<00:03, 169497.13it/s]

91%|█████████ | 5434133/6000000 [00:32<00:03, 168485.94it/s]

91%|█████████ | 5450983/6000000 [00:32<00:03, 168456.97it/s]

91%|█████████ | 5468210/6000000 [00:32<00:03, 169595.33it/s]

91%|█████████▏| 5485171/6000000 [00:32<00:03, 169207.41it/s]

92%|█████████▏| 5502093/6000000 [00:32<00:02, 168365.55it/s]

92%|█████████▏| 5518999/6000000 [00:32<00:02, 168571.76it/s]

92%|█████████▏| 5536085/6000000 [00:32<00:02, 169255.10it/s]

93%|█████████▎| 5553012/6000000 [00:32<00:02, 168551.21it/s]

93%|█████████▎| 5570030/6000000 [00:32<00:02, 169036.99it/s]

93%|█████████▎| 5587070/6000000 [00:33<00:02, 169442.35it/s]

93%|█████████▎| 5604016/6000000 [00:33<00:02, 168946.70it/s]

94%|█████████▎| 5620912/6000000 [00:33<00:02, 168863.73it/s]

94%|█████████▍| 5637799/6000000 [00:33<00:02, 168756.50it/s]

94%|█████████▍| 5654888/6000000 [00:33<00:02, 169392.54it/s]

95%|█████████▍| 5671834/6000000 [00:33<00:01, 169410.01it/s]

95%|█████████▍| 5688776/6000000 [00:33<00:01, 169369.25it/s]

95%|█████████▌| 5705779/6000000 [00:33<00:01, 169567.19it/s]

95%|█████████▌| 5722754/6000000 [00:33<00:01, 169619.98it/s]

96%|█████████▌| 5739905/6000000 [00:33<00:01, 170184.79it/s]

96%|█████████▌| 5756924/6000000 [00:34<00:01, 169321.05it/s]

96%|█████████▌| 5774088/6000000 [00:34<00:01, 170012.91it/s]

97%|█████████▋| 5791120/6000000 [00:34<00:01, 170102.65it/s]

97%|█████████▋| 5808131/6000000 [00:34<00:01, 170058.89it/s]

97%|█████████▋| 5825138/6000000 [00:34<00:01, 169008.30it/s]

97%|█████████▋| 5842041/6000000 [00:34<00:00, 168405.15it/s]

98%|█████████▊| 5858883/6000000 [00:34<00:00, 168209.38it/s]

98%|█████████▊| 5875859/6000000 [00:34<00:00, 168669.53it/s]

98%|█████████▊| 5892727/6000000 [00:34<00:00, 168562.31it/s]

98%|█████████▊| 5909584/6000000 [00:34<00:00, 168209.02it/s]

99%|█████████▉| 5926631/6000000 [00:35<00:00, 168883.35it/s]

99%|█████████▉| 5943520/6000000 [00:35<00:00, 168436.07it/s]

99%|█████████▉| 5960365/6000000 [00:35<00:00, 168357.14it/s]

100%|█████████▉| 5977249/6000000 [00:35<00:00, 168500.78it/s]

100%|█████████▉| 5994314/6000000 [00:35<00:00, 169142.84it/s]

100%|██████████| 6000000/6000000 [00:35<00:00, 168910.64it/s]

colors = {
    'ε = 0.1': 'tab:blue',
    'ε = 0.01': 'tab:red',
    'ε = 0': 'tab:green'
}

(
    trace_df
    .groupby(['step', 'policy'])
    ['is_action_optimal'].mean()
    .unstack()
    .plot()
)


<Axes: xlabel='step'>


png

Handling drift

The environment used above is a toy situation used for introducing bandits. It is stationary, meaning that the expected reward of each arm does not change over time.

In practice, arms are dynamic, and their performance can vary over time. A simple example of this is the Candy Cane Contest that was hosted on Kaggle in 2020. The expected reward of each arm diminishes each time it is pulled.

The way bandit policies in River deal with drift depends on the method. For the bandit.EpsilonGreedy policy, it makes sense to use a rolling average as the reward object. What this means is that the empirical reward the policy calculates for each arm is a rolling average, rather than a global one.

from river import proba, utils

policies=[
    bandit.EpsilonGreedy(
        epsilon=0.1,
        seed=42
    ),
    bandit.EpsilonGreedy(
        epsilon=0.3,
        reward_obj=utils.Rolling(stats.Mean(), window_size=50),
        seed=42
    ),
    bandit.ThompsonSampling(
        reward_obj=proba.Beta(),
        seed=42
    )
]

env = gym.make('river_bandits/CandyCaneContest-v0')

trace = bandit.evaluate(
    policies=policies,
    env=env,
    n_episodes=(n_episodes := 30),
    seed=42
)

trace_df = pd.DataFrame(tqdm(
    trace, position=0, total=(
        n_episodes *
        len(policies) *
        env._max_episode_steps
    )
))

0%| | 0/180000 [00:00<?, ?it/s]

1%| | 1265/180000 [00:00<00:14, 12648.90it/s]

1%|▏ | 2530/180000 [00:00<00:14, 12455.50it/s]

2%|▏ | 3809/180000 [00:00<00:13, 12605.82it/s]

3%|▎ | 5094/180000 [00:00<00:13, 12701.33it/s]

4%|▎ | 6365/180000 [00:00<00:14, 12401.37it/s]

4%|▍ | 7686/180000 [00:00<00:13, 12668.29it/s]

5%|▌ | 9003/180000 [00:00<00:13, 12829.82it/s]

6%|▌ | 10306/180000 [00:00<00:13, 12892.25it/s]

6%|▋ | 11600/180000 [00:00<00:13, 12906.67it/s]

7%|▋ | 12892/180000 [00:01<00:13, 12651.35it/s]

8%|▊ | 14159/180000 [00:01<00:13, 12568.17it/s]

9%|▊ | 15417/180000 [00:01<00:13, 12521.25it/s]

9%|▉ | 16722/180000 [00:01<00:12, 12677.30it/s]

10%|▉ | 17991/180000 [00:01<00:12, 12570.86it/s]

11%|█ | 19263/180000 [00:01<00:12, 12612.01it/s]

11%|█▏ | 20575/180000 [00:01<00:12, 12762.27it/s]

12%|█▏ | 21882/180000 [00:01<00:12, 12853.92it/s]

13%|█▎ | 23168/180000 [00:01<00:12, 12659.23it/s]

14%|█▎ | 24435/180000 [00:01<00:12, 12455.11it/s]

14%|█▍ | 25758/180000 [00:02<00:12, 12678.03it/s]

15%|█▌ | 27057/180000 [00:02<00:11, 12765.28it/s]

16%|█▌ | 28335/180000 [00:02<00:12, 12588.44it/s]

16%|█▋ | 29595/180000 [00:02<00:12, 12442.13it/s]

17%|█▋ | 30841/180000 [00:02<00:12, 12282.95it/s]

18%|█▊ | 32146/180000 [00:02<00:11, 12507.44it/s]

19%|█▊ | 33457/180000 [00:02<00:11, 12683.96it/s]

19%|█▉ | 34758/180000 [00:02<00:11, 12777.23it/s]

20%|██ | 36037/180000 [00:02<00:11, 12544.63it/s]

21%|██ | 37416/180000 [00:02<00:11, 12909.30it/s]

22%|██▏ | 38745/180000 [00:03<00:10, 13021.19it/s]

22%|██▏ | 40077/180000 [00:03<00:10, 13109.81it/s]

23%|██▎ | 41403/180000 [00:03<00:10, 13151.57it/s]

24%|██▎ | 42719/180000 [00:03<00:10, 13034.56it/s]

24%|██▍ | 44035/180000 [00:03<00:10, 13071.25it/s]

25%|██▌ | 45343/180000 [00:03<00:10, 13050.98it/s]

26%|██▌ | 46650/180000 [00:03<00:10, 13049.17it/s]

27%|██▋ | 47961/180000 [00:03<00:10, 13066.53it/s]

27%|██▋ | 49268/180000 [00:03<00:10, 12932.90it/s]

28%|██▊ | 50562/180000 [00:03<00:10, 12919.98it/s]

29%|██▉ | 51860/180000 [00:04<00:09, 12937.77it/s]

30%|██▉ | 53160/180000 [00:04<00:09, 12951.85it/s]

30%|███ | 54456/180000 [00:04<00:09, 12630.73it/s]

31%|███ | 55763/180000 [00:04<00:09, 12758.59it/s]

32%|███▏ | 57066/180000 [00:04<00:09, 12838.38it/s]

32%|███▏ | 58368/180000 [00:04<00:09, 12892.19it/s]

33%|███▎ | 59667/180000 [00:04<00:09, 12916.13it/s]

34%|███▍ | 60960/180000 [00:04<00:09, 12829.28it/s]

35%|███▍ | 62244/180000 [00:04<00:09, 12743.41it/s]

35%|███▌ | 63561/180000 [00:04<00:09, 12864.21it/s]

36%|███▌ | 64878/180000 [00:05<00:08, 12954.23it/s]

37%|███▋ | 66174/180000 [00:05<00:09, 12589.16it/s]

38%|███▊ | 67515/180000 [00:05<00:08, 12824.95it/s]

38%|███▊ | 68823/180000 [00:05<00:08, 12894.47it/s]

39%|███▉ | 70128/180000 [00:05<00:08, 12935.82it/s]

40%|███▉ | 71427/180000 [00:05<00:08, 12949.47it/s]

40%|████ | 72723/180000 [00:05<00:08, 12689.49it/s]

41%|████ | 74026/180000 [00:05<00:08, 12788.94it/s]

42%|████▏ | 75327/180000 [00:05<00:08, 12849.70it/s]

43%|████▎ | 76613/180000 [00:05<00:08, 12839.86it/s]

43%|████▎ | 77917/180000 [00:06<00:07, 12899.12it/s]

44%|████▍ | 79208/180000 [00:06<00:07, 12830.83it/s]

45%|████▍ | 80511/180000 [00:06<00:07, 12888.83it/s]

45%|████▌ | 81806/180000 [00:06<00:07, 12906.42it/s]

46%|████▌ | 83097/180000 [00:06<00:07, 12863.03it/s]

47%|████▋ | 84384/180000 [00:06<00:07, 12752.46it/s]

48%|████▊ | 85705/180000 [00:06<00:07, 12887.38it/s]

48%|████▊ | 87009/180000 [00:06<00:07, 12927.40it/s]

49%|████▉ | 88303/180000 [00:06<00:07, 12690.55it/s]

50%|████▉ | 89574/180000 [00:07<00:07, 12529.10it/s]

50%|█████ | 90828/180000 [00:07<00:07, 12377.71it/s]

51%|█████ | 92139/180000 [00:07<00:06, 12591.80it/s]

52%|█████▏ | 93453/180000 [00:07<00:06, 12750.55it/s]

53%|█████▎ | 94755/180000 [00:07<00:06, 12825.95it/s]

53%|█████▎ | 96039/180000 [00:07<00:06, 12558.12it/s]

54%|█████▍ | 97341/180000 [00:07<00:06, 12691.80it/s]

55%|█████▍ | 98658/180000 [00:07<00:06, 12828.62it/s]

56%|█████▌ | 99943/180000 [00:07<00:06, 12657.83it/s]

56%|█████▌ | 101241/180000 [00:07<00:06, 12750.55it/s]

57%|█████▋ | 102518/180000 [00:08<00:06, 12685.62it/s]

58%|█████▊ | 103836/180000 [00:08<00:05, 12825.94it/s]

58%|█████▊ | 105121/180000 [00:08<00:05, 12832.45it/s]

59%|█████▉ | 106419/180000 [00:08<00:05, 12873.99it/s]

60%|█████▉ | 107712/180000 [00:08<00:05, 12887.95it/s]

61%|██████ | 109002/180000 [00:08<00:05, 12795.41it/s]

61%|██████▏ | 110304/180000 [00:08<00:05, 12859.09it/s]

62%|██████▏ | 111606/180000 [00:08<00:05, 12905.77it/s]

63%|██████▎ | 112923/180000 [00:08<00:05, 12981.88it/s]

63%|██████▎ | 114222/180000 [00:08<00:05, 12732.39it/s]

64%|██████▍ | 115521/180000 [00:09<00:05, 12807.12it/s]

65%|██████▍ | 116835/180000 [00:09<00:04, 12905.49it/s]

66%|██████▌ | 118143/180000 [00:09<00:04, 12955.07it/s]

66%|██████▋ | 119453/180000 [00:09<00:04, 12997.53it/s]

67%|██████▋ | 120754/180000 [00:09<00:04, 12711.38it/s]

68%|██████▊ | 122081/180000 [00:09<00:04, 12874.91it/s]

69%|██████▊ | 123405/180000 [00:09<00:04, 12978.06it/s]

69%|██████▉ | 124728/180000 [00:09<00:04, 13050.45it/s]

70%|███████ | 126034/180000 [00:09<00:04, 12771.70it/s]

71%|███████ | 127420/180000 [00:09<00:04, 13091.54it/s]

72%|███████▏ | 128732/180000 [00:10<00:03, 13063.86it/s]

72%|███████▏ | 130040/180000 [00:10<00:03, 12812.51it/s]

73%|███████▎ | 131334/180000 [00:10<00:03, 12844.36it/s]

74%|███████▎ | 132620/180000 [00:10<00:03, 12795.77it/s]

74%|███████▍ | 133921/180000 [00:10<00:03, 12858.06it/s]

75%|███████▌ | 135210/180000 [00:10<00:03, 12867.26it/s]

76%|███████▌ | 136506/180000 [00:10<00:03, 12888.26it/s]

77%|███████▋ | 137796/180000 [00:10<00:03, 12687.10it/s]

77%|███████▋ | 139066/180000 [00:10<00:03, 12630.10it/s]

78%|███████▊ | 140365/180000 [00:10<00:03, 12736.27it/s]

79%|███████▊ | 141667/180000 [00:11<00:02, 12820.38it/s]

79%|███████▉ | 142965/180000 [00:11<00:02, 12863.04it/s]

80%|████████ | 144252/180000 [00:11<00:02, 12523.85it/s]

81%|████████ | 145587/180000 [00:11<00:02, 12762.62it/s]

82%|████████▏ | 146866/180000 [00:11<00:02, 12692.60it/s]

82%|████████▏ | 148185/180000 [00:11<00:02, 12836.84it/s]

83%|████████▎ | 149470/180000 [00:11<00:02, 12818.82it/s]

84%|████████▍ | 150753/180000 [00:11<00:02, 12764.74it/s]

84%|████████▍ | 152072/180000 [00:11<00:02, 12890.82it/s]

85%|████████▌ | 153381/180000 [00:11<00:02, 12946.77it/s]

86%|████████▌ | 154683/180000 [00:12<00:01, 12966.52it/s]

87%|████████▋ | 155980/180000 [00:12<00:01, 12709.75it/s]

87%|████████▋ | 157253/180000 [00:12<00:01, 12692.43it/s]

88%|████████▊ | 158524/180000 [00:12<00:01, 12527.25it/s]

89%|████████▉ | 159778/180000 [00:12<00:01, 12529.51it/s]

89%|████████▉ | 161064/180000 [00:12<00:01, 12621.70it/s]

90%|█████████ | 162327/180000 [00:12<00:01, 12289.27it/s]

91%|█████████ | 163657/180000 [00:12<00:01, 12583.76it/s]

92%|█████████▏| 164967/180000 [00:12<00:01, 12731.20it/s]

92%|█████████▏| 166266/180000 [00:13<00:01, 12806.50it/s]

93%|█████████▎| 167549/180000 [00:13<00:00, 12621.03it/s]

94%|█████████▍| 168813/180000 [00:13<00:00, 12381.29it/s]

94%|█████████▍| 170085/180000 [00:13<00:00, 12477.91it/s]

95%|█████████▌| 171393/180000 [00:13<00:00, 12649.24it/s]

96%|█████████▌| 172660/180000 [00:13<00:00, 12448.24it/s]

97%|█████████▋| 173907/180000 [00:13<00:00, 12366.06it/s]

97%|█████████▋| 175145/180000 [00:13<00:00, 12179.00it/s]

98%|█████████▊| 176457/180000 [00:13<00:00, 12448.10it/s]

99%|█████████▊| 177738/180000 [00:13<00:00, 12549.00it/s]

99%|█████████▉| 179052/180000 [00:14<00:00, 12720.05it/s]

100%|██████████| 180000/180000 [00:14<00:00, 12761.56it/s]

We can compare the performance of each policy by checking the average reward at the end of each episode.

(
    trace_df
    .groupby(['policy_idx', 'episode'])
    .last()
    .groupby('policy_idx')
    .reward_stat.mean()
)


policy_idx
0    736.1
1    817.0
2    854.0
Name: reward_stat, dtype: float64

We see that using a rolling average gives a boost to the epsilon greedy strategy. However, we see that the bandit.ThompsonSampling policy performs even better, even though no particular care was given to drift. A natural next step would thus be to see how it could be improved to handle drift. For instance, its dist parameter could be wrapped with a utils.Rolling:

policy = bandit.ThompsonSampling(
    reward_obj=utils.Rolling(proba.Beta(), window_size=50),
    seed=42
)

Bandits can be used for several tasks. They can be used for content personalization, as well as online model selection (see model_selection.BanditRegressor). The policies in River are therefore designed to be flexible, so that they can be used in conjunction with other River modules. For instance, the reward_obj in bandit.EpsilonGreedy can be a metric, a probability distribution, or a statistic. This works because objects in River adher to a coherent get/update interface.