Skip to content

pure_inference_mode

A context manager for making inferences with no side-effects.

Calling predict_one with a pipeline will update the unsupervised steps of the pipeline. This is the expected behavior for online machine learning. However, in some cases, you might just want to produce predictions without necessarily updating anything.

This context manager allows you to override that behavior, by making it so that unsupervised estimators are not updated when predict_one is called.

Examples

Let's first see what methods are called if we just call predict_one.

import io
import logging
from river import compose
from river import datasets
from river import linear_model
from river import preprocessing
from river import utils

model = compose.Pipeline(
    preprocessing.StandardScaler(),
    linear_model.LinearRegression()
)

class_condition = lambda x: x.__class__.__name__ in ('StandardScaler', 'LinearRegression')

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition):
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_one(x)

print(logs.getvalue())
StandardScaler.learn_one
StandardScaler.transform_one
LinearRegression.predict_one

Now let's use the context manager and see what methods get called.

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition), compose.pure_inference_mode():
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_one(x)

print(logs.getvalue())
StandardScaler.transform_one
LinearRegression.predict_one

We can see that the scaler did not get updated before transforming the data.

This also works when working with mini-batches.

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition):
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_many(pd.DataFrame([x]))
print(logs.getvalue())
StandardScaler.learn_many
StandardScaler.transform_many
LinearRegression.predict_many

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition), compose.pure_inference_mode():
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_many(pd.DataFrame([x]))
print(logs.getvalue())
StandardScaler.transform_many
LinearRegression.predict_many