绘图

导入需要的模块

1
2
3
4
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats

画基本图形

折线图 plot

1
2
3
4
5
6
y=np.random.randn(100) 
plt.plot(y,'b-')
plt.xlabel('x')
plt.ylabel('y')
plt.title(u'title')
plt.show()

可选参数如下所示:

也可以通过更改参数来改变画图效果

1
2
3
x=np.cumsum(np.random.rand(100)) 
plt.plot(y,label='line label',color='r',linestyle='-',marker='o')
plt.show()

可选的参数有

散点图 scatter

例如:数据服从正态分布,相关系数是0.5

1
2
3
4
5
6
z=np.random.randn(100,2) 
z[:,1]=0.5*z[:,0]+np.sqrt(0.5)*z[:,1]
x=z[:,0]
y=z[:,1]
plt.scatter(x,y)
plt.show()


参数也是可以修改的例如:

1
2
3
4
5
6
z=np.random.randn(100,2) 
z[:,1]=0.5*z[:,0]+np.sqrt(0.5)*z[:,1]
x=z[:,0]
y=z[:,1]
plt.scatter(x,y,marker='s',c='r')
plt.show()

条形图 bar

画条形图 bar,需要两个一位数组,第一个是横坐标,每个条形图的开始位置;纵坐标是条形图的高度

1
2
3
4
y=np.random.rand(5) 
x=np.arange(5)
plt.bar(x,y)
plt.show()

修改他的显示属性,可以使用一个颜色数组来指定每个条形图的颜色。

y=np.random.rand(5);
x=np.arange(5);
colors=[‘#FF0000’,’#FFFF00’,’#00FF00’,’#00FFFF’,’#0000FF’]
plt.bar(x,y,width=0.5,color=colors,edgecolor=’#000000’,linewidth=5)
plt.show()

图表 pie

使用一个一维数组来表示,不要求累加和是1,可以使人以大小的正数

1
2
x=np.arange(1,8) labels=['label1','label2','label3','label4','label5','label6','label7'] plt.pie(x,labels=labels) 
plt.show()

直方图 hist

需要一个数组,bins参数表示将数据分成几组,默认是10组

1
2
3
x=np.random.randn(2000)
plt.hist(x,bins=30)
plt.show()

如果想要生成累计直方图需要使参数cumulative为true

1
2
3
x=np.random.randn(1000);  
plt.hist(x,bins=20,cumulative=True);
plt.show()

多图表

在同一个图上画出多张图表,需要首先使用figure()函数生成一个画板,画子图时需要使用sp=add_subplot(m,n,p)来表示子图。m表示行,n表示列,p表示第几个图。

返回的是子图的句柄用于设置一些参数。最后要想显示出来需要使用draw()函数,将这些子图画在画板上,然后用show()函数显示出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
fig = plt.figure() 
ax = fig.add_subplot(2, 2, 1)
y = np.random.randn(100)
plt.plot(y); ax.set_title('1')
y = np.random.rand(5)
x = np.arange(5)
ax = fig.add_subplot(2, 2, 2) plt.bar(x, y)
ax.set_title('2')
y = np.random.rand(5)
y = y / np.sum(y) y[y < .05] = .05
ax = fig.add_subplot(2, 2, 3)
plt.pie(y) ax.set_title('3')
plt.draw()
plt.show()

3D 曲面图

画线,使用plot,需要Axed3D(fig)来画出3D轴线,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

fig=plt.figure()
ax = Axes3D(fig)

x = np.arange(-10,10,0.1)
y = np.arange(-10,10,0.1)

#网格化数据
X, Y = np.meshgrid(x, y)
Z = np.sqrt(X**2 + Y**2)

ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='rainbow')
plt.show()

3D 曲线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import copy
from mpl_toolkits.mplot3d import Axes3D
x=np.linspace(0,6*np.pi,600);
z=copy.copy(x)
x=np.cos(z)
y=np.sin(z);
fig=plt.figure()
ax = Axes3D(fig)
ax.plot(x,y,zs=z)
plt.xlabel('x')
plt.ylabel('y')
ax.view_init(15,45)
plt.draw()
plt.show()

图像配置

字体

1
2
from matplotlib.font_manager import FontProperties  
font_song = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=15)

使用文字时指定参数 fontproperties=font_song 即可

或者

1
2
ax2.set_xlabel('window size', fontsize=9, fontproperties = 'Times New Roman')
# plt.xlabel('window size', fontdict={'family' : 'Times New Roman', 'size':8})

或者

将全局字体改为Times New Roman:

1
2
import matplotlib.pyplot as plt
plt.rc('font',family='Times New Roman')

如果出现类似如下错误:

1
apps/rhel6/Python-2.7.2/lib/python2.7/site-packages/matplotlib/font_manager.py:1224: UserWarning: findfont: Font family ['Playfair Display'] not found. Falling back to Bitstream Vera Sans(prop.get_family(), self.defaultFamily[fontext]))

则需要删除 fontList.cache 文件。这个文件有点不好找。

用如下命令获得目录:

