KNNRegressor¶
K-Nearest Neighbors regressor.
This non-parametric regression method keeps track of the last window_size training samples. Predictions are obtained by aggregating the values of the closest n_neighbors stored samples with respect to a query sample.
Parameters¶
-
n_neighbors
Type → int
Default →
5The number of nearest neighbors to search for.
-
window_size
Type → int
Default →
1000The maximum size of the window storing the last observed samples.
-
aggregation_method
Type → str
Default →
meanThe method to aggregate the target values of neighbors. | 'mean' | 'median' | 'weighted_mean'
-
min_distance_keep
Type → float
Default →
0.0The minimum distance (similarity) to consider adding a point to the window. E.g., a value of 0.0 will add even exact duplicates.
-
distance_func
Type → DistanceFunc | None
Default →
NoneAn optional distance function that should accept an a=, b=, and any custom set of kwargs. If not defined, the Minkowski distance is used with p=2 (Euclidean distance). See the example section for more details.
Examples¶
from river import datasets
from river import evaluate
from river import metrics
from river import neighbors
from river import preprocessing
dataset = datasets.TrumpApproval()
model = neighbors.KNNRegressor(window_size=50)
evaluate.progressive_val_score(dataset, model, metrics.RMSE())
RMSE: 1.427746
When defining a custom distance function you can rely on functools.partial to set default
parameter values. For instance, let's use the Manhattan function instead of the default Euclidean distance:
import functools
from river.utils.math import minkowski_distance
model = (
preprocessing.StandardScaler() |
neighbors.KNNRegressor(
window_size=50,
distance_func=functools.partial(minkowski_distance, p=1)
)
)
evaluate.progressive_val_score(dataset, model, metrics.RMSE())
RMSE: 1.460385
Methods¶
learn_one
Fits to a set of features x and a real-valued target y.
Parameters
- x — 'dict'
- y — 'base.typing.RegTarget'
Returns
Regressor: self
predict_one
Predict the output of features x.
Parameters
- x — 'dict'
Returns
base.typing.RegTarget: The prediction.