13 Using segmentationto find suspected nodules

This chapter covers

  • Segmenting data with a pixel-to-pixel model
  • Performing segmentation with U-Net
  • Understanding mask prediction using Dice loss
  • Evaluating a segmentation model’s performance

In the last four chapters, we have accomplished a lot. We’ve learned about CT scans and lung tumors, datasets and data loaders, and metrics and monitoring. We have also applied many of the things we learned in part 1, and we have a working classifier. We are still operating in a somewhat artificial environment, however, since we require hand-annotated nodule candidate information to load into our classifier. We don’t have a good way to create that input automatically. Just feeding the entire CT into our model--that is, plugging in overlapping 32 × 32 × 32 patches of data--would result in 31 × 31 × 7 = 6,727 patches per CT, or about 10 times the number of annotated samples we have. We’d need to overlap the edges; our classifier expects the nodule candidate to be centered, and even then the inconsistent positioning would probably present issues.

As we explained in chapter 9, our project uses multiple steps to solve the problem of locating possible nodules, identifying them, with an indication of their possible malignancy. This is a common approach among practitioners, while in deep learning research there is a tendency to demonstrate the ability of individual models to solve complex problems in an end-to-end fashion. The multistage project design we use in this book gives us a good excuse to introduce new concepts step by step.

13.1 Adding a second model to our project

In the previous two chapters, we worked on step 4 of our plan shown in figure 13.1: classification. In this chapter, we’ll go back not just one but two steps. We need to find a way to tell our classifier where to look. To do this, we are going to take raw CT scans and find everything that might be a nodule.1 This is the highlighted step 2 in the figure. To find these possible nodules, we have to flag voxels that look like they might be part of a nodule, a process known as segmentation. Then, in chapter 14, we will deal with step 3 and provide the bridge by transforming the segmentation masks from this image into location annotations.

Figure 13.1 Our end-to-end lung cancer detection project, with a focus on this chapter’s topic: step 2, segmentation

By the time we’re finished with this chapter, we’ll have created a new model with an architecture that can perform per-pixel labeling, or segmentation. The code that will accomplish this will be very similar to the code from the last chapter, especially if we focus on the larger structure. All of the changes we’re going to make will be smaller and targeted. As we see in figure 13.2, we need to make updates to our model (step 2A in the figure), dataset (2B), and training loop (2C) to account for the new model’s inputs, outputs, and other requirements. (Don’t worry if you don’t recognize each component in each of these steps in step 2 on the right side of the diagram. We’ll go through the details when we get to each step.) Finally, we’ll examine the results we get when running our new model (step 3 in the figure).

Figure 13.2 The new model architecture for segmentation, along with the model, dataset, and training loop updates we will implement

Breaking down figure 13.2 into steps, our plan for this chapter is as follows:

  1. Segmentation. First we will learn how segmentation works with a U-Net model, including what the new model components are and what happens to them as we go through the segmentation process. This is step 1 in figure 13.2.

  2. Update. To implement segmentation, we need to change our existing code base in three main places, shown in the substeps on the right side of figure 13.2.The code will be structurally very similar to what we developed for classification, but will differ in detail:

    1. Update the model (step 2A). We will integrate a preexisting U-Net into our segmentation model. Our model in chapter 12 output a simple true/false classification; our model in this chapter will instead output an entire image.

    2. Change the dataset (step 2B). We need to change our dataset to not only deliver bits of the CT but also provide masks for the nodules. The classification dataset consisted of 3D crops around nodule candidates, but we’ll need to collect both full CT slices and 2D crops for segmentation training and validation.

    3. Adapt the training loop (step 2C). We need to adapt the training loop so we bring in a new loss to optimize. Because we want to display images of our segmentation results in TensorBoard, we’ll also do things like saving our model weights to disk.

  3. Results. Finally, we’ll see the fruits of our efforts when we look at the quantitative segmentation results.

13.2 Various types of segmentation

To get started, we need to talk about different flavors of segmentation. For this project, we will be using semantic segmentation, which is the act of classifying individual pixels in an image using labels just like those we’ve seen for our classification tasks, for example, “bear,” “cat,” “dog,” and so on. If done properly, this will result in distinct chunks or regions that signify things like “all of these pixels are part of a cat.” This takes the form of a label mask or heatmap that identifies areas of interest. We will have a simple binary label: true values will correspond to nodule candidates, and false values mean uninteresting healthy tissue. This partially meets our need to find nodule candidates that we will later feed into our classification network.

Before we get into the details, we should briefly discuss other approaches we could take to finding our nodule candidates. For example, instance segmentation labels individual objects of interest with distinct labels. So whereas semantic segmentation would label a picture of two people shaking hands with two labels (“person” and “background”), instance segmentation would have three labels (“person1,” “person2,” and “background”) with a boundary somewhere around the clasped hands. While this could be useful for us to distinguish “nodule1” from “nodule2,” we will instead use grouping to identify individual nodules. That approach will work well for us since nodules are unlikely to touch or overlap.

Another approach to these kinds of tasks is object detection, which locates an item of interest in an image and puts a bounding box around the item. While both instance segmentation and object detection could be great for our uses, their implementations are somewhat complex, and we don’t feel they are the best things for you to learn next. Also, training object-detection models typically requires much more computational resources than our approach requires. If you’re feeling up to the challenge, the YOLOv3 paper is a more entertaining read than most deep learning research papers.2 For us, though, semantic segmentation it is.

Note As we go through the code examples in this chapter, we’re going to rely on you checking the code from GitHub for much of the larger context. We’ll be omitting code that’s uninteresting or similar to what’s come before in earlier chapters, so that we can focus on the crux of the issue at hand.

13.3 Semantic segmentation: Per-pixel classification

Often, segmentation is used to answer questions of the form “Where is a cat in this picture?” Obviously, most pictures of a cat, like figure 13.3, have a lot of non-cat in them; there’s the table or wall in the background, the keyboard the cat is sitting on, that kind of thing. Being able to say “This pixel is part of the cat, and this other pixel is part of the wall” requires fundamentally different model output and a different internal structure from the classification models we’ve worked with thus far. Classification can tell us whether a cat is present, while segmentation will tell us where we can find it.

Figure 13.3 Classification results in one or more binary flags, while segmentation produces a mask or heatmap.

If your project requires differentiating between a near cat and a far cat, or a cat on the left versus a cat on the right, then segmentation is probably the right approach. The image-consuming classification models that we’ve implemented so far can be thought of as funnels or magnifying glasses that take a large bunch of pixels and focus them down into a single “point” (or, more accurately, a single set of class predictions), as shown in figure 13.4. Classification models provide answers of the form “Yes, this huge pile of pixels has a cat in it, somewhere,” or “No, no cats here.” This is great when you don’t care where the cat is, just that there is (or isn’t) one in the image.

Figure 13.4 The magnifying glass model structure for classification

Repeated layers of convolution and downsampling mean the model starts by consuming raw pixels to produce specific, detailed detectors for things like texture and color, and then builds up higher-level conceptual feature detectors for parts like eyes and ears and mouth and nose3 that finally result in “cat” versus “dog.” Due to the increasing receptive field of the convolutions after each downsampling layer, those higher-level detectors can use information from an increasingly large area of the input image.

Unfortunately, since segmentation needs to produce an image-like output, ending up at a single classification-like list of binary-ish flags won’t work. As we recall from section 11.4, downsampling is key to increasing the receptive fields of the convolutional layers, and is what helps reduce the array of pixels that make up an image to a single list of classes. Notice figure 13.5, which repeats figure 11.6.

Figure 13.5 The convolutional architecture of a LunaModel block, consisting of two 3 × 3 convolutions followed by a max pool. The final pixel has a 6 × 6 receptive field.

In the figure, our inputs flow from the left to right in the top row and are continued in the bottom row. In order to work out the receptive field--the area influencing the single pixel at bottom right--we can go backward. The max-pool operation has 2 × 2 inputs producing each final output pixel. The 3 × 3 conv in the middle of the bottom row looks at one adjacent pixel (including diagonally) in each direction, so the total receptive field of the convolutions that result in the 2 x 2 output is 4 x 4 (with the right “x” characters). The 3 × 3 convolution in the top row then adds an additional pixel of context in each direction, so the receptive field of the single output pixel at bottom right is a 6 × 6 field in the input at top left. With the downsampling from the max pool, the receptive field of the next block of convolutions will have double the width, and each additional downsampling will double it again, while shrinking the size of the output.

We’ll need a different model architecture if we want our output to be the same size as our input. One simple model to use for segmentation would have repeated convolutional layers without any downsampling. Given appropriate padding, that would result in output the same size as the input (good), but a very limited receptive field (bad) due to the limited reach based on how much overlap multiple layers of small convolutions will have. The classification model uses each downsampling layer to double the effective reach of the following convolutions; and without that increase in effective field size, each segmented pixel will only be able to consider a very local neighborhood.

Note Assuming 3 × 3 convolutions, the receptive field size for a simple model of stacked convolutions is 2 * L + 1, with L being the number of convolutional layers.

Four layers of 3 × 3 convolutions will have a receptive field of 9 × 9 per output pixel. By inserting a 2 × 2 max pool between the second and third convolutions, and another at the end, we increase the receptive field to ...

Note See if you can figure out the math yourself; when you’re done, check back here.

... 16 × 16. The final series of conv-conv-pool has a receptive field of 6 × 6, but that happens after the first max pool, which makes the final effective receptive field 12 × 12 in the original input resolution. The first two conv layers add a total border of 2 pixels around the 12 × 12, for a total of 16 × 16.

So the question remains: how can we improve the receptive field of an output pixel while maintaining a 1:1 ratio of input pixels to output pixels? One common answer is to use a technique called upsampling, which takes an image of a given resolution and produces an image of a higher resolution. Upsampling at its simplest just means replacing each pixel with an N × N block of pixels, each with the same value as the original input pixel. The possibilities only get more complex from there, with options like linear interpolation and learned deconvolution.

13.3.1 The U-Net architecture

Before we end up diving down a rabbit hole of possible upsampling algorithms, let’s get back to our goal for the chapter. Per figure 13.6, step 1 is to get familiar with a foundational segmentation algorithm called U-Net.

Figure 13.6 The new model architecture for segmentation, that we will be working with

The U-Net architecture is a design for a neural network that can produce pixel-wise output and that was invented for segmentation. As you can see from the highlight in figure 13.6, a diagram of the U-Net architecture looks a bit like the letter U, which explains the origins of the name. We also immediately see that it is quite a bit more complicated than the mostly sequential structure of the classifiers we are familiar with. We’ll see a more detailed version of the U-Net architecture shortly, in figure 13.7, and learn exactly what each of those components is doing. Once we understand the model architecture, we can work on training one to solve our segmentation task.

