Upload
others
View
3
Download
0
Embed Size (px)
Citation preview
DRAW: A Recurrent Neural Network For ImageGeneration
Karol Gregor, Ivo Danihelka,Alex Graves, Danilo Jimenez Rezende, & Daan Wierstra
Google DeepMind
ICML 2015
Presented by Zhe Gan, Duke University
October 2nd, 2015
1 / 16
Overview
DRAW: Deep Recurrent Attentive WriterIdeas: generate images sequentially
spatial attention mechanism, that learns where to looksequential variational auto-encoding framework, a pair of recurrent nets
Demo: https://www.youtube.com/watch?v=Zt-7MI9eKEo
2 / 16
Basics: Conventional Variational Auto-Encoder
Generative model (decoder FNN): prior P(z),likelihood P(x |z)Recognition model (encoder FNN): approximateposterior Q(z |x)Objective: minimize the total description length
L = KL(Q(z |x)||P(z))− EQ(z|x)[logP(x |z)](1)
i.e., maximize variational lower boundLearning: stochastic gradient descent
3 / 16
The DRAW Network
Three Key Differences:both encoder and decoder are LSTM recurrent netsthe decoder’s outputs are successively addedattention – where to read, where to write, and what to write
4 / 16
The DRAW Network
x̂t = x − σ(ct−1) (2)
rt = read(x , x̂t , hdect−1) (3)
henct = RNNenc(henct−1, [rt , hdect−1]) (4)
zt ∼ Q(zt |henct ) (5)
hdect = RNNdec(hdect−1, zt) (6)
ct = ct−1 + write(hdect ) (7)
ct is called canvas matrixthe final cT is used toparameterize D(x |cT )
5 / 16
The DRAW Network
Notation: b = W (a) denotes a linear weight matrix with bias fromvector a to b.Approximate posterior: Q(zt |henct ) = N (zt |µt , σ2
t )
µt = W (henct ), σ2t = exp(W (henct )) (8)
Loss Function:
L = Lx + Lz , Lx = − logD(x |cT ) (9)
Lz =T∑t=1
KL(Q(zt |henct )||P(zt)) (10)
Stochastic Data Generation
z̃t ∼ P(zt) (11)
h̃dect = RNNdec(h̃dect−1, z̃t) (12)
c̃t = c̃t−1 + write(h̃dect ) (13)x̃ ∼ D(x |c̃T ) (14)
6 / 16
Read and Write Operations
Reading and Writing Without Attention
read(x , x̂t , hdect−1) = [x , x̂t ] (15)
write(hdect ) = W (hdect ) (16)
7 / 16
Selective Attention Model
read: from the A×B input image, to obtain an N ×N attention patchHow to achieve this?
horizontal and vertical filterbank FX (N × A) and FY (N × B)(i , j) a point on the attention patch, (a, b) a point in the input image
FX [i , a] =1ZX
exp(−(a− µiX )2
2σ2
)(17)
FY [j , b] =1ZY
exp(−(b − µiY )2
2σ2
)(18)
µiX = gX + (i − N/2− 0.5)δ (19)
µjY = gY + (j − N/2− 0.5)δ (20)
8 / 16
Selective Attention Model
The attention is determined by five attention parameters, which isfurther dynamically determined at each time step
(g̃X , g̃Y , log σ2, log δ̃, log γ) = W (hdec) (21)
gX =A+ 12
(g̃X + 1) (22)
gY =B + 12
(g̃Y + 1) (23)
δ =max(A,B)− 1
N − 1δ̃ (24)
The initial patch will roughly covers the whole input image.
9 / 16
Reading and Writing With Attention
read: from the A×B input image, to obtain an N ×N attention patch
read(x , x̂t , hdect−1) = γ[FY xF
TX ,FY x̂F
TX ] (25)
write: from the N × N attention patch, back to A× B input image
wt = W (hdect ) (26)
write(hdect ) =1γ̂[F̂T
Y wt F̂X ] (27)
image below: 100× 75, patch size: 12× 12.
10 / 16
Experiments – Cluttered MNIST Classification
11 / 16
Experiments – MNIST Generation
12 / 16
Experiments – MNIST Generation with Two Digits
13 / 16
Experiments – Street View House Number Generation
14 / 16
Experiments – Generating CIFAR Images
15 / 16
Backup: Experimental Hyper-Parameters
16 / 16