Skip to content

Binary classification

Classification is about predicting an outcome from a fixed list of classes. The prediction is a probability distribution that assigns a probability to each possible outcome.

A labeled classification sample is made up of a bunch of features and a class. The class is a boolean in the case of binary classification. We'll use the phishing dataset as an example.

from river import datasets

dataset = datasets.Phishing()
dataset
Phishing websites.

This dataset contains features from web pages that are classified as phishing or not.

    Name  Phishing                                                    
    Task  Binary classification                                       
 Samples  1,250                                                       
Features  9                                                           
  Sparse  False                                                       
    Path  /home/runner/work/river/river/river/datasets/phishing.csv.gz

This dataset is a streaming dataset which can be looped over.

for x, y in dataset:
    pass

Let's take a look at the first sample.

x, y = next(iter(dataset))
x
{'empty_server_form_handler': 0.0,
 'popup_window': 0.0,
 'https': 0.0,
 'request_from_other_domain': 0.0,
 'anchor_from_other_domain': 0.0,
 'is_popular': 0.5,
 'long_url': 1.0,
 'age_of_domain': 1,
 'ip_in_url': 1}
y
True

A binary classifier's goal is to learn to predict a binary target y from some given features x. We'll try to do this with a logistic regression.

from river import linear_model

model = linear_model.LogisticRegression()
model.predict_proba_one(x)
{False: 0.5, True: 0.5}

The model hasn't been trained on any data, and therefore outputs a default probability of 50% for each class.

The model can be trained on the sample, which will update the model's state.

model = model.learn_one(x, y)

If we try to make a prediction on the same sample, we can see that the probabilities are different, because the model has learned something.

model.predict_proba_one(x)
{False: 0.494687699901455, True: 0.505312300098545}

Note that there is also a predict_one if you're only interested in the most likely class rather than the probability distribution.

model.predict_one(x)
True

Typically, an online model makes a prediction, and then learns once the ground truth reveals itself. The prediction and the ground truth can be compared to measure the model's correctness. If you have a dataset available, you can loop over it, make a prediction, update the model, and compare the model's output with the ground truth. This is called progressive validation.

from river import metrics

model = linear_model.LogisticRegression()

metric = metrics.ROCAUC()

for x, y in dataset:
    y_pred = model.predict_proba_one(x)
    model.learn_one(x, y)
    metric.update(y, y_pred)

metric
ROCAUC: 89.36%

This is a common way to evaluate an online model. In fact, there is a dedicated evaluate.progressive_val_score function that does this for you.

from river import evaluate

model = linear_model.LogisticRegression()
metric = metrics.ROCAUC()

evaluate.progressive_val_score(dataset, model, metric)
ROCAUC: 89.36%

A common way to improve the performance of a logistic regression is to scale the data. This can be done by using a preprocessing.StandardScaler. In particular, we can define a pipeline to organise our model into a sequence of steps:

from river import compose
from river import preprocessing

model = compose.Pipeline(
    preprocessing.StandardScaler(),
    linear_model.LogisticRegression()
)

model
StandardScaler
StandardScaler ( with_std=True )
LogisticRegression
LogisticRegression ( optimizer=SGD ( lr=Constant ( learning_rate=0.01 ) ) loss=Log ( weight_pos=1. weight_neg=1. ) l2=0. l1=0. intercept_init=0. intercept_lr=Constant ( learning_rate=0.01 ) clip_gradient=1e+12 initializer=Zeros () )
metric = metrics.ROCAUC()
evaluate.progressive_val_score(dataset, model, metric)
ROCAUC: 95.04%

That concludes the getting started introduction to binary classification! You can now move on to the next steps.