The Island
On a mysterious island live two types of animals: yeebles and zooks, and they love to eat wild strawberries. A certain cave on the island is home to either a yeeble or a zook, and you'd like to know which it is. Unfortunately, it's very shy, and never comes out of its cave if you're nearby. You decide to find out by feeding it strawberries.From previous experiments, you know that yeebles tend to eat around four pounds of strawberries in 24 hours if left alone, while zooks will eat about eight pounds. You put ten pounds of strawberries outside the cave, and come back a day later to find 6.3 pounds missing. Which kind of animal lives in the cave? How certain are you?
Classification
This is a classification problem. Given the result of a measurement, we want to figure out what class a member belongs to. In particular, this problem is binary, since the animal is either a yeeble or a zook, and can't be both. It's also univariate, since the only variable is the number of strawberries consumed in 24 hours.Suppose our data look like this:
Fig. 1: Previous data for two classes of animal |
Our strategy will be to turn the previous data about strawberry consumption into a function that takes [lbs of strawberries] and converts it to [probability that the animal is a zook]. The way this is normally done in the machine learning community is through logistic regression. We assume that the probability of an animal being a zook based on the amount of strawberries eaten is of the form
$P\left(Y|S\right) = \frac{1}{1 + \exp\left(a - b*S\right)}$,
where $a$ and $b$ are constants, and $S$ is the amount of strawberries eaten, in lbs. A regression is used to choose $a$ and $b$ so that the curve matches the data optimally. Just like in a linear regression, we choose the parameters that minimize a cost function like the sum of square errors.
Why the logistic curve?
Why do we assume a logistic curve is the correct model for the probability? Clearly, a line wouldn't work. At minimum, we need something with a range of $\left(0,1\right)$, since $P\left(Y\right)$ can't be less than 0 or greater than 1. But lots of functions have this property, including the sigmoids, of which a logistic curve is one. Choosing the logistic curve implies some assumptions about how the data are generated.In particular, the logistic curve is the correct probability distribution if each class exhibits a normally distributed feature with equal variance and different center values.
Here's an illustration of that assumption. Suppose we make a histogram of lbs of strawberries eaten for both yeebles and zooks:
Fig. 2: Histogram of previous data from which we construct our model. This plot shows 10k data points, wheras the scatter plot shows only 100 for clarity. |
These distributions are represented by
$P(S|Y) \propto \exp\left(-(S - S_Y)^2/2\sigma^2\right)$,
and
$P(S|Z) \propto \exp\left(-(S - S_Z)^2/2\sigma^2\right)$,
where $S_Y$ and $S_Z$ are the centers of the distributions and $\sigma$ is the standard deviation. We want to know the probability that an animal is a zook given a measurement $S'$. By Bayes' rule, this is
$P(Z|S') = P(Z)\frac{P(S'|Z)}{P(S'|Z) + P(S'|Y)}$,
where $P(Z)$ is the prior probability of finding a zook (independent of strawberry consumption), and $P(S'|Y),P(S'|Z)$ are called sampling distributions. They tell us the probability that the data would have been generated if each hypothesis were true. They're just the Gaussian distributions given above.
We're further going to assume that $P(Y)=P(Z)$, that the number of yeebles and zooks is the same. If we knew otherwise, we would modify this. I'll talk a lot more about priors in other blog posts. Substituting in the known distributions and reducing, we have
$P(Z|S') = \frac{1}{1 + \exp\left((S_Z^2 - S_Y^2 - 2S'\left(S_Y - S_Z\right))/2\sigma^2\right)}$.
This is a logistic curve. We have shown that under the assumption that the measured features are normally distributed with equal standard deviations, a logistic curve is the correct probability model to use. This analysis extends to multivariate distributions as well.
Mathematical convenience
Logistic curves have some convenient properties, so we're lucky everything turned out this way. The whole shape of the curve is controlled by the argument in the exponential, $(S_Z^2 - S_Y^2 - 2S'\left(S_Y - S_Z\right))/2\sigma^2$, which is linear in $S'$. That means that we can use a generalized linear regression to find the parameters $a$ and $b$ above, which in this case turn out to be $a = 18.25$ and $b = 3.058$.We can also easily show (by integrating the log of the likelihood ratio) that the probability of making an incorrect classification decision in this case is given by
$P(\textrm{fail}|S') = \frac{1}{2}\textrm{erf}\left(\sqrt{S_Y^2 - S_Z^2}/2\sigma\right)$, where $\textrm{erf}$ is the error function (another sigmoid!). Knowing this lets us address questions like "how close can the distributions be before we can't classify very effectively?" You tell me what error rate is acceptable, and I'll tell you how far the distributions have to be separated to achieve that error rate.
The answer
If we observe that 6.3 lbs of strawberries are missing after a day, all we have to do is check what probability it corresponds to on our logistic curve. Here is what the best fit curve looks like:
Fig. 3: The best fit logistic curve for the data allows us to classify future members. |
Mathematical inconvenience
I want to briefly break the problem here to illustrate the limitations of the logistic model. We assumed above that the standard deviation of each distribution was the same, but that's not always a valid assumption. In fact, it's quite common to find Gaussian-distributed processes (or approximately Gaussian) where the standard deviation is proportional to the center value. If the standard deviations are different, the argument in the exponent is quadratic in both $S_Y$ and $S_Z$, so a linear regression doesn't work any more.
In fact, we can keep adding higher order moments to the Gaussian distributions, and the argument becomes some kind of higher-order polynomial, and we can use a polynomial regression to do our classification.
What if the distributions are not Gaussian distributed? Well, that sucks as always. The sampling distributions won't be as nice, and P(Z|S') may not be parameterizable as a polynomial at all. Then we would need to perform general nonlinear fitting. That can be done of course, but not as efficiently.
I leave this as an exercise to the reader. ;)
- b
*: Gaussians actually extend over the entire real line, and I'm truncating at 0 lbs since it's hard to eat a negative weight of strawberries. The error incurred is small if the center of the distributions is a few times greater than the width.
Acknowledgement: Thanks to Thomas Stearns for helpful discussions.