AMFClassifier¶
Aggregated Mondrian Forest classifier for online learning.
This implementation is truly online1, in the sense that a single pass is performed, and that predictions can be produced anytime.
Each node in a tree predicts according to the distribution of the labels it contains. This distribution is regularized using a "Jeffreys" prior with parameter dirichlet. For each class with count labels in the node and n_samples samples in it, the prediction of a node is given by
The prediction for a sample is computed as the aggregated predictions of all the subtrees along the path leading to the leaf node containing the sample. The aggregation weights are exponential weights with learning rate step and log-loss when use_aggregation is True.
This computation is performed exactly thanks to a context tree weighting algorithm. More details can be found in the paper cited in the references below.
The final predictions are the average class probabilities predicted by each of the n_estimators trees in the forest.
Parameters¶
-
n_estimators
Type → int
Default →
10The number of trees in the forest.
-
step
Type → float
Default →
1.0Step-size for the aggregation weights. Default is 1 for classification with the log-loss, which is usually the best choice.
-
use_aggregation
Type → bool
Default →
TrueControls if aggregation is used in the trees. It is highly recommended to leave it as
True. -
dirichlet
Type → float
Default →
0.5Regularization level of the class frequencies used for predictions in each node. A rule of thumb is to set this to
1 / n_classes, wheren_classesis the expected number of classes which might appear. Default isdirichlet = 0.5, which works well for binary classification problems. -
split_pure
Type → bool
Default →
FalseControls if nodes that contains only sample of the same class should be split ("pure" nodes). Default is
False, namely pure nodes are not split, butTruecan be sometimes better. -
seed
Type → int | None
Default →
NoneRandom seed for reproducibility.
Attributes¶
- models
Examples¶
from river import datasets
from river import evaluate
from river import forest
from river import metrics
dataset = datasets.Bananas().take(500)
model = forest.AMFClassifier(
n_estimators=10,
use_aggregation=True,
dirichlet=0.5,
seed=1
)
metric = metrics.Accuracy()
evaluate.progressive_val_score(dataset, model, metric)
Accuracy: 85.37%
Methods¶
learn_one
Update the model with a set of features x and a label y.
Parameters
- x
- y
predict_one
Predict the label of a set of features x.
Parameters
- x — 'dict[base.typing.FeatureName, Any]'
- kwargs — 'Any'
Returns
base.typing.ClfTarget | None: The predicted label.
predict_proba_one
Predict the probability of each label for a dictionary of features x.
Parameters
- x
Returns
A dictionary that associates a probability which each label.
Notes¶
Only log_loss used for the computation of the aggregation weights is supported for now, namely the log-loss for multi-class classification.
-
Mourtada, J., Gaïffas, S., & Scornet, E. (2021). AMF: Aggregated Mondrian forests for online learning. Journal of the Royal Statistical Society Series B: Statistical Methodology, 83(3), 505-533. ↩