Figure 13.7 From the U-Net paper, with annotations. Source: The base of this figure is courtesy Olaf Ronneberger et al., from the paper “U-Net: Convolutional Networks for Biomedical Image Segmentation,” which can be found at https://arxiv.org/abs/1505.04597 and https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.

The U-Net architecture shown in figure 13.7 was an early breakthrough for image segmentation. Let’s take a look and then walk through the architecture.

In this diagram, the boxes represent intermediate results and the arrows represent operations between them. The U-shape of the architecture comes from the multiple resolutions at which the network operates. In the top row is the full resolution (512 × 512 for us), the row below has half that, and so on. The data flows from top left to bottom center through a series of convolutions and downscaling, as we saw in the classifiers and looked at in detail in chapter 8. Then we go up again, using upscaling convolutions to get back to the full resolution. Unlike the original U-Net, we will be padding things so we don’t lose pixels off the edges, so our resolution is the same on the left and on the right.

Earlier network designs already had this U-shape, which people attempted to use to address the limited receptive field size of fully convolutional networks. To address this limited field size, they used a design that copied, inverted, and appended the focusing portions of an image-classification network to create a symmetrical model that goes from fine detail to wide receptive field and back to fine detail.

Those earlier network designs had problems converging, however, most likely due to the loss of spatial information during downsampling. Once information reaches a large number of very downscaled images, the exact location of object boundaries gets harder to encode and therefore reconstruct. To address this, the U-Net authors added the skip connections we see at the center of the figure. We first touched on skip connections in chapter 8, although they are employed differently here than in the ResNet architecture. In U-Net, skip connections short-circuit inputs along the downsampling path into the corresponding layers in the upsampling path. These layers receive as input both the upsampled results of the wide receptive field layers from lower in the U as well as the output of the earlier fine detail layers via the “copy and crop” bridge connections. This is the key innovation behind U-Net (which, interestingly, predated ResNet).

All of this means those final detail layers are operating with the best of both worlds. They’ve got both information about the larger context surrounding the immediate area and fine detail data from the first set of full-resolution layers.

The “conv 1x1” layer at far right, in the head of the network, changes the number of channels from 64 to 2 (the original paper had 2 output channels; we have 1 in our case). This is somewhat akin to the fully connected layer we used in our classification network, but per-pixel, channel-wise: it’s a way to convert from the number of filters used in the last upsampling step to the number of output classes needed.

13.4 Updating the model for segmentation

It’s time to move through step 2A in figure 13.8. We’ve had enough theory about segmentation and history about U-Net; now we want to update our code, starting with the model. Instead of just outputting a binary classification that gives us a single output of true or false, we integrate a U-Net to get to a model that’s capable of outputting a probability for every pixel: that is, performing segmentation. Rather than imple-menting a custom U-Net segmentation model from scratch, we’re going to appropriate an existing implementation from an open source repository on GitHub.

The U-Net implementation at https://github.com/jvanvugt/pytorch-unet seems to meet our needs well.4 It’s MIT licensed (copyright 2018 Joris), it’s contained in a single file, and it has a number of parameter options for us to tweak. The file is included in our code repository at util/unet.py, along with a link to the original repository and the full text of the license used.

Note While it’s less of an issue for personal projects, it’s important to be aware of the license terms attached to open source software you use for a project. The MIT license is one of the most permissive open source licenses, and it still places requirements on users of MIT licensed code! Also be aware that authors retain copyright even if they publish their work in a public forum (yes, even on GitHub), and if they do not include a license, that does not mean the work is in the public domain. Quite the opposite! It means you don’t have any license to use the code, any more than you’d have the right to wholesale copy a book you borrowed from the library.

We suggest taking some time to inspect the code and, based on the knowledge you have built up until this point, identify the building blocks of the architecture as they are reflected in the code. Can you spot skip connections? A particularly worthy exercise for you is to draw a diagram that shows how the model is laid out, just by looking at the code.

Now that we have found a U-Net implementation that fits the bill, we need to adapt it so that it works well for our needs. In general, it’s a good idea to keep an eye out for situations where we can use something off the shelf. It’s important to have a sense of what models exist, how they’re implemented and trained, and whether any parts can be scavenged and applied to the project we’re working on at any given moment. While that broader knowledge is something that comes with time and experience, it’s a good idea to start building that toolbox now.

13.4.1 Adapting an off-the-shelf model to our project

We will now make some changes to the classic U-Net, justifying them along the way. A useful exercise for you will be to compare results between the vanilla model and the one after the tweaks, preferably removing one at a time to see the effect of each change (this is also called an ablation study in research circles).

Figure 13.8 The outline of this chapter, with a focus on the changes needed for our segmentation model

First, we’re going to pass the input through batch normalization. This way, we won’t have to normalize the data ourselves in the dataset; and, more importantly, we will get normalization statistics (read mean and standard deviation) estimated over individual batches. This means when a batch is dull for some reason--that is, when there is nothing to see in all the CT crops fed into the network--it will be scaled more strongly. The fact that samples in batches are picked randomly at every epoch will minimize the chances of a dull sample ending up in an all-dull batch, and hence those dull samples getting overemphasized.

Second, since the output values are unconstrained, we are going to pass the output through an nn.Sigmoid layer to restrict the output to the range [0, 1]. Third, we will reduce the total depth and number of filters we allow our model to use. While this is jumping ahead of ourselves a bit, the capacity of the model using the standard parameters far outstrips our dataset size. This means we’re unlikely to find a pretrained model that matches our exact needs. Finally, although this is not a modification, it’s important to note that our output is a single channel, with each pixel of output representing the model’s estimate of the probability that the pixel in question is part of a nodule.

This wrapping of U-Net can be done rather simply by implementing a model with three attributes: one each for the two features we want to add, and one for the U-Net itself--which we can treat just like any prebuilt module here. We will also pass any keyword arguments we receive into the U-Net constructor.

Listing 13.1 model.py:17, class UNetWrapper

class UNetWrapper(nn.Module):
  def __init__(self, **kwargs):                                    
    super().__init__()
 
    self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])   
    self.unet = UNet(**kwargs)                                     
    self.final = nn.Sigmoid()
 
    self._init_weights()                                           

kwarg is a dictionary containing all keyword arguments passed to the constructor.

BatchNorm2d wants us to specify the number of input channels, which we take from the keyword argument.

The U-Net: a small thing to include here, but it’s really doing all the work.

Just as for the classifier in chapter 11, we use our custom weight initialization. The function is copied over, so we will not show the code again.

The forward method is a similarly straightforward sequence. We could use an instance of nn.Sequential as we saw in chapter 8, but we’ll be explicit here for both clarity of code and clarity of stack traces.5

Listing 13.2 model.py:50, UNetWrapper.forward

def forward(self, input_batch):
  bn_output = self.input_batchnorm(input_batch)
  un_output = self.unet(bn_output)
  fn_output = self.final(un_output)
  return fn_output

Note that we’re using nn.BatchNorm2d here. This is because U-Net is fundamentally a two-dimensional segmentation model. We could adapt the implementation to use 3D convolutions, in order to use information across slices. The memory usage of a straightforward implementation would be considerably greater: that is, we would have to chop up the CT scan. Also, the fact that pixel spacing in the Z direction is much larger than in-plane makes a nodule less likely to be present across many slices. These considerations make a fully 3D approach less attractive for our purposes. Instead, we’ll adapt our 3D data to be segmented a slice at a time, providing adjacent slices for context (for example, detecting that a bright lump is indeed a blood vessel gets much easier alongside neighboring slices). Since we’re sticking with presenting the data in 2D, we’ll use channels to represent the adjacent slices. Our treatment of the third dimension is similar to how we applied a fully connected model to images in chapter 7: the model will have to relearn the adjacency relationships we’re throwing away along the axial direction, but that’s not difficult for the model to accomplish, especially with the limited number of slices given for context owing to the small size of the target structures.

13.5 Updating the dataset for segmentation

Our source data for this chapter remains unchanged: we’re consuming CT scans and annotation data about them. But our model expects input and will produce output of a different form than we had previously. As we hint at in step 2B of figure 13.9, our previous dataset produced 3D data, but we need to produce 2D data now.

Figure 13.9 The outline of this chapter, with a focus on the changes needed for our segmentation dataset

The original U-Net implementation did not use padded convolutions, which means while the output segmentation map was smaller than the input, every pixel of that output had a fully populated receptive field. None of the input pixels that fed into the determination of that output pixel were padded, fabricated, or otherwise incomplete. Thus the output of the original U-Net will tile perfectly, so it can be used with images of any size (except at the edges of the input image, where some context will be missing by definition).

There are two problems with us taking the same pixel-perfect approach for our problem. The first is related to the interaction between convolution and downsampling, and the second is related to the nature of our data being three-dimensional.

13.5.1 U-Net has very specific input size requirements

The first issue is that the sizes of the input and output patches for U-Net are very specific. In order to have the two-pixel loss per convolution line up evenly before and after downsampling (especially when considering the further convolutional shrinkage at that lower resolution), only certain input sizes will work. The U-Net paper used 572 × 572 image patches, which resulted in 388 × 388 output maps. The input images are bigger than our 512 × 512 CT slices, and the output is quite a bit smaller! That would mean any nodules near the edge of the CT scan slice wouldn’t be segmented at all. Although this setup works well when dealing with very large images, it’s not ideal for our use case.

We will address this issue by setting the padding flag of the U-Net constructor to True. This will mean we can use input images of any size, and we will get output of the same size. We may lose some fidelity near the edges of the image, since the receptive field of pixels located there will include regions that have been artificially padded, but that’s a compromise we decide to live with.

13.5.2 U-Net trade-offs for 3D vs. 2D data

The second issue is that our 3D data doesn’t line up exactly with U-Net’s 2D expected input. Simply taking our 512 × 512 × 128 image and feeding it into a converted-to-3D U-Net class won’t work, because we’ll exhaust our GPU memory. Each image is 29 by 29 by 27, with 22 bytes per voxel. The first layer of U-Net is 64 channels, or 26. That’s an exponent of 9 + 9 + 7 + 2 + 6 = 33, or 8 GB just for the first convolutional layer. There are two convolutional layers (16 GB); and then each downsampling halves the resolution but doubles the channels, which is another 2 GB for each layer after the first downsample (remember, halving the resolution results in one-eighth the data, since we’re working with 3D data). So we’ve hit 20 GB before we even get to the second downsample, much less anything on the upsample side of the model or anything dealing with autograd.