1
2
import matplotlib as plt
plt.get_cachedir()

然后进去删除fontList.cache就可以了!

参考 matplotlib 字体改为 Times New Roman

窗口

窗口设置

开启一个窗口,num 设置子图数量,figsize 设置窗口大小,dpi 设置分辨率

1
fig = plt.figure(num=1, figsize=(15, 8),dpi=80)

多张子图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fig=plt.figure()

ax1=fig.add_subplot(221)
ax1.plot(x,x)

ax2=fig.add_subplot(222)
ax2.plot(x,-x)

ax3=fig.add_subplot(223)
ax3.plot(x,x**2)

ax4=fig.add_subplot(224)
ax4.plot(x,np.log(x))

plt.show()

坐标轴

关闭刻度

  • 对于 plt
1
2
plt.xticks([])
plt.yticks([])
  • 对于 ax(matplotlib.axes._subplots.AxesSubplot)
1
2
ax.set_xticks([])
ax.set_yticks([])
1
2
ax.spines['top'].set_visible(False)  # 去掉上边框
ax.spines['right'].set_visible(False) # 去掉右边框
  • 全部坐标轴(上下左右)不显示
1
plt.axis('off')

坐标轴名称

1
2
plt.xlabel('Window Size',fontsize=14)
plt.ylabel('SNR',fontsize=14)

或者

1
2
3
4
ax1 = fig.add_subplot(111)
ax1.plot(x, snr, label="SNR")
ax1.set_ylabel('SNR')
ax1.set_xlabel('Window Size')

坐标轴范围

1
2
3
#设置坐标轴范围
plt.xlim((-5, 5))
plt.ylim((-2, 2))

图表标题

1
plt.title('Squares',fontsize=24)

刻度字号

1
plt.tick_params(axis='both',which='major',labelsize=14)

图例

同一图表中两条线

1
plt.legend(handles=[l1,l2],labels=['up','down'],loc='best')

双 y 轴

1
2
3
4
5
6
7
8
9
10
11
12
13
14
fig = plt.figure()
plt.xlabel('Window Size',fontsize=14)

ax1 = fig.add_subplot(111)
ax1.plot(x, snr, label="SNR")
ax1.set_ylabel('SNR')
plt.legend() # 添加图例

ax2 = ax1.twinx() # this is the important function
ax2.plot(x, lsd, 'r',linestyle='--', label="LSD")
ax2.set_ylabel('LSD')
plt.legend()

plt.show()

其中,ax = twinx() 意思是创建了一个独立的Y轴,共享了X轴。双坐标轴

这时两个图例是分开的,如果想要图例合并,则应为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
fig = plt.figure()
plt.xlabel('Window Size',fontsize=14)

ax1 = fig.add_subplot(111)
l1 = ax1.plot(x, snr, label="SNR")
ax1.set_ylabel('SNR')

ax2 = ax1.twinx() # this is the important function
l2 = ax2.plot(x, lsd, 'r',linestyle='--', label="LSD")
ax2.set_ylabel('LSD')

lns = l1+l2
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=0)

plt.show()

保存图像

使用savefig(’filename._ext_’) ,其中ext支持png, pdf, ps, eps or svg格式。

1
plt.savefig('F:/where-you-want-to-save.png', dpi=300, bbox_inches="tight")
  • 保存文件,dpi指定保存文件的分辨率
  • bbox_inches="tight" 可以保存图上所有的信息,不会出现横纵坐标轴的描述存掉了的情况

论文图实例

实例 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# coding=utf-8

import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['Arial'] # 如果要显示中文字体,则在此处设为:SimHei
plt.rcParams['axes.unicode_minus'] = False # 显示负号

x = np.array([1, 2, 3, 4, 5, 6])
VGG_supervised = np.array([2.9749694, 3.9357018, 4.7440844, 6.482254, 8.720203, 13.687582])
VGG_unsupervised = np.array([2.1044724, 2.9757383, 3.7754183, 5.686206, 8.367847, 14.144531])
ourNetwork = np.array([2.0205495, 2.6509762, 3.1876223, 4.380781, 6.004548, 9.9298])

# label在图示(legend)中显示。若为数学公式,则最好在字符串前后添加"$"符号
# color:b:blue、g:green、r:red、c:cyan、m:magenta、y:yellow、k:black、w:white、、、
# 线型:- -- -. : ,
# marker:. , o v < * + 1
plt.figure(figsize=(10, 5),dpi=600)
plt.grid(linestyle="--") # 设置背景网格线为虚线
ax = plt.gca()
ax.spines['top'].set_visible(False) # 去掉上边框
ax.spines['right'].set_visible(False) # 去掉右边框


plt.plot(x, VGG_supervised, marker='o', color="blue", label="VGG-style Supervised Network", linewidth=1.5)
plt.plot(x, VGG_unsupervised, marker='o', color="green", label="VGG-style Unsupervised Network", linewidth=1.5)
plt.plot(x, ourNetwork, marker='o', color="red", label="ShuffleNet-style Network", linewidth=1.5)

