Debugging a pipeline¶
river
encourages users to make use of pipelines. The biggest pain point of pipelines is that it can be hard to understand what's happening to the data, especially when the pipeline is complex. Fortunately the Pipeline
class has a debug_one
method that can help out.
Let's look at a fairly complex pipeline for predicting the number of bikes in 5 bike stations from the city of Toulouse. It doesn't matter if you understand the pipeline or not; the point of this notebook is to learn how to introspect a pipeline.
import datetime as dt
from river import compose
from river import datasets
from river import feature_extraction
from river import linear_model
from river import metrics
from river import preprocessing
from river import stats
from river import stream
X_y = datasets.Bikes()
X_y = stream.simulate_qa(X_y, moment='moment', delay=dt.timedelta(minutes=30))
def add_time_features(x):
return {
**x,
'hour': x['moment'].hour,
'day': x['moment'].weekday()
}
model = add_time_features
model |= (
compose.Select('clouds', 'humidity', 'pressure', 'temperature', 'wind') +
feature_extraction.TargetAgg(by=['station', 'hour'], how=stats.Mean()) +
feature_extraction.TargetAgg(by='station', how=stats.EWMean())
)
model |= preprocessing.StandardScaler()
model |= linear_model.LinearRegression()
metric = metrics.MAE()
questions = {}
for i, x, y in X_y:
# Question
is_question = y is None
if is_question:
y_pred = model.predict_one(x)
questions[i] = y_pred
# Answer
else:
metric.update(y, questions[i])
model = model.learn_one(x, y)
if i >= 30000 and i % 30000 == 0:
print(i, metric)
30000 MAE: 2.220942
60000 MAE: 2.270271
90000 MAE: 2.301302
120000 MAE: 2.275876
150000 MAE: 2.275224
180000 MAE: 2.289347
Let's start by looking at the pipeline. You can click each cell to display the current state for each step of the pipeline.
model
add_time_features
def add_time_features(x):
return {
**x,
'hour': x['moment'].hour,
'day': x['moment'].weekday()
}
['clouds', 'humidity', 'pressure', 'temperature', 'wind']
{'keys': {'humidity', 'temperature', 'clouds', 'pressure', 'wind'}}
y_mean_by_station_and_hour
{'_feature_name': 'y_mean_by_station_and_hour',
'_groups': defaultdict(functools.partial(<function deepcopy at 0x7ff94faef940>, Mean: 0.),
{('metro-canal-du-midi', 0): Mean: 7.93981,
('metro-canal-du-midi', 1): Mean: 8.179704,
('metro-canal-du-midi', 2): Mean: 8.35824,
('metro-canal-du-midi', 3): Mean: 8.656051,
('metro-canal-du-midi', 4): Mean: 8.868445,
('metro-canal-du-midi', 5): Mean: 8.99656,
('metro-canal-du-midi', 6): Mean: 9.09966,
('metro-canal-du-midi', 7): Mean: 8.852642,
('metro-canal-du-midi', 8): Mean: 12.66712,
('metro-canal-du-midi', 9): Mean: 13.412186,
('metro-canal-du-midi', 10): Mean: 12.486815,
('metro-canal-du-midi', 11): Mean: 11.675479,
('metro-canal-du-midi', 12): Mean: 10.197409,
('metro-canal-du-midi', 13): Mean: 10.650855,
('metro-canal-du-midi', 14): Mean: 11.109123,
('metro-canal-du-midi', 15): Mean: 11.068934,
('metro-canal-du-midi', 16): Mean: 11.274958,
('metro-canal-du-midi', 17): Mean: 8.459136,
('metro-canal-du-midi', 18): Mean: 7.587469,
('metro-canal-du-midi', 19): Mean: 7.734677,
('metro-canal-du-midi', 20): Mean: 7.582465,
('metro-canal-du-midi', 21): Mean: 7.190665,
('metro-canal-du-midi', 22): Mean: 7.486895,
('metro-canal-du-midi', 23): Mean: 7.840791,
('place-des-carmes', 0): Mean: 4.720696,
('place-des-carmes', 1): Mean: 3.390295,
('place-des-carmes', 2): Mean: 2.232181,
('place-des-carmes', 3): Mean: 1.371981,
('place-des-carmes', 4): Mean: 1.051665,
('place-des-carmes', 5): Mean: 0.984993,
('place-des-carmes', 6): Mean: 2.039947,
('place-des-carmes', 7): Mean: 3.850369,
('place-des-carmes', 8): Mean: 3.792624,
('place-des-carmes', 9): Mean: 5.957182,
('place-des-carmes', 10): Mean: 8.575303,
('place-des-carmes', 11): Mean: 9.321546,
('place-des-carmes', 12): Mean: 10.511931,
('place-des-carmes', 13): Mean: 11.392745,
('place-des-carmes', 14): Mean: 10.735003,
('place-des-carmes', 15): Mean: 10.198787,
('place-des-carmes', 16): Mean: 9.941479,
('place-des-carmes', 17): Mean: 9.125579,
('place-des-carmes', 18): Mean: 7.660775,
('place-des-carmes', 19): Mean: 6.847649,
('place-des-carmes', 20): Mean: 9.626876,
('place-des-carmes', 21): Mean: 11.602929,
('place-des-carmes', 22): Mean: 10.405537,
('place-des-carmes', 23): Mean: 7.700904,
('place-esquirol', 0): Mean: 7.415789,
('place-esquirol', 1): Mean: 5.244396,
('place-esquirol', 2): Mean: 2.858635,
('place-esquirol', 3): Mean: 1.155929,
('place-esquirol', 4): Mean: 0.73306,
('place-esquirol', 5): Mean: 0.668546,
('place-esquirol', 6): Mean: 1.21265,
('place-esquirol', 7): Mean: 3.107535,
('place-esquirol', 8): Mean: 8.518696,
('place-esquirol', 9): Mean: 15.470588,
('place-esquirol', 10): Mean: 19.465005,
('place-esquirol', 11): Mean: 22.976512,
('place-esquirol', 12): Mean: 25.324159,
('place-esquirol', 13): Mean: 25.428847,
('place-esquirol', 14): Mean: 24.57762,
('place-esquirol', 15): Mean: 24.416851,
('place-esquirol', 16): Mean: 23.555125,
('place-esquirol', 17): Mean: 22.062564,
('place-esquirol', 18): Mean: 18.10623,
('place-esquirol', 19): Mean: 11.916638,
('place-esquirol', 20): Mean: 13.346362,
('place-esquirol', 21): Mean: 16.743318,
('place-esquirol', 22): Mean: 15.562088,
('place-esquirol', 23): Mean: 10.911134,
('place-jeanne-darc', 0): Mean: 6.541667,
('place-jeanne-darc', 1): Mean: 5.99892,
('place-jeanne-darc', 2): Mean: 5.598169,
('place-jeanne-darc', 3): Mean: 5.180556,
('place-jeanne-darc', 4): Mean: 4.779626,
('place-jeanne-darc', 5): Mean: 4.67063,
('place-jeanne-darc', 6): Mean: 4.611995,
('place-jeanne-darc', 7): Mean: 4.960718,
('place-jeanne-darc', 8): Mean: 5.552273,
('place-jeanne-darc', 9): Mean: 6.249573,
('place-jeanne-darc', 10): Mean: 5.735553,
('place-jeanne-darc', 11): Mean: 5.616142,
('place-jeanne-darc', 12): Mean: 5.787478,
('place-jeanne-darc', 13): Mean: 5.817699,
('place-jeanne-darc', 14): Mean: 5.657546,
('place-jeanne-darc', 15): Mean: 6.224604,
('place-jeanne-darc', 16): Mean: 5.796141,
('place-jeanne-darc', 17): Mean: 5.743089,
('place-jeanne-darc', 18): Mean: 5.674784,
('place-jeanne-darc', 19): Mean: 5.833068,
('place-jeanne-darc', 20): Mean: 6.015755,
('place-jeanne-darc', 21): Mean: 6.242541,
('place-jeanne-darc', 22): Mean: 6.141509,
('place-jeanne-darc', 23): Mean: 6.493028,
('pomme', 0): Mean: 3.301532,
('pomme', 1): Mean: 2.312914,
('pomme', 2): Mean: 2.144453,
('pomme', 3): Mean: 1.563622,
('pomme', 4): Mean: 0.947328,
('pomme', 5): Mean: 0.924175,
('pomme', 6): Mean: 1.287805,
('pomme', 7): Mean: 1.299456,
('pomme', 8): Mean: 2.94988,
('pomme', 9): Mean: 7.89396,
('pomme', 10): Mean: 11.791436,
('pomme', 11): Mean: 12.976854,
('pomme', 12): Mean: 13.962654,
('pomme', 13): Mean: 11.692257,
('pomme', 14): Mean: 11.180851,
('pomme', 15): Mean: 11.939586,
('pomme', 16): Mean: 12.267051,
('pomme', 17): Mean: 12.132993,
('pomme', 18): Mean: 11.399108,
('pomme', 19): Mean: 6.37021,
('pomme', 20): Mean: 5.279234,
('pomme', 21): Mean: 6.254257,
('pomme', 22): Mean: 6.568678,
('pomme', 23): Mean: 5.235756}),
'by': ['station', 'hour'],
'how': Mean: 0.,
'on': 'y'}
y_ewm_0.5_by_station
{'_feature_name': 'y_ewm_0.5_by_station',
'_groups': defaultdict(functools.partial(<function deepcopy at 0x7ff94faef940>, EWMean: 0.),
{('metro-canal-du-midi',): EWMean: 4.690531,
('place-des-carmes',): EWMean: 3.295317,
('place-esquirol',): EWMean: 31.539759,
('place-jeanne-darc',): EWMean: 22.449934,
('pomme',): EWMean: 11.803716}),
'by': ['station'],
'how': EWMean: 0.,
'on': 'y'}
StandardScaler
{'counts': Counter({'y_ewm_0.5_by_station': 182470,
'y_mean_by_station_and_hour': 182470,
'humidity': 182470,
'temperature': 182470,
'clouds': 182470,
'pressure': 182470,
'wind': 182470}),
'means': defaultdict(<class 'float'>,
{'clouds': 30.315131254453505,
'humidity': 62.24244533347998,
'pressure': 1017.0563060996391,
'temperature': 20.50980692716619,
'wind': 3.4184331122924543,
'y_ewm_0.5_by_station': 10.08331958752748,
'y_mean_by_station_and_hour': 9.410348580619415}),
'vars': defaultdict(<class 'float'>,
{'clouds': 1389.0025610928221,
'humidity': 349.59967918503554,
'pressure': 33.298307526514115,
'temperature': 34.70701720774977,
'wind': 4.473627075744674,
'y_ewm_0.5_by_station': 80.17355266024735,
'y_mean_by_station_and_hour': 33.98249801051089}),
'with_std': True}
LinearRegression
{'_weights': {'y_ewm_0.5_by_station': 9.264175276315452, 'y_mean_by_station_and_hour': 0.1980140007049781, 'humidity': 1.0125248437612895, 'temperature': -0.4211217806219201, 'clouds': -0.3269694794458286, 'pressure': 0.18137498909137595, 'wind': -0.04087954775179438},
'_y_name': None,
'clip_gradient': 1000000000000.0,
'initializer': Zeros (),
'intercept': 9.22315869068918,
'intercept_init': 0.0,
'intercept_lr': Constant({'learning_rate': 0.01}),
'l2': 0.0,
'loss': Squared({}),
'optimizer': SGD({'lr': Constant({'learning_rate': 0.01}), 'n_iterations': 182470})}
As mentioned above the Pipeline
class has a debug_one
method. You can use this at any point you want to visualize what happen to an input x
. For example, let's see what happens to the last seen x
.
print(model.debug_one(x))
0. Input
--------
clouds: 88 (int)
description: overcast clouds (str)
humidity: 84 (int)
moment: 2016-10-05 09:57:18 (datetime)
pressure: 1,017.34000 (float)
station: pomme (str)
temperature: 17.45000 (float)
wind: 1.95000 (float)
1. add_time_features
--------------------
clouds: 88 (int)
day: 2 (int)
description: overcast clouds (str)
hour: 9 (int)
humidity: 84 (int)
moment: 2016-10-05 09:57:18 (datetime)
pressure: 1,017.34000 (float)
station: pomme (str)
temperature: 17.45000 (float)
wind: 1.95000 (float)
2. Transformer union
--------------------
2.0 Select
----------
clouds: 88 (int)
humidity: 84 (int)
pressure: 1,017.34000 (float)
temperature: 17.45000 (float)
wind: 1.95000 (float)
2.1 TargetAgg
-------------
y_mean_by_station_and_hour: 7.89396 (float)
2.2 TargetAgg1
--------------
y_ewm_0.5_by_station: 11.80372 (float)
clouds: 88 (int)
humidity: 84 (int)
pressure: 1,017.34000 (float)
temperature: 17.45000 (float)
wind: 1.95000 (float)
y_ewm_0.5_by_station: 11.80372 (float)
y_mean_by_station_and_hour: 7.89396 (float)
3. StandardScaler
-----------------
clouds: 1.54778 (float)
humidity: 1.16366 (float)
pressure: 0.04916 (float)
temperature: -0.51938 (float)
wind: -0.69426 (float)
y_ewm_0.5_by_station: 0.19214 (float)
y_mean_by_station_and_hour: -0.26013 (float)
4. LinearRegression
-------------------
Name Value Weight Contribution
Intercept 1.00000 9.22316 9.22316
y_ewm_0.5_by_station 0.19214 9.26418 1.78000
humidity 1.16366 1.01252 1.17823
temperature -0.51938 -0.42112 0.21872
wind -0.69426 -0.04088 0.02838
pressure 0.04916 0.18137 0.00892
y_mean_by_station_and_hour -0.26013 0.19801 -0.05151
clouds 1.54778 -0.32697 -0.50608
Prediction: 11.87982
The pipeline does quite a few things, but using debug_one
shows what happens step by step. This is really useful for checking that the pipeline is behaving as you're expecting it too. Remember that you can debug_one
whenever you wish, be it before, during, or after training a model.