Note There are a number of clever and innovative ways to get around these problems, and we in no way suggest that this is the only approach that will ever work.6 We do feel that this approach is one of the simplest that gets the job done to the level we need for our project in this book. We’d rather keep things simple so that we can focus on the fundamental concepts; the clever stuff can come later, once you’ve mastered the basics.

As anticipated, instead of trying to do things in 3D, we’re going to treat each slice as a 2D segmentation problem and cheat our way around the issue of context in the third dimension by providing neighboring slices as separate channels. Instead of the traditional “red,” “green,” and “blue” channels that we’re familiar with from photographic images, our main channels will be “two slices above,” “one slice above,” “the slice we’re actually segmenting,” “one slice below,” and so on.

This approach isn’t without trade-offs, however. We lose the direct spatial relationship between slices when represented as channels, as all channels will be linearly combined by the convolution kernels with no notion of them being one or two slices away, above or below. We also lose the wider receptive field in the depth dimension that would come from a true 3D segmentation with downsampling. Since CT slices are often thicker than the resolution in rows and columns, we do get a somewhat wider view than it seems at first, and this should be enough, considering that nodules typically span a limited number of slices.

Another aspect to consider, that is relevant for both the current and fully 3D approaches, is that we are now ignoring the exact slice thickness. This is something our model will eventually have to learn to be robust against, by being presented with data with different slice spacings.

In general, there isn’t an easy flowchart or rule of thumb that can give canned answers to questions about which trade-offs to make, or whether a given set of compromises compromise too much. Careful experimentation is key, however, and systematically testing hypothesis after hypothesis can help narrow down which changes and approaches are working well for the problem at hand. Although it’s tempting to make a flurry of changes while waiting for the last set of results to compute, resist that impulse.

That’s important enough to repeat: do not test multiple modifications at the same time. There is far too high a chance that one of the changes will interact poorly with the other, and you’ll be left without solid evidence that either one is worth investigating further. With that said, let’s start building out our segmentation dataset.

13.5.3 Building the ground truth data

The first thing we need to address is that we have a mismatch between our human-labeled training data and the actual output we want to get from our model. We have annotated points, but we want a per-voxel mask that indicates whether any given voxel is part of a nodule. We’ll have to build that mask ourselves from the data we have and then do some manual checking to make sure the routine that builds the mask is performing well.

Validating these manually constructed heuristics at scale can be difficult. We aren’t going to attempt to do anything comprehensive when it comes to making sure each and every nodule is properly handled by our heuristics. If we had more resources, approaches like “collaborate with (or pay) someone to create and/or verify everything by hand” might be an option, but since this isn’t a well-funded endeavor, we’ll rely on checking a handful of samples and using a very simple “does the output look reasonable?” approach.

To that end, we’ll design our approaches and our APIs to make it easy to investigate the intermediate steps that our algorithms are going through. While this might result in slightly clunky function calls returning huge tuples of intermediate values, being able to easily grab results and plot them in a notebook makes the clunk worth it.

Bounding boxes

We are going to begin by converting the nodule locations that we have into bounding boxes that cover the entire nodule (note that we’ll only do this for actual nodules). If we assume that the nodule locations are roughly centered in the mass, we can trace outward from that point in all three dimensions until we hit low-density voxels, indicating that we’ve reached normal lung tissue (which is mostly filled with air). Let’s follow this algorithm in figure 13.10.

Figure 13.10 An algorithm for finding a bounding box around a lung nodule

We start the origin of our search (O in the figure) at the voxel at the annotated center of our nodule. We then examine the density of the voxels adjacent to our origin on the column axis, marked with a question mark (?). Since both of the examined voxels contain dense tissue, shown here in lighter colors, we continue our search. After incrementing our column search distance to 2, we find that the left voxel has a density below our threshold, and so we stop our search at 2.

Next, we perform the same search in the row direction. Again, we start at the origin, and this time we search up and down. After our search distance becomes 3, we encounter a low-density voxel in both the upper and lower search locations. We only need one to stop our search!

We’ll skip showing the search in the third dimension. Our final bounding box is five voxels wide and seven voxels tall. Here’s what that looks like in code, for the index direction.

Listing 13.3 dsets.py:131, Ct.buildAnnotationMask

center_irc = xyz2irc(
  candidateInfo_tup.center_xyz,                                   
  self.origin_xyz,
  self.vxSize_xyz,
  self.direction_a,
)
ci = int(center_irc.index)                                        
cr = int(center_irc.row)
cc = int(center_irc.col)

index_radius = 2
try:
  while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and 
     self.hu_a[ci - index_radius, cr, cc] > threshold_hu:         
     index_radius += 1
except IndexError:                                                
  index_radius -= 1

candidateInfo_tup here is the same as we’ve seen previously: as returned by getCandidateInfoList.

Gets the center voxel indices, our starting point

The search described previously

The safety net for indexing beyond the size of the tensor

We first grab the center data and then do the search in a while loop. As a slight complication, our search might fall off the boundary of our tensor. We are not terribly concerned about that case and are lazy, so we just catch the index exception.7

Note that we stop incrementing the very approximate radius values after the density drops below threshold, so our bounding box should contain a one-voxel border of low-density tissue (at least on one side; since nodules can be adjacent to regions like the lung wall, we have to stop searching in both directions when we hit air on either side). Since we check both center_index + index_radius and center_index - index_radius against that threshold, that one-voxel boundary will only exist on the edge closest to our nodule location. This is why we need those locations to be relatively centered. Since some nodules are adjacent to the boundary between the lung and denser tissue like muscle or bone, we can’t trace each direction independently, as some edges would end up incredibly far away from the actual nodule.

We then repeat the same radius-expansion process with row_radius and col _radius (this code is omitted for brevity). Once that’s done, we can set a box in our bounding-box mask array to True (we’ll see the definition of boundingBox_ary in just a moment; it’s not surprising).

OK, let’s wrap all this up in a function. We loop over all nodules. For each nodule, we perform the search shown earlier (which we elide from listing 13.4). Then, in a Boolean tensor boundingBox_a, we mark the bounding box we found.

After the loop, we do a bit of cleanup by taking the intersection between the bounding-box mask and the tissue that’s denser than our threshold of -700 HU (or 0.3 g/cc). That’s going to clip off the corners of our boxes (at least, the ones not embedded in the lung wall), and make it conform to the contours of the nodule a bit better.

Listing 13.4 dsets.py:127, Ct.buildAnnotationMask

def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
  boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)                

  for candidateInfo_tup in positiveInfo_list:                            
    # ... line 169
    boundingBox_a[
       ci - index_radius: ci + index_radius + 1,
       cr - row_radius: cr + row_radius + 1,
       cc - col_radius: cc + col_radius + 1] = True                      

  mask_a = boundingBox_a & (self.hu_a > threshold_hu)                    

  return mask_a

Starts with an all-False tensor of the same size as the CT

Loops over the nodules. As a reminder that we are only looking at nodules, we call the variable positiveInfo_list.

After we get the nodule radius (the search itself is left out), we mark the bounding box.

Restricts the mask to voxels above our density threshold

Let’s take a look at figure 13.11 to see what these masks look like in practice. Additional images in full color can be found in the p2ch13_explore_data.ipynb notebook.

Figure 13.11 Three nodules from ct.positive_mask, highlighted in white

The bottom-right nodule mask demonstrates a limitation of our rectangular bounding-box approach by including a portion of the lung wall. It’s certainly something we could fix, but since we’re not yet convinced that’s the best use of our time and attention, we’ll let it remain as is for now.8 Next, we’ll go about adding this mask to our CT class.

Calling mask creation during CT initialization

Now that we can take a list of nodule information tuples and turn them into at CT-shaped binary “Is this a nodule?” mask, let’s embed those masks into our CT object. First, we’ll filter our candidates into a list containing only nodules, and then we’ll use that list to build the annotation mask. Finally, we’ll collect the set of unique array indexes that have at least one voxel of the nodule mask. We’ll use this to shape the data we use for validation.

Listing 13.5 dsets.py:99, Ct.__init__

def __init__(self, series_uid):
  # ... line 116
  candidateInfo_list = getCandidateInfoDict()[self.series_uid]
 
  self.positiveInfo_list = [
    candidate_tup
    for candidate_tup in candidateInfo_list
    if candidate_tup.isNodule_bool                                       
  ]
  self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
  self.positive_indexes = (self.positive_mask.sum(axis=(1,2))            
                .nonzero()[0].tolist())                                  

Filters for nodules

Gives us a 1D vector (over the slices) with the number of voxels flagged in the mask in each slice

Takes indices of the mask slices that have a nonzero count, which we make into a list

Keen eyes might have noticed the getCandidateInfoDict function. The definition isn’t surprising; it’s just a reformulation of the same information as in the getCandidateInfoList function, but pregrouped by series_uid.

Listing 13.6 dsets.py:87

@functools.lru_cache(1)                                        
def getCandidateInfoDict(requireOnDisk_bool=True):
  candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
  candidateInfo_dict = {}
 
  for candidateInfo_tup in candidateInfo_list:
    candidateInfo_dict.setdefault(candidateInfo_tup.series_uid,
                    []).append(candidateInfo_tup)              
 
  return candidateInfo_dict

This can be useful to keep Ct init from being a performance bottleneck.

Takes the list of candidates for the series UID from the dict, defaulting to a fresh, empty list if we cannot find it. Then appends the present candidateInfo_tup to it.

Caching chunks of the mask in addition to the CT

In earlier chapters, we cached chunks of CT centered around nodule candidates, since we didn’t want to have to read and parse all of a CT’s data every time we wanted a small chunk of the CT. We’ll want to do the same thing with our new positive _mask, so we need to also return it from our Ct.getRawCandidate function. This works out to an additional line of code and an edit to the return statement.

Listing 13.7 dsets.py:178, Ct.getRawCandidate

def getRawCandidate(self, center_xyz, width_irc):
  center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
             self.direction_a)
 
  slice_list = []
  # ... line 203
  ct_chunk = self.hu_a[tuple(slice_list)]
  pos_chunk = self.positive_mask[tuple(slice_list)]   
 
  return ct_chunk, pos_chunk, center_irc              

Newly added

New value returned here

This will, in turn, be cached to disk by the getCtRawCandidate function, which opens the CT, gets the specified raw candidate including the nodule mask, and clips the CT values before returning the CT chunk, mask, and center information.

Listing 13.8 dsets.py:212

@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
  ct = getCt(series_uid)
  ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz,
                             width_irc)
  ct_chunk.clip(-1000, 1000, ct_chunk)
  return ct_chunk, pos_chunk, center_irc

The prepcache script precomputes and saves all these values for us, helping keep training quick.

