Skip to content


Thompson sampling.

Thompson sampling is often used with a Beta distribution. However, any probability distribution can be used, as long it makes sense with the reward shape. For instance, a Beta distribution is meant to be used with binary rewards, while a Gaussian distribution is meant to be used with continuous rewards.

The randomness of a distribution is controlled by its seed. The seed should not set within the distribution, but should rather be defined in the policy parametrization. In other words, you should do this:

policy = ThompsonSampling(dist=proba.Beta(1, 1), seed=42) 

and not this:

policy = ThompsonSampling(dist=proba.Beta(1, 1, seed=42)) 


  • dist


    A distribution to sample from.

  • burn_in


    The number of steps to use for the burn-in phase. Each arm is given the chance to be pulled during the burn-in phase. This is useful to mitigate selection bias.

  • seed

    Typeint | None


    Random number generator seed for reproducibility.


  • dist

  • ranking

    Return the list of arms in descending order of performance.


import gym
from river import bandit
from river import proba
from river import stats

env = gym.make(
_ = env.reset(seed=42)
_ = env.action_space.seed(123)

policy = bandit.ThompsonSampling(dist=proba.Beta(), seed=101)

metric = stats.Sum()
while True:
    action = next(policy.pull(range(env.action_space.n)))
    observation, reward, terminated, truncated, info = env.step(action)
    policy = policy.update(action, reward)
    metric = metric.update(reward)
    if terminated or truncated:

Sum: 820.



Pull arm(s).

This method is a generator that yields the arm(s) that should be pulled. During the burn-in phase, all the arms that have not been pulled enough are yielded. Once the burn-in phase is over, the policy is allowed to choose the arm(s) that should be pulled. If you only want to pull one arm at a time during the burn-in phase, simply call next(policy.pull(arms)).


  • arm_ids'list[ArmID]'


Update an arm's state.


  • arm_id
  • reward_args
  • reward_kwargs