Skip to content

learn_during_predict

A context manager for fitting unsupervised steps during prediction.

Usually, unsupervised parts of a pipeline are updated during learn_one. However, in the case of online learning, it is possible to update them before, during the prediction step. This context manager allows you to do so.

This usually brings a slight performance improvement. But it is not done by default because it is not intuitive and is more difficult to test. It also means that you have to call predict_one before learn_one in order for the whole pipeline to be updated.

Examples

Let's first see what methods are called if we just call predict_one.

import io
import logging
from river import compose
from river import datasets
from river import linear_model
from river import preprocessing
from river import utils

model = compose.Pipeline(
    preprocessing.StandardScaler(),
    linear_model.LinearRegression()
)

class_condition = lambda x: x.__class__.__name__ in ('StandardScaler', 'LinearRegression')

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition):
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_one(x)

print(logs.getvalue())
StandardScaler.transform_one
LinearRegression.predict_one

Now let's use the context manager and see what methods get called.

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition), compose.learn_during_predict():
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_one(x)

print(logs.getvalue())
StandardScaler.learn_one
StandardScaler.transform_one
LinearRegression.predict_one

We can see that the scaler did not get updated before transforming the data.

This also works when working with mini-batches.

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition):
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_many(pd.DataFrame([x]))
print(logs.getvalue())
StandardScaler.transform_many
LinearRegression.predict_many

logs = io.StringIO()
sh = logging.StreamHandler(logs)
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)

with utils.log_method_calls(class_condition), compose.learn_during_predict():
    for x, y in datasets.TrumpApproval().take(1):
        _ = model.predict_many(pd.DataFrame([x]))
print(logs.getvalue())
StandardScaler.learn_many
StandardScaler.transform_many
LinearRegression.predict_many