Cleaning up our annotation data

Another thing we’re going to take care of in this chapter is doing some better screening on our annotation data. It turns out that several of the candidates listed in candidates.csv are present multiple times. To make it even more interesting, those entries are not exact duplicates of one another. Instead, it seems that the original human annotations weren’t sufficiently cleaned before being entered in the file. They might be annotations on the same nodule on different slices, which might even have been beneficial for our classifier.

We’ll do a bit of a hand wave here and provide a cleaned up annotation.csv file. In order to fully walk through the provenance of this cleaned file, you’ll need to know that the LUNA dataset is derived from another dataset called the Lung Image Database Consortium image collection (LIDC-IDRI)9 and includes detailed annotation information from multiple radiologists. We’ve already done the legwork to get the original LIDC annotations, pull out the nodules, dedupe them, and save them to the file /data/part2/luna/annotations_with_malignancy.csv.

With that file, we can update our getCandidateInfoList function to pull our nodules from our new annotations file. First, we loop over the new annotations for the actual nodules. Using the CSV reader,10 we need to convert the data to the appropriate types before we stick them into our CandidateInfoTuple data structure.

Listing 13.9 dsets.py:43, def getCandidateInfoList

candidateInfo_list = []
with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
  for row in list(csv.reader(f))[1:]:                                   
    series_uid = row[0]
    annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
    annotationDiameter_mm = float(row[4])
    isMal_bool = {'False': False, 'True': True}[row[5]]
    candidateInfo_list.append(                                          
      CandidateInfoTuple(
        True,                                                           
        True,                                                           
        isMal_bool,
        annotationDiameter_mm,
        series_uid,
        annotationCenter_xyz,
      )
    )

For each line in the annotations file that represents one nodule, ...

... we add a record to our list.

isNodule_bool

hasAnnotation_bool

Similarly, we loop over candidates from candidates.csv as before, but this time we only use the non-nodules. As these are not nodules, the nodule-specific information will just be filled with False and 0.

Listing 13.10 dsets.py:62, def getCandidateInfoList

with open('data/part2/luna/candidates.csv', "r") as f:
  for row in list(csv.reader(f))[1:]:                  
    series_uid = row[0]
    # ... line 72
    if not isNodule_bool:                              
      candidateInfo_list.append(                       
        CandidateInfoTuple(
          False,                                       
          False,                                       
          False,                                       
          0.0,
          series_uid,
          candidateCenter_xyz,
        )
      )

For each line in the candidates file ...

... but only the non-nodules (we have the others from earlier) ...

... we add a candidate record.

isNodule_bool

hasAnnotation_bool

isMal_bool

Other than the addition of the hasAnnotation_bool and isMal_bool flags (which we won’t use in this chapter), the new annotations will slot in and be usable just like the old ones.

Note You might be wondering why we haven’t discussed the LIDC before now. As it turns out, the LIDC has a large amount of tooling that’s already been constructed around the underlying dataset, which is specific to the LIDC. You could even get ready-made masks from PyLIDC. That tooling presents a somewhat unrealistic picture of what sort of support a given dataset might have, since the LIDC is anomalously well supported. What we’ve done with the LUNA data is much more typical and provides for better learning, since we’re spending our time manipulating the raw data rather than learning an API that someone else cooked up.

13.5.4 Implementing Luna2dSegmentationDataset

Compared to previous chapters, we are going to take a different approach to the training and validation split in this chapter. We will have two classes: one acting as a general base class suitable for validation data, and one subclassing the base for the training set, with randomization and a cropped sample.

While this approach is somewhat more complicated in some ways (the classes aren’t perfectly encapsulated, for example), it actually simplifies the logic of selecting randomized training samples and the like. It also becomes extremely clear which code paths impact both training and validation, and which are isolated to training only. Without this, we found that some of the logic can become nested or intertwined in ways that make it hard to follow. This is important because our training data will look significantly different from our validation data!

Note Other class arrangements are also viable; we considered having two entirely separate Dataset subclasses, for example. Standard software engineering design principles apply, so try to keep your structure relatively simple, and try to not copy and paste code, but don’t invent complicated frameworks to prevent having to duplicate three lines of code.

The data that we produce will be two-dimensional CT slices with multiple channels. The extra channels will hold adjacent slices of CT. Recall figure 4.2, shown here as figure 13.12; we can see that each slice of CT scan can be thought of as a 2D grayscale image.

Figure 13.12 Each slice of a CT scan represents a different position in space.

How we combine those slices is up to us. For the input to our classification model, we treated those slices as a 3D array of data and used 3D convolutions to process each sample. For our segmentation model, we are going to instead treat each slice as a single channel, and produce a multichannel 2D image. Doing so will mean that we are treating each slice of CT scan as if it was a color channel of an RGB image, like we saw in figure 4.1, repeated here as figure 13.13. Each input slice of the CT will get stacked together and consumed just like any other 2D image. The channels of our stacked CT image won’t correspond to colors, but nothing about 2D convolutions requires the input channels to be colors, so it works out fine.

Figure 13.13 Each channel of a photographic image represents a different color.

For validation, we’ll need to produce one sample per slice of CT that has an entry in the positive mask, for each validation CT we have. Since different CT scans can have different slice counts,11 we’re going to introduce a new function that caches the size of each CT scan and its positive mask to disk. We need this to be able to quickly construct the full size of a validation set without having to load each CT at Dataset initialization. We’ll continue to use the same caching decorator as before. Populating this data will also take place during the prepcache.py script, which we must run once before we start any model training.

Listing 13.11 dsets.py:220

@raw_cache.memoize(typed=True)
def getCtSampleSize(series_uid):
  ct = Ct(series_uid)
  return int(ct.hu_a.shape[0]), ct.positive_indexes

The majority of the Luna2dSegmentationDataset.__init__ method is similar to what we’ve seen before. We have a new contextSlices_count parameter, as well as an augmentation_dict similar to what we introduced in chapter 12.

The handling for the flag indicating whether this is meant to be a training or validation set needs to change somewhat. Since we’re no longer training on individual nodules, we will have to partition the list of series, taken as a whole, into training and validation sets. This means an entire CT scan, along with all nodule candidates it contains, will be in either the training set or the validation set.

Listing 13.12 dsets.py:242, .__init__

if isValSet_bool:
  assert val_stride > 0, val_stride
  self.series_list = self.series_list[::val_stride]   
  assert self.series_list
elif val_stride > 0:
  del self.series_list[::val_stride]                  
  assert self.series_list

Starting with a series list containing all our series, we keep only every val_stride-th element, starting with 0.

If we are training, we delete every val_stride-th element instead.

Speaking of validation, we’re going to have two different modes we can validate our training with. First, when fullCt_bool is True, we will use every slice in the CT for our dataset. This will be useful when we’re evaluating end-to-end performance, since we need to pretend that we’re starting off with no prior information about the CT. We’ll use the second mode for validation during training, which is when we’re limiting ourselves to only the CT slices that have a positive mask present.

As we now only want certain CT series to be considered, we loop over the series UIDs we want and get the total number of slices and the list of interesting ones.

Listing 13.13 dsets.py:250, .__init__

self.sample_list = []
for series_uid in self.series_list:
  index_count, positive_indexes = getCtSampleSize(series_uid)
 
  if self.fullCt_bool:
    self.sample_list += [(series_uid, slice_ndx)      
               for slice_ndx in range(index_count)]
  else:
    self.sample_list += [(series_uid, slice_ndx)      
               for slice_ndx in positive_indexes]

Here we extend sample_list with every slice of the CT by using range ...

... while here we take only the interesting slices.

Doing it this way will keep our validation relatively quick and ensure that we’re getting complete stats for true positives and false negatives, but we’re making the assumption that other slices will have false positive and true negative stats relatively similar to the ones we evaluate during validation.

Once we have the set of series_uid values we’ll be using, we can filter our candidateInfo_list to contain only nodule candidates with a series_uid that is included in that set of series. Additionally, we’ll create another list that has only the positive candidates so that during training, we can use those as our training samples.

Listing 13.14 dsets.py:261, .__init__

self.candidateInfo_list = getCandidateInfoList()                   
 
series_set = set(self.series_list)                                 
self.candidateInfo_list = [cit for cit in self.candidateInfo_list
               if cit.series_uid in series_set]                    
 
self.pos_list = [nt for nt in self.candidateInfo_list
          if nt.isNodule_bool]                                     

This is cached.

Makes a set for faster lookup

Filters out the candidates from series not in our set

For the data balancing yet to come, we want a list of actual nodules.

Our __getitem__ implementation will also be a bit fancier by delegating a lot of the logic to a function that makes it easier to retrieve a specific sample. At the core of it, we’d like to retrieve our data in three different forms. First, we have the full slice of the CT, as specified by a series_uid and ct_ndx. Second, we have a cropped area around a nodule, which we’ll use for training data (we’ll explain in a bit why we’re not using full slices). Finally, the DataLoader is going to ask for samples via an integer ndx, and the dataset will need to return the appropriate type based on whether it’s training or validation.

The base class or subclass __getitem__ functions will convert from the integer ndx to either the full slice or training crop, as appropriate. As mentioned, our validation set’s __getitem__ just calls another function to do the real work. Before that, it wraps the index around into the sample list in order to decouple the epoch size (given by the length of the dataset) from the actual number of samples.

Listing 13.15 dsets.py:281, .__getitem__

def __getitem__(self, ndx):
  series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]  
  return self.getitem_fullSlice(series_uid, slice_ndx)

The modulo operation does the wrapping.

That was easy, but we still need to implement the interesting functionality from the getItem_fullSlice method.

Listing 13.16 dsets.py:285, .getitem_fullSlice

def getitem_fullSlice(self, series_uid, slice_ndx):
  ct = getCt(series_uid)
  ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512))    
 
  start_ndx = slice_ndx - self.contextSlices_count
  end_ndx = slice_ndx + self.contextSlices_count + 1
  for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
    context_ndx = max(context_ndx, 0)                                 
    context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
    ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
  ct_t.clamp_(-1000, 1000)
 
  pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)
 
  return ct_t, pos_t, ct.series_uid, slice_ndx

Preallocates the output

When we reach beyond the bounds of the ct_a, we duplicate the first or last slice.

Splitting the functions like this means we can always ask a dataset for a specific slice (or cropped training chunk, which we’ll see in the next section) indexed by series UID and position. Only for the integer indexing do we go through __getitem__, which then gets a sample from the (shuffled) list.

Aside from ct_t and pos_t, the rest of the tuple we return is all information that we include for debugging and display. We don’t need any of it for training.

13.5.5 Designing our training and validation data