group_labels = ['Top 0-5%', 'Top 5-10%', 'Top 10-20%', 'Top 20-50%', 'Top 50-70%', ' Top 70-100%'] # x轴刻度的标识
plt.xticks(x, group_labels, fontsize=12, fontweight='bold') # 默认字体大小为10
plt.yticks(fontsize=12, fontweight='bold')
# plt.title("example", fontsize=12, fontweight='bold') # 默认字体大小为12
plt.xlabel("Performance Percentile", fontsize=13, fontweight='bold')
plt.ylabel("4pt-Homography RMSE", fontsize=13, fontweight='bold')
plt.xlim(0.9, 6.1) # 设置x轴的范围
plt.ylim(1.5, 16)

# plt.legend() #显示各曲线的图例
plt.legend(loc=0, numpoints=1)
leg = plt.gca().get_legend()
ltext = leg.get_texts()
plt.setp(ltext, fontsize=12, fontweight='bold') # 设置图例字体的大小和粗细

plt.savefig('./filename.svg', format='svg') # 建议保存为svg格式,再用inkscape转为矢量图emf后插入word中
plt.show()

参考 Matplotlib画各种论文图

实例 2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np

figure(num=None, figsize=(2.8, 1.7), dpi=300)
# figsize的2.8和1.7指的是英寸,dpi指定图片分辨率。那么图片就是(2.8*300)*(1.7*300)像素大小

plt.plot(test_mean_1000S_n, 'royalblue', label='without threshold')
plt.plot(test_mean_1000S, 'darkorange', label='with threshold')
# 画图,并指定颜色

plt.xticks(fontproperties = 'Times New Roman', fontsize=8)
plt.yticks(np.arange(0, 1.1, 0.2), fontproperties = 'Times New Roman', fontsize=8)
# 指定横纵坐标的字体以及字体大小,记住是fontsize不是size。yticks上我还用numpy指定了坐标轴的变化范围。

plt.legend(loc='lower right', prop={'family':'Times New Roman', 'size':8})
# 图上的legend,记住字体是要用prop以字典形式设置的,而且字的大小是size不是fontsize,这个容易和xticks的命令弄混

plt.title('1000 samples', fontdict={'family' : 'Times New Roman', 'size':8})
# 指定图上标题的字体及大小

plt.xlabel('iterations', fontdict={'family' : 'Times New Roman', 'size':8})
plt.ylabel('accuracy', fontdict={'family' : 'Times New Roman', 'size':8})
# 指定横纵坐标描述的字体及大小

plt.savefig('F:/where-you-want-to-save.png', dpi=300, bbox_inches="tight")
# 保存文件,dpi指定保存文件的分辨率
# bbox_inches="tight" 可以保存图上所有的信息,不会出现横纵坐标轴的描述存掉了的情况

plt.show()
# 记住,如果你要show()的话,一定要先savefig,再show。如果你先show了,存出来的就是一张白纸。

参考 期刊论文写作之【python matplotlib 画图设置】

实例 3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import matplotlib.pyplot as plt
figure, ax = plt.subplots()
item = ['A','B','C']
num1 = [2.5, 2.6, 2.7]
num2 = [2.75, 2.85, 2.95]
x=[1,2,3]

plt.plot(x,num1,label='a',linestyle='--',color='r',marker='D')
plt.plot(x,num2,label='b',linestyle='--',color='b',marker='o')
plt.yticks([2.4,2.6,2.8,3.0]) #设置x,y坐标值
plt.xticks(x)

plt.tick_params(labelsize=16)
labels = ax.get_xticklabels() + ax.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]
font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size' : 16,
}
plt.legend(prop=font1,loc=4)

plt.grid(axis="y")
plt.xlabel('Item',font1)
plt.ylabel('Value',font1)
plt.title('Line Chart',font1)
plt.show()

CNN 结构图

如何用 matplotlib 画论文中的CNN结构图

图片处理

导入模块

1
2
from PIL import Image
import matplotlib.pyplot as plt

基本信息

1
2
3
4
5
6
7
8
9
10
img = Image.open('1.jpg')

print(type(img))
<class 'PIL.JpegImagePlugin.JpegImageFile'>

print(img)
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=28x28 at 0x7F153040C6D8>

print(img.shape)
AttributeError: 'JpegImageFile' object has no attribute 'shape'
1
2
3
4
5
6
7
8
img_1 = np.array(Image.open('1.jpg')).astype('float')

print(img_1.shape)
(28, 28, 3)

img_2 = np.array(Image.open('1.jpg').convert('L')).astype('float')
print(img_2.shape)
(28, 28)

显示单个图片

1
2
3
4
5
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('off') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()

显示灰度图像

1
2
3
4
5
6
7
8
img = img.convert('L')  #变为灰度图像

plt.figure("Image")
# 这里必须加 cmap='gray' ,否则尽管原图像是灰度图(下图1),但是显示的是伪彩色图像(下图2)(如果不加的话)
plt.imshow(img,cmap='gray')
plt.axis('off')
plt.title('image')
plt.show()