12 Improving training with metrics and augmentation

This chapter covers

  • Defining and computing precision, recall, and true/false positives/negatives
  • Using the F1 score versus other quality metrics
  • Balancing and augmenting data to reduce overfitting
  • Using TensorBoard to graph quality metrics

The close of the last chapter left us in a predicament. While we were able to get the mechanics of our deep learning project in place, none of the results were actually useful; the network simply classified everything as non-nodule! To make matters worse, the results seemed great on the surface, since we were looking at the overall percent of the training and validation sets that were classified correctly. With our data heavily skewed toward negative samples, blindly calling everything negative is a quick and easy way for our model to score well. Too bad doing so makes the model basically useless!

That means we’re still focused on the same part of figure 12.1 as we were in chapter 11. But now we’re working on getting our classification model working well instead of at all. This chapter is all about how to measure, quantify, express, and then improve on how well our model is doing its job.

Figure 12.1 Our end-to-end lung cancer detection project, with a focus on this chapter’s topic: step 4, classification

12.1 High-level plan for improvement

While a bit abstract, figure 12.2 shows us how we are going to approach that broad set of topics.

Let’s walk through this somewhat abstract map of the chapter in detail. We will be dealing with the issues we’re facing, like excessive focus on a single, narrow metric and the resulting behavior being useless in the general sense. In order to make some of this chapter’s concepts a bit more concrete, we’ll first employ a metaphor that puts our troubles in more tangible terms: in figure 12.2, (1) Guard Dogs and (2) Birds and Burglars.

Figure 12.2 The metaphors we’ll use to modify the metrics measuring our model to make it magnificent

After that, we will develop a graphical language to represent some of the core concepts needed to formally discuss the issues with the implementation from the last chapter: (3) Ratios: Recall and Precision. Once we have those concepts solidified, we’ll touch on some math using those concepts that will encapsulate a more robust way of grading our model’s performance and condensing it into a single number: (4) New Metric: F1 Score. We will implement the formula for those new metrics and look at the how the resulting values change epoch by epoch during training. Finally, we’ll make some much-needed changes to our LunaDataset implementation with an aim at improving our training results: (5) Balancing and (6) Augmentation. Then we will see if those experimental changes have the expected impact on our performance metrics.

By the time we’re through with this chapter, our trained model will be performing much better: (7) Workin’ Great! While it won’t be ready to drop into clinical use just yet, it will be capable of producing results that are clearly better than random. This will mean we have a workable implementation of step 4, nodule candidate classification; and once we’re finished, we can begin to think about how to incorporate steps 2 (segmentation) and 3 (grouping) into the project.

12.2 Good dogs vs. bad guys: False positives and false negatives

Instead of models and tumors, we’re going to consider the two guard dogs in figure 12.3, both fresh out of obedience school. They both want to alert us to burglars--a rare but serious situation that requires prompt attention.

Figure 12.3 The set of topics for this chapter, with a focus on the framing metaphor

Unfortunately, while both dogs are good dogs, neither is a good guard dog. Our terrier (Roxie) barks at just about everything, while our old hound dog (Preston) barks almost exclusively at burglars--but only if he happens to be awake when they arrive.

Roxie will alert us to a burglar just about every time. She will also alert us to fire engines, thunderstorms, helicopters, birds, the mail carrier, squirrels, passersby, and so on. If we follow up on every bark, we’ll almost never get robbed (only the sneakiest of sneak-thieves can slip past). Perfect! ... Except that being that diligent means we aren’t really saving any work by having a guard dog. Instead, we’ll be up every couple of hours, flashlight in hand, due to Roxie having smelled a cat, or heard an owl, or seen a late bus wander by. Roxie has a problematic number of false positives.

A false positive is an event that is classified as of interest or as a member of the desired class (positive as in “Yes, that’s the type of thing I’m interested in knowing about”) but that in truth is not really of interest. For the nodule-detection problem, it’s when an actually uninteresting candidate is flagged as a nodule and, hence, in need of a radiologist’s attention. For Roxie, these would be fire engines, thunderstorms, and so on. We will use an image of a cat as the canonical false positive in the next section and the figures that follow throughout the rest of the chapter.

Contrast false positives with true positives: items of interest that are classified correctly. These will be represented in the figures by a human burglar.

Meanwhile, if Preston barks, call the police, since that means someone has almost certainly broken in, the house is on fire, or Godzilla is attacking. Preston is a deep sleeper, however, and the sound of an in-progress home invasion isn’t likely to rouse him, so we’ll still get robbed just about every time someone tries. Again, while it’s better than nothing, we’re not really ending up with the peace of mind that motivated us to get a dog in the first place. Preston has a problematic number of false negatives.

A false negative is an event that is classified as not of interest or not a member of the desired class (negative as in “No, that’s not the type of thing I’m interested in knowing about”) but that in truth is actually of interest. For the nodule-detection problem, it’s when a nodule (that is, a potential cancer) goes undetected. For Preston, these would be the robberies that he sleeps through. We’ll get a bit creative here and use a picture of a rodent burglar for false negatives. They’re sneaky!

Contrast false negatives with true negatives: uninteresting items that are correctly identified as such. We’ll go with a picture of a bird for these.

Just to complete the metaphor, chapter 11’s model is basically a cat that refuses to meow at anything that isn’t a can of tuna (while stoically ignoring Roxie). Our focus at the end of the last chapter was on the percent correct for the overall training and validation sets. Clearly, that wasn’t a great way to grade ourselves, and as we can see from each of our dogs’ myopic focus on a single metric--like the number of true positives or true negatives--we need a metric with a broader focus to capture our overall performance.

12.3 Graphing the positives and negatives

Let’s start developing the visual language we’ll use to describe true/false positives/ negatives. Please bear with us if our explanation gets repetitive; we want to make sure you develop a solid mental model for the ratios we’re going to discuss. Consider figure 12.4, which shows events that might be of interest to one of our guard dogs.

Figure 12.4 Cats, birds, rodents, and robbers make up our four classification quadrants. They are separated by a human label and the dog classification threshold.

We’ll use two thresholds in figure 12.4. The first is the human-decided dividing line that separates burglars from harmless animals. In concrete terms, this is the label that is given for each training or validation sample. The second is the dog-determined classification threshold that determines whether the dog will bark at something. For a deep learning model, this is the predicted value that the model produces when considering a sample.

The combination of these two thresholds divides our events into quadrants: true/false positives/negatives. We will shade the events of concern with a darker background (what with those bad guys sneaking around in the dark all the time).

Of course, reality is far more complicated. There is no Platonic ideal of a burglar, and no single point relative to the classification threshold at which all burglars will be located. Instead, figure 12.5 shows us that some burglars will be particularly sneaky, and some birds will be particularly annoying. We will also go ahead and enclose our instances in a graph. Our X-axis will remain the bark-worthiness of each event, as determined by one of our guard dogs. We’re going to have the Y-axis represent some vague set of qualities that we as humans are able to perceive, but our dogs cannot.

Since our model produces a binary classification, we can think of the prediction threshold as comparing a single-numerical-value output to our classification threshold value. This is why we will require that the classification threshold line to be perfectly vertical in figure 12.5.

Figure 12.5 Each type of event will have many possible instances that our guard dogs will need to evaluate.