Before we get into the implementation for our training dataset, we need to explain why our training data will look different from our validation data. Instead of the full CT slices, we’re going to train on 64 × 64 crops around our positive candidates (the actually-a-nodule candidates). These 64 × 64 patches will be taken randomly from a 96 × 96 crop centered on the nodule. We will also include three slices of context in both directions as additional “channels” to our 2D segmentation.

We’re doing this to make training more stable, and to converge more quickly. The only reason we know to do this is because we tried to train on whole CT slices, but we found the results unsatisfactory. After some experimentation, we found that the 64 × 64 semirandom crop approach worked well, so we decided to use that for the book. When you work on your own projects, you’ll need to do that kind of experimentation for yourself!

We believe the whole-slice training was unstable essentially due to a class-balancing issue. Since each nodule is so small compared to the whole CT slice, we were right back in a needle-in-a-haystack situation similar to the one we got out of in the last chapter, where our positive samples were swamped by the negatives. In this case, we’re talking about pixels rather than nodules, but the concept is the same. By training on crops, we’re keeping the number of positive pixels the same and reducing the negative pixel count by several orders of magnitude.

Because our segmentation model is pixel-to-pixel and takes images of arbitrary size, we can get away with training and validating on samples with different dimensions. Validation uses the same convolutions with the same weights, just applied to a larger set of pixels (and so with fewer border pixels to fill in with edge data).

One caveat to this approach is that since our validation set contains orders of magnitude more negative pixels, our model will have a huge false positive rate during validation. There are many more opportunities for our segmentation model to get tricked! It doesn’t help that we’re going to be pushing for high recall as well. We’ll discuss that more in section 13.6.3.

13.5.6 Implementing TrainingLuna2dSegmentationDataset

With that out of the way, let’s get back to the code. Here’s the training set’s __getitem__. It looks just like the one for the validation set, except that we now sample from pos_list and call getItem_trainingCrop with the candidate info tuple, since we need the series and the exact center location, not just the slice.

Listing 13.17 dsets.py:320, .__getitem__

def __getitem__(self, ndx):
  candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
  return self.getitem_trainingCrop(candidateInfo_tup)

To implement getItem_trainingCrop, we will use a getCtRawCandidate function similar to the one we used during classification training. Here, we’re passing in a different size crop, but the function is unchanged except for now returning an additional array with a crop of the ct.positive_mask as well.

We limit our pos_a to the center slice that we’re actually segmenting, and then construct our 64 × 64 random crops of the 96 × 96 we were given by getCtRawCandidate. Once we have those, we return a tuple with the same items as our validation dataset.

Listing 13.18 dsets.py:324, .getitem_trainingCrop

def getitem_trainingCrop(self, candidateInfo_tup):
  ct_a, pos_a, center_irc = getCtRawCandidate(     
    candidateInfo_tup.series_uid,
    candidateInfo_tup.center_xyz,
    (7, 96, 96),
  )
  pos_a = pos_a[3:4]                               
 
  row_offset = random.randrange(0,32)              
  col_offset = random.randrange(0,32)
  ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64,
                 col_offset:col_offset+64]).to(torch.float32)
  pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64,
                   col_offset:col_offset+64]).to(torch.long)
 
  slice_ndx = center_irc.index
 
  return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx

Gets the candidate with a bit of extra surrounding

Taking a one-element slice keeps the third dimension, which will be the (single) output channel.

With two random numbers between 0 and 31, we crop both CT and mask.

You might have noticed that data augmentation is missing from our dataset implementation. We’re going to handle that a little differently this time around: we’ll augment our data on the GPU.

13.5.7 Augmenting on the GPU

One of the key concerns when it comes to training a deep learning model is avoiding bottlenecks in your training pipeline. Well, that’s not quite true--there will always be a bottleneck.12 The trick is to make sure the bottleneck is at the resource that’s the most expensive or difficult to upgrade, and that your usage of that resource isn’t wasteful.

Some common places to see bottlenecks are as follows:

  • In the data-loading pipeline, either in raw I/O or in decompressing data once it’s in RAM. We addressed this with our diskcache library usage.

  • In CPU preprocessing of the loaded data. This is often data normalization or augmentation.

  • In the training loop on the GPU. This is typically where we want our bottleneck to be, since total deep learning system costs for GPUs are usually higher than for storage or CPU.

  • Less commonly, the bottleneck can sometimes be the memory bandwidth between CPU and GPU. This implies that the GPU isn’t doing much work compared to the data size that’s being sent in.

Since GPUs can be 50 times faster than CPUs when working on tasks that fit GPUs well, it often makes sense to move those tasks to the GPU from the CPU in cases where CPU usage is becoming high. This is especially true if the data gets expanded during this processing; by moving the smaller input to the GPU first, the expanded data is kept local to the GPU, and less memory bandwidth is used.

In our case, we’re going to move data augmentation to the GPU. This will keep our CPU usage light, and the GPU will easily be able to accommodate the additional workload. Far better to have the GPU busy with a small bit of extra work than idle waiting for the CPU to struggle through the augmentation process.

We’ll accomplish this by using a second model, similar to all the other subclasses of nn.Module we’ve seen so far in this book. The main difference is that we’re not interested in backpropagating gradients through the model, and the forward method will be doing decidedly different things. There will be some slight modifications to the actual augmentation routines since we’re working with 2D data for this chapter, but otherwise, the augmentation will be very similar to what we saw in chapter 12. The model will consume tensors and produce different tensors, just like the other models we’ve implemented.

Our model’s __init__ takes the same data augmentation arguments--flip, offset, and so on--that we used in the last chapter, and assigns them to self.

Listing 13.19 model.py:56, class SegmentationAugmentation

class SegmentationAugmentation(nn.Module):
  def __init__(
      self, flip=None, offset=None, scale=None, rotate=None, noise=None
  ):
    super().__init__()
 
    self.flip = flip
    self.offset = offset
    # ... line 64

Our augmentation forward method takes the input and the label, and calls out to build the transform_t tensor that will then drive our affine_grid and grid_sample calls. Those calls should feel very familiar from chapter 12.

Listing 13.20 model.py:68, SegmentationAugmentation.forward

def forward(self, input_g, label_g):
  transform_t = self._build2dTransformMatrix()
  transform_t = transform_t.expand(input_g.shape[0], -1, -1)    
  transform_t = transform_t.to(input_g.device, torch.float32)
  affine_t = F.affine_grid(transform_t[:,:2],                   
      input_g.size(), align_corners=False)
 
  augmented_input_g = F.grid_sample(input_g,
      affine_t, padding_mode='border',
      align_corners=False)
  augmented_label_g = F.grid_sample(label_g.to(torch.float32),
      affine_t, padding_mode='border',
      align_corners=False)                                      
 
  if self.noise:
    noise_t = torch.randn_like(augmented_input_g)
    noise_t *= self.noise
 
    augmented_input_g += noise_t
 
  return augmented_input_g, augmented_label_g > 0.5             

Note that we’re augmenting 2D data.

The first dimension of the transformation is the batch, but we only want the first two rows of the 3 × 3 matrices per batch item.

We need the same transformation applied to CT and mask, so we use the same grid. Because grid_sample only works with floats, we convert here.

Just before returning, we convert the mask back to Booleans by comparing to 0.5. The interpolation that grid_sample results in fractional values.

Now that we know what we need to do with transform_t to get our data out, let’s take a look at the _build2dTransformMatrix function that actually creates the transformation matrix we use.

Listing 13.21 model.py:90, ._build2dTransformMatrix

def _build2dTransformMatrix(self):
  transform_t = torch.eye(3)                    
 
  for i in range(2):                            
    if self.flip:
      if random.random() > 0.5:
        transform_t[i,i] *= -1
  # ... line 108
  if self.rotate:
    angle_rad = random.random() * math.pi * 2   
    s = math.sin(angle_rad)
    c = math.cos(angle_rad)
 
    rotation_t = torch.tensor([                 
      [c, -s, 0],
      [s, c, 0],
      [0, 0, 1]])
 
    transform_t @= rotation_t                   
 
  return transform_t

Creates a 3 × 3 matrix, but we will drop the last row later.

Again, we’re augmenting 2D data here.

Takes a random angle in radians, so in the range 0 .. 2{pi}

Rotation matrix for the 2D rotation by the random angle in the first two dimensions

Applies the rotation to the transformation matrix using the Python matrix multiplication operator

Other than the slight differences to deal with 2D data, our GPU augmentation code looks very similar to our CPU augmentation code. That’s great, because it means we’re able to write code that doesn’t have to care very much about where it runs. The primary difference isn’t in the core implementation: it’s how we wrapped that implementation into a nn.Module subclass. While we’ve been thinking about models as exclusively a deep learning tool, this shows us that with PyTorch, tensors can be used quite a bit more generally. Keep this in mind when you start your next project--the range of things you can accomplish with a GPU-accelerated tensor is pretty large!

13.6 Updating the training script for segmentation

We have a model. We have data. We need to use them, and you won’t be surprised when step 2C of figure 13.14 suggests we should train our new model with the new data.

Figure 13.14 The outline of this chapter, with a focus on the changes needed for our training loop

To be more precise about the process of training our model, we will update three things affecting the outcome from the training code we got in chapter 12:

  • We need to instantiate the new model (unsurprisingly).

  • We will introduce a new loss: the Dice loss.

  • We will also look at an optimizer other than the venerable SGD we’ve used so far. We’ll stick with a popular one and use Adam.

But we will also step up our bookkeeping, by

  • Logging images for visual inspection of the segmentation to TensorBoard

  • Performing more metrics logging in TensorBoard

  • Saving our best model based on the validation

Overall, the training script p2ch13/training.py is even more similar to what we used for classification training in chapter 12 than the adapted code we’ve seen so far. Any significant changes will be covered here in the text, but be aware that some of the minor tweaks are skipped. For the full story, check the source.

13.6.1 Initializing our segmentation and augmentation models

Our initModel method is very unsurprising. We are using the UNetWrapper class and giving it our configuration parameters--which we will look at in detail shortly. Also, we now have a second model for augmentation. Just like before, we can move the model to the GPU if desired and possibly set up multi-GPU training using DataParallel. We skip these administrative tasks here.

Listing 13.22 training.py:133, .initModel

def initModel(self):
  segmentation_model = UNetWrapper(
    in_channels=7,
    n_classes=1,
    depth=3,
    wf=4,
    padding=True,
    batch_norm=True,
    up_mode='upconv',
  )
 
  augmentation_model = SegmentationAugmentation(**self.augmentation_dict)
 
  # ... line 154
  return segmentation_model, augmentation_model

For input into UNet, we’ve got seven input channels: 3 + 3 context slices, and 1 slice that is the focus for what we’re actually segmenting. We have one output class indicating whether this voxel is part of a nodule. The depth parameter controls how deep the U goes; each downsampling operation adds 1 to the depth. Using wf=5 means the first layer will have 2**wf == 32 filters, which doubles with each downsampling. We want the convolutions to be padded so that we get an output image the same size as our input. We also want batch normalization inside the network after each activation function, and our upsampling function should be an upconvolution layer, as implemented by nn.ConvTranspose2d (see util/unet.py, line 123).

13.6.2 Using the Adam optimizer

The Adam optimizer (https://arxiv.org/abs/1412.6980) is an alternative to using SGD when training our models. Adam maintains a separate learning rate for each parameter and automatically updates that learning rate as training progresses. Due to these automatic updates, we typically won’t need to specify a non-default learning rate when using Adam, since it will quickly determine a reasonable learning rate by itself.

Here’s how we instantiate Adam in code.

Listing 13.23 training.py:156, .initOptimizer

def initOptimizer(self):
  return Adam(self.segmentation_model.parameters())

It’s generally accepted that Adam is a reasonable optimizer to start most projects with.13 There is often a configuration of stochastic gradient descent with Nesterov momentum that will outperform Adam, but finding the correct hyperparameters to use when initializing SGD for a given project can be difficult and time consuming.

There have been a large number of variations on Adam--AdaMax, RAdam, Ranger, and so on--that each have strengths and weaknesses. Delving into the details of those is outside the scope of this book, but we think that it’s important to know that those alternatives exist. We’ll use Adam in chapter 13.

13.6.3 Dice loss

The Sørensen-Dice coefficient (https://en.wikipedia.org/wiki/S%C3%B8rensen%E2 %80%93Dice_coefficient), also known as the Dice loss, is a common loss metric for segmentation tasks. One advantage of using Dice loss over a per-pixel cross-entropy loss is that Dice handles the case where only a small portion of the overall image is flagged as positive. As we recall from chapter 11 in section 11.10, unbalanced training data can be problematic when using cross-entropy loss. That’s exactly the situation we have here--most of a CT scan isn’t a nodule. Luckily, with Dice, that won’t pose as much of a problem.

The Sørensen-Dice coefficient is based on the ratio of correctly segmented pixels to the sum of the predicted and actual pixels. Those ratios are laid out in figure 13.15. On the left, we see an illustration of the Dice score. It is twice the joint area (true positives, striped) divided by the sum of the entire predicted area and the entire ground-truth marked area (the overlap being counted twice). On the right are two prototypical examples of high agreement/high Dice score and low agreement/low Dice score.

Figure 13.15 The ratios that make up the Dice score

That might sound familiar; it’s the same ratio that we saw in chapter 12. We’re basically going to be using a per-pixel F1 score!

Note This is a per-pixel F1 score where the “population” is one image’s pixels. Since the population is entirely contained within one training sample, we can use it for training directly. In the classification case, the F1 score is not calculable over a single minibatch, and, hence, we cannot use it for training directly.

Since our label_g is effectively a Boolean mask, we can multiply it with our predictions to get our true positives. Note that we aren’t treating prediction_devtensor as a Boolean here. A loss defined with it wouldn’t be differentiable. Instead, we’re replacing the number of true positives with the sum of the predicted values for the pixels where the ground truth is 1. This converges to the same thing as the predicted values approach 1, but sometimes the predicted values will be uncertain predictions in the 0.4 to 0.6 range. Those undecided values will contribute roughly the same amount to our gradient updates, no matter which side of 0.5 they happen to fall on. A Dice coefficient utilizing continuous predictions is sometimes referred to as soft Dice.

There’s one tiny complication. Since we’re wanting a loss to minimize, we’re going to take our ratio and subtract it from 1. Doing so will invert the slope of our loss function so that in the high-overlap case, our loss is low; and in the low-overlap case, it’s high. Here’s what that looks like in code.

Listing 13.24 training.py:315, .diceLoss

def diceLoss(self, prediction_g, label_g, epsilon=1):
  diceLabel_g = label_g.sum(dim=[1,2,3])                      
  dicePrediction_g = prediction_g.sum(dim=[1,2,3])
  diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])
 
  diceRatio_g = (2 * diceCorrect_g + epsilon) 
    / (dicePrediction_g + diceLabel_g + epsilon)              
 
  return 1 - diceRatio_g                                      

Sums over everything except the batch dimension to get the positively labeled, (softly) positively detected, and (softly) correct positives per batch item

The Dice ratio. To avoid problems when we accidentally have neither predictions nor labels, we add 1 to both numerator and denominator.

To make it a loss, we take 1 - Dice ratio, so lower loss is better.

We’re going to update our computeBatchLoss function to call self.diceLoss. Twice. We’ll compute the normal Dice loss for the training sample, as well as for only the pixels included in label_g. By multiplying our predictions (which, remember, are floating-point values) times the label (which are effectively Booleans), we’ll get pseudo-predictions that got every negative pixel “exactly right” (since all the values for those pixels are multiplied by the false-is-zero values from label_g). The only pixels that will generate loss are the false negative pixels (everything that should have been predicted true, but wasn’t). This will be helpful, since recall is incredibly important for our overall project; after all, we can’t classify tumors properly if we don’t detect them in the first place!

Listing 13.25 training.py:282, .computeBatchLoss

def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
           classificationThreshold=0.5):
  input_t, label_t, series_list, _slice_ndx_list = batch_tup
 
  input_g = input_t.to(self.device, non_blocking=True)              
  label_g = label_t.to(self.device, non_blocking=True)
 
  if self.segmentation_model.training and self.augmentation_dict:   
    input_g, label_g = self.augmentation_model(input_g, label_g)
 
  prediction_g = self.segmentation_model(input_g)                   
 
  diceLoss_g = self.diceLoss(prediction_g, label_g)                 
  fnLoss_g = self.diceLoss(prediction_g * label_g, label_g)
  # ... line 313
  return diceLoss_g.mean() + fnLoss_g.mean() * 8                    

Transfers to GPU

Augments as needed if we are training. In validation, we would skip this.

Runs the segmentation model ...

... and applies our fine Dice loss

Oops. What is this?

Let’s talk a bit about what we’re doing with our return statement of diceLoss_g .mean() + fnLoss_g.mean() * 8.

Loss weighting

In chapter 12, we discussed shaping our dataset so that our classes were not wildly imbalanced. That helped training converge, since the positive and negative samples present in each batch were able to counteract the general pull of the other, and the model had to learn to discriminate between them to improve. We’re approximating that same balance here by cropping down our training samples to include fewer non-positive pixels; but it’s incredibly important to have high recall, and we need to make sure that as we train, we’re providing a loss that reflects that fact.

We are going to have a weighted loss that favors one class over the other. What we’re saying by multiplying fnLoss_g by 8 is that getting the entire population of our positive pixels right is eight times more important than getting the entire population of negative pixels right (nine, if you count the one in diceLoss_g). Since the area covered by the positive mask is much, much smaller than the whole 64 × 64 crop, that also means each individual positive pixel wields that much more influence when it comes to backpropagation.

We’re willing to trade away many correctly predicted negative pixels in the general Dice loss to gain one correct pixel in the false negative loss. Since the general Dice loss is a strict superset of the false negative loss, the only correct pixels available to make that trade are ones that start as true negatives (all of the true positive pixels are already included in the false negative loss, so there’s no trade to be made).

Since we’re willing to sacrifice huge swaths of true negative pixels in the pursuit of having better recall, we should expect a large number of false positives in general.14 We’re doing this because recall is very, very important to our use case, and we’d much rather have some false positives than even a single false negative.

We should note that this approach only works when using the Adam optimizer. When using SGD, the push to overpredict would lead to every pixel coming back as positive. Adam’s ability to fine-tune the learning rate means stressing the false negative loss doesn’t become overpowering.

Collecting metrics

Since we’re going to purposefully skew our numbers for better recall, let’s see just how tilted things will be. In our classification computeBatchLoss, we compute various per-sample values that we used for metrics and the like. We also compute similar values for the overall segmentation results. These true positive and other metrics were previously computed in logMetrics, but due to the size of the result data (recall that each single CT slice from the validation set is a quarter-million pixels!), we need to compute these summary stats live in the computeBatchLoss function.

Listing 13.26 training.py:297, .computeBatchLoss

start_ndx = batch_ndx * batch_size
end_ndx = start_ndx + input_t.size(0)
 
with torch.no_grad():
  predictionBool_g = (prediction_g[:, 0:1]
            > classificationThreshold).to(torch.float32)        
 
  tp = (   predictionBool_g *  label_g).sum(dim=[1,2,3])        
  fn = ((1 - predictionBool_g) *  label_g).sum(dim=[1,2,3])
  fp = (   predictionBool_g * (~label_g)).sum(dim=[1,2,3])
 
  metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g   
  metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
  metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
  metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp

We threshold the prediction to get “hard” Dice but convert to float for the later multiplication.

Computing true positives, false positives, and false negatives is similar to what we did when computing the Dice loss.

We store our metrics to a large tensor for future reference. This is per batch item rather than averaged over the batch.

As we discussed at the beginning of this section, we can compute our true positives and so on by multiplying our prediction (or its negation) and our label (or its negation) together. Since we’re not as worried about the exact values of our predictions here (it doesn’t really matter if we flag a pixel as 0.6 or 0.9--as long as it’s over the threshold, we’ll call it part of a nodule candidate), we are going to create predictionBool_g by comparing it to our threshold of 0.5.

13.6.4 Getting images into TensorBoard

One of the nice things about working on segmentation tasks is that the output is easily represented visually. Being able to eyeball our results can be a huge help for determining whether a model is progressing well (but perhaps needs more training), or if it has gone off the rails (so we need to stop wasting our time with further training). There are many ways we could package up our results as images, and many ways we could display them. TensorBoard has great support for this kind of data, and we already have TensorBoard SummaryWriter instances integrated with our training runs, so we’re going to use TensorBoard. Let’s see what it takes to get everything hooked up.

We’ll add a logImages function to our main application class and call it with both our training and validation data loaders. While we are at it, we will make another change to our training loop: we’re only going to perform validation and image logging on the first and then every fifth epoch. We do this by checking the epoch number against a new constant, validation_cadence.

When training, we’re trying to balance a few things:

  • Getting a rough idea of how our model is training without having to wait very long

  • Spending the bulk of our GPU cycles training, rather than validating

  • Making sure we are still performing well on the validation set

