Label propagation with semi-supervised learning

Label propagation is a semi-supervised technique that makes use of the labeled and unlabeled data to learn about the unlabeled data. Quite often, data that will benefit from a classification algorithm is difficult to label. For example, labeling data might be very expensive, so only a subset is cost-effective to manually label. This said, there does seem to be slow, but growing, support for companies to hire taxonomists.

Getting ready

Another problem area is censored data. You can imagine a case where the frontier of time will affect your ability to gather labeled data. Say, for instance, you took measurements of patients and gave them an experimental drug. In some cases, you are able to measure the outcome of the drug, if it happens fast enough, but you might want to predict the outcome of the drugs that have a slower reaction time. The drug might cause a fatal reaction for some patients, and life-saving measures might need to be taken.

How to do it…

In order to represent the semi-supervised or censored data, we'll need to do a little data preprocessing. First, we'll walk through a simple example, and then we'll move on to some more difficult cases:

>>> from sklearn import datasets
>>> d = datasets.load_iris()

Due to the fact that we'll be messing with the data, let's make copies and add an unlabeled member to the target name's copy. It'll make it easier to identify data later:

>>> X = d.data.copy()
>>> y = d.target.copy()
>>> names = d.target_names.copy()

>>> names = np.append(names, ['unlabeled'])
>>> names
array(['setosa', 'versicolor', 'virginica', 'unlabeled'],
       dtype='|S10')

Now, let's update y with -1. This is the marker for the unlabeled case. This is also why we added unlabeled to the end of names:

>>> y[np.random.choice([True, False], len(y))] = -1

Our data now has a bunch of negative ones (-1) interspersed with the actual data:

>>> y[:10]
array([-1, -1, -1, -1,  0,  0, -1, -1,  0, -1])

>>> names[y[:10]]
array(['unlabeled', 'unlabeled', 'unlabeled', 'unlabeled', 'setosa',
       'setosa', 'unlabeled', 'unlabeled', 'setosa', 'unlabeled'], 
        dtype='|S10')

We clearly have a lot of unlabeled data, and the goal now is to use LabelPropagation to predict the labels:

>>> from sklearn import semi_supervised
>>> lp = semi_supervised.LabelPropagation()

>>> lp.fit(X, y)

LabelPropagation(alpha=1, gamma=20, kernel='rbf', max_iter=30, 
                 n_neighbors=7, tol=0.001)
>>> preds = lp.predict(X)
>>> (preds == d.target).mean()
0.98666666666666669

Not too bad, though we did use all the data, so it's kind of cheating. Also, the iris dataset is a fairly separated dataset.

While we're at it, let's look at LabelSpreading, the "sister" class of LabelPropagation. We'll make the technical distinction between LabelPropagation and LabelSpreading in the How it works... section of this recipe, but it's easy to say that they are extremely similar:

>>> ls = semi_supervised.LabelSpreading()

LabelSpreading is more robust and noisy as observed from the way it works:

>>> ls.fit(X, y)
LabelSpreading(alpha=0.2, gamma=20, kernel='rbf', max_iter=30, 
               n_neighbors=7, tol=0.001)

>>> (ls.predict(X) == d.target).mean()
0.96666666666666667

Don't consider the fact that the label-spreading algorithm missed one more as an indication and that it performs worse in general. The whole point is that we might give some ability to predict well on the training set and to work on a wider range of situations.

How it works…

Label propagation works by creating a graph of the data points, with weights placed on the edge equal to the following:

How it works…

The algorithm then works by labeled data points propagating their labels to the unlabeled data. This propagation is in part determined by edge weight.

The edge weights can be placed in a matrix of transition probabilities. We can iteratively determine a good estimate of the actual labels.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset