前面笔者已经提到了,paddlepaddle官方文档的一大特点——文档写的好,就是看不懂。在这里没有“诋毁”的意思,毕竟官方文档写的严谨是应该的。
开篇
在学习过程中,我们经常会用到paddle.fluid.data
,这个方法定义的是网络的输入层,可以说是一个网络的入口。fluid.data
有一个重要的参数shape,通常需要根据定义的reader中的数据shape,来确定该shape的值。在学习过程中,案例代码中shape参数值通常都是确定好的,学习者可以很好的运行代码。但是拓展一下,需要自己来确定shape值的时候,很多初学者就会有雾里看花的感觉,无从下手。基于笔者踩过的坑,和大家一起学习。
官方文档中的fluid.data
首先看一下在官方文档中如何说明shape:
paddle.fluid.data(name, shape, dtype='float32', lod_level=0)
- shape (list|tuple)- 声明维度信息的list或tuple。 在示例代码中也给出了解释:
# Creates a variable with fixed size [3, 2, 1]
# User can only feed data of the same shape to x
x = fluid.data(name='x', shape=[3, 2, 1], dtype='float32')
# Creates a variable with changable batch size -1.
# Users can feed data of any batch size into y,
# but size of each data sample has to be [2, 1]
y = fluid.data(name='y', shape=[-1, 2, 1], dtype='float32')
大意就是:
- x的shape中第一个维度是固定的,则feed的数据要和该shape完全一致
- y的shape中第一个维度为-1,则可以是任意值,feed的数据第一个维度是可变的,但是同样要满足数组运算的法则
深入解读
文档中给出的信息就这么多,不知道您有没有明白。不明白没关系,下面再详细说一下笔者的理解。
对于固定的shape很好理解,比如一张shape为m * m * 3的彩色图片,在fluid.data中shape定义为[3, m, m],则表示每次输入一个三通道,每个通道为m * m的数组。
对于可变shape,通常应用于每次输入的是一组batch数据。比如在fluid.data中shape定义为[-1, 3, m, m],则表示为每次输入一个batch size大小的三通道,每个通道为m * m的数组。
笔者在学习过程中经常思考一个问题,(m, m, 3)表示彩色图片通道数在后面,然而在fluid.data中如果没有batch,则用(3, m, m)表示输入的是三通道的彩色图片。为什么会这样呢?
答案就在于,在卷积过程中,卷积计算是按照通道进行计算的。3通道的图片会和3通道的卷积核进行卷积计算。
拓展延伸
对于初学者可能对图片如何用数组表示还不是非常清楚,笔者在这里详细介绍一下。 先看一下原图片:
该图片的shape为:(32, 32, 3),长宽为32像素,每个像素点由类似于[180 191 203]的数组表示,该像素点数组的三个值表示RGB三个通道,每个通道取值范围为0-255。
图片输出为数组则为:
[[[180 191 203]
[180 188 199]
[196 200 209]
...
[160 167 184]
[156 163 182]
[162 170 188]]
[[175 189 197]
[179 190 199]
[193 193 198]
...
单通道图片为:
单通道图片shape:(32, 32),这里每个像素由一个0-255的值表示。
图片输出为数组则为:
[[180 180 196 ... 160 156 162]
[175 179 193 ... 157 154 157]
[177 186 185 ... 158 156 159]
...
[180 175 185 ... 134 117 87]
[180 175 174 ... 140 154 102]
[162 177 167 ... 184 193 131]]
实例代码
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
img = Image.open('work/dog.png')
image = np.array(img)
print('图片shape:')
print(image.shape)
print('输出为:')
print(image)
图片shape:
(32, 32, 3)
输出为:
[[[180 191 203]
[180 188 199]
[196 200 209]
...
[160 167 184]
[156 163 182]
[162 170 188]]
[[175 189 197]
[179 190 199]
[193 193 198]
...
[157 164 180]
[154 161 179]
[157 164 182]]
[[177 195 202]
[186 199 208]
[185 181 182]
...
[158 165 182]
[156 163 181]
[159 166 184]]
...
[[180 193 206]
[175 187 199]
[185 191 199]
...
[134 120 110]
[117 102 90]
[ 87 74 61]]
[[180 193 206]
[175 184 200]
[174 175 182]
...
[140 124 112]
[154 139 127]
[102 89 77]]
[[162 172 187]
[177 186 198]
[167 169 171]
...
[184 173 169]
[193 181 175]
[131 115 106]]]
# 原图
plt.imshow(image)
plt.axis('off')
plt.show()
# 单通道图片
s = image[:, :, 0]
print('单通道图片shape:')
print(s.shape)
print('输出为:')
print(s)
# 结果
单通道图片shape:
(32, 32)
输出为:
[[180 180 196 ... 160 156 162]
[175 179 193 ... 157 154 157]
[177 186 185 ... 158 156 159]
...
[180 175 185 ... 134 117 87]
[180 175 174 ... 140 154 102]
[162 177 167 ... 184 193 131]]
# 查看各个通道,通道0
plt.imshow(image[:, :, 0], cmap='gray')
plt.axis('off')
plt.show()
# 查看各个通道,通道1
plt.imshow(image[:, :, 1], cmap='gray')
plt.axis('off')
plt.show()
# 查看各个通道,通道2
plt.imshow(image[:, :, 2], cmap='gray')
plt.axis('off')
plt.show()