The first point means we need to have relatively short epochs so that we get to call logMetrics more often. The second, however, means we want to train for a relatively long time before calling doValidation. The third means we need to call doValidation regularly, rather than once at the end of training or something unworkable like that. By only doing validation on the first and then every fifth epoch, we can meet all of those goals. We get an early signal of training progress, spend the bulk of our time training, and have periodic check-ins with the validation set as we go along.

Listing 13.27 training.py:210, SegmentationTrainingApp.main

def main(self):
  # ... line 217
  self.validation_cadence = 5
  for epoch_ndx in range(1, self.cli_args.epochs + 1):              
    # ... line 228
    trnMetrics_t = self.doTraining(epoch_ndx, train_dl)             
    self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)                 
 
    if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:  
      # ... line 239
      self.logImages(epoch_ndx, 'trn', train_dl)                    
      self.logImages(epoch_ndx, 'val', val_dl)

Our outermost loop, over the epochs

Trains for one epoch

Logs the (scalar) metrics from training after each epoch

Only every validation cadence-th interval ...

... we validate the model and log images.

There isn’t a single right way to structure our image logging. We are going to grab a handful of CTs from both the training and validation sets. For each CT, we will select 6 evenly spaced slices, end to end, and show both the ground truth and our model’s output. We chose 6 slices only because TensorBoard will show 12 images at a time, and we can arrange the browser window to have a row of label images over the model output. Arranging things this way makes it easy to visually compare the two, as we can see in figure 13.16.

Figure 13.16 Top row: label data for training. Bottom row: output from the segmentation model.

Also note the small slider-dot on the prediction images. That slider will allow us to view previous versions of the images with the same label (such as val/0_prediction_3, but at an earlier epoch). Being able to see how our segmentation output changes over time can be useful when we’re trying to debug something or make tweaks to achieve a specific result. As training progresses, TensorBoard will limit the number of images viewable from the slider to 10, probably to avoid overwhelming the browser with a huge number of images.

The code that produces this output starts by getting 12 series from the pertinent data loader and 6 images from each series.

Listing 13.28 training.py:326, .logImages

def logImages(self, epoch_ndx, mode_str, dl):
  self.segmentation_model.eval()                                    
 
  images = sorted(dl.dataset.series_list)[:12]                      
  for series_ndx, series_uid in enumerate(images):
    ct = getCt(series_uid)
 
    for slice_ndx in range(6):
      ct_ndx = slice_ndx * (ct.hu_a.shape[0] - 1) // 5              
      sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)
 
      ct_t, label_t, series_uid, ct_ndx = sample_tup

Sets the model to eval

Takes (the same) 12 CTs by bypassing the data loader and using the dataset directly. The series list might be shuffled, so we sort.

Selects six equidistant slices throughout the CT

After that, we feed ct_t it into the model. This looks very much like what we see in computeBatchLoss; see p2ch13/training.py for details if desired.

Once we have prediction_a, we need to build an image_a that will hold RGB values to display. We’re using np.float32 values, which need to be in a range from 0 to 1. Our approach will cheat a little by adding together various images and masks to get data in the range 0 to 2, and then multiplying the entire array by 0.5 to get it back into the right range.

Listing 13.29 training.py:346, .logImages

ct_t[:-1,:,:] /= 2000
ct_t[:-1,:,:] += 0.5
 
ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
 
image_a = np.zeros((512, 512, 3), dtype=np.float32)
image_a[:,:,:] = ctSlice_a.reshape((512,512,1))          
image_a[:,:,0] += prediction_a & (1 - label_a)    
image_a[:,:,0] += (1 - prediction_a) & label_a           
image_a[:,:,1] += ((1 - prediction_a) & label_a) * 0.5   
 
image_a[:,:,1] += prediction_a & label_a                 
image_a *= 0.5
image_a.clip(0, 1, image_a)

CT intensity is assigned to all RGB channels to provide a grayscale base image.

False positives are flagged as red and overlaid on the image.

False negatives are orange.

True positives are green.

Our goal is to have a grayscale CT at half intensity, overlaid with predicted-nodule (or, more correctly, nodule-candidate) pixels in various colors. We’re going to use red for all pixels that are incorrect (false positives and false negatives). This will mostly be false positives, which we don’t care about too much (since we’re focused on recall). 1 - label_a inverts the label, and that multiplied by the prediction_a gives us only the predicted pixels that aren’t in a candidate nodule. False negatives get a half-strength mask added to green, which means they will show up as orange (1.0 red and 0.5 green renders as orange in RGB). Every correctly predicted pixel inside a nodule is set to green; since we got those pixels right, no red will be added, and so they will render as pure green.

After that, we renormalize our data to the 0...1 range and clamp it (in case we start displaying augmented data here, which would cause speckles when the noise was outside our expected CT range). All that remains is to save the data to TensorBoard.

Listing 13.30 training.py:361, .logImages

writer = getattr(self, mode_str + '_writer')
writer.add_image(
  f'{mode_str}/{series_ndx}_prediction_{slice_ndx}',
  image_a,
  self.totalTrainingSamples_count,
  dataformats='HWC',
)

This looks very similar to the writer.add_scalar calls we’ve seen before. The dataformats='HWC' argument tells TensorBoard that the order of axes in our image has our RGB channels as the third axis. Recall that our network layers often specify outputs that are B × C × H × W, and we could put that data directly into TensorBoard as well if we specified 'CHW'.

We also want to save the ground truth that we’re using to train, which will form the top row of our TensorBoard CT slices we saw earlier in figure 13.16. The code for that is similar enough to what we just saw that we’ll skip it. Again, check p2ch13/training.py if you want the details.

13.6.5 Updating our metrics logging

To give us an idea how we are doing, we compute per-epoch metrics: in particular, true positives, false negatives, and false positives. This is what the following listing does. Nothing here will be particularly surprising.

Listing 13.31 training.py:400, .logMetrics

sum_a = metrics_a.sum(axis=1)
allLabel_count = sum_a[METRICS_TP_NDX] + sum_a[METRICS_FN_NDX]
metrics_dict['percent_all/tp'] = 
  sum_a[METRICS_TP_NDX] / (allLabel_count or 1) * 100
metrics_dict['percent_all/fn'] = 
  sum_a[METRICS_FN_NDX] / (allLabel_count or 1) * 100
metrics_dict['percent_all/fp'] = 
  sum_a[METRICS_FP_NDX] / (allLabel_count or 1) * 100    

Can be larger than 100% since we’re comparing to the total number of pixels labeled as candidate nodules, which is a tiny fraction of each image

We are going to start scoring our models as a way to determine whether a particular training run is the best we’ve seen so far. In chapter 12, we said we’d be using the F1 score for our model ranking, but our goals are different here. We need to make sure our recall is as high as possible, since we can’t classify a potential nodule if we don’t find it in the first place!

We will use our recall to determine the “best” model. As long as the F1 score is reasonable for that epoch,15 we just want to get recall as high as possible. Screening out any false positives will be the responsibility of the classification model.

Listing 13.32 training.py:393, .logMetrics

def logMetrics(self, epoch_ndx, mode_str, metrics_t):
  # ... line 453
  score = metrics_dict['pr/recall']
 
  return score

When we add similar code to our classification training loop in the next chapter, we’ll use the F1 score.

Back in the main training loop, we’ll keep track of the best_score we’ve seen so far in this training run. When we save our model, we’ll include a flag that indicates whether this is the best score we’ve seen so far. Recall from section 13.6.4 that we’re only calling the doValidation function for the first and then every fifth epochs. That means we’re only going to check for a best score on those epochs. That shouldn’t be a problem, but it’s something to keep in mind if you need to debug something happening on epoch 7. We do this checking just before we save the images.

Listing 13.33 training.py:210, SegmentationTrainingApp.main

def main(self):
  best_score = 0.0
  for epoch_ndx in range(1, self.cli_args.epochs + 1):         
      # if validation is wanted
      # ... line 233
      valMetrics_t = self.doValidation(epoch_ndx, val_dl)
      score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)  
      best_score = max(score, best_score)
 
      self.saveModel('seg', epoch_ndx, score == best_score)    

The epoch-loop we already saw

Computes the score. As we saw earlier, we take the recall.

Now we only need to write saveModel. The third parameter is whether we want to save it as best model, too.

Let’s take a look at how we persist our model to disk.

13.6.6 Saving our model

PyTorch makes it pretty easy to save our model to disk. Under the hood, torch.save uses the standard Python pickle library, which means we could pass our model instance in directly, and it would save properly. That’s not considered the ideal way to persist our model, however, since we lose some flexibility.

Instead, we will save only the parameters of our model. Doing this allows us to load those parameters into any model that expects parameters of the same shape, even if the class doesn’t match the model those parameters were saved under. The save-parameters-only approach allows us to reuse and remix our models in more ways than saving the entire model.

We can get at our model’s parameters using the model.state_dict() function.

Listing 13.34 training.py:480, .saveModel

def saveModel(self, type_str, epoch_ndx, isBest=False):
  # ... line 496
  model = self.segmentation_model
  if isinstance(model, torch.nn.DataParallel):
    model = model.module                             
 
  state = {
    'sys_argv': sys.argv,
    'time': str(datetime.datetime.now()),
    'model_state': model.state_dict(),               
    'model_name': type(model).__name__,
    'optimizer_state' : self.optimizer.state_dict(), 
    'optimizer_name': type(self.optimizer).__name__,
    'epoch': epoch_ndx,
    'totalTrainingSamples_count': self.totalTrainingSamples_count,
  }
  torch.save(state, file_path)

Gets rid of the DataParallel wrapper, if it exists

The important part

Preserves momentum, and so on

We set file_path to something like data-unversioned/part2/models/p2ch13/ seg_2019-07-10_02.17.22_ch12.50000.state. The .50000. part is the number of training samples we’ve presented to the model so far, while the other parts of the path are obvious.

