Matplotlib Python is a two-dimensional plotting library. All kinds of plots, including histograms, scatter plots, line plots, dot plots, heat maps, and others, can be generated by Python. In this book, we will use the pyplot
interface of matplotlib
for all our visualization requirements.
In this recipe, we will introduce basic plotting mechanisms using pyplot
. We will use pyplot
in almost all our recipes for visualization in this book.
We used matplotlib version 1.3.1 for all the recipes in this book. In your command line, you can invoke the __version__
attribute to check for the version:
Let's start by looking at how to plot simple graphs using matplotlib's pyplot
module:
#Recipe_2a.py import numpy as np import matplotlib.pyplot as plt def simple_line_plot(x,y,figure_no): plt.figure(figure_no) plt.plot(x,y) plt.xlabel('x values') plt.ylabel('y values') plt.title('Simple Line') def simple_dots(x,y,figure_no): plt.figure(figure_no) plt.plot(x,y,'or') plt.xlabel('x values') plt.ylabel('y values') plt.title('Simple Dots') def simple_scatter(x,y,figure_no): plt.figure(figure_no) plt.scatter(x,y) plt.xlabel('x values') plt.ylabel('y values') plt.title('Simple scatter') def scatter_with_color(x,y,labels,figure_no): plt.figure(figure_no) plt.scatter(x,y,c=labels) plt.xlabel('x values') plt.ylabel('y values') plt.title('Scatter with color') if __name__ == "__main__": plt.close('all') # Sample x y data for line and simple dot plots x = np.arange(1,100,dtype=float) y = np.array([np.power(xx,2) for xx in x]) figure_no=1 simple_line_plot(x,y,figure_no) figure_no+=1 simple_dots(x,y,figure_no) # Sample x,y data for scatter plot x = np.random.uniform(size=100) y = np.random.uniform(size=100) figure_no+=1 simple_scatter(x,y,figure_no) figure_no+=1 label = np.random.randint(2,size=100) scatter_with_color(x,y,label,figure_no) plt.show()
Now we will proceed to look at some advanced topics, including generating heat maps and labeling the x and y axes:
#Recipe_2b.py import numpy as np import matplotlib.pyplot as plt def x_y_axis_labeling(x,y,x_labels,y_labels,figure_no): plt.figure(figure_no) plt.plot(x,y,'+r') plt.margins(0.2) plt.xticks(x,x_labels,rotation='vertical') plt.yticks(y,y_labels,) def plot_heat_map(x,figure_no): plt.figure(figure_no) plt.pcolor(x) plt.colorbar() if __name__ == "__main__": plt.close('all') x = np.array(range(1,6)) y = np.array(range(100,600,100)) x_label = ['element 1','element 2','element 3','element 4','element 5'] y_label = ['weight1','weight2','weight3','weight4','weight5'] x_y_axis_labeling(x,y,x_label,y_label,1) x = np.random.normal(loc=0.5,scale=0.2,size=(10,10)) plot_heat_map(x,2) plt.show()
We will start by importing the required modules. While using pyplot
, it's recommended that you import NumPy:
import numpy as np import matplotlib.pyplot as plt
Let's start by following the code from the main function. There may be graphs from the previously run program. It is good practice to close them all, as we will use more graphs in our program:
plt.close('all')
We will proceed by generating some data using NumPy to demonstrate plotting using pyplot
:
# Sample x y data for line and simple dot plots x = np.arange(1,100,dtype=float) y = np.array([np.power(xx,2) for xx in x])
We generated 100 elements in both our x and y variables. Our y is a square of our x variable.
Let's proceed to doing a simple line plot:
figure_no=1 simple_line_plot(x,y,figure_no)
When our program has multiple plots, it's a good practice to number each plot. Variable figure_no
is used to number our plots. Let's look at the simple_line_plot
function:
def simple_line_plot(x,y,figure_no): plt.figure(figure_no) plt.plot(x,y) plt.xlabel('x values') plt.ylabel('y values') plt.title('Simple Line')
As you can see, we started numbering our plots by calling the figure function in pyplot
. We passed the figure no variable from our main program. After this, we simply called the plot function with our x and y values. We can make our plot meaningful by giving names to our x and y axes using the xlabel
and ylabel
functions respectively. Finally, we can also give a title to our plot. That is it. Our first simple line plot is ready. The plot will not be displayed till the show()
function is called. In our code, we will invoke the show()
function in order to see all the plots together. Our plot will look as follows:
Here, we plotted the values x
on the x axis and x squared
on the y axis.
We created a simple line plot. We can see a nice curve as our y values are squares of our x values.
Let's move on to our next plot:
figure_no+=1 simple_dots(x,y,figure_no)
We will increment our figure number and call the simple_dots
function. We want to plot our x and y values as dots instead of a line. Let's look at the simple_dots
function:
def simple_dots(x,y,figure_no): plt.figure(figure_no) plt.plot(x,y,'or') plt.xlabel('x values') plt.ylabel('y values') plt.title('Simple Dots')
Every line is similar to our previous function except the following line:
plt.plot(x,y,'or')
The or
parameter says that we need dots (o)
, and the dots should be in in the color red (r). The following is the output of the preceding command:
Let's move to our next plot.
We are going to see a scatter plot. Let's generate some data using NumPy:
# Sample x,y data for scatter plot x = np.random.uniform(size=100) y = np.random.uniform(size=100)
We sampled 100 data points from a uniform distribution. Now we will proceed to call the simple_scatter
function in order to generate our scatter plot:
figure_no+=1 simple_scatter(x,y,figure_no)
In the simple_scatter
function, all the lines are similar to the previous plotting routines except for the following line:
plt.scatter(x,y)
Instead of calling the plot function in pyplot
, we invoked the scatter
function. Our plot will look as follows:
Let's move on to our final plot, which is a scatter plot, but the points are colored based on the class label that they belong to:
figure_no+=1 label = np.random.randint(2,size=100) scatter_with_color(x,y,label,figure_no)
We will increment our figure in order to keep track of our graph. In the next line, we will assign some random labels, either 1
or 0
, to our points. Finally, we will call the scatter_with_color
function with our x, y, and label variables.
In the function, let's look at the line that differentiates this code from the previous scatter plot code:
plt.scatter(x,y,c=labels)
We will pass our labels to a c
parameter, which stands for color. Each label will be assigned a unique color. In our example, all the points that are labeled as 0
will get a color that is different from the points that are labeled as 1
, as follows:
Let's move on to plotting some heat maps, and axis labeling.
Once again, we will start with the main function:
plt.close('all') x = np.array(range(1,6)) y = np.array(range(100,600,100)) x_label = ['element 1','element 2','element 3','element 4','element 5'] y_label = ['weight1','weight2','weight3','weight4','weight5'] x_y_axis_labeling(x,y,x_label,y_label,1)
As a good practice, we will close all the previous figures by calling the close
function. We will proceed with generating some data. Our x is an array of five elements, starting from 1
and ending with 5
. Our y is an array of five elements, starting from 100
and ending with 500
. We defined the two x_label
and y_label
lists, which will serve as the labels for our plot. Finally, we invoked the x_y_axis_labeling
function in order to demonstrate the concept of labeling our tickers in the x and y axes.
Let's look at the following function:
def x_y_axis_labeling(x,y,x_labels,y_labels,figure_no): plt.figure(figure_no) plt.plot(x,y,'+r') plt.margins(0.2) plt.xticks(x,x_labels,rotation='vertical') plt.yticks(y,y_labels,)
We will do a simple dot plot by calling pyplot's dot function. However, in this case, we want our points to be displayed as +
instead of o
. Hence, we will specify +r
. Our color of choice is red, hence r
.
In the next two lines, we will specify what our x axis and y axis tickers need to be. By calling the xticks
function, we will pass on our x values and their labels. In addition, we will say that we want the text to be rotated vertically so that they don't overlap each other. Similarly, we will specify the tickers for the y axis. Let's look at our plot, as follows:
Let's see how to generate heat maps using pyplot
:
x = np.random.normal(loc=0.5,scale=0.2,size=(10,10)) plot_heat_map(x,2)
We will generate some data for our heat map. In this case, we generated a 10 x 10 matrix filled with values from a normal distribution of a mean specified by a loc
variable of 0.5
and standard deviation specified by a scale
variable of 0.2
. We will invoke the plot_heat_map
function with this matrix. The second parameter is the figure number:
def plot_heat_map(x,figure_no): plt.figure(figure_no) plt.pcolor(x) plt.colorbar()
We will call the pcolor
function in order to generate a heat map. The next line invokes the colorbar
function to display the color gradients for our range of values:
For more information on matplotlib, you can refer to the general matplotlib documentation at http://matplotlib.org/faq/usage_faq.html.
The following link is an excellent tutorial on pyplot
:
http://matplotlib.org/users/pyplot_tutorial.html
Matplotlib provides excellent three-dimensional plotting capabilities. Refer to the following link for more information:
http://matplotlib.org/mpl_toolkits/mplot3d/tutorial.html
The pylab module in matplotlib combines the name space of NumPy with pyplot
. Pylab can also be used to generate the various types of plots shown in this recipe.