Each possible burglar is different, so our guard dogs will need to evaluate many different situations, and that means more opportunities to make mistakes. We can see the clear diagonal line that separates the birds from the burglars, but Preston and Roxie can only perceive the X-axis here: they have a muddled, overlapped set of events in the middle of our graph. They must pick a vertical bark-worthiness threshold, which means it’s impossible for either one of them to do so perfectly. Sometimes the person hauling your appliances to their van is the repair person you hired to fix your washing machine, and sometimes burglars show up in a van that says “Washing Machine Repair” on the side. Expecting a dog to pick up on those nuances is bound to fail.

The actual input data we’re going to use has high dimensionality--we need to consider a ton of CT voxel values, along with more abstract things like candidate size, overall location in the lungs, and so on. The job of our model is to map each of these events and respective properties into this rectangle in such a way that we can separate those positive and negative events cleanly using a single vertical line (our classification threshold). This is done by the nn.Linear layers at the end of our model. The position of the vertical line corresponds exactly to the classificationThreshold_float we saw in section 11.6.1. There, we chose the hardcoded value 0.5 as our threshold.

Note that in reality, the data presented is not two-dimensional; it goes from very-high-dimensional after the second-to-last layer, to one-dimensional (here, our X-axis) at the output--just a single scalar per sample (which is then bisected by the classification threshold). Here, we use the second dimension (the Y-axis) to represent per-sample features that our model cannot see or use: things like age or gender of the patient, location of the nodule candidate in the lung, or even local aspects of the candidate that the model hasn’t utilized. It also gives us a convenient way to represent confusion between non-nodule and nodule samples.

The quadrant areas in figure 12.5 and the count of samples contained in each will be the values we use to discuss model performance, since we can use the ratios between these values to construct increasingly complex metrics that we can use to objectively measure how well we are doing. As they say, “the proof is in the proportions.”1 Next, we’ll use ratios between these event subsets to start defining better metrics.

12.3.1 Recall is Roxie’s strength

Recall is basically “Make sure you never miss any interesting events!” Formally, recall is the ratio of the true positives to the union of true positives and false negatives. We can see this depicted in figure 12.6.

Figure 12.6 Recall is the ratio of the true positives to the union of true positives and false negatives. High recall minimizes false negatives.

Note In some contexts, recall is referred to as sensitivity.

To improve recall, minimize false negatives. In guard dog terms, that means if you’re unsure, bark at it, just in case. Don’t let any rodent thieves sneak by on your watch!

Roxie accomplishes having an incredibly high recall by pushing her classification threshold all the way to the left, such that it encompasses nearly all of the positive events in figure 12.7. Note how doing so means her recall value is near 1.0, which means 99% of robbers are barked at. Since that’s how Roxie defines success, in her mind, she’s doing a great job. Never mind the huge expanse of false positives!

Figure 12.7 Roxie’s choice of threshold prioritizes minimizing false negatives. Every last rat is barked at . . . and cats, and most birds.

12.3.2 Precision is Preston’s forte

Precision is basically “Never bark unless you’re sure.” To improve precision, minimize false positives. Preston won’t bark at something unless he’s certain it’s a burglar. More formally, precision is the ratio of the true positives to the union of true positives and false positives, as shown in figure 12.8.

Figure 12.8 Precision is the ratio of the true positives to the union of true positives and false positives. High precision minimizes false positives.

Preston accomplishes having an incredibly high precision by pushing his classification threshold all the way to the right, such that it excludes as many uninteresting, negative events as he can manage (see figure 12.9). This is the opposite of Roxie’s approach and means Preston has a precision of nearly 1.0: 99% of the things he barks at are robbers. This also matches his definition of being a good guard dog, even though a large number of events pass undetected.

While neither precision nor recall can be the single metric used to grade our model, they are both useful numbers to have on hand during training. Let’s calculate and display these as part of our training program, and then we’ll discuss other metrics we can employ.

Figure 12.9 Preston’s choice of threshold prioritizes minimizing false positives. Cats get left alone; only burglars are barked at!

12.3.3 Implementing precision and recall in logMetrics

Both precision and recall are valuable metrics to be able to track during training, since they provide important insight into how the model is behaving. If either of them drops to zero (as we saw in chapter 11!), it’s likely that our model has started to behave in a degenerate manner. We can use the exact details of the behavior to guide where to investigate and experiment with getting training back on track. We’d like to update the logMetrics function to add precision and recall to the output we see for each epoch, to complement the loss and correctness metrics we already have.

We’ve been defining precision and recall in terms of “true positives” and the like thus far, so we will continue to do so in the code. It turns out that we are already computing some of the values we need, though we had named them differently.

Listing 12.1 training.py:315, LunaTrainingApp.logMetrics

neg_count = int(negLabel_mask.sum())
pos_count = int(posLabel_mask.sum())
 
trueNeg_count = neg_correct = int((negLabel_mask & negPred_mask).sum())
truePos_count = pos_correct = int((posLabel_mask & posPred_mask).sum())
 
falsePos_count = neg_count - neg_correct
falseNeg_count = pos_count - pos_correct

Here, we can see that neg_correct is the same thing as trueNeg_count! That actually makes sense, since non-nodule is our “negative” value (as in “a negative diagnosis”), and if the classifier gets the prediction correct, then that’s a true negative. Similarly, correctly labeled nodule samples are true positives.

We do need to add the variables for our false positive and false negative values. That’s straightforward, since we can take the total number of benign labels and subtract the count of the correct ones. What’s left is the count of non-nodule samples misclassified as positive. Hence, they are false positives. Again, the false negative calculation is of the same form, but uses nodule counts.

With those values, we can compute precision and recall and store them in metrics _dict.

Listing 12.2 training.py:333, LunaTrainingApp.logMetrics

precision = metrics_dict['pr/precision'] = 
  truePos_count / np.float32(truePos_count + falsePos_count)
recall  = metrics_dict['pr/recall'] = 
  truePos_count / np.float32(truePos_count + falseNeg_count)

Note the double assignment: while having separate precision and recall variables isn’t strictly necessary, they improve the readability of the next section. We also extend the logging statement in logMetrics to include the new values, but we skip the implementation for now (we’ll revisit logging later in the chapter).

12.3.4 Our ultimate performance metric: The F1 score

While useful, neither precision nor recall entirely captures what we need in order to be able to evaluate a model. As we’ve seen with Roxie and Preston, it’s possible to game either one individually by manipulating our classification threshold, resulting in a model that scores well on one or the other but does so at the expense of any real-world utility. We need something that combines both of those values in a way that prevents such gamesmanship. As we can see in figure 12.10, it’s time to introduce our ultimate metric.

