Skip to content

pure_inference_mode

A context manager for making inferences with no side-effects.

Calling predict_one with a pipeline will update the unsupervised steps of the pipeline. This is the expected behavior for online machine learning. However, in some cases you might just want to produce predictions without necessarily updating anything.

This context manager allows you to override that behavior and make it so that unsupervised estimators are not updated when predict_one is called.

Examples

Let's first see what methods are called if we just call predict_one.

>>> import io
>>> import logging
>>> from river import compose
>>> from river import datasets
>>> from river import linear_model
>>> from river import preprocessing
>>> from river import utils

>>> model = compose.Pipeline(
...     preprocessing.StandardScaler(),
...     linear_model.LinearRegression()
... )

>>> class_condition = lambda x: x.__class__.__name__ in ('StandardScaler', 'LinearRegression')

>>> logger = logging.getLogger()
>>> logger.setLevel(logging.DEBUG)

>>> logs = io.StringIO()
>>> sh = logging.StreamHandler(logs)
>>> sh.setLevel(logging.DEBUG)
>>> logger.addHandler(sh)

>>> with utils.log_method_calls(class_condition):
...     for x, y in datasets.TrumpApproval().take(1):
...         _ = model.predict_one(x)

>>> print(logs.getvalue())
StandardScaler.learn_one
StandardScaler.transform_one
LinearRegression.predict_one

Now let's use the context manager and see what methods get called.

>>> logs = io.StringIO()
>>> sh = logging.StreamHandler(logs)
>>> sh.setLevel(logging.DEBUG)
>>> logger.addHandler(sh)

>>> with utils.log_method_calls(class_condition), utils.pure_inference_mode():
...     for x, y in datasets.TrumpApproval().take(1):
...         _ = model.predict_one(x)

>>> print(logs.getvalue())
StandardScaler.transform_one
LinearRegression.predict_one

We can see that the scaler did not get updated before transforming the data.