KNNRegressor¶
K-Nearest Neighbors regressor.
This non-parametric regression method keeps track of the last window_size
training samples. Predictions are obtained by aggregating the values of the closest n_neighbors stored samples with respect to a query sample.
Parameters¶
-
n_neighbors
Type → int
Default →
5
The number of nearest neighbors to search for.
-
window_size
Type → int
Default →
1000
The maximum size of the window storing the last observed samples.
-
aggregation_method
Type → str
Default →
mean
The method to aggregate the target values of neighbors. | 'mean' | 'median' | 'weighted_mean'
-
min_distance_keep
Type → float
Default →
0.0
The minimum distance (similarity) to consider adding a point to the window. E.g., a value of 0.0 will add even exact duplicates.
-
distance_func
Type → DistanceFunc | None
Default →
None
An optional distance function that should accept an a=, b=, and any custom set of kwargs. If not defined, the Minkowski distance is used with p=2 (Euclidean distance). See the example section for more details.
Examples¶
from river import datasets
from river import evaluate
from river import metrics
from river import neighbors
from river import preprocessing
dataset = datasets.TrumpApproval()
model = neighbors.KNNRegressor(window_size=50)
evaluate.progressive_val_score(dataset, model, metrics.RMSE())
RMSE: 1.427746
When defining a custom distance function you can rely on functools.partial
to set default
parameter values. For instance, let's use the Manhattan function instead of the default Euclidean distance:
import functools
from river.utils.math import minkowski_distance
model = (
preprocessing.StandardScaler() |
neighbors.KNNRegressor(
window_size=50,
distance_func=functools.partial(minkowski_distance, p=1)
)
)
evaluate.progressive_val_score(dataset, model, metrics.RMSE())
RMSE: 1.460385
Methods¶
learn_one
Fits to a set of features x
and a real-valued target y
.
Parameters
- x — 'dict'
- y — 'base.typing.RegTarget'
Returns
Regressor: self
predict_one
Predict the output of features x
.
Parameters
- x — 'dict'
Returns
base.typing.RegTarget: The prediction.