tip By saving the optimizer state as well, we could resume training seamlessly. While we don’t provide an implementation of this, it could be useful if your access to computing resources is likely to be interrupted. Details on loading a model and optimizer to restart training can be found in the official documentation (https://pytorch.org/tutorials/beginner/saving_loading_models.html).

If the current model has the best score we’ve seen so far, we save a second copy of state with a .best.state filename. This might get overwritten later by another, higher-score version of the model. By focusing only on this best file, we can divorce customers of our trained model from the details of how each epoch of training went (assuming, of course, that our score metric is of high quality).

Listing 13.35 training.py:514, .saveModel

if isBest:
  best_path = os.path.join(
    'data-unversioned', 'part2', 'models',
    self.cli_args.tb_prefix,
    f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
  shutil.copyfile(file_path, best_path)
 
  log.info("Saved model params to {}".format(best_path))
 
with open(file_path, 'rb') as f:
  log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())

We also output the SHA1 of the model we just saved. Similar to sys.argv and the timestamp we put into the state dictionary, this can help us debug exactly what model we’re working with if things become confused later (for example, if a file gets renamed incorrectly).

We will update our classification training script in the next chapter with a similar routine for saving the classification model. In order to diagnose a CT, we’ll need to have both models.

13.7 Results

Now that we’ve made all of our code changes, we’ve hit the last section in step 3 of figure 13.17. It’s time to run python -m p2ch13.training --epochs 20 --augmented final_seg. Let’s see what our results look like!

Figure 13.17 The outline of this chapter, with a focus on the results we see from training

Here is what our training metrics look like if we limit ourselves to the epochs we have validation metrics for (we’ll be looking at those metrics next, so this will keep it an apples-to-apples comparison):

E1 trn      0.5235 loss, 0.2276 precision, 0.9381 recall, 0.3663 f1 score 
E1 trn_all  0.5235 loss,  93.8% tp, 6.2% fn,     318.4% fp                
...
E5 trn      0.2537 loss, 0.5652 precision, 0.9377 recall, 0.7053 f1 score 
E5 trn_all  0.2537 loss,  93.8% tp, 6.2% fn,      72.1% fp                
...
E10 trn      0.2335 loss, 0.6011 precision, 0.9459 recall, 0.7351 f1 score
E10 trn_all  0.2335 loss,  94.6% tp, 5.4% fn,      62.8% fp               
...
E15 trn      0.2226 loss, 0.6234 precision, 0.9536 recall, 0.7540 f1 score
E15 trn_all  0.2226 loss,  95.4% tp, <2>  4.6% fn,      57.6% fp          
 ...
E20 trn      0.2149 loss, 0.6368 precision, 0.9584 recall, 0.7652 f1 score
E20 trn_all  0.2149 loss,  95.8% tp, <2>  4.2% fn,      54.7% fp          

TPs are trending up, too. Great! And FNs and FPs are trending down.

In these rows, we are particularly interested in the F1 score--it is trending up. Good!

In these rows, we are particularly interested in the F1 score--it is trending up. Good!

TPs are trending up, too. Great! And FNs and FPs are trending down.

Overall, it looks pretty good. True positives and the F1 score are trending up, false positives and negatives are trending down. That’s what we want to see! The validation metrics will tell us whether these results are legitimate. Keep in mind that since we’re training on 64 × 64 crops, but validating on whole 512 × 512 CT slices, we are almost certainly going to have drastically different TP:FN:FP ratios. Let’s see:

E1 val      0.9441 loss, 0.0219 precision, 0.8131 recall, 0.0426 f1 score
E1 val_all  0.9441 loss,  81.3% tp,  18.7% fn,    3637.5% fp
 
E5 val      0.9009 loss, 0.0332 precision, 0.8397 recall, 0.0639 f1 score
E5 val_all  0.9009 loss,  84.0% tp,  16.0% fn,    2443.0% fp
 
E10 val      0.9518 loss, 0.0184 precision, 0.8423 recall, 0.0360 f1 score
E10 val_all  0.9518 loss,  84.2% tp,  15.8% fn,    4495.0% fp              
 
E15 val      0.8100 loss, 0.0610 precision, 0.7792 recall, 0.1132 f1 score
E15 val_all  0.8100 loss,  77.9% tp,  22.1% fn,    1198.7% fp
 
E20 val      0.8602 loss, 0.0427 precision, 0.7691 recall, 0.0809 f1 score
E20 val_all  0.8602 loss,  76.9% tp,  23.1% fn,    1723.9% fp

The highest TP rate (great). Note that the TP rate is the same as recall. But FPs are 4495%--that sounds like a lot.

Ouch--false positive rates over 4,000%? Yes, actually, that’s expected. Our validation slice area is 218 pixels (512 is 29), while our training crop is only 212. That means we’re validating on a slice surface that’s 26 = 64 times bigger! Having a false positive count that’s also 64 times bigger makes sense. Remember that our true positive rate won’t have changed meaningfully, since it would all have been included in the 64 × 64 sample we trained on in the first place. This situation also results in very low precision, and, hence, a low F1 score. That’s a natural result of how we’ve structured the training and validation, so it’s not a cause for alarm.

What’s problematic, however, is our recall (and, hence, our true positive rate). Our recall plateaus between epochs 5 and 10 and then starts to drop. It’s pretty obvious that we begin overfitting very quickly, and we can see further evidence of that in figure 13.18--while the training recall keeps trending upward, the validation recall decreases after 3 million samples. This is how we identified overfitting in chapter 5, in particular figure 5.14.

Figure 13.18 The validation set recall, showing signs of overfitting when recall goes down after epoch 10 (3 million samples)

Note Always keep in mind that TensorBoard will smooth your data lines by default. The lighter ghost line behind the solid color shows the raw values.

The U-Net architecture has a lot of capacity, and even with our reduced filter and depth counts, it’s able to memorize our training set pretty quickly. One upside is that we don’t end up needing to train the model for very long!

Recall is our top priority for segmentation, since we’ll let issues with precision be handled downstream by the classification models. Reducing those false positives is the entire reason we have those classification models! This skewed situation does mean it is more difficult than we’d like to evaluate our model. We could instead use the F2 score, which weights recall more heavily (or F5, or F10 ...), but we’d have to pick an N high enough to almost completely discount precision. We’ll skip the intermediates and just score our model by recall, and use our human judgment to make sure a given training run isn’t being pathological about it. Since we’re training on the Dice loss, rather than directly on recall, it should work out.

This is one of the situations where we are cheating a little, because we (the authors) have already done the training and evaluation for chapter 14, and we know how all of this is going to turn out. There isn’t any good way to look at this situation and know that the results we’re seeing will work. Educated guesses are helpful, but they are no substitute for actually running experiments until something clicks.

As it stands, our results are good enough to use going forward, even if our metrics have some pretty extreme values. We’re one step closer to finishing our end-to-end project!

13.8 Conclusion

In this chapter, we’ve discussed a new way of structuring models for pixel-to-pixel segmentation; introduced U-Net, an off-the-shelf, proven model architecture for those kinds of tasks; and adapted an implementation for our own use. We’ve also changed our dataset to provide data for our new model’s training needs, including small crops for training and a limited set of slices for validation. Our training loop now has the ability to save images to TensorBoard, and we have moved augmentation from the dataset into a separate model that can operate on the GPU. Finally, we looked at our training results and discussed how even though the false positive rate (in particular) looks different from what we might hope, our results will be acceptable given our requirements for them from the larger project. In chapter 14, we will pull together the various models we’ve written into a cohesive, end-to-end whole.

13.9 Exercises

  1. Implement the model-wrapper approach to augmentation (like what we used for segmentation training) for the classification model.

    1. What compromises did you have to make?

    2. What impact did the change have on training speed?

  2. Change the segmentation Dataset implementation to have a three-way split for training, validation, and test sets.

    1. What fraction of the data did you use for the test set?

    2. Do performance on the test set and the validation set seem consistent with each other?

    3. How badly does training suffer with the smaller training set?

  3. Make the model try to segment malignant versus benign in addition to is-nodule status.

    1. How does your metrics reporting need to change? Your image generation?

    2. What kind of results do you see? Is the segmentation good enough to skip the classification step?

  4. Can you train the model on a combination of 64 × 64 crops and whole-CT slices?16

  5. Can you find additional sources of data to use beyond just the LUNA (or LIDC) data?

13.10 Summary

  • Segmentation flags individual pixels or voxels for membership in a class. This is in contrast to classification, which operates at the level of the entire image.

  • U-Net was a breakthrough model architecture for segmentation tasks.

  • Using segmentation followed by classification, we can implement detection with relatively modest data and computation requirements.

  • Naive approaches to 3D segmentation can quickly use too much RAM for current-generation GPUs. Carefully limiting the scope of what is presented to the model can help limit RAM usage.

  • It is possible to train a segmentation model on image crops while validating on whole-image slices. This flexibility can be important for class balancing.

  • Loss weighting is an emphasis on the loss computed from certain classes or subsets of the training data, to encourage the model to focus on the desired results. It can complement class balancing and is a useful tool when trying to tweak model training performance.

  • TensorBoard can display 2D images generated during training and will save a history of how those models changed over the training run. This can be used to visually track changes to model output as training progresses.

  • Model parameters can be saved to disk and loaded back to reconstitute a model that was saved earlier. The exact model implementation can change as long as there is a 1:1 mapping between old and new parameters.


1.We expect to mark quite a few things that are not nodules; thus, we use the classification step to reduce the number of these.

2.Joseph Redmon and Ali Farhadi, “YOLOv3: An Incremental Improvement,” https://pjreddie.com/media/ files/papers/YOLOv3.pdf. Perhaps check it out once you’ve finished the book.

3.... “head, shoulders, knees, and toes, knees and toes,” as my (Eli’s) toddlers would sing.

4.The implementation included here differs from the official paper by using average pooling instead of max pooling to downsample. The most recent version on GitHub has changed to use max pool.

5.In the unlikely event our code throws any exceptions--which it clearly won’t, will it?

6.For example, Stanislav Nikolov et al., “Deep Learning to Achieve Clinically Applicable Segmentation of Head and Neck Anatomy for Radiotherapy,” https://arxiv.org/pdf/1809.04430.pdf.

7.The bug here is that the wraparound at 0 will go undetected. It does not matter much to us. As an exercise, implement proper bounds checking.

8.Fixing this issue would not do a great deal to teach you about PyTorch.

9.Samuel G. Armato 3rd et al., 2011, “The Lung Image Database Consortium (LIDC) and Image Database Resource Initiative (IDRI): A Completed Reference Database of Lung Nodules on CT Scans,” Medical Physics 38, no. 2 (2011): 915-31, https://pubmed.ncbi.nlm.nih.gov/21452728/. See also Bruce Vendt, LIDC-IDRI, Cancer Imaging Archive, http://mng.bz/mBO4.

10.If you do this a lot, the pandas library that just released 1.0 in 2020 is a great tool to make this faster. We stick with the CSV reader included in the standard Python distribution here.

11.Most CT scanners produce 512 × 512 slices, and we’re not going to worry about the ones that do something different.

12.Otherwise, your model would train instantly!

13.See http://cs231n.github.io/neural-networks-3.

14.Roxie would be proud!

15.And yes, “reasonable” is a bit of a dodge. “Nonzero” is a good starting place, if you’d like something more specific.

16.Hint: Each sample tuple to be batched together must have the same shape for each corresponding tensor, but the next batch could have different samples with different shapes.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset