Heatmap in Seaborn

A heatmap is an extremely compact way to display a large amount of data. In the finance world, color-coded blocks can give investors a quick glance at which stocks are up or down. In the scientific world, heatmaps allow researchers to visualize the expression level of thousands of genes.

The seaborn.heatmap() function expects a 2D list, 2D Numpy array, or pandas DataFrame as input. If a list or array is supplied, we can supply column and row labels via xticklabels and yticklabels respectively. On the other hand, if a DataFrame is supplied, the column labels and index values will be used to label the columns and rows respectively.

To get started, we will plot an overview of the performance of the six stocks using a heatmap. We define stock performance as the change of closing price when compared to the previous close. This piece of information was already calculated earlier in this chapter (that is, the Close_change column). Unfortunately, we can't supply the whole DataFrame to seaborn.heatmap() directly, since it expects company names as columns, date as index, and the change in closing price as values.

If you are familiar with Microsoft Excel, you might have experience in using pivot tables, a powerful technique to summarize the levels or values of a particular variable. pandas includes such functionality. The following code excerpt makes use of the wonderful Pandas.DataFrame.pivot() function to make a pivot table:

stock_change = stock_df.pivot(index='Date', columns='Company', values='Close_change')
stock_change = stock_change.loc["2017-06-01":"2017-06-30"]
stock_change.head()
Company Date AAPL IBM JNJ MSFT PG XOM
2017-06-01 0.002749 0.000262 0.004133 0.003723 0.000454 0.002484
2017-06-02 0.014819 -0.004061 0.010095 0.023680 0.005220 -0.014870
2017-06-05 -0.009778 0.002368 0.002153 0.007246 0.001693 0.007799
2017-06-06 0.003378 -0.000262 0.003605 0.003320 0.000676 0.013605
2017-06-07 0.005957 -0.009123 -0.000611 -0.001793 -0.000338 -0.003694

 

With the pivot table ready, we can proceed to plot our first heatmap:

ax = sns.heatmap(stock_change)
plt.show()

The default heatmap implementation is not really compact enough. Of course, we can resize the figure via plt.figure(figsize=(width, height)); we can also toggle the square parameter to create square-shaped blocks. To ease visual recognition, we can add a thin border around the blocks.

By US stock market convention, green denotes a rise and red denotes a fall in prices. Hence we can adjust the cmap parameter to adjust the color map. However, neither Matplotlib nor Seaborn includes a red-green color map, so we need to create our own:

At the end of Chapter 7Visualizing Online Data, we briefly introduced functions for creating custom color maps. Here we will use seaborn.diverging_palette() to create the red-green color map, which requires us to specify the hues, saturation, and lightness (husl) for the negative and positive extents of the color map. You may also use this code to launch an interactive widget in Jupyter Notebook to help select the colors:

 

%matplotlib notebook
import seaborn as sns

sns.choose_diverging_palette(as_cmap=True) 

# Create a new red-green color map using the husl color system
# h_neg and h_pos determines the hue of the extents of the color map.
# s determines the color saturation
# l determines the lightness
# sep determines the width of center point
# In addition, we need to set as_cmap=True as the cmap parameter of
# sns.heatmap expects matplotlib colormap object.
rdgn = sns.diverging_palette(h_neg=10, h_pos=140, s=80, l=50,
sep=10, as_cmap=True)

# Change to square blocks (square=True), add a thin
# border (linewidths=.5), and change the color map
# to follow US stocks market convention (cmap="RdGn").
ax = sns.heatmap(stock_change, cmap=rdgn,
linewidths=.5, square=True)

# Prevent x axes label from being cropped
plt.tight_layout()
plt.show()

It could be hard to discern small differences in values when color is the only discriminative factor. Adding text annotations to each color block may help readers understand the magnitude of the difference:

fig = plt.figure(figsize=(6,8))

# Set annot=True to overlay the values.
# We can also assign python format string to fmt.
# For example ".2%" refers to percentage values with
# two decimal points.

ax = sns.heatmap(stock_change, cmap=rdgn,
annot=True, fmt=".2%",
linewidths=.5, cbar=False)
plt.show()
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset