SMOTE
This is a tutorial walking through the basic functionality of the smote
function of this package which implements the Synthetic Minority Over-sampling TEchnique (SMOTE) (Chawla et al., 2002).
This technique is useful when the accuracy of a statistical model is low due to class imbalances. For example, when fitting a model on an outcome variable with 2 classes and one class has 10 items and the other 90 items, then a model can score very well by always predicting a sample to be in the biggest class. These biggest and smallest classes are respectively known as the majority and minority class.
The idea of SMOTE is to fix the class imbalance by generating random points in between points in the minority group. This will be shown below in more detail.
In between
In the first part of this tutorial, a few random points in a minority class are generated. Next, we plot these points, generate a few synthetic (new) points and show that the synthetic points lay between the original points.
begin
using CairoMakie
using DataFrames
using Resample
using StableRNGs: StableRNG
end
Let's start by generating some random data:
df = let
rng = StableRNG(2)
DataFrame(; A=rand(rng, 4), B=rand(rng, 4))
end
A | B | |
---|---|---|
1 | 0.975324 | 0.421822 |
2 | 0.128897 | 0.394399 |
3 | 0.527714 | 0.849473 |
4 | 0.867262 | 0.929645 |
Which looks as follows when plotted:
To generate new synthetic points (new_points
), we use the smote
function from this package:
new_data = let
rng = StableRNG(1)
new_points = smote(rng, df, 4)
DataFrame(new_points)
end
A | B | |
---|---|---|
1 | 0.338128 | 0.546072 |
2 | 0.892834 | 0.809474 |
3 | 0.969836 | 0.447614 |
4 | 0.349973 | 0.64666 |
Plotting these 2-dimensional points with the synthetic points looks as follows:
where the dashed lines are straight lines in between all the points. As expected, the synthetic points lay in between the minority points.
Next, we do this for a point cloud.
Cloud
When doing this for many points, we will see that the cloud of points gets thicker due when the synthetic points are added. In the first part of this tutorial, we passed the data as a DataFrame (or any other Tables.istable
object), but we can also pass a matrix:
mat = randn(2, 400) .* 100
2×400 Matrix{Float64}: -9.24884 42.416 107.225 198.61 … 163.558 30.6655 -31.9358 46.8046 -183.394 -56.0835 -71.0079 -40.0269 -7.91094 138.941 -30.1962 15.9629
which looks as follows when plotted:
Now, we can generate a bunch of synthetic data:
new_mat = smote(mat, 300)
2×300 Matrix{Float64}: 102.331 -0.386574 51.8594 -40.1784 … 114.752 -58.337 -45.6469 106.402 -89.1872 -1.46835 -160.438 -96.9686 -55.4952 77.5844 29.8038 77.2466
which looks as follows: