Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
502 views
in Technique[技术] by (71.8m points)

machine learning - What is the role of "Flatten" in Keras?

I am trying to understand the role of the Flatten function in Keras. Below is my code, which is a simple two-layer network. It takes in 2-dimensional data of shape (3, 2), and outputs 1-dimensional data of shape (1, 4):

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

This prints out that y has shape (1, 4). However, if I remove the Flatten line, then it prints out that y has shape (1, 3, 4).

I don't understand this. From my understanding of neural networks, the model.add(Dense(16, input_shape=(3, 2))) function is creating a hidden fully-connected layer, with 16 nodes. Each of these nodes is connected to each of the 3x2 input elements. Therefore, the 16 nodes at the output of this first layer are already "flat". So, the output shape of the first layer should be (1, 16). Then, the second layer takes this as an input, and outputs data of shape (1, 4).

So if the output of the first layer is already "flat" and of shape (1, 16), why do I need to further flatten it?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

If you read the Keras documentation entry for Dense, you will see that this call:

Dense(16, input_shape=(5,3))

would result in a Dense network with 3 inputs and 16 outputs which would be applied independently for each of 5 steps. So, if D(x) transforms 3 dimensional vector to 16-d vector, what you'll get as output from your layer would be a sequence of vectors: [D(x[0,:]), D(x[1,:]),..., D(x[4,:])] with shape (5, 16). In order to have the behavior you specify you may first Flatten your input to a 15-d vector and then apply Dense:

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

EDIT: As some people struggled to understand - here you have an explaining image:

enter image description here


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...