The generally accepted way of combining precision and recall is by using the F1 score (https://en.wikipedia.org/wiki/F1_score). As with other metrics, the F1 score ranges between 0 (a classifier with no real-world predictive power) and 1 (a classifier that has perfect predictions). We will update logMetrics to include this as well.

Listing 12.3 training.py:338, LunaTrainingApp.logMetrics

metrics_dict['pr/f1_score'] = 
  2 * (precision * recall) / (precision + recall)

At first glance, this might seem more complicated than we need, and it might not be immediately obvious how the F1 score behaves when trading off precision for recall or vice versa. This formula has a lot of nice properties, however, and it compares favorably to several other, simpler alternatives that we might consider.

Figure 12.10 The set of topics for this chapter, with a focus on the final F1 score metric

One immediate possibility for a scoring function is to average the values for precision and recall together. Unfortunately, this gives both avg(p=1.0, r=0.0) and avg(p=0.5, r=0.5) the same score of 0.5, and as we discussed earlier, a classifier with either precision or recall of zero is usually worthless. Giving something useless the same nonzero score as something useful disqualifies averaging as a meaningful metric immediately.

Still, let’s visually compare averaging and F1 in figure 12.11. A few things stand out. First, we can see a lack of a curve or elbow in the contour lines for averaging. That’s what lets our precision or recall skew to one side or the other! There will never be a situation where it doesn’t make sense to maximize the score by having 100% recall (the Roxie approach) and then eliminate whichever false positives are easy to eliminate. That puts a floor on the addition score of 0.5 right out of the gate! Having a quality metric that is trivial to score at least 50% on doesn’t feel right.

Figure 12.11 Computing the final score with avg(p, r). Lighter values are closer to 1.0.

Note What we are actually doing here is taking the arithmetic mean (https://en.wikipedia.org/wiki/Arithmetic_mean) of the precision and recall, both of which are rates rather than countable scalar values. Taking the arithmetic mean of rates doesn’t typically give meaningful results. The F1 score is another name for the harmonic mean (https://en.wikipedia.org/wiki/ Harmonic_mean) of the two rates, which is a more appropriate way of combining those kinds of values.

Contrast that with the F1 score: when recall is high but precision is low, trading off a lot of recall for even a little precision will move the score closer to that balanced sweet spot. There’s a nice, deep elbow that is easy to slide into. That encouragement to have balanced precision and recall is what we want from our grading metric.

Let’s say we still want a simpler metric, but one that doesn’t reward skew at all. In order to correct for the weakness of addition, we might take the minimum of precision and recall (figure 12.12).

Figure 12.12 Computing the final score with min(p, r)

This is nice, because if either value is 0, the score is also 0, and the only way to get a score of 1.0 is to have both values be 1.0. However, it still leaves something to be desired, since making a model change that increased the recall from 0.7 to 0.9 while leaving precision constant at 0.5 wouldn’t improve the score at all, nor would dropping recall down to 0.6! Although this metric is certainly penalizing having an imbalance between precision and recall, it isn’t capturing a lot of nuance about the two values. As we have seen, it’s easy to trade one off for the other simply by moving the classification threshold. We’d like our metric to reflect those trades.

We’ll have to accept at least a bit more complexity to better meet our goals. We could multiply the two values together, as in figure 12.13. This approach keeps the nice property that if either value is 0, the score is 0, and a score of 1.0 means both inputs are perfect. It also favors a balanced trade-off between precision and recall at low values, though when it gets closer to perfect results, it becomes more linear. That’s not great, since we really need to push both up to have a meaningful improvement at that point.

Figure 12.13 Computing the final score with mult(p, r)

Note Here we’re taking the geometric mean (https://en.wikipedia.org/wiki/ Geometric_mean) of two rates, which also doesn’t produce meaningful results.

There’s also the issue of having almost the entire quadrant from (0, 0) to (0.5, 0.5) be very close to zero. As we’ll see, having a metric that’s sensitive to changes in that region is important, especially in the early stages of our model design.

While using multiplication as our scoring function is feasible (it doesn’t have any immediate disqualifications the way the previous scoring functions did), we will be using the F1 score to evaluate our classification model’s performance going forward.

Updating the logging output to include precision, recall, and F1 score

Now that we have our new metrics, adding them to our logging output is pretty straightforward. We’ll include precision, recall, and F1 in our main logging statement for each of our training and validation sets.

Listing 12.4 training.py:341, LunaTrainingApp.logMetrics

log.info(
  ("E{} {:8} {loss/all:.4f} loss, "
     + "{correct/all:-5.1f}% correct, "
     + "{pr/precision:.4f} precision, "   
     + "{pr/recall:.4f} recall, "         
     + "{pr/f1_score:.4f} f1 score"       
  ).format(
    epoch_ndx,
    mode_str,
    **metrics_dict,
  )
)

Format string updated

In addition, we’ll include exact values for the count of correctly identified and the total number of samples for each of the negative and positive samples.

Listing 12.5 training.py:353, LunaTrainingApp.logMetrics

log.info(
  ("E{} {:8} {loss/neg:.4f} loss, "
     + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
  ).format(
    epoch_ndx,
    mode_str + '_neg',
    neg_correct=neg_correct,
    neg_count=neg_count,
    **metrics_dict,
  )
)

The new version of the positive logging statement looks much the same.

12.3.5 How does our model perform with our new metrics?

Now that we’ve implemented our shiny new metrics, let’s take them for a spin; we’ll discuss the results after we show the results of the Bash shell session. You might want to read ahead while your system does its number crunching; this could take perhaps half an hour, depending on your system.2 Exactly how long it takes will depend on your system’s CPU, GPU, and disk speeds; our system with an SSD and GTX 1080 Ti took about 20 minutes per full epoch:

$ ../.venv/bin/python -m p2ch12.training
Starting LunaTrainingApp...
...
E1 LunaTrainingApp
 
.../p2ch12/training.py:274: RuntimeWarning: invalid value encountered in double_scalars
  metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)                                          
 
E1 trn      0.0025 loss,  99.8% correct, 0.0000 prc, 0.0000 rcl, nan f1
E1 trn_ben  0.0000 loss, 100.0% correct (494735 of 494743)
E1 trn_mal  1.0000 loss,   0.0% correct (0 of 1215)
 
.../p2ch12/training.py:269: RuntimeWarning: invalid value encountered in long_scalars
  precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
 
E1 val      0.0025 loss,  99.8% correct, nan prc, 0.0000 rcl, nan f1
E1 val_ben  0.0000 loss, 100.0% correct (54971 of 54971)
E1 val_mal  1.0000 loss,   0.0% correct (0 of 136)

The exact count and line numbers of these RuntimeWarning lines might be different from run to run.

Bummer. We’ve got some warnings, and given that some of the values we computed were nan, there’s probably a division by zero happening somewhere. Let’s see what we can figure out.

First, since none of the positive samples in the training set are getting classified as positive, that means both precision and recall are zero, which results in our F1 score calculation dividing by zero. Second, for our validation set, truePos_count and falsePos_count are both zero due to nothing being flagged as positive. It follows that the denominator of our precision calculation is also zero; that makes sense, as that’s where we’re seeing another RuntimeWarning.

A handful of negative training samples are classified as positive (494735 of 494743 are classified as negative, so that leaves 8 samples misclassified). While that might seem odd at first, recall that we are collecting our training results throughout the epoch, rather than using the model’s end-of-epoch state as we do for the validation results. That means the first batch is literally producing random results. A few of the samples from that first batch being flagged as positive isn’t surprising.

Note Due to both the random initialization of the network weights and the random ordering of the training samples, individual runs will likely exhibit slightly different behavior. Having exactly reproducible behavior can be desirable but is out of scope for what we’re trying to do in part 2 of this book.

Well, that was somewhat painful. Switching to our new metrics resulted in going from A+ to “Zero, if you’re lucky”--and if we’re not lucky, the score is so bad that it’s not even a number. Ouch.

That said, in the long run, this is good for us. We’ve known that our model’s performance was garbage since chapter 11. If our metrics told us anything but that, it would point to a fundamental flaw in the metrics!

12.4 What does an ideal dataset look like?

Before we start crying into our cups over the current sorry state of affairs, let’s instead think about what we actually want our model to do. Figure 12.14 says that first we need to balance our data so that our model can train properly. Let’s build up the logical steps needed to get us there.

Figure 12.14 The set of topics for this chapter, with a focus on balancing our positive and negative samples

Recall figure 12.5 earlier, and the following discussion of classification thresholds. Getting better results by moving the threshold has limited effectiveness--there’s just too much overlap between the positive and negative classes to work with.3

Instead, we want to see an image like figure 12.15. Here, our label threshold is nearly vertical. That’s what we want, because it means the label threshold and our classification threshold can line up reasonably well. Similarly, most of the samples are concentrated at either end of the diagram. Both of these things require that our data be easily separable and that our model have the capacity to perform that separation. Our model currently has enough capacity, so that’s not the issue. Instead, let’s take a look at our data.

Figure 12.15 A well-trained model can cleanly separate data, making it easy to pick a classification threshold with few trade-offs.

Figure 12.16 An imbalanced dataset that roughly approximates the imbalance in our LUNA classification data

Recall that our data is wildly imbalanced. There’s a 400:1 ratio of positive samples to negative ones. That’s crushingly imbalanced! Figure 12.16 shows what that looks like. No wonder our “actually nodule” samples are getting lost in the crowd!

Now, let’s be perfectly clear: when we’re done, our model will be able to handle this kind of data imbalance just fine. We could probably even train the model all the way there without changing the balancing, assuming we were willing to wait for a gajillion epochs first.4 But we’re busy people with things to do, so rather than cook our GPU until the heat death of the universe, let’s try to make our training data look more ideal by changing the class balance we are training with.

12.4.1 Making the data look less like the actual and more like the “ideal”

The best thing to do would be to have relatively more positive samples. During the initial epoch of training, when we’re going from randomized chaos to something more organized, having so few training samples be positive means they get drowned out.

The method by which this happens is somewhat subtle, however. Recall that since our network weights are initially randomized, the per-sample output of the network is also randomized (but clamped to the range [0-1]).

Note Our loss function is nn.CrossEntropyLoss, which technically operates on the raw logits rather than the class probabilities. For our discussion, we’ll ignore that distinction and assume the loss and the label-prediction deltas are the same thing.

The predictions numerically close to the correct label do not result in much change to the weights of the network, while predictions that are significantly different from the correct answer are responsible for a much greater change to the weights. Since the output is random when the model is initialized with random weights, we can assume that of our ~500k training samples (495,958, to be exact), we’ll have the following approximate groups:

  1. 250,000 negative samples will be predicted to be negative (0.0 to 0.5) and result in at most a small change to the network weights toward predicting negative.

  2. 250,000 negative samples will be predicted to be positive (0.5 to 1.0) and result in a large swing toward the network weights predicting negative.

  3. 500 positive samples will be predicted to be negative and result in a swing toward the network weights predicting positive.

  4. 500 positive samples will be predicted to be positive and result in almost no change to the network weights.

Note Keep in mind that the actual predictions are real numbers between 0.0 and 1.0 inclusive, so these groups won’t have strict delineations.

Here’s the kicker, though: groups 1 and 4 can be any size, and they will continue to have close to zero impact on training. The only thing that matters is that groups 2 and 3 can counteract each other’s pull enough to prevent the network from collapsing to a degenerate “only output one thing” state. Since group 2 is 500 times larger than group 3 and we’re using a batch size of 32, roughly 500/32 = 15 batches will go by before seeing a single positive sample. That implies that 14 out of 15 training batches will be 100% negative and will only pull all model weights toward predicting negative. That lopsided pull is what produces the degenerate behavior we’ve been seeing.

Instead, we’d like to have just as many positive samples as negative ones. For the first part of training, then, half of both labels will be classified incorrectly, meaning that groups 2 and 3 should be roughly equal in size. We also want to make sure we present batches with a mix of negative and positive samples. Balance would result in the tug-of-war evening out, and the mixture of classes per batch will give the model a decent chance of learning to discriminate between the two classes. Since our LUNA data has only a small, fixed number of positive samples, we’ll have to settle for taking the positive samples that we have and presenting them repeatedly during training.

Discrimination

Here, we define discrimination as “the ability to separate two classes from each other.” Building and training a model that can tell “actually nodule” candidates from normal anatomical structures is the entire point of what we’re doing in part 2.

Some other definitions of discrimination are more problematic. While out of scope for the discussion of our work here, there is a larger issue with models trained from real-world data. If that real-world dataset is collected from sources that have a real-world-discriminatory bias (for example, racial bias in arrest and conviction rates, or anything collected from social media), and that bias is not corrected for during dataset preparation or training, then the resulting model will continue to exhibit the same biases present in the training data. Just as in humans, racism is learned.

This means almost any model trained from internet-at-large data sources will be compromised in some fashion, unless extreme care is taken to scrub those biases from the model. Note that like our goal in part 2, this is considered an unsolved problem.

 

Recall our professor from chapter 11 who had a final exam with 99 false answers and 1 true answer. The next semester, after being told “You should have a more even balance of true and false answers,” the professor decided to add a midterm with 99 true answers and 1 false one. “Problem solved!”

Clearly, the correct approach is to intermix true and false answers in a way that doesn’t allow the students to exploit the larger structure of the tests to answer things correctly. Whereas a student would pick up on a pattern like “odd questions are true, even questions are false,” the batching system used by PyTorch doesn’t allow the model to “notice” or utilize that kind of pattern. Our training dataset will need to be updated to alternate between positive and negative samples, as in figure 12.17.

The unbalanced data is the proverbial needle in the haystack we mentioned at the start of chapter 9. If you had to perform this classification work by hand, you’d probably start to empathize with Preston.

Figure 12.17 Batch after batch of imbalanced data will have nothing but negative events long before the first positive event, while balanced data can alternate every other sample.

We will not be doing any balancing for validation, however. Our model needs to function well in the real world, and the real world is imbalanced (after all, that’s where we got the raw data!).

How should we accomplish this balancing? Let’s discuss our choices.

Samplers can reshape datasets

One of the optional arguments to DataLoader is sampler=... . This allows the data loader to override the iteration order native to the dataset passed in and instead shape, limit, or reemphasize the underlying data as desired. This can be incredibly useful when working with a dataset that isn’t under your control. Taking a public dataset and reshaping it to meet your needs is far less work than reimplementing that dataset from scratch.

The downside is that many of the mutations we could accomplish with samplers require that we break encapsulation of the underlying dataset. For example, let’s assume we have a dataset like CIFAR-10 (www.cs.toronto.edu/~kriz/cifar.html) that consists of 10 equally weighted classes, and we want to instead have 1 class (say, “airplane”) now make up 50% of all of the training images. We could decide to use WeightedRandomSampler (http://mng.bz/8plK) and weight each of the “airplane” sample indexes higher, but constructing the weights argument requires that we know in advance which indexes are airplanes.

As we discussed, the Dataset API only specifies that subclasses provide __len__ and __getitem__, but there is nothing direct we can use to ask “Which samples are airplanes?” We’d either have to load up every sample beforehand to inquire about the class of that sample, or we’d have to break encapsulation and hope the information we need is easily obtained from looking at the internal implementation of the Dataset subclass.

Since neither of those options is particularly ideal in cases where we have control over the dataset directly, the code for part 2 implements any needed data shaping inside the Dataset subclasses instead of relying on an external sampler.

Implementing class balancing in the dataset

We are going to directly change our LunaDataset to present a balanced, one-to-one ratio of positive and negative samples for training. We will keep separate lists of negative training samples and positive training samples, and alternate returning samples from each of those two lists. This will prevent the degenerate behavior of the model scoring well by simply answering “false” to every sample presented. In addition, the positive and negative classes will be intermixed so that the weight updates are forced to discriminate between the classes.

Let’s add a ratio_int to LunaDataset that will control the label for the Nth sample as well as keep track of our samples separated by label.

Listing 12.6 dsets.py:217, class LunaDataset

class LunaDataset(Dataset):
  def __init__(self,
         val_stride=0,
         isValSet_bool=None,
         ratio_int=0,
      ):
    self.ratio_int = ratio_int
    # ... line 228
    self.negative_list = [
      nt for nt in self.candidateInfo_list if not nt.isNodule_bool
    ]
    self.pos_list = [
      nt for nt in self.candidateInfo_list if nt.isNodule_bool
    ]
    # ... line 265
 
  def shuffleSamples(self):               
    if self.ratio_int:
      random.shuffle(self.negative_list)
      random.shuffle(self.pos_list)

We will call this at the top of each epoch to randomize the order of samples being presented.

With this, we now have dedicated lists for each label. Using these lists, it becomes much easier to return the label we want for a given index into the dataset. In order to make sure we’re getting the indexing right, we should sketch out the ordering we want. Let’s assume a ratio_int of 2, meaning a 2:1 ratio of negative to positive samples. That would mean every third index should be positive:

DS Index   0 1 2 3 4 5 6 7 8 9 ...
Label      + - - + - - + - - +
Pos Index  0     1     2     3
Neg Index    0 1   2 3   4 5

The relationship between the dataset index and the positive index is simple: divide the dataset index by 3 and then round down. The negative index is slightly more complicated, in that we have to subtract 1 from the dataset index and then subtract the most recent positive index as well.

Implemented in our LunaDataset class, that looks like the following.

Listing 12.7 dsets.py:286, LunaDataset.__getitem__

def __getitem__(self, ndx):
  if self.ratio_int:                                   
    pos_ndx = ndx // (self.ratio_int + 1)
 
    if ndx % (self.ratio_int + 1):                     
      neg_ndx = ndx - 1 - pos_ndx
      neg_ndx %= len(self.negative_list)               
      candidateInfo_tup = self.negative_list[neg_ndx]
    else:
      pos_ndx %= len(self.pos_list)                    
      candidateInfo_tup = self.pos_list[pos_ndx]
  else:
    candidateInfo_tup = self.candidateInfo_list[ndx]   

A ratio_int of zero means use the native balance.

A nonzero remainder means this should be a negative sample.

Overflow results in wraparound.

Returns the Nth sample if not balancing classes

That can get a little hairy, but if you desk-check it out, it will make sense. Keep in mind that with a low ratio, we’ll run out of positive samples before exhausting the dataset. We take care of that by taking the modulus of pos_ndx before indexing into self.pos_list. While the same kind of index overflow should never happen with neg_ndx due to the large number of negative samples, we do the modulus anyway, just in case we later decide to make a change that might cause it to overflow.

We’ll also make a change to our dataset’s length. Although this isn’t strictly necessary, it’s nice to speed up individual epochs. We’re going to hardcode our __len__ to be 200,000.

Listing 12.8 dsets.py:280, LunaDataset.__len__

def __len__(self):
  if self.ratio_int:
    return 200000
  else:
    return len(self.candidateInfo_list)

We’re no longer tied to a specific number of samples, and presenting “a full epoch” doesn’t really make sense when we would have to repeat positive samples many, many times to present a balanced training set. By picking 200,000 samples, we reduce the time between starting a training run and seeing results (faster feedback is always nice!), and we give ourselves a nice, clean number of samples per epoch. Feel free to adjust the length of an epoch to meet your needs.

For completeness, we also add a command-line parameter.

Listing 12.9 training.py:31, class LunaTrainingApp

class LunaTrainingApp:
  def __init__(self, sys_argv=None):
    # ... line 52
    parser.add_argument('--balanced',
      help="Balance the training data to half positive, half negative.",
      action='store_true',
      default=False,
    )

Then we pass that parameter into the LunaDataset constructor.

Listing 12.10 training.py:137, LunaTrainingApp.initTrainDl

def initTrainDl(self):
  train_ds = LunaDataset(
    val_stride=10,
    isValSet_bool=False,
    ratio_int=int(self.cli_args.balanced),    
  )

Here we rely on python’s True being convertible to a 1.

We’re all set. Let’s run it!

12.4.2 Contrasting training with a balanced LunaDataset to previous runs

As a reminder, our unbalanced training run had results like these:

$ python -m p2ch12.training
...
E1 LunaTrainingApp
E1 trn      0.0185 loss,  99.7% correct, 0.0000 precision, 0.0000 recall, nan f1 score
E1 trn_neg  0.0026 loss, 100.0% correct (494717 of 494743)
E1 trn_pos  6.5267 loss,   0.0% correct (0 of 1215)
...
E1 val      0.0173 loss,  99.8% correct, nan precision, 0.0000 recall, nan f1 score
E1 val_neg  0.0026 loss, 100.0% correct (54971 of 54971)
E1 val_pos  5.9577 loss,   0.0% correct (0 of 136)

But when we run with --balanced, we see the following:

$ python -m p2ch12.training --balanced
...
E1 LunaTrainingApp
E1 trn      0.1734 loss,  92.8% correct, 0.9363 precision, 0.9194 recall, 0.9277 f1 score
E1 trn_neg  0.1770 loss,  93.7% correct (93741 of 100000)
E1 trn_pos  0.1698 loss,  91.9% correct (91939 of 100000)
...
E1 val      0.0564 loss,  98.4% correct, 0.1102 precision, 0.7941 recall, 0.1935 f1 score
E1 val_neg  0.0542 loss,  98.4% correct (54099 of 54971)
E1 val_pos  0.9549 loss,  79.4% correct (108 of 136)

This seems much better! We’ve given up about 5% correct answers on the negative samples to gain 86% correct positive answers. We’re back into a solid B range again!5

As in chapter 11, however, this result is deceptive. Since there are 400 times as many negative samples as positive ones, even getting just 1% wrong means we’d be incorrectly classifying negative samples as positive four times more often than there are actually positive samples in total!

Still, this is clearly better than the outright wrong behavior from chapter 11 and much better than a random coin flip. In fact, we’ve even crossed over into being (almost) legitimately useful in real-world scenarios. Recall our overworked radiologist poring over each and every speck of a CT: well, now we’ve got something that can do a reasonable job of screening out 95% of the false positives. That’s a huge help, since it translates into about a tenfold increase in productivity for the machine-assisted human.

Of course, there’s still that pesky issue of the 14% of positive samples that were missed, which we should probably deal with. Perhaps some additional epochs of training would help. Let’s see (and again, expect to spend at least 10 minutes per epoch):

$ python -m p2ch12.training --balanced --epochs 20
...
E2 LunaTrainingApp
E2 trn      0.0432 loss,  98.7% correct, 0.9866 precision, 0.9879 recall, 0.9873 f1 score
E2 trn_ben  0.0545 loss,  98.7% correct (98663 of 100000)
E2 trn_mal  0.0318 loss,  98.8% correct (98790 of 100000)
E2 val      0.0603 loss,  98.5% correct, 0.1271 precision, 0.8456 recall, 0.2209 f1 score
E2 val_ben  0.0584 loss,  98.6% correct (54181 of 54971)
E2 val_mal  0.8471 loss,  84.6% correct (115 of 136)
...
E5 trn      0.0578 loss,  98.3% correct, 0.9839 precision, 0.9823 recall, 0.9831 f1 score
E5 trn_ben  0.0665 loss,  98.4% correct (98388 of 100000)
E5 trn_mal  0.0490 loss,  98.2% correct (98227 of 100000)
E5 val      0.0361 loss,  99.2% correct, 0.2129 precision, 0.8235 recall, 0.3384 f1 score
E5 val_ben  0.0336 loss,  99.2% correct (54557 of 54971)
E5 val_mal  1.0515 loss,  82.4% correct (112 of 136)...
...
E10 trn      0.0212 loss,  99.5% correct, 0.9942 precision, 0.9953 recall, 0.9948 f1 score
E10 trn_ben  0.0281 loss,  99.4% correct (99421 of 100000)
E10 trn_mal  0.0142 loss,  99.5% correct (99530 of 100000)
E10 val      0.0457 loss,  99.3% correct, 0.2171 precision, 0.7647 recall, 0.3382 f1 score
E10 val_ben  0.0407 loss,  99.3% correct (54596 of 54971)
E10 val_mal  2.0594 loss,  76.5% correct (104 of 136)
...
E20 trn      0.0132 loss,  99.7% correct, 0.9964 precision, 0.9974 recall, 0.9969 f1 score
E20 trn_ben  0.0186 loss,  99.6% correct (99642 of 100000)
E20 trn_mal  0.0079 loss,  99.7% correct (99736 of 100000)
E20 val      0.0200 loss,  99.7% correct, 0.4780 precision, 0.7206 recall, 0.5748 f1 score
E20 val_ben  0.0133 loss,  99.8% correct (54864 of 54971)
E20 val_mal  2.7101 loss,  72.1% correct (98 of 136)

Ugh. That’s a lot of text to scroll past to get to the numbers we’re interested in. Let’s power through and focus on the val_mal XX.X% correct numbers (or skip ahead to the TensorBoard graph in the next section.) After epoch 2, we were at 87.5%; on epoch 5, we peaked with 92.6%; and then by epoch 20 we dropped down to 86.8%--below our second epoch!

Note As mentioned earlier, expect each run to have unique behavior due to random initialization of network weights and random selection and ordering of training samples per epoch.

The training set numbers don’t seem to be having the same problem. Negative training samples are classified correctly 98.8% of the time, and positive samples are 99.1% correct. What’s going on?

12.4.3 Recognizing the symptoms of overfitting

What we are seeing are clear signs of overfitting. Let’s take a look at the graph of our loss on positive samples, in figure 12.18.

Figure 12.18 Our positive loss showing clear signs of overfitting, as the training loss and validation loss are trending in different directions

Here, we can see that the training loss for our positive samples is nearly zero--each positive training sample gets a nearly perfect prediction. Our validation loss for positive samples is increasing, though, and that means our real-world performance is likely getting worse. At this point, it’s often best to stop the training script, since the model is no longer improving.

tip Generally, if your model’s performance is improving on your training set while getting worse on your validation set, the model has started overfitting.

We must take care to examine the right metrics, however, since this trend is only happening on our positive loss. If we take a look at our overall loss, everything seems fine! That’s because our validation set is not balanced, so the overall loss is dominated by our negative samples. As shown in figure 12.19, we are not seeing the same divergent behavior for our negative samples. Instead, our negative loss looks great! That’s because we have 400 times more negative samples, so it’s much, much harder for the model to remember individual details. Our positive training set has only 1,215 samples, though. While we repeat those samples multiple times, that doesn’t make them harder to memorize. The model is shifting from generalized principles to essentially memorizing quirks of those 1,215 samples and claiming that anything that’s not one of those few samples is negative. This includes both negative training samples and everything in our validation set (both positive and negative).

Figure 12.19 Our negative loss showing no signs of overfitting

Clearly, some generalization is still going on, since we are classifying about 70% of the positive validation set correctly. We just need to change how we’re training the model so that our training set and validation set both trend in the right direction.

12.5 Revisiting the problem of overfitting

We touched on the concept of overfitting in chapter 5, and now it’s time to take a closer look at how to address this common situation. Our goal with training a model is to teach it to recognize the general properties of the classes we are interested in, as expressed in our dataset. Those general properties are present in some or all samples of the class and can be generalized and used to predict samples that haven’t been trained on. When the model starts to learn specific properties of the training set, overfitting occurs, and the model starts to lose the ability to generalize. In case that’s a bit too abstract, let’s use another analogy.

12.5.1 An overfit face-to-age prediction model

Let’s pretend we have a model that takes an image of a human face as input and outputs a predicted age in years. A good model would pick up on age signifiers like wrinkles, gray hair, hairstyle, clothing choices, and similar, and use those to build a general model of what different ages look like. When presented with a new picture, it would consider things like “conservative haircut” and “reading glasses” and “wrinkles” to conclude “around 65 years old.”

An overfit model, by contrast, instead remembers specific people by remembering identifying details. “That haircut and those glasses mean it’s Frank. He’s 62.8 years old”; “Oh, that scar means it’s Harry. He’s 39.3”; and so on. When shown a new person, the model won’t recognize the person and will have absolutely no idea what age to predict.

Even worse, if shown a picture of Frank Jr. (the spittin’ image of his dad, at least when he’s wearing his glasses!), the model will say, “I think that’s Frank. He’s 62.8 years old.” Never mind that Junior is 25 years younger!

Overfitting is usually due to having too few training samples when compared to the ability of the model to just memorize the answers. The median human can memorize the birthdays of their immediate family but would have to resort to generalizations when predicting the ages of any group larger than a small village.

Our face-to-age model has the capacity to simply memorize the photos of anyone who doesn’t look exactly their age. As we discussed in part 1, model capacity is a somewhat abstract concept, but is roughly a function of the number of parameters of the model times how efficiently those parameters are used. When a model has a high capacity relative to the amount of data needed to memorize the hard samples from the training set, it’s likely that the model will begin to overfit on those more difficult training samples.

12.6 Preventing overfitting with data augmentation

It’s time to take our model training from good to great. We need to cover one last step in figure 12.20.

Figure 12.20 The set of topics for this chapter, with a focus on data augmentation

We augment a dataset by applying synthetic alterations to individual samples, resulting in a new dataset with an effective size that is larger than the original. The typical goal is for the alterations to result in a synthetic sample that remains representative of the same general class as the source sample, but that cannot be trivially memorized alongside the original. When done properly, this augmentation can increase the training set size beyond what the model is capable of memorizing, resulting in the model being forced to increasingly rely on generalization, which is exactly what we want. Doing so is especially useful when dealing with limited data, as we saw in section 12.4.1.

Of course, not all augmentations are equally useful. Going back to our example of a face-to-age prediction model, we could trivially change the red channel of the four corner pixels of each image to a random value 0-255, which would result in a dataset 4 billion times larger the original. Of course, this wouldn’t be particularly useful, since the model can pretty trivially learn to ignore the red dots in the image corners, and the rest of the image remains as easy to memorize as the single, unaugmented original image. Contrast that approach with flipping the image left to right. Doing so would only result in a dataset twice as large as the original, but each image would be quite a bit more useful for training purposes. The general properties of aging are not correlated left to right, so a mirrored image remains representative. Similarly, it’s rare for facial pictures to be perfectly symmetrical, so a mirrored version is unlikely to be trivially memorized alongside the original.

12.6.1 Specific data augmentation techniques

We are going to implement five specific types of data augmentation. Our implementation will allow us to experiment with any or all of them, individually or in aggregate. The five techniques are as follows:

  • Mirroring the image up-down, left-right, and/or front-back

  • Shifting the image around by a few voxels

  • Scaling the image up or down

  • Rotating the image around the head-foot axis

  • Adding noise to the image

For each technique, we want to make sure our approach maintains the training sample’s representative nature, while being different enough that the sample is useful to train with.

We’ll define a function getCtAugmentedCandidate that is responsible for taking our standard chunk-of-CT-with-candidate-inside and modifying it. Our main approach will define an affine transformation matrix (http://mng.bz/Edxq) and use it with the PyTorch affine_grid (https://pytorch.org/docs/stable/nn.html#affine-grid) and grid _sample (https://pytorch.org/docs/stable/nn.html#torch.nn.functional.grid_sample) functions to resample our candidate.

Listing 12.11 dsets.py:149, def getCtAugmentedCandidate

def getCtAugmentedCandidate(
    augmentation_dict,
    series_uid, center_xyz, width_irc,
    use_cache=True):
  if use_cache:
    ct_chunk, center_irc = 
      getCtRawCandidate(series_uid, center_xyz, width_irc)
  else:
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
 
  ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)

We first obtain ct_chunk, either from the cache or directly by loading the CT (something that will come in handy once we are creating our own candidate centers), and then convert it to a tensor. Next is the affine grid and sampling code.

Listing 12.12 dsets.py:162, def getCtAugmentedCandidate

transform_t = torch.eye(4)
# ...                        
# ... line 195
affine_t = F.affine_grid(
    transform_t[:3].unsqueeze(0).to(torch.float32),
    ct_t.size(),
    align_corners=False,
  )
 
augmented_chunk = F.grid_sample(
    ct_t,
    affine_t,
    padding_mode='border',
    align_corners=False,
  ).to('cpu')
# ... line 214
return augmented_chunk[0], center_irc

Modifications to transform_tensor will go here.

Without anything additional, this function won’t do much. Let’s see what it takes to add in some actual transforms.

Note It’s important to structure your data pipeline such that your caching steps happen before augmentation! Doing otherwise will result in your data being augmented once and then persisted in that state, which defeats the purpose.

Mirroring

When mirroring a sample, we keep the pixel values exactly the same and only change the orientation of the image. Since there’s no strong correlation between tumor growth and left-right or front-back, we should be able to flip those without changing the representative nature of the sample. The index-axis (referred to as Z in patient coordinates) corresponds to the direction of gravity in an upright human, however, so there’s a possibility of a difference in the top and bottom of a tumor. We are going to assume it’s fine, since quick visual investigation doesn’t show any gross bias. Were we working toward a clinically relevant project, we’d need to confirm that assumption with an expert.

Listing 12.13 dsets.py:165, def getCtAugmentedCandidate

for i in range(3):
  if 'flip' in augmentation_dict:
    if random.random() > 0.5:
      transform_t[i,i] *= -1

The grid_sample function maps the range [-1, 1] to the extents of both the old and new tensors (the rescaling happens implicitly if the sizes are different). This range mapping means that to mirror the data, all we need to do is multiply the relevant element of the transformation matrix by -1.

Shifting by a random offset

Shifting the nodule candidate around shouldn’t make a huge difference, since convolutions are translation independent, though this will make our model more robust to imperfectly centered nodules. What will make a more significant difference is that the offset might not be an integer number of voxels; instead, the data will be resampled using trilinear interpolation, which can introduce some slight blurring. Voxels at the edge of the sample will be repeated, which can be seen as a smeared, streaky section along the border.

Listing 12.14 dsets.py:165, def getCtAugmentedCandidate

for i in range(3):
  # ... line 170
  if 'offset' in augmentation_dict:
    offset_float = augmentation_dict['offset']
    random_float = (random.random() * 2 - 1)
    transform_t[i,3] = offset_float * random_float

Note that our 'offset' parameter is the maximum offset expressed in the same scale as the [-1, 1] range the grid sample function expects.

Scaling

Scaling the image slightly is very similar to mirroring and shifting. Doing so can also result in the same repeated edge voxels we just mentioned when discussing shifting the sample.

Listing 12.15 dsets.py:165, def getCtAugmentedCandidate

for i in range(3):
  # ... line 175
  if 'scale' in augmentation_dict:
    scale_float = augmentation_dict['scale']
    random_float = (random.random() * 2 - 1)
    transform_t[i,i] *= 1.0 + scale_float * random_float

Since random_float is converted to be in the range [-1, 1], it doesn’t actually matter if we add scale_float * random_float to or subtract it from 1.0.

Rotating

Rotation is the first augmentation technique we’re going to use where we have to carefully consider our data to ensure that we don’t break our sample with a conversion that causes it to no longer be representative. Recall that our CT slices have uniform spacing along the rows and columns (X- and Y-axes), but in the index (or Z) direction, the voxels are non-cubic. That means we can’t treat those axes as interchangeable.

One option is to resample our data so that our resolution along the index-axis is the same as along the other two, but that’s not a true solution because the data along that axis would be very blurry and smeared. Even if we interpolate more voxels, the fidelity of the data would remain poor. Instead, we’ll treat that axis as special and confine our rotations to the X-Y plane.

Listing 12.16 dsets.py:181, def getCtAugmentedCandidate

if 'rotate' in augmentation_dict:
  angle_rad = random.random() * math.pi * 2
  s = math.sin(angle_rad)
  c = math.cos(angle_rad)
 
  rotation_t = torch.tensor([
    [c, -s, 0, 0],
    [s, c, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1],
  ])
 
  transform_t @= rotation_t

Noise

Our final augmentation technique is different from the others in that it is actively destructive to our sample in a way that flipping or rotating the sample is not. If we add too much noise to the sample, it will swamp the real data and make it effectively impossible to classify. While shifting and scaling the sample would do something similar if we used extreme input values, we’ve chosen values that will only impact the edge of the sample. Noise will have an impact on the entire image.

Listing 12.17 dsets.py:208, def getCtAugmentedCandidate

if 'noise' in augmentation_dict:
  noise_t = torch.randn_like(augmented_chunk)
  noise_t *= augmentation_dict['noise']
 
  augmented_chunk += noise_t

The other augmentation types have increased the effective size of our dataset. Noise makes our model’s job harder. We’ll revisit this once we see some training results.

Examining augmented candidates

We can see the result of our efforts in figure 12.21. The upper-left image shows an un-augmented positive candidate, and the next five show the effect of each augmentation type in isolation. Finally, the bottom row shows the combined result three times.

Figure 12.21 Various augmentation types performed on a positive nodule sample

Since each __getitem__ call to the augmenting dataset reapplies the augmentations randomly, each image on the bottom row looks different. This also means it’s nearly impossible to generate an image exactly like this again! It’s also important to remember that sometimes the 'flip' augmentation will result in no flip. Returning always-flipped images is just as limiting as not flipping in the first place. Now let’s see if any of this makes a difference.

12.6.2 Seeing the improvement from data augmentation

We are going to train additional models, one per augmentation type discussed in the last section, with an additional model training run that combines all of the augmentation types. Once they’re finished, we’ll take a look at our numbers in TensorBoard.

In order to be able to turn our new augmentation types on and off, we need to expose the construction of augmentation_dict to our command-line interface. Arguments to our program will be added by parser.add_argument calls (not shown, but similar to the ones our program already has), which will then be fed into code that actually constructs augmentation_dict.

Listing 12.18 training.py:105, LunaTrainingApp.__init__

self.augmentation_dict = {}
if self.cli_args.augmented or self.cli_args.augment_flip:
  self.augmentation_dict['flip'] = True
if self.cli_args.augmented or self.cli_args.augment_offset:
  self.augmentation_dict['offset'] = 0.1                     
if self.cli_args.augmented or self.cli_args.augment_scale:
  self.augmentation_dict['scale'] = 0.2                      
if self.cli_args.augmented or self.cli_args.augment_rotate:
  self.augmentation_dict['rotate'] = True
if self.cli_args.augmented or self.cli_args.augment_noise:
  self.augmentation_dict['noise'] = 25.0                     

These values were empirically chosen to have a reasonable impact, but better values probably exist.

Now that we have those command-line arguments ready, you can either run the following commands or revisit p2_run_everything.ipynb and run cells 8 through 16. Either way you run it, expect these to take a significant time to finish:

$ .venv/bin/python -m p2ch12.prepcache                   
 
$ .venv/bin/python -m p2ch12.training --epochs 20 
        --balanced sanity-bal                            
 
$ .venv/bin/python -m p2ch12.training --epochs 10 
        --balanced --augment-flip   sanity-bal-flip
 
$ .venv/bin/python -m p2ch12.training --epochs 10 
        --balanced --augment-shift  sanity-bal-shift
 
$ .venv/bin/python -m p2ch12.training --epochs 10 
        --balanced --augment-scale  sanity-bal-scale
 
$ .venv/bin/python -m p2ch12.training --epochs 10 
        --balanced --augment-rotate sanity-bal-rotate
 
$ .venv/bin/python -m p2ch12.training --epochs 10 
        --balanced --augment-noise  sanity-bal-noise
 
$ .venv/bin/python -m p2ch12.training --epochs 20 
        --balanced --augmented sanity-bal-aug

You only need to prep the cache once per chapter.

You might have this run from earlier in the chapter; in that case there’s no need to rerun it!

While that’s running, we can start TensorBoard. Let’s direct it to only show these runs by changing the logdir parameter like so: ../path/to/tensorboard --logdir runs/p2ch12.

Depending on the hardware you have at your disposal, the training might take a long time. Feel free to skip the flip, shift, and scale training jobs and reduce the first and last runs to 11 epochs if you need to move things along more quickly. We chose 20 runs because that helps them stand out from the other runs, but 11 should work as well.

If you let everything run to completion, your TensorBoard should have data like that shown in figure 12.22. We’re going to deselect everything except the validation data, to reduce clutter. When you’re looking at your data live, you can also change the smoothing value, which can help clarify the trend lines. Take a quick look at the figure, and then we’ll go over it in some detail.

Figure 12.22 Percent correctly classified, loss, F1 score, precision, and recall for the validation set from networks trained with a variety of augmentation schemes

The first thing to notice in the upper-left graph (“tag: correct/all”) is that the individual augmentation types are something of a jumble. Our unaugmented and fully augmented runs are on opposite sides of that jumble. That means when combined, our augmentation is more than the sum of its parts. Also of interest is that our fully augmented run gets many more wrong answers. While that’s bad generally, if we look at the right column of images (which focus on the positive candidate samples we actually care about--the ones that are really nodules), we see that our fully augmented model is much better at finding the positive candidate samples. The recall for the fully augmented model is great! It’s also much better at not overfitting. As we saw earlier, our unaugmented model gets worse over time.

One interesting thing to note is that the noise-augmented model is worse at identifying nodules than the unaugmented model. This makes sense if we remember that we said noise makes the model’s job harder.

Another interesting thing to see in the live data (it’s somewhat lost in the jumble here) is that the rotation-augmented model is nearly as good as the fully augmented model when it comes to recall, and it has much better precision. Since our F1 score is precision limited (due to the higher number of negative samples), the rotation-augmented model also has a better F1 score.

We’ll stick with the fully augmented model going forward, since our use case requires high recall. The F1 score will still be used to determine which epoch to save as the best. In a real-world project, we might want to devote extra time to investigating whether a different combination of augmentation types and parameter values could yield better results.

12.7 Conclusion

We spent a lot of time and energy in this chapter reformulating how we think about our model’s performance. It’s easy to be misled by poor methods of evaluation, and it’s crucial to have a strong intuitive understanding of the factors that feed into evaluating a model well. Once those fundamentals are internalized, it’s much easier to spot when we’re being led astray.

We’ve also learned about how to deal with data sources that aren’t sufficiently populated. Being able to synthesize representative training samples is incredibly useful. Situations where we have too much training data are rare indeed!

Now that we have a classifier that is performing reasonably, we’ll turn our attention to automatically finding candidate nodules to classify. Chapter 13 will start there; then, in chapter 14, we will feed those candidates back into the classifier we developed here and venture into building one more classifier to tell malignant nodules from benign ones.

12.8 Exercises

  1. The F1 score can be generalized to support values other than 1.

    1. Read https://en.wikipedia.org/wiki/F1_score, and implement F2 and F0.5 scores.

    2. Determine which of F1, F2, and F0.5 makes the most sense for this project. Track that value, and compare and contrast it with the F1 score. 6

  2. Implement a WeightedRandomSampler approach to balancing the positive and negative training samples for LunaDataset with ratio_int set to 0.

    1. How did you get the required information about the class of each sample?

    2. Which approach was easier? Which resulted in more readable code?

  3. Experiment with different class-balancing schemes.

    1. What ratio results in the best score after two epochs? After 20?

    2. What if the ratio is a function of epoch_ndx?

  4. Experiment with different data augmentation approaches.

    1. Can any of the existing approaches be made more aggressive (noise, offset, and so on)?

    2. Does the inclusion of noise augmentation help or hinder your training results?

      • Are there other values that change this result?
    3. Research data augmentation that other projects have used. Are any applicable here?

      • Implement “mixup” augmentation for positive nodule candidates. Does it help?
  5. Change the initial normalization from nn.BatchNorm to something custom, and retrain the model.

    1. Can you get better results using fixed normalization?

    2. What normalization offset and scale make sense?

    3. Do nonlinear normalizations like square roots help?

  6. What other kinds of data can TensorBoard display besides those we’ve covered here?

    1. Can you have it display information about the weights of your network?

    2. What about intermediate results from running your model on a particular sample?

      • Does having the backbone of the model wrapped in an instance of nn.Sequential help or hinder this effort?

12.9 Summary

  • A binary label and a binary classification threshold combine to partition the dataset into four quadrants: true positives, true negatives, false negatives, and false positives. These four quantities provide the basis for our improved performance metrics.

  • Recall is the ability of a model to maximize true positives. Selecting every single item guarantees perfect recall--because all the correct answers are included--but also exhibits poor precision.

  • Precision is the ability of a model to minimize false positives. Selecting nothing guarantees perfect precision--because no incorrect answers are included--but also exhibits poor recall.

  • The F1 score combines precision and recall into a single metric that describes model performance. We use the F1 score to determine what impact changes to training or the model have on our performance.

  • Balancing the training set to have an equal number of positive and negative samples during training can result in the model performing better (defined as having a positive, increasing F1 score).

  • Data augmentation takes existing organic data samples and modifies them such that the resulting augmented sample is non-trivially different from the original, but remains representative of samples of the same class. This allows additional training without overfitting in situations where data is limited.

  • Common data augmentation strategies include changes in orientation, mirroring, rescaling, shifting by an offset, and adding noise. Depending on the project, other more specific strategies may also be relevant.


1.No one actually says this.

2.If it’s taking longer than that, make sure you’ve run the prepcache script.

3.Keep in mind that these images are just a representation of the classification space and do not represent ground truth.

4.It’s not clear if this is actually true, but it’s plausible, and the loss was getting better . . .

5.And remember that this is after only the 200,000 training samples presented, not the 500,000+ of the unbalanced dataset, so we got there in less than half the time.

6.Yep, that’s a hint it’s not the F1 score!

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset