Matplotlib

Matplotlib

Matplotlib

Matplotlib is a Python 2D plotting module which produces publication quality figures in a variety of formats (jpg, png, etc). In this tutorial, you will learn the basics of how to use the Matplotlib module.

Plotting your first graph

Let's import Matplotlib library. When running Python using the command line, the graphs are typically shown in a separate window. In a Jupyter Notebook, you can simply output the graphs within the notebook itself by running the %matplotlib inline magic command.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

You can change the format to svg for better quality figures. You can also try the retina format and see which one looks better on your computer's screen.

In [2]:
%config InlineBackend.figure_format = 'retina'

You can also change the default style of plots. Let's go for our favourite style, ggplot.

In [3]:
plt.style.use("ggplot")

Now let's plot our first graph. We use the plot function to create the plot and we use the show function to display the plot. We place a semi-colon at the end of the show function to suppress the actual output of this function, which is not very useful as it looks something like matplotlib.lines.Line2D at 0x11a05e2e8.

In [4]:
plt.plot([1, 2, 4, 7, 5, 4])
plt.show(); 

So, it's as simple as calling the plot function with some data, and then calling the show function. If the plot function is given one array of data, it will use it as the coordinates on the vertical axis, and it will just use each data point's index in the array as the horizontal coordinate.

You can also provide two arrays: one for the horizontal axis x, and the second for the vertical axis y.

In [5]:
plt.plot([-3, -2, 0, 5], [1, 3, 2, 10])
plt.show();

The axes automatically match the extent of the data. We would like to give the graph a bit more room, so let's call the xlim and ylimfunctions to change the extent of each axis. Here, you can also specify a value of "None" for the default limit.

In [6]:
plt.plot([-3, -2, 0, 5], [1, 3, 2, 10])
plt.xlim(-5, 7)
plt.ylim(None, 12)
plt.show();

Now, let's plot a mathematical function. We use NumPy's linspace function to create an array x containing 500 floats ranging from -2 to 2, then we create a second array y computed as the square of x. While at it, we change the color to blue from the default style color red.

In [7]:
x = np.linspace(-2, 2, 500)
y = x**2
plt.plot(x, y, color='blue')
plt.show();

That's a bit dry, so let's add a title, and x and y labels, and also draw a grid.

In this particular case, since we are using the ggplot style, we get the grid for free, but in the default case, you can use the grid function for displaying a grid.

In [8]:
plt.plot(x, y, color='blue')
plt.title("Square function")
plt.xlabel("x")
plt.ylabel("y = x**2")
plt.grid(True)
plt.show();

Line style and color

By default, matplotlib draws a line between consecutive points. You can pass a 3rd argument to change the line's style and color. For example "b--" means "blue dashed line".

In [9]:
plt.plot(x, y, 'b--')
plt.title("Square function")
plt.xlabel("x")
plt.ylabel("y = x**2")
plt.show();

You can easily plot multiple lines on one graph. You simply call plot multiple times before calling show.

You can also draw simple points instead of lines. Here's an example with green dashes, red dotted line and blue triangles.

Check out the documentation for the full list of style & color options.

In [10]:
x = np.linspace(-1.4, 1.4, 30)
plt.plot(x, x, 'g--')
plt.plot(x, x**2, 'r:')
plt.plot(x, x**3, 'b^')
plt.show();

For each plot line, you can set extra attributes, such as the line width, the dash style, or the alpha level. See the full list of attributes in the documentation. You can also overwrite the current style's grid options using the grid function.

Subplots

A Matplotlib figure may contain multiple subplots. These subplots are organized in a grid. To create a subplot, just call the subplot function, specify the number of rows and columns in the figure, and the index of the subplot you want to draw on (starting from 1, then left to right, and top to bottom).

In [11]:
x = np.linspace(-1.4, 1.4, 30)
plt.subplot(2, 2, 1)  # 2 rows, 2 columns, 1st subplot = top left
plt.plot(x, x)
plt.subplot(2, 2, 2)  # 2 rows, 2 columns, 2nd subplot = top right
plt.plot(x, x**2)
plt.subplot(2, 2, 3)  # 2 rows, 2 columns, 3rd subplot = bottow left
plt.plot(x, x**3)
plt.subplot(2, 2, 4)  # 2 rows, 2 columns, 4th subplot = bottom right
plt.plot(x, x**4)
plt.show();

It is easy to create subplots that span across multiple grid cells.

In [12]:
plt.subplot(2, 2, 1)  # 2 rows, 2 columns, 1st subplot = top left
plt.plot(x, x)
plt.subplot(2, 2, 2)  # 2 rows, 2 columns, 2nd subplot = top right
plt.plot(x, x**2)
plt.subplot(2, 1, 2)  # 2 rows, *1* column, 2nd subplot = bottom
plt.plot(x, x**3)
plt.show();

If you need even more flexibility in subplot positioning, check out the GridSpec documentation.

Text and annotations

You can call text to add text at any location in the graph. Just specify the horizontal and vertical coordinates and the text, and optionally some extra attributes. Any text in matplotlib may contain TeX equation expressions, see the documentation for more details. Below, ha is an alias for horizontalalignment.

In [13]:
x = np.linspace(-1.5, 1.5, 30)
px = 0.8
py = px**2

plt.plot(x, x**2, "b-", px, py, "ro")

plt.text(0, 1.5, "Square function\n$y = x^2$", fontsize=15, color='blue', horizontalalignment="center")
plt.text(px, py, "x = %0.2f\ny = %0.2f"%(px, py), rotation=50, color='gray')

plt.show();

For more text properties, visit the documentation.

Labels and legends

The simplest way to add a legend is to set a label on all lines, then just call the legend function.

In [14]:
x = np.linspace(-1.4, 1.4, 50)
plt.plot(x, x**2, "r--", label="Square function")
plt.plot(x, x**3, "b-", label="Cube function")
plt.legend(loc="lower right")
plt.show();

Lines

You can draw lines simply using the plot function. However, it is often convenient to create a utility function that plots a (seemingly) infinite line across the graph, given a slope and an intercept. You can also use the hlines and vlines functions that plot horizontal and vertical line segments.

In [15]:
def plot_line(axis, slope, intercept, **kargs):
    xmin, xmax = axis.get_xlim()
    plt.plot([xmin, xmax], [xmin*slope+intercept, xmax*slope+intercept], **kargs)

x = np.random.randn(1000)
y = 0.5*x + 5 + np.random.randn(1000)
plt.axis([-2.5, 2.5, -5, 15])
plt.scatter(x, y, alpha=0.2)
plt.plot(1, 0, "ro", color='black')
plt.vlines(1, -5, 0, color="green", linewidth=0.75)
plt.hlines(0, -2.5, 1, color="green", linewidth=0.75)
plot_line(axis=plt.gca(), slope=0.5, intercept=5, color="blue")
plt.grid(True)
plt.show();

Histograms

You can plot histograms using the hist function.

In [16]:
data = [1, 1.1, 1.8, 2, 2.1, 3.2, 3, 3, 3, 3]
plt.subplot(2,1,1)
plt.hist(data, bins = 10, rwidth=0.8)

plt.subplot(2,1,2)
plt.hist(data, bins = [1, 1.5, 2, 2.5, 3], rwidth=0.95)
plt.xlabel("Value")
plt.ylabel("Frequency")

plt.show();
In [17]:
data1 = np.random.randn(100)
data2 = np.random.randn(100) + 3
data3 = np.random.randn(100) + 6

plt.hist(data1, bins=5, color='g', alpha=0.75, label='bar hist') # default histtype='bar'
plt.hist(data2, color='b', alpha=0.65, histtype='stepfilled', label='stepfilled hist')
plt.hist(data3, color='r', label='bar hist')

plt.legend()
plt.show();

Scatterplots

To draw a scatterplot, simply provide the x and y coordinates of the points and call the scatter function.

In [18]:
x, y = np.random.rand(2, 10)
plt.scatter(x, y)
plt.show();

You may also optionally specify the scale of each point.

In [19]:
scale = np.random.rand(10)
scale = 500 * scale ** 2
plt.scatter(x, y, s=scale)
plt.show();

As usual, there are a number of other attributes you can set, such as the fill and edge colors and the alpha level.

In [20]:
for color in ['red', 'green', 'blue']:
    n = 10
    x, y = np.random.rand(2, n)
    scale = 500.0 * np.random.rand(n) ** 2
    plt.scatter(x, y, s=scale, c=color, alpha=0.3, edgecolors='blue')

plt.show();

Boxplots

Boxplots can be displayed using the boxplot function.

In [21]:
data1 = np.random.rand(10)*2 + 5
plt.boxplot(x=data1)
plt.title("Boxplot")
plt.show();

Saving a figure

Saving a figure to disk is as simple as calling savefig with the name of the file (or a file object). The available image formats depend on the graphics backend you use.

In [22]:
x = np.linspace(-1.4, 1.4, 30)
plt.plot(x, x**2)
plt.savefig("my_square_function.png", transparent=True);